فهرست منبع

Support mix codebook training

Lengyue 2 سال پیش
والد
کامیت
9ac8edef1e

+ 2 - 3
fish_speech/configs/text2semantic.yaml

@@ -46,14 +46,13 @@ model:
     config:
     config:
       _target_: fish_speech.models.text2semantic.llama.ModelArgs
       _target_: fish_speech.models.text2semantic.llama.ModelArgs
       max_seq_len: 4096
       max_seq_len: 4096
-      vocab_size: 32312
+      vocab_size: 36408
       n_layer: 24
       n_layer: 24
       n_head: 16
       n_head: 16
       dim: 1024
       dim: 1024
       rope_base: 10000
       rope_base: 10000
       norm_eps: 1e-5
       norm_eps: 1e-5
-      codebook_size: 168
-      num_codebooks: 4
+      num_codebooks: 0  # single codebook
 
 
   optimizer:
   optimizer:
     _target_: torch.optim.AdamW
     _target_: torch.optim.AdamW

+ 44 - 22
fish_speech/datasets/text.py

@@ -222,36 +222,58 @@ class AutoAugTextDataset(IterableDataset):
             final_text.append(text)
             final_text.append(text)
             final_semantic.append(sentence.semantics)
             final_semantic.append(sentence.semantics)
 
 
-        final_text = "[INST] " + "<pad>".join(final_text) + " [/INST]"
+        final_text = "[INST] " + " ".join(final_text) + " [/INST]"
         encoded = self.tokenizer.encode(
         encoded = self.tokenizer.encode(
             final_text,
             final_text,
-            max_length=self.max_length,
             add_special_tokens=False,
             add_special_tokens=False,
             truncation=False,
             truncation=False,
+            max_length=10**6,
         )
         )
         semantic_length = sum([len(i[0].values) for i in final_semantic])
         semantic_length = sum([len(i[0].values) for i in final_semantic])
 
 
-        # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
-        tokens = (
-            [self.tokenizer.bos_token_id]
-            + encoded
-            + [self.tokenizer.pad_token_id] * semantic_length
-            + [self.tokenizer.eos_token_id]
-        )
-        codes = [[0] * (len(encoded) + 1) for _ in range(len(final_semantic[0]))]
-        for segment in final_semantic:
-            for book_idx, book in enumerate(segment):
-                for j in book.values:
-                    codes[book_idx].append(int(j) + 2)
-
-        for book in codes:
-            book.append(1)
-
-        tokens = [tokens] + codes
-        tokens = torch.tensor(tokens, dtype=torch.long)
+        # Single codebook
+        if len(final_semantic[0]) == 1:
+            semantic_tokens = [f"<s:{j}>" for i in final_semantic for j in i[0].values]
+            tokenized = self.tokenizer.encode(
+                f" ".join(semantic_tokens),
+                add_special_tokens=False,
+                truncation=False,
+                max_length=10**6,
+            )
 
 
-        labels = tokens.clone()
-        labels[1:, : len(encoded) + 1] = -100  # Mask out the <s> tokens for semantic
+            # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
+            tokens = (
+                [self.tokenizer.bos_token_id]
+                + encoded
+                + tokenized
+                + [self.tokenizer.eos_token_id]
+            )
+            tokens = torch.tensor([tokens], dtype=torch.long)
+            labels = tokens.clone()
+        else:
+            # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
+            tokens = (
+                [self.tokenizer.bos_token_id]
+                + encoded
+                + [self.tokenizer.pad_token_id] * semantic_length
+                + [self.tokenizer.eos_token_id]
+            )
+            codes = [[0] * (len(encoded) + 1) for _ in range(len(final_semantic[0]))]
+            for segment in final_semantic:
+                for book_idx, book in enumerate(segment):
+                    for j in book.values:
+                        codes[book_idx].append(int(j) + 2)
+
+            for book in codes:
+                book.append(1)
+
+            tokens = [tokens] + codes
+
+            tokens = torch.tensor(tokens, dtype=torch.long)
+            labels = tokens.clone()
+            labels[
+                1:, : len(encoded) + 1
+            ] = -100  # Mask out the <s> tokens for semantic
 
 
         return {
         return {
             "tokens": tokens[:, :-1],
             "tokens": tokens[:, :-1],

+ 77 - 27
fish_speech/models/text2semantic/generate.py

@@ -9,6 +9,7 @@ import time
 from pathlib import Path
 from pathlib import Path
 from typing import Optional, Tuple
 from typing import Optional, Tuple
 
 
+import numpy as np
 import torch
 import torch
 import torch._dynamo.config
 import torch._dynamo.config
 import torch._inductor.config
 import torch._inductor.config
@@ -33,7 +34,24 @@ def multinomial_sample_one_no_sync(
     return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
     return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
 
 
 
 
-def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
+def logits_to_probs(
+    logits,
+    temperature: float = 1.0,
+    top_k: Optional[int] = None,
+    top_p: Optional[int] = None,
+):
+    if top_p is not None and top_p < 1.0:
+        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+        cum_probs = torch.cumsum(
+            torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
+        )
+        sorted_indices_to_remove = cum_probs > top_p
+        sorted_indices_to_remove[0] = False  # keep at least one option
+        indices_to_remove = sorted_indices_to_remove.scatter(
+            dim=0, index=sorted_indices, src=sorted_indices_to_remove
+        )
+        logits = logits.masked_fill(indices_to_remove, -float("Inf"))
+
     logits = logits / max(temperature, 1e-5)
     logits = logits / max(temperature, 1e-5)
 
 
     if top_k is not None:
     if top_k is not None:
@@ -44,8 +62,13 @@ def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = Non
     return probs
     return probs
 
 
 
 
-def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
-    probs = logits_to_probs(logits[0, -1], temperature, top_k)
+def sample(
+    logits,
+    temperature: float = 1.0,
+    top_k: Optional[int] = None,
+    top_p: Optional[int] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    probs = logits_to_probs(logits[0, -1], temperature, top_k, top_p)
     idx_next = multinomial_sample_one_no_sync(probs)
     idx_next = multinomial_sample_one_no_sync(probs)
     return idx_next, probs
     return idx_next, probs
 
 
@@ -57,10 +80,15 @@ def decode_token(
     logits = model.forward_generate(x, input_pos)
     logits = model.forward_generate(x, input_pos)
     codebooks = [sample(logits.token_logits, **sampling_kwargs)[0]]
     codebooks = [sample(logits.token_logits, **sampling_kwargs)[0]]
 
 
-    # Disable <s> and </s> tokens for 2-n codebooks
-    logits.codebook_logits[:, :, 1:, :2] = -float("Inf")
-    for i in range(model.config.num_codebooks):
-        codebooks.append(sample(logits.codebook_logits[:, :, i], **sampling_kwargs)[0])
+    # Disable <s> and </s> tokens for codebooks
+    if model.config.num_codebooks != 0:
+        logits.codebook_logits[:, :, :, :2] = -float("Inf")
+
+        for i in range(model.config.num_codebooks):
+            codebooks.append(
+                sample(logits.codebook_logits[:, :, i], **sampling_kwargs)[0]
+            )
+
     return torch.stack(codebooks, dim=0)
     return torch.stack(codebooks, dim=0)
 
 
 
 
@@ -83,8 +111,8 @@ def decode_n_tokens(
         callback(new_tokens[-1])
         callback(new_tokens[-1])
         cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
         cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
 
 
-        # TODO: use tokenizer
-        if (cur_token[0, 1:, 0] == 1).any():
+        # TODO: use tokenizer's eos
+        if (cur_token[0, 0, -1] == 2).any():
             print("EOS detected, stopping generation")
             print("EOS detected, stopping generation")
             break
             break
 
 
@@ -151,14 +179,41 @@ def generate(
 
 
 
 
 def encode_tokens(tokenizer, string, bos=True, device="cuda"):
 def encode_tokens(tokenizer, string, bos=True, device="cuda"):
+    # data/Genshin/Chinese/神里绫华/vo_ayaka_character_idle_04.npy
+    prompt = g2p("剑,就和茶一样,细细品味才能理解其中风雅。 " + string)
+    prompt = [
+        (f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
+        for _, i in prompt
+    ]
+    prompt = " ".join(prompt)
+    string = f"[INST] {prompt} [/INST]"
+    print("Encoding string:", string)
+
+    data = np.load("data/Genshin/Chinese/神里绫华/vo_ayaka_character_idle_02.npy")
+    codes = [f"<s:{i}>" for i in data[0]]
+
     tokens = tokenizer.encode(
     tokens = tokenizer.encode(
-        string, max_length=10**6, add_special_tokens=bos, truncation=False
+        string + " ".join(codes),
+        max_length=10**6,
+        add_special_tokens=bos,
+        truncation=False,
     )
     )
     tokens = torch.tensor([tokens], dtype=torch.int, device=device)
     tokens = torch.tensor([tokens], dtype=torch.int, device=device)
 
 
     # Codebooks
     # Codebooks
-    zeros = torch.zeros((4, tokens.size(1)), dtype=torch.int, device=device)
-    return torch.cat((tokens, zeros), dim=0)
+    # zeros = torch.zeros((4, tokens.size(1)), dtype=torch.int, device=device)
+    # prompt = torch.cat((tokens, zeros), dim=0)
+
+    # # Get prompt tokens
+    # data = np.load("data/Genshin/Chinese/神里绫华/vo_ayaka_character_idle_02.npy")
+    # data = torch.from_numpy(data).to(device=device, dtype=torch.int) + 2
+
+    # zeros = torch.zeros((1, data.size(1)), dtype=torch.int, device=device) + 32311 # 32311 is the <pad> token
+    # data = torch.cat((zeros, data), dim=0)
+    # prompt = torch.cat((prompt, data), dim=1)
+    # print(prompt)
+
+    return tokens
 
 
 
 
 def _load_model(checkpoint_path, device, precision, use_tp):
 def _load_model(checkpoint_path, device, precision, use_tp):
@@ -174,7 +229,7 @@ def _load_model(checkpoint_path, device, precision, use_tp):
                 rope_base=10000,
                 rope_base=10000,
                 norm_eps=1e-5,
                 norm_eps=1e-5,
                 codebook_size=168,
                 codebook_size=168,
-                num_codebooks=4,
+                num_codebooks=0,
             )
             )
         )
         )
 
 
@@ -216,13 +271,13 @@ def main(
     interactive: bool = False,
     interactive: bool = False,
     num_samples: int = 5,
     num_samples: int = 5,
     max_new_tokens: int = 100,
     max_new_tokens: int = 100,
-    top_k: int = 200,
+    top_k: int = None,
+    top_p: int = None,
     temperature: float = 0.8,
     temperature: float = 0.8,
     checkpoint_path: Path = Path(
     checkpoint_path: Path = Path(
         "results/text2semantic_400m/checkpoints/step_000025000.ckpt"
         "results/text2semantic_400m/checkpoints/step_000025000.ckpt"
     ),
     ),
     compile: bool = True,
     compile: bool = True,
-    compile_prefill: bool = False,
     profile: Optional[Path] = None,
     profile: Optional[Path] = None,
     tokenizer: str = "fishaudio/speech-lm-v1",
     tokenizer: str = "fishaudio/speech-lm-v1",
 ) -> None:
 ) -> None:
@@ -249,16 +304,8 @@ def main(
     print(f"Time to load model: {time.time() - t0:.02f} seconds")
     print(f"Time to load model: {time.time() - t0:.02f} seconds")
 
 
     tokenizer = AutoTokenizer.from_pretrained(tokenizer)
     tokenizer = AutoTokenizer.from_pretrained(tokenizer)
-    prompt = g2p(prompt)
-    prompt = [
-        (f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
-        for _, i in prompt
-    ]
-    prompt = " ".join(prompt)
     print(prompt)
     print(prompt)
-    encoded = encode_tokens(
-        tokenizer, f"[INST] {prompt} [/INST]", bos=True, device=device
-    )
+    encoded = encode_tokens(tokenizer, f"{prompt}", bos=True, device=device)
     print(encoded[0])
     print(encoded[0])
     prompt_length = encoded.size(1)
     prompt_length = encoded.size(1)
 
 
@@ -322,6 +369,7 @@ def main(
                 callback=callback,
                 callback=callback,
                 temperature=temperature,
                 temperature=temperature,
                 top_k=top_k,
                 top_k=top_k,
+                top_p=top_p,
             )
             )
         if i == -1:
         if i == -1:
             print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
             print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
@@ -336,7 +384,7 @@ def main(
 
 
         if not interactive:
         if not interactive:
             print(tokenizer.decode(y[0].tolist()))
             print(tokenizer.decode(y[0].tolist()))
-            codes = y[1:, prompt_length:-1] - 2
+            codes = y[1:, prompt_length - 120 : -1] - 2
             assert (codes >= 0).all()
             assert (codes >= 0).all()
             import numpy as np
             import numpy as np
 
 
@@ -378,14 +426,15 @@ if __name__ == "__main__":
     parser.add_argument(
     parser.add_argument(
         "--max_new_tokens", type=int, default=768, help="Maximum number of new tokens."
         "--max_new_tokens", type=int, default=768, help="Maximum number of new tokens."
     )
     )
-    parser.add_argument("--top_k", type=int, default=10, help="Top-k for sampling.")
+    parser.add_argument("--top_k", type=int, default=None, help="Top-k for sampling.")
+    parser.add_argument("--top_p", type=int, default=0.7, help="Top-k for sampling.")
     parser.add_argument(
     parser.add_argument(
         "--temperature", type=float, default=1.0, help="Temperature for sampling."
         "--temperature", type=float, default=1.0, help="Temperature for sampling."
     )
     )
     parser.add_argument(
     parser.add_argument(
         "--checkpoint_path",
         "--checkpoint_path",
         type=Path,
         type=Path,
-        default=Path("results/text2semantic_400m/step_000025000_weights.ckpt"),
+        default=Path("results/text2semantic_400m/step_000035000_weights.ckpt"),
         help="Model checkpoint path.",
         help="Model checkpoint path.",
     )
     )
     parser.add_argument(
     parser.add_argument(
@@ -400,6 +449,7 @@ if __name__ == "__main__":
         args.num_samples,
         args.num_samples,
         args.max_new_tokens,
         args.max_new_tokens,
         args.top_k,
         args.top_k,
+        args.top_p,
         args.temperature,
         args.temperature,
         args.checkpoint_path,
         args.checkpoint_path,
         args.compile,
         args.compile,

+ 22 - 13
fish_speech/models/text2semantic/lit_module.py

@@ -36,20 +36,22 @@ class TextToSemantic(L.LightningModule):
 
 
         # Generate labels
         # Generate labels
         labels = batch["labels"]
         labels = batch["labels"]
-        token_loss = F.cross_entropy(
+        loss = F.cross_entropy(
             outputs.token_logits.reshape(-1, outputs.token_logits.size(-1)),
             outputs.token_logits.reshape(-1, outputs.token_logits.size(-1)),
             labels[:, 0].reshape(-1),
             labels[:, 0].reshape(-1),
             ignore_index=-100,
             ignore_index=-100,
         )
         )
 
 
-        codebook_labels = labels[:, 1:].mT
-        semantic_loss = F.cross_entropy(
-            outputs.codebook_logits.reshape(-1, outputs.codebook_logits.size(-1)),
-            codebook_labels.reshape(-1),
-            ignore_index=-100,
-        )
+        # If we have a codebook, add the loss
+        if self.model.config.num_codebooks != 0:
+            codebook_labels = labels[:, 1:].mT
+            semantic_loss = F.cross_entropy(
+                outputs.codebook_logits.reshape(-1, outputs.codebook_logits.size(-1)),
+                codebook_labels.reshape(-1),
+                ignore_index=-100,
+            )
 
 
-        loss = token_loss + semantic_loss
+            loss = loss + semantic_loss
 
 
         self.log(
         self.log(
             f"{stage}/loss",
             f"{stage}/loss",
@@ -61,11 +63,18 @@ class TextToSemantic(L.LightningModule):
         )
         )
 
 
         # Top-5 accuracy
         # Top-5 accuracy
-        _, indices = outputs.codebook_logits.topk(5, dim=-1)
-        correct = indices.eq(codebook_labels.unsqueeze(-1))
-        correct[codebook_labels == -100] = 0
-        correct = correct.sum()
-        accuracy = correct / (codebook_labels != -100).sum()
+        if self.model.config.num_codebooks == 0:
+            _, indices = outputs.token_logits.topk(5, dim=-1)
+            correct = indices.eq(labels[:, 0].unsqueeze(-1))
+            correct[labels[:, 0] == -100] = 0
+            correct = correct.sum()
+            accuracy = correct / (labels[:, 0] != -100).sum()
+        else:
+            _, indices = outputs.codebook_logits.topk(5, dim=-1)
+            correct = indices.eq(codebook_labels.unsqueeze(-1))
+            correct[codebook_labels == -100] = 0
+            correct = correct.sum()
+            accuracy = correct / (codebook_labels != -100).sum()
 
 
         self.log(
         self.log(
             f"{stage}/top_5_accuracy",
             f"{stage}/top_5_accuracy",

+ 43 - 43
fish_speech/models/text2semantic/llama.py

@@ -30,6 +30,7 @@ class ModelArgs:
     # Additional decoding heads
     # Additional decoding heads
     codebook_size: int = 160
     codebook_size: int = 160
     num_codebooks: int = 4
     num_codebooks: int = 4
+    codebook_padding_idx: int = 0
 
 
     def __post_init__(self):
     def __post_init__(self):
         if self.n_local_heads == -1:
         if self.n_local_heads == -1:
@@ -123,38 +124,39 @@ class Transformer(nn.Module):
                 max_batch_size, max_seq_len, self.config.n_local_heads, head_dim
                 max_batch_size, max_seq_len, self.config.n_local_heads, head_dim
             )
             )
 
 
-    def forward(self, x: Tensor, key_padding_mask: Optional[Tensor] = None) -> Tensor:
-        # x: (batch, num_codebooks + 1, seq_len)
-        seq_len = x.size(2)
-
+    def embed(self, x: Tensor) -> Tensor:
         # Here we want to merge the embeddings of the codebooks
         # Here we want to merge the embeddings of the codebooks
+        if self.config.num_codebooks == 0:
+            return self.embeddings(x[:, 0])
+
         vocab_embeds = [self.embeddings(x[:, 0])]
         vocab_embeds = [self.embeddings(x[:, 0])]
         for i in range(self.config.num_codebooks):
         for i in range(self.config.num_codebooks):
             emb = self.embeddings(
             emb = self.embeddings(
                 x[:, i + 1] + i * self.config.codebook_size + self.config.vocab_size
                 x[:, i + 1] + i * self.config.codebook_size + self.config.vocab_size
             )
             )
+            emb[x[:, i + 1] == self.config.codebook_padding_idx] = 0
             vocab_embeds.append(emb)
             vocab_embeds.append(emb)
 
 
         x = torch.stack(vocab_embeds, dim=3)
         x = torch.stack(vocab_embeds, dim=3)
-        x = x.mean(dim=3)
-
-        mask = self.causal_mask[None, None, :seq_len, :seq_len]  # (B, N, Q, K)
-        freqs_cis = self.freqs_cis[:seq_len]
-
-        # Not that the causal mask here follows the definition of scaled_dot_product_attention
-        # That is, FALSE means masked out
-        # To maintain consistency, key_padding_mask use TRUE to mask out
-        if key_padding_mask is not None:
-            mask = mask & key_padding_mask[:, None, None, :].logical_not()
+        return x.sum(dim=3)
 
 
+    def compute(
+        self, x: Tensor, freqs_cis: Tensor, mask: Tensor
+    ) -> TransformerForwardResult:
         for layer in self.layers:
         for layer in self.layers:
             x = layer(x, freqs_cis, mask)
             x = layer(x, freqs_cis, mask)
 
 
         x = self.norm(x)
         x = self.norm(x)
         logits = self.output(x)
         logits = self.output(x)
         token_logits = logits[:, :, : self.config.vocab_size]
         token_logits = logits[:, :, : self.config.vocab_size]
-        codebook_logits = logits[:, :, self.config.vocab_size :]
 
 
+        if self.config.num_codebooks == 0:
+            return TransformerForwardResult(
+                token_logits=token_logits,
+                codebook_logits=None,
+            )
+
+        codebook_logits = logits[:, :, self.config.vocab_size :]
         codebook_logits = rearrange(
         codebook_logits = rearrange(
             codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
             codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
         )
         )
@@ -164,6 +166,26 @@ class Transformer(nn.Module):
             codebook_logits=codebook_logits,
             codebook_logits=codebook_logits,
         )
         )
 
 
+    def forward(
+        self, x: Tensor, key_padding_mask: Optional[Tensor] = None
+    ) -> TransformerForwardResult:
+        # x: (batch, num_codebooks + 1, seq_len)
+        seq_len = x.size(2)
+
+        # Here we want to merge the embeddings of the codebooks
+        x = self.embed(x)
+
+        mask = self.causal_mask[None, None, :seq_len, :seq_len]  # (B, N, Q, K)
+        freqs_cis = self.freqs_cis[:seq_len]
+
+        # Not that the causal mask here follows the definition of scaled_dot_product_attention
+        # That is, FALSE means masked out
+        # To maintain consistency, key_padding_mask use TRUE to mask out
+        if key_padding_mask is not None:
+            mask = mask & key_padding_mask[:, None, None, :].logical_not()
+
+        return self.compute(x, freqs_cis, mask)
+
     def forward_generate(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
     def forward_generate(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
         # x: (batch, num_codebooks + 1, 1)
         # x: (batch, num_codebooks + 1, 1)
 
 
@@ -171,38 +193,16 @@ class Transformer(nn.Module):
             self.max_seq_len != -1 and self.max_batch_size != -1
             self.max_seq_len != -1 and self.max_batch_size != -1
         ), "Please call setup_caches before forward_generate"
         ), "Please call setup_caches before forward_generate"
 
 
-        # Here we want to merge the embeddings of the codebooks
-        vocab_embeds = [self.embeddings(x[:, 0])]
-        for i in range(self.config.num_codebooks):
-            emb = self.embeddings(
-                x[:, i + 1] + i * self.config.codebook_size + self.config.vocab_size
-            )
-            vocab_embeds.append(emb)
-
-        x = torch.stack(vocab_embeds, dim=3)
-        x = x.mean(dim=3)
+        x = self.embed(x)
 
 
         mask = self.causal_mask[
         mask = self.causal_mask[
             None, None, input_pos, : self.max_seq_len
             None, None, input_pos, : self.max_seq_len
         ]  # (B, N, Q, K)
         ]  # (B, N, Q, K)
         freqs_cis = self.freqs_cis[input_pos]
         freqs_cis = self.freqs_cis[input_pos]
 
 
-        for layer in self.layers:
-            x = layer(x, freqs_cis, mask, input_pos=input_pos)
-
-        x = self.norm(x)
-        logits = self.output(x)
-        token_logits = logits[:, :, : self.config.vocab_size]
-        codebook_logits = logits[:, :, self.config.vocab_size :]
-
-        codebook_logits = rearrange(
-            codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
-        )
+        # TODO: support key padding mask for generation
 
 
-        return TransformerForwardResult(
-            token_logits=token_logits,
-            codebook_logits=codebook_logits,
-        )
+        return self.compute(x, freqs_cis, mask)
 
 
 
 
 class TransformerBlock(nn.Module):
 class TransformerBlock(nn.Module):
@@ -339,8 +339,8 @@ if __name__ == "__main__":
         dim=768,
         dim=768,
         rope_base=10000,
         rope_base=10000,
         norm_eps=1e-5,
         norm_eps=1e-5,
-        codebook_size=168,
-        num_codebooks=4,
+        codebook_size=0,
+        num_codebooks=0,
     )
     )
 
 
     model = Transformer(args)
     model = Transformer(args)
@@ -352,4 +352,4 @@ if __name__ == "__main__":
     key_padding_mask[0, 2:] = True
     key_padding_mask[0, 2:] = True
     x1 = model(inputs, key_padding_mask=key_padding_mask)
     x1 = model(inputs, key_padding_mask=key_padding_mask)
     print(x1.token_logits.shape)
     print(x1.token_logits.shape)
-    print(x1.codebook_logits.shape)
+    # print(x1.codebook_logits.shape)

+ 0 - 534
fish_speech/models/text2semantic/modules.py

@@ -1,534 +0,0 @@
-import math
-from typing import Optional
-
-import torch
-from einops import rearrange
-from torch import nn
-from torch.nn import functional as F
-from transformers.modeling_attn_mask_utils import AttentionMaskConverter
-
-
-def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
-    """
-    Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
-
-    This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
-    and the end index 'end'. The 'theta' parameter scales the frequencies.
-    The returned tensor contains complex values in complex64 data type.
-
-    Args:
-        dim (int): Dimension of the frequency tensor.
-        end (int): End index for precomputing frequencies.
-        theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
-
-    Returns:
-        torch.Tensor: Precomputed frequency tensor with complex exponentials.
-    """
-    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
-    t = torch.arange(end, device=freqs.device)  # type: ignore
-    freqs = torch.outer(t, freqs).float()  # type: ignore
-    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
-    return freqs_cis
-
-
-def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
-    """
-    Reshape frequency tensor for broadcasting it with another tensor.
-
-    This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
-    for the purpose of broadcasting the frequency tensor during element-wise operations.
-
-    Args:
-        freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
-        x (torch.Tensor): Target tensor for broadcasting compatibility.
-
-    Returns:
-        torch.Tensor: Reshaped frequency tensor.
-
-    Raises:
-        AssertionError: If the frequency tensor doesn't match the expected shape.
-        AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
-    """
-    ndim = x.ndim
-    assert 0 <= 1 < ndim
-    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
-    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
-    return freqs_cis.view(*shape)
-
-
-def apply_rotary_emb(
-    x: torch.Tensor,
-    freqs_cis: torch.Tensor,
-) -> tuple[torch.Tensor, torch.Tensor]:
-    x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
-    freqs_cis = reshape_for_broadcast(freqs_cis, x_)
-    return torch.view_as_real(x_ * freqs_cis).flatten(3).type_as(x)
-
-
-class MultiheadAttention(nn.Module):
-    def __init__(self, d_model, nhead, dropout=0.1, is_cross_attention=False):
-        super().__init__()
-        assert d_model % nhead == 0
-        self.nhead = nhead
-        self.d_model = d_model
-        self.head_dim = d_model // nhead
-        self.is_cross_attention = is_cross_attention
-
-        # Auto fuse linear projection
-        if is_cross_attention:
-            self.q_proj = nn.Linear(d_model, d_model)
-            self.kv_proj = nn.Linear(d_model, d_model * 2)
-        else:
-            self.qkv_proj = nn.Linear(d_model, d_model * 3)
-
-        self.o_proj = nn.Linear(d_model, d_model)
-        self.dropout = nn.Dropout(dropout)
-
-    def forward(
-        self,
-        q,
-        freqs_cis_q,
-        kv=None,
-        freqs_cis_kv=None,
-        attn_mask=None,
-        input_pos=None,
-        kv_cache=None,
-    ):
-        if self.is_cross_attention:
-            q = self.q_proj(q)
-            if kv is None:
-                assert self.kv_cache is not None, "kv_cache should be initialized"
-                k, v = None
-            else:
-                # Using kv cache
-                kv = self.kv_proj(kv)
-                k, v = torch.chunk(kv, 2, dim=-1)
-        else:
-            assert kv is None, f"kv should be None for self attention"
-            assert (
-                freqs_cis_kv is None
-            ), f"freqs_cis_kv should be None for self attention"
-            q, k, v = torch.chunk(self.qkv_proj(q), 3, dim=-1)
-
-        # max_batch_size, max_seq_length, n_heads, head_dim
-        q = rearrange(q, "b t (h d) -> b t h d", h=self.nhead, d=self.head_dim)
-        q = apply_rotary_emb(q, freqs_cis_q)
-
-        if freqs_cis_kv is None:
-            freqs_cis_kv = freqs_cis_q
-
-        # Only do when self attention or cross attention without kv cache
-        if k is not None:
-            assert v is not None, "v should not be None when k is not None"
-            k = rearrange(k, "b t (h d) -> b t h d", h=self.nhead, d=self.head_dim)
-            v = rearrange(v, "b t (h d) -> b t h d", h=self.nhead, d=self.head_dim)
-            k = apply_rotary_emb(k, freqs_cis_kv)
-
-        if kv_cache is not None:
-            if k is None:
-                assert v is None, "v should be None when k is None"
-                k, v = kv_cache[0], kv_cache[1]
-            else:
-                k = torch.cat([kv_cache[0], k], dim=1)
-                v = torch.cat([kv_cache[1], v], dim=1)
-                kv_cache = (k, v)
-
-        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
-        value = F.scaled_dot_product_attention(
-            q,
-            k,
-            v,
-            attn_mask=attn_mask,
-            dropout_p=self.dropout.p if self.training else 0,
-        )
-
-        value = rearrange(value, "b h t d -> b t (h d)")
-        return self.o_proj(value), kv_cache
-
-
-class GluMLP(nn.Module):
-    def __init__(self, hidden_size=1024, intermediate_size=None, activation=nn.SiLU):
-        super().__init__()
-
-        if intermediate_size is None:
-            intermediate_size = hidden_size * (11 / 3)
-            intermediate_size = round(intermediate_size / 8) * 8
-
-        self.hidden_size = hidden_size
-        self.intermediate_size = intermediate_size
-
-        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
-        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
-        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
-        self.act_fn = activation()
-
-    def forward(self, x):
-        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
-
-
-class RMSNorm(nn.Module):
-    def __init__(self, hidden_size, eps=1e-6):
-        """
-        RMSNorm is equivalent to T5LayerNorm
-        """
-        super().__init__()
-
-        self.weight = nn.Parameter(torch.ones(hidden_size))
-        self.variance_epsilon = eps
-
-    def forward(self, hidden_states):
-        input_dtype = hidden_states.dtype
-        hidden_states = hidden_states.to(torch.float32)
-        variance = hidden_states.pow(2).mean(-1, keepdim=True)
-        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
-
-        return self.weight * hidden_states.to(input_dtype)
-
-
-class TransformerEncoderLayer(nn.Module):
-    def __init__(self, hidden_size=1024, intermediate_size=None, nhead=16, dropout=0.1):
-        super().__init__()
-
-        self.attention = MultiheadAttention(hidden_size, nhead, dropout=dropout)
-        self.ffn = GluMLP(hidden_size=hidden_size, intermediate_size=intermediate_size)
-
-        self.attention_norm = RMSNorm(hidden_size, eps=1e-6)
-        self.ffn_norm = RMSNorm(hidden_size, eps=1e-6)
-
-    def forward(
-        self,
-        x,
-        freqs_cis,
-        attn_mask=None,
-        input_pos=None,
-    ):
-        x = (
-            x
-            + self.attention(
-                q=self.attention_norm(x),
-                freqs_cis_q=freqs_cis,
-                attn_mask=attn_mask,
-                input_pos=input_pos,
-            )[0]
-        )
-
-        return x + self.ffn(self.ffn_norm(x))
-
-
-class TransformerDecoderLayer(nn.Module):
-    def __init__(self, hidden_size=1024, intermediate_size=None, nhead=16, dropout=0.1):
-        super().__init__()
-
-        self.self_attention = MultiheadAttention(hidden_size, nhead, dropout=dropout)
-        self.cross_attention = MultiheadAttention(
-            hidden_size, nhead, dropout=dropout, is_cross_attention=True
-        )
-        self.ffn = GluMLP(hidden_size=hidden_size, intermediate_size=intermediate_size)
-
-        self.self_attention_norm = RMSNorm(hidden_size, eps=1e-6)
-        self.cross_attention_norm = RMSNorm(hidden_size, eps=1e-6)
-        self.ffn_norm = RMSNorm(hidden_size, eps=1e-6)
-
-    def forward(
-        self,
-        x,
-        context,
-        freqs_cis_q,
-        freqs_cis_kv,
-        self_attn_mask=None,
-        cross_attn_mask=None,
-        input_pos=None,
-    ):
-        x = x + self.self_attention(
-            q=self.self_attention_norm(x),
-            freqs_cis_q=freqs_cis_q,
-            attn_mask=self_attn_mask,
-            input_pos=input_pos,
-        )
-
-        x = x + self.cross_attention(
-            q=self.cross_attention_norm(x),
-            kv=context,
-            freqs_cis_q=freqs_cis_q,
-            freqs_cis_kv=freqs_cis_kv,
-            attn_mask=cross_attn_mask,
-            input_pos=input_pos,
-        )
-
-        return x + self.ffn(self.ffn_norm(x))
-
-
-class Transformer(nn.Module):
-    def __init__(
-        self,
-        vocab_size,
-        codebook_size,
-        num_codebooks,
-        hidden_size=1024,
-        intermediate_size=None,
-        nhead=16,
-        num_encoder_layers=12,
-        num_decoder_layers=12,
-        dropout=0.1,
-        max_position=4096,
-    ):
-        super().__init__()
-
-        self.encoder_embedding = nn.Embedding(vocab_size, hidden_size)
-        self.decoder_embeddings = nn.ModuleList(
-            [nn.Embedding(codebook_size, hidden_size) for _ in range(num_codebooks)]
-        )
-        self.decoder_head = nn.Linear(hidden_size, codebook_size * num_codebooks)
-        self.codebook_size = codebook_size
-        self.num_codebooks = num_codebooks
-        self.nhead = nhead
-
-        self.encoder = nn.ModuleList(
-            [
-                TransformerEncoderLayer(
-                    hidden_size=hidden_size,
-                    intermediate_size=intermediate_size,
-                    nhead=nhead,
-                    dropout=dropout,
-                )
-                for _ in range(num_encoder_layers)
-            ]
-        )
-
-        self.decoder = nn.ModuleList(
-            [
-                TransformerDecoderLayer(
-                    hidden_size=hidden_size,
-                    intermediate_size=intermediate_size,
-                    nhead=nhead,
-                    dropout=dropout,
-                )
-                for _ in range(num_decoder_layers)
-            ]
-        )
-
-        self.register_buffer(
-            "freqs_cis",
-            precompute_freqs_cis(hidden_size // nhead, max_position, theta=10000.0),
-        )
-
-        causual_mask = torch.triu(
-            torch.ones(max_position, max_position), diagonal=1
-        ).bool()
-        causual_mask = torch.zeros(max_position, max_position).masked_fill(
-            causual_mask, float("-inf")
-        )
-
-        self.register_buffer("causual_mask", causual_mask)
-
-        # The following are reserved for kv cache
-        self.max_batch_size = -1
-        self.max_seq_length = -1
-
-    def setup_kv_caches(self, max_batch_size, max_seq_length):
-        if (
-            self.max_seq_length >= max_seq_length
-            and self.max_batch_size >= max_batch_size
-        ):
-            return
-
-        if max_seq_length % 8 != 0:
-            max_seq_length = max_seq_length + (8 - max_seq_length % 8)
-
-        self.max_seq_length = max_seq_length
-        self.max_batch_size = max_batch_size
-
-        for b in self.decoder:
-            b.self_attention.kv_cache = KVCache(
-                max_batch_size,
-                max_seq_length,
-                b.self_attention.nhead,
-                b.self_attention.head_dim,
-            ).to(b.self_attention_norm.weight.device)
-
-            b.cross_attention.kv_cache = KVCache(
-                max_batch_size,
-                max_seq_length,
-                b.cross_attention.nhead,
-                b.cross_attention.head_dim,
-            ).to(b.cross_attention_norm.weight.device)
-
-    def get_key_padding_mask(self, key_padding_mask, q_size=None):
-        # inputs: (B, T) bool ->
-        assert key_padding_mask.dtype == torch.bool and key_padding_mask.ndim == 2
-
-        key_padding_mask = (
-            key_padding_mask.unsqueeze(1).unsqueeze(1).expand(-1, self.nhead, -1, -1)
-        )
-
-        key_padding_mask = key_padding_mask.reshape(
-            key_padding_mask.shape[0], self.nhead, 1, key_padding_mask.shape[1]
-        )
-
-        if q_size is not None:
-            key_padding_mask = key_padding_mask.expand(-1, -1, q_size, -1)
-
-        new_mask = torch.zeros(
-            *key_padding_mask.shape, dtype=torch.float, device=key_padding_mask.device
-        )
-        new_mask = new_mask.masked_fill(key_padding_mask, float("-inf"))
-
-        return new_mask
-
-    def forward_encoder(
-        self, inputs, input_mask=None
-    ) -> tuple[torch.Tensor, torch.Tensor]:
-        # inputs: (B, T)
-        # input_mask: (B, T), bool mask
-        inputs = self.encoder_embedding(inputs)
-
-        # Calculate mask
-        if input_mask is None:
-            # Assume no padding
-            input_mask = torch.zeros(
-                inputs.shape[0], inputs.shape[1], dtype=torch.bool, device=inputs.device
-            )
-
-        input_mask = self.get_key_padding_mask(input_mask, q_size=None).to(inputs.dtype)
-
-        freqs_cis = self.freqs_cis[: inputs.shape[1]]
-        input_mask_self = input_mask.expand(-1, -1, inputs.shape[1], -1)
-
-        for layer in self.encoder:
-            inputs = layer(inputs, freqs_cis=freqs_cis, attn_mask=input_mask_self)
-
-        return inputs, input_mask
-
-    def forward_decoder(
-        self, codes, inputs, input_mask, codes_mask=None, input_pos=None
-    ):
-        # codes: (B, C, T)
-        # inputs: (B, T, N)
-
-        print(f"Codes: {codes.shape}, Inputs: {inputs.shape}")
-        codes = rearrange(codes, "b c t -> c b t")
-        codes = torch.stack(
-            [emb(code) for emb, code in zip(self.decoder_embeddings, codes)], dim=0
-        )
-        codes = torch.mean(codes, dim=0)  # (B, T)
-
-        # If kv cache is enabled
-        input_mask = input_mask.expand(-1, -1, codes.shape[1], -1)
-
-        # Calculate mask
-        if input_pos is not None:
-            attn_mask = self.causual_mask[: codes.shape[1], : codes.shape[1]]
-        else:
-            attn_mask = None
-
-        # if codes_mask is not None:
-        #     codes_mask = self.get_key_padding_mask(codes_mask)
-        #     attn_mask = attn_mask + codes_mask
-
-        # For kv cache
-        if input_pos is not None:
-            freqs_cis_q = self.freqs_cis[input_pos]
-        else:
-            freqs_cis_q = self.freqs_cis[: codes.shape[1]]
-
-        freqs_cis_kv = self.freqs_cis[: inputs.shape[1]]
-
-        for layer in self.decoder:
-            codes = layer(
-                codes,
-                inputs,
-                freqs_cis_q=freqs_cis_q,
-                freqs_cis_kv=freqs_cis_kv,
-                self_attn_mask=attn_mask,
-                cross_attn_mask=input_mask,
-                input_pos=input_pos,
-            )
-
-        codes = self.decoder_head(codes)
-        codes = rearrange(
-            codes, "b t (c d) -> b c t d", c=self.num_codebooks, d=self.codebook_size
-        )
-
-        return codes
-
-    def forward(
-        self,
-        inputs,
-        codes,
-        input_mask=None,
-        codes_mask=None,
-        input_pos=None,
-    ):
-        # inputs: (B, T)
-        # codes: (B, C, T)
-        # input_mask: (B, T), bool mask
-        # codes_mask: (B, T), bool mask
-        # input_pos: (B, T), int mask
-
-        inputs, input_mask = self.forward_encoder(inputs, input_mask)
-        codes = self.forward_decoder(codes, inputs, input_mask, codes_mask, input_pos)
-
-        return codes
-
-
-if __name__ == "__main__":
-    mha = MultiheadAttention(512, 8, dropout=0, is_cross_attention=True)
-    mha.eval()
-    mha.cuda()
-
-    q, kv = torch.randn(2, 10, 16, 512)
-    q, kv = q.cuda(), kv.cuda()
-
-    mha.bfloat16()
-    q, kv = q.bfloat16(), kv.bfloat16()
-    freqs_cis = precompute_freqs_cis(512 // 8, 4096 * 2).cuda()[:16]
-
-    # Causual mask
-    attn_mask = torch.triu(torch.ones(16, 16), diagonal=1).bool().cuda()
-    o = mha(q, freqs_cis, kv=kv, attn_mask=attn_mask)
-
-    trans = (
-        Transformer(
-            vocab_size=30000,
-            codebook_size=120,
-            num_codebooks=4,
-            hidden_size=1024,
-            intermediate_size=None,
-            nhead=16,
-            num_encoder_layers=12,
-            num_decoder_layers=12,
-        )
-        .bfloat16()
-        .cuda()
-    )
-    trans.eval()
-
-    # Print n param
-    print("Total params:", sum(i.numel() for i in trans.parameters()) / 1024 / 1024)
-    inputs = torch.randint(0, 1000, (2, 16)).cuda()
-    codes = torch.randint(0, 120, (2, 4, 128)).cuda()
-    x = trans(inputs, codes)
-    x1 = trans(inputs, codes)
-
-    assert torch.allclose(x, x1, atol=1e-4, rtol=1e-3), "Model is not deterministic"
-    print("Model is deterministic")
-
-    # Test kv cache
-    trans.setup_kv_caches(2, 1024)
-    inputs, inputs_mask = trans.forward_encoder(inputs)
-
-    outputs = []
-
-    for i in range(128):
-        code = codes[..., i].unsqueeze(-1)
-        code_mask = torch.tensor([[1], [1]], dtype=torch.bool, device=code.device)
-        input_pos = torch.tensor([i], dtype=torch.long, device=code.device)
-        outputs.append(
-            trans.forward_decoder(
-                code, inputs, inputs_mask, code_mask, input_pos=input_pos
-            )
-        )
-
-    outputs = torch.cat(outputs, dim=2)
-    print(x.shape, outputs.shape)
-    assert torch.allclose(x, outputs, atol=1e-4, rtol=1e-3), "KV cache is not working"

+ 2 - 1
tools/llama/build_dataset.py

@@ -13,6 +13,7 @@ from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
 
 
 # Define datasets
 # Define datasets
 DATASETS = [
 DATASETS = [
+    # (root, name, languages, extension, group parent level)
     ("data/StarRail/Chinese", "StarRail", ["ZH", "EN"], ".lab", 1),
     ("data/StarRail/Chinese", "StarRail", ["ZH", "EN"], ".lab", 1),
     ("data/StarRail/English", "StarRail", ["EN"], ".lab", 1),
     ("data/StarRail/English", "StarRail", ["EN"], ".lab", 1),
     ("data/StarRail/Japanese", "StarRail", ["JP", "EN"], ".lab", 1),
     ("data/StarRail/Japanese", "StarRail", ["JP", "EN"], ".lab", 1),
@@ -94,7 +95,7 @@ def run_task(task):
 
 
 
 
 def main():
 def main():
-    dataset_fp = open("data/quantized-dataset-1205.protos", "wb")
+    dataset_fp = open("data/quantized-dataset-1208.protos", "wb")
     with Pool(16) as p:
     with Pool(16) as p:
         for result in tqdm(p.imap_unordered(run_task, task_generator())):
         for result in tqdm(p.imap_unordered(run_task, task_generator())):
             dataset_fp.write(result)
             dataset_fp.write(result)

+ 2 - 2
tools/llama/extract_model.py

@@ -1,7 +1,7 @@
 import torch
 import torch
 
 
 state_dict = torch.load(
 state_dict = torch.load(
-    "results/text2semantic_400m/checkpoints/step_000025000.ckpt", map_location="cpu"
+    "results/text2semantic_400m/checkpoints/step_000035000.ckpt", map_location="cpu"
 )["state_dict"]
 )["state_dict"]
 state_dict = {
 state_dict = {
     state_dict.replace("model.", ""): value
     state_dict.replace("model.", ""): value
@@ -9,4 +9,4 @@ state_dict = {
     if state_dict.startswith("model.")
     if state_dict.startswith("model.")
 }
 }
 
 
-torch.save(state_dict, "results/text2semantic_400m/step_000025000_weights.ckpt")
+torch.save(state_dict, "results/text2semantic_400m/step_000035000_weights.ckpt")

+ 3 - 1
tools/llama/rebuild_tokenizer.py

@@ -8,7 +8,9 @@ tokenizer = AutoTokenizer.from_pretrained(model_type)
 
 
 # new tokens
 # new tokens
 new_tokens = list(set(zh_symbols + jp_symbols + en_symbols))
 new_tokens = list(set(zh_symbols + jp_symbols + en_symbols))
-new_tokens = [f"<p:{token}>" for token in new_tokens]
+new_tokens = [f"<p:{token}>" for token in new_tokens] + [
+    f"<s:{i}>" for i in range(4096)
+]
 tokenizer.add_tokens(new_tokens)
 tokenizer.add_tokens(new_tokens)
 tokenizer.add_special_tokens({"pad_token": "<pad>"})
 tokenizer.add_special_tokens({"pad_token": "<pad>"})