Просмотр исходного кода

Optimize instruction following & parallel inference

Lengyue 1 год назад
Родитель
Сommit
eac93fa2f4
3 измененных файлов с 46 добавлено и 25 удалено
  1. 10 4
      fish_speech/datasets/text.py
  2. 35 21
      tools/llama/generate.py
  3. 1 0
      tools/webui.py

+ 10 - 4
fish_speech/datasets/text.py

@@ -178,7 +178,7 @@ class AutoAugTextDataset(IterableDataset):
         interactive_prob: float = 0.5,
         max_length: int = 1024,
         tokenizer: AutoTokenizer = None,
-        use_speaker: bool = True,
+        use_speaker: bool | float = True,
         causual: bool = True,
         use_negative_samples: bool = False,
         num_codebooks: Optional[int] = None,
@@ -319,6 +319,12 @@ class AutoAugTextDataset(IterableDataset):
         else:
             remaining_tokens = self.max_length
 
+        # Use speaker
+        if isinstance(self.use_speaker, float):
+            use_speaker = random.random() < self.use_speaker
+        else:
+            use_speaker = self.use_speaker
+
         all_tokens, all_labels = [], []
         while remaining_tokens > 0 and len(samples) > 0:
             sentence = samples.pop(0)
@@ -336,7 +342,7 @@ class AutoAugTextDataset(IterableDataset):
                 tokens, labels = self.pack_sentences(
                     sentences=[text],
                     semantics=[sentence.semantics],
-                    speaker=response.name if (self.use_speaker and idx == 0) else None,
+                    speaker=response.name if use_speaker else None,
                     add_bos=idx == 0,
                 )
 
@@ -349,7 +355,7 @@ class AutoAugTextDataset(IterableDataset):
             tokens, labels = self.pack_sentences(
                 final_text,
                 semantics=final_semantic,
-                speaker=response.name if self.use_speaker else None,
+                speaker=response.name if use_speaker else None,
                 add_bos=True,
             )
             all_tokens.append(tokens)
@@ -440,7 +446,7 @@ class AutoAugTextDataset(IterableDataset):
         speaker: Optional[str] = None,
         add_bos: bool = True,
     ):
-        if speaker is not None:
+        if speaker is None:
             speaker = "assistant"
 
         final_text = "<|im_start|>user<|im_sep|>" + " ".join(sentences) + "<|im_end|>"

+ 35 - 21
tools/llama/generate.py

@@ -14,7 +14,7 @@ from loguru import logger
 from tqdm import tqdm
 from transformers import AutoTokenizer
 
-from fish_speech.datasets.text import CODEBOOK_EOS_TOKEN_ID
+from fish_speech.datasets.text import CODEBOOK_EOS_TOKEN_ID, CODEBOOK_PAD_TOKEN_ID
 from fish_speech.text.clean import clean_text
 
 os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -291,7 +291,7 @@ def encode_tokens(
 ):
     string = clean_text(string)
 
-    if speaker is not None:
+    if speaker is None:
         speaker = "assistant"
 
     string = (
@@ -309,7 +309,10 @@ def encode_tokens(
     tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
 
     # Codebooks
-    zeros = torch.zeros((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
+    zeros = (
+        torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
+        * CODEBOOK_PAD_TOKEN_ID
+    )
     prompt = torch.cat((tokens, zeros), dim=0)
 
     if prompt_tokens is None:
@@ -331,13 +334,23 @@ def encode_tokens(
         )
         data = data[:num_codebooks]
 
+    # Add eos token for each codebook
+    data = torch.cat(
+        (
+            data,
+            torch.ones((data.size(0), 1), dtype=torch.int, device=device)
+            * CODEBOOK_EOS_TOKEN_ID,
+        ),
+        dim=1,
+    )
+
     # Since 1.0, we use <|semantic|>
     s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
-    main_token_ids = torch.tensor(
-        [[s0_token_id] * data.size(1)],
-        dtype=torch.int,
-        device=device,
+    end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
+    main_token_ids = (
+        torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
     )
+    main_token_ids[0, -1] = end_token_id
 
     data = torch.cat((main_token_ids, data), dim=0)
     prompt = torch.cat((prompt, data), dim=1)
@@ -450,6 +463,20 @@ def generate_long(
     use_prompt = prompt_text is not None and prompt_tokens is not None
     encoded = []
     texts = split_text(text, chunk_length) if iterative_prompt else [text]
+
+    if use_prompt:
+        encoded.append(
+            encode_tokens(
+                tokenizer,
+                prompt_text,
+                prompt_tokens=prompt_tokens,
+                bos=True,
+                device=device,
+                speaker=speaker,
+                num_codebooks=model.config.num_codebooks,
+            )
+        )
+
     for idx, text in enumerate(texts):
         encoded.append(
             encode_tokens(
@@ -457,25 +484,12 @@ def generate_long(
                 string=text,
                 bos=idx == 0 and not use_prompt,
                 device=device,
-                speaker=None,
+                speaker=speaker,
                 num_codebooks=model.config.num_codebooks,
             )
         )
         logger.info(f"Encoded text: {text}")
 
-    if use_prompt:
-        encoded_prompt = encode_tokens(
-            tokenizer,
-            prompt_text,
-            prompt_tokens=prompt_tokens,
-            bos=True,
-            device=device,
-            speaker=speaker,
-            num_codebooks=model.config.num_codebooks,
-        )
-
-        encoded[0] = torch.cat((encoded_prompt, encoded[0]), dim=1)
-
     for sample_idx in range(num_samples):
         torch.cuda.synchronize()
         global_encoded = []

+ 1 - 0
tools/webui.py

@@ -233,6 +233,7 @@ def build_app():
                 speaker,
             ],
             [audio, error],
+            concurrency_limit=1,
         )
 
     return app