Просмотр исходного кода

Update tokenizer & add decoding

Lengyue 2 лет назад
Родитель
Сommit
a0302c9dbb

+ 4 - 3
fish_speech/configs/text2semantic.yaml

@@ -15,7 +15,8 @@ trainer:
 # Dataset Configuration
 tokenizer:
   _target_: transformers.AutoTokenizer.from_pretrained
-  pretrained_model_name_or_path: 01-ai/Yi-34B
+  pretrained_model_name_or_path: fishaudio/speech-lm-300m
+  revision: mqtts-phones
   padding_side: right
   truncation_side: right
 
@@ -35,7 +36,7 @@ data:
   train_dataset: ${train_dataset}
   val_dataset: ${val_dataset}
   num_workers: 4
-  batch_size: 32
+  batch_size: 16
   tokenizer: ${tokenizer}
 
 # Model Configuration
@@ -45,7 +46,7 @@ model:
   model:
     # ~ 130M parameters, for debug purpose
     _target_: fish_speech.models.text2semantic.modules.FishSpeechTransformer
-    vocab_size: 64000
+    vocab_size: 32248
     codebook_size: 1032  # 1024 + 2 (bos, eos), make it divisible by 8
     num_codebooks: 1
     hidden_size: 1024

+ 221 - 43
fish_speech/models/text2semantic/modules.py

@@ -1,4 +1,5 @@
 import math
+from typing import Optional
 
 import torch
 from einops import rearrange
@@ -48,6 +49,26 @@ class AlibiPostionEmbedding(nn.Module):
         return self.alibi[:, : x.size(1), : x.size(1)].to(x.device)
 
 
+class KVCache(nn.Module):
+    def __init__(
+        self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16
+    ):
+        super().__init__()
+        cache_shape = (max_batch_size, max_seq_length, n_heads * head_dim)
+        self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
+        self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
+
+    def update(self, input_pos, k_val, v_val):
+        assert input_pos is not None, "input_pos should not be None"
+
+        k_out = self.k_cache
+        v_out = self.v_cache
+        k_out[:, input_pos] = k_val
+        v_out[:, input_pos] = v_val
+
+        return k_out, v_out
+
+
 class MultiheadAttention(nn.Module):
     def __init__(self, d_model, nhead, dropout=0.1):
         super().__init__()
@@ -61,6 +82,7 @@ class MultiheadAttention(nn.Module):
         self.v_proj = nn.Linear(d_model, d_model)
         self.out_proj = nn.Linear(d_model, d_model)
         self.dropout = nn.Dropout(dropout)
+        self.kv_cache = None
 
     def forward(
         self,
@@ -70,16 +92,19 @@ class MultiheadAttention(nn.Module):
         attn_mask=None,
         key_padding_mask=None,
         attn_bias=None,
-        past_kv=None,
         return_weights=False,
+        input_pos=None,
     ):
         # (B, T, C)
         batch_size = q.size(0)
         q_length = q.size(1)
-        k_length = k.size(1)
 
-        if past_kv is not None:
-            k, v = torch.cat([past_kv, k], 1), torch.cat([past_kv, v], 1)
+        q, k, v = self.q_proj(q), self.k_proj(k), self.v_proj(v)
+
+        if self.kv_cache is not None:
+            k, v = self.kv_cache.update(input_pos, k, v)
+
+        k_length = k.size(1)
 
         if attn_bias is not None:
             assert attn_bias.size() == (
@@ -117,8 +142,6 @@ class MultiheadAttention(nn.Module):
             else:
                 attn_mask = attn_mask.logical_or(key_padding_mask)
 
-        q, k, v = self.q_proj(q), self.k_proj(k), self.v_proj(v)
-
         if (
             return_weights is False
             and memory_efficient_attention is not None
@@ -138,7 +161,9 @@ class MultiheadAttention(nn.Module):
                     )
                 attn_bias = attn_bias.masked_fill(attn_mask, float("-inf"))
 
-            attn_bias = attn_bias.to(q.dtype)
+            if attn_bias is not None:
+                attn_bias = attn_bias.to(q.dtype)
+
             attn_output = memory_efficient_attention(
                 q,
                 k,
@@ -236,6 +261,7 @@ class CrossAttentionLayer(nn.Module):
         tgt,
         memory,
         memory_key_padding_mask=None,
+        input_pos=None,
     ):
         residual = tgt
         tgt, memory = self.input_layernorm_q(tgt), self.input_layernorm_kv(memory)
@@ -245,6 +271,7 @@ class CrossAttentionLayer(nn.Module):
             memory,
             key_padding_mask=memory_key_padding_mask,
             return_weights=True,
+            input_pos=input_pos,
         )
         residual = x + residual
 
@@ -264,7 +291,9 @@ class TransformerEncoderLayer(nn.Module):
         self.input_layernorm = RMSNorm(hidden_size, eps=1e-6)
         self.post_attention_layernorm = RMSNorm(hidden_size, eps=1e-6)
 
-    def forward(self, x, attn_bias=None, key_padding_mask=None, tgt_mask=None):
+    def forward(
+        self, x, attn_bias=None, key_padding_mask=None, tgt_mask=None, input_pos=None
+    ):
         residual = x
         x = self.input_layernorm(x)
         x, _ = self.attn(
@@ -275,6 +304,7 @@ class TransformerEncoderLayer(nn.Module):
             key_padding_mask=key_padding_mask,
             attn_mask=tgt_mask,
             return_weights=False,
+            input_pos=input_pos,
         )
         residual = x + residual
 
@@ -352,6 +382,27 @@ class FishSpeechTransformer(nn.Module):
             torch.triu(torch.ones(max_position, max_position), diagonal=1).bool(),
         )
 
+        self.max_batch_size = -1
+        self.max_seq_length = -1
+
+    def setup_kv_caches(self, max_batch_size, max_seq_length):
+        if (
+            self.max_seq_length >= max_seq_length
+            and self.max_batch_size >= max_batch_size
+        ):
+            return
+
+        if max_seq_length % 8 != 0:
+            max_seq_length = max_seq_length + (8 - max_seq_length % 8)
+
+        self.max_seq_length = max_seq_length
+        self.max_batch_size = max_batch_size
+
+        for b in self.decoder:
+            b.attn.kv_cache = KVCache(
+                max_batch_size, max_seq_length, b.attn.nhead, b.attn.head_dim
+            )
+
     def forward(self, inputs, codes, input_mask=None, codes_mask=None):
         # x: (B, T)
         # y: (B, C, T)
@@ -389,53 +440,177 @@ class FishSpeechTransformer(nn.Module):
 
         return codes
 
+    def decode_one_token(
+        self,
+        x: torch.Tensor,
+        context: torch.Tensor,
+        input_pos: torch.Tensor,
+        **sampling_kwargs,
+    ) -> tuple[torch.Tensor, torch.Tensor]:
+        # input_pos: [B, 1]
+        assert input_pos.shape[-1] == 1
+        attn_bias = self.alibi.alibi[:, input_pos, : self.max_seq_length]
+        causual_mask = self.causual_mask[input_pos, : self.max_seq_length]
+
+        x = rearrange(x, "b c t -> c b t")
+        x = torch.stack(
+            [emb(code) for emb, code in zip(self.decoder_embeddings, x)], dim=0
+        )
+        x = torch.mean(x, dim=0)  # (B, T)
+
+        for idx, layer in enumerate(self.decoder):
+            if idx == self.alignment_position:
+                x, _ = self.alignment(x, context)
+
+            x = layer(
+                x, attn_bias=attn_bias, input_pos=input_pos, tgt_mask=causual_mask
+            )
+
+        x = self.decoder_head(x)
+        x = rearrange(
+            x, "b t (c d) -> b c t d", c=self.num_codebooks, d=self.codebook_size
+        )
+
+        # Never predict EOS or BOS for sub-codebooks
+        x[:, 1:, :2] = -float("Inf")
+
+        next_token, probs = [], []
+        for i in range(self.num_codebooks):
+            next_token_i, probs_i = self.sample(x[:, i], **sampling_kwargs)
+            next_token.append(next_token_i)
+            probs.append(probs_i)
+
+        return torch.stack(next_token, dim=1), torch.stack(probs, dim=1)
+
+    @staticmethod
+    def multinomial_sample_one_no_sync(
+        probs_sort,
+    ):  # Does multinomial sampling without a cuda synchronization
+        q = torch.empty_like(probs_sort).exponential_(1)
+        return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
+
+    @staticmethod
+    def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
+        logits = logits / max(temperature, 1e-5)
+
+        if top_k is not None:
+            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
+            pivot = v.select(-1, -1).unsqueeze(-1)
+            logits = torch.where(logits < pivot, -float("Inf"), logits)
+        probs = torch.nn.functional.softmax(logits, dim=-1)
+        return probs
+
+    def sample(self, logits, temperature: float = 1.0, top_k: Optional[int] = None):
+        probs = self.logits_to_probs(logits[0, -1], temperature, top_k)
+        idx_next = self.multinomial_sample_one_no_sync(probs)
+        return idx_next, probs
+
+    def decode_n_tokens(
+        self,
+        cur_token: torch.Tensor,
+        context: torch.Tensor,
+        input_pos: torch.Tensor,
+        num_new_tokens: int,
+        callback=lambda _: _,
+        **sampling_kwargs,
+    ):
+        new_tokens, new_probs = [], []
+
+        for i in range(num_new_tokens):
+            next_token, next_prob = self.decode_one_token(
+                cur_token, context, input_pos, **sampling_kwargs
+            )
+            input_pos += 1
+            new_tokens.append(next_token.clone())
+            callback(new_tokens[-1])
+            new_probs.append(next_prob.clone())
+            cur_token = next_token.view(1, self.num_codebooks, -1)
+
+        return new_tokens, new_probs
+
+    @torch.no_grad()
+    def inference(self, inputs, max_new_tokens=1024, top_k=5, temperature=1.0):
+        # x: (B, T)
+        # y: (B, C, T)
+
+        assert inputs.size(0) == 1, "Only support batch size 1 for now"
+
+        # Encode Features
+        inputs = self.encoder_embedding(inputs)
+        attn_bias = self.alibi(inputs)
+        for layer in self.encoder:
+            inputs = layer(inputs, attn_bias=attn_bias)
+
+        device, dtype = inputs.device, inputs.dtype
+
+        # Decode
+        with torch.device(inputs.device):
+            self.setup_kv_caches(max_batch_size=1, max_seq_length=max_new_tokens)
+
+        # create an empty tensor of the expected final shape and fill in the current tokens
+        input_pos = torch.tensor([0], device=device, dtype=torch.long)
+        next_token = torch.tensor(
+            [[0] * self.num_codebooks], device=device, dtype=torch.long
+        )  # BOS of decoder
+
+        generated_tokens, _ = self.decode_n_tokens(
+            next_token.view(1, self.num_codebooks, -1),
+            context=inputs,
+            input_pos=input_pos,
+            num_new_tokens=max_new_tokens - 1,
+            top_k=top_k,
+            temperature=temperature,
+        )
+
+        return [i[0, 0].item() for i in generated_tokens]
+
 
 if __name__ == "__main__":
-    mha = MultiheadAttention(512, 8, dropout=0)
-    mha.eval()
-    mha.cuda()
+    # mha = MultiheadAttention(512, 8, dropout=0)
+    # mha.eval()
+    # mha.cuda()
 
-    q, k, v = torch.randn(3, 10, 16, 512)
-    q, k, v = q.cuda(), k.cuda(), v.cuda()
-    alibi = AlibiPostionEmbedding(8, 1024)
+    # q, k, v = torch.randn(3, 10, 16, 512)
+    # q, k, v = q.cuda(), k.cuda(), v.cuda()
+    # alibi = AlibiPostionEmbedding(8, 1024)
 
-    mha.bfloat16()
-    q, k, v = q.bfloat16(), k.bfloat16(), v.bfloat16()
-    bias = alibi(q).bfloat16()
+    # mha.bfloat16()
+    # q, k, v = q.bfloat16(), k.bfloat16(), v.bfloat16()
+    # bias = alibi(q).bfloat16()
 
-    # Causual mask
-    attn_mask = torch.triu(torch.ones(16, 16), diagonal=1).bool().cuda()
-    o, w = mha(q, k, v, return_weights=True, attn_bias=bias, attn_mask=attn_mask)
+    # # Causual mask
+    # attn_mask = torch.triu(torch.ones(16, 16), diagonal=1).bool().cuda()
+    # o, w = mha(q, k, v, return_weights=True, attn_bias=bias, attn_mask=attn_mask)
 
-    print(o.size())
-    print(w.size())
+    # print(o.size())
+    # print(w.size())
 
-    o1, w = mha(q, k, v, return_weights=False, attn_bias=bias, attn_mask=attn_mask)
-    print(o1.size())
+    # o1, w = mha(q, k, v, return_weights=False, attn_bias=bias, attn_mask=attn_mask)
+    # print(o1.size())
 
-    print(o[0], o1.float()[0])
+    # print(o[0], o1.float()[0])
 
-    assert torch.allclose(o.float(), o1.float(), atol=1e-2, rtol=1e-2)
-    print("ok")
+    # assert torch.allclose(o.float(), o1.float(), atol=1e-2, rtol=1e-2)
+    # print("ok")
 
-    cross = CrossAttentionLayer(512, 1024, dropout=0)
-    cross.eval()
-    cross.cuda()
+    # cross = CrossAttentionLayer(512, 1024, dropout=0)
+    # cross.eval()
+    # cross.cuda()
 
-    tgt = torch.randn(3, 10, 512).cuda()
-    memory = torch.randn(3, 20, 512).cuda()
-    o, w = cross(tgt, memory)
+    # tgt = torch.randn(3, 10, 512).cuda()
+    # memory = torch.randn(3, 20, 512).cuda()
+    # o, w = cross(tgt, memory)
 
-    print(o.size())
-    print(w.size())
+    # print(o.size())
+    # print(w.size())
 
-    ten = TransformerEncoderLayer(512, 1024, 8, dropout=0)
-    ten.eval()
-    ten.cuda()
+    # ten = TransformerEncoderLayer(512, 1024, 8, dropout=0)
+    # ten.eval()
+    # ten.cuda()
 
-    tgt = torch.randn(3, 10, 512).cuda()
-    o = ten(tgt)
-    print(o.size())
+    # tgt = torch.randn(3, 10, 512).cuda()
+    # o = ten(tgt)
+    # print(o.size())
 
     trans = (
         FishSpeechTransformer(
@@ -453,6 +628,9 @@ if __name__ == "__main__":
     )
     # Print n param
     print("Total params:", sum(i.numel() for i in trans.parameters()) / 1024 / 1024)
-    inputs = torch.randint(0, 1000, (3, 16)).cuda()
-    codes = torch.randint(0, 120, (3, 4, 128)).cuda()
+    inputs = torch.randint(0, 1000, (1, 16)).cuda()
+    codes = torch.randint(0, 120, (1, 4, 128)).cuda()
     print(trans(inputs, codes).size())
+
+    r = trans.inference(inputs, max_new_tokens=1024, top_k=5, temperature=0.3)
+    print(r)

+ 13 - 15
tools/llama/rebuild_tokenizer.py

@@ -7,27 +7,27 @@ model_type = "meta-llama/Llama-2-7b-hf"
 tokenizer = AutoTokenizer.from_pretrained(model_type)
 
 # new tokens
-new_tokens = [f"<semantic_{i}>" for i in range(4096)] + list(
-    set(zh_symbols + jp_symbols + en_symbols)
-)
+new_tokens = list(set(zh_symbols + jp_symbols + en_symbols))
 tokenizer.add_tokens(new_tokens)
 
 # pad token
 tokenizer.pad_token = tokenizer.eos_token
 tokenizer.pad_token_id = tokenizer.eos_token_id
+tokenizer.padding_side = "right"
+tokenizer.truncation_side = "right"
 
 print(f"Vocab size: {len(tokenizer)}")
 
-model = AutoModelForCausalLM.from_pretrained(
-    "fishaudio/speech-lm-300m", revision="text-pretrain-10k"
-)
+# model = AutoModelForCausalLM.from_pretrained(
+#     "fishaudio/speech-lm-300m", revision="mqtts-proto"
+# )
 
 # Resize the token embeddings to include the new tokens
 # Make sure it's a multiple of 8 for faster training
-model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
+# model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
 
-total_params = sum(p.numel() for p in model.parameters())
-print(f"Total parameters: {total_params / 1e6:.2f}M")
+# total_params = sum(p.numel() for p in model.parameters())
+# print(f"Total parameters: {total_params / 1e6:.2f}M")
 
 # Try tokenizing a new sequence
 sequence = "Test <semantic_0> <semantic_1023> </s> uang1 iang5 AA an"
@@ -37,9 +37,7 @@ print(f"\tSentence: {sequence}")
 print(f"\tEncoded: {encoded}")
 print(f"\tDecoded: {tokenizer.batch_decode(encoded)}")
 
-model.push_to_hub(
-    "fishaudio/speech-lm-300m", private=True, revision="text-pretrain-10k-phones"
-)
-tokenizer.push_to_hub(
-    "fishaudio/speech-lm-300m", private=True, revision="text-pretrain-10k-phones"
-)
+# model.push_to_hub(
+#     "fishaudio/speech-lm-300m", private=True, revision="text-pretrain-10k-phones"
+# )
+tokenizer.push_to_hub("fishaudio/speech-lm-300m", private=True, revision="mqtts-phones")