Lengyue před 2 roky
rodič
revize
70f258bb90
2 změnil soubory, kde provedl 37 přidání a 11 odebrání
  1. 5 3
      fish_speech/datasets/text.py
  2. 32 8
      tools/llama/generate.py

+ 5 - 3
fish_speech/datasets/text.py

@@ -444,7 +444,7 @@ class AutoAugTextDataset(IterableDataset):
             sentences = [f"[SPK: {speaker}]"] + sentences
 
         final_text = "<|im_start|>user<|im_sep|>" + " ".join(sentences) + "<|im_end|>"
-        final_text += final_text + "<|im_start|>assistant<|im_sep|>"
+        final_text = final_text + "<|im_start|>assistant<|im_sep|>"
 
         encoded = self.tokenizer.encode(
             final_text,
@@ -650,13 +650,15 @@ if __name__ == "__main__":
     from tqdm import tqdm
 
     ds = AutoAugTextDataset(
-        ["data/protos/libirtts-test"],
+        ["data/protos/test"],
         tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"),
-        use_speaker=True,
+        use_speaker=False,
         interactive_prob=1.0,
         use_negative_samples=False,
     )
 
     for i in ds:
         print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
+        # i["labels"][0][i["labels"][0] == -100] = 0
+        # print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False))
         break

+ 32 - 8
tools/llama/generate.py

@@ -14,6 +14,7 @@ from loguru import logger
 from tqdm import tqdm
 from transformers import AutoTokenizer
 
+from fish_speech.datasets.text import CODEBOOK_EOS_TOKEN_ID
 from fish_speech.text.clean import clean_text
 
 os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -147,9 +148,9 @@ def decode_one_token_naive(
         codebooks.append(
             sample(
                 x.codebook_logits[:, :, i],
-                previous_tokens=previous_tokens[i + 1]
-                if previous_tokens is not None
-                else None,
+                previous_tokens=(
+                    previous_tokens[i + 1] if previous_tokens is not None else None
+                ),
                 **sampling_kwargs,
             )[0]
         )
@@ -163,6 +164,7 @@ def decode_n_tokens(
     input_pos: torch.Tensor,
     num_new_tokens: int,
     eos_token_id: int = 2,
+    im_end_id: int = 4,
     decode_one_token=decode_one_token_naive,
     **sampling_kwargs,
 ):
@@ -197,8 +199,11 @@ def decode_n_tokens(
             model.config.num_codebooks + 1, -1
         )
 
-        # TODO: use tokenizer's eos
-        if cur_token[0, 0, -1] == eos_token_id or (cur_token[0, 1:, -1] == 1).any():
+        if (
+            cur_token[0, 0, -1] == eos_token_id
+            or cur_token[0, 0, -1] == im_end_id
+            or (cur_token[0, 1:, -1] == CODEBOOK_EOS_TOKEN_ID).any()
+        ):
             break
 
     return previous_tokens[:, : i + 1]
@@ -212,6 +217,7 @@ def generate(
     prompt: torch.Tensor,
     max_new_tokens: int,
     eos_token_id: int = 2,
+    im_end_id: int = 4,
     decode_one_token=decode_one_token_naive,
     precision: torch.dtype = torch.bfloat16,
     **sampling_kwargs,
@@ -256,6 +262,7 @@ def generate(
         input_pos,
         max_new_tokens - 1,
         eos_token_id=eos_token_id,
+        im_end_id=im_end_id,
         decode_one_token=decode_one_token,
         **sampling_kwargs,
     )
@@ -283,9 +290,12 @@ def encode_tokens(
     string = (
         f"<|im_start|>user<|im_sep|>{string}<|im_end|><|im_start|>assistant<|im_sep|>"
     )
+    if bos:
+        string = f"<|begin_of_sequence|>{string}"
+
     new_tokens = tokenizer.encode(
         string,
-        add_special_tokens=bos,
+        add_special_tokens=False,
         max_length=10**6,
         truncation=False,
     )
@@ -392,7 +402,11 @@ def split_text(text, min_length):
 
 
 @click.command()
-@click.option("--text", type=str, default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.")
+@click.option(
+    "--text",
+    type=str,
+    default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
+)
 @click.option("--prompt-text", type=str, default=None)
 @click.option(
     "--prompt-tokens", type=click.Path(path_type=Path, exists=True), default=None
@@ -457,6 +471,8 @@ def main(
         else None
     )
 
+    im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
+
     use_prompt = prompt_text is not None and prompt_tokens is not None
     encoded = []
     texts = split_text(text, chunk_length) if iterative_prompt else [text]
@@ -539,6 +555,7 @@ def main(
                 prompt=cat_encoded,
                 max_new_tokens=max_new_tokens,
                 eos_token_id=tokenizer.eos_token_id,
+                im_end_id=im_end_id,
                 decode_one_token=decode_one_token,
                 precision=precision,
                 temperature=temperature,
@@ -575,8 +592,15 @@ def main(
                 logger.warning(f"Negative code found: {codes}, retrying ...")
                 continue
 
+            decoded = y[:, prompt_length:-1].clone()
+            if decoded[0, -1] != im_end_id:  # <im_end>
+                val = [[im_end_id]] + [[CODEBOOK_EOS_TOKEN_ID]] * (decoded.size(0) - 1)
+                decoded = torch.cat(
+                    (decoded, torch.tensor(val, device=device, dtype=torch.int)), dim=1
+                )
+
             # But for global encoding, we should keep the <im_end> token
-            global_encoded.append(y[:, prompt_length:-1].clone())
+            global_encoded.append(decoded)
             all_codes.append(codes)
             seg_idx += 1