Przeglądaj źródła

Add new text to semantic model

Lengyue 2 lat temu
rodzic
commit
895ed8e748

+ 1 - 0
README.md

@@ -20,3 +20,4 @@ pip3 install -e .
 - [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
 - [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
 - [GPT VITS](https://github.com/innnky/gpt-vits)
+- [MQTTS](https://github.com/b04901014/MQTTS)

+ 1 - 1
dockerfile

@@ -24,7 +24,7 @@ ENV SHELL=/usr/bin/zsh
 # Setup flash-attn
 RUN pip3 install --upgrade pip && \
     pip3 install ninja packaging && \
-    MAX_JOBS=4 pip3 install git+https://github.com/facebookresearch/xformers.git@v0.0.22
+    pip3 install git+https://github.com/facebookresearch/xformers.git@v0.0.22
 
 # Project Env
 WORKDIR /exp

+ 72 - 0
fish_speech/configs/text2semantic.yaml

@@ -0,0 +1,72 @@
+defaults:
+  - base
+  - _self_
+
+project: text2semantic_400m
+
+# Lightning Trainer
+trainer:
+  accumulate_grad_batches: 2
+  gradient_clip_val: 1.0
+  gradient_clip_algorithm: 'norm'
+  max_steps: 1_000_000
+  precision: bf16-true
+
+# Dataset Configuration
+tokenizer:
+  _target_: transformers.AutoTokenizer.from_pretrained
+  pretrained_model_name_or_path: 01-ai/Yi-34B
+  padding_side: right
+  truncation_side: right
+
+# Dataset Configuration
+train_dataset:
+  _target_: fish_speech.datasets.text.StreamTextDataset
+  repo: fishaudio/cn-hubert-25hz-vq
+  prefix: 'data/train'
+
+val_dataset:
+  _target_: fish_speech.datasets.text.StreamTextDataset
+  repo: fishaudio/cn-hubert-25hz-vq
+  prefix: 'data/test'
+
+data:
+  _target_: fish_speech.datasets.text.TextDataModule
+  train_dataset: ${train_dataset}
+  val_dataset: ${val_dataset}
+  num_workers: 4
+  batch_size: 32
+  tokenizer: ${tokenizer}
+
+# Model Configuration
+model:
+  _target_: fish_speech.models.text2semantic.TextToSemantic
+
+  model:
+    # ~ 130M parameters, for debug purpose
+    _target_: fish_speech.models.text2semantic.modules.FishSpeechTransformer
+    vocab_size: 64000
+    codebook_size: 1032  # 1024 + 2 (bos, eos), make it divisible by 8
+    num_codebooks: 1
+    hidden_size: 1024
+    nhead: 16
+    num_encoder_layers: 12
+    num_decoder_layers: 12
+
+  optimizer:
+    _target_: torch.optim.AdamW
+    _partial_: true
+    lr: 1e-4
+    weight_decay: 0.1
+    betas: [0.9, 0.95]
+    eps: 1e-5
+
+  lr_scheduler:
+    _target_: torch.optim.lr_scheduler.LambdaLR
+    _partial_: true
+    lr_lambda:
+      _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
+      _partial_: true
+      num_warmup_steps: 2000
+      num_training_steps: ${trainer.max_steps}
+      final_lr_ratio: 0.1

+ 61 - 3
fish_speech/datasets/text.py

@@ -1,5 +1,6 @@
 import json
 import random
+import re
 from dataclasses import dataclass
 from itertools import chain
 from pathlib import Path
@@ -9,6 +10,7 @@ from typing import Optional, Union
 import numpy as np
 import pyarrow.parquet as pq
 import torch
+import torch.nn.functional as F
 from datasets.download.streaming_download_manager import xopen
 from huggingface_hub import HfApi
 from lightning import LightningDataModule
@@ -96,6 +98,25 @@ class StreamTextDataset(IterableDataset):
                 log.exception(f"Failed to parse {filename}: {e}")
 
     def parse_data(self, filename: str):
+        for data in self.parse_data_internal(filename):
+            text = data["text"]
+            expression = re.compile(r"\[INST\] (.*) \[/INST\] (.*) </s>")
+            match = expression.match(text)
+
+            if match is None:
+                continue
+
+            text = match.group(1)
+            semantic = match.group(2)
+
+            # Convert semantic to ids
+            expression = re.compile(r"<semantic_(\d+)>")
+            # 0 and 1 are reserved for <s> and </s>
+            semantic = [0] + [int(i) + 2 for i in expression.findall(semantic)] + [1]
+
+            yield {"text": text, "semantic": [semantic]}
+
+    def parse_data_internal(self, filename: str):
         url = f"https://huggingface.co/datasets/{self.repo}/resolve/main/{filename}"
 
         with xopen(url, mode="rb") as stream:
@@ -242,16 +263,53 @@ class TextDataCollator:
         if self.tokenizer.pad_token is None:
             self.tokenizer.pad_token = self.tokenizer.eos_token
 
-        data = self.tokenizer(
+        encoded_texts = self.tokenizer(
             texts,
             truncation=True,
             padding=True,
             max_length=self.max_length,
             return_tensors="pt",
+            pad_to_multiple_of=8,
         )
 
-        data["labels"] = data["input_ids"].clone()
-        data["labels"][data["attention_mask"] == 0] = -100
+        semantic = [i["semantic"] for i in examples]
+        max_semantic_length = max([len(i[0]) for i in semantic])
+
+        # Make xformers happy
+        if (max_semantic_length - 1) % 8 != 0:
+            max_semantic_length += 8 - (max_semantic_length - 1) % 8
+
+        if max_semantic_length > self.max_length + 1:
+            max_semantic_length = self.max_length + 1
+
+        codes, codes_mask = [], []
+
+        for i in semantic:
+            t = torch.tensor(i)
+            if t.shape[-1] >= max_semantic_length:
+                t = t[..., :max_semantic_length]
+
+            codes.append(
+                F.pad(
+                    t,
+                    (0, max_semantic_length - t.shape[-1]),
+                    value=1,
+                )
+            )
+
+            mask = torch.zeros(max_semantic_length, dtype=torch.long)
+            mask[t.shape[-1] :] = 1
+            codes_mask.append(mask.bool())
+
+        codes = torch.stack(codes)
+        codes_mask = torch.stack(codes_mask)
+
+        data = {
+            "inputs": encoded_texts["input_ids"],
+            "input_mask": encoded_texts["attention_mask"] == 0,
+            "codes": codes,
+            "codes_mask": codes_mask,
+        }
 
         return data
 

+ 22 - 5
fish_speech/models/text2semantic/lit_module.py

@@ -1,6 +1,7 @@
 from typing import Any
 
 import lightning as L
+import torch.nn.functional as F
 from lightning.pytorch.utilities.types import OptimizerLRScheduler
 from transformers import LlamaForCausalLM
 
@@ -29,9 +30,25 @@ class TextToSemantic(L.LightningModule):
         }
 
     def _step(self, batch, batch_idx, stage: str):
-        result = self.model(**batch)
-        loss = result.loss
-        logits = result.logits
+        logits = self.model(
+            inputs=batch["inputs"],
+            input_mask=batch["input_mask"],
+            codes=batch["codes"][..., :-1],
+            codes_mask=batch["codes_mask"][..., :-1],
+        )
+
+        # Generate labels
+        labels = batch["codes"][..., 1:].contiguous()
+        label_mask = batch["codes_mask"][..., 1:]
+        label_mask = label_mask[:, None, :]
+        label_mask = label_mask.expand(-1, labels.size(1), -1)
+        labels = labels.masked_fill(label_mask, -100)
+
+        loss = F.cross_entropy(
+            logits.view(-1, logits.size(-1)),
+            labels.view(-1),
+            ignore_index=-100,
+        )
 
         self.log(
             f"{stage}/loss",
@@ -44,8 +61,8 @@ class TextToSemantic(L.LightningModule):
 
         # Top-5 accuracy
         _, indices = logits.topk(5, dim=-1)
-        correct = indices.eq(batch["labels"].unsqueeze(-1)).sum()
-        accuracy = correct / batch["labels"].numel()
+        correct = indices.eq(labels.unsqueeze(-1)).sum()
+        accuracy = correct / labels.numel()
         self.log(
             f"{stage}/top_5_accuracy",
             accuracy,

+ 129 - 7
fish_speech/models/text2semantic/modules.py

@@ -9,11 +9,12 @@ try:
     from xformers.ops import memory_efficient_attention
 except ImportError as e:
     memory_efficient_attention = None
-# memory_efficient_attention = None
 
 
-class AlibiPostionEmbedding:
+class AlibiPostionEmbedding(nn.Module):
     def __init__(self, nheads, maxpos):
+        super().__init__()
+
         context_position = torch.arange(maxpos)[:, None]
         memory_position = torch.arange(maxpos)[None, :]
         relative_position = memory_position - context_position
@@ -21,8 +22,10 @@ class AlibiPostionEmbedding:
             torch.abs(relative_position).unsqueeze(0).expand(nheads, -1, -1)
         )
         self.slopes = torch.Tensor(self.get_slopes(nheads)) * -1
-        self.alibi = self.slopes.unsqueeze(1).unsqueeze(1) * relative_position
-        self.alibi = self.alibi.view(nheads, maxpos, maxpos)
+        alibi = self.slopes.unsqueeze(1).unsqueeze(1) * relative_position
+        alibi = alibi.view(nheads, maxpos, maxpos)
+
+        self.register_buffer("alibi", alibi)
 
     @staticmethod
     def get_slopes_power_of_2(n):
@@ -128,8 +131,14 @@ class MultiheadAttention(nn.Module):
 
             if attn_mask is not None:
                 attn_mask = rearrange(attn_mask, "(b n) q k -> b n q k", n=self.nhead)
+
+                if attn_bias is None:
+                    attn_bias = torch.zeros_like(
+                        attn_mask, dtype=q.dtype, device=q.device
+                    )
                 attn_bias = attn_bias.masked_fill(attn_mask, float("-inf"))
 
+            attn_bias = attn_bias.to(q.dtype)
             attn_output = memory_efficient_attention(
                 q,
                 k,
@@ -222,7 +231,12 @@ class CrossAttentionLayer(nn.Module):
         self.input_layernorm_kv = RMSNorm(hidden_size, eps=1e-6)
         self.post_attention_layernorm = RMSNorm(hidden_size, eps=1e-6)
 
-    def forward(self, tgt, memory, memory_key_padding_mask=None):
+    def forward(
+        self,
+        tgt,
+        memory,
+        memory_key_padding_mask=None,
+    ):
         residual = tgt
         tgt, memory = self.input_layernorm_q(tgt), self.input_layernorm_kv(memory)
         x, attn_weights = self.attn(
@@ -283,9 +297,97 @@ class FishSpeechTransformer(nn.Module):
         num_encoder_layers=12,
         num_decoder_layers=12,
         dropout=0.1,
+        alignment_position=-2,
+        max_position=8192,
     ):
-        self.embedding = nn.Embedding(vocab_size, hidden_size)
-        self.lm_head = nn.Linear(hidden_size, vocab_size * num_codebooks)
+        super().__init__()
+
+        self.encoder_embedding = nn.Embedding(vocab_size, hidden_size)
+        self.decoder_embeddings = nn.ModuleList(
+            [nn.Embedding(codebook_size, hidden_size) for _ in range(num_codebooks)]
+        )
+        self.decoder_head = nn.Linear(hidden_size, codebook_size * num_codebooks)
+        self.codebook_size = codebook_size
+        self.num_codebooks = num_codebooks
+
+        self.encoder = nn.ModuleList(
+            [
+                TransformerEncoderLayer(
+                    hidden_size=hidden_size,
+                    intermediate_size=intermediate_size,
+                    nhead=nhead,
+                    dropout=dropout,
+                )
+                for _ in range(num_encoder_layers)
+            ]
+        )
+
+        self.alignment = CrossAttentionLayer(
+            hidden_size=hidden_size,
+            intermediate_size=intermediate_size,
+            dropout=dropout,
+        )
+
+        if alignment_position < 0:
+            alignment_position = num_decoder_layers + alignment_position
+
+        self.alignment_position = alignment_position
+        assert 0 <= alignment_position < num_decoder_layers
+
+        self.decoder = nn.ModuleList(
+            [
+                TransformerEncoderLayer(
+                    hidden_size=hidden_size,
+                    intermediate_size=intermediate_size,
+                    nhead=nhead,
+                    dropout=dropout,
+                )
+                for _ in range(num_decoder_layers)
+            ]
+        )
+
+        self.alibi = AlibiPostionEmbedding(nhead, max_position)
+        self.register_buffer(
+            "causual_mask",
+            torch.triu(torch.ones(max_position, max_position), diagonal=1).bool(),
+        )
+
+    def forward(self, inputs, codes, input_mask=None, codes_mask=None):
+        # x: (B, T)
+        # y: (B, C, T)
+        inputs = self.encoder_embedding(inputs)
+        codes = rearrange(codes, "b c t -> c b t")
+        codes = torch.stack(
+            [emb(code) for emb, code in zip(self.decoder_embeddings, codes)], dim=0
+        )
+        codes = torch.mean(codes, dim=0)  # (B, T)
+
+        attn_bias = self.alibi(inputs)
+        for layer in self.encoder:
+            inputs = layer(inputs, attn_bias=attn_bias, key_padding_mask=input_mask)
+
+        attn_bias = self.alibi(codes)
+        causual_mask = self.causual_mask[: codes.shape[1], : codes.shape[1]]
+
+        for idx, layer in enumerate(self.decoder):
+            if idx == self.alignment_position:
+                codes, _ = self.alignment(
+                    codes, inputs, memory_key_padding_mask=input_mask
+                )
+
+            codes = layer(
+                codes,
+                attn_bias=attn_bias,
+                key_padding_mask=codes_mask,
+                tgt_mask=causual_mask,
+            )
+
+        codes = self.decoder_head(codes)
+        codes = rearrange(
+            codes, "b t (c d) -> b c t d", c=self.num_codebooks, d=self.codebook_size
+        )
+
+        return codes
 
 
 if __name__ == "__main__":
@@ -334,3 +436,23 @@ if __name__ == "__main__":
     tgt = torch.randn(3, 10, 512).cuda()
     o = ten(tgt)
     print(o.size())
+
+    trans = (
+        FishSpeechTransformer(
+            vocab_size=30000,
+            codebook_size=120,
+            num_codebooks=4,
+            hidden_size=1024,
+            intermediate_size=None,
+            nhead=16,
+            num_encoder_layers=12,
+            num_decoder_layers=12,
+        )
+        .bfloat16()
+        .cuda()
+    )
+    # Print n param
+    print("Total params:", sum(i.numel() for i in trans.parameters()) / 1024 / 1024)
+    inputs = torch.randint(0, 1000, (3, 16)).cuda()
+    codes = torch.randint(0, 120, (3, 4, 128)).cuda()
+    print(trans(inputs, codes).size())

+ 1 - 2
pyproject.toml

@@ -13,7 +13,7 @@ classifiers = [
     "Programming Language :: Python :: 3",
 ]
 dependencies = [
-    "transformers>=4.34.1",
+    "transformers>=4.35.2",
     "datasets>=2.14.5",
     "bitsandbytes>=0.41.1",
     "peft>=0.5.0",
@@ -26,7 +26,6 @@ dependencies = [
     "vector-quantize-pytorch>=1.10.0",
     "rich>=13.5.3",
     "gradio>=4.0.0",
-    "diffusers@git+https://github.com/huggingface/diffusers",
     "cn2an",
     "pypinyin",
     "jieba",