Bladeren bron

Support vq encoding & update generate

Lengyue 1 jaar geleden
bovenliggende
commit
2b148f2eb8

+ 126 - 0
fish_speech/configs/vits_decoder_finetune.yaml

@@ -0,0 +1,126 @@
+defaults:
+  - base
+  - _self_
+
+project: vits_decoder
+ckpt_path: checkpoints/Bert-VITS2/ensemble.pth
+resume_weights_only: true
+
+# Lightning Trainer
+trainer:
+  accelerator: gpu
+  devices: auto
+  strategy: ddp_find_unused_parameters_true
+  precision: 32
+  max_steps: 100_000
+  val_check_interval: 1000
+
+sample_rate: 44100
+hop_length: 512
+num_mels: 128
+n_fft: 2048
+win_length: 2048
+
+# Dataset Configuration
+tokenizer:
+  _target_: transformers.AutoTokenizer.from_pretrained
+  pretrained_model_name_or_path: fishaudio/fish-speech-1
+
+# Dataset Configuration
+train_dataset:
+  _target_: fish_speech.datasets.vits.VITSDataset
+  filelist: data/source/Genshin/filelist.train.txt
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+  suffix: ".lab"
+  tokenizer: ${tokenizer}
+  sentence_mask_ratio: 0.2
+
+val_dataset:
+  _target_: fish_speech.datasets.vits.VITSDataset
+  filelist: data/source/Genshin/filelist.test.txt
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+  suffix: ".lab"
+  tokenizer: ${tokenizer}
+
+data:
+  _target_: fish_speech.datasets.vits.VITSDataModule
+  train_dataset: ${train_dataset}
+  val_dataset: ${val_dataset}
+  num_workers: 4
+  batch_size: 8
+  val_batch_size: 4
+  tokenizer: ${tokenizer}
+
+# Model Configuration
+model:
+  _target_: fish_speech.models.vits_decoder.VITSDecoder
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+  freeze_discriminator: false
+
+  weight_mel: 45.0
+  weight_kl: 1.0
+
+  generator:
+    _target_: fish_speech.models.vits_decoder.modules.models.SynthesizerTrn
+    spec_channels: 1025
+    segment_size: 32
+    inter_channels: 192
+    hidden_channels: 192
+    filter_channels: 768
+    n_heads: 2
+    n_layers: 6
+    kernel_size: 3
+    p_dropout: 0.1
+    resblock: "1"
+    resblock_kernel_sizes: [3, 7, 11]
+    resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
+    upsample_rates: [8, 8, 2, 2, 2]
+    upsample_initial_channel: 512
+    upsample_kernel_sizes: [16, 16, 8, 2, 2]
+    gin_channels: 512
+    vq_mask_ratio: 0.2
+    ref_mask_ratio: 0.2
+
+  discriminator:
+    _target_: fish_speech.models.vits_decoder.modules.models.EnsembledDiscriminator
+    periods: [2, 3, 5, 7, 11]
+
+  mel_transform:
+    _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
+    sample_rate: ${sample_rate}
+    n_fft: ${n_fft}
+    hop_length: ${hop_length}
+    win_length: ${win_length}
+    n_mels: ${num_mels}
+
+  spec_transform:
+    _target_: fish_speech.utils.spectrogram.LinearSpectrogram
+    n_fft: ${n_fft}
+    hop_length: ${hop_length}
+    win_length: ${win_length}
+    mode: pow2_sqrt
+  
+  optimizer:
+    _target_: torch.optim.AdamW
+    _partial_: true
+    lr: 1e-4
+    betas: [0.8, 0.99]
+    eps: 1e-6
+
+  lr_scheduler:
+    _target_: torch.optim.lr_scheduler.ExponentialLR
+    _partial_: true
+    gamma: 0.999999
+
+callbacks:
+  grad_norm_monitor:
+    sub_module: 
+      - generator
+      - discriminator
+
+  model_checkpoint:
+    every_n_train_steps: 1000
+    save_top_k: 10

+ 0 - 0
fish_speech/configs/vits_decoder.yaml → fish_speech/configs/vits_decoder_pretrain.yaml


+ 16 - 0
fish_speech/models/vits_decoder/modules/vq_encoder.py

@@ -83,3 +83,19 @@ class VQEncoder(nn.Module):
         z = self.quantizer.decode(indices) * mel_masks_float_conv
 
         return z
+
+    @torch.no_grad()
+    def encode(self, audios, audio_lengths, sr=None):
+        audios = audios.float()
+
+        mels = self.spec(audios, sample_rate=sr)
+        mel_lengths = audio_lengths // self.spec.hop_length
+        mel_masks = sequence_mask(mel_lengths, mels.shape[2])
+        mel_masks_float_conv = mel_masks[:, None, :].float()
+        mels = mels * mel_masks_float_conv
+
+        # Encode
+        encoded_features = self.encoder(mels) * mel_masks_float_conv
+        feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor)
+
+        return self.quantizer.encode(encoded_features), feature_lengths

+ 48 - 32
tools/llama/generate.py

@@ -3,8 +3,9 @@ import queue
 import string
 import threading
 import time
+from dataclasses import dataclass
 from pathlib import Path
-from typing import Optional, Tuple, Union
+from typing import Literal, Optional, Tuple, Union
 
 import click
 import hydra
@@ -439,6 +440,13 @@ def split_text(text, min_length):
     return segments
 
 
+@dataclass
+class GenerateResponse:
+    action: Literal["sample", "next"]
+    codes: Optional[torch.Tensor] = None
+    text: Optional[str] = None
+
+
 def generate_long(
     *,
     model,
@@ -458,7 +466,6 @@ def generate_long(
     speaker: Optional[str] = None,
     prompt_text: Optional[str] = None,
     prompt_tokens: Optional[torch.Tensor] = None,
-    is_streaming: bool = False,
 ):
     assert 0 < top_p <= 1, "top_p must be in (0, 1]"
     assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
@@ -508,7 +515,6 @@ def generate_long(
             torch.cuda.synchronize()
 
         global_encoded = []
-        all_codes = []
         seg_idx = 0
 
         while seg_idx < len(encoded):
@@ -594,22 +600,24 @@ def generate_long(
 
             # But for global encoding, we should keep the <im_end> token
             global_encoded.append(decoded)
+            assert (codes >= 0).all(), f"Negative code found: {codes}"
+            yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
+            seg_idx += 1
 
-            if is_streaming:
-                assert (codes >= 0).all(), f"Negative code found: {codes}"
-                yield codes
-            else:
-                all_codes.append(codes)
+        # This indicates the end of the current sample
+        yield GenerateResponse(action="next")
 
-            seg_idx += 1
 
-        if is_streaming:
-            # This indicates the end of the current sample
-            yield "next"
-        else:
-            all_codes = torch.cat(all_codes, dim=1)
-            assert (all_codes >= 0).all(), f"Negative code found: {codes}"
-            yield all_codes
+@dataclass
+class WrappedGenerateResponse:
+    status: Literal["success", "error"]
+    response: Optional[GenerateResponse | Exception] = None
+
+
+@dataclass
+class GenerateRequest:
+    request: dict
+    response_queue: queue.Queue
 
 
 def launch_thread_safe_queue(
@@ -617,8 +625,8 @@ def launch_thread_safe_queue(
     checkpoint_path,
     device,
     precision,
-    max_length,
-    compile=False,
+    max_length: int,
+    compile: bool = False,
 ):
     input_queue = queue.Queue()
     init_event = threading.Event()
@@ -630,26 +638,22 @@ def launch_thread_safe_queue(
         init_event.set()
 
         while True:
-            item = input_queue.get()
+            item: GenerateRequest | None = input_queue.get()
             if item is None:
                 break
 
-            kwargs = item["request"]
-            response_queue = item["response_queue"]
+            kwargs = item.request
+            response_queue = item.response_queue
 
             try:
-                item["success"] = True
                 for chunk in generate_long(
                     model=model, decode_one_token=decode_one_token, **kwargs
                 ):
-                    response_queue.put(chunk)
-
-                response_queue.put("done")
+                    response_queue.put(
+                        WrappedGenerateResponse(status="success", response=chunk)
+                    )
             except Exception as e:
-                item["success"] = False
-                item["response"] = e
-
-                response_queue.put("done")
+                response_queue.put(WrappedGenerateResponse(status="error", response=e))
 
     threading.Thread(target=worker, daemon=True).start()
     init_event.wait()
@@ -753,9 +757,21 @@ def main(
         prompt_tokens=prompt_tokens,
     )
 
-    for idx, codes in enumerate(generator):
-        np.save(f"codes_{idx}.npy", codes.cpu().numpy())
-        logger.info(f"Saved codes to codes_{idx}.npy")
+    idx = 0
+    codes = []
+
+    for response in generator:
+        if response.action == "sample":
+            codes.append(response.codes)
+            logger.info(f"Sampled text: {response.text}")
+        elif response.action == "next":
+            if codes:
+                np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy())
+                logger.info(f"Saved codes to codes_{idx}.npy")
+            logger.info(f"Next sample")
+            idx += 1
+        else:
+            logger.error(f"Error: {response}")
 
 
 if __name__ == "__main__":