Преглед изворни кода

Agent inference (#650)

* support basic TTS inference

* Agent (#648)

* agent

* rm fastapi

* routes

* dry run: tts

* api_invoke_cahta

* .gradio ignore

* small fix

* Fix llama generate

* add lots

* add agent

* fix agent

* fix agent

* fix route

* fix compile

* Add fixed timbre

* Fix duplicated audio

* Fix

* remove unused

* Improve ui

* okok

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Update Agent Webui and doc

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Lengyue <lengyue@lengyue.me>
Co-authored-by: spicysama <a2983352531@outlook.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
PoTaTo пре 1 година
родитељ
комит
834b07257c
13 измењених фајлова са 1875 додато и 86 уклоњено
  1. 1 0
      .gitignore
  2. 1 1
      API_FLAGS.txt
  3. 45 0
      Start_Agent.md
  4. 254 0
      fish_speech/conversation.py
  5. 72 10
      fish_speech/models/text2semantic/llama.py
  6. 475 24
      tools/api.py
  7. 0 36
      tools/commons.py
  8. 210 0
      tools/e2e_webui.py
  9. 256 0
      tools/fish_e2e.py
  10. 372 7
      tools/llama/generate.py
  11. 1 1
      tools/msgpack_api.py
  12. 1 7
      tools/post_api.py
  13. 187 0
      tools/schema.py

+ 1 - 0
.gitignore

@@ -29,3 +29,4 @@ asr-label*
 /references
 /example
 /faster_whisper
+/.gradio

+ 1 - 1
API_FLAGS.txt

@@ -1,5 +1,5 @@
 # --infer
-# --api
+--api
 --listen 0.0.0.0:8080 \
 --llama-checkpoint-path "checkpoints/fish-speech-1.4" \
 --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \

+ 45 - 0
Start_Agent.md

@@ -0,0 +1,45 @@
+# How To Start?
+
+### Environment Prepare
+
+If you haven't install the environment of Fish-speech, please use:
+
+```bash
+pip install -e .[stable]
+```
+
+Then use:
+
+```bash
+pip install livekit livekit-agents
+```
+
+### Launch The Agent Demo.
+
+Please use the command below under the main folder:
+
+```bash
+python -m tools.api --llama-checkpoint-path checkpoints/fish-agent-3b-pretrain/ --mode agent --compile
+```
+
+The ``--compile`` args only support Python < 3.12 , which will greatly speed up the token generation.
+
+It won't compile at once (remember).
+
+Then please use the command:
+
+```bash
+python -m tools.e2e_webui
+```
+
+This will create a Gradio WebUI on the device.
+
+When you first use the model, it will come to compile (if the ``--compile`` is True) for a short time, so please wait with patience.
+
+Have a good time!
+
+# About Agent
+
+This model is currently undergoing testing. We welcome suggestions and assistance in improving it.
+
+We are considering refining the tutorial and incorporating it into the main documentation after the testing phase is complete.

+ 254 - 0
fish_speech/conversation.py

@@ -1,2 +1,256 @@
+from dataclasses import dataclass, field
+from typing import Literal
+
+import torch
+from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerFast
+
+IM_START_TOKEN = "<|im_start|>"
+IM_END_TOKEN = "<|im_end|>"
 SEMANTIC_TOKEN = "<|semantic|>"
+MEL_TOKEN = "<|mel|>"
+PHONEME_START_TOKEN = "<|phoneme_start|>"
+PHONEME_END_TOKEN = "<|phoneme_end|>"
+ALL_SPECIAL_TOKENS = [
+    IM_START_TOKEN,
+    IM_END_TOKEN,
+    SEMANTIC_TOKEN,
+    MEL_TOKEN,
+    PHONEME_START_TOKEN,
+    PHONEME_END_TOKEN,
+]
+
 CODEBOOK_PAD_TOKEN_ID = 0
+
+
+class FishTokenizerConfig(PretrainedConfig):
+    share_codebook_embeddings: bool = True
+    codebook_size: int = 1024
+    num_codebooks: int = 8
+
+
+class FishTokenizerFast(PreTrainedTokenizerFast):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.share_codebook_embeddings = kwargs.pop("share_codebook_embeddings", True)
+        self.codebook_size = kwargs.pop("codebook_size", 1024)
+        self.num_codebooks = kwargs.pop("num_codebooks", 8)
+
+
+AutoTokenizer.register(FishTokenizerConfig, fast_tokenizer_class=FishTokenizerFast)
+
+
+@dataclass(kw_only=True)
+class BasePart:
+    pass
+
+
+@dataclass(kw_only=True)
+class VQPart(BasePart):
+    codes: torch.Tensor
+
+
+@dataclass(kw_only=True)
+class TextPart(BasePart):
+    text: str
+
+
+@dataclass(kw_only=True)
+class MelPart(BasePart):
+    mels: torch.Tensor
+
+
+@dataclass(kw_only=True)
+class EncodedMessage:
+    tokens: torch.Tensor
+    labels: torch.Tensor
+    vq_parts: list[torch.Tensor]
+    mel_parts: list[torch.Tensor]
+    vq_require_losses: torch.Tensor | None = None
+
+
+@dataclass(kw_only=True)
+class Message:
+    role: Literal["system", "user", "assistant"]
+    parts: list[VQPart | TextPart | MelPart] = field(default_factory=list)
+    add_im_start: bool = True
+    add_im_end: bool = True
+    cal_loss: bool = False
+
+    # By default, ignore the loss of the auto-generated im_start token
+    ignore_im_start_loss: bool = True
+
+    def encode(
+        self: "Message",
+        tokenizer: AutoTokenizer,
+    ) -> EncodedMessage:
+        all_tokens = []
+        all_labels = []
+
+        # Multi-modal tokens
+        vq_parts = []
+        mel_parts = []
+
+        semantic_id, mel_id = tokenizer.convert_tokens_to_ids(
+            [SEMANTIC_TOKEN, MEL_TOKEN]
+        )
+
+        parts = self.parts.copy()
+        if self.add_im_start:
+            parts.insert(0, TextPart(text=f"<|im_start|>{self.role}\n"))
+
+        if self.add_im_end:
+            parts.append(TextPart(text="<|im_end|>"))
+
+        for part in parts:
+            if isinstance(part, TextPart):
+                tokens = tokenizer.encode(
+                    part.text,
+                    add_special_tokens=False,
+                    truncation=False,
+                    return_tensors="pt",
+                ).int()[0]
+            elif isinstance(part, VQPart):
+                tokens = torch.zeros(part.codes.shape[1], dtype=torch.int) + semantic_id
+                codes = part.codes.clone() + 1
+
+                if getattr(tokenizer, "share_codebook_embeddings", True) is False:
+                    for i in range(len(codes)):
+                        codes[i] += tokenizer.codebook_size * i
+
+                vq_parts.append(codes)
+            elif isinstance(part, MelPart):
+                tokens = torch.zeros(part.mels.shape[1], dtype=torch.int) + mel_id
+                mel_parts.append(part.mels)
+            else:
+                raise ValueError(f"Unsupported part type: {type(part)}")
+
+            all_tokens.append(tokens)
+            if self.cal_loss:
+                all_labels.append(tokens.clone())
+            else:
+                all_labels.append(torch.full_like(tokens, -100))
+
+        tokens = torch.cat(all_tokens, dim=0)
+        labels = torch.cat(all_labels, dim=0)
+        assert tokens.shape == labels.shape
+
+        if self.ignore_im_start_loss and self.add_im_start:
+            labels[: len(all_tokens[0])] = -100
+
+        return EncodedMessage(
+            tokens=tokens,
+            labels=labels,
+            vq_parts=vq_parts,
+            mel_parts=mel_parts,
+        )
+
+
+@dataclass
+class Conversation:
+    messages: list[Message]
+
+    def encode(
+        self: "Conversation",
+        tokenizer: AutoTokenizer,
+        add_shift: bool = True,
+    ) -> EncodedMessage:
+        # Build the input_ids and labels
+        tokens = []
+        labels = []
+        vq_parts = []
+        mel_parts = []
+        vq_require_losses = []
+
+        for message in self.messages:
+            encoded = message.encode(
+                tokenizer,
+            )
+            tokens.append(encoded.tokens)
+            labels.append(encoded.labels)
+            vq_parts.extend(encoded.vq_parts)
+            mel_parts.extend(encoded.mel_parts)
+            vq_require_losses.extend([message.cal_loss] * len(encoded.vq_parts))
+
+        tokens = torch.cat(tokens, dim=0)
+        labels = torch.cat(labels, dim=0)
+        vq_require_losses = torch.tensor(vq_require_losses, dtype=torch.bool)
+
+        if add_shift:
+            tokens = tokens[:-1]
+            labels = labels[1:]
+
+        assert tokens.dtype in [
+            torch.int,
+            torch.long,
+        ], f"Invalid dtype: {tokens.dtype}, conv: {conversation}"
+
+        return EncodedMessage(
+            tokens=tokens,
+            labels=labels,
+            vq_parts=vq_parts,
+            mel_parts=mel_parts,
+            vq_require_losses=vq_require_losses,
+        )
+
+    def encode_for_inference(
+        self: "Conversation",
+        tokenizer: AutoTokenizer,
+        num_codebooks: int,
+    ) -> EncodedMessage:
+        encoded = self.encode(tokenizer, add_shift=False)
+        tokens = encoded.tokens
+        values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.int)
+        values[0] = tokens
+
+        if encoded.vq_parts is None or len(encoded.vq_parts) == 0:
+            return values
+
+        semantic_id, mel_id = tokenizer.convert_tokens_to_ids(
+            [SEMANTIC_TOKEN, MEL_TOKEN]
+        )
+        vq_parts = encoded.vq_parts
+        vq_parts = torch.cat(vq_parts, dim=1)
+        values[1:, tokens == semantic_id] = vq_parts
+        return values
+
+    def visualize(self: "Conversation", tokenizer: AutoTokenizer):
+        encoded = self.encode(tokenizer, add_shift=False)
+
+        print_in_blue = lambda x: print("\033[94m" + x + "\033[0m", end="")
+        print_in_green = lambda x: print("\033[92m" + x + "\033[0m", end="")
+
+        for tok, lab in zip(encoded.tokens, encoded.labels):
+            val = tokenizer.decode(tok, skip_special_tokens=False)
+            if val == "\n":
+                val = "\\n\n"
+
+            if lab == -100:
+                print_in_green(val)
+            else:
+                print_in_blue(val)
+
+        print()
+
+
+if __name__ == "__main__":
+    message0 = Message(
+        role="user",
+        parts=[
+            TextPart(text="Hello, how are you?"),
+            VQPart(codes=torch.zeros((4, 10))),
+        ],
+        cal_loss=False,
+    )
+
+    message1 = Message(
+        role="assistant",
+        parts=[TextPart(text="I'm fine, thank you.")],
+        cal_loss=True,
+    )
+    conversation = Conversation([message0, message1])
+    tokenizer = AutoTokenizer.from_pretrained("checkpoints/Qwen2-1.5B-Instruct")
+    conversation.visualize(tokenizer)
+
+    encoded = conversation.encode(tokenizer)
+    print(encoded)
+    print(tokenizer.batch_decode(encoded.tokens))

+ 72 - 10
fish_speech/models/text2semantic/llama.py

@@ -1,3 +1,4 @@
+import dataclasses
 import json
 import math
 from collections import OrderedDict
@@ -57,6 +58,10 @@ class BaseModelArgs:
     # Initialize the model
     initializer_range: float = 0.02
 
+    # Dummy vars
+    is_reward_model: bool = False
+    share_codebook_embeddings: bool = True
+
     def __post_init__(self):
         if self.n_local_heads == -1:
             self.n_local_heads = self.n_head
@@ -100,6 +105,28 @@ class NaiveModelArgs(BaseModelArgs):
 class DualARModelArgs(BaseModelArgs):
     model_type: str = "dual_ar"
     n_fast_layer: int = 4
+    fast_dim: int | None = None
+    fast_n_head: int | None = None
+    fast_n_local_heads: int | None = None
+    fast_head_dim: int | None = None
+    fast_intermediate_size: int | None = None
+    fast_attention_qkv_bias: bool | None = None
+
+    def __post_init__(self):
+        super().__post_init__()
+
+        self.fast_dim = self.fast_dim or self.dim
+        self.fast_n_head = self.fast_n_head or self.n_head
+        self.fast_n_local_heads = self.fast_n_local_heads or self.n_local_heads
+        self.fast_head_dim = self.fast_head_dim or self.head_dim
+        self.fast_intermediate_size = (
+            self.fast_intermediate_size or self.intermediate_size
+        )
+        self.fast_attention_qkv_bias = (
+            self.fast_attention_qkv_bias
+            if self.fast_attention_qkv_bias is not None
+            else self.attention_qkv_bias
+        )
 
 
 class KVCache(nn.Module):
@@ -474,20 +501,46 @@ class DualARTransformer(BaseTransformer):
     def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
         super().__init__(config, init_weights=False, tokenizer=tokenizer)
 
+        # Project to fast dim if needed
+        if config.fast_dim is not None and config.fast_dim != config.dim:
+            self.fast_project_in = nn.Linear(config.dim, config.fast_dim)
+        else:
+            self.fast_project_in = nn.Identity()
+
         # Fast transformer
-        self.fast_embeddings = nn.Embedding(config.codebook_size, config.dim)
+        self.fast_embeddings = nn.Embedding(config.codebook_size, config.fast_dim)
 
         # The equivalent bs is so large that sdpa doesn't work
+        override_config = dataclasses.replace(
+            config,
+            dim=config.fast_dim,
+            n_head=config.fast_n_head,
+            n_local_heads=config.fast_n_local_heads,
+            head_dim=config.fast_head_dim,
+            intermediate_size=config.fast_intermediate_size,
+            attention_qkv_bias=config.fast_attention_qkv_bias,
+        )
+
         self.fast_layers = nn.ModuleList(
-            TransformerBlock(config, use_sdpa=False) for _ in range(config.n_fast_layer)
+            TransformerBlock(override_config, use_sdpa=False)
+            for _ in range(config.n_fast_layer)
         )
-        self.fast_norm = RMSNorm(config.dim, eps=config.norm_eps)
+        self.fast_norm = RMSNorm(config.fast_dim, eps=config.norm_eps)
         self.fast_output = nn.Linear(
-            config.dim,
+            config.fast_dim,
             config.codebook_size,
             bias=False,
         )
 
+        self.register_buffer(
+            "fast_freqs_cis",
+            precompute_freqs_cis(
+                config.num_codebooks,
+                config.fast_dim // config.fast_n_head,
+                config.rope_base,
+            ),
+            persistent=False,
+        )
         self.apply(self._init_weights)
 
     def setup_caches(
@@ -495,7 +548,7 @@ class DualARTransformer(BaseTransformer):
     ):
         super().setup_caches(max_batch_size, max_seq_len, dtype)
 
-        head_dim = self.config.dim // self.config.n_head
+        head_dim = self.config.fast_dim // self.config.fast_n_head
 
         # Fast transformer
         # The max seq len here is the number of codebooks
@@ -503,7 +556,7 @@ class DualARTransformer(BaseTransformer):
             b.attention.kv_cache = KVCache(
                 max_batch_size,
                 self.config.num_codebooks,
-                self.config.n_local_heads,
+                self.config.fast_n_local_heads,
                 head_dim,
                 dtype=dtype,
             )
@@ -516,13 +569,13 @@ class DualARTransformer(BaseTransformer):
         parent_result = super().forward(inp, key_padding_mask)
         token_logits = parent_result.logits
         x = parent_result.hidden_states
+        x = self.fast_project_in(x)
 
         # Fast transformer
         fast_seq_len = self.config.num_codebooks
         fast_mask = self.causal_mask[
             None, None, :fast_seq_len, :fast_seq_len
         ]  # (B, N, Q, K)
-        fast_freqs_cis = self.freqs_cis[:fast_seq_len]
 
         # Drop the last token and rotate left
         codebooks = inp[:, 1:-1, 1:]
@@ -545,9 +598,11 @@ class DualARTransformer(BaseTransformer):
 
         for layer in self.fast_layers:
             if self.config.use_gradient_checkpointing and self.training:
-                x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True)
+                x = checkpoint(
+                    layer, x, self.fast_freqs_cis, fast_mask, use_reentrant=True
+                )
             else:
-                x = layer(x, fast_freqs_cis, fast_mask)
+                x = layer(x, self.fast_freqs_cis, fast_mask)
 
         # unflatten the batch and num_codebooks
         fast_out = self.fast_norm(x)
@@ -587,7 +642,7 @@ class DualARTransformer(BaseTransformer):
         fast_mask = self.causal_mask[
             None, None, input_pos, : self.config.num_codebooks
         ]  # (B, N, Q, K)
-        fast_freqs_cis = self.freqs_cis[input_pos]
+        fast_freqs_cis = self.fast_freqs_cis[input_pos]
 
         for layer in self.fast_layers:
             x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
@@ -598,6 +653,13 @@ class DualARTransformer(BaseTransformer):
 
         return codebook_logits
 
+    def forward_generate(
+        self, x: Tensor, input_pos: Optional[Tensor] = None
+    ) -> TransformerForwardResult:
+        x = super().forward_generate(x, input_pos)
+        x.hidden_states = self.fast_project_in(x.hidden_states)
+        return x
+
 
 class TransformerBlock(nn.Module):
     def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:

+ 475 - 24
tools/api.py

@@ -1,7 +1,8 @@
 import io
 import os
 import queue
-import sys
+import re
+import time
 import traceback
 import wave
 from argparse import ArgumentParser
@@ -9,6 +10,7 @@ from http import HTTPStatus
 from pathlib import Path
 from typing import Annotated, Any
 
+import librosa
 import numpy as np
 import ormsgpack
 import pyrootutils
@@ -26,26 +28,67 @@ from kui.asgi import (
     Kui,
     OpenAPI,
     StreamResponse,
+    request,
 )
 from kui.asgi.routing import MultimethodRoutes
 from loguru import logger
+from transformers import AutoTokenizer
 
 pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+import struct
+from threading import Lock
+
+import httpx
+from cachetools import LRUCache, cached
+from funasr import AutoModel
+from silero_vad import get_speech_timestamps, load_silero_vad
+
+from fish_speech.conversation import IM_END_TOKEN, SEMANTIC_TOKEN
+from fish_speech.models.text2semantic.llama import BaseModelArgs
 
 # from fish_speech.models.vqgan.lit_module import VQGAN
 from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
 from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
 from fish_speech.utils import autocast_exclude_mps, set_seed
-from tools.commons import ServeTTSRequest
 from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
 from tools.llama.generate import (
     GenerateRequest,
     GenerateResponse,
     WrappedGenerateResponse,
     launch_thread_safe_queue,
+    launch_thread_safe_queue_agent,
+)
+from tools.schema import (
+    GLOBAL_NUM_SAMPLES,
+    ASRPackRequest,
+    ServeASRRequest,
+    ServeASRResponse,
+    ServeASRSegment,
+    ServeAudioPart,
+    ServeForwardMessage,
+    ServeMessage,
+    ServeRequest,
+    ServeResponse,
+    ServeStreamDelta,
+    ServeStreamResponse,
+    ServeTextPart,
+    ServeTimedASRResponse,
+    ServeTTSRequest,
+    ServeVQGANDecodeRequest,
+    ServeVQGANDecodeResponse,
+    ServeVQGANEncodeRequest,
+    ServeVQGANEncodeResponse,
+    ServeVQPart,
 )
 from tools.vqgan.inference import load_model as load_decoder_model
 
+global_lock = Lock()
+
+# Whether to disable keepalive (which is helpful if the server is in the same cluster)
+DISABLE_KEEPALIVE = os.getenv("DISABLE_KEEPALIVE", "false").lower() == "true"
+async_client = httpx.AsyncClient(
+    timeout=120, limits=httpx.Limits(keepalive_expiry=0 if DISABLE_KEEPALIVE else None)
+)
 backends = torchaudio.list_audio_backends()
 
 if "ffmpeg" in backends:
@@ -169,6 +212,385 @@ def get_content_type(audio_format):
         return "application/octet-stream"
 
 
+@torch.no_grad()
+@torch.autocast(device_type="cuda", dtype=torch.half)
+def batch_encode(model, audios: list[bytes | torch.Tensor]):
+    audios = [
+        (
+            torch.from_numpy(
+                librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0]
+            )[None]
+            if isinstance(audio, bytes)
+            else audio
+        )
+        for audio in audios
+    ]
+
+    # if any(audio.shape[-1] > model.spec_transform.sample_rate * 120 for audio in audios):
+    #     raise ValueError("Single audio length is too long (>120s)")
+
+    max_length = max(audio.shape[-1] for audio in audios)
+    print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s")
+
+    lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device)
+    max_length = lengths.max().item()
+    padded = torch.stack(
+        [
+            torch.nn.functional.pad(audio, (0, max_length - audio.shape[-1]))
+            for audio in audios
+        ]
+    ).to(model.device)
+
+    features, feature_lengths = model.encode(padded, audio_lengths=lengths)
+    features, feature_lengths = features.cpu(), feature_lengths.cpu()
+
+    return [feature[..., :length] for feature, length in zip(features, feature_lengths)]
+
+
+@cached(
+    cache=LRUCache(maxsize=10000),
+    key=lambda model, audios: (model.device, tuple(audios)),
+)
+def cached_vqgan_batch_encode(model, audios: list[bytes]):
+    return batch_encode(model, audios)
+
+
+@routes.http.post("/v1/vqgan/encode")
+def api_vqgan_encode(payload: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]):
+
+    start_time = time.time()
+    tokens = cached_vqgan_batch_encode(decoder_model, payload.audios)
+    logger.info(f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms")
+
+    return ormsgpack.packb(
+        ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
+        option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
+    )
+
+
+@torch.no_grad()
+@torch.autocast(device_type="cuda", dtype=torch.half)
+def vqgan_decode(model, features):
+    lengths = torch.tensor(
+        [feature.shape[-1] for feature in features], device=model.device
+    )
+    max_length = lengths.max().item()
+    padded = torch.stack(
+        [
+            torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1]))
+            for feature in features
+        ]
+    ).to(model.device)
+
+    # If bs too large, we do micro batch decode
+    audios, audio_lengths = [], []
+    for i in range(0, padded.shape[0], 8):
+        audio, audio_length = model.decode(
+            padded[i : i + 8], feature_lengths=lengths[i : i + 8]
+        )
+        audios.append(audio)
+        audio_lengths.append(audio_length)
+    audios = torch.cat(audios, dim=0)
+    audio_lengths = torch.cat(audio_lengths, dim=0)
+    audios, audio_lengths = audios.cpu(), audio_lengths.cpu()
+
+    return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)]
+
+
+@routes.http.post("/v1/vqgan/decode")
+def api_vqgan_decode(payload: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]):
+    tokens = [torch.tensor(token, dtype=torch.int) for token in payload.tokens]
+    start_time = time.time()
+    audios = vqgan_decode(decoder_model, tokens)
+    logger.info(f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms")
+    audios = [audio.astype(np.float16).tobytes() for audio in audios]
+    return ormsgpack.packb(
+        ServeVQGANDecodeResponse(audios=audios), option=ormsgpack.OPT_SERIALIZE_PYDANTIC
+    )
+
+
+@torch.no_grad()
+def batch_asr(model, audios, sr, language="auto"):
+    resampled_audios = []
+    for audio in audios:
+        audio = torchaudio.functional.resample(audio, sr, 16000)
+        assert audio.ndim == 1
+        resampled_audios.append(audio)
+
+    with global_lock:
+        res = model.generate(
+            input=resampled_audios,
+            batch_size=len(resampled_audios),
+            language=language,
+            use_itn=True,
+        )
+
+    results = []
+    for r, audio in zip(res, audios):
+        text = r["text"]
+        text = re.sub(r"<\|.*?\|>", "", text)
+        duration = len(audio) / sr * 1000
+        huge_gap = False
+
+        if "timestamp" in r and len(r["timestamp"]) > 2:
+            for timestamp_a, timestamp_b in zip(
+                r["timestamp"][:-1], r["timestamp"][1:]
+            ):
+                # If there is a gap of more than 5 seconds, we consider it as a huge gap
+                if timestamp_b[0] - timestamp_a[1] > 5000:
+                    huge_gap = True
+                    break
+
+            # Doesn't make sense to have a huge gap at the end
+            if duration - r["timestamp"][-1][1] > 3000:
+                huge_gap = True
+
+        results.append(
+            {
+                "text": text,
+                "duration": duration,
+                "huge_gap": huge_gap,
+            }
+        )
+
+    return results
+
+
+@routes.http.post("/v1/asr")
+def api_invoke_asr(payload: Annotated[ServeASRRequest, Body(exclusive=True)]):
+    start_time = time.time()
+    audios = [np.frombuffer(audio, dtype=np.float16) for audio in payload.audios]
+    audios = [torch.from_numpy(audio).float() for audio in audios]
+
+    if any(audios.shape[-1] >= 30 * payload.sample_rate for audios in audios):
+        raise HTTPException(status_code=400, detail="Audio length is too long")
+
+    transcriptions = batch_asr(
+        asr_model, audios=audios, sr=payload.sample_rate, language=payload.language
+    )
+    logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms")
+
+    return ormsgpack.packb(
+        ServeASRResponse(transcriptions=transcriptions),
+        option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
+    )
+
+
+from fish_speech.conversation import Conversation, Message
+
+
+def execute_request(
+    input_queue: queue.Queue,
+    tokenizer: AutoTokenizer,
+    config: BaseModelArgs,
+    request: ServeRequest,
+    device: str = "cuda:0",
+):
+    semantic_id, im_end_id = tokenizer.convert_tokens_to_ids(
+        [SEMANTIC_TOKEN, IM_END_TOKEN]
+    )
+    messages = []
+    for message in request.messages:
+        messages.append(message.to_conversation_message())
+
+    assert len(messages) >= 1, "At least one message is required"
+    # assert messages[-1].role == "user", "The last message must be from the user"
+
+    if messages[-1].role == "user":
+        messages.append(Message(role="assistant", parts=[], add_im_end=False))
+    else:
+        assert (
+            messages[-1].role == "assistant"
+        ), "The last message must be from the assistant"
+        messages[-1].add_im_end = False
+
+    conv = Conversation(messages=messages)
+    prompt = conv.encode_for_inference(
+        tokenizer=tokenizer, num_codebooks=config.num_codebooks
+    ).to(device)
+
+    if request.streaming:
+        for i in range(request.num_samples):
+            yield ServeStreamResponse(
+                sample_id=i,
+                delta=ServeStreamDelta(
+                    role="assistant",
+                ),
+            )
+
+    req = {
+        "prompt": prompt,
+        "max_new_tokens": request.max_new_tokens,
+        "im_end_id": im_end_id,
+        "semantic_id": semantic_id,
+        "temperature": request.temperature,
+        "top_p": request.top_p,
+        "repetition_penalty": request.repetition_penalty,
+        "num_samples": request.num_samples,
+        "early_stop_threshold": request.early_stop_threshold,
+    }
+
+    start = time.time()
+    response_queue = queue.Queue()
+    input_queue.put(GenerateRequest(req, response_queue))
+
+    # Decoding
+    decode_buffer = [[] for _ in range(request.num_samples)]
+    parts = [[] for _ in range(request.num_samples)]
+
+    def send_reset_buffer(sample_id):
+        nonlocal decode_buffer
+        if len(decode_buffer[sample_id]) == 0:
+            return
+
+        decoded = tokenizer.decode(decode_buffer[sample_id])
+        part = ServeTextPart(text=decoded)
+
+        if request.streaming:
+            yield ServeStreamResponse(delta=ServeStreamDelta(part=part))
+        else:
+            parts[sample_id].append(part)
+
+        decode_buffer[sample_id] = []
+
+    # Decode process
+    finished = [False for _ in range(request.num_samples)]
+    stats = {}
+    idx = 0
+    while True:
+        response = response_queue.get()
+
+        if response in ["stop", "error"]:
+            break
+
+        for sample_id, tokens in enumerate(response):
+            if finished[sample_id]:
+                continue
+
+            if tokens[0] == im_end_id:
+                finished[sample_id] = True
+                if request.streaming:
+                    yield from send_reset_buffer(sample_id)
+                    yield ServeStreamResponse(
+                        sample_id=sample_id,
+                        finish_reason="stop",
+                        stats=stats,
+                    )
+                continue
+
+            if tokens[0] == semantic_id and request.streaming:
+                yield from send_reset_buffer(sample_id)
+                # Streaming vq
+                _tokens = tokens[1:].clone() - 1
+
+                if config.share_codebook_embeddings is False:
+                    for i in range(len(_tokens)):
+                        _tokens[i] -= config.codebook_size * i
+
+                yield ServeStreamResponse(
+                    sample_id=sample_id,
+                    delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())),
+                )
+                continue
+
+            # Not streaming vq
+            if tokens[0] == semantic_id:
+                yield from send_reset_buffer(sample_id)
+                # None streaming vq
+                if len(parts[sample_id]) == 0 or not isinstance(
+                    parts[sample_id][-1], ServeVQPart
+                ):
+                    _tokens = tokens[1:].clone() - 1
+
+                    if config.share_codebook_embeddings is False:
+                        for i in range(len(_tokens)):
+                            _tokens[i] -= config.codebook_size * i
+
+                    parts[sample_id].append(ServeVQPart(codes=_tokens.tolist()))
+                else:
+                    for codebook_id, value in enumerate(tokens[1:, :]):
+                        val = value.item() - 1
+                        if config.share_codebook_embeddings is False:
+                            val -= config.codebook_size * codebook_id
+
+                        parts[sample_id][-1].codes[codebook_id].append(val)
+                continue
+
+            if tokens[0] != semantic_id:
+                # Stream text decode is not supported now
+                decode_buffer[sample_id].append(tokens[0, 0])
+
+        if idx == 0:
+            stats["time_to_first_token"] = (time.time() - start) * 1000
+
+        idx += 1
+
+    for sample_id in range(request.num_samples):
+        yield from send_reset_buffer(sample_id)
+
+    stats["total_time"] = (time.time() - start) * 1000
+    stats["total_tokens"] = idx
+
+    if request.streaming:
+        for sample_id in range(request.num_samples):
+            if finished[sample_id]:
+                continue
+            yield ServeStreamResponse(
+                finish_reason=response, stats=stats, sample_id=sample_id
+            )
+        return
+
+    yield ServeResponse(
+        messages=[
+            ServeMessage(role="assistant", parts=parts[i])
+            for i in range(request.num_samples)
+        ],
+        finish_reason=response,
+        stats=stats,
+    )
+
+
+@routes.http.post("/v1/chat")
+def api_invoke_chat(
+    req: Annotated[ServeRequest, Body(exclusive=True)],
+):
+    """
+    Invoke model and generate audio
+    """
+
+    # This makes torch compile happy
+    assert (
+        req.num_samples == GLOBAL_NUM_SAMPLES
+    ), f"num_samples must be {GLOBAL_NUM_SAMPLES}"
+
+    content_type = request.headers.get("Content-Type", "application/json")
+    json_mode = "application/json" in content_type
+
+    async def wrapped_generator():
+        generator = execute_request(llama_queue, tokenizer, config, req, args.device)
+
+        for i in generator:
+            if json_mode:
+                body = i.model_dump_json().encode("utf-8")
+                yield b"data: " + body + b"\n\n"
+            else:
+                body = ormsgpack.packb(i, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
+                yield struct.pack("I", len(body)) + body
+
+    # Naive mode
+    if req.streaming is False:
+        result = next(execute_request(llama_queue, tokenizer, config, req, args.device))
+
+        if json_mode:
+            return JSONResponse(result.model_dump())
+        else:
+            return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
+
+    return StreamResponse(
+        iterable=wrapped_generator(), content_type="text/event-stream"
+    )
+
+
 @torch.inference_mode()
 def inference(req: ServeTTSRequest):
 
@@ -360,6 +782,8 @@ async def api_health():
 
 def parse_args():
     parser = ArgumentParser()
+    parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts")
+    parser.add_argument("--load-asr-model", action="store_true")
     parser.add_argument(
         "--llama-checkpoint-path",
         type=str,
@@ -419,6 +843,15 @@ app = Kui(
 )
 
 
+def load_asr_model(*, device="cuda", hub="ms"):
+    return AutoModel(
+        model="iic/SenseVoiceSmall",
+        device=device,
+        disable_pbar=True,
+        hub=hub,
+    )
+
+
 # Each worker process created by Uvicorn has its own memory space,
 # meaning that models and variables are not shared between processes.
 # Therefore, any global variables (like `llama_queue` or `decoder_model`)
@@ -431,20 +864,33 @@ app = Kui(
 @app.on_startup
 def initialize_app(app: Kui):
 
-    global args, llama_queue, decoder_model, prompt_tokens, prompt_texts
+    global args, llama_queue, tokenizer, config, decoder_model, vad_model, asr_model, prompt_tokens, prompt_texts
 
     prompt_tokens, prompt_texts = [], []
 
     args = parse_args()  # args same as ones in other processes
     args.precision = torch.half if args.half else torch.bfloat16
 
+    if args.load_asr_model:
+        logger.info(f"Loading ASR model...")
+        asr_model = load_asr_model(device=args.device)
+
     logger.info("Loading Llama model...")
-    llama_queue = launch_thread_safe_queue(
-        checkpoint_path=args.llama_checkpoint_path,
-        device=args.device,
-        precision=args.precision,
-        compile=args.compile,
-    )
+
+    if args.mode == "tts":
+        llama_queue = launch_thread_safe_queue(
+            checkpoint_path=args.llama_checkpoint_path,
+            device=args.device,
+            precision=args.precision,
+            compile=args.compile,
+        )
+    else:
+        llama_queue, tokenizer, config = launch_thread_safe_queue_agent(
+            checkpoint_path=args.llama_checkpoint_path,
+            device=args.device,
+            precision=args.precision,
+            compile=args.compile,
+        )
 
     logger.info("Llama model loaded, loading VQ-GAN model...")
 
@@ -456,23 +902,28 @@ def initialize_app(app: Kui):
 
     logger.info("VQ-GAN model loaded, warming up...")
 
-    # Dry run to ensure models work and avoid first-time latency
-    list(
-        inference(
-            ServeTTSRequest(
-                text="Hello world.",
-                references=[],
-                reference_id=None,
-                max_new_tokens=0,
-                chunk_length=200,
-                top_p=0.7,
-                repetition_penalty=1.2,
-                temperature=0.7,
-                emotion=None,
-                format="wav",
+    vad_model = load_silero_vad()
+
+    logger.info("VAD model loaded, warming up...")
+
+    if args.mode == "tts":
+        # Dry run to ensure models work and avoid first-time latency
+        list(
+            inference(
+                ServeTTSRequest(
+                    text="Hello world.",
+                    references=[],
+                    reference_id=None,
+                    max_new_tokens=0,
+                    chunk_length=200,
+                    top_p=0.7,
+                    repetition_penalty=1.2,
+                    temperature=0.7,
+                    emotion=None,
+                    format="wav",
+                )
             )
         )
-    )
 
     logger.info(f"Warming up done, starting server at http://{args.listen}")
 

+ 0 - 36
tools/commons.py

@@ -1,36 +0,0 @@
-from typing import Annotated, Literal, Optional
-
-from pydantic import BaseModel, Field, conint
-
-
-class ServeReferenceAudio(BaseModel):
-    audio: bytes
-    text: str
-
-
-class ServeTTSRequest(BaseModel):
-    text: str
-    chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
-    # Audio format
-    format: Literal["wav", "pcm", "mp3"] = "wav"
-    mp3_bitrate: Literal[64, 128, 192] = 128
-    # References audios for in-context learning
-    references: list[ServeReferenceAudio] = []
-    # Reference id
-    # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
-    # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
-    reference_id: str | None = None
-    seed: int | None = None
-    use_memory_cache: Literal["on-demand", "never"] = "never"
-    # Normalize text for en & zh, this increase stability for numbers
-    normalize: bool = True
-    mp3_bitrate: Optional[int] = 64
-    opus_bitrate: Optional[int] = -1000
-    # Balance mode will reduce latency to 300ms, but may decrease stability
-    latency: Literal["normal", "balanced"] = "normal"
-    # not usually used below
-    streaming: bool = False
-    max_new_tokens: int = 1024
-    top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
-    repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
-    temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7

+ 210 - 0
tools/e2e_webui.py

@@ -0,0 +1,210 @@
+import re
+
+import gradio as gr
+import numpy as np
+
+from .fish_e2e import FishE2EAgent, FishE2EEventType
+from .schema import ServeMessage, ServeTextPart, ServeVQPart
+
+
+class ChatState:
+    def __init__(self):
+        self.conversation = []
+        self.added_systext = False
+        self.added_sysaudio = False
+
+    def get_history(self):
+        results = []
+        for msg in self.conversation:
+            results.append({"role": msg.role, "content": self.repr_message(msg)})
+
+        # Process assistant messages to extract questions and update user messages
+        for i, msg in enumerate(results):
+            if msg["role"] == "assistant":
+                match = re.search(r"Question: (.*?)\n\nResponse:", msg["content"])
+                if match and i > 0 and results[i - 1]["role"] == "user":
+                    # Update previous user message with extracted question
+                    results[i - 1]["content"] += "\n" + match.group(1)
+                    # Remove the Question/Answer format from assistant message
+                    msg["content"] = msg["content"].split("\n\nResponse: ", 1)[1]
+        return results
+
+    def repr_message(self, msg: ServeMessage):
+        response = ""
+        for part in msg.parts:
+            if isinstance(part, ServeTextPart):
+                response += part.text
+            elif isinstance(part, ServeVQPart):
+                response += f"<audio {len(part.codes[0]) / 21:.2f}s>"
+        return response
+
+
+def clear_fn():
+    return [], ChatState(), None, None, None
+
+
+async def process_audio_input(
+    sys_audio_input, sys_text_input, audio_input, state: ChatState, text_input: str
+):
+    if audio_input is None and not text_input:
+        raise gr.Error("No input provided")
+
+    agent = FishE2EAgent()  # Create new agent instance for each request
+
+    # Convert audio input to numpy array
+    if isinstance(audio_input, tuple):
+        sr, audio_data = audio_input
+    elif text_input:
+        sr = 44100
+        audio_data = None
+    else:
+        raise gr.Error("Invalid audio format")
+
+    if isinstance(sys_audio_input, tuple):
+        sr, sys_audio_data = sys_audio_input
+    elif text_input:
+        sr = 44100
+        sys_audio_data = None
+    else:
+        raise gr.Error("Invalid audio format")
+
+    def append_to_chat_ctx(
+        part: ServeTextPart | ServeVQPart, role: str = "assistant"
+    ) -> None:
+        if not state.conversation or state.conversation[-1].role != role:
+            state.conversation.append(ServeMessage(role=role, parts=[part]))
+        else:
+            state.conversation[-1].parts.append(part)
+
+    if state.added_systext is False and sys_text_input:
+        state.added_systext = True
+        append_to_chat_ctx(ServeTextPart(text=sys_text_input), role="system")
+    if text_input:
+        append_to_chat_ctx(ServeTextPart(text=text_input), role="user")
+        audio_data = None
+
+    result_audio = b""
+    async for event in agent.stream(
+        sys_audio_data,
+        audio_data,
+        sr,
+        1,
+        chat_ctx={
+            "messages": state.conversation,
+            "added_sysaudio": state.added_sysaudio,
+        },
+    ):
+        if event.type == FishE2EEventType.USER_CODES:
+            append_to_chat_ctx(ServeVQPart(codes=event.vq_codes), role="user")
+        elif event.type == FishE2EEventType.SPEECH_SEGMENT:
+            result_audio += event.frame.data
+            np_audio = np.frombuffer(result_audio, dtype=np.int16)
+            append_to_chat_ctx(ServeVQPart(codes=event.vq_codes))
+
+            yield state.get_history(), (44100, np_audio), None, None
+        elif event.type == FishE2EEventType.TEXT_SEGMENT:
+            append_to_chat_ctx(ServeTextPart(text=event.text))
+            if result_audio:
+                np_audio = np.frombuffer(result_audio, dtype=np.int16)
+                yield state.get_history(), (44100, np_audio), None, None
+            else:
+                yield state.get_history(), None, None, None
+
+    np_audio = np.frombuffer(result_audio, dtype=np.int16)
+    yield state.get_history(), (44100, np_audio), None, None
+
+
+async def process_text_input(
+    sys_audio_input, sys_text_input, state: ChatState, text_input: str
+):
+    async for event in process_audio_input(
+        sys_audio_input, sys_text_input, None, state, text_input
+    ):
+        yield event
+
+
+def create_demo():
+    with gr.Blocks() as demo:
+        state = gr.State(ChatState())
+
+        with gr.Row():
+            # Left column (70%) for chatbot and notes
+            with gr.Column(scale=7):
+                chatbot = gr.Chatbot(
+                    [],
+                    elem_id="chatbot",
+                    bubble_full_width=False,
+                    height=600,
+                    type="messages",
+                )
+
+                notes = gr.Markdown(
+                    """
+                # Fish Agent
+                1. 此Demo为Fish Audio自研端到端语言模型Fish Agent 3B版本.
+                2. 你可以在我们的官方仓库找到代码以及权重,但是相关内容全部基于 CC BY-NC-SA 4.0 许可证发布.
+                3. Demo为早期灰度测试版本,推理速度尚待优化.
+                # 特色
+                1. 该模型自动集成ASR与TTS部分,不需要外挂其它模型,即真正的端到端,而非三段式(ASR+LLM+TTS).
+                2. 模型可以使用reference audio控制说话音色.
+                3. 可以生成具有较强情感与韵律的音频.
+                """
+                )
+
+            # Right column (30%) for controls
+            with gr.Column(scale=3):
+                sys_audio_input = gr.Audio(
+                    sources=["upload"],
+                    type="numpy",
+                    label="Give a timbre for your assistant",
+                )
+                sys_text_input = gr.Textbox(
+                    label="What is your assistant's role?",
+                    value='您是由 Fish Audio 设计的语音助手,提供端到端的语音交互,实现无缝用户体验。首先转录用户的语音,然后使用以下格式回答:"Question: [用户语音]\n\nResponse: [你的回答]\n"。',
+                    type="text",
+                )
+                audio_input = gr.Audio(
+                    sources=["microphone"], type="numpy", label="Speak your message"
+                )
+
+                text_input = gr.Textbox(label="Or type your message", type="text")
+
+                output_audio = gr.Audio(label="Assistant's Voice", type="numpy")
+
+                send_button = gr.Button("Send", variant="primary")
+                clear_button = gr.Button("Clear")
+
+        # Event handlers
+        audio_input.stop_recording(
+            process_audio_input,
+            inputs=[sys_audio_input, sys_text_input, audio_input, state, text_input],
+            outputs=[chatbot, output_audio, audio_input, text_input],
+            show_progress=True,
+        )
+
+        send_button.click(
+            process_text_input,
+            inputs=[sys_audio_input, sys_text_input, state, text_input],
+            outputs=[chatbot, output_audio, audio_input, text_input],
+            show_progress=True,
+        )
+
+        text_input.submit(
+            process_text_input,
+            inputs=[sys_audio_input, sys_text_input, state, text_input],
+            outputs=[chatbot, output_audio, audio_input, text_input],
+            show_progress=True,
+        )
+
+        clear_button.click(
+            clear_fn,
+            inputs=[],
+            outputs=[chatbot, state, audio_input, output_audio, text_input],
+        )
+
+    return demo
+
+
+if __name__ == "__main__":
+    demo = create_demo()
+    demo.launch(server_name="127.0.0.1", server_port=7860, share=True)

+ 256 - 0
tools/fish_e2e.py

@@ -0,0 +1,256 @@
+import base64
+import io
+import json
+import os
+import struct
+from dataclasses import dataclass
+from enum import Enum
+from typing import AsyncGenerator
+
+import httpx
+import numpy as np
+import ormsgpack
+import soundfile as sf
+from livekit import rtc
+from livekit.agents.llm.chat_context import ChatContext
+
+from .schema import (
+    ServeMessage,
+    ServeRequest,
+    ServeTextPart,
+    ServeVQGANDecodeRequest,
+    ServeVQGANEncodeRequest,
+    ServeVQPart,
+)
+
+
+class FishE2EEventType(Enum):
+    SPEECH_SEGMENT = 1
+    TEXT_SEGMENT = 2
+    END_OF_TEXT = 3
+    END_OF_SPEECH = 4
+    ASR_RESULT = 5
+    USER_CODES = 6
+
+
+@dataclass
+class FishE2EEvent:
+    type: FishE2EEventType
+    frame: rtc.AudioFrame = None
+    text: str = None
+    vq_codes: list[list[int]] = None
+
+
+client = httpx.AsyncClient(
+    timeout=None,
+    limits=httpx.Limits(
+        max_connections=None,
+        max_keepalive_connections=None,
+        keepalive_expiry=None,
+    ),
+)
+
+
+class FishE2EAgent:
+    def __init__(self):
+        self.llm_url = "http://localhost:8080/v1/chat"
+        self.vqgan_url = "http://localhost:8080"
+        self.client = httpx.AsyncClient(timeout=None)
+
+    async def get_codes(self, audio_data, sample_rate):
+        audio_buffer = io.BytesIO()
+        sf.write(audio_buffer, audio_data, sample_rate, format="WAV")
+        audio_buffer.seek(0)
+        # Step 1: Encode audio using VQGAN
+        encode_request = ServeVQGANEncodeRequest(audios=[audio_buffer.read()])
+        encode_request_bytes = ormsgpack.packb(
+            encode_request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC
+        )
+        encode_response = await self.client.post(
+            f"{self.vqgan_url}/v1/vqgan/encode",
+            data=encode_request_bytes,
+            headers={"Content-Type": "application/msgpack"},
+        )
+        encode_response_data = ormsgpack.unpackb(encode_response.content)
+        codes = encode_response_data["tokens"][0]
+        return codes
+
+    async def stream(
+        self,
+        system_audio_data: np.ndarray | None,
+        user_audio_data: np.ndarray | None,
+        sample_rate: int,
+        num_channels: int,
+        chat_ctx: ChatContext | None = None,
+    ) -> AsyncGenerator[bytes, None]:
+
+        if system_audio_data is not None:
+            sys_codes = await self.get_codes(system_audio_data, sample_rate)
+        else:
+            sys_codes = None
+        if user_audio_data is not None:
+            user_codes = await self.get_codes(user_audio_data, sample_rate)
+        # Step 2: Prepare LLM request
+        if chat_ctx is None:
+            sys_parts = [
+                ServeTextPart(
+                    text='您是由 Fish Audio 设计的语音助手,提供端到端的语音交互,实现无缝用户体验。首先转录用户的语音,然后使用以下格式回答:"Question: [用户语音]\n\nAnswer: [你的回答]\n"。'
+                ),
+            ]
+            if system_audio_data is not None:
+                sys_parts.append(ServeVQPart(codes=sys_codes))
+            chat_ctx = {
+                "messages": [
+                    ServeMessage(
+                        role="system",
+                        parts=sys_parts,
+                    ),
+                ],
+            }
+        else:
+            if chat_ctx["added_sysaudio"] is False and sys_codes:
+                chat_ctx["added_sysaudio"] = True
+                chat_ctx["messages"][0].parts.append(ServeVQPart(codes=sys_codes))
+
+        prev_messages = chat_ctx["messages"].copy()
+        if user_audio_data is not None:
+            yield FishE2EEvent(
+                type=FishE2EEventType.USER_CODES,
+                vq_codes=user_codes,
+            )
+        else:
+            user_codes = None
+
+        request = ServeRequest(
+            messages=prev_messages
+            + (
+                [
+                    ServeMessage(
+                        role="user",
+                        parts=[ServeVQPart(codes=user_codes)],
+                    )
+                ]
+                if user_codes
+                else []
+            ),
+            streaming=True,
+            num_samples=1,
+        )
+
+        # Step 3: Stream LLM response and decode audio
+        buffer = b""
+        vq_codes = []
+        current_vq = False
+
+        async def decode_send():
+            nonlocal current_vq
+            nonlocal vq_codes
+
+            data = np.concatenate(vq_codes, axis=1).tolist()
+            # Decode VQ codes to audio
+            decode_request = ServeVQGANDecodeRequest(tokens=[data])
+            decode_response = await self.client.post(
+                f"{self.vqgan_url}/v1/vqgan/decode",
+                data=ormsgpack.packb(
+                    decode_request,
+                    option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
+                ),
+                headers={"Content-Type": "application/msgpack"},
+            )
+            decode_data = ormsgpack.unpackb(decode_response.content)
+
+            # Convert float16 audio data to int16
+            audio_data = np.frombuffer(decode_data["audios"][0], dtype=np.float16)
+            audio_data = (audio_data * 32768).astype(np.int16).tobytes()
+
+            audio_frame = rtc.AudioFrame(
+                data=audio_data,
+                samples_per_channel=len(audio_data) // 2,
+                sample_rate=44100,
+                num_channels=1,
+            )
+            yield FishE2EEvent(
+                type=FishE2EEventType.SPEECH_SEGMENT,
+                frame=audio_frame,
+                vq_codes=data,
+            )
+
+            current_vq = False
+            vq_codes = []
+
+        async with self.client.stream(
+            "POST",
+            self.llm_url,
+            data=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
+            headers={"Content-Type": "application/msgpack"},
+        ) as response:
+
+            async for chunk in response.aiter_bytes():
+                buffer += chunk
+
+                while len(buffer) >= 4:
+                    read_length = struct.unpack("I", buffer[:4])[0]
+                    if len(buffer) < 4 + read_length:
+                        break
+
+                    body = buffer[4 : 4 + read_length]
+                    buffer = buffer[4 + read_length :]
+                    data = ormsgpack.unpackb(body)
+
+                    if data["delta"] and data["delta"]["part"]:
+                        if current_vq and data["delta"]["part"]["type"] == "text":
+                            async for event in decode_send():
+                                yield event
+                        if data["delta"]["part"]["type"] == "text":
+                            yield FishE2EEvent(
+                                type=FishE2EEventType.TEXT_SEGMENT,
+                                text=data["delta"]["part"]["text"],
+                            )
+                        elif data["delta"]["part"]["type"] == "vq":
+                            vq_codes.append(np.array(data["delta"]["part"]["codes"]))
+                            current_vq = True
+
+        if current_vq and vq_codes:
+            async for event in decode_send():
+                yield event
+
+        yield FishE2EEvent(type=FishE2EEventType.END_OF_TEXT)
+        yield FishE2EEvent(type=FishE2EEventType.END_OF_SPEECH)
+
+
+# Example usage:
+async def main():
+    import torchaudio
+
+    agent = FishE2EAgent()
+
+    # Replace this with actual audio data loading
+    with open("uz_story_en.m4a", "rb") as f:
+        audio_data = f.read()
+
+    audio_data, sample_rate = torchaudio.load("uz_story_en.m4a")
+    audio_data = (audio_data.numpy() * 32768).astype(np.int16)
+
+    stream = agent.stream(audio_data, sample_rate, 1)
+    if os.path.exists("audio_segment.wav"):
+        os.remove("audio_segment.wav")
+
+    async for event in stream:
+        if event.type == FishE2EEventType.SPEECH_SEGMENT:
+            # Handle speech segment (e.g., play audio or save to file)
+            with open("audio_segment.wav", "ab+") as f:
+                f.write(event.frame.data)
+        elif event.type == FishE2EEventType.ASR_RESULT:
+            print(event.text, flush=True)
+        elif event.type == FishE2EEventType.TEXT_SEGMENT:
+            print(event.text, flush=True, end="")
+        elif event.type == FishE2EEventType.END_OF_TEXT:
+            print("\nEnd of text reached.")
+        elif event.type == FishE2EEventType.END_OF_SPEECH:
+            print("End of speech reached.")
+
+
+if __name__ == "__main__":
+    import asyncio
+
+    asyncio.run(main())

+ 372 - 7
tools/llama/generate.py

@@ -15,8 +15,10 @@ import torch._dynamo.config
 import torch._inductor.config
 from loguru import logger
 from tqdm import tqdm
+from transformers import AutoTokenizer
 
 from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
+from fish_speech.models.text2semantic.llama import BaseModelArgs
 from fish_speech.text import clean_text, split_text
 
 os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -28,6 +30,8 @@ if hasattr(torch._inductor.config, "fx_graph_cache"):
     torch._inductor.config.fx_graph_cache = True
 
 
+from torch.nn.attention import SDPBackend, sdpa_kernel
+
 from fish_speech.models.text2semantic.llama import (
     BaseTransformer,
     DualARTransformer,
@@ -74,6 +78,45 @@ def logits_to_probs(
     return probs
 
 
+def multinomial_sample_one_no_sync_agent(
+    probs_sort,
+):  # Does multinomial sampling without a cuda synchronization
+    q = torch.empty_like(probs_sort).exponential_(1)
+    return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
+
+
+def logits_to_probs_agent(
+    logits,
+    previous_tokens: Optional[torch.Tensor] = None,
+    temperature: torch.Tensor = 1.0,
+    top_p: torch.Tensor = 1.0,
+    repetition_penalty: torch.Tensor = 1.0,
+) -> torch.Tensor:
+    # Apply repetition penalty
+    if previous_tokens is not None:
+        previous_tokens = previous_tokens.long()
+        score = torch.gather(logits, dim=-1, index=previous_tokens)
+        score = torch.where(
+            score < 0, score * repetition_penalty, score / repetition_penalty
+        )
+        logits.scatter_(dim=-1, index=previous_tokens, src=score)
+
+    # Apply top-p sampling
+    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+    cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
+    sorted_indices_to_remove = cum_probs > top_p
+    sorted_indices_to_remove[..., 0] = False  # keep at least one option
+    indices_to_remove = sorted_indices_to_remove.scatter(
+        dim=-1, index=sorted_indices, src=sorted_indices_to_remove
+    )
+    logits = logits.masked_fill(indices_to_remove, -float("Inf"))
+
+    logits = logits / max(temperature, 1e-5)
+
+    probs = torch.nn.functional.softmax(logits, dim=-1)
+    return probs
+
+
 def sample(
     logits,
     previous_tokens: Optional[torch.Tensor] = None,
@@ -86,20 +129,135 @@ def sample(
     return idx_next, probs
 
 
-def decode_one_token_ar(
+def sample_agent(
+    logits,
+    previous_tokens: Optional[torch.Tensor] = None,
+    **sampling_kwargs,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    probs = logits_to_probs_agent(
+        logits=logits[:, -1], previous_tokens=previous_tokens, **sampling_kwargs
+    )
+    idx_next = multinomial_sample_one_no_sync_agent(probs)
+    return idx_next, probs
+
+
+def decode_one_token_ar_agent(
     model: DualARTransformer,
     x: torch.Tensor,
     input_pos: torch.Tensor,
     previous_tokens: torch.Tensor = None,
+    semantic_id: int = 32003,
     **sampling_kwargs,
 ) -> torch.Tensor:
+    # print(x, input_pos)
     x = model.forward_generate(x, input_pos)
+    logits = x.logits  # [:, -1:]
+    hidden_states = x.hidden_states  # [:, -1:]
 
     sampling_kwargs_main = sampling_kwargs.copy()
     sampling_kwargs_main["temperature"] = 0.1
     sampling_kwargs_main["top_p"] = 0.1
     sampling_kwargs_main["repetition_penalty"] = 1.0
 
+    codebooks = [
+        sample_agent(
+            logits,
+            previous_tokens=None,  # Disable repetition penalty for the token codebook
+            **sampling_kwargs_main,
+        )[0]
+    ]
+
+    # Cleanup the cache
+    for layer in model.fast_layers:
+        layer.attention.kv_cache.k_cache.fill_(0)
+        layer.attention.kv_cache.v_cache.fill_(0)
+
+    for codebook_idx in range(model.config.num_codebooks):
+        input_pos = torch.tensor(
+            [codebook_idx], device=hidden_states.device, dtype=torch.long
+        )
+        logits = model.forward_generate_fast(hidden_states, input_pos)
+        a = sample_agent(
+            logits,
+            previous_tokens=(
+                previous_tokens[:, codebook_idx + 1]
+                if previous_tokens is not None
+                else None
+            ),
+            **sampling_kwargs,
+        )[0]
+        hidden_states = model.fast_embeddings(a)
+        codebooks.append(a)
+
+    codebooks = torch.stack(codebooks, dim=1)
+    codebooks[:, 1:, :] = torch.masked_fill(
+        codebooks[:, 1:, :], codebooks[:, :1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
+    )
+
+    # for i in range(codebooks.size(1) - 1):
+    #     codebooks[:, i + 1, :] = torch.masked_fill(
+    #         codebooks[:, i + 1, :],
+    #         codebooks[:, :1, :] != semantic_id,
+    #         CODEBOOK_PAD_TOKEN_ID + i * 1024,
+    #     )
+
+    # print(codebooks)
+
+    return codebooks
+
+
+def decode_one_token_naive_agent(
+    model: NaiveTransformer,
+    x: torch.Tensor,
+    input_pos: torch.Tensor,
+    previous_tokens: torch.Tensor = None,
+    semantic_id: int = 32003,
+    **sampling_kwargs,
+) -> torch.Tensor:
+    x = model.forward_generate(x, input_pos)
+
+    codebooks = [
+        sample(
+            x.token_logits,
+            previous_tokens=None,  # Disable repetition penalty for the token codebook
+            **sampling_kwargs,
+        )[0]
+    ]
+
+    for i in range(model.config.num_codebooks):
+        codebooks.append(
+            sample_agent(
+                x.codebook_logits[:, :, i],
+                previous_tokens=(
+                    previous_tokens[:, i + 1] if previous_tokens is not None else None
+                ),
+                **sampling_kwargs,
+            )[0]
+        )
+
+    codebooks = torch.stack(codebooks, dim=1)
+    codebooks[:, 1:, :] = torch.masked_fill(
+        codebooks[:, 1:, :], codebooks[:, :1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
+    )
+
+    return codebooks
+
+
+def decode_one_token_ar(
+    model: DualARTransformer,
+    x: torch.Tensor,
+    input_pos: torch.Tensor,
+    previous_tokens: torch.Tensor = None,
+    semantic_id: int = 0,
+    **sampling_kwargs,
+) -> torch.Tensor:
+    x = model.forward_generate(x, input_pos)
+
+    sampling_kwargs_main = sampling_kwargs.copy()
+    # sampling_kwargs_main["temperature"] = 0.1
+    # sampling_kwargs_main["top_p"] = 0.1
+    # sampling_kwargs_main["repetition_penalty"] = 1.0
+
     codebooks = [
         sample(
             x.logits,
@@ -130,7 +288,12 @@ def decode_one_token_ar(
         x = model.fast_embeddings(a)
         codebooks.append(a)
 
-    return torch.stack(codebooks, dim=0)
+    codebooks = torch.stack(codebooks, dim=0)
+    codebooks[1:, :] = torch.masked_fill(
+        codebooks[1:, :], codebooks[:1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
+    )
+
+    return codebooks
 
 
 def decode_one_token_naive(
@@ -176,6 +339,7 @@ def decode_n_tokens(
     num_new_tokens: int,
     im_end_id: int = 4,
     decode_one_token=decode_one_token_naive,
+    semantic_id: int = 0,
     **sampling_kwargs,
 ):
     previous_tokens = torch.zeros(
@@ -204,6 +368,7 @@ def decode_n_tokens(
                 x=cur_token,
                 input_pos=input_pos,
                 previous_tokens=window,
+                semantic_id=semantic_id,
                 **sampling_kwargs,
             )
 
@@ -236,6 +401,7 @@ def generate(
 
     # create an empty tensor of the expected final shape and fill in the current tokens
     T = prompt.size(1)
+    semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
 
     if max_new_tokens:
         if T + max_new_tokens > model.config.max_seq_len:
@@ -266,7 +432,11 @@ def generate(
     )
 
     next_token = prefill_decode(
-        model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
+        model,
+        prompt.view(1, codebook_dim, -1),
+        input_pos,
+        semantic_id=semantic_id,
+        **sampling_kwargs,
     )
     seq[:, T : T + 1] = next_token
 
@@ -278,6 +448,7 @@ def generate(
         max_new_tokens - 1,
         im_end_id=im_end_id,
         decode_one_token=decode_one_token,
+        semantic_id=semantic_id,
         **sampling_kwargs,
     )
     # x = torch.cat(generated_tokens, dim=1)
@@ -287,6 +458,142 @@ def generate(
     return seq
 
 
+def decode_n_tokens_agent(
+    model: NaiveTransformer,
+    cur_token: torch.Tensor,
+    input_pos: torch.Tensor,
+    num_new_tokens: int,
+    im_end_id: int = 4,
+    semantic_id: int = 32003,
+    decode_one_token=decode_one_token_naive_agent,
+    early_stop_threshold: float = 0.6,
+    **sampling_kwargs,
+):
+    batch_size = cur_token.size(0)
+    previous_tokens = torch.zeros(
+        (batch_size, model.config.num_codebooks + 1, model.config.max_seq_len),
+        dtype=torch.int,
+        device=cur_token.device,
+    )
+    finished = torch.zeros(batch_size, dtype=torch.bool, device=cur_token.device)
+    finished = finished | (cur_token[:, 0, -1] == im_end_id)
+    start_time = time.time()
+
+    for i in tqdm(range(num_new_tokens), desc="Decoding: ", total=num_new_tokens):
+        # We need to get windowed repeat penalty
+        win_size = 16
+        if i < win_size:
+            window = previous_tokens[:, :, :win_size]
+        else:
+            window = previous_tokens[:, :, i - win_size : i]
+
+        with sdpa_kernel(
+            SDPBackend.MATH
+        ):  # Actually better for Inductor to codegen attention here
+            next_token = decode_one_token(
+                model=model,
+                x=cur_token,
+                input_pos=input_pos,
+                previous_tokens=window,
+                semantic_id=semantic_id,
+                **sampling_kwargs,
+            )
+
+        input_pos += 1
+        cur_token = next_token.view(batch_size, model.config.num_codebooks + 1, -1)
+        previous_tokens[:, :, i : i + 1] = next_token.view(
+            batch_size, model.config.num_codebooks + 1, -1
+        )
+
+        yield cur_token.cpu()
+
+        finished = finished | (cur_token[:, 0, -1] == im_end_id)
+        if finished.all() or (
+            0 < early_stop_threshold < 1
+            and finished.sum() >= round(batch_size * early_stop_threshold)
+        ):
+            break
+
+    total_time = time.time() - start_time
+    generated_tokens = i + 1
+    tokens_per_second = (generated_tokens / total_time) * batch_size
+    logger.info(
+        f"Decoded {generated_tokens} x {batch_size} tokens in {total_time:.2f}s ({tokens_per_second:.2f} tokens/s)"
+    )
+
+
+@torch.no_grad()
+@torch.inference_mode()
+def generate_agent(
+    *,
+    model: BaseTransformer,
+    prompt: torch.Tensor,
+    max_new_tokens: int,
+    im_end_id: int = 4,
+    semantic_id: int = 32003,
+    decode_one_token=decode_one_token_naive_agent,
+    num_samples: int = 1,
+    early_stop_threshold: float = 0.6,
+    **sampling_kwargs,
+):
+    """
+    Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
+    """
+
+    # create an empty tensor of the expected final shape and fill in the current tokens
+    T = prompt.size(1)
+    prompt = prompt[None].repeat(num_samples, 1, 1)
+
+    if T >= model.config.max_seq_len:
+        raise ValueError(
+            f"Input sequence length {T} exceeds max_seq_len {model.config.max_seq_len}"
+        )
+
+    if max_new_tokens:
+        if T + max_new_tokens > model.config.max_seq_len:
+            max_new_tokens = model.config.max_seq_len - T
+            logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
+
+        T_new = T + max_new_tokens
+    else:
+        T_new = model.config.max_seq_len
+        max_new_tokens = T_new - T
+
+    device, dtype = prompt.device, prompt.dtype
+
+    codebook_dim = 1 + model.config.num_codebooks
+    input_pos = torch.arange(0, T, device=device)
+
+    # Use non-accelerated version for now, to avoid compilation overhead
+    prefill_decode = (
+        decode_one_token_naive_agent
+        if isinstance(model, NaiveTransformer)
+        else decode_one_token_ar_agent
+    )
+    next_token = prefill_decode(
+        model,
+        prompt,
+        input_pos,
+        semantic_id=semantic_id,
+        **sampling_kwargs,
+    ).view(num_samples, codebook_dim, -1)
+    yield next_token.cpu()
+
+    input_pos = torch.tensor([T], device=device, dtype=torch.int)
+
+    yield from decode_n_tokens_agent(
+        model,
+        next_token,
+        input_pos,
+        max_new_tokens - 1,
+        im_end_id=im_end_id,
+        semantic_id=semantic_id,
+        decode_one_token=decode_one_token,
+        early_stop_threshold=early_stop_threshold,
+        **sampling_kwargs,
+    )
+
+
 def encode_tokens(
     tokenizer,
     string,
@@ -295,7 +602,7 @@ def encode_tokens(
     num_codebooks=4,
 ):
     string = clean_text(string)
-    string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n"
+    string = f"<|im_start|>user\nSpeak: {string}<|im_end|><|im_start|>assistant\n"
 
     new_tokens = tokenizer.encode(
         string,
@@ -351,7 +658,7 @@ def encode_tokens(
     return prompt
 
 
-def load_model(checkpoint_path, device, precision, compile=False):
+def load_model(checkpoint_path, device, precision, compile=False, is_agent=False):
     model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
         checkpoint_path, load_weights=True
     )
@@ -360,10 +667,14 @@ def load_model(checkpoint_path, device, precision, compile=False):
     logger.info(f"Restored model from checkpoint")
 
     if isinstance(model, DualARTransformer):
-        decode_one_token = decode_one_token_ar
+        decode_one_token = (
+            decode_one_token_ar_agent if is_agent else decode_one_token_ar
+        )
         logger.info("Using DualARTransformer")
     else:
-        decode_one_token = decode_one_token_naive
+        decode_one_token = (
+            decode_one_token_naive_agent if is_agent else decode_one_token_naive
+        )
         logger.info("Using NaiveTransformer")
 
     if compile:
@@ -605,6 +916,60 @@ def launch_thread_safe_queue(
     return input_queue
 
 
+def launch_thread_safe_queue_agent(
+    checkpoint_path,
+    device,
+    precision,
+    compile: bool = False,
+):
+    input_queue = queue.Queue()
+    init_event = threading.Event()
+
+    tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
+    config = BaseModelArgs.from_pretrained(checkpoint_path)
+
+    def worker():
+        model, decode_one_token = load_model(
+            checkpoint_path, device, precision, compile=compile, is_agent=True
+        )
+
+        with torch.device(device):
+            model.setup_caches(
+                max_batch_size=1,
+                max_seq_len=model.config.max_seq_len,
+                dtype=next(model.parameters()).dtype,
+            )
+        init_event.set()
+
+        while True:
+            item: GenerateRequest | None = input_queue.get()
+            if item is None:
+                break
+
+            kwargs = item.request
+            response_queue = item.response_queue
+
+            try:
+                for token in generate_agent(
+                    model=model,
+                    decode_one_token=decode_one_token,
+                    **kwargs,
+                ):
+                    response_queue.put(token)
+
+                response_queue.put("stop")
+            except Exception as e:
+                import traceback
+
+                logger.exception(f"Error in worker: {traceback.format_exc()}")
+                response_queue.put("error")
+
+    threading.Thread(target=worker, daemon=True).start()
+    init_event.wait()
+
+    return input_queue, tokenizer, config
+
+
 @click.command()
 @click.option(
     "--text",

+ 1 - 1
tools/msgpack_api.py

@@ -5,7 +5,7 @@ from pathlib import Path
 import httpx
 import ormsgpack
 
-from tools.commons import ServeReferenceAudio, ServeTTSRequest
+from tools.schema import ServeReferenceAudio, ServeTTSRequest
 
 api_key = os.environ.get("FISH_API_KEY", "YOUR_API_KEY")
 

+ 1 - 7
tools/post_api.py

@@ -8,8 +8,8 @@ import requests
 from pydub import AudioSegment
 from pydub.playback import play
 
-from tools.commons import ServeReferenceAudio, ServeTTSRequest
 from tools.file import audio_to_bytes, read_ref_text
+from tools.schema import ServeReferenceAudio, ServeTTSRequest
 
 
 def parse_args():
@@ -125,12 +125,6 @@ def parse_args():
         help="`None` means randomized inference, otherwise deterministic.\n"
         "It can't be used for fixing a timbre.",
     )
-    parser.add_argument(
-        "--seed",
-        type=int,
-        default=None,
-        help="None means randomized inference, otherwise deterministic",
-    )
 
     return parser.parse_args()
 

+ 187 - 0
tools/schema.py

@@ -0,0 +1,187 @@
+import os
+import queue
+from dataclasses import dataclass
+from typing import Annotated, Literal, Optional
+
+import torch
+from pydantic import AfterValidator, BaseModel, Field, confloat, conint, conlist
+from pydantic.functional_validators import SkipValidation
+
+from fish_speech.conversation import Message, TextPart, VQPart
+
+GLOBAL_NUM_SAMPLES = int(os.getenv("GLOBAL_NUM_SAMPLES", 1))
+
+
+class ServeVQPart(BaseModel):
+    type: Literal["vq"] = "vq"
+    codes: SkipValidation[list[list[int]]]
+
+
+class ServeTextPart(BaseModel):
+    type: Literal["text"] = "text"
+    text: str
+
+
+class ServeAudioPart(BaseModel):
+    type: Literal["audio"] = "audio"
+    audio: bytes
+
+
+@dataclass
+class ASRPackRequest:
+    audio: torch.Tensor
+    result_queue: queue.Queue
+    language: str
+
+
+class ServeASRRequest(BaseModel):
+    # The audio should be an uncompressed PCM float16 audio
+    audios: list[bytes]
+    sample_rate: int = 44100
+    language: Literal["zh", "en", "ja", "auto"] = "auto"
+
+
+class ServeASRTranscription(BaseModel):
+    text: str
+    duration: float
+    huge_gap: bool
+
+
+class ServeASRSegment(BaseModel):
+    text: str
+    start: float
+    end: float
+
+
+class ServeTimedASRResponse(BaseModel):
+    text: str
+    segments: list[ServeASRSegment]
+    duration: float
+
+
+class ServeASRResponse(BaseModel):
+    transcriptions: list[ServeASRTranscription]
+
+
+class ServeMessage(BaseModel):
+    role: Literal["system", "assistant", "user"]
+    parts: list[ServeVQPart | ServeTextPart]
+
+    def to_conversation_message(self):
+        new_message = Message(role=self.role, parts=[])
+        for part in self.parts:
+            if isinstance(part, ServeTextPart):
+                new_message.parts.append(TextPart(text=part.text))
+            elif isinstance(part, ServeVQPart):
+                new_message.parts.append(
+                    VQPart(codes=torch.tensor(part.codes, dtype=torch.int))
+                )
+            else:
+                raise ValueError(f"Unsupported part type: {part}")
+
+        return new_message
+
+
+class ServeRequest(BaseModel):
+    messages: Annotated[list[ServeMessage], conlist(ServeMessage, min_length=1)]
+    max_new_tokens: int = 1024
+    top_p: float = 0.7
+    repetition_penalty: float = 1.2
+    temperature: float = 0.7
+    streaming: bool = False
+    num_samples: int = 1
+    early_stop_threshold: float = 1.0
+
+
+class ServeVQGANEncodeRequest(BaseModel):
+    # The audio here should be in wav, mp3, etc
+    audios: list[bytes]
+
+
+class ServeVQGANEncodeResponse(BaseModel):
+    tokens: SkipValidation[list[list[list[int]]]]
+
+
+class ServeVQGANDecodeRequest(BaseModel):
+    tokens: SkipValidation[list[list[list[int]]]]
+
+
+class ServeVQGANDecodeResponse(BaseModel):
+    # The audio here should be in PCM float16 format
+    audios: list[bytes]
+
+
+class ServeReferenceAudio(BaseModel):
+    audio: bytes
+    text: str
+
+
+class ServeForwardMessage(BaseModel):
+    role: str
+    content: str
+
+
+class ServeResponse(BaseModel):
+    messages: list[ServeMessage]
+    finish_reason: Literal["stop", "error"] | None = None
+    stats: dict[str, int | float | str] = {}
+
+
+class ServeStreamDelta(BaseModel):
+    role: Literal["system", "assistant", "user"] | None = None
+    part: ServeVQPart | ServeTextPart | None = None
+
+
+class ServeStreamResponse(BaseModel):
+    sample_id: int = 0
+    delta: ServeStreamDelta | None = None
+    finish_reason: Literal["stop", "error"] | None = None
+    stats: dict[str, int | float | str] | None = None
+
+
+class ServeReferenceAudio(BaseModel):
+    audio: bytes
+    text: str
+
+    def __repr__(self) -> str:
+        return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"
+
+
+class ServeChatRequestV1(BaseModel):
+    model: str = "llama3-8b"
+    messages: list[ServeForwardMessage] = []
+    audio: bytes | None = None
+    temperature: float = 1.0
+    top_p: float = 1.0
+    max_tokens: int = 256
+    voice: str = "jessica"
+    tts_audio_format: Literal["mp3", "pcm", "opus"] = "mp3"
+    tts_audio_bitrate: Literal[16, 24, 32, 48, 64, 96, 128, 192] = 128
+
+
+class ServeTTSRequest(BaseModel):
+    text: str
+    chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
+    # Audio format
+    format: Literal["wav", "pcm", "mp3"] = "wav"
+    mp3_bitrate: Literal[64, 128, 192] = 128
+    # References audios for in-context learning
+    references: list[ServeReferenceAudio] = []
+    # Reference id
+    # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
+    # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
+    reference_id: str | None = None
+    seed: int | None = None
+    use_memory_cache: Literal["on-demand", "never"] = "never"
+    # Normalize text for en & zh, this increase stability for numbers
+    normalize: bool = True
+    mp3_bitrate: Optional[int] = 64
+    opus_bitrate: Optional[int] = -1000
+    # Balance mode will reduce latency to 300ms, but may decrease stability
+    latency: Literal["normal", "balanced"] = "normal"
+    # not usually used below
+    streaming: bool = False
+    max_new_tokens: int = 1024
+    top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
+    repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
+    temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7