Przeglądaj źródła

Support speaker finetune

Lengyue 2 lat temu
rodzic
commit
2966ba9019

+ 5 - 1
data_server/src/main.rs

@@ -106,7 +106,11 @@ impl DataService for MyDataService {
                 .cloned() // Clone each &Sentence to get Sentence
                 .collect();
 
-            Ok(Response::new(SampledData { samples: sentences }))
+            Ok(Response::new(SampledData {
+                name: group.name.clone(), 
+                source: group.source.clone(),
+                samples: sentences 
+            }))
         } else {
             Err(Status::internal("Failed to select a group"))
         }

+ 84 - 0
fish_speech/configs/text2semantic_finetune_spk.yaml

@@ -0,0 +1,84 @@
+defaults:
+  - base
+  - _self_
+
+project: text2semantic_400m_finetune_spk
+max_length: 4096
+ckpt_path: results/text2semantic_400m_finetune/checkpoints/step_000010000.ckpt
+resume_weights_only: true
+
+# Lightning Trainer
+trainer:
+  accumulate_grad_batches: 2
+  gradient_clip_val: 1.0
+  gradient_clip_algorithm: 'norm'
+  max_steps: 1000
+  precision: bf16-true
+  limit_val_batches: 10
+
+# Dataset Configuration
+tokenizer:
+  _target_: transformers.AutoTokenizer.from_pretrained
+  pretrained_model_name_or_path: fishaudio/speech-lm-v1
+
+# Dataset Configuration
+train_dataset:
+  _target_: fish_speech.datasets.text.AutoAugTextDataset
+  tokenizer: ${tokenizer}
+  max_length: ${max_length}
+
+val_dataset:
+  _target_: fish_speech.datasets.text.AutoAugTextDataset
+  tokenizer: ${tokenizer}
+  max_length: ${max_length}
+
+data:
+  _target_: fish_speech.datasets.text.TextDataModule
+  train_dataset: ${train_dataset}
+  val_dataset: ${val_dataset}
+  num_workers: 4
+  batch_size: 8
+  tokenizer: ${tokenizer}
+  max_length: ${max_length}
+
+# Model Configuration
+model:
+  _target_: fish_speech.models.text2semantic.TextToSemantic
+
+  model:
+    # ~ 130M parameters, for debug purpose
+    _target_: fish_speech.models.text2semantic.llama.Transformer
+    config:
+      _target_: fish_speech.models.text2semantic.llama.ModelArgs
+      max_seq_len: 4096
+      vocab_size: 36408
+      n_layer: 24
+      n_head: 16
+      dim: 1024
+      rope_base: 10000
+      norm_eps: 1e-5
+      num_codebooks: 4  # single codebook
+      codebook_size: 168 # codebook size 160 + 2 special tokens
+
+  optimizer:
+    _target_: torch.optim.AdamW
+    _partial_: true
+    lr: 1e-4
+    weight_decay: 0.1
+    betas: [0.9, 0.95]
+    eps: 1e-5
+
+  lr_scheduler:
+    _target_: torch.optim.lr_scheduler.LambdaLR
+    _partial_: true
+    lr_lambda:
+      _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
+      _partial_: true
+      num_warmup_steps: 100
+      num_training_steps: ${trainer.max_steps}
+      final_lr_ratio: 0.1
+
+# Callbacks
+callbacks:
+  model_checkpoint:
+    every_n_train_steps: 1000

+ 3 - 1
fish_speech/datasets/protos/text-data.proto

@@ -20,7 +20,9 @@ message TextData {
 }
 
 message SampledData {
-    repeated Sentence samples = 1;
+    string source = 1;
+    string name = 2;
+    repeated Sentence samples = 3;
 }
 
 message SampleDataRequest {

+ 6 - 6
fish_speech/datasets/protos/text_data_pb2.py

@@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"Q\n\x08Sentence\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x0e\n\x06phones\x18\x02 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"c\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x11\n\tlanguages\x18\x03 \x03(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"3\n\x0bSampledData\x12$\n\x07samples\x18\x01 \x03(\x0b\x32\x13.text_data.Sentence"(\n\x11SampleDataRequest\x12\x13\n\x0bnum_samples\x18\x01 \x01(\r2S\n\x0b\x44\x61taService\x12\x44\n\nSampleData\x12\x1c.text_data.SampleDataRequest\x1a\x16.text_data.SampledData"\x00\x62\x06proto3'
+    b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"Q\n\x08Sentence\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x0e\n\x06phones\x18\x02 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"c\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x11\n\tlanguages\x18\x03 \x03(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"Q\n\x0bSampledData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12$\n\x07samples\x18\x03 \x03(\x0b\x32\x13.text_data.Sentence"(\n\x11SampleDataRequest\x12\x13\n\x0bnum_samples\x18\x01 \x01(\r2S\n\x0b\x44\x61taService\x12\x44\n\nSampleData\x12\x1c.text_data.SampleDataRequest\x1a\x16.text_data.SampledData"\x00\x62\x06proto3'
 )
 
 _globals = globals()
@@ -29,9 +29,9 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _globals["_TEXTDATA"]._serialized_start = 142
     _globals["_TEXTDATA"]._serialized_end = 241
     _globals["_SAMPLEDDATA"]._serialized_start = 243
-    _globals["_SAMPLEDDATA"]._serialized_end = 294
-    _globals["_SAMPLEDATAREQUEST"]._serialized_start = 296
-    _globals["_SAMPLEDATAREQUEST"]._serialized_end = 336
-    _globals["_DATASERVICE"]._serialized_start = 338
-    _globals["_DATASERVICE"]._serialized_end = 421
+    _globals["_SAMPLEDDATA"]._serialized_end = 324
+    _globals["_SAMPLEDATAREQUEST"]._serialized_start = 326
+    _globals["_SAMPLEDATAREQUEST"]._serialized_end = 366
+    _globals["_DATASERVICE"]._serialized_start = 368
+    _globals["_DATASERVICE"]._serialized_end = 451
 # @@protoc_insertion_point(module_scope)

+ 14 - 6
fish_speech/datasets/text.py

@@ -181,6 +181,7 @@ class AutoAugTextDataset(IterableDataset):
         repetition_prob: float = 0.0,
         max_length: int = 1024,
         tokenizer: AutoTokenizer = None,
+        use_speaker: bool = True,
     ):
         """
         Args:
@@ -199,6 +200,7 @@ class AutoAugTextDataset(IterableDataset):
         self.max_length = max_length
         self.tokenizer = tokenizer
         self.repetition_prob = repetition_prob
+        self.use_speaker = use_speaker
 
         # Read all lines, and group by speaker
         self.channel = grpc.insecure_channel(server)
@@ -218,6 +220,8 @@ class AutoAugTextDataset(IterableDataset):
                     for i in phones
                 ]
             )
+        else:
+            sentence = clean_text(sentence)
 
         tokens = self.tokenizer.encode(
             f"{sentence}",
@@ -268,6 +272,9 @@ class AutoAugTextDataset(IterableDataset):
             final_text.append(text)
             final_semantic.append(sentence.semantics)
 
+        if self.use_speaker is not None:
+            final_text = [f"[SPK: {response.name}]"] + final_text
+
         final_text = "[INST] " + " ".join(final_text) + " [/INST]"
         encoded = self.tokenizer.encode(
             final_text,
@@ -441,15 +448,16 @@ if __name__ == "__main__":
 
     from tqdm import tqdm
 
-    # ds = AutoAugTextDataset(
-    #     tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
-    # )
-
-    ds = StreamTextDataset(
-        prefix="en/",
+    ds = AutoAugTextDataset(
         tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
+        use_speaker=True,
     )
 
+    # ds = StreamTextDataset(
+    #     prefix="en/",
+    #     tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
+    # )
+
     dm = TextDataModule(
         train_dataset=ds,
         val_dataset=ds,

+ 7 - 0
tools/llama/generate.py

@@ -263,6 +263,7 @@ def encode_tokens(
     prompt_text=None,
     prompt_tokens=None,
     use_g2p=False,
+    speaker=None,
 ):
     if prompt_text is not None:
         string = prompt_text + " " + string
@@ -277,6 +278,9 @@ def encode_tokens(
     else:
         string = clean_text(string)
 
+    if speaker is not None:
+        string = f"[SPK: {speaker}] {string}"
+
     string = f"[INST] {string} [/INST]"
 
     tokens = tokenizer.encode(
@@ -373,6 +377,7 @@ def load_model(config_name, checkpoint_path, device, precision):
 @click.option("--compile/--no-compile", default=False)
 @click.option("--use-g2p/--no-g2p", default=True)
 @click.option("--seed", type=int, default=42)
+@click.option("--speaker", type=str, default=None)
 def main(
     text: str,
     prompt_text: Optional[str],
@@ -389,6 +394,7 @@ def main(
     compile: bool,
     use_g2p: bool,
     seed: int,
+    speaker: Optional[str],
 ) -> None:
     device = "cuda"
     precision = torch.bfloat16
@@ -415,6 +421,7 @@ def main(
         bos=True,
         device=device,
         use_g2p=use_g2p,
+        speaker=speaker,
     )
     prompt_length = encoded.size(1)
     logger.info(f"Encoded prompt shape: {encoded.shape}")