Lengyue пре 2 година
родитељ
комит
9649873532

+ 5 - 3
fish_speech/datasets/text.py

@@ -188,9 +188,10 @@ class AutoAugTextDataset(IterableDataset):
 
     def augment(self):
         # 50% to pure text or pure phones
-        mode = "sample"
-        if random.random() < 0.5:
-            mode = random.choice(["text", "phones"])
+        # mode = "sample"
+        # if random.random() < 0.5:
+        #     mode = random.choice(["text", "phones"])
+        mode = "phones"
 
         # Random sample based on speaker using a truncated normal distribution
         a = torch.tensor([0], dtype=torch.float32)
@@ -250,6 +251,7 @@ class AutoAugTextDataset(IterableDataset):
             )
             tokens = torch.tensor([tokens], dtype=torch.long)
             labels = tokens.clone()
+            labels[0, : len(encoded) + 1] = -100  # Mask out the <s> and query tokens
         else:
             # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
             tokens = (

+ 65 - 18
fish_speech/models/text2semantic/generate.py

@@ -36,10 +36,20 @@ def multinomial_sample_one_no_sync(
 
 def logits_to_probs(
     logits,
+    previous_tokens: Optional[torch.Tensor] = None,
     temperature: float = 1.0,
     top_k: Optional[int] = None,
     top_p: Optional[int] = None,
+    repetition_penalty: float = 1.0,
 ):
+    if previous_tokens is not None and repetition_penalty != 1.0:
+        previous_tokens = previous_tokens.long()
+        score = torch.gather(logits, dim=-1, index=previous_tokens)
+        score = torch.where(
+            score < 0, score * repetition_penalty, score / repetition_penalty
+        )
+        logits.scatter_(dim=-1, index=previous_tokens, src=score)
+
     if top_p is not None and top_p < 1.0:
         sorted_logits, sorted_indices = torch.sort(logits, descending=True)
         cum_probs = torch.cumsum(
@@ -64,21 +74,35 @@ def logits_to_probs(
 
 def sample(
     logits,
+    previous_tokens: Optional[torch.Tensor] = None,
     temperature: float = 1.0,
     top_k: Optional[int] = None,
     top_p: Optional[int] = None,
+    repetition_penalty: float = 1.0,
 ) -> Tuple[torch.Tensor, torch.Tensor]:
-    probs = logits_to_probs(logits[0, -1], temperature, top_k, top_p)
+    probs = logits_to_probs(
+        logits[0, -1], previous_tokens, temperature, top_k, top_p, repetition_penalty
+    )
     idx_next = multinomial_sample_one_no_sync(probs)
     return idx_next, probs
 
 
 def decode_token(
-    model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
+    model: Transformer,
+    x: torch.Tensor,
+    input_pos: torch.Tensor,
+    previous_tokens: torch.Tensor = None,
+    **sampling_kwargs,
 ) -> torch.Tensor:
     # input_pos: [B, S]
     logits = model.forward_generate(x, input_pos)
-    codebooks = [sample(logits.token_logits, **sampling_kwargs)[0]]
+    codebooks = [
+        sample(
+            logits.token_logits,
+            previous_tokens=previous_tokens[0] if previous_tokens is not None else None,
+            **sampling_kwargs,
+        )[0]
+    ]
 
     # Disable <s> and </s> tokens for codebooks
     if model.config.num_codebooks != 0:
@@ -86,7 +110,13 @@ def decode_token(
 
         for i in range(model.config.num_codebooks):
             codebooks.append(
-                sample(logits.codebook_logits[:, :, i], **sampling_kwargs)[0]
+                sample(
+                    logits.codebook_logits[:, :, i],
+                    previous_tokens=previous_tokens[i]
+                    if previous_tokens is not None
+                    else None,
+                    **sampling_kwargs,
+                )[0]
             )
 
     return torch.stack(codebooks, dim=0)
@@ -105,7 +135,13 @@ def decode_n_tokens(
         with torch.backends.cuda.sdp_kernel(
             enable_flash=False, enable_mem_efficient=False, enable_math=True
         ):  # Actually better for Inductor to codegen attention here
-            next_token = decode_token(model, cur_token, input_pos, **sampling_kwargs)
+            next_token = decode_token(
+                model,
+                cur_token,
+                input_pos,
+                torch.concat(new_tokens, dim=1) if len(new_tokens) > 0 else None,
+                **sampling_kwargs,
+            )
         input_pos += 1
         new_tokens.append(next_token.clone())
         callback(new_tokens[-1])
@@ -139,6 +175,10 @@ def generate(
 
     # create an empty tensor of the expected final shape and fill in the current tokens
     T = prompt.size(1)
+    # if T + max_new_tokens > 1024:
+    #     max_new_tokens = 1024 - T
+    #     print(f"Truncating max_new_tokens to {max_new_tokens}")
+
     T_new = T + max_new_tokens
     if interactive:
         max_seq_length = 350
@@ -180,7 +220,8 @@ def generate(
 
 def encode_tokens(tokenizer, string, bos=True, device="cuda"):
     # data/Genshin/Chinese/神里绫华/vo_ayaka_character_idle_04.npy
-    prompt = g2p("剑,就和茶一样,细细品味才能理解其中风雅。 " + string)
+
+    prompt = g2p("算啦,虽然他罪无可恕,但也有可怜的地方嘛。" + string)
     prompt = [
         (f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
         for _, i in prompt
@@ -189,7 +230,7 @@ def encode_tokens(tokenizer, string, bos=True, device="cuda"):
     string = f"[INST] {prompt} [/INST]"
     print("Encoding string:", string)
 
-    data = np.load("data/Genshin/Chinese/神里绫华/vo_ayaka_character_idle_02.npy")
+    data = np.load("data/Genshin/Chinese/派蒙/vo_WYLQ103_10_paimon_03.npy")
     codes = [f"<s:{i}>" for i in data[0]]
 
     tokens = tokenizer.encode(
@@ -222,7 +263,7 @@ def _load_model(checkpoint_path, device, precision, use_tp):
         model = Transformer(
             ModelArgs(
                 max_seq_len=4096,
-                vocab_size=32312,
+                vocab_size=36408,
                 n_layer=24,
                 n_head=16,
                 dim=1024,
@@ -272,7 +313,8 @@ def main(
     num_samples: int = 5,
     max_new_tokens: int = 100,
     top_k: int = None,
-    top_p: int = None,
+    top_p: int = 1.0,
+    repetition_penalty: float = 1.0,
     temperature: float = 0.8,
     checkpoint_path: Path = Path(
         "results/text2semantic_400m/checkpoints/step_000025000.ckpt"
@@ -306,7 +348,6 @@ def main(
     tokenizer = AutoTokenizer.from_pretrained(tokenizer)
     print(prompt)
     encoded = encode_tokens(tokenizer, f"{prompt}", bos=True, device=device)
-    print(encoded[0])
     prompt_length = encoded.size(1)
 
     torch.manual_seed(1234)
@@ -370,6 +411,7 @@ def main(
                 temperature=temperature,
                 top_k=top_k,
                 top_p=top_p,
+                repetition_penalty=repetition_penalty,
             )
         if i == -1:
             print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
@@ -384,11 +426,14 @@ def main(
 
         if not interactive:
             print(tokenizer.decode(y[0].tolist()))
-            codes = y[1:, prompt_length - 120 : -1] - 2
+            # Find all <s:2769>
+            codes = y[0, prompt_length:-1]
+            codes = codes - 32311
+            # print(codes)
             assert (codes >= 0).all()
             import numpy as np
 
-            np.save(f"codes_{i}.npy", codes.cpu().numpy())
+            np.save(f"codes_{i}.npy", codes[None].cpu().numpy())
         else:
             print()
         tokens_generated = y.size(1) - prompt_length
@@ -414,7 +459,7 @@ if __name__ == "__main__":
     parser.add_argument(
         "--prompt",
         type=str,
-        default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
+        default="在情感分析功能中,我们让大语言模型分析了一段经典散文。可以看到虽然分析的角度比较浅显,但没有逻辑错误,还是可以自洽的。这也不能怪AI,如果我们提前告知它作者所处的时代背景,相信它一定可以回答得更好。而在这个中文翻译功能中,英特尔大语言模型的表现就更加令我意外了,",
         help="Input prompt.",
     )
     parser.add_argument(
@@ -424,17 +469,18 @@ if __name__ == "__main__":
     )
     parser.add_argument("--num_samples", type=int, default=1, help="Number of samples.")
     parser.add_argument(
-        "--max_new_tokens", type=int, default=768, help="Maximum number of new tokens."
+        "--max_new_tokens", type=int, default=1024, help="Maximum number of new tokens."
     )
-    parser.add_argument("--top_k", type=int, default=None, help="Top-k for sampling.")
-    parser.add_argument("--top_p", type=int, default=0.7, help="Top-k for sampling.")
+    parser.add_argument("--top_k", type=int, default=50, help="Top-k for sampling.")
+    parser.add_argument("--top_p", type=int, default=0.95, help="Top-k for sampling.")
+    parser.add_argument("--repetition_penalty", type=float, default=1.1)
     parser.add_argument(
-        "--temperature", type=float, default=1.0, help="Temperature for sampling."
+        "--temperature", type=float, default=0.8, help="Temperature for sampling."
     )
     parser.add_argument(
         "--checkpoint_path",
         type=Path,
-        default=Path("results/text2semantic_400m/step_000035000_weights.ckpt"),
+        default=Path("results/text2semantic_400m/step_000090000_weights.ckpt"),
         help="Model checkpoint path.",
     )
     parser.add_argument(
@@ -450,6 +496,7 @@ if __name__ == "__main__":
         args.max_new_tokens,
         args.top_k,
         args.top_p,
+        args.repetition_penalty,
         args.temperature,
         args.checkpoint_path,
         args.compile,

+ 7 - 3
fish_speech/models/text2semantic/llama.py

@@ -141,10 +141,14 @@ class Transformer(nn.Module):
         return x.sum(dim=3)
 
     def compute(
-        self, x: Tensor, freqs_cis: Tensor, mask: Tensor
+        self,
+        x: Tensor,
+        freqs_cis: Tensor,
+        mask: Tensor,
+        input_pos: Optional[Tensor] = None,
     ) -> TransformerForwardResult:
         for layer in self.layers:
-            x = layer(x, freqs_cis, mask)
+            x = layer(x, freqs_cis, mask, input_pos=input_pos)
 
         x = self.norm(x)
         logits = self.output(x)
@@ -202,7 +206,7 @@ class Transformer(nn.Module):
 
         # TODO: support key padding mask for generation
 
-        return self.compute(x, freqs_cis, mask)
+        return self.compute(x, freqs_cis, mask, input_pos=input_pos)
 
 
 class TransformerBlock(nn.Module):

+ 29 - 27
tools/infer_vq.py

@@ -34,7 +34,7 @@ def main():
 
     # Load audio
     audio = librosa.load(
-        "data/StarRail/Chinese/停云/chapter2_1_tingyun_142.wav",
+        "data/Genshin/Chinese/派蒙/vo_WYLQ103_10_paimon_04.wav",
         sr=model.sampling_rate,
         mono=True,
     )[0]
@@ -72,37 +72,39 @@ def main():
     _, indices, _ = model.vq_encoder(text_features, feature_masks)
     print(indices.shape)
 
+    speaker_features = model.speaker_encoder(gt_mels, mel_masks)
+
     # Restore
-    # indices = np.load("codes_0.npy")
-    # indices = torch.from_numpy(indices).to(model.device).long()
+    indices = np.load("codes_0.npy")
+    indices = torch.from_numpy(indices).to(model.device).long()
+    print(indices)
     # indices = indices.unsqueeze(1).unsqueeze(-1)
-    # mel_lengths = indices.shape[2] * (
-    #     model.downsample.total_strides if model.downsample is not None else 1
-    # )
-    # mel_lengths = torch.tensor([mel_lengths], device=model.device, dtype=torch.long)
-    # mel_masks = torch.ones(
-    #     (1, 1, mel_lengths), device=model.device, dtype=torch.float32
-    # )
+    mel_lengths = indices.shape[1] * (
+        model.downsample.total_strides if model.downsample is not None else 1
+    )
+    mel_lengths = torch.tensor([mel_lengths], device=model.device, dtype=torch.long)
+    mel_masks = torch.ones(
+        (1, 1, mel_lengths), device=model.device, dtype=torch.float32
+    )
 
-    # print(mel_lengths)
+    print(mel_lengths)
 
     # Reference speaker
-    ref_audio = librosa.load(
-        "data/StarRail/Chinese/符玄/chapter2_8_fuxuan_104.wav",
-        sr=model.sampling_rate,
-        mono=True,
-    )[0]
-    ref_audios = torch.from_numpy(ref_audio).to(model.device)[None, None, :]
-    ref_audio_lengths = torch.tensor(
-        [ref_audios.shape[2]], device=model.device, dtype=torch.long
-    )
-    ref_mels = model.mel_transform(ref_audios, sample_rate=model.sampling_rate)
-    ref_mel_lengths = ref_audio_lengths // model.hop_length
-    ref_mel_masks = torch.unsqueeze(
-        sequence_mask(ref_mel_lengths, ref_mels.shape[2]), 1
-    ).to(gt_mels.dtype)
-    speaker_features = model.speaker_encoder(ref_mels, ref_mel_masks)
-    # speaker_features = model.speaker_encoder(gt_mels, mel_masks)
+    # ref_audio = librosa.load(
+    #     "data/StarRail/Chinese/符玄/chapter2_8_fuxuan_104.wav",
+    #     sr=model.sampling_rate,
+    #     mono=True,
+    # )[0]
+    # ref_audios = torch.from_numpy(ref_audio).to(model.device)[None, None, :]
+    # ref_audio_lengths = torch.tensor(
+    #     [ref_audios.shape[2]], device=model.device, dtype=torch.long
+    # )
+    # ref_mels = model.mel_transform(ref_audios, sample_rate=model.sampling_rate)
+    # ref_mel_lengths = ref_audio_lengths // model.hop_length
+    # ref_mel_masks = torch.unsqueeze(
+    #     sequence_mask(ref_mel_lengths, ref_mels.shape[2]), 1
+    # ).to(gt_mels.dtype)
+    # speaker_features = model.speaker_encoder(ref_mels, ref_mel_masks)
 
     print("indices", indices.shape)
     text_features = model.vq_encoder.decode(indices)