Kaynağa Gözat

support s1 model structure

Lengyue 10 ay önce
ebeveyn
işleme
89474bbb3b

+ 2 - 2
API_FLAGS.txt

@@ -1,6 +1,6 @@
 # --infer
 --api
 --listen 0.0.0.0:8080 \
---llama-checkpoint-path "checkpoints/fish-speech-1.5" \
---decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
+--llama-checkpoint-path "checkpoints/openaudio-s1-mini" \
+--decoder-checkpoint-path "checkpoints/openaudio-s1-mini/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
 --decoder-config-name firefly_gan_vq

+ 1 - 0
README.md

@@ -104,6 +104,7 @@ It should be noted that the current model **DOESN'T SUPPORT FINETUNE**.
 - [MQTTS](https://github.com/b04901014/MQTTS)
 - [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
 - [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
+- [Qwen3](https://github.com/QwenLM/Qwen3)
 
 ## Tech Report (V1.4)
 ```bibtex

+ 1 - 1
fish_speech/configs/text2semantic_finetune.yaml

@@ -4,7 +4,7 @@ defaults:
 
 project: text2semantic_finetune_dual_ar
 max_length: 4096
-pretrained_ckpt_path: checkpoints/fish-speech-1.5
+pretrained_ckpt_path: checkpoints/openaudio-s1-mini
 
 # Lightning Trainer
 trainer:

+ 367 - 0
fish_speech/content_sequence.py

@@ -0,0 +1,367 @@
+from dataclasses import dataclass, field
+from typing import List, Literal, Union
+
+import numpy as np
+import torch
+
+from fish_speech.tokenizer import (
+    IM_END_TOKEN,
+    MODALITY_TOKENS,
+    FishTokenizer,
+)
+
+
+def restore_ndarray(obj, to_tensor: bool = False):
+    if isinstance(obj, dict) and "__ndarray__" in obj:
+        obj = np.frombuffer(obj["data"], dtype=obj["dtype"]).reshape(obj["shape"])
+
+    if to_tensor and isinstance(obj, np.ndarray):
+        obj = torch.from_numpy(obj.copy())
+
+    return obj
+
+
+@dataclass
+class BasePart:
+    type: Literal["text", "vq", "audio"] | None = None
+    cal_loss: bool = False
+
+
+@dataclass(kw_only=True)
+class VQPart(BasePart):
+    type = "vq"
+    codes: torch.Tensor
+
+    def __post_init__(self: "VQPart"):
+        self.type = "vq"
+        self.codes = restore_ndarray(self.codes, to_tensor=True)
+
+
+@dataclass(kw_only=True)
+class TextPart(BasePart):
+    type = "text"
+    text: str | None = None
+    tokens: list[int] | None = None
+
+    def __post_init__(self: "TextPart"):
+        self.type = "text"
+        if self.text is None and self.tokens is None:
+            raise ValueError("Either text or tokens must be provided")
+
+
+@dataclass(kw_only=True)
+class AudioPart(BasePart):
+    type = "audio"
+    features: torch.Tensor
+
+    def __post_init__(self: "AudioPart"):
+        self.type = "audio"
+        self.features = restore_ndarray(self.features, to_tensor=True)
+
+
+@dataclass(kw_only=True)
+class EncodedMessage:
+    tokens: torch.Tensor
+    labels: torch.Tensor
+    vq_mask_tokens: torch.Tensor | None = None
+    vq_mask_labels: torch.Tensor | None = None
+    vq_parts: list[torch.Tensor]
+    vq_require_losses: torch.Tensor | None = None
+    audio_parts: list[torch.Tensor]
+    audio_masks: torch.Tensor | None = None
+    metadata: dict | None = None
+
+
+@dataclass
+class ContentSequence:
+    """
+    Flexible sequence of content parts that supports interleaved multimodal format.
+    Example format: <|interleave|><|speaker:1|> TEXT AUDIO <|im_end|><|speaker:2|> TEXT AUDIO <|im_end|>
+    """
+
+    parts: list[BasePart] = field(default_factory=list)
+    modality: Literal["text", "voice", "interleave"] | None = None
+    metadata: dict | None = None
+
+    def __init__(
+        self: "ContentSequence",
+        parts: list[BasePart | dict] | None = None,
+        modality: Literal["text", "voice", "interleave"] | None = None,
+        metadata: dict | None = None,
+    ):
+        self.modality = modality
+        self.metadata = metadata or {}
+
+        fixed_parts = []
+        for part in parts or []:
+            if isinstance(part, dict):
+                if part["type"] == "vq":
+                    part = VQPart(**part)
+                elif part["type"] == "audio":
+                    part = AudioPart(**part)
+                elif part["type"] == "text":
+                    part = TextPart(**part)
+                else:
+                    raise ValueError(f"Unsupported part type: {part['type']}")
+            fixed_parts.append(part)
+
+        self.parts = fixed_parts
+
+        # If modality is specified, add it at the beginning if it's not already there
+        if self.modality and not (
+            len(self.parts) > 0
+            and isinstance(self.parts[0], dict) is False
+            and isinstance(self.parts[0], TextPart)
+            and self.parts[0].text is not None
+            and self.parts[0].text.startswith(MODALITY_TOKENS[self.modality])
+        ):
+            modality_token = MODALITY_TOKENS[self.modality]
+            self.parts.insert(0, TextPart(text=modality_token))
+
+    def append(
+        self: "ContentSequence",
+        part_or_parts: Union[BasePart, List[BasePart]],
+        add_end: bool = False,
+        speaker: Union[str, int] | None = None,
+    ):
+        """
+        Append a part or list of parts to the sequence.
+
+        Args:
+            part_or_parts: A single part or list of parts to add
+            add_end: Whether to add the IM_END_TOKEN after these parts
+            speaker: Optional speaker identifier (name or ID) to add before the parts
+        """
+        # Convert single part to list
+        parts_to_add = (
+            [part_or_parts] if not isinstance(part_or_parts, list) else part_or_parts
+        )
+
+        # Add speaker token if specified
+        if speaker is not None:
+            speaker_token = f"<|speaker:{speaker}|>"
+            self.parts.append(TextPart(text=speaker_token))
+
+        # Add all the parts
+        self.parts.extend(parts_to_add)
+
+        # Add end token if requested
+        if add_end:
+            self.parts.append(
+                TextPart(text=IM_END_TOKEN, cal_loss=self.parts[-1].cal_loss)
+            )
+
+    def encode(
+        self: "ContentSequence",
+        tokenizer: FishTokenizer,
+        add_shift: bool = True,
+        ignore_loss_tokens: list[str] = [],
+    ) -> EncodedMessage:
+        """
+        Encode the sequence parts into tokens for the model.
+
+        Args:
+            tokenizer: The tokenizer to use
+            add_shift: Whether to shift tokens for next-token prediction
+            ignore_loss_tokens: List of token strings to ignore when calculating loss
+
+        Returns:
+            EncodedMessage with tensors ready for the model
+        """
+        all_tokens = []
+        all_labels = []
+
+        # Multi-modal elements
+        vq_parts = []
+        vq_masks = []
+        vq_require_losses = []
+
+        audio_parts = []
+        audio_masks = []
+
+        ignore_loss_token_ids = [tokenizer.get_token_id(i) for i in ignore_loss_tokens]
+
+        for part in self.parts:
+            if isinstance(part, TextPart):
+                if part.tokens is None:
+                    assert part.text is not None
+                    tokens = tokenizer.encode(part.text)
+                else:
+                    tokens = part.tokens
+
+                tokens = torch.tensor(tokens, dtype=torch.int)
+            elif isinstance(part, VQPart):
+                curr_codes = part.codes.clone().to(torch.int)
+                tokens = torch.tensor(
+                    [
+                        tokenizer.semantic_id_to_token_id[int(i.item())]
+                        for i in curr_codes[0].int()
+                    ],
+                    dtype=torch.int,
+                )
+                vq_parts.append(curr_codes)
+                vq_require_losses.append(part.cal_loss)
+            else:
+                raise ValueError(f"Unsupported part type: {type(part)}")
+
+            all_tokens.append(tokens)
+
+            # Set masks for different part types
+            if isinstance(part, VQPart):
+                vq_masks.append(torch.ones_like(tokens, dtype=torch.bool))
+                audio_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
+            elif isinstance(part, AudioPart):
+                vq_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
+                audio_mask = torch.ones_like(tokens, dtype=torch.bool)
+                audio_mask[0] = False  # Skip start token
+                audio_mask[-1] = False  # Skip end token
+                audio_masks.append(audio_mask)
+            else:
+                vq_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
+                audio_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
+
+            # Set labels based on whether we want to calculate loss for this part
+            if part.cal_loss and not isinstance(part, AudioPart):
+                all_labels.append(tokens.clone())
+            else:
+                all_labels.append(torch.full_like(tokens, -100))
+
+        # Concatenate all tensors
+        tokens = torch.cat(all_tokens, dim=0)
+        labels = torch.cat(all_labels, dim=0)
+        vq_masks = torch.cat(vq_masks, dim=0)
+        audio_masks = torch.cat(audio_masks, dim=0)
+        vq_require_losses = torch.tensor(vq_require_losses, dtype=torch.bool)
+
+        # Apply shift if needed for next-token prediction
+        vq_mask_tokens = vq_masks
+        vq_mask_labels = vq_masks
+
+        if add_shift:
+            tokens = tokens[:-1]
+            labels = labels[1:]
+            vq_masks = vq_masks[:-1]
+            vq_mask_tokens = vq_mask_tokens[:-1]
+            vq_mask_labels = vq_mask_labels[1:]
+            audio_masks = audio_masks[:-1]
+
+        # Ignore specified tokens
+        for i in ignore_loss_token_ids:
+            assert i != -100 and i is not None
+            labels[labels == i] = -100
+
+        assert tokens.dtype in [
+            torch.int,
+            torch.long,
+        ], f"Invalid dtype: {tokens.dtype}"
+
+        return EncodedMessage(
+            tokens=tokens,
+            labels=labels,
+            vq_parts=vq_parts,
+            vq_mask_tokens=vq_mask_tokens,
+            vq_mask_labels=vq_mask_labels,
+            vq_require_losses=vq_require_losses,
+            audio_parts=audio_parts,
+            audio_masks=audio_masks,
+            metadata=self.metadata,
+        )
+
+    def encode_for_inference(
+        self: "ContentSequence",
+        tokenizer: FishTokenizer,
+        num_codebooks: int,
+    ) -> torch.Tensor:
+        encoded = self.encode(tokenizer, add_shift=False)
+        tokens = encoded.tokens
+        values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.int)
+        values[0] = tokens
+
+        if (encoded.vq_parts is None or len(encoded.vq_parts) == 0) and (
+            encoded.audio_parts is None or len(encoded.audio_parts) == 0
+        ):
+            return values
+
+        if encoded.vq_parts is not None and len(encoded.vq_parts) > 0:
+            vq_parts = encoded.vq_parts
+            vq_parts = torch.cat(vq_parts, dim=1)
+            values[0, encoded.vq_mask_tokens] = (
+                vq_parts[0] + tokenizer.semantic_begin_id
+            )
+            values[1:, encoded.vq_mask_tokens] = vq_parts
+
+        return values
+
+    def visualize(
+        self: "ContentSequence",
+        tokenizer: FishTokenizer,
+        ignore_loss_tokens: list[str] = [],
+        merge_semantic_tokens: bool = False,
+    ):
+        """
+        Visualize the encoded sequence with color-coded tokens.
+        Blue/cyan tokens contribute to loss, green tokens do not.
+        """
+        encoded = self.encode(
+            tokenizer, add_shift=False, ignore_loss_tokens=ignore_loss_tokens
+        )
+
+        # Colors for alternating tokens
+        colors = {
+            "blue": "\033[94m",  # Light blue
+            "cyan": "\033[96m",  # Cyan
+            "green": "\033[92m",  # Light green
+            "dark_green": "\033[32m",  # Dark green
+        }
+        blue_idx = 0
+        green_idx = 0
+
+        def print_in_blue(x):
+            nonlocal blue_idx
+            color = colors["blue"] if blue_idx % 2 == 0 else colors["cyan"]
+            print(f"{color}{x}\033[0m", end="")
+            blue_idx += 1
+
+        def print_in_green(x):
+            nonlocal green_idx
+            color = colors["green"] if green_idx % 2 == 0 else colors["dark_green"]
+            print(f"{color}{x}\033[0m", end="")
+            green_idx += 1
+
+        def print_semantic_token(x, count):
+            val = f"[<|semantic|>x{count}]"
+            if x == -100:
+                print_in_green(val)
+            else:
+                print_in_blue(val)
+
+        count_semantic_tokens = 0
+        semantic_label = None
+
+        for tok, lab in zip(encoded.tokens, encoded.labels):
+            token_id = int(tok.item())
+
+            if merge_semantic_tokens:
+                if (
+                    tokenizer.semantic_begin_id <= token_id <= tokenizer.semantic_end_id
+                    and (semantic_label is None or semantic_label == lab)
+                ):
+                    count_semantic_tokens += 1
+                    semantic_label = lab
+                    continue
+                elif count_semantic_tokens > 0:
+                    print_semantic_token(semantic_label, count_semantic_tokens)
+                    count_semantic_tokens = 0
+                    semantic_label = None
+
+            val = tokenizer.decode([int(tok.item())])
+
+            if lab == -100:
+                print_in_green(val)
+            else:
+                print_in_blue(val)
+
+        if merge_semantic_tokens and count_semantic_tokens > 0:
+            print_semantic_token(semantic_label, count_semantic_tokens)
+
+        print()

+ 0 - 266
fish_speech/conversation.py

@@ -1,266 +0,0 @@
-from dataclasses import dataclass, field
-from typing import Literal
-
-import torch
-
-from .tokenizer import MODALITY_TOKENS, FishTokenizer
-
-CODEBOOK_PAD_TOKEN_ID = 0
-
-
-@dataclass(kw_only=True)
-class BasePart:
-    pass
-
-
-@dataclass(kw_only=True)
-class VQPart(BasePart):
-    codes: torch.Tensor
-
-
-@dataclass(kw_only=True)
-class TextPart(BasePart):
-    text: str
-
-
-@dataclass(kw_only=True)
-class EncodedMessage:
-    tokens: torch.Tensor
-    labels: torch.Tensor
-    vq_mask_tokens: torch.Tensor | None = None
-    vq_mask_labels: torch.Tensor | None = None
-    vq_parts: list[torch.Tensor]
-    vq_require_losses: torch.Tensor | None = None
-
-
-@dataclass(kw_only=True)
-class Message:
-    role: Literal["system", "user", "assistant"]
-    parts: list[VQPart | TextPart] = field(default_factory=list)
-    add_im_start: bool = True
-    add_im_end: bool = True
-    cal_loss: bool = False
-    modality: Literal["text", "voice", "interleave"] | None = None
-
-    # By default, ignore the loss of the auto-generated im_start token
-    ignore_im_start_loss: bool = True
-
-    def encode(
-        self: "Message",
-        tokenizer: FishTokenizer,
-    ) -> EncodedMessage:
-        all_tokens = []
-        all_labels = []
-
-        # Multi-modal tokens
-        vq_parts = []
-        vq_masks = []
-
-        parts = self.parts.copy()
-        if self.add_im_start:
-            modality_token = MODALITY_TOKENS[self.modality] if self.modality else ""
-            parts.insert(0, TextPart(text=f"<|im_start|>{self.role}\n{modality_token}"))
-
-        if self.add_im_end:
-            parts.append(TextPart(text="<|im_end|>"))
-
-        for part in parts:
-            if isinstance(part, TextPart):
-                tokens = torch.tensor(
-                    tokenizer.encode(part.text),
-                    dtype=torch.int,
-                )
-            elif isinstance(part, VQPart):
-                curr_codes = part.codes.clone()
-                tokens = torch.tensor(
-                    [
-                        tokenizer.semantic_id_to_token_id[i.item()]
-                        for i in curr_codes[0].int()
-                    ],
-                    dtype=torch.int,
-                )
-                vq_parts.append(curr_codes)
-            else:
-                raise ValueError(f"Unsupported part type: {type(part)}")
-
-            all_tokens.append(tokens)
-            if isinstance(part, VQPart):
-                vq_masks.append(torch.ones_like(tokens, dtype=torch.bool))
-            else:
-                vq_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
-
-            if self.cal_loss:
-                all_labels.append(tokens.clone())
-            else:
-                all_labels.append(torch.full_like(tokens, -100))
-
-        tokens = torch.cat(all_tokens, dim=0)
-        labels = torch.cat(all_labels, dim=0)
-        vq_masks = torch.cat(vq_masks, dim=0)
-
-        assert tokens.shape == labels.shape == vq_masks.shape
-
-        if self.ignore_im_start_loss and self.add_im_start:
-            labels[: len(all_tokens[0])] = -100
-
-        return EncodedMessage(
-            tokens=tokens,
-            labels=labels,
-            vq_parts=vq_parts,
-            vq_mask_tokens=vq_masks,
-            vq_mask_labels=vq_masks,
-        )
-
-
-@dataclass
-class Conversation:
-    messages: list[Message]
-
-    def __init__(self: "Conversation", messages: list[Message] | None = None):
-        self.messages = messages or []
-
-    def encode(
-        self: "Conversation",
-        tokenizer: FishTokenizer,
-        add_shift: bool = True,
-        ignore_loss_tokens: list[str] = [],
-    ) -> EncodedMessage:
-        # Build the input_ids and labels
-        tokens = []
-        labels = []
-        vq_parts = []
-        vq_mask_tokens = []
-        vq_mask_labels = []
-        vq_require_losses = []
-        ignore_loss_token_ids = [tokenizer.get_token_id(i) for i in ignore_loss_tokens]
-
-        for message in self.messages:
-            encoded = message.encode(
-                tokenizer,
-            )
-            tokens.append(encoded.tokens)
-            labels.append(encoded.labels)
-            vq_parts.extend(encoded.vq_parts)
-            vq_mask_tokens.append(encoded.vq_mask_tokens)
-            vq_mask_labels.append(encoded.vq_mask_labels)
-            vq_require_losses.extend([message.cal_loss] * len(encoded.vq_parts))
-
-        tokens = torch.cat(tokens, dim=0)
-        labels = torch.cat(labels, dim=0)
-        vq_mask_tokens = torch.cat(vq_mask_tokens, dim=0)
-        vq_mask_labels = torch.cat(vq_mask_labels, dim=0)
-        vq_require_losses = torch.tensor(vq_require_losses, dtype=torch.bool)
-
-        if add_shift:
-            tokens = tokens[:-1]
-            labels = labels[1:]
-            vq_mask_tokens = vq_mask_tokens[:-1]
-            vq_mask_labels = vq_mask_labels[1:]
-
-        for i in ignore_loss_token_ids:
-            assert i != -100 and i is not None
-            labels[labels == i] = -100
-
-        assert tokens.dtype in [
-            torch.int,
-            torch.long,
-        ], f"Invalid dtype: {tokens.dtype}, conv: {conversation}"
-
-        return EncodedMessage(
-            tokens=tokens,
-            labels=labels,
-            vq_parts=vq_parts,
-            vq_mask_tokens=vq_mask_tokens,
-            vq_mask_labels=vq_mask_labels,
-            vq_require_losses=vq_require_losses,
-        )
-
-    def encode_for_inference(
-        self: "Conversation",
-        tokenizer: FishTokenizer,
-        num_codebooks: int,
-    ) -> EncodedMessage:
-        # self.visualize(tokenizer)
-
-        encoded = self.encode(tokenizer, add_shift=False)
-        tokens = encoded.tokens
-        values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.int)
-        values[0] = tokens
-
-        if encoded.vq_parts is None or len(encoded.vq_parts) == 0:
-            return values
-
-        vq_parts = encoded.vq_parts
-        vq_parts = [part.to(values.device) for part in vq_parts]
-        vq_parts = torch.cat(vq_parts, dim=1)
-        values[0, encoded.vq_mask_tokens] = vq_parts[0] + tokenizer.semantic_begin_id
-        values[1:, encoded.vq_mask_tokens] = vq_parts
-
-        return values
-
-    def visualize(
-        self: "Conversation",
-        tokenizer: FishTokenizer,
-        ignore_loss_tokens: list[str] = [],
-    ):
-        encoded = self.encode(
-            tokenizer, add_shift=False, ignore_loss_tokens=ignore_loss_tokens
-        )
-
-        colors = {
-            "purple": "\033[95m",
-            "yellow": "\033[93m",
-            "red": "\033[91m",
-            "cyan": "\033[96m",
-        }
-        first_idx = 0
-        second_idx = 0
-
-        def print_first_group(x):
-            nonlocal first_idx
-            color = colors["purple"] if first_idx % 2 == 0 else colors["yellow"]
-            print(f"{color}{x}\033[0m", end="")
-            first_idx += 1
-
-        def print_second_group(x):
-            nonlocal second_idx
-            color = colors["red"] if second_idx % 2 == 0 else colors["cyan"]
-            print(f"{color}{x}\033[0m", end="")
-            second_idx += 1
-
-        for tok, lab in zip(encoded.tokens, encoded.labels):
-            val = tokenizer.decode([tok])
-
-            if lab == -100:
-                print_second_group(val)
-            else:
-                print_first_group(val)
-
-        print()
-
-    def append(self: "Conversation", message: Message):
-        self.messages.append(message)
-
-
-if __name__ == "__main__":
-    message0 = Message(
-        role="user",
-        parts=[
-            TextPart(text="Hello, how are you?"),
-            VQPart(codes=torch.zeros((4, 10))),
-        ],
-        cal_loss=False,
-    )
-
-    message1 = Message(
-        role="assistant",
-        parts=[TextPart(text="I'm fine, thank you.")],
-        cal_loss=True,
-    )
-    conversation = Conversation([message0, message1])
-    tokenizer = FishTokenizer.from_pretrained("checkpoints/Qwen2-1.5B-Instruct")
-    conversation.visualize(tokenizer)
-
-    encoded = conversation.encode(tokenizer)
-    print(encoded)
-    print(tokenizer.batch_decode(encoded.tokens))

+ 5 - 2
fish_speech/models/dac/inference.py

@@ -58,10 +58,10 @@ def load_model(config_name, checkpoint_path, device="cuda"):
 @click.option(
     "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
 )
-@click.option("--config-name", default="firefly_gan_vq")
+@click.option("--config-name", default="modded_dac_vq")
 @click.option(
     "--checkpoint-path",
-    default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+    default="checkpoints/openaudio-s1-mini/codec.pth",
 )
 @click.option(
     "--device",
@@ -89,6 +89,9 @@ def main(input_path, output_path, config_name, checkpoint_path, device):
         audio_lengths = torch.tensor([audios.shape[2]], device=device, dtype=torch.long)
         indices, indices_lens = model.encode(audios, audio_lengths)
 
+        if indices.ndim == 3:
+            indices = indices[0]
+
         logger.info(f"Generated indices of shape {indices.shape}")
 
         # Save indices

+ 3 - 0
fish_speech/models/dac/modded_dac.py

@@ -920,6 +920,9 @@ class DAC(BaseModel, CodecMixin):
         return indices, indices_lens
 
     def decode(self, indices: torch.Tensor, feature_lengths):
+        if indices.ndim == 2:
+            indices = indices[None]
+
         z = self.quantizer.decode(indices)
         audio_lengths = feature_lengths * self.frame_length
         return self.decoder(z), audio_lengths

+ 60 - 469
fish_speech/models/text2semantic/inference.py

@@ -16,10 +16,8 @@ from loguru import logger
 from tqdm import tqdm
 from transformers import AutoTokenizer
 
-from fish_speech.conversation import (
-    CODEBOOK_PAD_TOKEN_ID,
-    Conversation,
-    Message,
+from fish_speech.content_sequence import (
+    ContentSequence,
     TextPart,
     VQPart,
 )
@@ -84,45 +82,6 @@ def logits_to_probs(
     return probs
 
 
-def multinomial_sample_one_no_sync_agent(
-    probs_sort,
-):  # Does multinomial sampling without a cuda synchronization
-    q = torch.empty_like(probs_sort).exponential_(1)
-    return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
-
-
-def logits_to_probs_agent(
-    logits,
-    previous_tokens: Optional[torch.Tensor] = None,
-    temperature: torch.Tensor = 1.0,
-    top_p: torch.Tensor = 1.0,
-    repetition_penalty: torch.Tensor = 1.0,
-) -> torch.Tensor:
-    # Apply repetition penalty
-    if previous_tokens is not None:
-        previous_tokens = previous_tokens.long()
-        score = torch.gather(logits, dim=-1, index=previous_tokens)
-        score = torch.where(
-            score < 0, score * repetition_penalty, score / repetition_penalty
-        )
-        logits.scatter_(dim=-1, index=previous_tokens, src=score)
-
-    # Apply top-p sampling
-    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
-    cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
-    sorted_indices_to_remove = cum_probs > top_p
-    sorted_indices_to_remove[..., 0] = False  # keep at least one option
-    indices_to_remove = sorted_indices_to_remove.scatter(
-        dim=-1, index=sorted_indices, src=sorted_indices_to_remove
-    )
-    logits = logits.masked_fill(indices_to_remove, -float("Inf"))
-
-    logits = logits / max(temperature, 1e-5)
-
-    probs = torch.nn.functional.softmax(logits, dim=-1)
-    return probs
-
-
 def sample(
     logits,
     previous_tokens: Optional[torch.Tensor] = None,
@@ -135,117 +94,6 @@ def sample(
     return idx_next, probs
 
 
-def sample_agent(
-    logits,
-    previous_tokens: Optional[torch.Tensor] = None,
-    **sampling_kwargs,
-) -> Tuple[torch.Tensor, torch.Tensor]:
-    probs = logits_to_probs_agent(
-        logits=logits[:, -1], previous_tokens=previous_tokens, **sampling_kwargs
-    )
-    idx_next = multinomial_sample_one_no_sync_agent(probs)
-    return idx_next, probs
-
-
-def decode_one_token_ar_agent(
-    model: DualARTransformer,
-    x: torch.Tensor,
-    input_pos: torch.Tensor,
-    semantic_ids: list,
-    previous_tokens: torch.Tensor = None,
-    **sampling_kwargs,
-) -> torch.Tensor:
-    # print(x, input_pos)
-    x = model.forward_generate(x, input_pos)
-    logits = x.logits  # [:, -1:]
-    hidden_states = x.hidden_states  # [:, -1:]
-
-    sampling_kwargs_main = sampling_kwargs.copy()
-    sampling_kwargs_main["temperature"] = 0.1
-    sampling_kwargs_main["top_p"] = 0.1
-    sampling_kwargs_main["repetition_penalty"] = 1.0
-
-    codebooks = [
-        sample_agent(
-            logits,
-            previous_tokens=None,  # Disable repetition penalty for the token codebook
-            **sampling_kwargs_main,
-        )[0]
-    ]
-
-    # Cleanup the cache
-    for layer in model.fast_layers:
-        layer.attention.kv_cache.k_cache.fill_(0)
-        layer.attention.kv_cache.v_cache.fill_(0)
-
-    for codebook_idx in range(model.config.num_codebooks):
-        input_pos = torch.tensor(
-            [codebook_idx], device=hidden_states.device, dtype=torch.long
-        )
-        logits = model.forward_generate_fast(hidden_states, input_pos)
-        a = sample_agent(
-            logits,
-            previous_tokens=(
-                previous_tokens[:, codebook_idx + 1]
-                if previous_tokens is not None
-                else None
-            ),
-            **sampling_kwargs,
-        )[0]
-        hidden_states = model.fast_embeddings(a)
-        codebooks.append(a)
-
-    codebooks = torch.stack(codebooks, dim=1)
-    semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
-    codebooks[:, 1:, :] = torch.masked_fill(
-        codebooks[:, 1:, :],
-        ~torch.isin(codebooks[:, :1, :], semantic_ids_tensor),
-        CODEBOOK_PAD_TOKEN_ID,
-    )
-
-    return codebooks
-
-
-def decode_one_token_naive_agent(
-    model: NaiveTransformer,
-    x: torch.Tensor,
-    input_pos: torch.Tensor,
-    semantic_ids: list,
-    previous_tokens: torch.Tensor = None,
-    **sampling_kwargs,
-) -> torch.Tensor:
-    x = model.forward_generate(x, input_pos)
-
-    codebooks = [
-        sample(
-            x.token_logits,
-            previous_tokens=None,  # Disable repetition penalty for the token codebook
-            **sampling_kwargs,
-        )[0]
-    ]
-
-    for i in range(model.config.num_codebooks):
-        codebooks.append(
-            sample_agent(
-                x.codebook_logits[:, :, i],
-                previous_tokens=(
-                    previous_tokens[:, i + 1] if previous_tokens is not None else None
-                ),
-                **sampling_kwargs,
-            )[0]
-        )
-
-    codebooks = torch.stack(codebooks, dim=1)
-    semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
-    codebooks[:, 1:, :] = torch.masked_fill(
-        codebooks[:, 1:, :],
-        ~torch.isin(codebooks[:, :1, :], semantic_ids_tensor),
-        CODEBOOK_PAD_TOKEN_ID,
-    )
-
-    return codebooks
-
-
 def decode_one_token_ar(
     model: DualARTransformer,
     x: torch.Tensor,
@@ -290,8 +138,9 @@ def decode_one_token_ar(
             [codebook_idx], device=hidden_states.device, dtype=torch.long
         )
         logits = model.forward_generate_fast(hidden_states, input_pos)
+        chunked_logits = logits[..., :1024]
         a = sample(
-            logits,
+            chunked_logits,
             previous_tokens=(
                 previous_tokens[codebook_idx + 1]
                 if previous_tokens is not None
@@ -312,49 +161,13 @@ def decode_one_token_ar(
     return codebooks
 
 
-def decode_one_token_naive(
-    model: NaiveTransformer,
-    x: torch.Tensor,
-    input_pos: torch.Tensor,
-    previous_tokens: torch.Tensor = None,
-    **sampling_kwargs,
-) -> torch.Tensor:
-    x = model.forward_generate(x, input_pos)
-
-    sampling_kwargs_main = sampling_kwargs.copy()
-    sampling_kwargs_main["temperature"] = 0.1
-    sampling_kwargs_main["top_p"] = 0.1
-    sampling_kwargs_main["repetition_penalty"] = 1.0
-
-    codebooks = [
-        sample(
-            x.logits,
-            previous_tokens=None,  # Disable repetition penalty for the token codebook
-            **sampling_kwargs_main,
-        )[0]
-    ]
-
-    for i in range(model.config.num_codebooks):
-        codebooks.append(
-            sample(
-                x.codebook_logits[:, :, i],
-                previous_tokens=(
-                    previous_tokens[i + 1] if previous_tokens is not None else None
-                ),
-                **sampling_kwargs,
-            )[0]
-        )
-
-    return torch.stack(codebooks, dim=0)
-
-
 def decode_n_tokens(
     model: NaiveTransformer,
     cur_token: torch.Tensor,
     input_pos: torch.Tensor,
     num_new_tokens: int,
     semantic_ids: list,
-    decode_one_token=decode_one_token_naive,
+    decode_one_token=decode_one_token_ar,
     **sampling_kwargs,
 ):
     previous_tokens = torch.zeros(
@@ -406,7 +219,7 @@ def generate(
     model: NaiveTransformer,
     prompt: torch.Tensor,
     max_new_tokens: int,
-    decode_one_token=decode_one_token_naive,
+    decode_one_token=decode_one_token_ar,
     **sampling_kwargs,
 ) -> torch.Tensor:
     """
@@ -442,11 +255,7 @@ def generate(
     input_pos = torch.arange(0, T, device=device)
 
     # Use non-accelerated version for now, to avoid compilation overhead
-    prefill_decode = (
-        decode_one_token_naive
-        if isinstance(model, NaiveTransformer)
-        else decode_one_token_ar
-    )
+    prefill_decode = decode_one_token_ar
 
     next_token = prefill_decode(
         model,
@@ -474,222 +283,17 @@ def generate(
     return seq
 
 
-def decode_n_tokens_agent(
-    model: NaiveTransformer,
-    cur_token: torch.Tensor,
-    input_pos: torch.Tensor,
-    num_new_tokens: int,
-    semantic_ids: list,
-    im_end_id: int = 4,
-    decode_one_token=decode_one_token_naive_agent,
-    early_stop_threshold: float = 0.6,
-    **sampling_kwargs,
-):
-    batch_size = cur_token.size(0)
-    previous_tokens = torch.zeros(
-        (batch_size, model.config.num_codebooks + 1, model.config.max_seq_len),
-        dtype=torch.int,
-        device=cur_token.device,
-    )
-    finished = torch.zeros(batch_size, dtype=torch.bool, device=cur_token.device)
-    finished = finished | (cur_token[:, 0, -1] == im_end_id)
-    start_time = time.time()
-
-    for i in tqdm(range(num_new_tokens), desc="Decoding: ", total=num_new_tokens):
-        # We need to get windowed repeat penalty
-        win_size = 16
-        if i < win_size:
-            window = previous_tokens[:, :, :win_size]
-        else:
-            window = previous_tokens[:, :, i - win_size : i]
-
-        with sdpa_kernel(
-            SDPBackend.MATH
-        ):  # Actually better for Inductor to codegen attention here
-            next_token = decode_one_token(
-                model=model,
-                x=cur_token,
-                input_pos=input_pos,
-                previous_tokens=window,
-                semantic_ids=semantic_ids,
-                **sampling_kwargs,
-            )
-
-        input_pos += 1
-        cur_token = next_token.view(batch_size, model.config.num_codebooks + 1, -1)
-        previous_tokens[:, :, i : i + 1] = next_token.view(
-            batch_size, model.config.num_codebooks + 1, -1
-        )
-
-        yield cur_token.cpu()
-
-        finished = finished | (cur_token[:, 0, -1] == im_end_id)
-        if finished.all() or (
-            0 < early_stop_threshold < 1
-            and finished.sum() >= round(batch_size * early_stop_threshold)
-        ):
-            break
-
-    total_time = time.time() - start_time
-    generated_tokens = i + 1
-    tokens_per_second = (generated_tokens / total_time) * batch_size
-    logger.info(
-        f"Decoded {generated_tokens} x {batch_size} tokens in {total_time:.2f}s ({tokens_per_second:.2f} tokens/s)"
-    )
-
-
-@torch.no_grad()
-@torch.inference_mode()
-def generate_agent(
-    *,
-    model: BaseTransformer,
-    prompt: torch.Tensor,
-    max_new_tokens: int,
-    semantic_ids: list,
-    im_end_id: int = 4,
-    decode_one_token=decode_one_token_naive_agent,
-    num_samples: int = 1,
-    early_stop_threshold: float = 0.6,
-    **sampling_kwargs,
-):
-    """
-    Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
-    """
-
-    # create an empty tensor of the expected final shape and fill in the current tokens
-    T = prompt.size(1)
-    prompt = prompt[None].repeat(num_samples, 1, 1)
-
-    if T >= model.config.max_seq_len:
-        raise ValueError(
-            f"Input sequence length {T} exceeds max_seq_len {model.config.max_seq_len}"
-        )
-
-    if max_new_tokens:
-        if T + max_new_tokens > model.config.max_seq_len:
-            max_new_tokens = model.config.max_seq_len - T
-            logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
-
-        T_new = T + max_new_tokens
-    else:
-        T_new = model.config.max_seq_len
-        max_new_tokens = T_new - T
-
-    device, dtype = prompt.device, prompt.dtype
-
-    codebook_dim = 1 + model.config.num_codebooks
-    input_pos = torch.arange(0, T, device=device)
-
-    # Use non-accelerated version for now, to avoid compilation overhead
-    prefill_decode = (
-        decode_one_token_naive_agent
-        if isinstance(model, NaiveTransformer)
-        else decode_one_token_ar_agent
-    )
-    next_token = prefill_decode(
-        model,
-        prompt,
-        input_pos,
-        semantic_ids=semantic_ids,
-        **sampling_kwargs,
-    ).view(num_samples, codebook_dim, -1)
-    yield next_token.cpu()
-
-    input_pos = torch.tensor([T], device=device, dtype=torch.int)
-
-    yield from decode_n_tokens_agent(
-        model,
-        next_token,
-        input_pos,
-        max_new_tokens - 1,
-        im_end_id=im_end_id,
-        semantic_ids=semantic_ids,
-        decode_one_token=decode_one_token,
-        early_stop_threshold=early_stop_threshold,
-        **sampling_kwargs,
-    )
-
-
-def encode_tokens(
-    tokenizer,
-    string,
-    device="cuda",
-    prompt_tokens=None,
-    num_codebooks=4,
-):
-    string = clean_text(string)
-
-    messages = []
-    messages.append(
-        Message(
-            role="user",
-            parts=[TextPart(text=string)],
-            cal_loss=False,
-        )
-    )
-
-    if prompt_tokens is not None:
-        if prompt_tokens.ndim == 3:
-            assert (
-                prompt_tokens.shape[0] == 1
-            ), "3D prompt tokens should have shape (1, num_codebooks, seq_len)"
-            prompt_tokens = prompt_tokens[0]
-
-        assert prompt_tokens.ndim == 2, "Prompt tokens should be 2D tensor"
-
-        if prompt_tokens.shape[0] > num_codebooks:
-            logger.warning(
-                f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
-            )
-            prompt_tokens = prompt_tokens[:num_codebooks]
-
-        vq_part = VQPart(codes=prompt_tokens.to(device))
-
-        messages.append(
-            Message(
-                role="assistant",
-                parts=[TextPart(text="<|voice|>"), vq_part],
-                cal_loss=False,
-            )
-        )
-    else:
-        messages.append(
-            Message(
-                role="assistant",
-                parts=[TextPart(text="<|voice|>")],
-                cal_loss=False,
-                add_im_end=False,
-            )
-        )
-
-    conversation = Conversation(messages=messages)
-    # conversation.visualize(tokenizer)
-    encoded = conversation.encode_for_inference(
-        tokenizer=tokenizer,
-        num_codebooks=num_codebooks,
-    )
-
-    return encoded.to(device)
-
-
-def load_model(checkpoint_path, device, precision, compile=False, is_agent=False):
-    model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
-        checkpoint_path, load_weights=True, is_agent=is_agent
-    )
+def load_model(checkpoint_path, device, precision, compile=False):
+    model = DualARTransformer.from_pretrained(checkpoint_path, load_weights=True)
 
     model = model.to(device=device, dtype=precision)
     logger.info(f"Restored model from checkpoint")
 
     if isinstance(model, DualARTransformer):
-        decode_one_token = (
-            decode_one_token_ar_agent if is_agent else decode_one_token_ar
-        )
+        decode_one_token = decode_one_token_ar
         logger.info("Using DualARTransformer")
     else:
-        decode_one_token = (
-            decode_one_token_naive_agent if is_agent else decode_one_token_naive
-        )
-        logger.info("Using NaiveTransformer")
+        raise ValueError("Model is not a DualARTransformer")
 
     if compile:
         logger.info("Compiling function...")
@@ -723,7 +327,6 @@ def generate_long(
     temperature: float = 0.7,
     compile: bool = False,
     iterative_prompt: bool = True,
-    max_length: int = 2048,
     chunk_length: int = 150,
     prompt_text: Optional[str | list[str]] = None,
     prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
@@ -743,46 +346,36 @@ def generate_long(
 
     model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
     tokenizer = model.tokenizer
-    im_end_id = tokenizer.get_token_id("<|im_end|>")
+    base_content_sequence = ContentSequence(modality="interleave")
 
-    encoded = []
     texts = split_text(text, chunk_length) if iterative_prompt else [text]
-    encoded_prompts = [
-        Conversation(
-            messages=[
-                Message(
-                    role="system",
-                    parts=[TextPart(text="Speak out the provided text.")],
-                    cal_loss=False,
-                )
-            ]
-        )
-        .encode_for_inference(
-            tokenizer=tokenizer,
-            num_codebooks=model.config.num_codebooks,
-        )
-        .to(device)
-    ]
+    max_length = model.config.max_seq_len
 
     if use_prompt:
-        for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
-            encoded_prompts.append(
-                encode_tokens(
-                    tokenizer,
-                    string=t,
-                    device=device,
-                    prompt_tokens=c,
-                    num_codebooks=model.config.num_codebooks,
-                )
+        for t, c in zip(prompt_text, prompt_tokens):
+            base_content_sequence.append(
+                [
+                    TextPart(text=t),
+                    VQPart(codes=c),
+                ],
+                add_end=True,
             )
 
-    for idx, text in enumerate(texts):
+    encoded_prompts = base_content_sequence.encode_for_inference(
+        tokenizer, num_codebooks=model.config.num_codebooks
+    )
+    if encoded_prompts.size(1) > max_length - 2048:
+        raise ValueError(
+            f"Prompt is too long: {encoded_prompts.size(1)} > {max_length - 2048}"
+        )
+
+    encoded = []
+    for text in texts:
+        content_sequence = ContentSequence(modality=None)
+        content_sequence.append(TextPart(text=text))
         encoded.append(
-            encode_tokens(
-                tokenizer,
-                string=text,
-                device=device,
-                num_codebooks=model.config.num_codebooks,
+            content_sequence.encode_for_inference(
+                tokenizer, num_codebooks=model.config.num_codebooks
             )
         )
         logger.info(f"Encoded text: {text}")
@@ -810,30 +403,28 @@ def generate_long(
             seg = encoded[seg_idx]
             global_encoded.append(seg)
 
-            lengths = reversed([seg.size(1) for seg in global_encoded])
-
-            # Pick last 2000 tokens
-            count = 0
-            for i, length in enumerate(lengths):
-                count += length
-                if count + length > max_length - 1024 - sum(
-                    t.shape[1] for t in encoded_prompts
-                ):
-                    break
+            # Do not use previous segments to generate current segment for now
+            # lengths = reversed([seg.size(1) for seg in global_encoded])
 
-            if i != 0 and i % 2 == 0:
-                i -= 1
+            # # Pick last 2000 tokens
+            # count = 0
+            # for i, length in enumerate(lengths):
+            #     count += length
+            #     if count + length > max_length - 2048 - encoded_prompts.size(1):
+            #         break
 
-            # Rotate the list, always make sure first segment is included to avoid drift
-            if i < len(global_encoded) - 2:
-                partial_encoded = global_encoded[:2] + global_encoded[-i:]
-            else:
-                partial_encoded = global_encoded
+            # if i != 0 and i % 2 == 0:
+            #     i -= 1
 
-            if use_prompt:
-                partial_encoded = encoded_prompts + partial_encoded
+            # # Rotate the list, always make sure first segment is included to avoid drift
+            # if i < len(global_encoded) - 2:
+            #     partial_encoded = global_encoded[:2] + global_encoded[-i:]
+            # else:
+            #     partial_encoded = global_encoded
 
-            cat_encoded = torch.cat(partial_encoded, dim=1)
+            # cat_encoded = torch.cat([encoded_prompts, *partial_encoded], dim=1)
+            cat_encoded = torch.cat([encoded_prompts, seg], dim=1)
+            cat_encoded = cat_encoded.to(device=device)
             prompt_length = cat_encoded.size(1)
 
             t0 = time.perf_counter()
@@ -871,13 +462,13 @@ def generate_long(
 
             # Put the generated tokens
             # since there is <im_end>, we remove last token
-            codes = y[1:, prompt_length + 1 :].clone()
+            codes = y[1:, prompt_length:-1].clone()
             assert (codes >= 0).all(), f"Negative code found"
 
             decoded = y[:, prompt_length:].clone()
             # But for global encoding, we should keep the <im_end> token
 
-            global_encoded.append(decoded)
+            global_encoded.append(decoded.cpu())
             assert (codes >= 0).all(), f"Negative code found: {codes}"
             yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
             seg_idx += 1
@@ -1012,20 +603,20 @@ def launch_thread_safe_queue_agent(
 )
 @click.option("--num-samples", type=int, default=1)
 @click.option("--max-new-tokens", type=int, default=0)
-@click.option("--top-p", type=float, default=0.7)
-@click.option("--repetition-penalty", type=float, default=1.2)
-@click.option("--temperature", type=float, default=0.7)
+@click.option("--top-p", type=float, default=0.8)
+@click.option("--repetition-penalty", type=float, default=1.1)
+@click.option("--temperature", type=float, default=0.8)
 @click.option(
     "--checkpoint-path",
     type=click.Path(path_type=Path, exists=True),
-    default="checkpoints/fish-speech-1.5",
+    default="checkpoints/openaudio-s1-mini",
 )
 @click.option("--device", type=str, default="cuda")
 @click.option("--compile/--no-compile", default=False)
 @click.option("--seed", type=int, default=42)
 @click.option("--half/--no-half", default=False)
 @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
-@click.option("--chunk-length", type=int, default=100)
+@click.option("--chunk-length", type=int, default=300)
 @click.option("--output-dir", type=Path, default="temp")
 def main(
     text: str,
@@ -1070,7 +661,7 @@ def main(
     logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
 
     if prompt_tokens is not None:
-        prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
+        prompt_tokens = [torch.from_numpy(np.load(p)) for p in prompt_tokens]
 
     torch.manual_seed(seed)
 

+ 59 - 40
fish_speech/models/text2semantic/llama.py

@@ -16,12 +16,8 @@ from torch.nn.attention import SDPBackend, sdpa_kernel
 from torch.utils.checkpoint import checkpoint
 from transformers import AutoTokenizer
 
+from fish_speech.models.text2semantic.lora import LoraConfig, setup_lora
 from fish_speech.tokenizer import SEMANTIC_TOKENS, FishTokenizer
-from fish_speech.utils import RankedLogger
-
-from .lora import LoraConfig, setup_lora
-
-log = RankedLogger(__name__, rank_zero_only=True)
 
 
 def find_multiple(n: int, k: int) -> int:
@@ -47,6 +43,8 @@ class BaseModelArgs:
     dropout: float = 0.0
     tie_word_embeddings: bool = True
     attention_qkv_bias: bool = False
+    attention_o_bias: bool = False
+    attention_qk_norm: bool = False
 
     # Codebook configs
     codebook_size: int = 160
@@ -60,7 +58,6 @@ class BaseModelArgs:
 
     # Dummy vars
     is_reward_model: bool = False
-    share_codebook_embeddings: bool = True
     scale_codebook_embeddings: bool = False
 
     def __post_init__(self):
@@ -70,7 +67,8 @@ class BaseModelArgs:
             hidden_dim = 4 * self.dim
             n_hidden = int(2 * hidden_dim / 3)
             self.intermediate_size = find_multiple(n_hidden, 256)
-        self.head_dim = self.dim // self.n_head
+        if self.head_dim is None:
+            self.head_dim = self.dim // self.n_head
 
     @staticmethod
     def from_pretrained(path: str):
@@ -112,6 +110,8 @@ class DualARModelArgs(BaseModelArgs):
     fast_head_dim: int | None = None
     fast_intermediate_size: int | None = None
     fast_attention_qkv_bias: bool | None = None
+    fast_attention_qk_norm: bool | None = None
+    fast_attention_o_bias: bool | None = None
 
     def __post_init__(self):
         super().__post_init__()
@@ -128,6 +128,16 @@ class DualARModelArgs(BaseModelArgs):
             if self.fast_attention_qkv_bias is not None
             else self.attention_qkv_bias
         )
+        self.fast_attention_qk_norm = (
+            self.fast_attention_qk_norm
+            if self.fast_attention_qk_norm is not None
+            else self.attention_qk_norm
+        )
+        self.fast_attention_o_bias = (
+            self.fast_attention_o_bias
+            if self.fast_attention_o_bias is not None
+            else self.attention_o_bias
+        )
 
 
 class KVCache(nn.Module):
@@ -173,9 +183,7 @@ class BaseTransformer(nn.Module):
         super().__init__()
         self.config = config
         self.tokenizer = tokenizer
-        self.semantic_token_ids = [
-            tokenizer.get_token_id(SEMANTIC_TOKEN) for SEMANTIC_TOKEN in SEMANTIC_TOKENS
-        ]
+        self.semantic_token_ids = list(tokenizer.semantic_id_to_token_id.values())
 
         # Slow transformer
         self.embeddings = nn.Embedding(
@@ -202,7 +210,7 @@ class BaseTransformer(nn.Module):
             "freqs_cis",
             precompute_freqs_cis(
                 config.max_seq_len,
-                config.dim // config.n_head,
+                config.head_dim,
                 config.rope_base,
             ),
             persistent=False,
@@ -232,7 +240,6 @@ class BaseTransformer(nn.Module):
         if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
             return
 
-        head_dim = self.config.dim // self.config.n_head
         max_seq_len = find_multiple(max_seq_len, 8)
         self.max_seq_len = max_seq_len
         self.max_batch_size = max_batch_size
@@ -242,23 +249,20 @@ class BaseTransformer(nn.Module):
                 max_batch_size,
                 max_seq_len,
                 self.config.n_local_heads,
-                head_dim,
+                self.config.head_dim,
                 dtype=dtype,
             )
 
-    def embed(self, inp: Tensor, share_codebook_embeddings=True) -> Tensor:
+    def embed(self, inp: Tensor) -> Tensor:
         embeds = []
         semantic_token_ids_tensor = torch.tensor(
             self.semantic_token_ids, device=inp.device, dtype=inp.dtype
         )
 
         for i in range(self.config.num_codebooks):
-            if share_codebook_embeddings:
-                emb = self.codebook_embeddings(
-                    inp[:, i + 1] + i * self.config.codebook_size
-                )
-            else:
-                emb = self.codebook_embeddings(inp[:, i + 1])
+            emb = self.codebook_embeddings(
+                inp[:, i + 1] + i * self.config.codebook_size
+            )
             embeds.append(emb)
 
         vq_embeds_sum = torch.stack(embeds, dim=1).sum(dim=1)
@@ -318,9 +322,7 @@ class BaseTransformer(nn.Module):
         input_pos: Optional[Tensor] = None,
         return_all: bool = False,
     ) -> BaseTransformerForwardResult:
-        x = self.embed(
-            inp, share_codebook_embeddings=self.config.share_codebook_embeddings
-        )
+        x = self.embed(inp)
 
         if input_pos is None:
             input_pos = torch.arange(inp.shape[-1], device=x.device)
@@ -371,16 +373,15 @@ class BaseTransformer(nn.Module):
         max_length: int | None = None,
         lora_config: LoraConfig | None = None,
         rope_base: int | None = None,
-        is_agent: bool = False,
     ) -> "BaseTransformer":
         config = BaseModelArgs.from_pretrained(str(path))
         if max_length is not None:
             config.max_seq_len = max_length
-            log.info(f"Override max_seq_len to {max_length}")
+            logger.info(f"Override max_seq_len to {max_length}")
 
         if rope_base is not None:
             config.rope_base = rope_base
-            log.info(f"Override rope_base to {rope_base}")
+            logger.info(f"Override rope_base to {rope_base}")
 
         match config.model_type:
             case "naive":
@@ -390,18 +391,17 @@ class BaseTransformer(nn.Module):
             case _:
                 raise ValueError(f"Unknown model type: {config.model_type}")
 
-        tokenizer_path = str(path) + "/tokenizer.tiktoken"
-        tokenizer = FishTokenizer(tokenizer_path)
+        tokenizer = FishTokenizer.from_pretrained(path)
 
-        log.info(f"Loading model from {path}, config: {config}")
+        logger.info(f"Loading model from {path}, config: {config}")
         model = model_cls(config, tokenizer=tokenizer)
 
         if lora_config is not None:
             setup_lora(model, lora_config)
-            log.info(f"LoRA setup: {lora_config}")
+            logger.info(f"LoRA setup: {lora_config}")
 
         if load_weights is False:
-            log.info("Randomly initialized model")
+            logger.info("Randomly initialized model")
         else:
 
             if "int8" in str(Path(path)):
@@ -444,6 +444,11 @@ class BaseTransformer(nn.Module):
                     new_weights[k.replace("model.", "")] = v
                 weights = new_weights
 
+            # Remove audio related weights
+            for k in list(weights.keys()):
+                if "audio_" in k:
+                    weights.pop(k)
+
             # Verify the name and shape of parameters since strict=False in load_state_dict.
             for k, v in model.named_parameters():
                 if k not in weights:
@@ -454,7 +459,7 @@ class BaseTransformer(nn.Module):
                     )
 
             err = model.load_state_dict(weights, strict=False, assign=True)
-            log.info(f"Loaded weights with error: {err}")
+            logger.info(f"Loaded weights with error: {err}")
 
         return model
 
@@ -471,7 +476,7 @@ class BaseTransformer(nn.Module):
                     continue
 
                 state_dict.pop(key)
-                log.info(f"Drop LoRA parameter: {key}")
+                logger.info(f"Drop LoRA parameter: {key}")
 
         torch.save(state_dict, path / "model.pth")
         self.tokenizer.save_pretrained(path)
@@ -545,6 +550,8 @@ class DualARTransformer(BaseTransformer):
             head_dim=config.fast_head_dim,
             intermediate_size=config.fast_intermediate_size,
             attention_qkv_bias=config.fast_attention_qkv_bias,
+            attention_qk_norm=config.fast_attention_qk_norm,
+            attention_o_bias=config.fast_attention_o_bias,
         )
 
         self.fast_layers = nn.ModuleList(
@@ -562,7 +569,7 @@ class DualARTransformer(BaseTransformer):
             "fast_freqs_cis",
             precompute_freqs_cis(
                 config.num_codebooks,
-                config.fast_dim // config.fast_n_head,
+                config.fast_head_dim,
                 config.rope_base,
             ),
             persistent=False,
@@ -574,8 +581,6 @@ class DualARTransformer(BaseTransformer):
     ):
         super().setup_caches(max_batch_size, max_seq_len, dtype)
 
-        head_dim = self.config.fast_dim // self.config.fast_n_head
-
         # Fast transformer
         # The max seq len here is the number of codebooks
         for b in self.fast_layers:
@@ -583,7 +588,7 @@ class DualARTransformer(BaseTransformer):
                 max_batch_size,
                 self.config.num_codebooks,
                 self.config.fast_n_local_heads,
-                head_dim,
+                self.config.fast_head_dim,
                 dtype=dtype,
             )
 
@@ -716,15 +721,24 @@ class Attention(nn.Module):
         self.wqkv = nn.Linear(
             config.dim, total_head_dim, bias=config.attention_qkv_bias
         )
-        self.wo = nn.Linear(config.dim, config.dim, bias=False)
+        self.wo = nn.Linear(
+            config.n_head * config.head_dim, config.dim, bias=config.attention_o_bias
+        )
         self.kv_cache = None
 
+        if config.attention_qk_norm:
+            self.q_norm = nn.RMSNorm(config.head_dim, config.norm_eps)
+            self.k_norm = nn.RMSNorm(config.head_dim, config.norm_eps)
+
         self.dropout = config.dropout
         self.n_head = config.n_head
         self.head_dim = config.head_dim
         self.n_local_heads = config.n_local_heads
         self.dim = config.dim
         self.use_sdpa = use_sdpa
+        self.attention_qk_norm = config.attention_qk_norm
+        self.config = config
+
         self._register_load_state_dict_pre_hook(self.load_hook)
 
     def load_hook(self, state_dict, prefix, *args):
@@ -743,13 +757,18 @@ class Attention(nn.Module):
     ) -> Tensor:
         bsz, seqlen, _ = x.shape
 
+        q_size = self.n_head * self.head_dim
         kv_size = self.n_local_heads * self.head_dim
-        q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
+        q, k, v = self.wqkv(x).split([q_size, kv_size, kv_size], dim=-1)
 
         q = q.view(bsz, seqlen, self.n_head, self.head_dim)
         k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
         v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
 
+        if self.attention_qk_norm:
+            q = self.q_norm(q)
+            k = self.k_norm(k)
+
         q = apply_rotary_emb(q, freqs_cis)
         k = apply_rotary_emb(k, freqs_cis)
 
@@ -789,7 +808,7 @@ class Attention(nn.Module):
                 dropout_p=self.dropout if self.training else 0.0,
             )
 
-        y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
+        y = y.transpose(1, 2).contiguous().view(bsz, seqlen, q_size)
 
         return self.wo(y)
 

+ 58 - 31
fish_speech/tokenizer.py

@@ -1,6 +1,7 @@
 import base64
 import json
 import logging
+import re
 from pathlib import Path
 
 import tiktoken
@@ -27,20 +28,23 @@ EOS_TOKEN = "<|end_of_text|>"
 PAD_TOKEN = "<|pad|>"
 IM_START_TOKEN = "<|im_start|>"
 IM_END_TOKEN = "<|im_end|>"
+PHONEME_START_TOKEN = "<|phoneme_start|>"
+PHONEME_END_TOKEN = "<|phoneme_end|>"
+TOOL_CALL_START_TOKEN = "<|tool_call_start|>"
+TOOL_CALL_END_TOKEN = "<|tool_call_end|>"
 
 MODALITY_TEXT_TOKEN = "<|text|>"
 MODALITY_VOICE_TOKEN = "<|voice|>"
 MODALITY_INTERLEAVE_TOKEN = "<|interleave|>"
+AUDIO_START_TOKEN = "<|audio_start|>"
+AUDIO_END_TOKEN = "<|audio_end|>"
+AUDIO_EMBED_TOKEN = "<|audio|>"
 MODALITY_TOKENS = {
     "text": MODALITY_TEXT_TOKEN,
     "voice": MODALITY_VOICE_TOKEN,
     "interleave": MODALITY_INTERLEAVE_TOKEN,
 }
 
-PLACEHOLDER_TOKEN = [""] * 4
-for i in range(4):
-    PLACEHOLDER_TOKEN[i] = f"<|placeholder:{i}|>"
-
 SEMANTIC_TOKEN_TEMPLATE = "<|semantic:{i}|>"
 SEMANTIC_TOKENS = [SEMANTIC_TOKEN_TEMPLATE.format(i=i) for i in range(1024)]
 
@@ -51,30 +55,44 @@ ALL_SPECIAL_TOKENS = [
     PAD_TOKEN,
     IM_START_TOKEN,
     IM_END_TOKEN,
-    PLACEHOLDER_TOKEN[0],
-    PLACEHOLDER_TOKEN[1],
-    PLACEHOLDER_TOKEN[2],
-    PLACEHOLDER_TOKEN[3],
+    PHONEME_START_TOKEN,
+    PHONEME_END_TOKEN,
+    TOOL_CALL_START_TOKEN,
+    TOOL_CALL_END_TOKEN,
     MODALITY_TEXT_TOKEN,
     MODALITY_VOICE_TOKEN,
     MODALITY_INTERLEAVE_TOKEN,
+    AUDIO_START_TOKEN,
+    AUDIO_END_TOKEN,
+    AUDIO_EMBED_TOKEN,
     *SEMANTIC_TOKENS,
 ]
 
 
 class FishTokenizer:
-    def __init__(self, model_path: str) -> None:
+    def __init__(
+        self, model_path: str, special_tokens: list[str] = ALL_SPECIAL_TOKENS
+    ) -> None:
         mergeable_ranks = self.load_tiktoken_bpe(model_path)
         special_token_begin = len(mergeable_ranks)
         self.all_special_tokens_with_ids = {
-            token: special_token_begin + i for i, token in enumerate(ALL_SPECIAL_TOKENS)
-        }
-        self.semantic_id_to_token_id = {
-            i: self.all_special_tokens_with_ids[token]
-            for i, token in enumerate(SEMANTIC_TOKENS)
+            token: special_token_begin + i for i, token in enumerate(special_tokens)
         }
-        self.semantic_begin_id = self.all_special_tokens_with_ids[SEMANTIC_TOKENS[0]]
-        self.semantic_end_id = self.all_special_tokens_with_ids[SEMANTIC_TOKENS[-1]]
+
+        self.semantic_id_to_token_id = {}
+        end_idx = 0
+        for token in special_tokens:
+            if token.startswith("<|semantic:"):
+                idx = int(re.match(r"<\|semantic:(\d+)\|>", token).group(1))
+                self.semantic_id_to_token_id[idx] = self.all_special_tokens_with_ids[
+                    token
+                ]
+
+                if idx > end_idx:
+                    end_idx = idx
+
+        self.semantic_begin_id = self.semantic_id_to_token_id[0]
+        self.semantic_end_id = self.semantic_id_to_token_id[end_idx]
 
         self.tkt_model = tiktoken.core.Encoding(
             name=Path(model_path).stem,
@@ -83,6 +101,14 @@ class FishTokenizer:
             special_tokens=self.all_special_tokens_with_ids,
         )
 
+    @property
+    def vocab_size(self):
+        return len(self.tkt_model._mergeable_ranks)
+
+    @property
+    def num_special_tokens(self):
+        return len(self.all_special_tokens_with_ids)
+
     @staticmethod
     def load_tiktoken_bpe(tiktoken_bpe_file: str) -> dict[bytes, int]:
         data = {}
@@ -90,6 +116,8 @@ class FishTokenizer:
             if not line:
                 continue
             token, rank = line.split()
+            if token == "=":
+                continue
             data[base64.b64decode(token)] = int(rank)
         return data
 
@@ -124,7 +152,10 @@ class FishTokenizer:
 
         with open(path / "tokenizer.tiktoken", "w") as f:
             for token, rank in self.tkt_model._mergeable_ranks.items():
-                f.write(f"{base64.b64encode(token).decode()} {rank}\n")
+                a = base64.b64encode(token).decode()
+                if a == "":
+                    a = "="
+                f.write(f"{a} {rank}\n")
 
         with open(path / "special_tokens.json", "w") as f:
             json.dump(
@@ -136,17 +167,13 @@ class FishTokenizer:
 
     @staticmethod
     def from_pretrained(path: str):
-        return FishTokenizer(Path(path) / "tokenizer.tiktoken")
-
-
-if __name__ == "__main__":
-    tokenizer = FishTokenizer("data/mpacks/v1.4-pretrain/tokenizer.all.tiktoken")
-    tokenizer.save_pretrained("checkpoints/fish-speech-0.5B")
-    tokenizer = FishTokenizer.from_pretrained("checkpoints/fish-speech-0.5B")
-
-    print(
-        [
-            tokenizer.decode([i])
-            for i in tokenizer.encode(f"{BOS_TOKEN}你好,世界!{EOS_TOKEN}")
-        ]
-    )
+        special_tokens_path = Path(path) / "special_tokens.json"
+        if special_tokens_path.exists():
+            with open(special_tokens_path) as f:
+                all_special_tokens_with_ids = json.load(f)
+        else:
+            all_special_tokens_with_ids = ALL_SPECIAL_TOKENS
+
+        return FishTokenizer(
+            Path(path) / "tokenizer.tiktoken", all_special_tokens_with_ids
+        )

+ 10 - 31
fish_speech/utils/schema.py

@@ -5,11 +5,11 @@ from dataclasses import dataclass
 from typing import Literal
 
 import torch
-from pydantic import BaseModel, Field, conint, conlist, model_validator
+from pydantic import BaseModel, Field, conint, model_validator
 from pydantic.functional_validators import SkipValidation
 from typing_extensions import Annotated
 
-from fish_speech.conversation import Message, TextPart, VQPart
+from fish_speech.content_sequence import TextPart, VQPart
 
 
 class ServeVQPart(BaseModel):
@@ -63,31 +63,10 @@ class ServeASRResponse(BaseModel):
     transcriptions: list[ServeASRTranscription]
 
 
-class ServeMessage(BaseModel):
-    role: Literal["system", "assistant", "user"]
-    parts: list[ServeVQPart | ServeTextPart]
-
-    def to_conversation_message(self):
-        new_message = Message(role=self.role, parts=[])
-        if self.role == "assistant":
-            new_message.modality = "voice"
-
-        for part in self.parts:
-            if isinstance(part, ServeTextPart):
-                new_message.parts.append(TextPart(text=part.text))
-            elif isinstance(part, ServeVQPart):
-                new_message.parts.append(
-                    VQPart(codes=torch.tensor(part.codes, dtype=torch.int))
-                )
-            else:
-                raise ValueError(f"Unsupported part type: {part}")
-
-        return new_message
-
-
-class ServeChatRequest(BaseModel):
-    messages: Annotated[list[ServeMessage], conlist(ServeMessage, min_length=1)]
-    max_new_tokens: int = 1024
+class ServeRequest(BaseModel):
+    # Raw content sequence dict that we can use with ContentSequence(**content)
+    content: dict
+    max_new_tokens: int = 600
     top_p: float = 0.7
     repetition_penalty: float = 1.2
     temperature: float = 0.7
@@ -114,15 +93,15 @@ class ServeVQGANDecodeResponse(BaseModel):
     audios: list[bytes]
 
 
-class ServeForwardMessage(BaseModel):
-    role: str
-    content: str
+class ServeContentSequenceParts(BaseModel):
+    parts: list[VQPart | TextPart]
 
 
 class ServeResponse(BaseModel):
-    messages: list[ServeMessage]
+    content_sequences: list[ServeContentSequenceParts]
     finish_reason: Literal["stop", "error"] | None = None
     stats: dict[str, int | float | str] = {}
+    finished: list[bool] | None = None
 
 
 class ServeStreamDelta(BaseModel):

+ 6 - 6
inference.ipynb

@@ -61,7 +61,7 @@
     "# !set HF_ENDPOINT=https://hf-mirror.com\n",
     "# !export HF_ENDPOINT=https://hf-mirror.com \n",
     "\n",
-    "!huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5/"
+    "!huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/openaudio-s1-mini/"
    ]
   },
   {
@@ -84,8 +84,8 @@
    "outputs": [],
    "source": [
     "!python tools/run_webui.py \\\n",
-    "    --llama-checkpoint-path checkpoints/fish-speech-1.5 \\\n",
-    "    --decoder-checkpoint-path checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth \\\n",
+    "    --llama-checkpoint-path checkpoints/openaudio-s1-mini \\\n",
+    "    --decoder-checkpoint-path checkpoints/openaudio-s1-mini/firefly-gan-vq-fsq-8x1024-21hz-generator.pth \\\n",
     "    # --compile"
    ]
   },
@@ -122,7 +122,7 @@
     "\n",
     "!python fish_speech/models/vqgan/inference.py \\\n",
     "    -i {src_audio} \\\n",
-    "    --checkpoint-path \"checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth\"\n",
+    "    --checkpoint-path \"checkpoints/openaudio-s1-mini/firefly-gan-vq-fsq-8x1024-21hz-generator.pth\"\n",
     "\n",
     "from IPython.display import Audio, display\n",
     "audio = Audio(filename=\"fake.wav\")\n",
@@ -158,7 +158,7 @@
     "    --text \"hello world\" \\\n",
     "    --prompt-text \"The text corresponding to reference audio\" \\\n",
     "    --prompt-tokens \"fake.npy\" \\\n",
-    "    --checkpoint-path \"checkpoints/fish-speech-1.5\" \\\n",
+    "    --checkpoint-path \"checkpoints/openaudio-s1-mini\" \\\n",
     "    --num-samples 2\n",
     "    # --compile"
    ]
@@ -182,7 +182,7 @@
    "source": [
     "!python fish_speech/models/vqgan/inference.py \\\n",
     "    -i \"codes_0.npy\" \\\n",
-    "    --checkpoint-path \"checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth\"\n",
+    "    --checkpoint-path \"checkpoints/openaudio-s1-mini/firefly-gan-vq-fsq-8x1024-21hz-generator.pth\"\n",
     "\n",
     "from IPython.display import Audio, display\n",
     "audio = Audio(filename=\"fake.wav\")\n",

+ 1 - 1
pyproject.toml

@@ -31,7 +31,6 @@ dependencies = [
     "loguru>=0.6.0",
     "loralib>=0.1.2",
     "pyrootutils>=1.0.4",
-    "vector_quantize_pytorch==1.14.24",
     "resampy>=0.4.3",
     "einx[torch]==0.2.2",
     "zstandard>=0.22.0",
@@ -46,6 +45,7 @@ dependencies = [
     "tiktoken>=0.8.0",
     "pydantic==2.9.2",
     "cachetools",
+    "descript-audio-codec"
 ]
 
 [project.optional-dependencies]

+ 1 - 1
tools/download_models.py

@@ -23,7 +23,7 @@ def check_and_download_files(repo_id, file_list, local_dir):
 
 # 1st
 repo_id_1 = "fishaudio/fish-speech-1.5"
-local_dir_1 = "./checkpoints/fish-speech-1.5"
+local_dir_1 = "./checkpoints/openaudio-s1-mini"
 files_1 = [
     ".gitattributes",
     "model.pth",

+ 2 - 2
tools/run_webui.py

@@ -24,12 +24,12 @@ def parse_args():
     parser.add_argument(
         "--llama-checkpoint-path",
         type=Path,
-        default="checkpoints/fish-speech-1.5",
+        default="checkpoints/openaudio-s1-mini",
     )
     parser.add_argument(
         "--decoder-checkpoint-path",
         type=Path,
-        default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+        default="checkpoints/openaudio-s1-mini/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
     )
     parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
     parser.add_argument("--device", type=str, default="cuda")

+ 2 - 2
tools/server/api_utils.py

@@ -18,12 +18,12 @@ def parse_args():
     parser.add_argument(
         "--llama-checkpoint-path",
         type=str,
-        default="checkpoints/fish-speech-1.5",
+        default="checkpoints/openaudio-s1-mini",
     )
     parser.add_argument(
         "--decoder-checkpoint-path",
         type=str,
-        default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+        default="checkpoints/openaudio-s1-mini/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
     )
     parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
     parser.add_argument("--device", type=str, default="cuda")

+ 2 - 2
tools/vqgan/extract_vq.py

@@ -47,7 +47,7 @@ logger.add(sys.stderr, format=logger_format)
 @lru_cache(maxsize=1)
 def get_model(
     config_name: str = "firefly_gan_vq",
-    checkpoint_path: str = "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+    checkpoint_path: str = "checkpoints/openaudio-s1-mini/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
     device: str | torch.device = "cuda",
 ):
     with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
@@ -138,7 +138,7 @@ def process_batch(files: list[Path], model) -> float:
 @click.option("--config-name", default="firefly_gan_vq")
 @click.option(
     "--checkpoint-path",
-    default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+    default="checkpoints/openaudio-s1-mini/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
 )
 @click.option("--batch-size", default=64)
 @click.option("--filelist", default=None, type=Path)