Pārlūkot izejas kodu

Implement parallel decoding llama

Lengyue 2 gadi atpakaļ
vecāks
revīzija
4b22991668

+ 23 - 19
fish_speech/configs/text2semantic.yaml

@@ -2,7 +2,7 @@ defaults:
   - base
   - _self_
 
-project: text2semantic_400m
+project: text2semantic_100m
 
 # Lightning Trainer
 trainer:
@@ -15,21 +15,20 @@ trainer:
 # Dataset Configuration
 tokenizer:
   _target_: transformers.AutoTokenizer.from_pretrained
-  pretrained_model_name_or_path: fishaudio/speech-lm-300m
-  revision: mqtts-phones
-  padding_side: right
-  truncation_side: right
+  pretrained_model_name_or_path: fishaudio/speech-lm-v1
 
 # Dataset Configuration
 train_dataset:
-  _target_: fish_speech.datasets.text.StreamTextDataset
-  repo: fishaudio/cn-hubert-25hz-vq
-  prefix: 'data/train'
+  _target_: fish_speech.datasets.text.AutoAugTextDataset
+  files: [ data/quantized-dataset-1205.protos ]
+  tokenizer: ${tokenizer}
+  split: train
 
 val_dataset:
-  _target_: fish_speech.datasets.text.StreamTextDataset
-  repo: fishaudio/cn-hubert-25hz-vq
-  prefix: 'data/test'
+  _target_: fish_speech.datasets.text.AutoAugTextDataset
+  files: [ data/quantized-dataset-1205.protos ]
+  tokenizer: ${tokenizer}
+  split: val
 
 data:
   _target_: fish_speech.datasets.text.TextDataModule
@@ -38,6 +37,7 @@ data:
   num_workers: 4
   batch_size: 16
   tokenizer: ${tokenizer}
+  max_length: 1024
 
 # Model Configuration
 model:
@@ -45,14 +45,18 @@ model:
 
   model:
     # ~ 130M parameters, for debug purpose
-    _target_: fish_speech.models.text2semantic.modules.FishSpeechTransformer
-    vocab_size: 32248
-    codebook_size: 1032  # 1024 + 2 (bos, eos), make it divisible by 8
-    num_codebooks: 1
-    hidden_size: 1024
-    nhead: 16
-    num_encoder_layers: 12
-    num_decoder_layers: 12
+    _target_: fish_speech.models.text2semantic.llama.Transformer
+    config:
+      _target_: fish_speech.models.text2semantic.llama.ModelArgs
+      max_seq_len: 4096
+      vocab_size: 32312
+      n_layer: 12
+      n_head: 12
+      dim: 768
+      rope_base: 10000
+      norm_eps: 1e-5
+      codebook_size: 168
+      num_codebooks: 4
 
   optimizer:
     _target_: torch.optim.AdamW

+ 116 - 121
fish_speech/datasets/text.py

@@ -8,7 +8,6 @@ from random import Random
 from typing import Optional, Union
 
 import numpy as np
-import orjson
 import pyarrow.parquet as pq
 import torch
 import torch.nn.functional as F
@@ -16,10 +15,13 @@ from datasets.download.streaming_download_manager import xopen
 from huggingface_hub import HfApi
 from lightning import LightningDataModule
 from torch.distributed import get_rank, get_world_size, is_initialized
-from torch.utils.data import DataLoader, Dataset, IterableDataset, get_worker_info
+from torch.utils.data import DataLoader, IterableDataset, get_worker_info
 from transformers import AutoTokenizer
 
-from fish_speech.text import clean_text, g2p
+from fish_speech.datasets.protos.text_data_pb2 import Semantics
+from fish_speech.datasets.protos.text_data_stream import read_pb_stream
+from fish_speech.text.symbols import pad as pad_symbol
+from fish_speech.text.symbols import pu_symbols
 from fish_speech.utils import RankedLogger
 from fish_speech.utils.braceexpand import braceexpand
 
@@ -132,13 +134,6 @@ class StreamTextDataset(IterableDataset):
                 yield from texts
 
 
-# @dataclass
-# class DatasetLine:
-#     text: str
-#     semantic: str
-#     speaker: str
-
-
 class AutoAugTextDataset(IterableDataset):
     """
     Auto Augment Dataset by Speaker
@@ -150,87 +145,79 @@ class AutoAugTextDataset(IterableDataset):
 
     def __init__(
         self,
-        jsonl_files: list[str],
+        files: list[str],
         seed: int = 42,
-        phones_prob: float = 0.5,
+        phones_prob: float = 0.3,
         max_length: int = 1024,
-        order: Optional[list[str]] = None,
         tokenizer: AutoTokenizer = None,
+        split: Optional[str] = None,
     ):
         super().__init__()
 
-        self.jsonl_files = jsonl_files
+        self.files = files
         self.seed = seed
         self.phones_prob = phones_prob
         self.max_length = max_length
-        self.order = order
         self.tokenizer = tokenizer
 
         # Read all lines, and group by speaker
         self.groups = []
-        from tqdm import tqdm
-
-        for filename in self.jsonl_files:
-            with open(filename, "r") as f:
-                for json_line in tqdm(f):
-                    if json_line.strip() == "":
-                        continue
-
-                    line = orjson.loads(json_line)
-                    # for i in line["sentences"]:
-                    #     # Save memory
-                    #     i["semantics"] = np.array(i["semantics"], dtype=np.uint16)
-                    self.groups.append(line)
+        count = 0
+        for filename in self.files:
+            with open(filename, "rb") as f:
+                for text_data in read_pb_stream(f):
+                    self.groups.append(text_data)
+                    count += 1
 
-        import sys
+                    if count % 10000 == 0:
+                        log.info(f"Read {count} groups of text data")
 
-        print(sys.getsizeof(self.groups) / 1024 / 1024)
         # Shuffle the lines
-        Random(seed).shuffle(self.lines)
+        Random(seed).shuffle(self.groups)
 
-    def __iter__(self):
-        lines = split_by_rank_worker(self.lines)
-        random.shuffle(lines)
+        if split == "train":
+            self.groups = self.groups[:-500]
+        elif split == "val":
+            self.groups = self.groups[-500:]
 
-        for line in lines:
-            yield self.augment(line)
+    def __iter__(self):
+        groups = split_by_rank_worker(self.groups)
+        random.shuffle(groups)
 
-    def tokenize_sentence(
-        self, sentence: str, semantic: list[int], mode: str = "sample"
-    ):
-        sentence = clean_text(sentence)
+        for group in groups:
+            x = self.augment(group)
+            if x is not None:
+                yield x
 
+    def tokenize_sentence(self, sentence: str, phones: list[str], mode: str = "sample"):
         if (
             mode == "sample" and (random.random() < self.phones_prob)
         ) or mode == "phones":
-            sentence = " ".join([t for _, t in g2p(sentence, order=self.order)])
-
-        semantic = " ".join([f"<semantic_{i}>" for i in semantic])
+            sentence = " ".join(
+                [
+                    (f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
+                    for i in phones
+                ]
+            )
 
         tokens = self.tokenizer.encode(
-            f"{sentence} {semantic}", max_length=10**6, add_special_tokens=False
+            f"{sentence}", max_length=10**6, add_special_tokens=False
         )
-        return sentence, semantic, len(tokens)
+        return sentence, len(tokens)
 
-    def augment(self, line):
-        speaker = line.get("speaker", None)
-
-        # 20% to pure text or pure phones
+    def augment(self, group):
+        # 50% to pure text or pure phones
         mode = "sample"
-        if random.random() < 0.2:
+        if random.random() < 0.5:
             mode = random.choice(["text", "phones"])
 
-        if speaker is None:
-            a, b, _ = self.tokenize_sentence(line["text"], line["semantic"], mode=mode)
-            return {"text": f"[INST] {a} [/INST] {b} </s>"}
-
         # Random sample based on speaker using a truncated normal distribution
         a = torch.tensor([0], dtype=torch.float32)
         torch.nn.init.trunc_normal_(
             a,
             mean=self.max_length // 2,
             std=self.max_length // 4,
-            a=0,
+            a=10,
             b=self.max_length,
         )
         remaining_tokens = a.long().item() - 4
@@ -238,85 +225,97 @@ class AutoAugTextDataset(IterableDataset):
         final_text, final_semantic = [], []
 
         # Shuffle unique lines
-        idxs = list(range(len(self.speakers[speaker])))
+        idxs = list(range(len(group.sentences)))
         random.shuffle(idxs)
 
+        if len(idxs) == 0:
+            # Invalid group
+            return None
+
         while remaining_tokens > 0 and len(idxs) > 0:
-            line = self.speakers[speaker][idxs.pop()]
-            text, semantic, length = self.tokenize_sentence(
-                line["text"], line["semantic"], mode=mode
+            sentence = group.sentences[idxs.pop()]
+            text, length = self.tokenize_sentence(
+                sentence.text, sentence.phones, mode=mode
             )
-            remaining_tokens -= length
+            remaining_tokens -= length + len(sentence.semantics[0].values)
             final_text.append(text)
-            final_semantic.append(semantic)
+            final_semantic.append(sentence.semantics)
+
+        final_text = "[INST] " + "<pad>".join(final_text) + " [/INST]"
+        encoded = self.tokenizer.encode(
+            final_text, max_length=self.max_length, add_special_tokens=False
+        )
+        semantic_length = sum([len(i[0].values) for i in final_semantic])
+
+        # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
+        tokens = (
+            [self.tokenizer.bos_token_id]
+            + encoded
+            + [self.tokenizer.pad_token_id] * semantic_length
+            + [self.tokenizer.eos_token_id]
+        )
+        codes = [[0] * (len(encoded) + 1) for _ in range(len(final_semantic[0]))]
+        for segment in final_semantic:
+            for book_idx, book in enumerate(segment):
+                for j in book.values:
+                    codes[book_idx].append(int(j) + 2)
+
+        for book in codes:
+            book.append(1)
 
-        final_text = " ".join(final_text)
-        final_semantic = " ".join(final_semantic)
+        tokens = [tokens] + codes
+        tokens = torch.tensor(tokens, dtype=torch.long)
 
-        return {"text": f"[INST] {final_text} [/INST] {final_semantic} </s>"}
+        labels = tokens.clone()
+        labels[1:, : len(encoded) + 1] = -100  # Mask out the <s> tokens for semantic
+
+        return {
+            "tokens": tokens[:, :-1],
+            "labels": labels[:, 1:],
+        }
 
 
 @dataclass
 class TextDataCollator:
     tokenizer: AutoTokenizer
-    max_length: int = 512
+    max_length: int = 1024
 
     def __call__(self, examples):
-        texts = [i["text"] for i in examples]
-
-        if self.tokenizer.pad_token is None:
-            self.tokenizer.pad_token = self.tokenizer.eos_token
-
-        encoded_texts = self.tokenizer(
-            texts,
-            truncation=True,
-            padding=True,
-            max_length=self.max_length,
-            return_tensors="pt",
-            pad_to_multiple_of=8,
-        )
-
-        semantic = [i["semantic"] for i in examples]
-        max_semantic_length = max([len(i[0]) for i in semantic])
-
-        # Make xformers happy
-        if (max_semantic_length - 1) % 8 != 0:
-            max_semantic_length += 8 - (max_semantic_length - 1) % 8
-
-        if max_semantic_length > self.max_length + 1:
-            max_semantic_length = self.max_length + 1
-
-        codes, codes_mask = [], []
-
-        for i in semantic:
-            t = torch.tensor(i)
-            if t.shape[-1] >= max_semantic_length:
-                t = t[..., :max_semantic_length]
-
-            codes.append(
-                F.pad(
-                    t,
-                    (0, max_semantic_length - t.shape[-1]),
-                    value=1,
+        tokens, attention_masks, labels = [], [], []
+        for example in examples:
+            _tokens = example["tokens"][:, : self.max_length]
+            _labels = example["labels"][:, : self.max_length]
+            _attention_mask = torch.ones((self.max_length,), dtype=torch.bool)
+            _attention_mask[: _tokens.size(1)] = False
+
+            assert _tokens.size(1) == _labels.size(
+                1
+            ), f"{_tokens.size(1)} != {_labels.size(1)}"
+
+            if _tokens.size(1) < self.max_length:
+                _tokens = F.pad(
+                    _tokens,
+                    (0, self.max_length - _tokens.size(1)),
+                    value=self.tokenizer.eos_token_id,
+                )
+                _labels = F.pad(
+                    _labels, (0, self.max_length - _labels.size(1)), value=-100
                 )
-            )
 
-            mask = torch.zeros(max_semantic_length, dtype=torch.long)
-            mask[t.shape[-1] :] = 1
-            codes_mask.append(mask.bool())
+            tokens.append(_tokens)
+            attention_masks.append(_attention_mask)
+            labels.append(_labels)
 
-        codes = torch.stack(codes)
-        codes_mask = torch.stack(codes_mask)
+        tokens = torch.stack(tokens, dim=0)
+        attention_masks = torch.stack(attention_masks, dim=0)
+        labels = torch.stack(labels, dim=0)
 
-        data = {
-            "inputs": encoded_texts["input_ids"],
-            "input_mask": encoded_texts["attention_mask"] == 0,
-            "codes": codes,
-            "codes_mask": codes_mask,
+        return {
+            "inputs": tokens,
+            "attention_masks": attention_masks,
+            "labels": labels,
         }
 
-        return data
-
 
 class InterleaveDataset(IterableDataset):
     def __init__(
@@ -398,22 +397,18 @@ if __name__ == "__main__":
     #         f.write(json.dumps({"text": Path(i).read_text(), "speaker": i.parent.name, "semantic": fake_tokens}) + "\n")
 
     ds = AutoAugTextDataset(
-        jsonl_files=["data/quantized-dataset-1205.json"],
-        order=["en"],
-        tokenizer=AutoTokenizer.from_pretrained(
-            "fishaudio/speech-lm-300m", revision="text-pretrain-10k-phones"
-        ),
+        files=["data/quantized-dataset-1205.protos"],
+        tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
     )
 
     dm = TextDataModule(
         train_dataset=ds,
         val_dataset=ds,
         tokenizer=ds.tokenizer,
-        batch_size=2,
+        batch_size=16,
         max_length=1024,
         num_workers=0,
     )
 
     for batch in dm.train_dataloader():
         print(batch)
-        break

+ 23 - 19
fish_speech/models/text2semantic/lit_module.py

@@ -3,11 +3,10 @@ from typing import Any
 import lightning as L
 import torch.nn.functional as F
 from lightning.pytorch.utilities.types import OptimizerLRScheduler
-from transformers import LlamaForCausalLM
 
 
 class TextToSemantic(L.LightningModule):
-    def __init__(self, model: LlamaForCausalLM, optimizer: Any, lr_scheduler: Any):
+    def __init__(self, model, optimizer: Any, lr_scheduler: Any):
         super().__init__()
 
         self.model = model
@@ -30,26 +29,28 @@ class TextToSemantic(L.LightningModule):
         }
 
     def _step(self, batch, batch_idx, stage: str):
-        logits = self.model(
-            inputs=batch["inputs"],
-            input_mask=batch["input_mask"],
-            codes=batch["codes"][..., :-1],
-            codes_mask=batch["codes_mask"][..., :-1],
+        outputs = self.model(
+            x=batch["inputs"],
+            key_padding_mask=batch["attention_masks"],
         )
 
         # Generate labels
-        labels = batch["codes"][..., 1:].contiguous()
-        label_mask = batch["codes_mask"][..., 1:]
-        label_mask = label_mask[:, None, :]
-        label_mask = label_mask.expand(-1, labels.size(1), -1)
-        labels = labels.masked_fill(label_mask, -100)
-
-        loss = F.cross_entropy(
-            logits.view(-1, logits.size(-1)),
-            labels.view(-1),
+        labels = batch["labels"]
+        token_loss = F.cross_entropy(
+            outputs.token_logits.reshape(-1, outputs.token_logits.size(-1)),
+            labels[:, 0].reshape(-1),
             ignore_index=-100,
         )
 
+        codebook_labels = labels[:, 1:].mT
+        semantic_loss = F.cross_entropy(
+            outputs.codebook_logits.reshape(-1, outputs.codebook_logits.size(-1)),
+            codebook_labels.reshape(-1),
+            ignore_index=-100,
+        )
+
+        loss = token_loss + semantic_loss
+
         self.log(
             f"{stage}/loss",
             loss,
@@ -60,9 +61,12 @@ class TextToSemantic(L.LightningModule):
         )
 
         # Top-5 accuracy
-        _, indices = logits.topk(5, dim=-1)
-        correct = indices.eq(labels.unsqueeze(-1)).sum()
-        accuracy = correct / labels.numel()
+        _, indices = outputs.codebook_logits.topk(5, dim=-1)
+        correct = indices.eq(codebook_labels.unsqueeze(-1))
+        correct[codebook_labels == -100] = 0
+        correct = correct.sum()
+        accuracy = correct / (codebook_labels != -100).sum()
+
         self.log(
             f"{stage}/top_5_accuracy",
             accuracy,

+ 10 - 8
fish_speech/models/text2semantic/llama.py

@@ -124,20 +124,20 @@ class Transformer(nn.Module):
             )
 
     def forward(self, x: Tensor, key_padding_mask: Optional[Tensor] = None) -> Tensor:
-        # x: (batch, seq_len, num_codebooks + 1)
+        # x: (batch, num_codebooks + 1, seq_len)
+        seq_len = x.size(2)
 
         # Here we want to merge the embeddings of the codebooks
-        vocab_embeds = [self.embeddings(x[:, :, 0])]
+        vocab_embeds = [self.embeddings(x[:, 0])]
         for i in range(self.config.num_codebooks):
             emb = self.embeddings(
-                x[:, :, i + 1] + i * self.config.codebook_size + self.config.vocab_size
+                x[:, i + 1] + i * self.config.codebook_size + self.config.vocab_size
             )
             vocab_embeds.append(emb)
 
         x = torch.stack(vocab_embeds, dim=3)
         x = x.mean(dim=3)
 
-        seq_len = x.size(1)
         mask = self.causal_mask[None, None, :seq_len, :seq_len]  # (B, N, Q, K)
         freqs_cis = self.freqs_cis[:seq_len]
 
@@ -145,7 +145,7 @@ class Transformer(nn.Module):
         # That is, FALSE means masked out
         # To maintain consistency, key_padding_mask use TRUE to mask out
         if key_padding_mask is not None:
-            mask = mask & key_padding_mask[:, None, :, None].logical_not()
+            mask = mask & key_padding_mask[:, None, None, :].logical_not()
 
         for layer in self.layers:
             x = layer(x, freqs_cis, mask)
@@ -156,7 +156,7 @@ class Transformer(nn.Module):
         codebook_logits = logits[:, :, self.config.vocab_size :]
 
         codebook_logits = rearrange(
-            codebook_logits, "b n (d c) -> b n d c", c=self.config.codebook_size
+            codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
         )
 
         return TransformerForwardResult(
@@ -293,19 +293,21 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
 if __name__ == "__main__":
     args = ModelArgs(
         max_seq_len=4096,
-        vocab_size=32000,
+        vocab_size=32312,
         n_layer=12,
         n_head=12,
         dim=768,
         rope_base=10000,
         norm_eps=1e-5,
+        codebook_size=168,
+        num_codebooks=4,
     )
 
     model = Transformer(args)
     model = model.cuda().bfloat16()
     print("Total params:", sum(i.numel() for i in model.parameters()) / 1024 / 1024)
 
-    inputs = torch.randint(0, 100, (2, 128, 5)).cuda()
+    inputs = torch.randint(0, 100, (2, 5, 128)).cuda()
     key_padding_mask = torch.zeros(2, 128).bool().cuda()
     key_padding_mask[0, 2:] = True
     x1 = model(inputs, key_padding_mask=key_padding_mask)

+ 1 - 1
fish_speech/train.py

@@ -33,7 +33,7 @@ def train(cfg: DictConfig) -> tuple[dict, dict]:
 
     # set seed for random number generators in pytorch, numpy and python.random
     if cfg.get("seed"):
-        L.seed_everything(cfg.seed, workers=True)
+        L.seed_everything(cfg.seed, workers=False)
 
     if cfg.get("deterministic"):
         torch.use_deterministic_algorithms(True)

+ 2 - 1
tools/llama/rebuild_tokenizer.py

@@ -8,6 +8,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_type)
 
 # new tokens
 new_tokens = list(set(zh_symbols + jp_symbols + en_symbols))
+new_tokens = [f"<p:{token}>" for token in new_tokens]
 tokenizer.add_tokens(new_tokens)
 tokenizer.add_special_tokens({"pad_token": "<pad>"})
 
@@ -33,7 +34,7 @@ print(f"Vocab size: {len(tokenizer)}, padded to {length}")
 # print(f"Total parameters: {total_params / 1e6:.2f}M")
 
 # Try tokenizing a new sequence
-sequence = "[INST] Test uang1 iang5 AA an 你好 [/INST]<s>[PAD]</s>"
+sequence = "All around, too, lay vast quantities of the costliest merchandise, and treasures were heaped in every cranny of the rocks, but all these things only added to the desolation of the scene."
 encoded = tokenizer.encode(sequence)
 print("Test encoding....")
 print(f"\tSentence: {sequence}")