Parcourir la source

Move inference code to library dir (#801)

* move inference engine

* move file functions

* move schema

* minor fix

* minor fix

* move inference code

* update docs

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor fix

* [feature]retain the interface to support old version codes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update inference.ipynb tp support 1.5

* Remove unused package

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Whale-Dolphin <whaledolphin666@gmail.com>
Co-authored-by: Whale and Dolphin <70465000+Whale-Dolphin@users.noreply.github.com>
Picus303 il y a 1 an
Parent
commit
4fc8dbdfe0
44 fichiers modifiés avec 1505 ajouts et 3016 suppressions
  1. 12 3
      docs/en/inference.md
  2. 12 3
      docs/ja/inference.md
  3. 12 3
      docs/ko/inference.md
  4. 12 3
      docs/pt/inference.md
  5. 12 3
      docs/zh/inference.md
  6. 12 9
      fish_speech/inference_engine/__init__.py
  7. 7 2
      fish_speech/inference_engine/reference_loader.py
  8. 1 1
      fish_speech/inference_engine/utils.py
  9. 0 0
      fish_speech/inference_engine/vq_manager.py
  10. 1114 0
      fish_speech/models/text2semantic/inference.py
  11. 124 0
      fish_speech/models/vqgan/inference.py
  12. 123 0
      fish_speech/utils/file.py
  13. 0 0
      fish_speech/utils/schema.py
  14. 0 161
      fish_speech/webui/css/style.css
  15. 0 11
      fish_speech/webui/html/footer.html
  16. 0 69
      fish_speech/webui/js/animate.js
  17. 0 120
      fish_speech/webui/launch_utils.py
  18. 0 1239
      fish_speech/webui/manage.py
  19. 9 9
      inference.ipynb
  20. 2 2
      tools/api_client.py
  21. 2 2
      tools/e2e_webui.py
  22. 0 125
      tools/file.py
  23. 1 1
      tools/fish_e2e.py
  24. 1 2
      tools/llama/build_dataset.py
  25. 1 1
      tools/llama/eval_in_context.py
  26. 8 1106
      tools/llama/generate.py
  27. 1 1
      tools/llama/quantize.py
  28. 4 4
      tools/run_webui.py
  29. 1 1
      tools/sensevoice/fun_asr.py
  30. 1 1
      tools/server/agent/generate.py
  31. 1 1
      tools/server/agent/generation_utils.py
  32. 1 1
      tools/server/agent/pre_generation_utils.py
  33. 2 2
      tools/server/api_utils.py
  34. 2 2
      tools/server/inference.py
  35. 4 4
      tools/server/model_manager.py
  36. 1 1
      tools/server/model_utils.py
  37. 7 3
      tools/server/views.py
  38. 1 1
      tools/smart_pad.py
  39. 1 1
      tools/vqgan/create_train_split.py
  40. 1 2
      tools/vqgan/extract_vq.py
  41. 9 113
      tools/vqgan/inference.py
  42. 1 1
      tools/webui/__init__.py
  43. 1 1
      tools/webui/inference.py
  44. 1 1
      tools/whisper_asr.py

+ 12 - 3
docs/en/inference.md

@@ -23,8 +23,11 @@ huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-
 !!! note
     If you plan to let the model randomly choose a voice timbre, you can skip this step.
 
+!!! warning "Future Warning"
+    We have kept the interface accessible from the original path (tools/vqgan/inference.py), but this interface may be removed in subsequent releases, so please change your code as soon as possible.
+
 ```bash
-python tools/vqgan/inference.py \
+python fish_speech/models/vqgan/inference.py \
     -i "paimon.wav" \
     --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
 ```
@@ -33,8 +36,11 @@ You should get a `fake.npy` file.
 
 ### 2. Generate semantic tokens from text:
 
+!!! warning "Future Warning"
+    We have kept the interface accessible from the original path (tools/llama/generate.py), but this interface may be removed in subsequent releases, so please change your code as soon as possible.
+
 ```bash
-python tools/llama/generate.py \
+python fish_speech/models/text2semantic/inference.py \
     --text "The text you want to convert" \
     --prompt-text "Your reference text" \
     --prompt-tokens "fake.npy" \
@@ -56,8 +62,11 @@ This command will create a `codes_N` file in the working directory, where N is a
 
 #### VQGAN Decoder
 
+!!! warning "Future Warning"
+    We have kept the interface accessible from the original path (tools/vqgan/inference.py), but this interface may be removed in subsequent releases, so please change your code as soon as possible.
+
 ```bash
-python tools/vqgan/inference.py \
+python fish_speech/models/vqgan/inference.py \
     -i "codes_0.npy" \
     --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
 ```

+ 12 - 3
docs/ja/inference.md

@@ -23,8 +23,11 @@ huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-
 !!! note
     モデルにランダムに音声の音色を選ばせる場合、このステップをスキップできます。
 
+!!! warning "将来のバージョンに関する警告"
+    元のパス(tools/vqgan/infernce.py)からアクセスできるインターフェースは残していますが、このインターフェースは将来のいくつかのバージョンで削除される可能性があります。お早めにコードを変更してください。
+
 ```bash
-python tools/vqgan/inference.py \
+python fish_speech/models/vqgan/inference.py \
     -i "paimon.wav" \
     --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
 ```
@@ -33,8 +36,11 @@ python tools/vqgan/inference.py \
 
 ### 2. テキストからセマンティックトークンを生成する:
 
+!!! warning "将来のバージョンに関する警告"
+    元のパス(tools/llama/generate.py)からアクセスできるインターフェースは残していますが、このインターフェースは将来のいくつかのバージョンで削除される可能性があります。お早めにコードを変更してください。
+
 ```bash
-python tools/llama/generate.py \
+python fish_speech/models/text2semantic/inference.py \
     --text "変換したいテキスト" \
     --prompt-text "参照テキスト" \
     --prompt-tokens "fake.npy" \
@@ -56,8 +62,11 @@ python tools/llama/generate.py \
 
 #### VQGAN デコーダー
 
+!!! warning "将来のバージョンに関する警告"
+    元のパス(tools/vqgan/infernce.py)からアクセスできるインターフェースは残していますが、このインターフェースは将来のいくつかのバージョンで削除される可能性があります。お早めにコードを変更してください。
+
 ```bash
-python tools/vqgan/inference.py \
+python fish_speech/models/vqgan/inference.py \
     -i "codes_0.npy" \
     --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
 ```

+ 12 - 3
docs/ko/inference.md

@@ -23,8 +23,11 @@ huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-
 !!! note
     모델이 음색을 무작위로 선택하도록 하려면 이 단계를 건너뛸 수 있습니다.
 
+!!! warning "향후 버전 경고"
+    원래 경로(tools/vqgan/infernce.py)에서 접근할 수 있는 인터페이스는 유지했지만, 이 인터페이스는 향후 몇몇 버전에서 삭제될 수 있습니다. 가능한 한 빨리 코드를 변경하십시오.
+
 ```bash
-python tools/vqgan/inference.py \
+python fish_speech/models/vqgan/inference.py \
     -i "paimon.wav" \
     --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
 ```
@@ -33,8 +36,11 @@ python tools/vqgan/inference.py \
 
 ### 2. 텍스트에서 시맨틱 토큰 생성:
 
+!!! warning "향후 버전 경고"
+    원래 경로(tools/llama/generate.py)에서 접근할 수 있는 인터페이스는 유지했지만, 이 인터페이스는 향후 몇몇 버전에서 삭제될 수 있습니다. 가능한 한 빨리 코드를 변경하십시오.
+
 ```bash
-python tools/llama/generate.py \
+python fish_speech/models/text2semantic/inference.py \
     --text "변환할 텍스트" \
     --prompt-text "참고할 텍스트" \
     --prompt-tokens "fake.npy" \
@@ -56,8 +62,11 @@ python tools/llama/generate.py \
 
 #### VQGAN 디코더
 
+!!! warning "향후 버전 경고"
+    원래 경로(tools/vqgan/infernce.py)에서 접근할 수 있는 인터페이스는 유지했지만, 이 인터페이스는 향후 몇몇 버전에서 삭제될 수 있습니다. 가능한 한 빨리 코드를 변경하십시오.
+
 ```bash
-python tools/vqgan/inference.py \
+python fish_speech/models/vqgan/inference.py \
     -i "codes_0.npy" \
     --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
 ```

+ 12 - 3
docs/pt/inference.md

@@ -23,8 +23,11 @@ huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-
 !!! note
     Se quiser permitir que o modelo escolha aleatoriamente um timbre de voz, pule esta etapa.
 
+!!! warning "Aviso de Versão Futura"
+    Mantivemos a interface acessível a partir do caminho original (tools/vqgan/infernce.py), mas esta interface poderá ser removida em algumas versões futuras. Por favor, altere o seu código o mais breve possível.
+
 ```bash
-python tools/vqgan/inference.py \
+python fish_speech/models/vqgan/inference.py \
     -i "paimon.wav" \
     --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
 ```
@@ -33,8 +36,11 @@ Você deverá obter um arquivo `fake.npy`.
 
 ### 2. Gerar tokens semânticos a partir do texto:
 
+!!! warning "Aviso de Versão Futura"
+    Mantivemos a interface acessível a partir do caminho original (tools/llama/generate.py), mas esta interface poderá ser removida em algumas versões futuras. Por favor, altere o seu código o mais breve possível.
+
 ```bash
-python tools/llama/generate.py \
+python fish_speech/models/text2semantic/inference.py \
     --text "O texto que você deseja converter" \
     --prompt-text "Seu texto de referência" \
     --prompt-tokens "fake.npy" \
@@ -56,8 +62,11 @@ Este comando criará um arquivo `codes_N` no diretório de trabalho, onde N é u
 
 #### Decodificador VQGAN
 
+!!! warning "Aviso de Versão Futura"
+    Mantivemos a interface acessível a partir do caminho original (tools/vqgan/infernce.py), mas esta interface poderá ser removida em algumas versões futuras. Por favor, altere o seu código o mais breve possível.
+
 ```bash
-python tools/vqgan/inference.py \
+python fish_speech/models/vqgan/inference.py \
     -i "codes_0.npy" \
     --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
 ```

+ 12 - 3
docs/zh/inference.md

@@ -29,8 +29,11 @@ HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech
 !!! note
     如果你打算让模型随机选择音色, 你可以跳过这一步.
 
+!!! warning "未来版本警告"
+    我们保留了从原来路径(tools/vqgan/infernce.py)访问的接口,但是这个接口可能在之后几个版本被删除,请尽快更改你的代码。
+
 ```bash
-python tools/vqgan/inference.py \
+python fish_speech/models/vqgan/inference.py \
     -i "paimon.wav" \
     --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
 ```
@@ -39,8 +42,11 @@ python tools/vqgan/inference.py \
 
 ### 2. 从文本生成语义 token:
 
+!!! warning "未来版本警告"
+    我们保留了从原来路径(tools/llama/generate.py)访问的接口,但是这个接口可能在之后几个版本被删除,请尽快更改你的代码。
+
 ```bash
-python tools/llama/generate.py \
+python fish_speech/models/text2semantic/inference.py \
     --text "要转换的文本" \
     --prompt-text "你的参考文本" \
     --prompt-tokens "fake.npy" \
@@ -62,8 +68,11 @@ python tools/llama/generate.py \
 
 #### VQGAN 解码
 
+!!! warning "未来版本警告"
+    我们保留了从原来路径(tools/vqgan/infernce.py)访问的接口,但是这个接口可能在之后几个版本被删除,请尽快更改你的代码。
+
 ```bash
-python tools/vqgan/inference.py \
+python fish_speech/models/vqgan/inference.py \
     -i "codes_0.npy" \
     --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
 ```

+ 12 - 9
tools/inference_engine/__init__.py → fish_speech/inference_engine/__init__.py

@@ -6,18 +6,18 @@ import numpy as np
 import torch
 from loguru import logger
 
-from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
-from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
-from fish_speech.utils import autocast_exclude_mps, set_seed
-from tools.inference_engine.reference_loader import ReferenceLoader
-from tools.inference_engine.utils import InferenceResult, wav_chunk_header
-from tools.inference_engine.vq_manager import VQManager
-from tools.llama.generate import (
+from fish_speech.inference_engine.reference_loader import ReferenceLoader
+from fish_speech.inference_engine.utils import InferenceResult, wav_chunk_header
+from fish_speech.inference_engine.vq_manager import VQManager
+from fish_speech.models.text2semantic.inference import (
     GenerateRequest,
     GenerateResponse,
     WrappedGenerateResponse,
 )
-from tools.schema import ServeTTSRequest
+from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
+from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
+from fish_speech.utils import autocast_exclude_mps, set_seed
+from fish_speech.utils.schema import ServeTTSRequest
 
 
 class TTSInferenceEngine(ReferenceLoader, VQManager):
@@ -72,7 +72,10 @@ class TTSInferenceEngine(ReferenceLoader, VQManager):
         if req.streaming:
             yield InferenceResult(
                 code="header",
-                audio=(sample_rate, wav_chunk_header(sample_rate=sample_rate)),
+                audio=(
+                    sample_rate,
+                    np.array(wav_chunk_header(sample_rate=sample_rate)),
+                ),
                 error=None,
             )
 

+ 7 - 2
tools/inference_engine/reference_loader.py → fish_speech/inference_engine/reference_loader.py

@@ -8,8 +8,13 @@ import torchaudio
 from loguru import logger
 
 from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
-from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
-from tools.schema import ServeReferenceAudio
+from fish_speech.utils.file import (
+    AUDIO_EXTENSIONS,
+    audio_to_bytes,
+    list_files,
+    read_ref_text,
+)
+from fish_speech.utils.schema import ServeReferenceAudio
 
 
 class ReferenceLoader:

+ 1 - 1
tools/inference_engine/utils.py → fish_speech/inference_engine/utils.py

@@ -11,7 +11,7 @@ from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
 @dataclass
 class InferenceResult:
     code: Literal["header", "segment", "error", "final"]
-    audio: Optional[Tuple[int, np.ndarray | bytes]]
+    audio: Optional[Tuple[int, np.ndarray]]
     error: Optional[Exception]
 
 

+ 0 - 0
tools/inference_engine/vq_manager.py → fish_speech/inference_engine/vq_manager.py


+ 1114 - 0
fish_speech/models/text2semantic/inference.py

@@ -0,0 +1,1114 @@
+import os
+import queue
+import threading
+import time
+from contextlib import nullcontext
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Literal, Optional, Tuple, Union
+
+import click
+import numpy as np
+import torch
+import torch._dynamo.config
+import torch._inductor.config
+from loguru import logger
+from tqdm import tqdm
+from transformers import AutoTokenizer
+
+from fish_speech.conversation import (
+    CODEBOOK_PAD_TOKEN_ID,
+    Conversation,
+    Message,
+    TextPart,
+    VQPart,
+)
+from fish_speech.models.text2semantic.llama import BaseModelArgs
+from fish_speech.text import clean_text, split_text
+from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
+
+os.environ["TOKENIZERS_PARALLELISM"] = "false"
+torch._inductor.config.coordinate_descent_tuning = True
+torch._inductor.config.triton.unique_kernel_names = True
+
+if hasattr(torch._inductor.config, "fx_graph_cache"):
+    # Experimental feature to reduce compilation times, will be on by default in future
+    torch._inductor.config.fx_graph_cache = True
+
+
+from torch.nn.attention import SDPBackend, sdpa_kernel
+
+from fish_speech.models.text2semantic.llama import (
+    BaseTransformer,
+    DualARTransformer,
+    NaiveTransformer,
+)
+
+
+def multinomial_sample_one_no_sync(
+    probs_sort,
+):  # Does multinomial sampling without a cuda synchronization
+    q = torch.empty_like(probs_sort).exponential_(1)
+    return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
+
+
+def logits_to_probs(
+    logits,
+    previous_tokens: Optional[torch.Tensor] = None,
+    temperature: torch.Tensor = 1.0,
+    top_p: torch.Tensor = 1.0,
+    repetition_penalty: torch.Tensor = 1.0,
+) -> torch.Tensor:
+    # Apply repetition penalty
+    if previous_tokens is not None:
+        previous_tokens = previous_tokens.long()
+        score = torch.gather(logits, dim=0, index=previous_tokens)
+        score = torch.where(
+            score < 0, score * repetition_penalty, score / repetition_penalty
+        )
+        logits.scatter_(dim=0, index=previous_tokens, src=score)
+
+    # Apply top-p sampling
+    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+    cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
+    sorted_indices_to_remove = cum_probs > top_p
+    sorted_indices_to_remove[0] = False  # keep at least one option
+    indices_to_remove = sorted_indices_to_remove.scatter(
+        dim=0, index=sorted_indices, src=sorted_indices_to_remove
+    )
+    logits = logits.masked_fill(indices_to_remove, -float("Inf"))
+
+    logits = logits / max(temperature, 1e-5)
+
+    probs = torch.nn.functional.softmax(logits, dim=-1)
+    return probs
+
+
+def multinomial_sample_one_no_sync_agent(
+    probs_sort,
+):  # Does multinomial sampling without a cuda synchronization
+    q = torch.empty_like(probs_sort).exponential_(1)
+    return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
+
+
+def logits_to_probs_agent(
+    logits,
+    previous_tokens: Optional[torch.Tensor] = None,
+    temperature: torch.Tensor = 1.0,
+    top_p: torch.Tensor = 1.0,
+    repetition_penalty: torch.Tensor = 1.0,
+) -> torch.Tensor:
+    # Apply repetition penalty
+    if previous_tokens is not None:
+        previous_tokens = previous_tokens.long()
+        score = torch.gather(logits, dim=-1, index=previous_tokens)
+        score = torch.where(
+            score < 0, score * repetition_penalty, score / repetition_penalty
+        )
+        logits.scatter_(dim=-1, index=previous_tokens, src=score)
+
+    # Apply top-p sampling
+    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+    cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
+    sorted_indices_to_remove = cum_probs > top_p
+    sorted_indices_to_remove[..., 0] = False  # keep at least one option
+    indices_to_remove = sorted_indices_to_remove.scatter(
+        dim=-1, index=sorted_indices, src=sorted_indices_to_remove
+    )
+    logits = logits.masked_fill(indices_to_remove, -float("Inf"))
+
+    logits = logits / max(temperature, 1e-5)
+
+    probs = torch.nn.functional.softmax(logits, dim=-1)
+    return probs
+
+
+def sample(
+    logits,
+    previous_tokens: Optional[torch.Tensor] = None,
+    **sampling_kwargs,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    probs = logits_to_probs(
+        logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
+    )
+    idx_next = multinomial_sample_one_no_sync(probs)
+    return idx_next, probs
+
+
+def sample_agent(
+    logits,
+    previous_tokens: Optional[torch.Tensor] = None,
+    **sampling_kwargs,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    probs = logits_to_probs_agent(
+        logits=logits[:, -1], previous_tokens=previous_tokens, **sampling_kwargs
+    )
+    idx_next = multinomial_sample_one_no_sync_agent(probs)
+    return idx_next, probs
+
+
+def decode_one_token_ar_agent(
+    model: DualARTransformer,
+    x: torch.Tensor,
+    input_pos: torch.Tensor,
+    semantic_ids: list,
+    previous_tokens: torch.Tensor = None,
+    **sampling_kwargs,
+) -> torch.Tensor:
+    # print(x, input_pos)
+    x = model.forward_generate(x, input_pos)
+    logits = x.logits  # [:, -1:]
+    hidden_states = x.hidden_states  # [:, -1:]
+
+    sampling_kwargs_main = sampling_kwargs.copy()
+    sampling_kwargs_main["temperature"] = 0.1
+    sampling_kwargs_main["top_p"] = 0.1
+    sampling_kwargs_main["repetition_penalty"] = 1.0
+
+    codebooks = [
+        sample_agent(
+            logits,
+            previous_tokens=None,  # Disable repetition penalty for the token codebook
+            **sampling_kwargs_main,
+        )[0]
+    ]
+
+    # Cleanup the cache
+    for layer in model.fast_layers:
+        layer.attention.kv_cache.k_cache.fill_(0)
+        layer.attention.kv_cache.v_cache.fill_(0)
+
+    for codebook_idx in range(model.config.num_codebooks):
+        input_pos = torch.tensor(
+            [codebook_idx], device=hidden_states.device, dtype=torch.long
+        )
+        logits = model.forward_generate_fast(hidden_states, input_pos)
+        a = sample_agent(
+            logits,
+            previous_tokens=(
+                previous_tokens[:, codebook_idx + 1]
+                if previous_tokens is not None
+                else None
+            ),
+            **sampling_kwargs,
+        )[0]
+        hidden_states = model.fast_embeddings(a)
+        codebooks.append(a)
+
+    codebooks = torch.stack(codebooks, dim=1)
+    semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
+    codebooks[:, 1:, :] = torch.masked_fill(
+        codebooks[:, 1:, :],
+        ~torch.isin(codebooks[:, :1, :], semantic_ids_tensor),
+        CODEBOOK_PAD_TOKEN_ID,
+    )
+
+    return codebooks
+
+
+def decode_one_token_naive_agent(
+    model: NaiveTransformer,
+    x: torch.Tensor,
+    input_pos: torch.Tensor,
+    semantic_ids: list,
+    previous_tokens: torch.Tensor = None,
+    **sampling_kwargs,
+) -> torch.Tensor:
+    x = model.forward_generate(x, input_pos)
+
+    codebooks = [
+        sample(
+            x.token_logits,
+            previous_tokens=None,  # Disable repetition penalty for the token codebook
+            **sampling_kwargs,
+        )[0]
+    ]
+
+    for i in range(model.config.num_codebooks):
+        codebooks.append(
+            sample_agent(
+                x.codebook_logits[:, :, i],
+                previous_tokens=(
+                    previous_tokens[:, i + 1] if previous_tokens is not None else None
+                ),
+                **sampling_kwargs,
+            )[0]
+        )
+
+    codebooks = torch.stack(codebooks, dim=1)
+    semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
+    codebooks[:, 1:, :] = torch.masked_fill(
+        codebooks[:, 1:, :],
+        ~torch.isin(codebooks[:, :1, :], semantic_ids_tensor),
+        CODEBOOK_PAD_TOKEN_ID,
+    )
+
+    return codebooks
+
+
+def decode_one_token_ar(
+    model: DualARTransformer,
+    x: torch.Tensor,
+    input_pos: torch.Tensor,
+    semantic_ids: list,
+    previous_tokens: torch.Tensor = None,
+    **sampling_kwargs,
+) -> torch.Tensor:
+    x = model.forward_generate(x, input_pos)
+
+    sampling_kwargs_main = sampling_kwargs.copy()
+    # sampling_kwargs_main["temperature"] = 0.1
+    # sampling_kwargs_main["top_p"] = 0.1
+    # sampling_kwargs_main["repetition_penalty"] = 1.0
+
+    codebooks = [
+        sample(
+            x.logits,
+            previous_tokens=(
+                previous_tokens[0] if previous_tokens is not None else None
+            ),  # Disable repetition penalty for the token codebook
+            **sampling_kwargs_main,
+        )[0]
+    ]
+
+    hidden_states = x.hidden_states
+
+    # Cleanup the cache
+    for layer in model.fast_layers:
+        layer.attention.kv_cache.k_cache.fill_(0)
+        layer.attention.kv_cache.v_cache.fill_(0)
+
+    input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
+    model.forward_generate_fast(hidden_states, input_pos)
+    a = codebooks[0] - model.tokenizer.semantic_begin_id
+    a[a < 0] = 0
+    hidden_states = model.fast_embeddings(a)
+    codebooks.append(a)
+
+    for codebook_idx in range(1, model.config.num_codebooks):
+        input_pos = torch.tensor(
+            [codebook_idx], device=hidden_states.device, dtype=torch.long
+        )
+        logits = model.forward_generate_fast(hidden_states, input_pos)
+        a = sample(
+            logits,
+            previous_tokens=(
+                previous_tokens[codebook_idx + 1]
+                if previous_tokens is not None
+                else None
+            ),
+            **sampling_kwargs,
+        )[0]
+        hidden_states = model.fast_embeddings(a)
+        codebooks.append(a)
+
+    codebooks = torch.stack(codebooks, dim=0)
+    # semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
+    # codebooks[1:, :] = torch.masked_fill(
+    #     codebooks[1:, :], ~torch.isin(codebooks[:1, :], semantic_ids_tensor), CODEBOOK_PAD_TOKEN_ID
+    # )
+
+    # print(codebooks)
+    return codebooks
+
+
+def decode_one_token_naive(
+    model: NaiveTransformer,
+    x: torch.Tensor,
+    input_pos: torch.Tensor,
+    previous_tokens: torch.Tensor = None,
+    **sampling_kwargs,
+) -> torch.Tensor:
+    x = model.forward_generate(x, input_pos)
+
+    sampling_kwargs_main = sampling_kwargs.copy()
+    sampling_kwargs_main["temperature"] = 0.1
+    sampling_kwargs_main["top_p"] = 0.1
+    sampling_kwargs_main["repetition_penalty"] = 1.0
+
+    codebooks = [
+        sample(
+            x.logits,
+            previous_tokens=None,  # Disable repetition penalty for the token codebook
+            **sampling_kwargs_main,
+        )[0]
+    ]
+
+    for i in range(model.config.num_codebooks):
+        codebooks.append(
+            sample(
+                x.codebook_logits[:, :, i],
+                previous_tokens=(
+                    previous_tokens[i + 1] if previous_tokens is not None else None
+                ),
+                **sampling_kwargs,
+            )[0]
+        )
+
+    return torch.stack(codebooks, dim=0)
+
+
+def decode_n_tokens(
+    model: NaiveTransformer,
+    cur_token: torch.Tensor,
+    input_pos: torch.Tensor,
+    num_new_tokens: int,
+    semantic_ids: list,
+    decode_one_token=decode_one_token_naive,
+    **sampling_kwargs,
+):
+    previous_tokens = torch.zeros(
+        (model.config.num_codebooks + 1, model.config.max_seq_len),
+        dtype=torch.int,
+        device=cur_token.device,
+    )
+
+    for i in tqdm(range(num_new_tokens)):
+        # We need to get windowed repeat penalty
+        win_size = 16
+        if i < win_size:
+            window = previous_tokens[:, :win_size]
+        else:
+            window = previous_tokens[:, i - win_size : i]
+
+        with (
+            torch.backends.cuda.sdp_kernel(
+                enable_flash=False, enable_mem_efficient=False, enable_math=True
+            )
+            if torch.cuda.is_available()
+            else nullcontext()
+        ):  # Actually better for Inductor to codegen attention here
+            next_token = decode_one_token(
+                model=model,
+                x=cur_token,
+                input_pos=input_pos,
+                previous_tokens=window,
+                semantic_ids=semantic_ids,
+                **sampling_kwargs,
+            )
+
+        input_pos += 1
+        cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
+        previous_tokens[:, i : i + 1] = next_token.view(
+            model.config.num_codebooks + 1, -1
+        )
+
+        if cur_token[0, 0, -1] == model.tokenizer.get_token_id(IM_END_TOKEN):
+            break
+
+    return previous_tokens[:, : i + 1]
+
+
+@torch.no_grad()
+@torch.inference_mode()
+def generate(
+    *,
+    model: NaiveTransformer,
+    prompt: torch.Tensor,
+    max_new_tokens: int,
+    decode_one_token=decode_one_token_naive,
+    **sampling_kwargs,
+) -> torch.Tensor:
+    """
+    Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
+    """
+
+    # create an empty tensor of the expected final shape and fill in the current tokens
+    T = prompt.size(1)
+    # semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
+    semantic_ids = [
+        model.tokenizer.get_token_id(f"<|semantic:{i}|>") for i in range(1024)
+    ]
+
+    if max_new_tokens:
+        if T + max_new_tokens > model.config.max_seq_len:
+            max_new_tokens = model.config.max_seq_len - T
+            logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
+
+        T_new = T + max_new_tokens
+    else:
+        T_new = model.config.max_seq_len
+        max_new_tokens = T_new - T
+
+    device, dtype = prompt.device, prompt.dtype
+
+    codebook_dim = 1 + model.config.num_codebooks
+    # create an empty tensor of the expected final shape and fill in the current tokens
+    empty = torch.empty(
+        (codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
+    )
+    empty[:, :T] = prompt
+    seq = empty
+    input_pos = torch.arange(0, T, device=device)
+
+    # Use non-accelerated version for now, to avoid compilation overhead
+    prefill_decode = (
+        decode_one_token_naive
+        if isinstance(model, NaiveTransformer)
+        else decode_one_token_ar
+    )
+
+    next_token = prefill_decode(
+        model,
+        prompt.view(1, codebook_dim, -1),
+        input_pos,
+        semantic_ids=semantic_ids,
+        **sampling_kwargs,
+    )
+    seq[:, T : T + 1] = next_token
+
+    input_pos = torch.tensor([T], device=device, dtype=torch.int)
+    x = decode_n_tokens(
+        model,
+        next_token.view(1, codebook_dim, -1),
+        input_pos,
+        max_new_tokens - 1,
+        decode_one_token=decode_one_token,
+        semantic_ids=semantic_ids,
+        **sampling_kwargs,
+    )
+    # x = torch.cat(generated_tokens, dim=1)
+    seq = seq[:, : T + 1 + x.size(1)]
+    seq[:, T + 1 :] = x
+
+    return seq
+
+
+def decode_n_tokens_agent(
+    model: NaiveTransformer,
+    cur_token: torch.Tensor,
+    input_pos: torch.Tensor,
+    num_new_tokens: int,
+    semantic_ids: list,
+    im_end_id: int = 4,
+    decode_one_token=decode_one_token_naive_agent,
+    early_stop_threshold: float = 0.6,
+    **sampling_kwargs,
+):
+    batch_size = cur_token.size(0)
+    previous_tokens = torch.zeros(
+        (batch_size, model.config.num_codebooks + 1, model.config.max_seq_len),
+        dtype=torch.int,
+        device=cur_token.device,
+    )
+    finished = torch.zeros(batch_size, dtype=torch.bool, device=cur_token.device)
+    finished = finished | (cur_token[:, 0, -1] == im_end_id)
+    start_time = time.time()
+
+    for i in tqdm(range(num_new_tokens), desc="Decoding: ", total=num_new_tokens):
+        # We need to get windowed repeat penalty
+        win_size = 16
+        if i < win_size:
+            window = previous_tokens[:, :, :win_size]
+        else:
+            window = previous_tokens[:, :, i - win_size : i]
+
+        with sdpa_kernel(
+            SDPBackend.MATH
+        ):  # Actually better for Inductor to codegen attention here
+            next_token = decode_one_token(
+                model=model,
+                x=cur_token,
+                input_pos=input_pos,
+                previous_tokens=window,
+                semantic_ids=semantic_ids,
+                **sampling_kwargs,
+            )
+
+        input_pos += 1
+        cur_token = next_token.view(batch_size, model.config.num_codebooks + 1, -1)
+        previous_tokens[:, :, i : i + 1] = next_token.view(
+            batch_size, model.config.num_codebooks + 1, -1
+        )
+
+        yield cur_token.cpu()
+
+        finished = finished | (cur_token[:, 0, -1] == im_end_id)
+        if finished.all() or (
+            0 < early_stop_threshold < 1
+            and finished.sum() >= round(batch_size * early_stop_threshold)
+        ):
+            break
+
+    total_time = time.time() - start_time
+    generated_tokens = i + 1
+    tokens_per_second = (generated_tokens / total_time) * batch_size
+    logger.info(
+        f"Decoded {generated_tokens} x {batch_size} tokens in {total_time:.2f}s ({tokens_per_second:.2f} tokens/s)"
+    )
+
+
+@torch.no_grad()
+@torch.inference_mode()
+def generate_agent(
+    *,
+    model: BaseTransformer,
+    prompt: torch.Tensor,
+    max_new_tokens: int,
+    semantic_ids: list,
+    im_end_id: int = 4,
+    decode_one_token=decode_one_token_naive_agent,
+    num_samples: int = 1,
+    early_stop_threshold: float = 0.6,
+    **sampling_kwargs,
+):
+    """
+    Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
+    """
+
+    # create an empty tensor of the expected final shape and fill in the current tokens
+    T = prompt.size(1)
+    prompt = prompt[None].repeat(num_samples, 1, 1)
+
+    if T >= model.config.max_seq_len:
+        raise ValueError(
+            f"Input sequence length {T} exceeds max_seq_len {model.config.max_seq_len}"
+        )
+
+    if max_new_tokens:
+        if T + max_new_tokens > model.config.max_seq_len:
+            max_new_tokens = model.config.max_seq_len - T
+            logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
+
+        T_new = T + max_new_tokens
+    else:
+        T_new = model.config.max_seq_len
+        max_new_tokens = T_new - T
+
+    device, dtype = prompt.device, prompt.dtype
+
+    codebook_dim = 1 + model.config.num_codebooks
+    input_pos = torch.arange(0, T, device=device)
+
+    # Use non-accelerated version for now, to avoid compilation overhead
+    prefill_decode = (
+        decode_one_token_naive_agent
+        if isinstance(model, NaiveTransformer)
+        else decode_one_token_ar_agent
+    )
+    next_token = prefill_decode(
+        model,
+        prompt,
+        input_pos,
+        semantic_ids=semantic_ids,
+        **sampling_kwargs,
+    ).view(num_samples, codebook_dim, -1)
+    yield next_token.cpu()
+
+    input_pos = torch.tensor([T], device=device, dtype=torch.int)
+
+    yield from decode_n_tokens_agent(
+        model,
+        next_token,
+        input_pos,
+        max_new_tokens - 1,
+        im_end_id=im_end_id,
+        semantic_ids=semantic_ids,
+        decode_one_token=decode_one_token,
+        early_stop_threshold=early_stop_threshold,
+        **sampling_kwargs,
+    )
+
+
+def encode_tokens(
+    tokenizer,
+    string,
+    device="cuda",
+    prompt_tokens=None,
+    num_codebooks=4,
+):
+    string = clean_text(string)
+
+    messages = []
+    messages.append(
+        Message(
+            role="user",
+            parts=[TextPart(text=string)],
+            cal_loss=False,
+        )
+    )
+
+    if prompt_tokens is not None:
+        if prompt_tokens.ndim == 3:
+            assert (
+                prompt_tokens.shape[0] == 1
+            ), "3D prompt tokens should have shape (1, num_codebooks, seq_len)"
+            prompt_tokens = prompt_tokens[0]
+
+        assert prompt_tokens.ndim == 2, "Prompt tokens should be 2D tensor"
+
+        if prompt_tokens.shape[0] > num_codebooks:
+            logger.warning(
+                f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
+            )
+            prompt_tokens = prompt_tokens[:num_codebooks]
+
+        vq_part = VQPart(codes=prompt_tokens.to(device))
+
+        messages.append(
+            Message(
+                role="assistant",
+                parts=[TextPart(text="<|voice|>"), vq_part],
+                cal_loss=False,
+            )
+        )
+    else:
+        messages.append(
+            Message(
+                role="assistant",
+                parts=[TextPart(text="<|voice|>")],
+                cal_loss=False,
+                add_im_end=False,
+            )
+        )
+
+    conversation = Conversation(messages=messages)
+    # conversation.visualize(tokenizer)
+    encoded = conversation.encode_for_inference(
+        tokenizer=tokenizer,
+        num_codebooks=num_codebooks,
+    )
+
+    return encoded.to(device)
+
+
+def load_model(checkpoint_path, device, precision, compile=False, is_agent=False):
+    model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
+        checkpoint_path, load_weights=True, is_agent=is_agent
+    )
+
+    model = model.to(device=device, dtype=precision)
+    logger.info(f"Restored model from checkpoint")
+
+    if isinstance(model, DualARTransformer):
+        decode_one_token = (
+            decode_one_token_ar_agent if is_agent else decode_one_token_ar
+        )
+        logger.info("Using DualARTransformer")
+    else:
+        decode_one_token = (
+            decode_one_token_naive_agent if is_agent else decode_one_token_naive
+        )
+        logger.info("Using NaiveTransformer")
+
+    if compile:
+        logger.info("Compiling function...")
+        decode_one_token = torch.compile(
+            decode_one_token,
+            fullgraph=True,
+            backend="inductor" if torch.cuda.is_available() else "aot_eager",
+            mode="reduce-overhead" if torch.cuda.is_available() else None,
+        )
+
+    return model.eval(), decode_one_token
+
+
+@dataclass
+class GenerateResponse:
+    action: Literal["sample", "next"]
+    codes: Optional[torch.Tensor] = None
+    text: Optional[str] = None
+
+
+def generate_long(
+    *,
+    model,
+    device: str | torch.device,
+    decode_one_token: callable,
+    text: str,
+    num_samples: int = 1,
+    max_new_tokens: int = 0,
+    top_p: int = 0.7,
+    repetition_penalty: float = 1.5,
+    temperature: float = 0.7,
+    compile: bool = False,
+    iterative_prompt: bool = True,
+    max_length: int = 2048,
+    chunk_length: int = 150,
+    prompt_text: Optional[str | list[str]] = None,
+    prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
+):
+    assert 0 < top_p <= 1, "top_p must be in (0, 1]"
+    assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
+    assert 0 < temperature < 2, "temperature must be in (0, 2)"
+
+    use_prompt = prompt_text is not None and prompt_tokens is not None
+    if use_prompt and isinstance(prompt_text, str):
+        prompt_text = [prompt_text]
+        prompt_tokens = [prompt_tokens]
+
+    assert use_prompt is False or len(prompt_text) == len(
+        prompt_tokens
+    ), "Prompt text and tokens must have the same length"
+
+    model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
+    tokenizer = model.tokenizer
+    im_end_id = tokenizer.get_token_id("<|im_end|>")
+
+    encoded = []
+    texts = split_text(text, chunk_length) if iterative_prompt else [text]
+    encoded_prompts = [
+        Conversation(
+            messages=[
+                Message(
+                    role="system",
+                    parts=[TextPart(text="Speak out the provided text.")],
+                    cal_loss=False,
+                )
+            ]
+        )
+        .encode_for_inference(
+            tokenizer=tokenizer,
+            num_codebooks=model.config.num_codebooks,
+        )
+        .to(device)
+    ]
+
+    if use_prompt:
+        for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
+            encoded_prompts.append(
+                encode_tokens(
+                    tokenizer,
+                    string=t,
+                    device=device,
+                    prompt_tokens=c,
+                    num_codebooks=model.config.num_codebooks,
+                )
+            )
+
+    for idx, text in enumerate(texts):
+        encoded.append(
+            encode_tokens(
+                tokenizer,
+                string=text,
+                device=device,
+                num_codebooks=model.config.num_codebooks,
+            )
+        )
+        logger.info(f"Encoded text: {text}")
+
+    # Move temperature, top_p, repetition_penalty to device
+    # This is important so that changing params doesn't trigger recompile
+    temperature = torch.tensor(temperature, device=device, dtype=torch.float)
+    top_p = torch.tensor(top_p, device=device, dtype=torch.float)
+    repetition_penalty = torch.tensor(
+        repetition_penalty, device=device, dtype=torch.float
+    )
+
+    for sample_idx in range(num_samples):
+        if torch.cuda.is_available():
+            torch.cuda.synchronize()
+
+        global_encoded = []
+        seg_idx = 0
+
+        while seg_idx < len(encoded):
+            logger.info(
+                f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
+            )
+
+            seg = encoded[seg_idx]
+            global_encoded.append(seg)
+
+            lengths = reversed([seg.size(1) for seg in global_encoded])
+
+            # Pick last 2000 tokens
+            count = 0
+            for i, length in enumerate(lengths):
+                count += length
+                if count + length > max_length - 1024 - sum(
+                    t.shape[1] for t in encoded_prompts
+                ):
+                    break
+
+            if i != 0 and i % 2 == 0:
+                i -= 1
+
+            # Rotate the list, always make sure first segment is included to avoid drift
+            if i < len(global_encoded) - 2:
+                partial_encoded = global_encoded[:2] + global_encoded[-i:]
+            else:
+                partial_encoded = global_encoded
+
+            if use_prompt:
+                partial_encoded = encoded_prompts + partial_encoded
+
+            cat_encoded = torch.cat(partial_encoded, dim=1)
+            prompt_length = cat_encoded.size(1)
+
+            t0 = time.perf_counter()
+            y = generate(
+                model=model,
+                prompt=cat_encoded,
+                max_new_tokens=max_new_tokens,
+                decode_one_token=decode_one_token,
+                temperature=temperature,
+                top_p=top_p,
+                repetition_penalty=repetition_penalty,
+            )
+
+            if sample_idx == 0 and seg_idx == 0 and compile:
+                logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
+
+            if torch.cuda.is_available():
+                torch.cuda.synchronize()
+
+            t = time.perf_counter() - t0
+
+            tokens_generated = y.size(1) - prompt_length
+            tokens_sec = tokens_generated / t
+            logger.info(
+                f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
+            )
+            logger.info(
+                f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
+            )
+
+            if torch.cuda.is_available():
+                logger.info(
+                    f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
+                )
+
+            # Put the generated tokens
+            # since there is <im_end>, we remove last token
+            codes = y[1:, prompt_length + 1 :].clone()
+            assert (codes >= 0).all(), f"Negative code found"
+
+            decoded = y[:, prompt_length:].clone()
+            # But for global encoding, we should keep the <im_end> token
+
+            global_encoded.append(decoded)
+            assert (codes >= 0).all(), f"Negative code found: {codes}"
+            yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
+            seg_idx += 1
+
+        # This indicates the end of the current sample
+        yield GenerateResponse(action="next")
+
+
+@dataclass
+class WrappedGenerateResponse:
+    status: Literal["success", "error"]
+    response: Optional[GenerateResponse | Exception] = None
+
+
+@dataclass
+class GenerateRequest:
+    request: dict
+    response_queue: queue.Queue
+
+
+def launch_thread_safe_queue(
+    checkpoint_path,
+    device,
+    precision,
+    compile: bool = False,
+):
+    input_queue = queue.Queue()
+    init_event = threading.Event()
+
+    def worker():
+        model, decode_one_token = load_model(
+            checkpoint_path, device, precision, compile=compile
+        )
+        with torch.device(device):
+            model.setup_caches(
+                max_batch_size=1,
+                max_seq_len=model.config.max_seq_len,
+                dtype=next(model.parameters()).dtype,
+            )
+        init_event.set()
+
+        while True:
+            item: GenerateRequest | None = input_queue.get()
+            if item is None:
+                break
+
+            kwargs = item.request
+            response_queue = item.response_queue
+
+            try:
+                for chunk in generate_long(
+                    model=model, decode_one_token=decode_one_token, **kwargs
+                ):
+                    response_queue.put(
+                        WrappedGenerateResponse(status="success", response=chunk)
+                    )
+            except Exception as e:
+                response_queue.put(WrappedGenerateResponse(status="error", response=e))
+
+    threading.Thread(target=worker, daemon=True).start()
+    init_event.wait()
+
+    return input_queue
+
+
+def launch_thread_safe_queue_agent(
+    checkpoint_path,
+    device,
+    precision,
+    compile: bool = False,
+):
+    input_queue = queue.Queue()
+    init_event = threading.Event()
+
+    tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
+    config = BaseModelArgs.from_pretrained(checkpoint_path)
+
+    def worker():
+        model, decode_one_token = load_model(
+            checkpoint_path, device, precision, compile=compile, is_agent=True
+        )
+
+        with torch.device(device):
+            model.setup_caches(
+                max_batch_size=1,
+                max_seq_len=model.config.max_seq_len,
+                dtype=next(model.parameters()).dtype,
+            )
+        init_event.set()
+
+        while True:
+            item: GenerateRequest | None = input_queue.get()
+            if item is None:
+                break
+
+            kwargs = item.request
+            response_queue = item.response_queue
+
+            try:
+                for token in generate_agent(
+                    model=model,
+                    decode_one_token=decode_one_token,
+                    **kwargs,
+                ):
+                    response_queue.put(token)
+
+                response_queue.put("stop")
+            except Exception as e:
+                import traceback
+
+                logger.exception(f"Error in worker: {traceback.format_exc()}")
+                response_queue.put("error")
+
+    threading.Thread(target=worker, daemon=True).start()
+    init_event.wait()
+
+    return input_queue, tokenizer, config
+
+
+@click.command()
+@click.option(
+    "--text",
+    type=str,
+    default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
+)
+@click.option("--prompt-text", type=str, default=None, multiple=True)
+@click.option(
+    "--prompt-tokens",
+    type=click.Path(path_type=Path, exists=True),
+    default=None,
+    multiple=True,
+)
+@click.option("--num-samples", type=int, default=1)
+@click.option("--max-new-tokens", type=int, default=0)
+@click.option("--top-p", type=float, default=0.7)
+@click.option("--repetition-penalty", type=float, default=1.2)
+@click.option("--temperature", type=float, default=0.7)
+@click.option(
+    "--checkpoint-path",
+    type=click.Path(path_type=Path, exists=True),
+    default="checkpoints/fish-speech-1.5",
+)
+@click.option("--device", type=str, default="cuda")
+@click.option("--compile/--no-compile", default=False)
+@click.option("--seed", type=int, default=42)
+@click.option("--half/--no-half", default=False)
+@click.option("--iterative-prompt/--no-iterative-prompt", default=True)
+@click.option("--chunk-length", type=int, default=100)
+def main(
+    text: str,
+    prompt_text: Optional[list[str]],
+    prompt_tokens: Optional[list[Path]],
+    num_samples: int,
+    max_new_tokens: int,
+    top_p: int,
+    repetition_penalty: float,
+    temperature: float,
+    checkpoint_path: Path,
+    device: str,
+    compile: bool,
+    seed: int,
+    half: bool,
+    iterative_prompt: bool,
+    chunk_length: int,
+) -> None:
+
+    precision = torch.half if half else torch.bfloat16
+
+    if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
+        raise ValueError(
+            f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
+        )
+
+    logger.info("Loading model ...")
+    t0 = time.time()
+    model, decode_one_token = load_model(
+        checkpoint_path, device, precision, compile=compile
+    )
+    with torch.device(device):
+        model.setup_caches(
+            max_batch_size=1,
+            max_seq_len=model.config.max_seq_len,
+            dtype=next(model.parameters()).dtype,
+        )
+    if torch.cuda.is_available():
+        torch.cuda.synchronize()
+
+    logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
+
+    if prompt_tokens is not None:
+        prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
+
+    torch.manual_seed(seed)
+
+    if torch.cuda.is_available():
+        torch.cuda.manual_seed(seed)
+
+    generator = generate_long(
+        model=model,
+        device=device,
+        decode_one_token=decode_one_token,
+        text=text,
+        num_samples=num_samples,
+        max_new_tokens=max_new_tokens,
+        top_p=top_p,
+        repetition_penalty=repetition_penalty,
+        temperature=temperature,
+        compile=compile,
+        iterative_prompt=iterative_prompt,
+        chunk_length=chunk_length,
+        prompt_text=prompt_text,
+        prompt_tokens=prompt_tokens,
+    )
+
+    idx = 0
+    codes = []
+
+    for response in generator:
+        if response.action == "sample":
+            codes.append(response.codes)
+            logger.info(f"Sampled text: {response.text}")
+        elif response.action == "next":
+            if codes:
+                np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy())
+                logger.info(f"Saved codes to codes_{idx}.npy")
+            logger.info(f"Next sample")
+            codes = []
+            idx += 1
+        else:
+            logger.error(f"Error: {response}")
+
+
+if __name__ == "__main__":
+    main()

+ 124 - 0
fish_speech/models/vqgan/inference.py

@@ -0,0 +1,124 @@
+from pathlib import Path
+
+import click
+import hydra
+import numpy as np
+import pyrootutils
+import soundfile as sf
+import torch
+import torchaudio
+from hydra import compose, initialize
+from hydra.utils import instantiate
+from loguru import logger
+from omegaconf import OmegaConf
+
+pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+
+from fish_speech.utils.file import AUDIO_EXTENSIONS
+
+# register eval resolver
+OmegaConf.register_new_resolver("eval", eval)
+
+
+def load_model(config_name, checkpoint_path, device="cuda"):
+    hydra.core.global_hydra.GlobalHydra.instance().clear()
+    with initialize(version_base="1.3", config_path="../../configs"):
+        cfg = compose(config_name=config_name)
+
+    model = instantiate(cfg)
+    state_dict = torch.load(
+        checkpoint_path, map_location=device, mmap=True, weights_only=True
+    )
+    if "state_dict" in state_dict:
+        state_dict = state_dict["state_dict"]
+
+    if any("generator" in k for k in state_dict):
+        state_dict = {
+            k.replace("generator.", ""): v
+            for k, v in state_dict.items()
+            if "generator." in k
+        }
+
+    result = model.load_state_dict(state_dict, strict=False, assign=True)
+    model.eval()
+    model.to(device)
+
+    logger.info(f"Loaded model: {result}")
+    return model
+
+
+@torch.no_grad()
+@click.command()
+@click.option(
+    "--input-path",
+    "-i",
+    default="test.wav",
+    type=click.Path(exists=True, path_type=Path),
+)
+@click.option(
+    "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
+)
+@click.option("--config-name", default="firefly_gan_vq")
+@click.option(
+    "--checkpoint-path",
+    default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+)
+@click.option(
+    "--device",
+    "-d",
+    default="cuda",
+)
+def main(input_path, output_path, config_name, checkpoint_path, device):
+    model = load_model(config_name, checkpoint_path, device=device)
+
+    if input_path.suffix in AUDIO_EXTENSIONS:
+        logger.info(f"Processing in-place reconstruction of {input_path}")
+
+        # Load audio
+        audio, sr = torchaudio.load(str(input_path))
+        if audio.shape[0] > 1:
+            audio = audio.mean(0, keepdim=True)
+        audio = torchaudio.functional.resample(
+            audio, sr, model.spec_transform.sample_rate
+        )
+
+        audios = audio[None].to(device)
+        logger.info(
+            f"Loaded audio with {audios.shape[2] / model.spec_transform.sample_rate:.2f} seconds"
+        )
+
+        # VQ Encoder
+        audio_lengths = torch.tensor([audios.shape[2]], device=device, dtype=torch.long)
+        indices = model.encode(audios, audio_lengths)[0][0]
+
+        logger.info(f"Generated indices of shape {indices.shape}")
+
+        # Save indices
+        np.save(output_path.with_suffix(".npy"), indices.cpu().numpy())
+    elif input_path.suffix == ".npy":
+        logger.info(f"Processing precomputed indices from {input_path}")
+        indices = np.load(input_path)
+        indices = torch.from_numpy(indices).to(device).long()
+        assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
+    else:
+        raise ValueError(f"Unknown input type: {input_path}")
+
+    # Restore
+    feature_lengths = torch.tensor([indices.shape[1]], device=device)
+    fake_audios, _ = model.decode(
+        indices=indices[None], feature_lengths=feature_lengths
+    )
+    audio_time = fake_audios.shape[-1] / model.spec_transform.sample_rate
+
+    logger.info(
+        f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}"
+    )
+
+    # Save audio
+    fake_audio = fake_audios[0, 0].float().cpu().numpy()
+    sf.write(output_path, fake_audio, model.spec_transform.sample_rate)
+    logger.info(f"Saved audio to {output_path}")
+
+
+if __name__ == "__main__":
+    main()

+ 123 - 0
fish_speech/utils/file.py

@@ -1,5 +1,27 @@
 import os
 from pathlib import Path
+from typing import Union
+
+from loguru import logger
+from natsort import natsorted
+
+AUDIO_EXTENSIONS = {
+    ".mp3",
+    ".wav",
+    ".flac",
+    ".ogg",
+    ".m4a",
+    ".wma",
+    ".aac",
+    ".aiff",
+    ".aif",
+    ".aifc",
+}
+
+VIDEO_EXTENSIONS = {
+    ".mp4",
+    ".avi",
+}
 
 
 def get_latest_checkpoint(path: Path | str) -> Path | None:
@@ -14,3 +36,104 @@ def get_latest_checkpoint(path: Path | str) -> Path | None:
         return None
 
     return ckpts[-1]
+
+
+def audio_to_bytes(file_path):
+    if not file_path or not Path(file_path).exists():
+        return None
+    with open(file_path, "rb") as wav_file:
+        wav = wav_file.read()
+    return wav
+
+
+def read_ref_text(ref_text):
+    path = Path(ref_text)
+    if path.exists() and path.is_file():
+        with path.open("r", encoding="utf-8") as file:
+            return file.read()
+    return ref_text
+
+
+def list_files(
+    path: Union[Path, str],
+    extensions: set[str] = set(),
+    recursive: bool = False,
+    sort: bool = True,
+) -> list[Path]:
+    """List files in a directory.
+
+    Args:
+        path (Path): Path to the directory.
+        extensions (set, optional): Extensions to filter. Defaults to None.
+        recursive (bool, optional): Whether to search recursively. Defaults to False.
+        sort (bool, optional): Whether to sort the files. Defaults to True.
+
+    Returns:
+        list: List of files.
+    """
+
+    if isinstance(path, str):
+        path = Path(path)
+
+    if not path.exists():
+        raise FileNotFoundError(f"Directory {path} does not exist.")
+
+    files = [file for ext in extensions for file in path.rglob(f"*{ext}")]
+
+    if sort:
+        files = natsorted(files)
+
+    return files
+
+
+def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
+    """
+    Load a Bert-VITS2 style filelist.
+    """
+
+    files = set()
+    results = []
+    count_duplicated, count_not_found = 0, 0
+
+    LANGUAGE_TO_LANGUAGES = {
+        "zh": ["zh", "en"],
+        "jp": ["jp", "en"],
+        "en": ["en"],
+    }
+
+    with open(path, "r", encoding="utf-8") as f:
+        for line in f.readlines():
+            splits = line.strip().split("|", maxsplit=3)
+            if len(splits) != 4:
+                logger.warning(f"Invalid line: {line}")
+                continue
+
+            filename, speaker, language, text = splits
+            file = Path(filename)
+            language = language.strip().lower()
+
+            if language == "ja":
+                language = "jp"
+
+            assert language in ["zh", "jp", "en"], f"Invalid language {language}"
+            languages = LANGUAGE_TO_LANGUAGES[language]
+
+            if file in files:
+                logger.warning(f"Duplicated file: {file}")
+                count_duplicated += 1
+                continue
+
+            if not file.exists():
+                logger.warning(f"File not found: {file}")
+                count_not_found += 1
+                continue
+
+            results.append((file, speaker, languages, text))
+
+    if count_duplicated > 0:
+        logger.warning(f"Total duplicated files: {count_duplicated}")
+
+    if count_not_found > 0:
+        logger.warning(f"Total files not found: {count_not_found}")
+
+    return results

+ 0 - 0
tools/schema.py → fish_speech/utils/schema.py


+ 0 - 161
fish_speech/webui/css/style.css

@@ -1,161 +0,0 @@
-:root {
-  --my-200: #80eeee;
-  --my-50: #ecfdf5;
-  --water-width: 300px;
-  --water-heigh: 300px;
-}
-
-
-/* general styled components */
-.tools {
-  align-items: center;
-  justify-content: center;
-}
-
-.gradio-button {
-    max-width: 2.2em;
-    min-width: 2.2em !important;
-    height: 2.4em;
-    align-self: end;
-    line-height: 1em;
-    border-radius: 0.5em;
-
-}
-
-.gradio-button.secondary-down, .gradio-button.secondary-down:hover{
-    box-shadow: 1px 1px 1px rgba(0,0,0,0.25) inset, 0px 0px 3px rgba(0,0,0,0.15) inset;
-}
-
-/* replace original footer with ours */
-a{
-    font-weight: bold;
-    cursor: pointer;
-    color: #030C14 !important;
-}
-
-footer {
-    display: none !important;
-}
-
-#footer{
-    text-align: center;
-}
-
-#footer div{
-    display: inline-block;
-}
-
-#footer .versions{
-    font-size: 85%;
-    opacity: 0.85;
-}
-
-/*@keyframes moveBackground {*/
-/*  0% {*/
-/*    background-position: 0 0;*/
-/*  }*/
-/*  100% {*/
-/*    background-position: -100px 100px;*/
-/*  }*/
-/*}*/
-@keyframes moveJellyBackground {
-  0% {
-    background-position: 0% 50%;
-  }
-  50% {
-    background-position: 100% 50%;
-  }
-  100% {
-    background-position: 0% 50%;
-  }
-}
-
-.gradio-container {
-  position: absolute;
-  z-index: 10;
-}
-
-
-.quan {
-  position: absolute;
-  bottom: 0;
-  width: var(--water-width);
-  height: var(--water-heigh);
-  border-radius: 0;
-  /*border: 3px solid rgb(246, 247, 248);*/
-  /*box-shadow: 0 0 0 3px rgb(41, 134, 196);*/
-  z-index: 0;
-
-}
-
-.quan:last-child {
-  margin-right: 0;
-}
-
-.shui {
-  position: absolute;
-  top: 0;
-  left: 0;
-  width: 100%;
-  height: 100%;
-  background-color: rgb(23, 106, 201);
-  border-radius: 0;
-  overflow: hidden;
-  z-index: 0;
-}
-
-.shui::after {
-
-  content: '';
-  position: absolute;
-  top: 20%;
-  left: 50%;
-  width: 150%;
-  height: 150%;
-  border-radius: 40%;
-  background-image: radial-gradient(circle at 0% 50%, #dcfcf1, var(--my-50) 50%);
-  animation: shi 5s linear infinite;
-}
-
-@keyframes shi {
-  0% {
-    transform: translate(-50%, -65%) rotate(0deg);
-  }
-  100% {
-    transform: translate(-50%, -65%) rotate(360deg);
-  }
-}
-
-.shui::before {
-  content: '';
-  position: absolute;
-  top: 20%;
-  left: 50%;
-  width: 150%;
-  height: 150%;
-  border-radius: 42%;
-  background-color: rgb(240, 228, 228, 0.2);
-  animation: xu 7s linear infinite;
-}
-
-@keyframes xu {
-  0% {
-    transform: translate(-50%, -60%) rotate(0deg);
-  }
-  100% {
-    transform: translate(-50%, -60%) rotate(360deg);
-  }
-}
-
-fieldset.data_src div.wrap label {
-  background: #f8bffee0 !important;
-}
-
-.scrollable-component {
-  max-height: 100px;
-  overflow-y: auto;
-}
-
-#file_accordion {
-  max-height: 220px !important;
-}

+ 0 - 11
fish_speech/webui/html/footer.html

@@ -1,11 +0,0 @@
-<div style="color: rgba(25,255,205,0.7) !important;">
-        <a href="{api_docs}">API</a>
-         • 
-        <a href="https://github.com/fishaudio/fish-speech">Github</a>
-         • 
-        <a href="https://gradio.app">Gradio</a>
-</div>
-<br />
-<div class="versions" style="color: rgba(25,255,205,0.7) !important;">
-{versions}
-</div>

+ 0 - 69
fish_speech/webui/js/animate.js

@@ -1,69 +0,0 @@
-
-function createGradioAnimation() {
-    const params = new URLSearchParams(window.location.search);
-    if (!params.has('__theme')) {
-        params.set('__theme', 'light');
-        window.location.search = params.toString();
-    }
-
-    var gradioApp = document.querySelector('gradio-app');
-    if (gradioApp) {
-
-        document.documentElement.style.setProperty('--my-200', '#80eeee');
-        document.documentElement.style.setProperty('--my-50', '#ecfdf5');
-
-        // gradioApp.style.position = 'relative';
-        // gradioApp.style.backgroundSize = '200% 200%';
-        // gradioApp.style.animation = 'moveJellyBackground 10s ease infinite';
-        // gradioApp.style.backgroundImage = 'radial-gradient(circle at 0% 50%, var(--my-200), var(--my-50) 50%)';
-        // gradioApp.style.display = 'flex';
-        // gradioApp.style.justifyContent = 'flex-start';
-        // gradioApp.style.flexWrap = 'nowrap';
-        // gradioApp.style.overflowX = 'auto';
-
-        // for (let i = 0; i < 6; i++) {
-        //     var quan = document.createElement('div');
-        //     quan.className = 'quan';
-        //     gradioApp.insertBefore(quan, gradioApp.firstChild);
-        //     quan.id = 'quan' + i.toString();
-        //     quan.style.left = 'calc(var(--water-width) * ' + i.toString() + ')';
-        //     var quanContainer = document.querySelector('.quan');
-        //     if (quanContainer) {
-        //         var shui = document.createElement('div');
-        //         shui.className = 'shui';
-        //         quanContainer.insertBefore(shui, quanContainer.firstChild)
-        //     }
-        // }
-    }
-
-    var container = document.createElement('div');
-    container.id = 'gradio-animation';
-    container.style.fontSize = '2em';
-    container.style.fontFamily = 'Maiandra GD, ui-monospace, monospace';
-    container.style.fontWeight = 'bold';
-    container.style.textAlign = 'center';
-    container.style.marginBottom = '20px';
-
-    var text = 'Welcome to Fish-Speech!';
-    for (var i = 0; i < text.length; i++) {
-        (function(i){
-            setTimeout(function(){
-                var letter = document.createElement('span');
-                letter.style.opacity = '0';
-                letter.style.transition = 'opacity 0.5s';
-                letter.innerText = text[i];
-
-                container.appendChild(letter);
-
-                setTimeout(function() {
-                    letter.style.opacity = '1';
-                }, 50);
-            }, i * 200);
-        })(i);
-    }
-
-    var gradioContainer = document.querySelector('.gradio-container');
-    gradioContainer.insertBefore(container, gradioContainer.firstChild);
-
-    return 'Animation created';
-}

+ 0 - 120
fish_speech/webui/launch_utils.py

@@ -1,120 +0,0 @@
-import importlib.util
-import os
-import subprocess
-import sys
-from functools import lru_cache
-from pathlib import Path
-from typing import Iterable
-
-import gradio as gr
-from gradio.themes.base import Base
-from gradio.themes.utils import colors, fonts, sizes
-
-GIT = (
-    (Path(os.environ.get("GIT_HOME", "")) / "git").resolve()
-    if sys.platform == "win32"
-    else "git"
-)
-GIT = str(GIT)
-
-
-def is_module_installed(module_name: str) -> bool:
-    spec = importlib.util.find_spec(module_name)
-    return spec is not None
-
-
-@lru_cache()
-def commit_hash():
-    try:
-        return subprocess.check_output(
-            [GIT, "log", "-1", "--format='%h %s'"], shell=False, encoding="utf8"
-        ).strip()
-    except Exception:
-        return "<none>"
-
-
-def versions_html():
-    import torch
-
-    python_version = ".".join([str(x) for x in sys.version_info[0:3]])
-    commit = commit_hash()
-    hash = commit.strip("'").split(" ")[0]
-
-    return f"""
-version: <a href="https://github.com/fishaudio/fish-speech/commit/{hash}">{hash}</a>
-&#x2000;•&#x2000;
-python: <span title="{sys.version}">{python_version}</span>
-&#x2000;•&#x2000;
-torch: {getattr(torch, '__long_version__',torch.__version__)}
-&#x2000;•&#x2000;
-gradio: {gr.__version__}
-&#x2000;•&#x2000;
-author: <a href="https://github.com/fishaudio">fishaudio</a>
-"""
-
-
-def version_check(commit):
-    try:
-        import requests
-
-        commits = requests.get(
-            "https://api.github.com/repos/fishaudio/fish-speech/branches/main"
-        ).json()
-        if commit != "<none>" and commits["commit"]["sha"] != commit:
-            print("--------------------------------------------------------")
-            print("| You are not up to date with the most recent release. |")
-            print("| Consider running `git pull` to update.               |")
-            print("--------------------------------------------------------")
-        elif commits["commit"]["sha"] == commit:
-            print("You are up to date with the most recent release.")
-        else:
-            print("Not a git clone, can't perform version check.")
-    except Exception as e:
-        print("version check failed", e)
-
-
-class Seafoam(Base):
-    def __init__(
-        self,
-        *,
-        primary_hue: colors.Color | str = colors.emerald,
-        secondary_hue: colors.Color | str = colors.blue,
-        neutral_hue: colors.Color | str = colors.blue,
-        spacing_size: sizes.Size | str = sizes.spacing_md,
-        radius_size: sizes.Size | str = sizes.radius_md,
-        text_size: sizes.Size | str = sizes.text_lg,
-        font: fonts.Font | str | Iterable[fonts.Font | str] = (
-            fonts.GoogleFont("Quicksand"),
-            "ui-sans-serif",
-            "sans-serif",
-        ),
-        font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
-            fonts.GoogleFont("IBM Plex Mono"),
-            "ui-monospace",
-            "monospace",
-        ),
-    ):
-        super().__init__(
-            primary_hue=primary_hue,
-            secondary_hue=secondary_hue,
-            neutral_hue=neutral_hue,
-            spacing_size=spacing_size,
-            radius_size=radius_size,
-            text_size=text_size,
-            font=font,
-            font_mono=font_mono,
-        )
-        super().set(
-            button_primary_background_fill="linear-gradient(90deg, *primary_300, *secondary_400)",
-            button_primary_background_fill_hover="linear-gradient(90deg, *primary_200, *secondary_300)",
-            button_primary_text_color="white",
-            button_primary_background_fill_dark="linear-gradient(90deg, *primary_600, *secondary_800)",
-            slider_color="*secondary_300",
-            slider_color_dark="*secondary_600",
-            block_title_text_weight="600",
-            block_border_width="3px",
-            block_shadow="*shadow_drop_lg",
-            # button_shadow="*shadow_drop_lg",
-            button_small_padding="0px",
-            button_large_padding="3px",
-        )

+ 0 - 1239
fish_speech/webui/manage.py

@@ -1,1239 +0,0 @@
-from __future__ import annotations
-
-import os
-
-os.environ["USE_LIBUV"] = "0"
-import datetime
-import html
-import json
-import platform
-import shutil
-import signal
-import subprocess
-import sys
-from pathlib import Path
-
-import gradio as gr
-import psutil
-import yaml
-from loguru import logger
-from tqdm import tqdm
-
-PYTHON = os.path.join(os.environ.get("PYTHON_FOLDERPATH", ""), "python")
-sys.path.insert(0, "")
-print(sys.path)
-cur_work_dir = Path(os.getcwd()).resolve()
-print("You are in ", str(cur_work_dir))
-
-from fish_speech.i18n import i18n
-from fish_speech.webui.launch_utils import Seafoam, is_module_installed, versions_html
-
-config_path = cur_work_dir / "fish_speech" / "configs"
-vqgan_yml_path = config_path / "firefly_gan_vq.yaml"
-llama_yml_path = config_path / "text2semantic_finetune.yaml"
-
-env = os.environ.copy()
-env["no_proxy"] = "127.0.0.1, localhost, 0.0.0.0"
-
-seafoam = Seafoam()
-
-
-def build_html_error_message(error):
-    return f"""
-    <div style="color: red; font-weight: bold;">
-        {html.escape(error)}
-    </div>
-    """
-
-
-def build_html_ok_message(msg):
-    return f"""
-    <div style="color: green; font-weight: bold;">
-        {html.escape(msg)}
-    </div>
-    """
-
-
-def build_html_href(link, desc, msg):
-    return f"""
-    <span style="color: green; font-weight: bold; display: inline-block">
-        {html.escape(msg)}
-        <a href="{link}">{desc}</a>
-    </span>
-    """
-
-
-def load_data_in_raw(path):
-    with open(path, "r", encoding="utf-8") as file:
-        data = file.read()
-    return str(data)
-
-
-def kill_proc_tree(pid, including_parent=True):
-    try:
-        parent = psutil.Process(pid)
-    except psutil.NoSuchProcess:
-        # Process already terminated
-        return
-
-    children = parent.children(recursive=True)
-    for child in children:
-        try:
-            os.kill(child.pid, signal.SIGTERM)  # or signal.SIGKILL
-        except OSError:
-            pass
-    if including_parent:
-        try:
-            os.kill(parent.pid, signal.SIGTERM)  # or signal.SIGKILL
-        except OSError:
-            pass
-
-
-system = platform.system()
-p_label = None
-p_infer = None
-p_tensorboard = None
-
-
-def kill_process(pid):
-    if system == "Windows":
-        cmd = "taskkill /t /f /pid %s" % pid
-        # os.system(cmd)
-        subprocess.run(cmd)
-    else:
-        kill_proc_tree(pid)
-
-
-def change_label(if_label):
-    global p_label
-    if if_label == True and p_label is None:
-        url = "http://localhost:3000"
-        remote_url = "https://text-labeler.pages.dev/"
-        try:
-            p_label = subprocess.Popen(
-                [
-                    (
-                        "asr-label-linux-x64"
-                        if sys.platform == "linux"
-                        else "asr-label-win-x64.exe"
-                    )
-                ]
-            )
-        except FileNotFoundError:
-            logger.warning("asr-label execution not found!")
-
-        yield build_html_href(
-            link=remote_url,
-            desc=i18n("Optional online ver"),
-            msg=i18n("Opened labeler in browser"),
-        )
-
-    elif if_label == False and p_label is not None:
-        kill_process(p_label.pid)
-        p_label = None
-        yield build_html_ok_message("Nothing")
-
-
-def clean_infer_cache():
-    import tempfile
-
-    temp_dir = Path(tempfile.gettempdir())
-    gradio_dir = str(temp_dir / "gradio")
-    try:
-        shutil.rmtree(gradio_dir)
-        logger.info(f"Deleted cached audios: {gradio_dir}")
-    except PermissionError:
-        logger.info(f"Permission denied: Unable to delete {gradio_dir}")
-    except FileNotFoundError:
-        logger.info(f"{gradio_dir} was not found")
-    except Exception as e:
-        logger.info(f"An error occurred: {e}")
-
-
-def change_infer(
-    if_infer,
-    host,
-    port,
-    infer_decoder_model,
-    infer_decoder_config,
-    infer_llama_model,
-    infer_compile,
-):
-    global p_infer
-    if if_infer == True and p_infer == None:
-        env = os.environ.copy()
-
-        env["GRADIO_SERVER_NAME"] = host
-        env["GRADIO_SERVER_PORT"] = port
-        # 启动第二个进程
-        url = f"http://{host}:{port}"
-        yield build_html_ok_message(
-            i18n("Inferring interface is launched at {}").format(url)
-        )
-
-        clean_infer_cache()
-
-        p_infer = subprocess.Popen(
-            [
-                PYTHON,
-                "tools/run_webui.py",
-                "--decoder-checkpoint-path",
-                infer_decoder_model,
-                "--decoder-config-name",
-                infer_decoder_config,
-                "--llama-checkpoint-path",
-                infer_llama_model,
-            ]
-            + (["--compile"] if infer_compile == "Yes" else []),
-            env=env,
-        )
-
-    elif if_infer == False and p_infer is not None:
-        kill_process(p_infer.pid)
-        p_infer = None
-        yield build_html_error_message(i18n("Infer interface is closed"))
-
-
-js = load_data_in_raw("fish_speech/webui/js/animate.js")
-css = load_data_in_raw("fish_speech/webui/css/style.css")
-
-data_pre_output = (cur_work_dir / "data").resolve()
-default_model_output = (cur_work_dir / "results").resolve()
-default_filelist = data_pre_output / "detect.list"
-data_pre_output.mkdir(parents=True, exist_ok=True)
-
-items = []
-dict_items = {}
-
-
-def load_yaml_data_in_fact(yml_path):
-    with open(yml_path, "r", encoding="utf-8") as file:
-        yml = yaml.safe_load(file)
-    return yml
-
-
-def write_yaml_data_in_fact(yml, yml_path):
-    with open(yml_path, "w", encoding="utf-8") as file:
-        yaml.safe_dump(yml, file, allow_unicode=True)
-    return yml
-
-
-def generate_tree(directory, depth=0, max_depth=None, prefix=""):
-    if max_depth is not None and depth > max_depth:
-        return ""
-
-    tree_str = ""
-    files = []
-    directories = []
-    for item in os.listdir(directory):
-        if os.path.isdir(os.path.join(directory, item)):
-            directories.append(item)
-        else:
-            files.append(item)
-
-    entries = directories + files
-    for i, entry in enumerate(entries):
-        connector = "├── " if i < len(entries) - 1 else "└── "
-        tree_str += f"{prefix}{connector}{entry}<br />"
-        if i < len(directories):
-            extension = "│   " if i < len(entries) - 1 else "    "
-            tree_str += generate_tree(
-                os.path.join(directory, entry),
-                depth + 1,
-                max_depth,
-                prefix=prefix + extension,
-            )
-    return tree_str
-
-
-def new_explorer(data_path, max_depth):
-    return gr.Markdown(
-        elem_classes=["scrollable-component"],
-        value=generate_tree(data_path, max_depth=max_depth),
-    )
-
-
-def add_item(
-    folder: str,
-    method: str,
-    label_lang: str,
-    if_initial_prompt: bool,
-    initial_prompt: str | None,
-):
-    folder = folder.strip(" ").strip('"')
-
-    folder_path = Path(folder)
-
-    if folder and folder not in items and data_pre_output not in folder_path.parents:
-        if folder_path.is_dir():
-            items.append(folder)
-            dict_items[folder] = dict(
-                type="folder",
-                method=method,
-                label_lang=label_lang,
-                initial_prompt=initial_prompt if if_initial_prompt else None,
-            )
-        elif folder:
-            err = folder
-            return gr.Checkboxgroup(choices=items), build_html_error_message(
-                i18n("Invalid path: {}").format(err)
-            )
-
-    formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
-    logger.info("After Adding: " + formatted_data)
-    gr.Info(formatted_data)
-    return gr.Checkboxgroup(choices=items), build_html_ok_message(
-        i18n("Added path successfully!")
-    )
-
-
-def remove_items(selected_items):
-    global items, dict_items
-    to_remove = [item for item in items if item in selected_items]
-    for item in to_remove:
-        del dict_items[item]
-    items = [item for item in items if item in dict_items.keys()]
-    formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
-    logger.info(formatted_data)
-    gr.Warning("After Removing: " + formatted_data)
-    return gr.Checkboxgroup(choices=items, value=[]), build_html_ok_message(
-        i18n("Removed path successfully!")
-    )
-
-
-def show_selected(options):
-    selected_options = ", ".join(options)
-
-    if options:
-        return i18n("Selected: {}").format(selected_options)
-    else:
-        return i18n("No selected options")
-
-
-from pydub import AudioSegment
-
-
-def convert_to_mono_in_place(audio_path: Path):
-    audio = AudioSegment.from_file(audio_path)
-    if audio.channels > 1:
-        mono_audio = audio.set_channels(1)
-        mono_audio.export(audio_path, format=audio_path.suffix[1:])
-        logger.info(f"Convert {audio_path} successfully")
-
-
-def list_copy(list_file_path, method):
-    wav_root = data_pre_output
-    lst = []
-    with list_file_path.open("r", encoding="utf-8") as file:
-        for line in tqdm(file, desc="Processing audio/transcript"):
-            wav_path, speaker_name, language, text = line.strip().split("|")
-            original_wav_path = Path(wav_path)
-            target_wav_path = (
-                wav_root / original_wav_path.parent.name / original_wav_path.name
-            )
-            lst.append(f"{target_wav_path}|{speaker_name}|{language}|{text}")
-            if target_wav_path.is_file():
-                continue
-            target_wav_path.parent.mkdir(parents=True, exist_ok=True)
-            if method == i18n("Copy"):
-                shutil.copy(original_wav_path, target_wav_path)
-            else:
-                shutil.move(original_wav_path, target_wav_path.parent)
-            convert_to_mono_in_place(target_wav_path)
-            original_lab_path = original_wav_path.with_suffix(".lab")
-            target_lab_path = (
-                wav_root
-                / original_wav_path.parent.name
-                / original_wav_path.with_suffix(".lab").name
-            )
-            if target_lab_path.is_file():
-                continue
-            if method == i18n("Copy"):
-                shutil.copy(original_lab_path, target_lab_path)
-            else:
-                shutil.move(original_lab_path, target_lab_path.parent)
-
-    if method == i18n("Move"):
-        with list_file_path.open("w", encoding="utf-8") as file:
-            file.writelines("\n".join(lst))
-
-    del lst
-    return build_html_ok_message(i18n("Use filelist"))
-
-
-def check_files(data_path: str, max_depth: int, label_model: str, label_device: str):
-    global dict_items
-    data_path = Path(data_path)
-    gr.Warning("Pre-processing begins...")
-    for item, content in dict_items.items():
-        item_path = Path(item)
-        tar_path = data_path / item_path.name
-
-        if content["type"] == "folder" and item_path.is_dir():
-            if content["method"] == i18n("Copy"):
-                os.makedirs(tar_path, exist_ok=True)
-                shutil.copytree(
-                    src=str(item_path), dst=str(tar_path), dirs_exist_ok=True
-                )
-            elif not tar_path.is_dir():
-                shutil.move(src=str(item_path), dst=str(tar_path))
-
-            for suf in ["wav", "flac", "mp3"]:
-                for audio_path in tar_path.glob(f"**/*.{suf}"):
-                    convert_to_mono_in_place(audio_path)
-
-            cur_lang = content["label_lang"]
-            initial_prompt = content["initial_prompt"]
-
-            transcribe_cmd = [
-                PYTHON,
-                "tools/whisper_asr.py",
-                "--model-size",
-                label_model,
-                "--device",
-                label_device,
-                "--audio-dir",
-                tar_path,
-                "--save-dir",
-                tar_path,
-                "--language",
-                cur_lang,
-            ]
-
-            if initial_prompt is not None:
-                transcribe_cmd += ["--initial-prompt", initial_prompt]
-
-            if cur_lang != "IGNORE":
-                try:
-                    gr.Warning("Begin To Transcribe")
-                    subprocess.run(
-                        transcribe_cmd,
-                        env=env,
-                    )
-                except Exception:
-                    print("Transcription error occurred")
-
-        elif content["type"] == "file" and item_path.is_file():
-            list_copy(item_path, content["method"])
-
-    return build_html_ok_message(i18n("Move files successfully")), new_explorer(
-        data_path, max_depth=max_depth
-    )
-
-
-def generate_folder_name():
-    now = datetime.datetime.now()
-    folder_name = now.strftime("%Y%m%d_%H%M%S")
-    return folder_name
-
-
-def train_process(
-    data_path: str,
-    option: str,
-    # llama config
-    llama_ckpt,
-    llama_base_config,
-    llama_lr,
-    llama_maxsteps,
-    llama_data_num_workers,
-    llama_data_batch_size,
-    llama_data_max_length,
-    llama_precision,
-    llama_check_interval,
-    llama_grad_batches,
-    llama_use_speaker,
-    llama_use_lora,
-):
-
-    backend = "nccl" if sys.platform == "linux" else "gloo"
-
-    new_project = generate_folder_name()
-    print("New Project Name: ", new_project)
-
-    if option == "VQGAN":
-        msg = "Skipped VQGAN Training."
-        gr.Warning(msg)
-        logger.info(msg)
-
-    if option == "LLAMA":
-        msg = "LLAMA Training begins..."
-        gr.Warning(msg)
-        logger.info(msg)
-        subprocess.run(
-            [
-                PYTHON,
-                "tools/vqgan/extract_vq.py",
-                str(data_pre_output),
-                "--num-workers",
-                "1",
-                "--batch-size",
-                "16",
-                "--config-name",
-                "firefly_gan_vq",
-                "--checkpoint-path",
-                "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
-            ]
-        )
-
-        subprocess.run(
-            [
-                PYTHON,
-                "tools/llama/build_dataset.py",
-                "--input",
-                str(data_pre_output),
-                "--text-extension",
-                ".lab",
-                "--num-workers",
-                "16",
-            ]
-        )
-        ckpt_path = "checkpoints/fish-speech-1.4/model.pth"
-        lora_prefix = "lora_" if llama_use_lora else ""
-        llama_name = lora_prefix + "text2semantic_" + new_project
-        latest = next(
-            iter(
-                sorted(
-                    [
-                        str(p.relative_to("results"))
-                        for p in Path("results").glob(lora_prefix + "text2sem*/")
-                    ],
-                    reverse=True,
-                )
-            ),
-            llama_name,
-        )
-        project = (
-            llama_name
-            if llama_ckpt == i18n("new")
-            else (
-                latest
-                if llama_ckpt == i18n("latest")
-                else Path(llama_ckpt).relative_to("results")
-            )
-        )
-        logger.info(project)
-
-        if llama_check_interval > llama_maxsteps:
-            llama_check_interval = llama_maxsteps
-
-        train_cmd = [
-            PYTHON,
-            "fish_speech/train.py",
-            "--config-name",
-            "text2semantic_finetune",
-            f"project={project}",
-            f"trainer.strategy.process_group_backend={backend}",
-            f"train_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
-            f"val_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
-            f"model.optimizer.lr={llama_lr}",
-            f"trainer.max_steps={llama_maxsteps}",
-            f"data.num_workers={llama_data_num_workers}",
-            f"data.batch_size={llama_data_batch_size}",
-            f"max_length={llama_data_max_length}",
-            f"trainer.precision={llama_precision}",
-            f"trainer.val_check_interval={llama_check_interval}",
-            f"trainer.accumulate_grad_batches={llama_grad_batches}",
-            f"train_dataset.interactive_prob={llama_use_speaker}",
-        ] + ([f"+lora@model.model.lora_config=r_8_alpha_16"] if llama_use_lora else [])
-        logger.info(train_cmd)
-        subprocess.run(train_cmd)
-
-    return build_html_ok_message(i18n("Training stopped"))
-
-
-def tensorboard_process(
-    if_tensorboard: bool,
-    tensorboard_dir: str,
-    host: str,
-    port: str,
-):
-    global p_tensorboard
-    if if_tensorboard == True and p_tensorboard == None:
-        url = f"http://{host}:{port}"
-        yield build_html_ok_message(
-            i18n("Tensorboard interface is launched at {}").format(url)
-        )
-        prefix = ["tensorboard"]
-        if Path("fishenv").exists():
-            prefix = ["fishenv/env/python.exe", "fishenv/env/Scripts/tensorboard.exe"]
-
-        p_tensorboard = subprocess.Popen(
-            prefix
-            + [
-                "--logdir",
-                tensorboard_dir,
-                "--host",
-                host,
-                "--port",
-                port,
-                "--reload_interval",
-                "120",
-            ]
-        )
-    elif if_tensorboard == False and p_tensorboard != None:
-        kill_process(p_tensorboard.pid)
-        p_tensorboard = None
-        yield build_html_error_message(i18n("Tensorboard interface is closed"))
-
-
-def fresh_tb_dir():
-    return gr.Dropdown(
-        choices=[str(p) for p in Path("results").glob("**/tensorboard/")]
-    )
-
-
-def list_decoder_models():
-    paths = [str(p) for p in Path("checkpoints").glob("fish*/firefly*.pth")]
-    if not paths:
-        logger.warning("No decoder model found")
-    return paths
-
-
-def list_llama_models():
-    choices = [str(p.parent) for p in Path("checkpoints").glob("merged*/*model*.pth")]
-    choices += [str(p.parent) for p in Path("checkpoints").glob("fish*/*model*.pth")]
-    choices += [str(p.parent) for p in Path("checkpoints").glob("fs*/*model*.pth")]
-    choices = sorted(choices, reverse=True)
-    if not choices:
-        logger.warning("No LLaMA model found")
-    return choices
-
-
-def list_lora_llama_models():
-    choices = sorted(
-        [str(p) for p in Path("results").glob("lora*/**/*.ckpt")], reverse=True
-    )
-    if not choices:
-        logger.warning("No LoRA LLaMA model found")
-    return choices
-
-
-def fresh_decoder_model():
-    return gr.Dropdown(choices=list_decoder_models())
-
-
-def fresh_llama_ckpt(llama_use_lora):
-    return gr.Dropdown(
-        choices=[i18n("latest"), i18n("new")]
-        + (
-            [str(p) for p in Path("results").glob("text2sem*/")]
-            if not llama_use_lora
-            else [str(p) for p in Path("results").glob("lora_*/")]
-        )
-    )
-
-
-def fresh_llama_model():
-    return gr.Dropdown(choices=list_llama_models())
-
-
-def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_output):
-    if (
-        lora_weight is None
-        or not Path(lora_weight).exists()
-        or not Path(llama_weight).exists()
-    ):
-        return build_html_error_message(
-            i18n(
-                "Path error, please check the model file exists in the corresponding path"
-            )
-        )
-    gr.Warning("Merging begins...")
-    merge_cmd = [
-        PYTHON,
-        "tools/llama/merge_lora.py",
-        "--lora-config",
-        "r_8_alpha_16",
-        "--lora-weight",
-        lora_weight,
-        "--output",
-        llama_lora_output + "_" + generate_folder_name(),
-    ]
-    logger.info(merge_cmd)
-    subprocess.run(merge_cmd)
-    return build_html_ok_message(i18n("Merge successfully"))
-
-
-def llama_quantify(llama_weight, quantify_mode):
-    if llama_weight is None or not Path(llama_weight).exists():
-        return build_html_error_message(
-            i18n(
-                "Path error, please check the model file exists in the corresponding path"
-            )
-        )
-
-    gr.Warning("Quantifying begins...")
-
-    now = generate_folder_name()
-    quantify_cmd = [
-        PYTHON,
-        "tools/llama/quantize.py",
-        "--checkpoint-path",
-        llama_weight,
-        "--mode",
-        quantify_mode,
-        "--timestamp",
-        now,
-    ]
-    logger.info(quantify_cmd)
-    subprocess.run(quantify_cmd)
-    if quantify_mode == "int8":
-        quantize_path = str(
-            Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-{now}"
-        )
-    else:
-        quantize_path = str(
-            Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-g128-{now}"
-        )
-    return build_html_ok_message(
-        i18n("Quantify successfully") + f"Path: {quantize_path}"
-    )
-
-
-init_vqgan_yml = load_yaml_data_in_fact(vqgan_yml_path)
-init_llama_yml = load_yaml_data_in_fact(llama_yml_path)
-
-with gr.Blocks(
-    head="<style>\n" + css + "\n</style>",
-    js=js,
-    theme=seafoam,
-    analytics_enabled=False,
-    title="Fish Speech",
-) as demo:
-    with gr.Row():
-        with gr.Column():
-            with gr.Tab("\U0001F4D6 " + i18n("Data Preprocessing")):
-                with gr.Row():
-                    textbox = gr.Textbox(
-                        label="\U0000270F "
-                        + i18n("Input Audio & Source Path for Transcription"),
-                        info=i18n("Speaker is identified by the folder name"),
-                        interactive=True,
-                    )
-                with gr.Row(equal_height=False):
-                    with gr.Column():
-                        output_radio = gr.Radio(
-                            label="\U0001F4C1 "
-                            + i18n("Select source file processing method"),
-                            choices=[i18n("Copy"), i18n("Move")],
-                            value=i18n("Copy"),
-                            interactive=True,
-                        )
-                    with gr.Column():
-                        error = gr.HTML(label=i18n("Error Message"))
-                        if_label = gr.Checkbox(
-                            label=i18n("Open Labeler WebUI"), scale=0, show_label=True
-                        )
-
-                with gr.Row():
-                    label_device = gr.Dropdown(
-                        label=i18n("Labeling Device"),
-                        info=i18n(
-                            "It is recommended to use CUDA, if you have low configuration, use CPU"
-                        ),
-                        choices=["cpu", "cuda"],
-                        value="cuda",
-                        interactive=True,
-                    )
-                    label_model = gr.Dropdown(
-                        label=i18n("Whisper Model"),
-                        info=i18n("Faster Whisper, Up to 5g GPU memory usage"),
-                        choices=["large-v3", "medium"],
-                        value="large-v3",
-                        interactive=True,
-                    )
-                    label_radio = gr.Dropdown(
-                        label=i18n("Optional Label Language"),
-                        info=i18n(
-                            "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format"
-                        ),
-                        choices=[
-                            (i18n("Chinese"), "zh"),
-                            (i18n("English"), "en"),
-                            (i18n("Japanese"), "ja"),
-                            (i18n("Disabled"), "IGNORE"),
-                            (i18n("auto"), "auto"),
-                        ],
-                        value="IGNORE",
-                        interactive=True,
-                    )
-
-                with gr.Row():
-                    if_initial_prompt = gr.Checkbox(
-                        value=False,
-                        label=i18n("Enable Initial Prompt"),
-                        min_width=120,
-                        scale=0,
-                    )
-                    initial_prompt = gr.Textbox(
-                        label=i18n("Initial Prompt"),
-                        info=i18n(
-                            "Initial prompt can provide contextual or vocabulary-specific guidance to the model."
-                        ),
-                        placeholder="This audio introduces the basic concepts and applications of artificial intelligence and machine learning.",
-                        interactive=False,
-                    )
-
-                with gr.Row():
-                    add_button = gr.Button(
-                        "\U000027A1 " + i18n("Add to Processing Area"),
-                        variant="primary",
-                    )
-                    remove_button = gr.Button(
-                        "\U000026D4 " + i18n("Remove Selected Data")
-                    )
-
-            with gr.Tab("\U0001F6E0 " + i18n("Training Configuration")):
-                with gr.Row():
-                    model_type_radio = gr.Radio(
-                        label=i18n(
-                            "Select the model to be trained (Depending on the Tab page you are on)"
-                        ),
-                        interactive=False,
-                        choices=["VQGAN", "LLAMA"],
-                        value="VQGAN",
-                    )
-                with gr.Row():
-                    with gr.Column():
-                        with gr.Tab(label=i18n("VQGAN Configuration")) as vqgan_page:
-                            gr.HTML("You don't need to train this model!")
-
-                        with gr.Tab(label=i18n("LLAMA Configuration")) as llama_page:
-                            with gr.Row(equal_height=False):
-                                llama_use_lora = gr.Checkbox(
-                                    label=i18n("Use LoRA"),
-                                    info=i18n(
-                                        "Use LoRA can save GPU memory, but may reduce the quality of the model"
-                                    ),
-                                    value=True,
-                                    interactive=True,
-                                )
-                                llama_ckpt = gr.Dropdown(
-                                    label=i18n("Select LLAMA ckpt"),
-                                    choices=[i18n("latest"), i18n("new")]
-                                    + [
-                                        str(p)
-                                        for p in Path("results").glob("text2sem*/")
-                                    ]
-                                    + [str(p) for p in Path("results").glob("lora*/")],
-                                    value=i18n("latest"),
-                                    interactive=True,
-                                )
-                            with gr.Row(equal_height=False):
-                                llama_lr_slider = gr.Slider(
-                                    label=i18n("Initial Learning Rate"),
-                                    info=i18n(
-                                        "lr smaller -> usually train slower but more stable"
-                                    ),
-                                    interactive=True,
-                                    minimum=1e-5,
-                                    maximum=1e-4,
-                                    step=1e-5,
-                                    value=5e-5,
-                                )
-                                llama_maxsteps_slider = gr.Slider(
-                                    label=i18n("Maximum Training Steps"),
-                                    info=i18n(
-                                        "recommend: max_steps = num_audios // batch_size * (2 to 5)"
-                                    ),
-                                    interactive=True,
-                                    minimum=1,
-                                    maximum=10000,
-                                    step=1,
-                                    value=50,
-                                )
-                            with gr.Row(equal_height=False):
-                                llama_base_config = gr.Dropdown(
-                                    label=i18n("Model Size"),
-                                    choices=[
-                                        "text2semantic_finetune",
-                                    ],
-                                    value="text2semantic_finetune",
-                                )
-                                llama_data_num_workers_slider = gr.Slider(
-                                    label=i18n("Number of Workers"),
-                                    minimum=1,
-                                    maximum=32,
-                                    step=1,
-                                    value=4,
-                                )
-                            with gr.Row(equal_height=False):
-                                llama_data_batch_size_slider = gr.Slider(
-                                    label=i18n("Batch Size"),
-                                    interactive=True,
-                                    minimum=1,
-                                    maximum=32,
-                                    step=1,
-                                    value=2,
-                                )
-                                llama_data_max_length_slider = gr.Slider(
-                                    label=i18n("Maximum Length per Sample"),
-                                    interactive=True,
-                                    minimum=1024,
-                                    maximum=4096,
-                                    step=128,
-                                    value=2048,
-                                )
-                            with gr.Row(equal_height=False):
-                                llama_precision_dropdown = gr.Dropdown(
-                                    label=i18n("Precision"),
-                                    info=i18n(
-                                        "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU"
-                                    ),
-                                    interactive=True,
-                                    choices=["32", "bf16-true", "16-mixed"],
-                                    value="bf16-true",
-                                )
-                                llama_check_interval_slider = gr.Slider(
-                                    label=i18n("Save model every n steps"),
-                                    info=i18n(
-                                        "make sure that it's not greater than max_steps"
-                                    ),
-                                    interactive=True,
-                                    minimum=1,
-                                    maximum=1000,
-                                    step=1,
-                                    value=50,
-                                )
-                            with gr.Row(equal_height=False):
-                                llama_grad_batches = gr.Slider(
-                                    label=i18n("Accumulate Gradient Batches"),
-                                    interactive=True,
-                                    minimum=1,
-                                    maximum=20,
-                                    step=1,
-                                    value=init_llama_yml["trainer"][
-                                        "accumulate_grad_batches"
-                                    ],
-                                )
-                                llama_use_speaker = gr.Slider(
-                                    label=i18n(
-                                        "Probability of applying Speaker Condition"
-                                    ),
-                                    interactive=True,
-                                    minimum=0.1,
-                                    maximum=1.0,
-                                    step=0.05,
-                                    value=init_llama_yml["train_dataset"][
-                                        "interactive_prob"
-                                    ],
-                                )
-
-                        with gr.Tab(label=i18n("Merge LoRA"), id=4):
-                            with gr.Row(equal_height=False):
-                                llama_weight = gr.Dropdown(
-                                    label=i18n("Base LLAMA Model"),
-                                    info=i18n(
-                                        "Type the path or select from the dropdown"
-                                    ),
-                                    choices=[
-                                        "checkpoints/fish-speech-1.4/model.pth",
-                                    ],
-                                    value="checkpoints/fish-speech-1.4/model.pth",
-                                    allow_custom_value=True,
-                                    interactive=True,
-                                )
-                            with gr.Row(equal_height=False):
-                                lora_weight = gr.Dropdown(
-                                    label=i18n("LoRA Model to be merged"),
-                                    info=i18n(
-                                        "Type the path or select from the dropdown"
-                                    ),
-                                    choices=[
-                                        str(p)
-                                        for p in Path("results").glob("lora*/**/*.ckpt")
-                                    ],
-                                    allow_custom_value=True,
-                                    interactive=True,
-                                )
-                                lora_llama_config = gr.Dropdown(
-                                    label=i18n("LLAMA Model Config"),
-                                    info=i18n(
-                                        "Type the path or select from the dropdown"
-                                    ),
-                                    choices=[
-                                        "text2semantic_finetune",
-                                    ],
-                                    value="text2semantic_finetune",
-                                    allow_custom_value=True,
-                                )
-                            with gr.Row(equal_height=False):
-                                llama_lora_output = gr.Dropdown(
-                                    label=i18n("Output Path"),
-                                    info=i18n(
-                                        "Type the path or select from the dropdown"
-                                    ),
-                                    value="checkpoints/merged",
-                                    choices=["checkpoints/merged"],
-                                    allow_custom_value=True,
-                                    interactive=True,
-                                )
-                            with gr.Row(equal_height=False):
-                                llama_lora_merge_btn = gr.Button(
-                                    value=i18n("Merge"), variant="primary"
-                                )
-
-                        with gr.Tab(label=i18n("Model Quantization"), id=5):
-                            with gr.Row(equal_height=False):
-                                llama_weight_to_quantify = gr.Dropdown(
-                                    label=i18n("Base LLAMA Model"),
-                                    info=i18n(
-                                        "Type the path or select from the dropdown"
-                                    ),
-                                    choices=list_llama_models(),
-                                    value="checkpoints/fish-speech-1.4",
-                                    allow_custom_value=True,
-                                    interactive=True,
-                                )
-                                quantify_mode = gr.Dropdown(
-                                    label=i18n("Post-quantification Precision"),
-                                    info=i18n(
-                                        "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase"
-                                    ),
-                                    choices=["int8", "int4"],
-                                    value="int8",
-                                    allow_custom_value=False,
-                                    interactive=True,
-                                )
-                            with gr.Row(equal_height=False):
-                                llama_quantify_btn = gr.Button(
-                                    value=i18n("Quantify"), variant="primary"
-                                )
-
-                        with gr.Tab(label="Tensorboard", id=6):
-                            with gr.Row(equal_height=False):
-                                tb_host = gr.Textbox(
-                                    label=i18n("Tensorboard Host"), value="127.0.0.1"
-                                )
-                                tb_port = gr.Textbox(
-                                    label=i18n("Tensorboard Port"), value="11451"
-                                )
-                            with gr.Row(equal_height=False):
-                                tb_dir = gr.Dropdown(
-                                    label=i18n("Tensorboard Log Path"),
-                                    allow_custom_value=True,
-                                    choices=[
-                                        str(p)
-                                        for p in Path("results").glob("**/tensorboard/")
-                                    ],
-                                )
-                            with gr.Row(equal_height=False):
-                                if_tb = gr.Checkbox(
-                                    label=i18n("Open Tensorboard"),
-                                )
-
-            with gr.Tab("\U0001F9E0 " + i18n("Inference Configuration")):
-                with gr.Column():
-                    with gr.Row():
-                        with gr.Accordion(
-                            label="\U0001F5A5 "
-                            + i18n("Inference Server Configuration"),
-                            open=False,
-                        ):
-                            with gr.Row():
-                                infer_host_textbox = gr.Textbox(
-                                    label=i18n("WebUI Host"), value="127.0.0.1"
-                                )
-                                infer_port_textbox = gr.Textbox(
-                                    label=i18n("WebUI Port"), value="7862"
-                                )
-                            with gr.Row():
-                                infer_decoder_model = gr.Dropdown(
-                                    label=i18n("Decoder Model Path"),
-                                    info=i18n(
-                                        "Type the path or select from the dropdown"
-                                    ),
-                                    choices=list_decoder_models(),
-                                    value="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
-                                    allow_custom_value=True,
-                                )
-                                infer_decoder_config = gr.Dropdown(
-                                    label=i18n("Decoder Model Config"),
-                                    info=i18n("Changing with the Model Path"),
-                                    value="firefly_gan_vq",
-                                    choices=[
-                                        "firefly_gan_vq",
-                                    ],
-                                    allow_custom_value=True,
-                                )
-                            with gr.Row():
-                                infer_llama_model = gr.Dropdown(
-                                    label=i18n("LLAMA Model Path"),
-                                    info=i18n(
-                                        "Type the path or select from the dropdown"
-                                    ),
-                                    value="checkpoints/fish-speech-1.4",
-                                    choices=list_llama_models(),
-                                    allow_custom_value=True,
-                                )
-
-                            with gr.Row():
-                                infer_compile = gr.Radio(
-                                    label=i18n("Compile Model"),
-                                    info=i18n(
-                                        "Compile the model can significantly reduce the inference time, but will increase cold start time"
-                                    ),
-                                    choices=["Yes", "No"],
-                                    value=(
-                                        "Yes" if (sys.platform == "linux") else "No"
-                                    ),
-                                    interactive=is_module_installed("triton"),
-                                )
-
-                    with gr.Row():
-                        infer_checkbox = gr.Checkbox(
-                            label=i18n("Open Inference Server")
-                        )
-                        infer_error = gr.HTML(label=i18n("Inference Server Error"))
-
-        with gr.Column():
-            train_error = gr.HTML(label=i18n("Training Error"))
-            checkbox_group = gr.CheckboxGroup(
-                label="\U0001F4CA " + i18n("Data Source"),
-                info=i18n(
-                    "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list."
-                ),
-                elem_classes=["data_src"],
-            )
-            train_box = gr.Textbox(
-                label=i18n("Data Preprocessing Path"),
-                value=str(data_pre_output),
-                interactive=False,
-            )
-            model_box = gr.Textbox(
-                label="\U0001F4BE " + i18n("Model Output Path"),
-                value=str(default_model_output),
-                interactive=False,
-            )
-
-            with gr.Accordion(
-                i18n(
-                    "View the status of the preprocessing folder (use the slider to control the depth of the tree)"
-                ),
-                elem_classes=["scrollable-component"],
-                elem_id="file_accordion",
-            ):
-                tree_slider = gr.Slider(
-                    minimum=0,
-                    maximum=3,
-                    value=0,
-                    step=1,
-                    show_label=False,
-                    container=False,
-                )
-                file_markdown = new_explorer(str(data_pre_output), 0)
-            with gr.Row(equal_height=False):
-                admit_btn = gr.Button(
-                    "\U00002705 " + i18n("File Preprocessing"),
-                    variant="primary",
-                )
-                fresh_btn = gr.Button("\U0001F503", scale=0, min_width=80)
-                help_button = gr.Button("\U00002753", scale=0, min_width=80)  # question
-                train_btn = gr.Button(i18n("Start Training"), variant="primary")
-
-    footer = load_data_in_raw("fish_speech/webui/html/footer.html")
-    footer = footer.format(
-        versions=versions_html(),
-        api_docs="https://speech.fish.audio/inference/#http-api",
-    )
-    gr.HTML(footer, elem_id="footer")
-    vqgan_page.select(lambda: "VQGAN", None, model_type_radio)
-    llama_page.select(lambda: "LLAMA", None, model_type_radio)
-    add_button.click(
-        fn=add_item,
-        inputs=[textbox, output_radio, label_radio, if_initial_prompt, initial_prompt],
-        outputs=[checkbox_group, error],
-    )
-    remove_button.click(
-        fn=remove_items, inputs=[checkbox_group], outputs=[checkbox_group, error]
-    )
-    checkbox_group.change(fn=show_selected, inputs=checkbox_group, outputs=[error])
-    help_button.click(
-        fn=None,
-        js='() => { window.open("https://speech.fish.audio/", "newwindow", "height=100, width=400, '
-        'toolbar=no, menubar=no, scrollbars=no, resizable=no, location=no, status=no")}',
-    )
-    if_label.change(fn=change_label, inputs=[if_label], outputs=[error])
-    if_initial_prompt.change(
-        fn=lambda x: gr.Textbox(value="", interactive=x),
-        inputs=[if_initial_prompt],
-        outputs=[initial_prompt],
-    )
-    train_btn.click(
-        fn=train_process,
-        inputs=[
-            train_box,
-            model_type_radio,
-            # llama config
-            llama_ckpt,
-            llama_base_config,
-            llama_lr_slider,
-            llama_maxsteps_slider,
-            llama_data_num_workers_slider,
-            llama_data_batch_size_slider,
-            llama_data_max_length_slider,
-            llama_precision_dropdown,
-            llama_check_interval_slider,
-            llama_grad_batches,
-            llama_use_speaker,
-            llama_use_lora,
-        ],
-        outputs=[train_error],
-    )
-    if_tb.change(
-        fn=tensorboard_process,
-        inputs=[if_tb, tb_dir, tb_host, tb_port],
-        outputs=[train_error],
-    )
-    tb_dir.change(fn=fresh_tb_dir, inputs=[], outputs=[tb_dir])
-    infer_decoder_model.change(
-        fn=fresh_decoder_model, inputs=[], outputs=[infer_decoder_model]
-    )
-    infer_llama_model.change(
-        fn=fresh_llama_model, inputs=[], outputs=[infer_llama_model]
-    )
-    llama_weight.change(fn=fresh_llama_model, inputs=[], outputs=[llama_weight])
-    admit_btn.click(
-        fn=check_files,
-        inputs=[train_box, tree_slider, label_model, label_device],
-        outputs=[error, file_markdown],
-    )
-    fresh_btn.click(
-        fn=new_explorer, inputs=[train_box, tree_slider], outputs=[file_markdown]
-    )
-    llama_use_lora.change(
-        fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt]
-    )
-    llama_ckpt.change(
-        fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt]
-    )
-    lora_weight.change(
-        fn=lambda: gr.Dropdown(choices=list_lora_llama_models()),
-        inputs=[],
-        outputs=[lora_weight],
-    )
-    llama_lora_merge_btn.click(
-        fn=llama_lora_merge,
-        inputs=[llama_weight, lora_llama_config, lora_weight, llama_lora_output],
-        outputs=[train_error],
-    )
-    llama_quantify_btn.click(
-        fn=llama_quantify,
-        inputs=[llama_weight_to_quantify, quantify_mode],
-        outputs=[train_error],
-    )
-    infer_checkbox.change(
-        fn=change_infer,
-        inputs=[
-            infer_checkbox,
-            infer_host_textbox,
-            infer_port_textbox,
-            infer_decoder_model,
-            infer_decoder_config,
-            infer_llama_model,
-            infer_compile,
-        ],
-        outputs=[infer_error],
-    )
-
-demo.launch(inbrowser=True)

+ 9 - 9
inference.ipynb

@@ -61,7 +61,7 @@
     "# !set HF_ENDPOINT=https://hf-mirror.com\n",
     "# !export HF_ENDPOINT=https://hf-mirror.com \n",
     "\n",
-    "!huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4/"
+    "!huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5/"
    ]
   },
   {
@@ -84,8 +84,8 @@
    "outputs": [],
    "source": [
     "!python tools/run_webui.py \\\n",
-    "    --llama-checkpoint-path checkpoints/fish-speech-1.4 \\\n",
-    "    --decoder-checkpoint-path checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth \\\n",
+    "    --llama-checkpoint-path checkpoints/fish-speech-1.5 \\\n",
+    "    --decoder-checkpoint-path checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth \\\n",
     "    # --compile"
    ]
   },
@@ -120,9 +120,9 @@
     "## Enter the path to the audio file here\n",
     "src_audio = r\"D:\\PythonProject\\vo_hutao_draw_appear.wav\"\n",
     "\n",
-    "!python tools/vqgan/inference.py \\\n",
+    "!python fish_speech/models/vqgan/inference.py \\\n",
     "    -i {src_audio} \\\n",
-    "    --checkpoint-path \"checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth\"\n",
+    "    --checkpoint-path \"checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth\"\n",
     "\n",
     "from IPython.display import Audio, display\n",
     "audio = Audio(filename=\"fake.wav\")\n",
@@ -154,11 +154,11 @@
    },
    "outputs": [],
    "source": [
-    "!python tools/llama/generate.py \\\n",
+    "!python fish_speech/models/text2semantic/inference.py \\\n",
     "    --text \"hello world\" \\\n",
     "    --prompt-text \"The text corresponding to reference audio\" \\\n",
     "    --prompt-tokens \"fake.npy\" \\\n",
-    "    --checkpoint-path \"checkpoints/fish-speech-1.4\" \\\n",
+    "    --checkpoint-path \"checkpoints/fish-speech-1.5\" \\\n",
     "    --num-samples 2\n",
     "    # --compile"
    ]
@@ -180,9 +180,9 @@
    },
    "outputs": [],
    "source": [
-    "!python tools/vqgan/inference.py \\\n",
+    "!python fish_speech/models/vqgan/inference.py \\\n",
     "    -i \"codes_0.npy\" \\\n",
-    "    --checkpoint-path \"checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth\"\n",
+    "    --checkpoint-path \"checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth\"\n",
     "\n",
     "from IPython.display import Audio, display\n",
     "audio = Audio(filename=\"fake.wav\")\n",

+ 2 - 2
tools/api_client.py

@@ -8,8 +8,8 @@ import requests
 from pydub import AudioSegment
 from pydub.playback import play
 
-from tools.file import audio_to_bytes, read_ref_text
-from tools.schema import ServeReferenceAudio, ServeTTSRequest
+from fish_speech.utils.file import audio_to_bytes, read_ref_text
+from fish_speech.utils.schema import ServeReferenceAudio, ServeTTSRequest
 
 
 def parse_args():

+ 2 - 2
tools/e2e_webui.py

@@ -3,10 +3,10 @@ import re
 import wave
 
 import gradio as gr
-import numpy as np
+
+from fish_speech.utils.schema import ServeMessage, ServeTextPart, ServeVQPart
 
 from .fish_e2e import FishE2EAgent, FishE2EEventType
-from .schema import ServeMessage, ServeTextPart, ServeVQPart
 
 
 def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):

+ 0 - 125
tools/file.py

@@ -1,125 +0,0 @@
-import base64
-from pathlib import Path
-from typing import Union
-
-from loguru import logger
-from natsort import natsorted
-
-AUDIO_EXTENSIONS = {
-    ".mp3",
-    ".wav",
-    ".flac",
-    ".ogg",
-    ".m4a",
-    ".wma",
-    ".aac",
-    ".aiff",
-    ".aif",
-    ".aifc",
-}
-
-VIDEO_EXTENSIONS = {
-    ".mp4",
-    ".avi",
-}
-
-
-def audio_to_bytes(file_path):
-    if not file_path or not Path(file_path).exists():
-        return None
-    with open(file_path, "rb") as wav_file:
-        wav = wav_file.read()
-    return wav
-
-
-def read_ref_text(ref_text):
-    path = Path(ref_text)
-    if path.exists() and path.is_file():
-        with path.open("r", encoding="utf-8") as file:
-            return file.read()
-    return ref_text
-
-
-def list_files(
-    path: Union[Path, str],
-    extensions: set[str] = None,
-    recursive: bool = False,
-    sort: bool = True,
-) -> list[Path]:
-    """List files in a directory.
-
-    Args:
-        path (Path): Path to the directory.
-        extensions (set, optional): Extensions to filter. Defaults to None.
-        recursive (bool, optional): Whether to search recursively. Defaults to False.
-        sort (bool, optional): Whether to sort the files. Defaults to True.
-
-    Returns:
-        list: List of files.
-    """
-
-    if isinstance(path, str):
-        path = Path(path)
-
-    if not path.exists():
-        raise FileNotFoundError(f"Directory {path} does not exist.")
-
-    files = [file for ext in extensions for file in path.rglob(f"*{ext}")]
-
-    if sort:
-        files = natsorted(files)
-
-    return files
-
-
-def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
-    """
-    Load a Bert-VITS2 style filelist.
-    """
-
-    files = set()
-    results = []
-    count_duplicated, count_not_found = 0, 0
-
-    LANGUAGE_TO_LANGUAGES = {
-        "zh": ["zh", "en"],
-        "jp": ["jp", "en"],
-        "en": ["en"],
-    }
-
-    with open(path, "r", encoding="utf-8") as f:
-        for line in f.readlines():
-            splits = line.strip().split("|", maxsplit=3)
-            if len(splits) != 4:
-                logger.warning(f"Invalid line: {line}")
-                continue
-
-            filename, speaker, language, text = splits
-            file = Path(filename)
-            language = language.strip().lower()
-
-            if language == "ja":
-                language = "jp"
-
-            assert language in ["zh", "jp", "en"], f"Invalid language {language}"
-            languages = LANGUAGE_TO_LANGUAGES[language]
-
-            if file in files:
-                logger.warning(f"Duplicated file: {file}")
-                count_duplicated += 1
-                continue
-
-            if not file.exists():
-                logger.warning(f"File not found: {file}")
-                count_not_found += 1
-                continue
-
-            results.append((file, speaker, languages, text))
-
-    if count_duplicated > 0:
-        logger.warning(f"Total duplicated files: {count_duplicated}")
-
-    if count_not_found > 0:
-        logger.warning(f"Total files not found: {count_not_found}")
-
-    return results

+ 1 - 1
tools/fish_e2e.py

@@ -13,7 +13,7 @@ import numpy as np
 import ormsgpack
 import soundfile as sf
 
-from .schema import (
+from fish_speech.utils.schema import (
     ServeChatRequest,
     ServeMessage,
     ServeTextPart,

+ 1 - 2
tools/llama/build_dataset.py

@@ -12,8 +12,7 @@ from loguru import logger
 from tqdm import tqdm
 
 from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
-from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
-from tools.file import load_filelist
+from fish_speech.datasets.protos.text_data_stream import load_filelist, pack_pb_stream
 
 # To avoid CPU overload
 os.environ["MKL_NUM_THREADS"] = "1"

+ 1 - 1
tools/llama/eval_in_context.py

@@ -10,7 +10,7 @@ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
 from torch.utils.data import DataLoader
 
 from fish_speech.datasets.semantic import AutoAugTextDataset, TextDataCollator
-from tools.llama.generate import load_model
+from fish_speech.models.text2semantic.inference import load_model
 
 
 def smooth(

+ 8 - 1106
tools/llama/generate.py

@@ -1,1114 +1,16 @@
 import os
-import queue
-import threading
-import time
-from contextlib import nullcontext
-from dataclasses import dataclass
-from pathlib import Path
-from typing import Literal, Optional, Tuple, Union
+import subprocess
+import sys
 
-import click
-import hydra
-import numpy as np
-import torch
-import torch._dynamo.config
-import torch._inductor.config
-from loguru import logger
-from tqdm import tqdm
-from transformers import AutoTokenizer
+#!/usr/bin/env python
 
-from fish_speech.conversation import (
-    CODEBOOK_PAD_TOKEN_ID,
-    Conversation,
-    Message,
-    TextPart,
-    VQPart,
-)
-from fish_speech.models.text2semantic.llama import BaseModelArgs
-from fish_speech.text import clean_text, split_text
-from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
 
-os.environ["TOKENIZERS_PARALLELISM"] = "false"
-torch._inductor.config.coordinate_descent_tuning = True
-torch._inductor.config.triton.unique_kernel_names = True
-
-if hasattr(torch._inductor.config, "fx_graph_cache"):
-    # Experimental feature to reduce compilation times, will be on by default in future
-    torch._inductor.config.fx_graph_cache = True
-
-
-from torch.nn.attention import SDPBackend, sdpa_kernel
-
-from fish_speech.models.text2semantic.llama import (
-    BaseTransformer,
-    DualARTransformer,
-    NaiveTransformer,
-)
-
-
-def multinomial_sample_one_no_sync(
-    probs_sort,
-):  # Does multinomial sampling without a cuda synchronization
-    q = torch.empty_like(probs_sort).exponential_(1)
-    return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
-
-
-def logits_to_probs(
-    logits,
-    previous_tokens: Optional[torch.Tensor] = None,
-    temperature: torch.Tensor = 1.0,
-    top_p: torch.Tensor = 1.0,
-    repetition_penalty: torch.Tensor = 1.0,
-) -> torch.Tensor:
-    # Apply repetition penalty
-    if previous_tokens is not None:
-        previous_tokens = previous_tokens.long()
-        score = torch.gather(logits, dim=0, index=previous_tokens)
-        score = torch.where(
-            score < 0, score * repetition_penalty, score / repetition_penalty
-        )
-        logits.scatter_(dim=0, index=previous_tokens, src=score)
-
-    # Apply top-p sampling
-    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
-    cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
-    sorted_indices_to_remove = cum_probs > top_p
-    sorted_indices_to_remove[0] = False  # keep at least one option
-    indices_to_remove = sorted_indices_to_remove.scatter(
-        dim=0, index=sorted_indices, src=sorted_indices_to_remove
-    )
-    logits = logits.masked_fill(indices_to_remove, -float("Inf"))
-
-    logits = logits / max(temperature, 1e-5)
-
-    probs = torch.nn.functional.softmax(logits, dim=-1)
-    return probs
-
-
-def multinomial_sample_one_no_sync_agent(
-    probs_sort,
-):  # Does multinomial sampling without a cuda synchronization
-    q = torch.empty_like(probs_sort).exponential_(1)
-    return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
-
-
-def logits_to_probs_agent(
-    logits,
-    previous_tokens: Optional[torch.Tensor] = None,
-    temperature: torch.Tensor = 1.0,
-    top_p: torch.Tensor = 1.0,
-    repetition_penalty: torch.Tensor = 1.0,
-) -> torch.Tensor:
-    # Apply repetition penalty
-    if previous_tokens is not None:
-        previous_tokens = previous_tokens.long()
-        score = torch.gather(logits, dim=-1, index=previous_tokens)
-        score = torch.where(
-            score < 0, score * repetition_penalty, score / repetition_penalty
-        )
-        logits.scatter_(dim=-1, index=previous_tokens, src=score)
-
-    # Apply top-p sampling
-    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
-    cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
-    sorted_indices_to_remove = cum_probs > top_p
-    sorted_indices_to_remove[..., 0] = False  # keep at least one option
-    indices_to_remove = sorted_indices_to_remove.scatter(
-        dim=-1, index=sorted_indices, src=sorted_indices_to_remove
-    )
-    logits = logits.masked_fill(indices_to_remove, -float("Inf"))
-
-    logits = logits / max(temperature, 1e-5)
-
-    probs = torch.nn.functional.softmax(logits, dim=-1)
-    return probs
-
-
-def sample(
-    logits,
-    previous_tokens: Optional[torch.Tensor] = None,
-    **sampling_kwargs,
-) -> Tuple[torch.Tensor, torch.Tensor]:
-    probs = logits_to_probs(
-        logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
-    )
-    idx_next = multinomial_sample_one_no_sync(probs)
-    return idx_next, probs
-
-
-def sample_agent(
-    logits,
-    previous_tokens: Optional[torch.Tensor] = None,
-    **sampling_kwargs,
-) -> Tuple[torch.Tensor, torch.Tensor]:
-    probs = logits_to_probs_agent(
-        logits=logits[:, -1], previous_tokens=previous_tokens, **sampling_kwargs
-    )
-    idx_next = multinomial_sample_one_no_sync_agent(probs)
-    return idx_next, probs
-
-
-def decode_one_token_ar_agent(
-    model: DualARTransformer,
-    x: torch.Tensor,
-    input_pos: torch.Tensor,
-    semantic_ids: list,
-    previous_tokens: torch.Tensor = None,
-    **sampling_kwargs,
-) -> torch.Tensor:
-    # print(x, input_pos)
-    x = model.forward_generate(x, input_pos)
-    logits = x.logits  # [:, -1:]
-    hidden_states = x.hidden_states  # [:, -1:]
-
-    sampling_kwargs_main = sampling_kwargs.copy()
-    sampling_kwargs_main["temperature"] = 0.1
-    sampling_kwargs_main["top_p"] = 0.1
-    sampling_kwargs_main["repetition_penalty"] = 1.0
-
-    codebooks = [
-        sample_agent(
-            logits,
-            previous_tokens=None,  # Disable repetition penalty for the token codebook
-            **sampling_kwargs_main,
-        )[0]
-    ]
-
-    # Cleanup the cache
-    for layer in model.fast_layers:
-        layer.attention.kv_cache.k_cache.fill_(0)
-        layer.attention.kv_cache.v_cache.fill_(0)
-
-    for codebook_idx in range(model.config.num_codebooks):
-        input_pos = torch.tensor(
-            [codebook_idx], device=hidden_states.device, dtype=torch.long
-        )
-        logits = model.forward_generate_fast(hidden_states, input_pos)
-        a = sample_agent(
-            logits,
-            previous_tokens=(
-                previous_tokens[:, codebook_idx + 1]
-                if previous_tokens is not None
-                else None
-            ),
-            **sampling_kwargs,
-        )[0]
-        hidden_states = model.fast_embeddings(a)
-        codebooks.append(a)
-
-    codebooks = torch.stack(codebooks, dim=1)
-    semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
-    codebooks[:, 1:, :] = torch.masked_fill(
-        codebooks[:, 1:, :],
-        ~torch.isin(codebooks[:, :1, :], semantic_ids_tensor),
-        CODEBOOK_PAD_TOKEN_ID,
-    )
-
-    return codebooks
-
-
-def decode_one_token_naive_agent(
-    model: NaiveTransformer,
-    x: torch.Tensor,
-    input_pos: torch.Tensor,
-    semantic_ids: list,
-    previous_tokens: torch.Tensor = None,
-    **sampling_kwargs,
-) -> torch.Tensor:
-    x = model.forward_generate(x, input_pos)
-
-    codebooks = [
-        sample(
-            x.token_logits,
-            previous_tokens=None,  # Disable repetition penalty for the token codebook
-            **sampling_kwargs,
-        )[0]
-    ]
-
-    for i in range(model.config.num_codebooks):
-        codebooks.append(
-            sample_agent(
-                x.codebook_logits[:, :, i],
-                previous_tokens=(
-                    previous_tokens[:, i + 1] if previous_tokens is not None else None
-                ),
-                **sampling_kwargs,
-            )[0]
-        )
-
-    codebooks = torch.stack(codebooks, dim=1)
-    semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
-    codebooks[:, 1:, :] = torch.masked_fill(
-        codebooks[:, 1:, :],
-        ~torch.isin(codebooks[:, :1, :], semantic_ids_tensor),
-        CODEBOOK_PAD_TOKEN_ID,
-    )
-
-    return codebooks
-
-
-def decode_one_token_ar(
-    model: DualARTransformer,
-    x: torch.Tensor,
-    input_pos: torch.Tensor,
-    semantic_ids: list,
-    previous_tokens: torch.Tensor = None,
-    **sampling_kwargs,
-) -> torch.Tensor:
-    x = model.forward_generate(x, input_pos)
-
-    sampling_kwargs_main = sampling_kwargs.copy()
-    # sampling_kwargs_main["temperature"] = 0.1
-    # sampling_kwargs_main["top_p"] = 0.1
-    # sampling_kwargs_main["repetition_penalty"] = 1.0
-
-    codebooks = [
-        sample(
-            x.logits,
-            previous_tokens=(
-                previous_tokens[0] if previous_tokens is not None else None
-            ),  # Disable repetition penalty for the token codebook
-            **sampling_kwargs_main,
-        )[0]
-    ]
-
-    hidden_states = x.hidden_states
-
-    # Cleanup the cache
-    for layer in model.fast_layers:
-        layer.attention.kv_cache.k_cache.fill_(0)
-        layer.attention.kv_cache.v_cache.fill_(0)
-
-    input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
-    model.forward_generate_fast(hidden_states, input_pos)
-    a = codebooks[0] - model.tokenizer.semantic_begin_id
-    a[a < 0] = 0
-    hidden_states = model.fast_embeddings(a)
-    codebooks.append(a)
-
-    for codebook_idx in range(1, model.config.num_codebooks):
-        input_pos = torch.tensor(
-            [codebook_idx], device=hidden_states.device, dtype=torch.long
-        )
-        logits = model.forward_generate_fast(hidden_states, input_pos)
-        a = sample(
-            logits,
-            previous_tokens=(
-                previous_tokens[codebook_idx + 1]
-                if previous_tokens is not None
-                else None
-            ),
-            **sampling_kwargs,
-        )[0]
-        hidden_states = model.fast_embeddings(a)
-        codebooks.append(a)
-
-    codebooks = torch.stack(codebooks, dim=0)
-    # semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
-    # codebooks[1:, :] = torch.masked_fill(
-    #     codebooks[1:, :], ~torch.isin(codebooks[:1, :], semantic_ids_tensor), CODEBOOK_PAD_TOKEN_ID
-    # )
-
-    # print(codebooks)
-    return codebooks
-
-
-def decode_one_token_naive(
-    model: NaiveTransformer,
-    x: torch.Tensor,
-    input_pos: torch.Tensor,
-    previous_tokens: torch.Tensor = None,
-    **sampling_kwargs,
-) -> torch.Tensor:
-    x = model.forward_generate(x, input_pos)
-
-    sampling_kwargs_main = sampling_kwargs.copy()
-    sampling_kwargs_main["temperature"] = 0.1
-    sampling_kwargs_main["top_p"] = 0.1
-    sampling_kwargs_main["repetition_penalty"] = 1.0
-
-    codebooks = [
-        sample(
-            x.logits,
-            previous_tokens=None,  # Disable repetition penalty for the token codebook
-            **sampling_kwargs_main,
-        )[0]
-    ]
-
-    for i in range(model.config.num_codebooks):
-        codebooks.append(
-            sample(
-                x.codebook_logits[:, :, i],
-                previous_tokens=(
-                    previous_tokens[i + 1] if previous_tokens is not None else None
-                ),
-                **sampling_kwargs,
-            )[0]
-        )
-
-    return torch.stack(codebooks, dim=0)
-
-
-def decode_n_tokens(
-    model: NaiveTransformer,
-    cur_token: torch.Tensor,
-    input_pos: torch.Tensor,
-    num_new_tokens: int,
-    semantic_ids: list,
-    decode_one_token=decode_one_token_naive,
-    **sampling_kwargs,
-):
-    previous_tokens = torch.zeros(
-        (model.config.num_codebooks + 1, model.config.max_seq_len),
-        dtype=torch.int,
-        device=cur_token.device,
-    )
-
-    for i in tqdm(range(num_new_tokens)):
-        # We need to get windowed repeat penalty
-        win_size = 16
-        if i < win_size:
-            window = previous_tokens[:, :win_size]
-        else:
-            window = previous_tokens[:, i - win_size : i]
-
-        with (
-            torch.backends.cuda.sdp_kernel(
-                enable_flash=False, enable_mem_efficient=False, enable_math=True
-            )
-            if torch.cuda.is_available()
-            else nullcontext()
-        ):  # Actually better for Inductor to codegen attention here
-            next_token = decode_one_token(
-                model=model,
-                x=cur_token,
-                input_pos=input_pos,
-                previous_tokens=window,
-                semantic_ids=semantic_ids,
-                **sampling_kwargs,
-            )
-
-        input_pos += 1
-        cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
-        previous_tokens[:, i : i + 1] = next_token.view(
-            model.config.num_codebooks + 1, -1
-        )
-
-        if cur_token[0, 0, -1] == model.tokenizer.get_token_id(IM_END_TOKEN):
-            break
-
-    return previous_tokens[:, : i + 1]
-
-
-@torch.no_grad()
-@torch.inference_mode()
-def generate(
-    *,
-    model: NaiveTransformer,
-    prompt: torch.Tensor,
-    max_new_tokens: int,
-    decode_one_token=decode_one_token_naive,
-    **sampling_kwargs,
-) -> torch.Tensor:
-    """
-    Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
-    """
-
-    # create an empty tensor of the expected final shape and fill in the current tokens
-    T = prompt.size(1)
-    # semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
-    semantic_ids = [
-        model.tokenizer.get_token_id(f"<|semantic:{i}|>") for i in range(1024)
-    ]
-
-    if max_new_tokens:
-        if T + max_new_tokens > model.config.max_seq_len:
-            max_new_tokens = model.config.max_seq_len - T
-            logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
-
-        T_new = T + max_new_tokens
-    else:
-        T_new = model.config.max_seq_len
-        max_new_tokens = T_new - T
-
-    device, dtype = prompt.device, prompt.dtype
-
-    codebook_dim = 1 + model.config.num_codebooks
-    # create an empty tensor of the expected final shape and fill in the current tokens
-    empty = torch.empty(
-        (codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
-    )
-    empty[:, :T] = prompt
-    seq = empty
-    input_pos = torch.arange(0, T, device=device)
-
-    # Use non-accelerated version for now, to avoid compilation overhead
-    prefill_decode = (
-        decode_one_token_naive
-        if isinstance(model, NaiveTransformer)
-        else decode_one_token_ar
-    )
-
-    next_token = prefill_decode(
-        model,
-        prompt.view(1, codebook_dim, -1),
-        input_pos,
-        semantic_ids=semantic_ids,
-        **sampling_kwargs,
-    )
-    seq[:, T : T + 1] = next_token
-
-    input_pos = torch.tensor([T], device=device, dtype=torch.int)
-    x = decode_n_tokens(
-        model,
-        next_token.view(1, codebook_dim, -1),
-        input_pos,
-        max_new_tokens - 1,
-        decode_one_token=decode_one_token,
-        semantic_ids=semantic_ids,
-        **sampling_kwargs,
-    )
-    # x = torch.cat(generated_tokens, dim=1)
-    seq = seq[:, : T + 1 + x.size(1)]
-    seq[:, T + 1 :] = x
-
-    return seq
-
-
-def decode_n_tokens_agent(
-    model: NaiveTransformer,
-    cur_token: torch.Tensor,
-    input_pos: torch.Tensor,
-    num_new_tokens: int,
-    semantic_ids: list,
-    im_end_id: int = 4,
-    decode_one_token=decode_one_token_naive_agent,
-    early_stop_threshold: float = 0.6,
-    **sampling_kwargs,
-):
-    batch_size = cur_token.size(0)
-    previous_tokens = torch.zeros(
-        (batch_size, model.config.num_codebooks + 1, model.config.max_seq_len),
-        dtype=torch.int,
-        device=cur_token.device,
-    )
-    finished = torch.zeros(batch_size, dtype=torch.bool, device=cur_token.device)
-    finished = finished | (cur_token[:, 0, -1] == im_end_id)
-    start_time = time.time()
-
-    for i in tqdm(range(num_new_tokens), desc="Decoding: ", total=num_new_tokens):
-        # We need to get windowed repeat penalty
-        win_size = 16
-        if i < win_size:
-            window = previous_tokens[:, :, :win_size]
-        else:
-            window = previous_tokens[:, :, i - win_size : i]
-
-        with sdpa_kernel(
-            SDPBackend.MATH
-        ):  # Actually better for Inductor to codegen attention here
-            next_token = decode_one_token(
-                model=model,
-                x=cur_token,
-                input_pos=input_pos,
-                previous_tokens=window,
-                semantic_ids=semantic_ids,
-                **sampling_kwargs,
-            )
-
-        input_pos += 1
-        cur_token = next_token.view(batch_size, model.config.num_codebooks + 1, -1)
-        previous_tokens[:, :, i : i + 1] = next_token.view(
-            batch_size, model.config.num_codebooks + 1, -1
-        )
-
-        yield cur_token.cpu()
-
-        finished = finished | (cur_token[:, 0, -1] == im_end_id)
-        if finished.all() or (
-            0 < early_stop_threshold < 1
-            and finished.sum() >= round(batch_size * early_stop_threshold)
-        ):
-            break
-
-    total_time = time.time() - start_time
-    generated_tokens = i + 1
-    tokens_per_second = (generated_tokens / total_time) * batch_size
-    logger.info(
-        f"Decoded {generated_tokens} x {batch_size} tokens in {total_time:.2f}s ({tokens_per_second:.2f} tokens/s)"
-    )
-
-
-@torch.no_grad()
-@torch.inference_mode()
-def generate_agent(
-    *,
-    model: BaseTransformer,
-    prompt: torch.Tensor,
-    max_new_tokens: int,
-    semantic_ids: list,
-    im_end_id: int = 4,
-    decode_one_token=decode_one_token_naive_agent,
-    num_samples: int = 1,
-    early_stop_threshold: float = 0.6,
-    **sampling_kwargs,
-):
-    """
-    Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
-    """
-
-    # create an empty tensor of the expected final shape and fill in the current tokens
-    T = prompt.size(1)
-    prompt = prompt[None].repeat(num_samples, 1, 1)
-
-    if T >= model.config.max_seq_len:
-        raise ValueError(
-            f"Input sequence length {T} exceeds max_seq_len {model.config.max_seq_len}"
-        )
-
-    if max_new_tokens:
-        if T + max_new_tokens > model.config.max_seq_len:
-            max_new_tokens = model.config.max_seq_len - T
-            logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
-
-        T_new = T + max_new_tokens
-    else:
-        T_new = model.config.max_seq_len
-        max_new_tokens = T_new - T
-
-    device, dtype = prompt.device, prompt.dtype
-
-    codebook_dim = 1 + model.config.num_codebooks
-    input_pos = torch.arange(0, T, device=device)
-
-    # Use non-accelerated version for now, to avoid compilation overhead
-    prefill_decode = (
-        decode_one_token_naive_agent
-        if isinstance(model, NaiveTransformer)
-        else decode_one_token_ar_agent
-    )
-    next_token = prefill_decode(
-        model,
-        prompt,
-        input_pos,
-        semantic_ids=semantic_ids,
-        **sampling_kwargs,
-    ).view(num_samples, codebook_dim, -1)
-    yield next_token.cpu()
-
-    input_pos = torch.tensor([T], device=device, dtype=torch.int)
-
-    yield from decode_n_tokens_agent(
-        model,
-        next_token,
-        input_pos,
-        max_new_tokens - 1,
-        im_end_id=im_end_id,
-        semantic_ids=semantic_ids,
-        decode_one_token=decode_one_token,
-        early_stop_threshold=early_stop_threshold,
-        **sampling_kwargs,
-    )
-
-
-def encode_tokens(
-    tokenizer,
-    string,
-    device="cuda",
-    prompt_tokens=None,
-    num_codebooks=4,
-):
-    string = clean_text(string)
-
-    messages = []
-    messages.append(
-        Message(
-            role="user",
-            parts=[TextPart(text=string)],
-            cal_loss=False,
-        )
-    )
-
-    if prompt_tokens is not None:
-        if prompt_tokens.ndim == 3:
-            assert (
-                prompt_tokens.shape[0] == 1
-            ), "3D prompt tokens should have shape (1, num_codebooks, seq_len)"
-            prompt_tokens = prompt_tokens[0]
-
-        assert prompt_tokens.ndim == 2, "Prompt tokens should be 2D tensor"
-
-        if prompt_tokens.shape[0] > num_codebooks:
-            logger.warning(
-                f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
-            )
-            prompt_tokens = prompt_tokens[:num_codebooks]
-
-        vq_part = VQPart(codes=prompt_tokens.to(device))
-
-        messages.append(
-            Message(
-                role="assistant",
-                parts=[TextPart(text="<|voice|>"), vq_part],
-                cal_loss=False,
-            )
-        )
-    else:
-        messages.append(
-            Message(
-                role="assistant",
-                parts=[TextPart(text="<|voice|>")],
-                cal_loss=False,
-                add_im_end=False,
-            )
-        )
-
-    conversation = Conversation(messages=messages)
-    # conversation.visualize(tokenizer)
-    encoded = conversation.encode_for_inference(
-        tokenizer=tokenizer,
-        num_codebooks=num_codebooks,
-    )
-
-    return encoded.to(device)
-
-
-def load_model(checkpoint_path, device, precision, compile=False, is_agent=False):
-    model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
-        checkpoint_path, load_weights=True, is_agent=is_agent
-    )
-
-    model = model.to(device=device, dtype=precision)
-    logger.info(f"Restored model from checkpoint")
-
-    if isinstance(model, DualARTransformer):
-        decode_one_token = (
-            decode_one_token_ar_agent if is_agent else decode_one_token_ar
-        )
-        logger.info("Using DualARTransformer")
-    else:
-        decode_one_token = (
-            decode_one_token_naive_agent if is_agent else decode_one_token_naive
-        )
-        logger.info("Using NaiveTransformer")
-
-    if compile:
-        logger.info("Compiling function...")
-        decode_one_token = torch.compile(
-            decode_one_token,
-            fullgraph=True,
-            backend="inductor" if torch.cuda.is_available() else "aot_eager",
-            mode="reduce-overhead" if torch.cuda.is_available() else None,
-        )
-
-    return model.eval(), decode_one_token
-
-
-@dataclass
-class GenerateResponse:
-    action: Literal["sample", "next"]
-    codes: Optional[torch.Tensor] = None
-    text: Optional[str] = None
-
-
-def generate_long(
-    *,
-    model,
-    device: str | torch.device,
-    decode_one_token: callable,
-    text: str,
-    num_samples: int = 1,
-    max_new_tokens: int = 0,
-    top_p: int = 0.7,
-    repetition_penalty: float = 1.5,
-    temperature: float = 0.7,
-    compile: bool = False,
-    iterative_prompt: bool = True,
-    max_length: int = 2048,
-    chunk_length: int = 150,
-    prompt_text: Optional[str | list[str]] = None,
-    prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
-):
-    assert 0 < top_p <= 1, "top_p must be in (0, 1]"
-    assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
-    assert 0 < temperature < 2, "temperature must be in (0, 2)"
-
-    use_prompt = prompt_text is not None and prompt_tokens is not None
-    if use_prompt and isinstance(prompt_text, str):
-        prompt_text = [prompt_text]
-        prompt_tokens = [prompt_tokens]
-
-    assert use_prompt is False or len(prompt_text) == len(
-        prompt_tokens
-    ), "Prompt text and tokens must have the same length"
-
-    model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
-    tokenizer = model.tokenizer
-    im_end_id = tokenizer.get_token_id("<|im_end|>")
-
-    encoded = []
-    texts = split_text(text, chunk_length) if iterative_prompt else [text]
-    encoded_prompts = [
-        Conversation(
-            messages=[
-                Message(
-                    role="system",
-                    parts=[TextPart(text="Speak out the provided text.")],
-                    cal_loss=False,
-                )
-            ]
-        )
-        .encode_for_inference(
-            tokenizer=tokenizer,
-            num_codebooks=model.config.num_codebooks,
-        )
-        .to(device)
-    ]
-
-    if use_prompt:
-        for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
-            encoded_prompts.append(
-                encode_tokens(
-                    tokenizer,
-                    string=t,
-                    device=device,
-                    prompt_tokens=c,
-                    num_codebooks=model.config.num_codebooks,
-                )
-            )
-
-    for idx, text in enumerate(texts):
-        encoded.append(
-            encode_tokens(
-                tokenizer,
-                string=text,
-                device=device,
-                num_codebooks=model.config.num_codebooks,
-            )
-        )
-        logger.info(f"Encoded text: {text}")
-
-    # Move temperature, top_p, repetition_penalty to device
-    # This is important so that changing params doesn't trigger recompile
-    temperature = torch.tensor(temperature, device=device, dtype=torch.float)
-    top_p = torch.tensor(top_p, device=device, dtype=torch.float)
-    repetition_penalty = torch.tensor(
-        repetition_penalty, device=device, dtype=torch.float
-    )
-
-    for sample_idx in range(num_samples):
-        if torch.cuda.is_available():
-            torch.cuda.synchronize()
-
-        global_encoded = []
-        seg_idx = 0
-
-        while seg_idx < len(encoded):
-            logger.info(
-                f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
-            )
-
-            seg = encoded[seg_idx]
-            global_encoded.append(seg)
-
-            lengths = reversed([seg.size(1) for seg in global_encoded])
-
-            # Pick last 2000 tokens
-            count = 0
-            for i, length in enumerate(lengths):
-                count += length
-                if count + length > max_length - 1024 - sum(
-                    t.shape[1] for t in encoded_prompts
-                ):
-                    break
-
-            if i != 0 and i % 2 == 0:
-                i -= 1
-
-            # Rotate the list, always make sure first segment is included to avoid drift
-            if i < len(global_encoded) - 2:
-                partial_encoded = global_encoded[:2] + global_encoded[-i:]
-            else:
-                partial_encoded = global_encoded
-
-            if use_prompt:
-                partial_encoded = encoded_prompts + partial_encoded
-
-            cat_encoded = torch.cat(partial_encoded, dim=1)
-            prompt_length = cat_encoded.size(1)
-
-            t0 = time.perf_counter()
-            y = generate(
-                model=model,
-                prompt=cat_encoded,
-                max_new_tokens=max_new_tokens,
-                decode_one_token=decode_one_token,
-                temperature=temperature,
-                top_p=top_p,
-                repetition_penalty=repetition_penalty,
-            )
-
-            if sample_idx == 0 and seg_idx == 0 and compile:
-                logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
-
-            if torch.cuda.is_available():
-                torch.cuda.synchronize()
-
-            t = time.perf_counter() - t0
-
-            tokens_generated = y.size(1) - prompt_length
-            tokens_sec = tokens_generated / t
-            logger.info(
-                f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
-            )
-            logger.info(
-                f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
-            )
-
-            if torch.cuda.is_available():
-                logger.info(
-                    f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
-                )
-
-            # Put the generated tokens
-            # since there is <im_end>, we remove last token
-            codes = y[1:, prompt_length + 1 :].clone()
-            assert (codes >= 0).all(), f"Negative code found"
-
-            decoded = y[:, prompt_length:].clone()
-            # But for global encoding, we should keep the <im_end> token
-
-            global_encoded.append(decoded)
-            assert (codes >= 0).all(), f"Negative code found: {codes}"
-            yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
-            seg_idx += 1
-
-        # This indicates the end of the current sample
-        yield GenerateResponse(action="next")
-
-
-@dataclass
-class WrappedGenerateResponse:
-    status: Literal["success", "error"]
-    response: Optional[GenerateResponse | Exception] = None
-
-
-@dataclass
-class GenerateRequest:
-    request: dict
-    response_queue: queue.Queue
-
-
-def launch_thread_safe_queue(
-    checkpoint_path,
-    device,
-    precision,
-    compile: bool = False,
-):
-    input_queue = queue.Queue()
-    init_event = threading.Event()
-
-    def worker():
-        model, decode_one_token = load_model(
-            checkpoint_path, device, precision, compile=compile
-        )
-        with torch.device(device):
-            model.setup_caches(
-                max_batch_size=1,
-                max_seq_len=model.config.max_seq_len,
-                dtype=next(model.parameters()).dtype,
-            )
-        init_event.set()
-
-        while True:
-            item: GenerateRequest | None = input_queue.get()
-            if item is None:
-                break
-
-            kwargs = item.request
-            response_queue = item.response_queue
-
-            try:
-                for chunk in generate_long(
-                    model=model, decode_one_token=decode_one_token, **kwargs
-                ):
-                    response_queue.put(
-                        WrappedGenerateResponse(status="success", response=chunk)
-                    )
-            except Exception as e:
-                response_queue.put(WrappedGenerateResponse(status="error", response=e))
-
-    threading.Thread(target=worker, daemon=True).start()
-    init_event.wait()
-
-    return input_queue
-
-
-def launch_thread_safe_queue_agent(
-    checkpoint_path,
-    device,
-    precision,
-    compile: bool = False,
-):
-    input_queue = queue.Queue()
-    init_event = threading.Event()
-
-    tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
-    config = BaseModelArgs.from_pretrained(checkpoint_path)
-
-    def worker():
-        model, decode_one_token = load_model(
-            checkpoint_path, device, precision, compile=compile, is_agent=True
-        )
-
-        with torch.device(device):
-            model.setup_caches(
-                max_batch_size=1,
-                max_seq_len=model.config.max_seq_len,
-                dtype=next(model.parameters()).dtype,
-            )
-        init_event.set()
-
-        while True:
-            item: GenerateRequest | None = input_queue.get()
-            if item is None:
-                break
-
-            kwargs = item.request
-            response_queue = item.response_queue
-
-            try:
-                for token in generate_agent(
-                    model=model,
-                    decode_one_token=decode_one_token,
-                    **kwargs,
-                ):
-                    response_queue.put(token)
-
-                response_queue.put("stop")
-            except Exception as e:
-                import traceback
-
-                logger.exception(f"Error in worker: {traceback.format_exc()}")
-                response_queue.put("error")
-
-    threading.Thread(target=worker, daemon=True).start()
-    init_event.wait()
-
-    return input_queue, tokenizer, config
-
-
-@click.command()
-@click.option(
-    "--text",
-    type=str,
-    default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
-)
-@click.option("--prompt-text", type=str, default=None, multiple=True)
-@click.option(
-    "--prompt-tokens",
-    type=click.Path(path_type=Path, exists=True),
-    default=None,
-    multiple=True,
-)
-@click.option("--num-samples", type=int, default=1)
-@click.option("--max-new-tokens", type=int, default=0)
-@click.option("--top-p", type=float, default=0.7)
-@click.option("--repetition-penalty", type=float, default=1.2)
-@click.option("--temperature", type=float, default=0.7)
-@click.option(
-    "--checkpoint-path",
-    type=click.Path(path_type=Path, exists=True),
-    default="checkpoints/fish-speech-1.5",
-)
-@click.option("--device", type=str, default="cuda")
-@click.option("--compile/--no-compile", default=False)
-@click.option("--seed", type=int, default=42)
-@click.option("--half/--no-half", default=False)
-@click.option("--iterative-prompt/--no-iterative-prompt", default=True)
-@click.option("--chunk-length", type=int, default=100)
-def main(
-    text: str,
-    prompt_text: Optional[list[str]],
-    prompt_tokens: Optional[list[Path]],
-    num_samples: int,
-    max_new_tokens: int,
-    top_p: int,
-    repetition_penalty: float,
-    temperature: float,
-    checkpoint_path: Path,
-    device: str,
-    compile: bool,
-    seed: int,
-    half: bool,
-    iterative_prompt: bool,
-    chunk_length: int,
-) -> None:
-
-    precision = torch.half if half else torch.bfloat16
-
-    if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
-        raise ValueError(
-            f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
-        )
-
-    logger.info("Loading model ...")
-    t0 = time.time()
-    model, decode_one_token = load_model(
-        checkpoint_path, device, precision, compile=compile
-    )
-    with torch.device(device):
-        model.setup_caches(
-            max_batch_size=1,
-            max_seq_len=model.config.max_seq_len,
-            dtype=next(model.parameters()).dtype,
-        )
-    if torch.cuda.is_available():
-        torch.cuda.synchronize()
-
-    logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
-
-    if prompt_tokens is not None:
-        prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
-
-    torch.manual_seed(seed)
-
-    if torch.cuda.is_available():
-        torch.cuda.manual_seed(seed)
-
-    generator = generate_long(
-        model=model,
-        device=device,
-        decode_one_token=decode_one_token,
-        text=text,
-        num_samples=num_samples,
-        max_new_tokens=max_new_tokens,
-        top_p=top_p,
-        repetition_penalty=repetition_penalty,
-        temperature=temperature,
-        compile=compile,
-        iterative_prompt=iterative_prompt,
-        chunk_length=chunk_length,
-        prompt_text=prompt_text,
-        prompt_tokens=prompt_tokens,
+def main():
+    # Make path relative to this file
+    script_path = os.path.join(
+        os.path.dirname(__file__), "../../fish_speech/models/text2semantic/inference.py"
     )
-
-    idx = 0
-    codes = []
-
-    for response in generator:
-        if response.action == "sample":
-            codes.append(response.codes)
-            logger.info(f"Sampled text: {response.text}")
-        elif response.action == "next":
-            if codes:
-                np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy())
-                logger.info(f"Saved codes to codes_{idx}.npy")
-            logger.info(f"Next sample")
-            codes = []
-            idx += 1
-        else:
-            logger.error(f"Error: {response}")
+    subprocess.run(["python", script_path] + sys.argv[1:])
 
 
 if __name__ == "__main__":

+ 1 - 1
tools/llama/quantize.py

@@ -13,8 +13,8 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
+from fish_speech.models.text2semantic.inference import load_model
 from fish_speech.models.text2semantic.llama import find_multiple
-from tools.llama.generate import load_model
 
 ##### Quantization Primitives ######
 

+ 4 - 4
tools/run_webui.py

@@ -8,10 +8,10 @@ from loguru import logger
 
 pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
 
-from tools.inference_engine import TTSInferenceEngine
-from tools.llama.generate import launch_thread_safe_queue
-from tools.schema import ServeTTSRequest
-from tools.vqgan.inference import load_model as load_decoder_model
+from fish_speech.inference_engine import TTSInferenceEngine
+from fish_speech.models.text2semantic.inference import launch_thread_safe_queue
+from fish_speech.models.vqgan.inference import load_model as load_decoder_model
+from fish_speech.utils.schema import ServeTTSRequest
 from tools.webui import build_app
 from tools.webui.inference import get_inference_wrapper
 

+ 1 - 1
tools/sensevoice/fun_asr.py

@@ -17,7 +17,7 @@ from pydub import AudioSegment
 from silero_vad import get_speech_timestamps, load_silero_vad, read_audio
 from tqdm import tqdm
 
-from tools.file import AUDIO_EXTENSIONS, VIDEO_EXTENSIONS, list_files
+from fish_speech.utils.file import AUDIO_EXTENSIONS, VIDEO_EXTENSIONS, list_files
 from tools.sensevoice.auto_model import AutoModel
 
 

+ 1 - 1
tools/server/agent/generate.py

@@ -1,6 +1,6 @@
 import time
 
-from tools.schema import ServeMessage, ServeResponse, ServeStreamResponse
+from fish_speech.utils.schema import ServeMessage, ServeResponse, ServeStreamResponse
 from tools.server.agent.generation_utils import (
     initialize_decode_buffers,
     process_response_tokens,

+ 1 - 1
tools/server/agent/generation_utils.py

@@ -1,6 +1,6 @@
 import time
 
-from tools.schema import (
+from fish_speech.utils.schema import (
     ServeStreamDelta,
     ServeStreamResponse,
     ServeTextPart,

+ 1 - 1
tools/server/agent/pre_generation_utils.py

@@ -1,8 +1,8 @@
 import queue
 
 from fish_speech.conversation import Conversation, Message
+from fish_speech.models.text2semantic.inference import GenerateRequest
 from fish_speech.tokenizer import IM_END_TOKEN
-from tools.llama.generate import GenerateRequest
 
 
 def prepare_messages(request, tokenizer, config):

+ 2 - 2
tools/server/api_utils.py

@@ -6,8 +6,8 @@ import ormsgpack
 from baize.datastructures import ContentType
 from kui.asgi import HTTPException, HttpRequest
 
-from tools.inference_engine import TTSInferenceEngine
-from tools.schema import ServeTTSRequest
+from fish_speech.inference_engine import TTSInferenceEngine
+from fish_speech.utils.schema import ServeTTSRequest
 from tools.server.inference import inference_wrapper as inference
 
 

+ 2 - 2
tools/server/inference.py

@@ -3,8 +3,8 @@ from http import HTTPStatus
 import numpy as np
 from kui.asgi import HTTPException
 
-from tools.inference_engine import TTSInferenceEngine
-from tools.schema import ServeTTSRequest
+from fish_speech.inference_engine import TTSInferenceEngine
+from fish_speech.utils.schema import ServeTTSRequest
 
 AMPLITUDE = 32768  # Needs an explaination
 

+ 4 - 4
tools/server/model_manager.py

@@ -2,14 +2,14 @@ import torch
 from funasr import AutoModel
 from loguru import logger
 
-from tools.inference_engine import TTSInferenceEngine
-from tools.llama.generate import (
+from fish_speech.inference_engine import TTSInferenceEngine
+from fish_speech.models.text2semantic.inference import (
     launch_thread_safe_queue,
     launch_thread_safe_queue_agent,
 )
-from tools.schema import ServeTTSRequest
+from fish_speech.models.vqgan.inference import load_model as load_decoder_model
+from fish_speech.utils.schema import ServeTTSRequest
 from tools.server.inference import inference_wrapper as inference
-from tools.vqgan.inference import load_model as load_decoder_model
 
 ASR_MODEL_NAME = "iic/SenseVoiceSmall"
 

+ 1 - 1
tools/server/model_utils.py

@@ -54,7 +54,7 @@ def cached_vqgan_batch_encode(model, audios: list[bytes]):
 
 @torch.no_grad()
 @torch.autocast(device_type="cuda", dtype=torch.half)
-def vqgan_decode(model, features):
+def batch_vqgan_decode(model, features):
     lengths = torch.tensor(
         [feature.shape[-1] for feature in features], device=model.device
     )

+ 7 - 3
tools/server/views.py

@@ -11,7 +11,7 @@ from kui.asgi import Body, HTTPException, JSONResponse, Routes, StreamResponse,
 from loguru import logger
 from typing_extensions import Annotated
 
-from tools.schema import (
+from fish_speech.utils.schema import (
     ServeASRRequest,
     ServeASRResponse,
     ServeChatRequest,
@@ -29,7 +29,11 @@ from tools.server.api_utils import (
 )
 from tools.server.inference import inference_wrapper as inference
 from tools.server.model_manager import ModelManager
-from tools.server.model_utils import batch_asr, cached_vqgan_batch_encode, vqgan_decode
+from tools.server.model_utils import (
+    batch_asr,
+    batch_vqgan_decode,
+    cached_vqgan_batch_encode,
+)
 
 MAX_NUM_SAMPLES = int(os.getenv("NUM_SAMPLES", 1))
 
@@ -68,7 +72,7 @@ async def vqgan_decode(req: Annotated[ServeVQGANDecodeRequest, Body(exclusive=Tr
     # Decode the audio
     tokens = [torch.tensor(token, dtype=torch.int) for token in req.tokens]
     start_time = time.time()
-    audios = vqgan_decode(decoder_model, tokens)
+    audios = batch_vqgan_decode(decoder_model, tokens)
     logger.info(f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms")
     audios = [audio.astype(np.float16).tobytes() for audio in audios]
 

+ 1 - 1
tools/smart_pad.py

@@ -8,7 +8,7 @@ import torch.nn.functional as F
 import torchaudio
 from tqdm import tqdm
 
-from tools.file import AUDIO_EXTENSIONS, list_files
+from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
 
 threshold = 10 ** (-50 / 20.0)
 

+ 1 - 1
tools/vqgan/create_train_split.py

@@ -7,7 +7,7 @@ from loguru import logger
 from pydub import AudioSegment
 from tqdm import tqdm
 
-from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist
+from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
 
 
 @click.command()

+ 1 - 2
tools/vqgan/extract_vq.py

@@ -13,11 +13,10 @@ import torch
 import torchaudio
 from hydra import compose, initialize
 from hydra.utils import instantiate
-from lightning import LightningModule
 from loguru import logger
 from omegaconf import OmegaConf
 
-from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist
+from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
 
 # register eval resolver
 OmegaConf.register_new_resolver("eval", eval)

+ 9 - 113
tools/vqgan/inference.py

@@ -1,120 +1,16 @@
-from pathlib import Path
+import os
+import subprocess
+import sys
 
-import click
-import hydra
-import numpy as np
-import soundfile as sf
-import torch
-import torchaudio
-from hydra import compose, initialize
-from hydra.utils import instantiate
-from loguru import logger
-from omegaconf import OmegaConf
+#!/usr/bin/env python
 
-from tools.file import AUDIO_EXTENSIONS
 
-# register eval resolver
-OmegaConf.register_new_resolver("eval", eval)
-
-
-def load_model(config_name, checkpoint_path, device="cuda"):
-    hydra.core.global_hydra.GlobalHydra.instance().clear()
-    with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
-        cfg = compose(config_name=config_name)
-
-    model = instantiate(cfg)
-    state_dict = torch.load(
-        checkpoint_path, map_location=device, mmap=True, weights_only=True
+def main():
+    # Make path relative to this file
+    script_path = os.path.join(
+        os.path.dirname(__file__), "../../fish_speech/models/vqgan/inference.py"
     )
-    if "state_dict" in state_dict:
-        state_dict = state_dict["state_dict"]
-
-    if any("generator" in k for k in state_dict):
-        state_dict = {
-            k.replace("generator.", ""): v
-            for k, v in state_dict.items()
-            if "generator." in k
-        }
-
-    result = model.load_state_dict(state_dict, strict=False, assign=True)
-    model.eval()
-    model.to(device)
-
-    logger.info(f"Loaded model: {result}")
-    return model
-
-
-@torch.no_grad()
-@click.command()
-@click.option(
-    "--input-path",
-    "-i",
-    default="test.wav",
-    type=click.Path(exists=True, path_type=Path),
-)
-@click.option(
-    "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
-)
-@click.option("--config-name", default="firefly_gan_vq")
-@click.option(
-    "--checkpoint-path",
-    default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
-)
-@click.option(
-    "--device",
-    "-d",
-    default="cuda",
-)
-def main(input_path, output_path, config_name, checkpoint_path, device):
-    model = load_model(config_name, checkpoint_path, device=device)
-
-    if input_path.suffix in AUDIO_EXTENSIONS:
-        logger.info(f"Processing in-place reconstruction of {input_path}")
-
-        # Load audio
-        audio, sr = torchaudio.load(str(input_path))
-        if audio.shape[0] > 1:
-            audio = audio.mean(0, keepdim=True)
-        audio = torchaudio.functional.resample(
-            audio, sr, model.spec_transform.sample_rate
-        )
-
-        audios = audio[None].to(device)
-        logger.info(
-            f"Loaded audio with {audios.shape[2] / model.spec_transform.sample_rate:.2f} seconds"
-        )
-
-        # VQ Encoder
-        audio_lengths = torch.tensor([audios.shape[2]], device=device, dtype=torch.long)
-        indices = model.encode(audios, audio_lengths)[0][0]
-
-        logger.info(f"Generated indices of shape {indices.shape}")
-
-        # Save indices
-        np.save(output_path.with_suffix(".npy"), indices.cpu().numpy())
-    elif input_path.suffix == ".npy":
-        logger.info(f"Processing precomputed indices from {input_path}")
-        indices = np.load(input_path)
-        indices = torch.from_numpy(indices).to(device).long()
-        assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
-    else:
-        raise ValueError(f"Unknown input type: {input_path}")
-
-    # Restore
-    feature_lengths = torch.tensor([indices.shape[1]], device=device)
-    fake_audios, _ = model.decode(
-        indices=indices[None], feature_lengths=feature_lengths
-    )
-    audio_time = fake_audios.shape[-1] / model.spec_transform.sample_rate
-
-    logger.info(
-        f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}"
-    )
-
-    # Save audio
-    fake_audio = fake_audios[0, 0].float().cpu().numpy()
-    sf.write(output_path, fake_audio, model.spec_transform.sample_rate)
-    logger.info(f"Saved audio to {output_path}")
+    subprocess.run(["python", script_path] + sys.argv[1:])
 
 
 if __name__ == "__main__":

+ 1 - 1
tools/webui/__init__.py

@@ -3,7 +3,7 @@ from typing import Callable
 import gradio as gr
 
 from fish_speech.i18n import i18n
-from tools.inference_engine.utils import normalize_text
+from fish_speech.inference_engine.utils import normalize_text
 from tools.webui.variables import HEADER_MD, TEXTBOX_PLACEHOLDER
 
 

+ 1 - 1
tools/webui/inference.py

@@ -3,7 +3,7 @@ from functools import partial
 from typing import Any, Callable
 
 from fish_speech.i18n import i18n
-from tools.schema import ServeReferenceAudio, ServeTTSRequest
+from fish_speech.utils.schema import ServeReferenceAudio, ServeTTSRequest
 
 
 def inference_wrapper(

+ 1 - 1
tools/whisper_asr.py

@@ -32,7 +32,7 @@ from loguru import logger
 from pydub import AudioSegment
 from tqdm import tqdm
 
-from tools.file import AUDIO_EXTENSIONS, list_files
+from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
 
 
 @click.command()