Forráskód Böngészése

[fix]fix problems to let version 1.5 support sft (#774)

* [docs]Add docs of Fish Agent.

* [docs]:Fix some issues

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [docs]Add Chinese docs for Fish Agent

* [docs]fix some issue

* [docs]fix the bug that chinese page display wrong

* [docs]Fix bugs in Chinese docs and add translated docs of agent for other language.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [feature]: change conversation.visual color and semantic encoded method

* [feature]:change collate_fn tokenizer to FishTokenzier

* [fix]fix some dimension problem in semantic.py

* [feature]change conf to tiktoken

* [fix]:fix ddp training problem

* [feature]use conversation to replace manully tokens and labels generate

* [fix]fix embedding calculate in BaseTransformer forward

* [fix]use einops to operate tensor to avoid bugs

* [fix]fix bugs in generate and llama for sft

* [fix]delete unused codes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: whaledolphin <whaledolphin@github.com>
Whale and Dolphin 1 éve
szülő
commit
40665e1a39

+ 1 - 0
.gitignore

@@ -30,3 +30,4 @@ asr-label*
 /example
 /example
 /faster_whisper
 /faster_whisper
 /.gradio
 /.gradio
+*log

+ 8 - 5
fish_speech/configs/text2semantic_finetune.yaml

@@ -4,22 +4,25 @@ defaults:
 
 
 project: text2semantic_finetune_dual_ar
 project: text2semantic_finetune_dual_ar
 max_length: 4096
 max_length: 4096
-pretrained_ckpt_path: checkpoints/fish-speech-1.4
+pretrained_ckpt_path: checkpoints/fish-speech-1.5
 
 
 # Lightning Trainer
 # Lightning Trainer
 trainer:
 trainer:
   accumulate_grad_batches: 1
   accumulate_grad_batches: 1
   gradient_clip_val: 1.0
   gradient_clip_val: 1.0
   gradient_clip_algorithm: "norm"
   gradient_clip_algorithm: "norm"
-  max_steps: 1000
+  max_steps: 10000
   precision: bf16-true
   precision: bf16-true
   limit_val_batches: 10
   limit_val_batches: 10
   val_check_interval: 100
   val_check_interval: 100
+  # strategy:
+  #   find_unused_parameters: true
+  #   static_graph: true 
 
 
 # Dataset Configuration
 # Dataset Configuration
 tokenizer:
 tokenizer:
-  _target_: transformers.AutoTokenizer.from_pretrained
-  pretrained_model_name_or_path: ${pretrained_ckpt_path}
+  _target_: fish_speech.tokenizer.FishTokenizer
+  model_path: ${pretrained_ckpt_path}/tokenizer.tiktoken
 
 
 # Dataset Configuration
 # Dataset Configuration
 train_dataset:
 train_dataset:
@@ -47,7 +50,7 @@ data:
   train_dataset: ${train_dataset}
   train_dataset: ${train_dataset}
   val_dataset: ${val_dataset}
   val_dataset: ${val_dataset}
   num_workers: 4
   num_workers: 4
-  batch_size: 8
+  batch_size: 4
   tokenizer: ${tokenizer}
   tokenizer: ${tokenizer}
   max_length: ${max_length}
   max_length: ${max_length}
 
 

+ 16 - 17
fish_speech/conversation.py

@@ -207,35 +207,34 @@ class Conversation:
             tokenizer, add_shift=False, ignore_loss_tokens=ignore_loss_tokens
             tokenizer, add_shift=False, ignore_loss_tokens=ignore_loss_tokens
         )
         )
 
 
-        # Colors for alternating tokens
         colors = {
         colors = {
-            "blue": "\033[94m",  # Light blue
-            "cyan": "\033[96m",  # Cyan
-            "green": "\033[92m",  # Light green
-            "dark_green": "\033[32m",  # Dark green
+            "purple": "\033[95m",
+            "yellow": "\033[93m",
+            "red": "\033[91m",
+            "cyan": "\033[96m",
         }
         }
-        blue_idx = 0
-        green_idx = 0
+        first_idx = 0
+        second_idx = 0
 
 
-        def print_in_blue(x):
-            nonlocal blue_idx
-            color = colors["blue"] if blue_idx % 2 == 0 else colors["cyan"]
+        def print_first_group(x):
+            nonlocal first_idx
+            color = colors["purple"] if first_idx % 2 == 0 else colors["yellow"]
             print(f"{color}{x}\033[0m", end="")
             print(f"{color}{x}\033[0m", end="")
-            blue_idx += 1
+            first_idx += 1
 
 
-        def print_in_green(x):
-            nonlocal green_idx
-            color = colors["green"] if green_idx % 2 == 0 else colors["dark_green"]
+        def print_second_group(x):
+            nonlocal second_idx
+            color = colors["red"] if second_idx % 2 == 0 else colors["cyan"]
             print(f"{color}{x}\033[0m", end="")
             print(f"{color}{x}\033[0m", end="")
-            green_idx += 1
+            second_idx += 1
 
 
         for tok, lab in zip(encoded.tokens, encoded.labels):
         for tok, lab in zip(encoded.tokens, encoded.labels):
             val = tokenizer.decode([tok])
             val = tokenizer.decode([tok])
 
 
             if lab == -100:
             if lab == -100:
-                print_in_green(val)
+                print_second_group(val)
             else:
             else:
-                print_in_blue(val)
+                print_first_group(val)
 
 
         print()
         print()
 
 

+ 100 - 142
fish_speech/datasets/semantic.py

@@ -14,12 +14,18 @@ from huggingface_hub import HfApi
 from lightning import LightningDataModule
 from lightning import LightningDataModule
 from torch.distributed import get_rank, get_world_size, is_initialized
 from torch.distributed import get_rank, get_world_size, is_initialized
 from torch.utils.data import DataLoader, IterableDataset, get_worker_info
 from torch.utils.data import DataLoader, IterableDataset, get_worker_info
-from transformers import AutoTokenizer
 
 
-from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
+from fish_speech.conversation import (
+    CODEBOOK_PAD_TOKEN_ID,
+    Conversation,
+    Message,
+    TextPart,
+    VQPart,
+)
 from fish_speech.datasets.protos.text_data_pb2 import SampledData
 from fish_speech.datasets.protos.text_data_pb2 import SampledData
 from fish_speech.datasets.protos.text_data_stream import read_pb_stream
 from fish_speech.datasets.protos.text_data_stream import read_pb_stream
 from fish_speech.text.clean import clean_text
 from fish_speech.text.clean import clean_text
+from fish_speech.tokenizer import FishTokenizer
 from fish_speech.utils import RankedLogger
 from fish_speech.utils import RankedLogger
 from fish_speech.utils.braceexpand import braceexpand
 from fish_speech.utils.braceexpand import braceexpand
 
 
@@ -73,7 +79,7 @@ class AutoTextSemanticInstructionDataset(IterableDataset):
         seed: int = 42,
         seed: int = 42,
         interactive_prob: float = 0.5,
         interactive_prob: float = 0.5,
         max_length: int = 1024,
         max_length: int = 1024,
-        tokenizer: AutoTokenizer = None,
+        tokenizer: FishTokenizer = None,
         use_speaker: bool | float = True,
         use_speaker: bool | float = True,
         causal: bool = True,
         causal: bool = True,
         num_codebooks: Optional[int] = None,
         num_codebooks: Optional[int] = None,
@@ -106,9 +112,12 @@ class AutoTextSemanticInstructionDataset(IterableDataset):
         self.num_codebooks = num_codebooks
         self.num_codebooks = num_codebooks
         self.skip_text_prob = skip_text_prob
         self.skip_text_prob = skip_text_prob
 
 
-        self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>")
         self.groups = None
         self.groups = None
 
 
+    def __iter__(self):
+        while True:
+            yield self.augment()
+
     def init_mock_data_server(self):
     def init_mock_data_server(self):
         if self.groups is not None:
         if self.groups is not None:
             return
             return
@@ -148,20 +157,6 @@ class AutoTextSemanticInstructionDataset(IterableDataset):
         Random(self.seed).shuffle(self.groups)
         Random(self.seed).shuffle(self.groups)
         self.group_weights = [len(i.sentences) for i in self.groups]
         self.group_weights = [len(i.sentences) for i in self.groups]
 
 
-    def __iter__(self):
-        while True:
-            yield self.augment()
-
-    def tokenize_sentence(self, sentence: str):
-        sentence = clean_text(sentence)
-        tokens = self.tokenizer.encode(
-            f"{sentence}",
-            max_length=10**6,
-            add_special_tokens=False,
-            truncation=False,
-        )
-        return sentence, len(tokens)
-
     def sample_data(self):
     def sample_data(self):
         if self.groups is None:
         if self.groups is None:
             self.init_mock_data_server()
             self.init_mock_data_server()
@@ -190,155 +185,119 @@ class AutoTextSemanticInstructionDataset(IterableDataset):
             samples=samples,
             samples=samples,
         )
         )
 
 
-    def augment(self):
-        final_text, final_semantic = [], []
-        response = self.sample_data()
-        if len(response.samples) == 0:
-            # Invalid group
-            return None
-
-        samples = list(response.samples)
-        idx = 0
-        use_interactive = random.random() < self.interactive_prob
-
-        if use_interactive is False:
-            # Random sample based on speaker using a truncated normal distribution
-            a = torch.tensor([0], dtype=torch.float32)
-            torch.nn.init.trunc_normal_(
-                a,
-                mean=self.max_length // 2,
-                std=self.max_length // 4,
-                a=10,
-                b=self.max_length,
-            )
-            remaining_tokens = a.long().item() - 4
-        else:
-            remaining_tokens = self.max_length
-
-        # Use speaker
-        if isinstance(self.use_speaker, float):
-            use_speaker = random.random() < self.use_speaker
-        else:
-            use_speaker = self.use_speaker
-
-        all_tokens, all_labels = [], []
-        while remaining_tokens > 0 and len(samples) > 0:
-            sentence = samples.pop(0)
-
-            text = random.choice(sentence.texts)
-            text, length = self.tokenize_sentence(text)
-            remaining_tokens -= length + len(sentence.semantics[0].values)
-
-            if use_interactive is False:
-                final_text.append(text)
-                final_semantic.append(sentence.semantics)
-            else:
-                # For interactive mode, we only apply speaker for the first sentence
-                # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
-                tokens, labels = self.pack_sentences(
-                    sentences=[text],
-                    semantics=[sentence.semantics],
-                    speaker=response.name if use_speaker else None,
-                    skip_text=random.random() < self.skip_text_prob,
-                )
-
-                all_tokens.append(tokens)
-                all_labels.append(labels)
-
-            idx += 1
-
-        if use_interactive is False:
-            tokens, labels = self.pack_sentences(
-                final_text,
-                semantics=final_semantic,
-                speaker=response.name if use_speaker else None,
-            )
-            all_tokens.append(tokens)
-            all_labels.append(labels)
-
-        tokens = torch.cat(all_tokens, dim=1)
-        labels = torch.cat(all_labels, dim=1)
-
-        # Verify that the length is correct
-        assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
-
-        data = {"tokens": tokens, "labels": labels}
-
-        return data
-
     def pack_sentences(
     def pack_sentences(
         self,
         self,
         sentences: list[str],
         sentences: list[str],
         semantics: list,
         semantics: list,
-        speaker: Optional[str] = None,
+        # speaker: Optional[str] = None,
         skip_text: bool = False,
         skip_text: bool = False,
     ):
     ):
-        if speaker is None:
-            speaker = "assistant"
+        # if speaker is None:
+        #     speaker = "assistant"
+
+        messages = [
+            Message(
+                role="system",
+                parts=[TextPart(text="Speak out the provided text.")],
+                # add_im_end=False,
+                # cal_loss=True,
+            )
+        ]
 
 
         cated_sentences = " ".join(sentences)
         cated_sentences = " ".join(sentences)
         if skip_text:
         if skip_text:
             cated_sentences = "<|skip_text|>"
             cated_sentences = "<|skip_text|>"
 
 
-        final_text = "<|im_start|>user\n" + cated_sentences + "<|im_end|>"
-        final_text = final_text + f"<|im_start|>{speaker}\n"
+        messages.append(
+            Message(
+                role="user",
+                parts=[TextPart(text=cated_sentences)],
+                # cal_loss=True,
+            )
+        )
 
 
-        encoded = self.tokenizer.encode(
-            final_text,
-            add_special_tokens=False,
-            truncation=False,
-            max_length=10**6,
+        vq_codes = [x.values for x in semantics[0]]
+        vq_codes_tensor = torch.tensor(vq_codes).to(torch.int32)
+        vqpart = VQPart(codes=vq_codes_tensor)
+        messages.append(
+            Message(
+                role="assistant",
+                parts=[TextPart(text="<|voice|>"), vqpart],
+                cal_loss=True,
+            )
         )
         )
-        semantic_length = sum([len(i[0].values) for i in semantics])
-        prompt_length = len(encoded)
+
         num_codebooks = (
         num_codebooks = (
             len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
             len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
         )
         )
 
 
-        # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
-        tokens = (
-            encoded
-            + [self.semantic_token_id] * semantic_length
-            + self.tokenizer.convert_tokens_to_ids(["<|im_end|>"])
+        conversation = Conversation(messages=messages)
+        # conversation.visualize(tokenizer=self.tokenizer)
+        encoded = conversation.encode(
+            tokenizer=self.tokenizer,
         )
         )
 
 
-        # Codebook bos/padding: 0, eos: 1
-        codes = [[CODEBOOK_PAD_TOKEN_ID] * prompt_length for _ in range(num_codebooks)]
-        for segment in semantics:
-            for book_idx, book in zip(range(num_codebooks), segment):
-                for j in book.values:
-                    codes[book_idx].append(int(j) + 1)
+        tokens_raw = encoded.tokens
+        tokens = torch.zeros((num_codebooks + 1, len(tokens_raw)), dtype=torch.int)
+        tokens[0] = tokens_raw
 
 
-        for book in codes:
-            book.extend([CODEBOOK_PAD_TOKEN_ID] * 1)
+        vq_parts = encoded.vq_parts
+        vq_parts = [part.to(tokens.device) for part in vq_parts]
+        vq_parts = torch.cat(vq_parts, dim=1)
+        tokens[1:, encoded.vq_mask_tokens] = vq_parts
 
 
-        tokens = [tokens] + codes
+        labels_raw = encoded.labels
+        labels = torch.full((num_codebooks + 1, len(labels_raw)), -100, dtype=torch.int)
+        labels[0, :] = labels_raw
+        labels[1:, encoded.vq_mask_labels] = vq_parts
+        labels[1:, -1:] = CODEBOOK_PAD_TOKEN_ID
 
 
-        tokens = torch.tensor(tokens, dtype=torch.long)
-        labels = tokens.clone()
-
-        if skip_text:
-            # If text is not provided, the sentence is used for condition only, all labels are -100
-            torch.fill_(labels, -100)
-            return tokens, labels
-
-        # Mask out the <s> tokens for semantic, predict semantic tokens only
-        # Since we don't mask out the input tokens, the language modeling still works
-        labels[1:, :prompt_length] = -100
-
-        tokens = tokens[:, :-1]
-        labels = labels[:, 1:]
+        tokens = tokens.long()
+        labels = labels.long()
 
 
         # Verify the padding is correct, and the last token is eos
         # Verify the padding is correct, and the last token is eos
-        assert (tokens[1:, :prompt_length] == CODEBOOK_PAD_TOKEN_ID).all()
+        assert (tokens[1:, ~(encoded.vq_mask_tokens)] == CODEBOOK_PAD_TOKEN_ID).all()
         assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all()
         assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all()
 
 
         return tokens, labels
         return tokens, labels
 
 
+    def augment(self):
+        response = self.sample_data()
+        if len(response.samples) == 0:
+            # Invalid group
+            return None
+
+        samples = list(response.samples)
+        all_tokens, all_labels = [], []
+
+        while len(samples) > 0:
+            sentence = samples.pop(0)
+            text = clean_text(random.choice(sentence.texts))
+
+            tokens, labels = self.pack_sentences(
+                sentences=[text],
+                semantics=[sentence.semantics],
+                # speaker=response.name if use_speaker else None,
+                skip_text=random.random() < self.skip_text_prob,
+            )
+
+            all_tokens.append(tokens)
+            all_labels.append(labels)
+
+        tokens = torch.cat(all_tokens, dim=1)
+        labels = torch.cat(all_labels, dim=1)
+
+        # Verify that the length is correct
+        assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
+
+        data = {"tokens": tokens, "labels": labels}
+
+        return data
+
 
 
 @dataclass
 @dataclass
 class TextDataCollator:
 class TextDataCollator:
-    tokenizer: AutoTokenizer
+    tokenizer: FishTokenizer
     max_length: int = 1024
     max_length: int = 1024
 
 
     def __call__(self, examples):
     def __call__(self, examples):
@@ -388,7 +347,7 @@ class TextDataCollator:
                 _tokens = F.pad(
                 _tokens = F.pad(
                     _tokens,
                     _tokens,
                     (0, max_tokens_length - tokens_length),
                     (0, max_tokens_length - tokens_length),
-                    value=self.tokenizer.eos_token_id,
+                    value=self.tokenizer.get_token_id("<|end_of_text|>"),
                 )
                 )
                 _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
                 _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
                 _labels = F.pad(
                 _labels = F.pad(
@@ -446,7 +405,7 @@ class SemanticDataModule(LightningDataModule):
         train_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
         train_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
         val_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
         val_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
         batch_size: int = 32,
         batch_size: int = 32,
-        tokenizer: AutoTokenizer = None,
+        tokenizer: FishTokenizer = None,
         max_length: int = 1024,
         max_length: int = 1024,
         num_workers: int = 4,
         num_workers: int = 4,
     ):
     ):
@@ -483,14 +442,13 @@ if __name__ == "__main__":
 
 
     ds = AutoTextSemanticInstructionDataset(
     ds = AutoTextSemanticInstructionDataset(
         ["data/protos"],
         ["data/protos"],
-        tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"),
+        tokenizer=FishTokenizer("checkpoints/fish-speech-1.5/tokenizer.tiktoken"),
         use_speaker=False,
         use_speaker=False,
         interactive_prob=1.0,
         interactive_prob=1.0,
         skip_text_prob=0.5,
         skip_text_prob=0.5,
     )
     )
 
 
     for i in ds:
     for i in ds:
-        print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
-        # i["labels"][0][i["labels"][0] == -100] = 0
-        # print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False))
+        # Please uncomment line 235 to visualize the tokenized message
+        print(i)
         break
         break

+ 40 - 43
fish_speech/models/text2semantic/llama.py

@@ -167,7 +167,7 @@ class BaseTransformer(nn.Module):
     def __init__(
     def __init__(
         self,
         self,
         config: BaseModelArgs,
         config: BaseModelArgs,
-        tokenizer: FishTokenizer | AutoTokenizer,
+        tokenizer: FishTokenizer,
         init_weights: bool = True,
         init_weights: bool = True,
     ) -> None:
     ) -> None:
         super().__init__()
         super().__init__()
@@ -246,17 +246,24 @@ class BaseTransformer(nn.Module):
                 dtype=dtype,
                 dtype=dtype,
             )
             )
 
 
-    def embed(self, x: Tensor) -> Tensor:
-        vocab_embeds = [self.embeddings(x[:, 0])]
+    def embed(self, inp: Tensor, share_codebook_embeddings=True) -> Tensor:
+        embeds = []
+        semantic_token_ids_tensor = torch.tensor(
+            self.semantic_token_ids, device=inp.device
+        )
+
         for i in range(self.config.num_codebooks):
         for i in range(self.config.num_codebooks):
-            emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size)
-            semantic_token_ids_tensor = torch.tensor(
-                self.semantic_token_ids, device=x.device
-            )
-            emb[~torch.isin(x[:, 0], semantic_token_ids_tensor)] = 0
+            if share_codebook_embeddings:
+                emb = self.codebook_embeddings(
+                    inp[:, i + 1] + i * self.config.codebook_size
+                )
+            else:
+                emb = self.codebook_embeddings(inp[:, i + 1])
+            embeds.append(emb)
 
 
-        x = torch.stack(vocab_embeds, dim=3)
-        x = x.sum(dim=3)
+        vq_embeds_sum = torch.stack(embeds, dim=1).sum(dim=1)
+        vq_embeds_sum[~torch.isin(inp[:, 0], semantic_token_ids_tensor)] = 0
+        x = self.embeddings(inp[:, 0]) + vq_embeds_sum
 
 
         return x
         return x
 
 
@@ -277,8 +284,14 @@ class BaseTransformer(nn.Module):
         # To maintain consistency, key_padding_mask use TRUE to mask out
         # To maintain consistency, key_padding_mask use TRUE to mask out
         mask = None
         mask = None
         if key_padding_mask is not None:
         if key_padding_mask is not None:
-            mask = self.causal_mask[None, None, :seq_len, :seq_len]  # (B, N, Q, K)
-            mask = mask & key_padding_mask[:, None, None, :].logical_not()
+            causal = self.causal_mask[:seq_len, :seq_len]
+            causal = rearrange(causal, "q k -> 1 1 q k")
+
+            atten_mask = rearrange(key_padding_mask, "b s -> b 1 1 s")
+            atten_mask = atten_mask.logical_not()
+            mask = causal & atten_mask
+
+        # return freqs_cis, mask
 
 
         for layer in self.layers:
         for layer in self.layers:
             if self.config.use_gradient_checkpointing and self.training:
             if self.config.use_gradient_checkpointing and self.training:
@@ -303,36 +316,12 @@ class BaseTransformer(nn.Module):
         self,
         self,
         inp: Tensor,
         inp: Tensor,
         input_pos: Optional[Tensor] = None,
         input_pos: Optional[Tensor] = None,
-        vq_masks: Optional[Tensor] = None,  # this is not used in fact
         return_all: bool = False,
         return_all: bool = False,
     ) -> BaseTransformerForwardResult:
     ) -> BaseTransformerForwardResult:
-        # This is used for generation, optimized for torch compile
-        # assert (
-        #     self.max_seq_len != -1 and self.max_batch_size != -1
-        # ), "Please call setup_caches before forward_generate"
-
-        embeds = []
-        for i in range(self.config.num_codebooks):
-            if self.config.share_codebook_embeddings:
-                _tokens = inp[:, i + 1] + i * self.config.codebook_size
-            else:
-                _tokens = inp[:, i + 1]
-
-            emb = self.codebook_embeddings(_tokens)
-            embeds.append(emb)
-
-        vq_embeds_sum = torch.stack(embeds, dim=1).sum(dim=1)
-        # if self.config.use_codebook_mlp:
-        #     vq_embeds_sum = vq_embeds_sum / self.config.num_codebooks
-        #     vq_embeds_sum = self.codebook_mlp(vq_embeds_sum)
-
-        vq_masks = (inp[:, 0] >= self.tokenizer.semantic_begin_id) & (
-            inp[:, 0] <= self.tokenizer.semantic_end_id
+        x = self.embed(
+            inp, share_codebook_embeddings=self.config.share_codebook_embeddings
         )
         )
 
 
-        vq_embeds_sum[~vq_masks] = 0
-        x = self.embeddings(inp[:, 0]) + vq_embeds_sum
-
         if input_pos is None:
         if input_pos is None:
             input_pos = torch.arange(inp.shape[-1], device=x.device)
             input_pos = torch.arange(inp.shape[-1], device=x.device)
             max_seq_len = inp.shape[-1]
             max_seq_len = inp.shape[-1]
@@ -401,11 +390,8 @@ class BaseTransformer(nn.Module):
             case _:
             case _:
                 raise ValueError(f"Unknown model type: {config.model_type}")
                 raise ValueError(f"Unknown model type: {config.model_type}")
 
 
-        if is_agent:
-            tokenizer = AutoTokenizer.from_pretrained(str(path))
-        else:
-            tokenizer_path = str(path) + "/tokenizer.tiktoken"
-            tokenizer = FishTokenizer(tokenizer_path)
+        tokenizer_path = str(path) + "/tokenizer.tiktoken"
+        tokenizer = FishTokenizer(tokenizer_path)
 
 
         log.info(f"Loading model from {path}, config: {config}")
         log.info(f"Loading model from {path}, config: {config}")
         model = model_cls(config, tokenizer=tokenizer)
         model = model_cls(config, tokenizer=tokenizer)
@@ -862,6 +848,17 @@ class RMSNorm(nn.Module):
 
 
 
 
 def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
 def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
+    """
+    Precomputes frequency tensors for complex exponentials (cis)
+
+    Args:
+        seq_len: Length of the sequence for which positional embeddings are needed.
+        n_elem: Number of elements in the frequency tensor.
+        base: Base value for the frequency scaling (default: 10000).
+
+    Returns:
+        A tensor containing the precomputed frequencies in real and imaginary parts (bfloat16).
+    """
     freqs = 1.0 / (
     freqs = 1.0 / (
         base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
         base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
     )
     )

+ 9 - 8
tools/llama/merge_lora.py

@@ -76,19 +76,20 @@ def merge(lora_config, base_weight, lora_weight, output):
 
 
     new_state_dict = torch.load(output / "model.pth", map_location="cpu")
     new_state_dict = torch.load(output / "model.pth", map_location="cpu")
     original_keys = set(llama_state_dict_copy.keys())
     original_keys = set(llama_state_dict_copy.keys())
-    merged_keys = set(new_state_dict.keys())
-
-    assert original_keys == merged_keys, "Keys should be same"
 
 
+    tolerance = 1e-5
     for key in original_keys:
     for key in original_keys:
         diff_l1 = (new_state_dict[key] - llama_state_dict_copy[key]).abs().sum().item()
         diff_l1 = (new_state_dict[key] - llama_state_dict_copy[key]).abs().sum().item()
-        if diff_l1 != 0:
+        if diff_l1 > tolerance:
+            logger.info(f"Significant difference found in key: {key}")
             break
             break
-    else:
-        logger.error("Merged model is same as the original model")
-        exit(1)
 
 
-    logger.info("Merged model is different from the original model, check passed")
+    if diff_l1 <= tolerance:
+        logger.warning(
+            "Merged model seems identical to the original model. Further validation might be needed."
+        )
+    else:
+        logger.info("Merged model is different from the original model, check passed")
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":