Преглед изворни кода

This PR brings V1.2 inference into main (#300)

* Add model converter & agent test code

* Optimize training speed

* Fix agent nan loss

* Add new dataset & conversation system

* handle codebook bias in conversation

* Fix compile

* Implement new conversation system

* Support new vq

* Update VQ config

* optm:(build dataset) use mp to speedup grouping

* fix p not defined

* rollback due to thread lock

* remove vits decoder & add firefly vq

* remove unused configs, test generate.py

---------

Co-authored-by: Stardust·减 <stardust@fish.audio>
Leng Yue пре 1 година
родитељ
комит
5e7914472f
49 измењених фајлова са 1975 додато и 3625 уклоњено
  1. 1 0
      .gitignore
  2. 2 2
      API_FLAGS.txt
  3. 5 5
      docs/en/finetune.md
  4. 5 5
      docs/en/inference.md
  5. 5 5
      docs/zh/finetune.md
  6. 6 6
      docs/zh/inference.md
  7. 34 0
      fish_speech/configs/firefly_gan_vq.yaml
  8. 0 9
      fish_speech/configs/model/dual_ar_2_codebook_large.yaml
  9. 0 9
      fish_speech/configs/model/dual_ar_2_codebook_medium.yaml
  10. 0 13
      fish_speech/configs/model/dual_ar_2_codebook_small.yaml
  11. 0 12
      fish_speech/configs/model/naive_2_codebook_small.yaml
  12. 66 0
      fish_speech/configs/text2semantic_agent.yaml
  13. 0 128
      fish_speech/configs/vits_decoder_finetune.yaml
  14. 0 127
      fish_speech/configs/vits_decoder_pretrain.yaml
  15. 0 137
      fish_speech/configs/vqgan_finetune.yaml
  16. 0 140
      fish_speech/configs/vqgan_pretrain.yaml
  17. 2 0
      fish_speech/conversation.py
  18. 0 35
      fish_speech/datasets/concat_repeat.py
  19. 381 0
      fish_speech/datasets/prompts.py
  20. 263 298
      fish_speech/datasets/text.py
  21. 0 3
      fish_speech/models/text2semantic/__init__.py
  22. 11 7
      fish_speech/models/text2semantic/lit_module.py
  23. 206 68
      fish_speech/models/text2semantic/llama.py
  24. 8 0
      fish_speech/models/text2semantic/lora.py
  25. 0 3
      fish_speech/models/vits_decoder/__init__.py
  26. 0 394
      fish_speech/models/vits_decoder/lit_module.py
  27. 0 67
      fish_speech/models/vits_decoder/losses.py
  28. 0 350
      fish_speech/models/vits_decoder/modules/attentions.py
  29. 0 190
      fish_speech/models/vits_decoder/modules/commons.py
  30. 0 686
      fish_speech/models/vits_decoder/modules/models.py
  31. 0 619
      fish_speech/models/vits_decoder/modules/modules.py
  32. 0 58
      fish_speech/models/vits_decoder/modules/mrte.py
  33. 0 101
      fish_speech/models/vits_decoder/modules/vq_encoder.py
  34. 86 0
      fish_speech/models/vqgan/modules/firefly.py
  35. 1 1
      fish_speech/models/vqgan/modules/fsq.py
  36. 7 7
      fish_speech/webui/manage.py
  37. 2 1
      pyproject.toml
  38. 112 0
      run.py
  39. 412 0
      stream_service.py
  40. 183 0
      test_echo.py
  41. 3 9
      tools/api.py
  42. 101 0
      tools/llama/convert_hf_weights_to_llama.py
  43. 24 85
      tools/llama/generate.py
  44. 1 3
      tools/llama/merge_lora.py
  45. 1 1
      tools/llama/quantize.py
  46. 1 1
      tools/vits_decoder/inference.py
  47. 19 9
      tools/vqgan/extract_vq.py
  48. 24 22
      tools/vqgan/inference.py
  49. 3 9
      tools/webui.py

+ 1 - 0
.gitignore

@@ -25,3 +25,4 @@ asr-label*
 /.locale
 /demo-audios
 ref_data*
+/example

+ 2 - 2
API_FLAGS.txt

@@ -3,5 +3,5 @@
 --listen 0.0.0.0:8000 \
 --llama-checkpoint-path "checkpoints/text2semantic-sft-medium-v1.1-4k.pth" \
 --llama-config-name dual_ar_2_codebook_medium \
---decoder-checkpoint-path "checkpoints/vq-gan-group-fsq-2x1024.pth" \
---decoder-config-name vqgan_finetune
+--decoder-checkpoint-path "checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth" \
+--decoder-config-name firefly_gan_vq

+ 5 - 5
docs/en/finetune.md

@@ -59,8 +59,8 @@ You can then run the following command to extract semantic tokens:
 ```bash
 python tools/vqgan/extract_vq.py data \
     --num-workers 1 --batch-size 16 \
-    --config-name "vqgan_pretrain" \
-    --checkpoint-path "checkpoints/vq-gan-group-fsq-2x1024.pth"
+    --config-name "firefly_gan_vq" \
+    --checkpoint-path "checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth"
 ```
 
 !!! note
@@ -233,16 +233,16 @@ This command will create `data/vq_train_filelist.txt` and `data/vq_val_filelist.
 ### 3. Start Training
 
 ```bash
-python fish_speech/train.py --config-name vqgan_finetune
+python fish_speech/train.py --config-name firefly_gan_vq
 ```
 
 !!! note
-    You can modify training parameters by editing `fish_speech/configs/vqgan_finetune.yaml`, but in most cases, this won't be necessary.
+    You can modify training parameters by editing `fish_speech/configs/firefly_gan_vq.yaml`, but in most cases, this won't be necessary.
 
 ### 4. Test the Audio
     
 ```bash
-python tools/vqgan/inference.py -i test.wav --checkpoint-path results/vqgan_finetune/checkpoints/step_000010000.ckpt
+python tools/vqgan/inference.py -i test.wav --checkpoint-path results/firefly_gan_vq/checkpoints/step_000010000.ckpt
 ```
 
 You can review `fake.wav` to assess the fine-tuning results.

+ 5 - 5
docs/en/inference.md

@@ -31,7 +31,7 @@ huggingface-cli download fishaudio/fish-speech-1 firefly-gan-base-generator.ckpt
 ```bash
 python tools/vqgan/inference.py \
     -i "paimon.wav" \
-    --checkpoint-path "checkpoints/vq-gan-group-fsq-2x1024.pth"
+    --checkpoint-path "checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth"
 ```
 You should get a `fake.npy` file.
 
@@ -73,7 +73,7 @@ python tools/vits_decoder/inference.py \
 ```bash
 python tools/vqgan/inference.py \
     -i "codes_0.npy" \
-    --checkpoint-path "checkpoints/vq-gan-group-fsq-2x1024.pth"
+    --checkpoint-path "checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth"
 ```
 
 ## HTTP API Inference
@@ -85,8 +85,8 @@ python -m tools.api \
     --listen 0.0.0.0:8000 \
     --llama-checkpoint-path "checkpoints/text2semantic-sft-medium-v1.1-4k.pth" \
     --llama-config-name dual_ar_2_codebook_medium \
-    --decoder-checkpoint-path "checkpoints/vq-gan-group-fsq-2x1024.pth" \
-    --decoder-config-name vqgan_pretrain
+    --decoder-checkpoint-path "checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth" \
+    --decoder-config-name firefly_gan_vq
 ```
 
 After that, you can view and test the API at http://127.0.0.1:8000/.  
@@ -107,7 +107,7 @@ You can start the WebUI using the following command:
 python -m tools.webui \
     --llama-checkpoint-path "checkpoints/text2semantic-sft-medium-v1.1-4k.pth" \
     --llama-config-name dual_ar_2_codebook_medium \
-    --vqgan-checkpoint-path "checkpoints/vq-gan-group-fsq-2x1024.pth" \
+    --vqgan-checkpoint-path "checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth" \
     --vits-checkpoint-path "checkpoints/vits_decoder_v1.1.ckpt"
 ```
 

+ 5 - 5
docs/zh/finetune.md

@@ -63,8 +63,8 @@ HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech
 ```bash
 python tools/vqgan/extract_vq.py data \
     --num-workers 1 --batch-size 16 \
-    --config-name "vqgan_pretrain" \
-    --checkpoint-path "checkpoints/vq-gan-group-fsq-2x1024.pth"
+    --config-name "firefly_gan_vq" \
+    --checkpoint-path "checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth"
 ```
 
 !!! note
@@ -239,16 +239,16 @@ python tools/vqgan/create_train_split.py data
 ### 3. 启动训练
 
 ```bash
-python fish_speech/train.py --config-name vqgan_finetune
+python fish_speech/train.py --config-name firefly_gan_vq
 ```
 
 !!! note
-    你可以通过修改 `fish_speech/configs/vqgan_finetune.yaml` 来修改训练参数, 但大部分情况下, 你不需要这么做.
+    你可以通过修改 `fish_speech/configs/firefly_gan_vq.yaml` 来修改训练参数, 但大部分情况下, 你不需要这么做.
 
 ### 4. 测试音频
     
 ```bash
-python tools/vqgan/inference.py -i test.wav --checkpoint-path results/vqgan_finetune/checkpoints/step_000010000.ckpt
+python tools/vqgan/inference.py -i test.wav --checkpoint-path results/firefly_gan_vq/checkpoints/step_000010000.ckpt
 ```
 
 你可以查看 `fake.wav` 来判断微调效果.

+ 6 - 6
docs/zh/inference.md

@@ -41,7 +41,7 @@ HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech
 ```bash
 python tools/vqgan/inference.py \
     -i "paimon.wav" \
-    --checkpoint-path "checkpoints/vq-gan-group-fsq-2x1024.pth"
+    --checkpoint-path "checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth"
 ```
 你应该能得到一个 `fake.npy` 文件.
 
@@ -83,7 +83,7 @@ python tools/vits_decoder/inference.py \
 ```bash
 python tools/vqgan/inference.py \
     -i "codes_0.npy" \
-    --checkpoint-path "checkpoints/vq-gan-group-fsq-2x1024.pth"
+    --checkpoint-path "checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth"
 ```
 
 ## HTTP API 推理
@@ -95,8 +95,8 @@ python -m tools.api \
     --listen 0.0.0.0:8000 \
     --llama-checkpoint-path "checkpoints/text2semantic-sft-medium-v1.1-4k.pth" \
     --llama-config-name dual_ar_2_codebook_medium \
-    --decoder-checkpoint-path "checkpoints/vq-gan-group-fsq-2x1024.pth" \
-    --decoder-config-name vqgan_pretrain
+    --decoder-checkpoint-path "checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth" \
+    --decoder-config-name firefly_gan_vq
 
 # 推荐中国大陆用户运行以下命令来启动 HTTP 服务:
 HF_ENDPOINT=https://hf-mirror.com python -m ...
@@ -120,8 +120,8 @@ HF_ENDPOINT=https://hf-mirror.com python -m ...
 python -m tools.webui \
     --llama-checkpoint-path "checkpoints/text2semantic-sft-medium-v1.1-4k.pth" \
     --llama-config-name dual_ar_2_codebook_medium \
-    --decoder-checkpoint-path "checkpoints/vq-gan-group-fsq-2x1024.pth" \
-    --decoder-config-name vqgan_pretrain
+    --decoder-checkpoint-path "checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth" \
+    --decoder-config-name firefly_gan_vq
 ```
 
 !!! info

+ 34 - 0
fish_speech/configs/firefly_gan_vq.yaml

@@ -0,0 +1,34 @@
+_target_: fish_speech.models.vqgan.modules.firefly.FireflyArchitecture
+spec_transform:
+  _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
+  sample_rate: 44100
+  n_mels: 160
+  n_fft: 2048
+  hop_length: 512
+  win_length: 2048
+backbone:
+  _target_: fish_speech.models.vqgan.modules.firefly.ConvNeXtEncoder
+  input_channels: 160
+  depths: [3, 3, 9, 3]
+  dims: [128, 256, 384, 512]
+  drop_path_rate: 0.2
+  kernel_size: 7
+head:
+  _target_: fish_speech.models.vqgan.modules.firefly.HiFiGANGenerator
+  hop_length: 512
+  upsample_rates: [8, 8, 2, 2, 2]  # aka. strides
+  upsample_kernel_sizes: [16, 16, 4, 4, 4]
+  resblock_kernel_sizes: [3, 7, 11]
+  resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
+  num_mels: 512
+  upsample_initial_channel: 512
+  use_template: false
+  pre_conv_kernel_size: 13
+  post_conv_kernel_size: 13
+quantizer:
+  _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
+  input_dim: 512
+  n_groups: 4
+  n_codebooks: 1
+  levels: [8, 5, 5, 5]
+  downsample_factor: [2]

+ 0 - 9
fish_speech/configs/model/dual_ar_2_codebook_large.yaml

@@ -1,9 +0,0 @@
-defaults:
-  - dual_ar_2_codebook_small
-  - _self_
-
-config:
-  n_layer: 30
-  n_fast_layer: 6
-  n_head: 24
-  dim: 1536

+ 0 - 9
fish_speech/configs/model/dual_ar_2_codebook_medium.yaml

@@ -1,9 +0,0 @@
-defaults:
-  - dual_ar_2_codebook_small
-  - _self_
-
-config:
-  n_layer: 24
-  n_fast_layer: 6
-  n_head: 16
-  dim: 1024

+ 0 - 13
fish_speech/configs/model/dual_ar_2_codebook_small.yaml

@@ -1,13 +0,0 @@
-_target_: fish_speech.models.text2semantic.llama.DualARTransformer
-config:
-  _target_: fish_speech.models.text2semantic.llama.DualARModelArgs
-  max_seq_len: ${max_length}
-  vocab_size: 264 # pad 262 to 8x
-  n_layer: 12
-  n_fast_layer: 4
-  n_head: 12
-  dim: 768
-  rope_base: 10000
-  norm_eps: 1e-5
-  num_codebooks: 2  # input/output codebook size
-  codebook_size: 1032 # codebook size 1024 + 2 special tokens

+ 0 - 12
fish_speech/configs/model/naive_2_codebook_small.yaml

@@ -1,12 +0,0 @@
-_target_: fish_speech.models.text2semantic.llama.NaiveTransformer
-config:
-  _target_: fish_speech.models.text2semantic.llama.NaiveModelArgs
-  max_seq_len: ${max_length}
-  vocab_size: 36408
-  n_layer: 12
-  n_head: 12
-  dim: 768
-  rope_base: 10000
-  norm_eps: 1e-5
-  num_codebooks: 2  # input/output codebook size
-  codebook_size: 1032 # codebook size 1024 + 2 special tokens

+ 66 - 0
fish_speech/configs/text2semantic_agent.yaml

@@ -0,0 +1,66 @@
+defaults:
+  - base
+  - model@model.model: dual_ar_2_codebook_1.3b
+  - _self_
+
+project: text2semantic_agent_dual_ar_debug
+max_length: 2048
+ckpt_path: checkpoints/fish-speech-agent-1/step_000013000.ckpt
+resume_weights_only: true
+
+# Lightning Trainer
+trainer:
+  accumulate_grad_batches: 1
+  gradient_clip_val: 1.0
+  gradient_clip_algorithm: 'norm'
+  max_steps: 1_000_000
+  precision: bf16-true
+  log_every_n_steps: 10
+  limit_val_batches: 10
+  val_check_interval: 1000
+
+# Dataset Configuration
+tokenizer:
+  _target_: transformers.AutoTokenizer.from_pretrained
+  pretrained_model_name_or_path: checkpoints/fish-speech-agent-1
+
+# Dataset Configuration
+train_dataset: {}
+val_dataset: {}
+
+data:
+  _target_: fish_speech.datasets.text.TextDataModule
+  train_dataset: ${train_dataset}
+  val_dataset: ${val_dataset}
+  num_workers: 4
+  batch_size: 8
+  tokenizer: ${tokenizer}
+  max_length: ${max_length}
+
+# Model Configuration
+model:
+  _target_: fish_speech.models.text2semantic.TextToSemantic
+  model: {}
+
+  optimizer:
+    _target_: torch.optim.AdamW
+    _partial_: true
+    lr: 3e-4
+    weight_decay: 0.01
+    betas: [0.9, 0.95]
+    eps: 1e-5
+
+  lr_scheduler:
+    _target_: torch.optim.lr_scheduler.LambdaLR
+    _partial_: true
+    lr_lambda:
+      _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
+      _partial_: true
+      num_warmup_steps: 1000
+      num_training_steps: ${trainer.max_steps}
+      final_lr_ratio: 0.1
+
+# Callbacks
+callbacks:
+  model_checkpoint:
+    every_n_train_steps: ${trainer.val_check_interval}

+ 0 - 128
fish_speech/configs/vits_decoder_finetune.yaml

@@ -1,128 +0,0 @@
-defaults:
-  - base
-  - _self_
-
-project: vits_decoder
-ckpt_path: checkpoints/vits_decoder_v1.1.ckpt
-resume_weights_only: true
-
-# Lightning Trainer
-trainer:
-  accelerator: gpu
-  devices: auto
-  strategy:
-    find_unused_parameters: true
-  precision: 32
-  max_steps: 100_000
-  val_check_interval: 100
-  benchmark: false
-
-sample_rate: 44100
-hop_length: 512
-num_mels: 128
-n_fft: 2048
-win_length: 2048
-
-# Dataset Configuration
-tokenizer:
-  _target_: transformers.AutoTokenizer.from_pretrained
-  pretrained_model_name_or_path: fishaudio/fish-speech-1
-
-# Dataset Configuration
-train_dataset:
-  _target_: fish_speech.datasets.vits.VITSDataset
-  filelist: data/source/Genshin/filelist.train.txt
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-  suffix: ".lab"
-  tokenizer: ${tokenizer}
-  sentence_mask_ratio: 0.2
-
-val_dataset:
-  _target_: fish_speech.datasets.vits.VITSDataset
-  filelist: data/source/Genshin/filelist.test.txt
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-  suffix: ".lab"
-  tokenizer: ${tokenizer}
-
-data:
-  _target_: fish_speech.datasets.vits.VITSDataModule
-  train_dataset: ${train_dataset}
-  val_dataset: ${val_dataset}
-  num_workers: 4
-  batch_size: 8
-  val_batch_size: 4
-  tokenizer: ${tokenizer}
-
-# Model Configuration
-model:
-  _target_: fish_speech.models.vits_decoder.VITSDecoder
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-  freeze_discriminator: false
-
-  weight_mel: 45.0
-  weight_kl: 1.0
-
-  generator:
-    _target_: fish_speech.models.vits_decoder.modules.models.SynthesizerTrn
-    spec_channels: 1025
-    segment_size: 32
-    inter_channels: 192
-    hidden_channels: 192
-    filter_channels: 768
-    n_heads: 2
-    n_layers: 6
-    kernel_size: 3
-    p_dropout: 0.1
-    resblock: "1"
-    resblock_kernel_sizes: [3, 7, 11]
-    resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
-    upsample_rates: [8, 8, 2, 2, 2]
-    upsample_initial_channel: 512
-    upsample_kernel_sizes: [16, 16, 8, 2, 2]
-    gin_channels: 512
-    vq_mask_ratio: 0.2
-    ref_mask_ratio: 0.2
-
-  discriminator:
-    _target_: fish_speech.models.vits_decoder.modules.models.EnsembledDiscriminator
-    periods: [2, 3, 5, 7, 11]
-
-  mel_transform:
-    _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
-    sample_rate: ${sample_rate}
-    n_fft: ${n_fft}
-    hop_length: ${hop_length}
-    win_length: ${win_length}
-    n_mels: ${num_mels}
-
-  spec_transform:
-    _target_: fish_speech.utils.spectrogram.LinearSpectrogram
-    n_fft: ${n_fft}
-    hop_length: ${hop_length}
-    win_length: ${win_length}
-    mode: pow2_sqrt
-  
-  optimizer:
-    _target_: torch.optim.AdamW
-    _partial_: true
-    lr: 1e-4
-    betas: [0.8, 0.99]
-    eps: 1e-6
-
-  lr_scheduler:
-    _target_: torch.optim.lr_scheduler.ExponentialLR
-    _partial_: true
-    gamma: 0.999999
-
-callbacks:
-  grad_norm_monitor:
-    sub_module: 
-      - generator
-      - discriminator
-
-  model_checkpoint:
-    every_n_train_steps: ${trainer.val_check_interval}
-    save_top_k: 10

+ 0 - 127
fish_speech/configs/vits_decoder_pretrain.yaml

@@ -1,127 +0,0 @@
-defaults:
-  - base
-  - _self_
-
-project: vits_decoder
-ckpt_path: checkpoints/Bert-VITS2/ensemble.pth
-resume_weights_only: true
-
-# Lightning Trainer
-trainer:
-  accelerator: gpu
-  devices: auto
-  strategy: ddp_find_unused_parameters_true
-  precision: 32
-  max_steps: 1_000_000
-  val_check_interval: 1000
-  benchmark: false
-
-sample_rate: 44100
-hop_length: 512
-num_mels: 128
-n_fft: 2048
-win_length: 2048
-
-# Dataset Configuration
-tokenizer:
-  _target_: transformers.AutoTokenizer.from_pretrained
-  pretrained_model_name_or_path: fishaudio/fish-speech-1
-
-# Dataset Configuration
-train_dataset:
-  _target_: fish_speech.datasets.vits.VITSDataset
-  filelist: data/source/Genshin/filelist.train.txt
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-  suffix: ".lab"
-  tokenizer: ${tokenizer}
-  sentence_mask_ratio: 0.2
-
-val_dataset:
-  _target_: fish_speech.datasets.vits.VITSDataset
-  filelist: data/source/Genshin/filelist.test.txt
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-  suffix: ".lab"
-  tokenizer: ${tokenizer}
-
-data:
-  _target_: fish_speech.datasets.vits.VITSDataModule
-  train_dataset: ${train_dataset}
-  val_dataset: ${val_dataset}
-  num_workers: 4
-  batch_size: 8
-  val_batch_size: 4
-  tokenizer: ${tokenizer}
-
-# Model Configuration
-model:
-  _target_: fish_speech.models.vits_decoder.VITSDecoder
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-  freeze_discriminator: false
-
-  weight_mel: 45.0
-  weight_kl: 1.0
-
-  generator:
-    _target_: fish_speech.models.vits_decoder.modules.models.SynthesizerTrn
-    spec_channels: 1025
-    segment_size: 32
-    inter_channels: 192
-    hidden_channels: 192
-    filter_channels: 768
-    n_heads: 2
-    n_layers: 6
-    kernel_size: 3
-    p_dropout: 0.1
-    resblock: "1"
-    resblock_kernel_sizes: [3, 7, 11]
-    resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
-    upsample_rates: [8, 8, 2, 2, 2]
-    upsample_initial_channel: 512
-    upsample_kernel_sizes: [16, 16, 8, 2, 2]
-    gin_channels: 512
-    vq_mask_ratio: 0.2
-    ref_mask_ratio: 0.2
-
-  discriminator:
-    _target_: fish_speech.models.vits_decoder.modules.models.EnsembledDiscriminator
-    periods: [2, 3, 5, 7, 11]
-
-  mel_transform:
-    _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
-    sample_rate: ${sample_rate}
-    n_fft: ${n_fft}
-    hop_length: ${hop_length}
-    win_length: ${win_length}
-    n_mels: ${num_mels}
-
-  spec_transform:
-    _target_: fish_speech.utils.spectrogram.LinearSpectrogram
-    n_fft: ${n_fft}
-    hop_length: ${hop_length}
-    win_length: ${win_length}
-    mode: pow2_sqrt
-  
-  optimizer:
-    _target_: torch.optim.AdamW
-    _partial_: true
-    lr: 1e-4
-    betas: [0.8, 0.99]
-    eps: 1e-6
-
-  lr_scheduler:
-    _target_: torch.optim.lr_scheduler.ExponentialLR
-    _partial_: true
-    gamma: 0.999999
-
-callbacks:
-  grad_norm_monitor:
-    sub_module: 
-      - generator
-      - discriminator
-
-  model_checkpoint:
-    every_n_train_steps: 1000
-    save_top_k: 10

+ 0 - 137
fish_speech/configs/vqgan_finetune.yaml

@@ -1,137 +0,0 @@
-defaults:
-  - base
-  - _self_
-
-project: vq-gan-finetune
-ckpt_path: checkpoints/vq-gan-group-fsq-2x1024.pth
-resume_weights_only: true
-
-# Lightning Trainer
-trainer:
-  accelerator: gpu
-  devices: auto
-  precision: bf16-mixed
-  max_steps: 100_000
-  val_check_interval: 5000
-  strategy:
-    find_unused_parameters: true
-
-sample_rate: 44100
-hop_length: 512
-num_mels: 128
-n_fft: 2048
-win_length: 2048
-
-# Dataset Configuration
-train_dataset:
-  _target_: fish_speech.datasets.vqgan.VQGANDataset
-  filelist: data/vq_train_filelist.txt
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-  slice_frames: 512
-
-val_dataset:
-  _target_: fish_speech.datasets.vqgan.VQGANDataset
-  filelist: data/vq_val_filelist.txt
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-
-data:
-  _target_: fish_speech.datasets.vqgan.VQGANDataModule
-  train_dataset: ${train_dataset}
-  val_dataset: ${val_dataset}
-  num_workers: 4
-  batch_size: 16
-  val_batch_size: 16
-
-# Model Configuration
-model:
-  _target_: fish_speech.models.vqgan.VQGAN
-
-  sampling_rate: ${sample_rate}
-  weight_adv: 0.2
-  weight_vq: 1.0
-  weight_mel: 1.0
-
-  # Important: Set the freeze_encoder to true to only train the decoder
-  freeze_encoder: true
-
-  encoder:
-    _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
-    input_channels: ${num_mels}
-    residual_channels: 768
-    residual_layers: 20
-    dilation_cycle: 4
-  
-  quantizer:
-    _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
-    input_dim: 768
-    n_codebooks: 1
-    n_groups: 2
-    levels: [8, 5, 5, 5]
-
-  decoder:
-    _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
-    output_channels: ${num_mels}
-    residual_channels: 768
-    residual_layers: 20
-    dilation_cycle: 4
-    condition_channels: 768
-  
-  discriminator:
-    _target_: fish_speech.models.vqgan.modules.discriminator.Discriminator
-
-  vocoder:
-    _target_: fish_speech.models.vqgan.modules.firefly.FireflyBase
-    ckpt_path: null # You may download the pretrained vocoder and set the path here
-
-  encode_mel_transform:
-    _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
-    sample_rate: ${sample_rate}
-    n_fft: ${n_fft}
-    hop_length: ${hop_length}
-    win_length: ${win_length}
-    n_mels: ${num_mels}
-    f_min: 0.0
-    f_max: 8000.0
-
-  gt_mel_transform:
-    _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
-    sample_rate: ${sample_rate}
-    n_fft: ${n_fft}
-    hop_length: ${hop_length}
-    win_length: ${win_length}
-    n_mels: ${num_mels}
-
-  optimizer:
-    _target_: torch.optim.AdamW
-    _partial_: true
-    lr: 4e-5
-    betas: [0.8, 0.99]
-    eps: 1e-5
-    weight_decay: 0.01
-
-  lr_scheduler:
-    _target_: torch.optim.lr_scheduler.LambdaLR
-    _partial_: true
-    lr_lambda:
-      _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
-      _partial_: true
-      num_warmup_steps: 0
-      num_training_steps: ${trainer.max_steps}
-      final_lr_ratio: 0
-
-callbacks:
-  model_summary:
-    _target_: lightning.pytorch.callbacks.ModelSummary
-    max_depth: 1
-
-  model_checkpoint:
-    every_n_train_steps: ${trainer.val_check_interval}
-
-  grad_norm_monitor:
-    sub_module: 
-      - encoder
-      - decoder
-      - quantizer
-      - discriminator

+ 0 - 140
fish_speech/configs/vqgan_pretrain.yaml

@@ -1,140 +0,0 @@
-defaults:
-  - base
-  - _self_
-
-project: vq-gan-pretrain
-
-# Lightning Trainer
-trainer:
-  accelerator: gpu
-  devices: auto
-  precision: bf16-mixed
-  max_steps: 1_000_000
-  val_check_interval: 5000
-  strategy:
-    find_unused_parameters: true
-
-sample_rate: 44100
-hop_length: 512
-num_mels: 128
-n_fft: 2048
-win_length: 2048
-
-# Dataset Configuration
-train_dataset:
-  _target_: torch.utils.data.ConcatDataset
-  datasets:
-    - _target_: fish_speech.datasets.vqgan.VQGANDataset
-      filelist: data/gigaspeech/vq_train_filelist.txt
-      sample_rate: ${sample_rate}
-      hop_length: ${hop_length}
-      slice_frames: 512
-    - _target_: fish_speech.datasets.vqgan.VQGANDataset
-      filelist: data/sft/vq_train_filelist.txt
-      sample_rate: ${sample_rate}
-      hop_length: ${hop_length}
-      slice_frames: 512
-
-val_dataset:
-  _target_: fish_speech.datasets.vqgan.VQGANDataset
-  filelist: data/sft/vq_val_filelist.txt
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-
-data:
-  _target_: fish_speech.datasets.vqgan.VQGANDataModule
-  train_dataset: ${train_dataset}
-  val_dataset: ${val_dataset}
-  num_workers: 4
-  batch_size: 32
-  val_batch_size: 32
-
-# Model Configuration
-model:
-  _target_: fish_speech.models.vqgan.VQGAN
-
-  sampling_rate: ${sample_rate}
-  weight_adv: 0.2
-  weight_vq: 1.0
-  weight_mel: 1.0
-  freeze_encoder: false
-
-  encoder:
-    _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
-    input_channels: ${num_mels}
-    residual_channels: 768
-    residual_layers: 20
-    dilation_cycle: 4
-  
-  quantizer:
-    _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
-    input_dim: 768
-    n_codebooks: 1
-    n_groups: 2
-    levels: [8, 5, 5, 5]
-
-  decoder:
-    _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
-    output_channels: ${num_mels}
-    residual_channels: 768
-    residual_layers: 20
-    dilation_cycle: 4
-    condition_channels: 768
-  
-  discriminator:
-    _target_: fish_speech.models.vqgan.modules.discriminator.Discriminator
-
-  vocoder:
-    _target_: fish_speech.models.vqgan.modules.firefly.FireflyBase
-    ckpt_path: null # You may download the pretrained vocoder and set the path here
-
-  encode_mel_transform:
-    _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
-    sample_rate: ${sample_rate}
-    n_fft: ${n_fft}
-    hop_length: ${hop_length}
-    win_length: ${win_length}
-    n_mels: ${num_mels}
-    f_min: 0.0
-    f_max: 8000.0
-
-  gt_mel_transform:
-    _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
-    sample_rate: ${sample_rate}
-    n_fft: ${n_fft}
-    hop_length: ${hop_length}
-    win_length: ${win_length}
-    n_mels: ${num_mels}
-
-  optimizer:
-    _target_: torch.optim.AdamW
-    _partial_: true
-    lr: 1e-4
-    betas: [0.8, 0.99]
-    eps: 1e-5
-    weight_decay: 0.01
-
-  lr_scheduler:
-    _target_: torch.optim.lr_scheduler.LambdaLR
-    _partial_: true
-    lr_lambda:
-      _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
-      _partial_: true
-      num_warmup_steps: 100
-      num_training_steps: ${trainer.max_steps}
-      final_lr_ratio: 0
-
-callbacks:
-  model_summary:
-    _target_: lightning.pytorch.callbacks.ModelSummary
-    max_depth: 1
-
-  model_checkpoint:
-    every_n_train_steps: ${trainer.val_check_interval}
-
-  grad_norm_monitor:
-    sub_module: 
-      - encoder
-      - decoder
-      - quantizer
-      - discriminator

+ 2 - 0
fish_speech/conversation.py

@@ -0,0 +1,2 @@
+SEMANTIC_TOKEN = "<|semantic|>"
+CODEBOOK_PAD_TOKEN_ID = 0

+ 0 - 35
fish_speech/datasets/concat_repeat.py

@@ -51,38 +51,3 @@ class ConcatRepeatDataset(Dataset):
         dataset = self.datasets[dataset_idx]
 
         return dataset[sample_idx % len(dataset)]
-
-
-class ConcatWeightedIterableDataset(IterableDataset):
-    datasets: list[IterableDataset]
-    weights: list[float]
-
-    def __init__(self, datasets: Iterable[IterableDataset], weights: list[float]):
-        super().__init__()
-
-        total_weight = sum(weights)
-        self.weights = [w / total_weight for w in weights]
-        self.datasets = list(datasets)
-
-        assert len(self.datasets) > 0, "datasets should not be an empty iterable"
-        assert len(self.datasets) == len(
-            weights
-        ), "datasets and repeats should have the same length"
-
-        for d in self.datasets:
-            assert isinstance(
-                d, IterableDataset
-            ), "ConcatRepeatIterableDataset only supports IterableDataset"
-
-    def __iter__(self):
-        all_datasets = [iter(dataset) for dataset in self.datasets]
-        ids = list(range(len(self.datasets)))
-
-        while True:
-            chosen_dataset = random.choices(ids, self.weights)[0]
-
-            try:
-                yield next(all_datasets[chosen_dataset])
-            except StopIteration:
-                all_datasets[chosen_dataset] = iter(self.datasets[chosen_dataset])
-                yield next(all_datasets[chosen_dataset])

+ 381 - 0
fish_speech/datasets/prompts.py

@@ -0,0 +1,381 @@
+# "Transcribe the following audio into text."
+# "Transcribe what you will hear."
+
+asr_instructions = [
+    "Transcribe:",
+    "Transcribe the following audio into text.",
+    "Convert the audio you're about to hear into written text.",
+    "Please write down what you hear in the audio file.",
+    "Listen to the audio and type out its contents.",
+    "Your task is to write the audio's content in text form.",
+    "Transcribe the content of the audio into text.",
+    "Transform the given audio into a textual format.",
+    "Listen to the following sound clip and transcribe it.",
+    "The audio provided should be converted into written words.",
+    "Document the audio in text.",
+    "Put the audio's dialogue into written form.",
+    "Capture the audio's message in text.",
+    "Turn the sound file's speech into text.",
+    "Render the audio into a text version.",
+    "Translate the audio recording to text.",
+    "Write out the dialogue from the audio.",
+    "Listen and transcribe the audio into words.",
+    "Change the audio into a written transcript.",
+    "Your job is to transcribe the audio to text.",
+    "Please transcribe the spoken words into text.",
+    "The task is to convert audio speech into written text.",
+    "Make a text transcript of the following audio.",
+    "Decode the audio into a written document.",
+    "Write down the transcription of the audio.",
+    "Please provide a text version of this audio.",
+    "The objective is to transcribe the audio into readable text.",
+    "Listen carefully and type out the audio.",
+    "Transform this audio clip into a text document.",
+    "Your assignment is to transcribe this audio.",
+    "Transcribe this sound recording into text format.",
+    "The goal is to turn the audio into text.",
+    "Your duty is to document the audio in written form.",
+    "Listen to this audio piece and write down its contents.",
+    "The task is converting the audio into text.",
+    "Please create a textual transcription of the audio.",
+    "Capture in writing what is said in the audio.",
+    "Transcribe the audible content into a text format.",
+    "The mission is to transcribe the audio into text.",
+    "Your task: convert the audio to text.",
+    "Write the contents of the audio as text.",
+    "Listen to the clip and transcribe its audio to text.",
+    "Transcribe the given audio track into written words.",
+    "The assignment is to write out the audio in text.",
+    "Convert the spoken words into text.",
+    "Transcribe the voice recording into text.",
+    "Your task is to make a written record of the audio.",
+    "Listen to the audio and reproduce it in text.",
+    "Transcribe the following sound into written text.",
+    "Your challenge is to transcribe the audio into written form.",
+    "Make a written version of the audio.",
+    "Take the audio and transcribe it to text.",
+    "Write down everything you hear in the audio.",
+    "Please put the audio into text format.",
+    "Your role is to transcribe the following audio into text.",
+    "Convert the audio message into written text.",
+    "Provide a written transcription of the audio.",
+    "Listen and convert the audio to text.",
+    "The requirement is to transcribe the audio into text form.",
+    "Document in text what the audio says.",
+    "Transcribe into text what you hear in the audio.",
+    "Translate the audio file's contents into text.",
+    "The task is to create a text transcript of the audio.",
+    "Your assignment: Translate the audio into written words.",
+    "Write a textual representation of the audio.",
+    "Capture the essence of the audio in text.",
+    "Your job: Listen to the audio and transcribe it.",
+    "Turn the audio content into a text transcript.",
+    "The task at hand is to transcribe the audio to text.",
+    "Reproduce the audio in text form.",
+    "Your mission: Convert the audio into a textual format.",
+    "Transcribe what is spoken in the audio into text.",
+    "Create a written version of what's in the audio.",
+    "Transform the spoken audio into text.",
+    "Document the spoken words in the audio as text.",
+    "The objective is to write down the audio in text.",
+    "Your goal: Transcribe the audio into text.",
+    "Please convert the audio file into text.",
+    "Transcribe the audio clip into written text.",
+    "Listen to the audio and transcribe the speech into text.",
+    "Transform the voice from the audio into written words.",
+    "The task is to write the audio's speech in text form.",
+    "Your duty: Write down what the audio says.",
+    "Turn the given audio into a written format.",
+    "Write in text form what is said in the audio.",
+    "Your task: Document the audio in text.",
+    "Provide a text transcription of the audio.",
+    "Provide a text transcription of the audio.",
+    "Write down the audio you listen to.",
+    "Type out the spoken words you hear.",
+    "Document the audio content verbatim.",
+    "Transcribe the spoken content accurately.",
+    "Convert the audio you hear into text.",
+    "Record in writing what is said in the audio.",
+    "Capture the spoken words in written form.",
+    "Translate the audio into written text.",
+    "Jot down the words you hear in the audio.",
+    "Put into writing the spoken words you hear.",
+    "Transcribe the auditory information verbatim.",
+    "Note down the dialogue from the audio.",
+    "Write out the spoken words from the audio.",
+    "Transcribe the oral presentation into text.",
+    "Render the spoken audio into written form.",
+    "Reproduce the spoken words in text form.",
+    "Document what is being said in the audio.",
+    "Translate the spoken word into written form.",
+    "Write verbatim what you hear in the audio.",
+    "Capture in writing the contents of the audio.",
+    "Transcribe verbatim the spoken words.",
+    "Write down verbatim what is spoken.",
+    "Transcribe the sounds into words on paper.",
+    "Translate the sounds you hear into words.",
+    "Write the spoken words in text form.",
+    "Reproduce the audio content in writing.",
+    "Note verbatim what is said in the audio.",
+    "Put the audio content into written words.",
+    "Record the spoken words into text format.",
+    "Transcribe the audio into a written document.",
+    "Write down exactly what you hear.",
+    "Type out the content of the audio.",
+    "Document the words spoken in the audio.",
+    "Translate the verbal content into text.",
+    "Convert what you hear into written words.",
+    "Capture the essence of the audio in writing.",
+    "Reproduce the spoken content in written form.",
+    "Jot down exactly what is said in the audio.",
+    "Document every word you hear in the audio.",
+    "Record the audio content by writing it down.",
+    "Capture the audio's spoken words in text.",
+    "Turn the spoken audio into a written transcript.",
+    "Write down the contents of the audio verbatim.",
+    "Transcribe the voice you hear into text.",
+    "Convert the spoken audio into text format.",
+    "Type what is being spoken in the audio.",
+    "Translate the audio speech into written words.",
+    "Write the audio's dialogue in written form.",
+    "Record the verbal content as written text.",
+    "Transcribe the spoken parts of the audio.",
+    "Note down everything you hear in the audio.",
+    "Capture every word from the audio in text.",
+    "Put the spoken audio into text form.",
+    "Transcribe the audible content into words.",
+    "Translate the oral content into written text.",
+    "Type out everything heard in the audio.",
+    "Write down the spoken parts verbatim.",
+    "Document the spoken audio in text form.",
+    "Capture the verbal exchanges in written text.",
+    "Transcribe each word you hear accurately.",
+    "Turn the audio into a textual document.",
+    "Transcribe the sound into written words.",
+    "Write the audio transcript in your own words.",
+    "Document in text what you hear in the audio.",
+    "Record in text the spoken parts of the audio.",
+    "Transcribe the narrative you hear into text.",
+    "Capture the spoken narrative in written form.",
+    "Convert the verbal audio into written script.",
+    "Note down the spoken words in the audio.",
+    "Write in text form what is spoken in the audio.",
+    "Record the audio's spoken words verbatim.",
+    "Jot down the audio's dialogue accurately.",
+    "Transcribe the verbal parts into written words.",
+    "Translate the audio's spoken content into text.",
+    "Document the audio dialogue in written form.",
+    "Type out the words spoken in the audio verbatim.",
+    "Write down word for word what is said in the audio.",
+    "Transcribe the entire audio content into text.",
+    "Note down precisely what is said in the audio.",
+    "Capture in text the spoken content of the audio.",
+    "Record the spoken audio into written language.",
+    "Write the essence of the audio in text form.",
+    "Transcribe the words you hear in the audio.",
+    "Translate every spoken word into written text.",
+    "Convert the oral speech into a written format.",
+    "Jot down the words spoken in the audio.",
+    "Record every word from the audio in writing.",
+    "Document the entire audio in written form.",
+    "Transcribe the spoken language into text.",
+    "Write down the audio's words exactly as spoken.",
+    "Capture the spoken word in written format.",
+    "Type out verbatim the spoken audio content.",
+    "Write precisely what you hear from the audio.",
+]
+
+# "Read the following text with emotion."
+# "Read the following text."
+
+tts_instructions = [
+    "Speak:",
+    "Expressively read the text that follows.",
+    "Convey the upcoming text with emotion.",
+    "Deliver the following passage with heartfelt expression.",
+    "Evoke emotion while reading the text below.",
+    "With feeling, please read the text that comes next.",
+    "Infuse the upcoming words with emotional depth as you read.",
+    "Let your emotions guide you as you read the following lines.",
+    "Channel emotion into your reading of the next passage.",
+    "Read the text below with a sense of emotion.",
+    "Bring the following words to life with emotional expression.",
+    "Engage emotionally with the text as you read it aloud.",
+    "Imbue the subsequent text with feeling as you read.",
+    "Read the following content with genuine emotion.",
+    "Allow your feelings to resonate through the upcoming text.",
+    "Emotionally interpret the text that follows.",
+    "Read the ensuing passage with deep feeling.",
+    "Convey the text below with genuine emotional depth.",
+    "Read the text that comes next, letting your emotions flow.",
+    "With emotion, present the following words.",
+    "Let your emotional expression enhance the next text.",
+    "Embrace emotion as you read the following passage.",
+    "Read aloud the text below with emotive expression.",
+    "Infuse the upcoming lines with emotional intensity.",
+    "With sincerity, read the following text with emotion.",
+    "Project emotion as you deliver the text that follows.",
+    "Let the next words be read with a wealth of emotion.",
+    "Give the upcoming text an emotional rendition.",
+    "With emotion, read the text that is presented next.",
+    "Convey the essence of the following text with heartfelt emotion.",
+    "Inject emotional depth into your reading of the next passage.",
+    "Bring out the emotional undertones in the following text.",
+    "Embody the emotions as you read the text below.",
+    "Express the following narrative with emotional depth.",
+    "Let emotion permeate your reading of the upcoming passage.",
+    "Interpret the following text with a rich emotional tone.",
+    "Elicit emotion through your reading of the next content.",
+    "Read the subsequent text with a deep emotional connection.",
+    "Emote the essence of the text that follows in your reading.",
+    "Render the following lines with emotional expression.",
+    "Expressively interpret the upcoming text.",
+    "Immerse in emotion as you read the following passage.",
+    "Engage with the text below on an emotional level as you read.",
+    "With emotional clarity, read the next text.",
+    "Let an emotional depth inform your reading of the following words.",
+    "Express the following content with deep emotional resonance.",
+    "Deliver the upcoming text with a range of emotions.",
+    "Narrate the following lines with emotional expressiveness.",
+    "Convey emotional texture as you read the text below.",
+    "Instill the next passage with emotive power as you read.",
+    "Read the ensuing text with a palette of emotions.",
+    "With a depth of feeling, present the next text.",
+    "Inflect the upcoming words with emotional vibrancy.",
+    "Emotionally engage with the text that follows in your reading.",
+    "Lend emotional expression to the passage below.",
+    "Evoke a spectrum of emotions as you read the next lines.",
+    "Channel a rich emotional tone into the following text.",
+    "With feeling, convey the essence of the upcoming passage.",
+    "Read the text that comes next with emotional fervor.",
+    "Render the following words with emotional authenticity.",
+    "Give the upcoming passage an emotive interpretation.",
+    "Allow your reading of the text below to be emotionally driven.",
+    "Imbue the next lines with a sense of emotion.",
+    "Emotionally animate the following text as you read.",
+    "Bring emotional depth to the passage that follows.",
+    "Articulate the text below with emotional nuance.",
+    "Project a range of emotions as you read the upcoming text.",
+    "With emotion, breathe life into the following words.",
+    "Narrate the ensuing text with heartfelt emotion.",
+    "Convey the text that follows with emotional richness.",
+    "Read aloud the next passage with a depth of emotion.",
+    "Emphasize emotional expression in your reading of the text below.",
+    "Let your reading of the following lines be emotionally charged.",
+    "With a heartfelt approach, read the upcoming text.",
+    "Express the essence of emotion as you deliver the next passage.",
+    "Read the following text, infused with emotional energy.",
+    "Allow the text that comes next to be expressed with emotion.",
+    "Convey the following passage with an emotional depth.",
+    "Emotionally render the text that follows.",
+    "With an emotional undertone, read the upcoming words.",
+    "Read the text below, letting emotion guide your expression.",
+    "Elicit an emotional response through your reading of the next passage.",
+    "Give the following lines an emotive delivery.",
+    "Read the upcoming text with emotional sincerity.",
+    "Narrate the text that follows with an emotional touch.",
+    "Deliver the following words with an emotive clarity.",
+    "Express the next passage with a range of emotional tones.",
+    "Immerse yourself emotionally in the text below as you read.",
+    "Let the ensuing text be conveyed with profound emotion.",
+    "Infuse the following lines with a sense of heartfelt emotion.",
+    "Emotionally engage with the upcoming text in your reading.",
+    "Convey deep emotion as you read the text that follows.",
+    "Let your reading of the next passage be rich in emotion.",
+    "With emotional depth, narrate the following text.",
+    "Read the text below, capturing its emotional essence.",
+    "Emote through your reading of the upcoming lines.",
+    "Please read the text that follows aloud.",
+    "Proceed to vocalize the upcoming text.",
+    "Kindly articulate the subsequent text.",
+    "Go ahead and pronounce the text below.",
+    "Could you recite the forthcoming passage?",
+    "Start reading the text below out loud.",
+    "Announce the following text audibly.",
+    "Voice the text that comes next.",
+    "Read through the following lines aloud.",
+    "Narrate the text presented below.",
+    "Elevate your voice for the upcoming script.",
+    "Broadcast the text that follows.",
+    "Project the subsequent lines audibly.",
+    "Give voice to the text underneath.",
+    "Unfold the following text with your voice.",
+    "Engage in reading the next piece of text aloud.",
+    "Orate the following series of words.",
+    "Enunciate the text appearing next.",
+    "Verbally present the upcoming text.",
+    "Articulate the passage that follows.",
+    "Read aloud the text that's coming up.",
+    "Proclaim the subsequent words.",
+    "Vocalize the narrative below.",
+    "Bring the following text to life by reading it aloud.",
+    "Express the next text with your voice.",
+    "Render the following text audibly.",
+    "Voice out the lines that follow.",
+    "Orally deliver the upcoming text.",
+    "Loudly read out the text below.",
+    "Share the next text by reading it out loud.",
+    "Speak the following passage aloud.",
+    "Let your voice carry the upcoming words.",
+    "Annunciate the text that follows.",
+    "Sound out the subsequent text.",
+    "Aurally present the text below.",
+    "Elocute the forthcoming lines.",
+    "Recite the text below with clarity.",
+    "Make the next text heard by reading aloud.",
+    "Bring forth your voice for the following script.",
+    "Read the text that ensues out loud.",
+    "Deliver the following lines vocally.",
+    "Voice the ensuing text.",
+    "Publicly read the text that follows.",
+    "Loudly narrate the subsequent text.",
+    "Express the following text through your voice.",
+    "Verbally articulate the next passage.",
+    "Read the forthcoming text clearly.",
+    "Announce the next set of words aloud.",
+    "Broadcast the following narrative.",
+    "Articulate the text coming up next.",
+    "Enunciate the passage that follows clearly.",
+    "Recite the subsequent text audibly.",
+    "Speak out the text below.",
+    "Project your voice with the following words.",
+    "Read the next lines aloud.",
+    "Vocalize the text that is to follow.",
+    "Narrate aloud the text below.",
+    "Orate the forthcoming script.",
+    "Pronounce the next passage.",
+    "Read out the subsequent text.",
+    "Let the following words be heard by reading them aloud.",
+    "Express the text that follows with your voice.",
+    "Give audible life to the text below.",
+    "Speak the ensuing text clearly.",
+    "Make the forthcoming text audible.",
+    "Project the next series of words audibly.",
+    "Voice out the following narrative.",
+    "Elevate the subsequent text with your voice.",
+    "Bring the next passage to audible life.",
+    "Read the lines that come next out loud.",
+    "Announce the text below with clarity.",
+    "Vocalize the script that follows.",
+    "Narrate the following text with emphasis.",
+    "Deliver the upcoming words with your voice.",
+    "Articulate the next set of lines.",
+    "Verbally convey the following text.",
+    "Present the subsequent text vocally.",
+    "Enunciate the upcoming passage loudly.",
+    "Orally render the text that follows.",
+    "Speak out the subsequent narrative.",
+    "Proclaim the next text audibly.",
+    "Elocute the following lines with clarity.",
+    "Give voice to the upcoming script.",
+    "Let your voice express the text below.",
+    "Annunciate the following words clearly.",
+    "Sound out the text that is next.",
+    "Aurally convey the subsequent passage.",
+    "Read the text up next aloud.",
+]
+
+prompt_dict = {
+    "asr": asr_instructions,
+    "tts": tts_instructions,
+}

+ 263 - 298
fish_speech/datasets/text.py

@@ -1,21 +1,29 @@
+import gzip
+import io
+import json
 import random
 from dataclasses import dataclass
-from itertools import chain
 from pathlib import Path
 from random import Random
 from typing import Optional, Union
 
 import numpy as np
-import pyarrow.parquet as pq
 import torch
 import torch.nn.functional as F
-from datasets.download.streaming_download_manager import xopen
-from huggingface_hub import HfApi
+import zstandard as zstd
 from lightning import LightningDataModule
 from torch.distributed import get_rank, get_world_size, is_initialized
 from torch.utils.data import DataLoader, IterableDataset, get_worker_info
 from transformers import AutoTokenizer
 
+from fish_speech.conversation import (
+    CODEBOOK_PAD_TOKEN_ID,
+    SKIP_TEXT_STRING,
+    Conversation,
+    Message,
+    encode_conversation,
+)
+from fish_speech.datasets.prompts import asr_instructions, tts_instructions
 from fish_speech.datasets.protos.text_data_pb2 import SampledData
 from fish_speech.datasets.protos.text_data_stream import read_pb_stream
 from fish_speech.text.clean import clean_text
@@ -24,9 +32,7 @@ from fish_speech.utils.braceexpand import braceexpand
 
 log = RankedLogger(__name__, rank_zero_only=True)
 
-CODEBOOK_PAD_TOKEN_ID = 0
-CODEBOOK_EOS_TOKEN_ID = 1
-SKIP_TEXT_STRING = "<|skip_text|>"
+DCTX = zstd.ZstdDecompressor(max_window_size=2**31)
 
 
 def split_by_rank_worker(files):
@@ -56,43 +62,55 @@ def split_by_rank_worker(files):
     return files
 
 
-class StreamTextDataset(IterableDataset):
+def expand_split_proto_files(proto_files, seed: int = 42):
+    # Expand the proto files
+    expanded_proto_files = []
+    for filename in proto_files:
+        for i in braceexpand(filename):
+            i = Path(i)
+            if i.is_file():
+                expanded_proto_files.append(i)
+            elif i.is_dir():
+                expanded_proto_files.extend(i.rglob("*.proto"))
+                expanded_proto_files.extend(i.rglob("*.protos"))
+            else:
+                raise ValueError(f"{i} is not a file or directory")
+
+    expanded_proto_files = sorted(expanded_proto_files)
+    Random(seed).shuffle(expanded_proto_files)
+    return split_by_rank_worker(expanded_proto_files)
+
+
+class TextPretrainDataset(IterableDataset):
     def __init__(
         self,
-        files: Optional[Union[list[str], str]] = None,
-        prefix: Optional[str] = None,
+        source: str,
         seed: int = 42,
-        parquet_batch_size: int = 10000,
-        repo: str = "uonlp/CulturaX",
         max_length: int = 1024,
         tokenizer: AutoTokenizer = None,
+        num_codebooks: int = 2,
     ):
         super().__init__()
 
+        self.source = Path(source)
         self.seed = seed
-        self.parquet_batch_size = parquet_batch_size
-        self.repo = repo
         self.max_length = max_length
         self.tokenizer = tokenizer
+        self.num_codebooks = num_codebooks
 
-        if files is None and prefix is None:
-            raise ValueError("Either files or prefix must be specified")
-
-        if prefix is not None:
-            files = HfApi().list_repo_files(repo, repo_type="dataset")
+        if self.source.is_file():
+            with open(self.source, "r") as f:
+                files = f.read().strip().split("\n")
+            self.root = self.source.parent
+        else:
             files = [
-                f for f in files if f.startswith(prefix) and f.endswith(".parquet")
+                str(i.relative_to(self.source)) for i in self.source.rglob("*.jsonl")
             ]
-            log.info(f"Found {len(files)} files in {repo} with prefix {prefix}")
-        else:
-            if isinstance(files, str):
-                files = [files]
-
-            files = list(chain.from_iterable(map(braceexpand, files)))
-            log.info(f"Expanded {len(files)} files in {repo}")
+            self.root = self.source
 
         # Get sharded files
         self.files = sorted(files)
+
         Random(seed).shuffle(self.files)
 
     def __iter__(self):
@@ -105,142 +123,147 @@ class StreamTextDataset(IterableDataset):
             except Exception as e:
                 log.exception(f"Failed to parse {filename}: {e}")
 
-    def parse_data(self, filename: str):
-        for data in self.parse_data_internal(filename):
-            text = data["text"]
+    def read_jsonl(self, filename: str):
+        with open(self.root / filename, "rb") as f:
+            if filename.endswith(".zst"):
+                stream_reader = DCTX.stream_reader(f)
+            elif filename.endswith(".gz"):
+                stream_reader = gzip.open(f, "rb")
+            elif filename.endswith(".jsonl"):
+                stream_reader = f
+            else:
+                raise ValueError(f"Unknown file type: {filename}")
 
+            stream = io.TextIOWrapper(stream_reader, encoding="utf-8")
+
+            # Parse jsonl
+            for line in stream:
+                line = json.loads(line)
+                yield line
+
+    def parse_data(self, filename: str):
+        for line in self.read_jsonl(filename):
             # encode
             tokens = self.tokenizer.encode(
-                text,
+                line["text"],
                 add_special_tokens=False,
                 truncation=False,
                 max_length=10**6,
             )
 
-            # Random choice self.max_length
-            if len(tokens) > self.max_length:
-                start = random.randint(0, len(tokens) - self.max_length)
-                tokens = tokens[start : start + self.max_length - 1]
-
             tokens = (
                 [self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
             )
-            # Pad dims
-            placeholder_multi_codebook = torch.zeros((4, len(tokens)), dtype=torch.long)
-
-            tokens = torch.concat(
-                [
-                    torch.tensor([tokens], dtype=torch.long),
-                    placeholder_multi_codebook,
-                ],
-                dim=0,
-            )
+
+            if len(tokens) > self.max_length:
+                tokens = tokens[: self.max_length]
+
+            tokens = self.pad_codebooks(tokens)
             labels = tokens.clone()
             tokens = tokens[:, :-1]
             labels = labels[:, 1:]
-            labels[1:] = -100  # remove all placeholders
+            labels[1:] = -100  # no loss on codebook
 
             yield {"tokens": tokens, "labels": labels}
 
-    def parse_data_internal(self, filename: str):
-        url = f"https://huggingface.co/datasets/{self.repo}/resolve/main/{filename}"
+    def pad_codebooks(self, tokens):
+        placeholder_multi_codebook = (
+            torch.zeros((self.num_codebooks, len(tokens)), dtype=torch.long)
+            + CODEBOOK_PAD_TOKEN_ID
+        )
+        return torch.concat(
+            [
+                torch.tensor([tokens], dtype=torch.long),
+                placeholder_multi_codebook,
+            ],
+            dim=0,
+        )
+
 
-        with xopen(url, mode="rb") as stream:
-            parquet_file = pq.ParquetFile(stream)
+class TextInstructionDataset(TextPretrainDataset):
+    def parse_data(self, filename: str):
+        for line in self.read_jsonl(filename):
+            messages = []
+            for conversation in line["conversations"]:
+                role = {
+                    "human": "user",
+                    "gpt": "assistant",
+                    "system": "system",
+                }[conversation["from"]]
+
+                message = Message(
+                    role=role,
+                    parts=[conversation["value"]],
+                )
+                messages.append(message)
+
+            conversation = Conversation(messages=messages)
+            tokens, labels = encode_conversation(
+                conversation,
+                self.tokenizer,
+                num_codebooks=self.num_codebooks,
+            )
 
-            for batch in parquet_file.iter_batches(
-                batch_size=self.parquet_batch_size, columns=["text"]
-            ):
-                # In-batch shuffling
-                texts = [{"text": text.as_py()} for text in batch["text"]]
-                random.shuffle(texts)
-                yield from texts
+            yield {"tokens": tokens, "labels": labels}
 
 
-class AutoAugTextDataset(IterableDataset):
-    """
-    Auto Augment Dataset by Speaker
+def semantic_to_tensor(semantics):
+    num_codebooks = len(semantics)
+    codes = [[] for _ in range(num_codebooks)]
 
-    1. Random concatenate multiple sentences from the same speaker to form a longer sentence
-    2. Automatically normalize the text
+    for book_idx, book in zip(range(num_codebooks), semantics):
+        for j in book.values:
+            codes[book_idx].append(int(j))
 
-    For interactive mode, we use the following format (multiple sequences):
-    <s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
+    return torch.tensor(codes, dtype=torch.int)
 
-    For non-interactive mode, we use the following format (one long sequence):
-    <s> [INST] text [/INST] ... </s>
-    """
 
+class AutoTextSemanticInstructionDataset(IterableDataset):
     def __init__(
         self,
         proto_files: list[str],
         seed: int = 42,
-        interactive_prob: float = 0.5,
         max_length: int = 1024,
         tokenizer: AutoTokenizer = None,
-        use_speaker: bool | float = True,
-        causual: bool = True,
-        use_negative_samples: bool = False,
+        causual: Union[bool, float] = True,
         num_codebooks: Optional[int] = None,
         skip_text_prob: float = 0.0,
+        asr_prob: float = 0.0,
     ):
         """
         Args:
             proto_files: proto buf files if using local data
             seed: random seed
-            interactive_prob: probability to use interactive mode
             max_length: max length of the text
             tokenizer: tokenizer
-            use_speaker: include speaker information in the prompt
             causual: use causual sampling when using local data, disable will lead to random sampling
-            use_negative_samples: generate negative samples
             num_codebooks: number of codebooks, if None, it will be automatically detected
             skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
+            asr_prob: probability to use ASR
         """
 
         super().__init__()
 
-        assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
+        assert 0 <= skip_text_prob <= 1, "skip_text_prob must be in [0, 1]"
+        assert 0 <= asr_prob <= 1, "asr_prob must be in [0, 1]"
 
         self.seed = seed
         self.max_length = max_length
         self.tokenizer = tokenizer
-        self.interactive_prob = interactive_prob
-        self.use_speaker = use_speaker
         self.proto_files = proto_files
         self.causual = causual
-        self.use_negative_samples = use_negative_samples
         self.num_codebooks = num_codebooks
         self.skip_text_prob = skip_text_prob
-
-        self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>")
+        self.asr_prob = asr_prob
         self.groups = None
 
     def init_mock_data_server(self):
         if self.groups is not None:
             return
 
-        # Expand the proto files
-        expanded_proto_files = []
-        for filename in self.proto_files:
-            for i in braceexpand(filename):
-                i = Path(i)
-                if i.is_file():
-                    expanded_proto_files.append(i)
-                elif i.is_dir():
-                    expanded_proto_files.extend(i.rglob("*.proto"))
-                    expanded_proto_files.extend(i.rglob("*.protos"))
-                else:
-                    raise ValueError(f"{i} is not a file or directory")
-
-        expanded_proto_files = sorted(expanded_proto_files)
-        Random(self.seed).shuffle(expanded_proto_files)
-
         self.groups = []
-        shard_proto_files = split_by_rank_worker(expanded_proto_files)
-        log.info(
-            f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
-        )
+        shard_proto_files = expand_split_proto_files(self.proto_files, seed=self.seed)
+        log.info(f"Reading {len(shard_proto_files)} files")
 
         count = 0
         for filename in shard_proto_files:
@@ -279,7 +302,11 @@ class AutoAugTextDataset(IterableDataset):
         # choice group based on their number of samples
         group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
 
-        if self.causual:
+        causual = self.causual
+        if isinstance(self.causual, float):
+            causual = random.random() < self.causual
+
+        if causual:
             # Sample in order
             if num_samples >= len(group.sentences):
                 samples = group.sentences
@@ -298,7 +325,6 @@ class AutoAugTextDataset(IterableDataset):
         )
 
     def augment(self):
-        final_text, final_semantic = [], []
         response = self.sample_data()
         if len(response.samples) == 0:
             # Invalid group
@@ -306,29 +332,9 @@ class AutoAugTextDataset(IterableDataset):
 
         samples = list(response.samples)
         idx = 0
-        use_interactive = random.random() < self.interactive_prob
-
-        if use_interactive is False:
-            # Random sample based on speaker using a truncated normal distribution
-            a = torch.tensor([0], dtype=torch.float32)
-            torch.nn.init.trunc_normal_(
-                a,
-                mean=self.max_length // 2,
-                std=self.max_length // 4,
-                a=10,
-                b=self.max_length,
-            )
-            remaining_tokens = a.long().item() - 4
-        else:
-            remaining_tokens = self.max_length
-
-        # Use speaker
-        if isinstance(self.use_speaker, float):
-            use_speaker = random.random() < self.use_speaker
-        else:
-            use_speaker = self.use_speaker
+        remaining_tokens = self.max_length
 
-        all_tokens, all_labels = [], []
+        all_messages = []
         while remaining_tokens > 0 and len(samples) > 0:
             sentence = samples.pop(0)
 
@@ -336,37 +342,52 @@ class AutoAugTextDataset(IterableDataset):
             text, length = self.tokenize_sentence(text)
             remaining_tokens -= length + len(sentence.semantics[0].values)
 
-            if use_interactive is False:
-                final_text.append(text)
-                final_semantic.append(sentence.semantics)
+            # For interactive mode, we only apply speaker for the first sentence
+            # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
+
+            if random.random() < self.asr_prob:
+                all_messages.append(
+                    Message(
+                        role="user",
+                        parts=[
+                            random.choice(asr_instructions),
+                            semantic_to_tensor(sentence.semantics),
+                        ],
+                    )
+                )
+                all_messages.append(
+                    Message(
+                        role="assistant",
+                        parts=[text],
+                    )
+                )
             else:
-                # For interactive mode, we only apply speaker for the first sentence
-                # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
-                tokens, labels = self.pack_sentences(
-                    sentences=[text],
-                    semantics=[sentence.semantics],
-                    speaker=response.name if use_speaker else None,
-                    add_bos=idx == 0,
-                    skip_text=random.random() < self.skip_text_prob,
+                skip_text = random.random() < self.skip_text_prob
+                if skip_text:
+                    text = SKIP_TEXT_STRING
+
+                all_messages.append(
+                    Message(
+                        role="user",
+                        parts=[random.choice(tts_instructions) + text],
+                        mask_labels=skip_text,
+                    )
+                )
+                all_messages.append(
+                    Message(
+                        role="assistant",
+                        parts=[semantic_to_tensor(sentence.semantics)],
+                        mask_labels=skip_text,
+                    )
                 )
-
-                all_tokens.append(tokens)
-                all_labels.append(labels)
 
             idx += 1
 
-        if use_interactive is False:
-            tokens, labels = self.pack_sentences(
-                final_text,
-                semantics=final_semantic,
-                speaker=response.name if use_speaker else None,
-                add_bos=True,
-            )
-            all_tokens.append(tokens)
-            all_labels.append(labels)
-
-        tokens = torch.cat(all_tokens, dim=1)
-        labels = torch.cat(all_labels, dim=1)
+        tokens, labels = encode_conversation(
+            Conversation(messages=all_messages),
+            self.tokenizer,
+            num_codebooks=self.num_codebooks,
+        )
 
         # Verify that the length is correct
         assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
@@ -374,156 +395,71 @@ class AutoAugTextDataset(IterableDataset):
         # Verify bos token
         assert tokens[0, 0] == self.tokenizer.bos_token_id
 
-        data = {"tokens": tokens, "labels": labels}
-
-        if self.use_negative_samples:
-            negative_samples = self.generate_negative_samples(all_tokens, all_labels)
-            data.update(negative_samples)
-
-        return data
-
-    def generate_negative_samples(self, all_tokens, all_labels):
-        new_tokens, new_labels = [], []
-
-        for tokens, labels in zip(all_tokens, all_labels):
-            # If all codebooks are not -100, we find where it starts
-            start = torch.where(labels[1:].sum(0) != -100 * (labels.size(0) - 1))[0][0]
-            assert (labels[1:, start:] != -100).all()  # This shouldn't happen
+        return {"tokens": tokens, "labels": labels}
 
-            mode = random.choice(["repeat", "lost", "noise"])
-            begin = random.randint(start, labels.size(1) - 1)
-            end = random.randint(begin, labels.size(1) - 1)
 
-            if mode == "repeat":
-                tokens = torch.cat(
-                    [
-                        tokens[:, :begin],
-                        tokens[:, begin:end],
-                        tokens[:, begin:end],
-                        tokens[:, end:],
-                    ],
-                    dim=1,
-                )
-                labels = torch.cat(
-                    [
-                        labels[:, :begin],
-                        labels[:, begin:end],
-                        labels[:, begin:end],
-                        labels[:, end:],
-                    ],
-                    dim=1,
-                )
-            elif mode == "lost":
-                tokens = torch.cat([tokens[:, :begin], tokens[:, end:]], dim=1)
-                labels = torch.cat([labels[:, :begin], labels[:, end:]], dim=1)
-            elif mode == "noise":
-                middle_tokens, middle_labels = (
-                    tokens[:, begin:end],
-                    labels[:, begin:end],
-                )
-                random_order0 = torch.randperm(middle_tokens.size(1))
-                random_order1 = torch.randperm(middle_tokens.size(1))
-                middle_tokens = middle_tokens[:, random_order0]
-                middle_labels = middle_labels[:, random_order1]
-                tokens = torch.cat(
-                    [tokens[:, :begin], middle_tokens, tokens[:, end:]], dim=1
-                )
-                labels = torch.cat(
-                    [labels[:, :begin], middle_labels, labels[:, end:]], dim=1
-                )
-
-            new_tokens.append(tokens)
-            new_labels.append(labels)
-
-        tokens = torch.cat(new_tokens, dim=1)
-        labels = torch.cat(new_labels, dim=1)
-
-        # Verify that the length is correct
-        assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
-
-        return {"negative_tokens": tokens, "negative_labels": labels}
-
-    def pack_sentences(
+class SemanticInstructionDataset(IterableDataset):
+    def __init__(
         self,
-        sentences: list[str],
-        semantics: list,
-        speaker: Optional[str] = None,
-        add_bos: bool = True,
-        skip_text: bool = False,
+        proto_files: list[str],
+        seed: int = 42,
+        max_length: int = 1024,
+        tokenizer: AutoTokenizer = None,
+        num_codebooks: Optional[int] = None,
     ):
-        if speaker is None:
-            speaker = "assistant"
+        super().__init__()
 
-        cated_sentences = " ".join(sentences)
-        if skip_text:
-            cated_sentences = SKIP_TEXT_STRING
+        self.seed = seed
+        self.max_length = max_length
+        self.tokenizer = tokenizer
+        self.proto_files = proto_files
+        self.num_codebooks = num_codebooks
 
-        final_text = "<|im_start|>user<|im_sep|>" + cated_sentences + "<|im_end|>"
-        final_text = final_text + f"<|im_start|>{speaker}<|im_sep|>"
+    def get_data_generator(self):
+        shard_proto_files = expand_split_proto_files(self.proto_files, seed=self.seed)
+        random.shuffle(shard_proto_files)
+        log.info(f"Fetched {len(shard_proto_files)} files")
 
-        encoded = self.tokenizer.encode(
-            final_text,
-            add_special_tokens=False,
-            truncation=False,
-            max_length=10**6,
-        )
-        semantic_length = sum([len(i[0].values) for i in semantics])
-        prompt_length = len(encoded)
-        num_codebooks = (
-            len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
-        )
+        for filename in shard_proto_files:
+            with open(filename, "rb") as f:
+                for group in read_pb_stream(f):
+                    yield group
 
-        bos_bias = 1 if add_bos else 0
+    def pack_one_group(self, group):
+        sentences = group.sentences
 
-        # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
-        tokens = (
-            encoded
-            + [self.semantic_token_id] * semantic_length
-            + self.tokenizer.convert_tokens_to_ids(
-                ["<|im_end|>", "<|end_of_sequence|>"]
+        messages = []
+        for idx, sentence in enumerate(sentences):
+            role = "user" if idx % 2 == 0 else "assistant"
+            semantic = semantic_to_tensor(sentence.semantics)
+            text = random.choice(sentence.texts)
+            parts = [semantic]
+            if role == "assistant":
+                # Let model to predict the text first
+                prev_text = random.choice(sentences[idx - 1].texts)
+                # parts.insert(0, f"Q: {prev_text}\nA: {text}")
+            messages.append(
+                Message(
+                    role=role,
+                    parts=parts,
+                )
             )
-        )
-
-        if add_bos:
-            tokens = [self.tokenizer.bos_token_id] + tokens
-
-        # Codebook bos/padding: 0, eos: 1
-        codes = [
-            [CODEBOOK_PAD_TOKEN_ID] * (prompt_length + bos_bias)
-            for _ in range(num_codebooks)
-        ]
-        for segment in semantics:
-            for book_idx, book in zip(range(num_codebooks), segment):
-                for j in book.values:
-                    codes[book_idx].append(int(j) + 2)
 
-        for book in codes:
-            book.extend([CODEBOOK_EOS_TOKEN_ID] * 2)
-
-        tokens = [tokens] + codes
-
-        tokens = torch.tensor(tokens, dtype=torch.long)
-        labels = tokens.clone()
-
-        if skip_text:
-            # If text is not provided, the sentence is used for condition only, all labels are -100
-            torch.fill_(labels, -100)
-            return tokens, labels
-
-        # Mask out the <s> tokens for semantic, predict semantic tokens only
-        # Since we don't mask out the input tokens, the language modeling still works
-        labels[1:, : (prompt_length + bos_bias)] = -100
-
-        tokens = tokens[:, :-1]
-        labels = labels[:, 1:]
+        conversation = Conversation(messages=messages)
+        tokens, labels = encode_conversation(
+            conversation,
+            self.tokenizer,
+            num_codebooks=self.num_codebooks,
+        )
 
-        # Verify the padding is correct, and the last token is eos
-        assert add_bos is False or tokens[0, 0] == self.tokenizer.bos_token_id
-        assert (tokens[1:, : prompt_length + bos_bias] == CODEBOOK_PAD_TOKEN_ID).all()
-        assert labels[0, -1] == self.tokenizer.eos_token_id
-        assert (labels[1:, -2:] == CODEBOOK_EOS_TOKEN_ID).all()
+        return {"tokens": tokens, "labels": labels}
 
-        return tokens, labels
+    def __iter__(self):
+        for group in self.get_data_generator():
+            try:
+                yield self.pack_one_group(group)
+            except Exception as e:
+                log.exception(f"Failed to parse {group}: {e}")
 
 
 @dataclass
@@ -633,8 +569,18 @@ class InterleaveDataset(IterableDataset):
 class TextDataModule(LightningDataModule):
     def __init__(
         self,
-        train_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
-        val_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
+        train_dataset: Union[
+            AutoTextSemanticInstructionDataset,
+            TextPretrainDataset,
+            TextInstructionDataset,
+            InterleaveDataset,
+        ],
+        val_dataset: Union[
+            AutoTextSemanticInstructionDataset,
+            TextPretrainDataset,
+            TextInstructionDataset,
+            InterleaveDataset,
+        ],
         batch_size: int = 32,
         tokenizer: AutoTokenizer = None,
         max_length: int = 1024,
@@ -671,17 +617,36 @@ class TextDataModule(LightningDataModule):
 if __name__ == "__main__":
     from tqdm import tqdm
 
-    ds = AutoAugTextDataset(
-        ["data/protos"],
-        tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"),
-        use_speaker=False,
-        interactive_prob=1.0,
-        use_negative_samples=False,
-        skip_text_prob=0.5,
+    # ds = AutoTextSemanticInstructionDataset(
+    #     ["data/protos/sft/val/11labs"],
+    #     tokenizer=AutoTokenizer.from_pretrained("checkpoints/fish-speech-agent-1"),
+    #     skip_text_prob=1.0,
+    #     asr_prob=0.0,
+    #     num_codebooks=2,
+    # )
+    # ds = TextInstructionDataset(
+    #     source="data/openhermes2_5",
+    #     tokenizer=AutoTokenizer.from_pretrained("checkpoints/fish-speech-agent-1"),
+    # )
+
+    ds = SemanticInstructionDataset(
+        proto_files=["data/protos/sft/val/ultrachat_200k_spoken_openai"],
+        tokenizer=AutoTokenizer.from_pretrained("checkpoints/fish-speech-agent-1"),
+        num_codebooks=2,
     )
 
     for i in ds:
-        print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
+        # print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
         # i["labels"][0][i["labels"][0] == -100] = 0
         # print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False))
+
+        length = i["tokens"].size(1)
+        print(i["tokens"].size(), i["tokens"].dtype)
+        for j in range(length):
+            print(
+                ds.tokenizer.decode(i["tokens"][0, j]),
+                i["tokens"][:, j],
+                i["labels"][:, j],
+            )
+            input()
         break

+ 0 - 3
fish_speech/models/text2semantic/__init__.py

@@ -1,3 +0,0 @@
-from .lit_module import TextToSemantic
-
-__all__ = ["TextToSemantic"]

+ 11 - 7
fish_speech/models/text2semantic/lit_module.py

@@ -6,8 +6,8 @@ import torch.nn.functional as F
 from lightning.pytorch.utilities.types import OptimizerLRScheduler
 
 import fish_speech.utils as utils
+from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
 from fish_speech.models.text2semantic.llama import NaiveTransformer
-from fish_speech.models.text2semantic.lora_utils import LoraConfig, setup_lora
 
 log = utils.RankedLogger(__name__, rank_zero_only=True)
 
@@ -137,15 +137,15 @@ class TextToSemantic(L.LightningModule):
             labels, negative_labels = labels.chunk(2)
 
         # Generate labels
-        base_loss = F.cross_entropy(
-            token_logits.reshape(-1, token_logits.size(-1)),
+        base_loss = fast_cross_entropy_loss(
+            token_logits.view(-1, token_logits.size(-1)),
             labels[:, 0].reshape(-1),
             ignore_index=-100,
         )
 
         codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
-        semantic_loss = F.cross_entropy(
-            codebook_logits.reshape(-1, codebook_logits.size(-1)),
+        semantic_loss = fast_cross_entropy_loss(
+            codebook_logits.view(-1, codebook_logits.size(-1)),
             codebook_labels.reshape(-1),
             ignore_index=-100,
         )
@@ -281,11 +281,15 @@ class TextToSemantic(L.LightningModule):
         return loss
 
     def get_accuracy(self, logits, labels):
+        mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID)
+        if mask.sum() == 0:
+            return torch.tensor(0.0, device=logits.device)
+
         _, indices = logits.topk(5, dim=-1)
         correct = indices.eq(labels.unsqueeze(-1))
-        correct[labels == -100] = 0
+        correct[~mask] = 0
         correct = correct.sum()
-        accuracy = correct / (labels != -100).sum()
+        accuracy = correct / mask.sum()
 
         return accuracy
 

+ 206 - 68
fish_speech/models/text2semantic/llama.py

@@ -1,5 +1,7 @@
+import json
 import math
 from dataclasses import dataclass
+from pathlib import Path
 from typing import Optional
 
 import torch
@@ -7,7 +9,16 @@ import torch.nn as nn
 from einops import rearrange
 from torch import Tensor
 from torch.nn import functional as F
+from torch.nn.attention import SDPBackend, sdpa_kernel
 from torch.utils.checkpoint import checkpoint
+from transformers import AutoTokenizer
+
+from fish_speech.conversation import SEMANTIC_TOKEN
+from fish_speech.utils import RankedLogger
+
+from .lora import LoraConfig, setup_lora
+
+log = RankedLogger(__name__, rank_zero_only=True)
 
 
 def find_multiple(n: int, k: int) -> int:
@@ -18,6 +29,8 @@ def find_multiple(n: int, k: int) -> int:
 
 @dataclass
 class BaseModelArgs:
+    model_type: str = "base"
+
     vocab_size: int = 32000
     n_layer: int = 32
     n_head: int = 32
@@ -29,16 +42,19 @@ class BaseModelArgs:
     norm_eps: float = 1e-5
     max_seq_len: int = 2048
     dropout: float = 0.0
+    tie_word_embeddings: bool = True
+    attention_qkv_bias: bool = False
 
     # Codebook configs
     codebook_size: int = 160
     num_codebooks: int = 4
-    num_in_codebooks: Optional[int] = None
-    codebook_padding_idx: int = 0
 
     # Gradient checkpointing
     use_gradient_checkpointing: bool = True
 
+    # Initialize the model
+    initializer_range: float = 0.02
+
     def __post_init__(self):
         if self.n_local_heads == -1:
             self.n_local_heads = self.n_head
@@ -46,18 +62,41 @@ class BaseModelArgs:
             hidden_dim = 4 * self.dim
             n_hidden = int(2 * hidden_dim / 3)
             self.intermediate_size = find_multiple(n_hidden, 256)
-        if self.num_in_codebooks is None:
-            self.num_in_codebooks = self.num_codebooks
         self.head_dim = self.dim // self.n_head
 
+    @staticmethod
+    def from_pretrained(path: str):
+        path = Path(path)
+
+        if path.is_dir():
+            path = path / "config.json"
+
+        with open(path, "r") as f:
+            data = json.load(f)
+
+        match data["model_type"]:
+            case "naive":
+                cls = NaiveModelArgs
+            case "dual_ar":
+                cls = DualARModelArgs
+            case _:
+                raise ValueError(f"Unknown model type: {data['model_type']}")
+
+        return cls(**data)
+
+    def save(self, path: str):
+        with open(path, "w") as f:
+            json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False)
+
 
 @dataclass
 class NaiveModelArgs(BaseModelArgs):
-    pass
+    model_type: str = "naive"
 
 
 @dataclass
 class DualARModelArgs(BaseModelArgs):
+    model_type: str = "dual_ar"
     n_fast_layer: int = 4
 
 
@@ -95,24 +134,35 @@ class BaseTransformerForwardResult:
 
 
 class BaseTransformer(nn.Module):
-    def __init__(self, config: BaseModelArgs) -> None:
+    def __init__(
+        self, config: BaseModelArgs, tokenizer: AutoTokenizer, init_weights: bool = True
+    ) -> None:
         super().__init__()
         self.config = config
+        self.tokenizer = tokenizer
+
+        self.semantic_token_id = tokenizer.convert_tokens_to_ids(SEMANTIC_TOKEN)
 
         # Slow transformer
         self.embeddings = nn.Embedding(
-            config.vocab_size + config.codebook_size * config.num_in_codebooks,
+            config.vocab_size,
+            config.dim,
+        )
+        self.codebook_embeddings = nn.Embedding(
+            config.codebook_size * config.num_codebooks,
             config.dim,
         )
         self.layers = nn.ModuleList(
             TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
         )
         self.norm = RMSNorm(config.dim, eps=config.norm_eps)
-        self.output = nn.Linear(
-            config.dim,
-            config.vocab_size,
-            bias=False,
-        )
+
+        if self.config.tie_word_embeddings is False:
+            self.output = nn.Linear(
+                config.dim,
+                config.vocab_size,
+                bias=False,
+            )
 
         self.register_buffer(
             "freqs_cis",
@@ -139,6 +189,9 @@ class BaseTransformer(nn.Module):
         self.max_batch_size = -1
         self.max_seq_len = -1
 
+        if init_weights:
+            self.apply(self._init_weights)
+
     def setup_caches(
         self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
     ):
@@ -161,11 +214,9 @@ class BaseTransformer(nn.Module):
 
     def embed(self, x: Tensor) -> Tensor:
         vocab_embeds = [self.embeddings(x[:, 0])]
-        for i in range(self.config.num_in_codebooks):
-            emb = self.embeddings(
-                x[:, i + 1] + i * self.config.codebook_size + self.config.vocab_size
-            )
-            emb[x[:, i + 1] == self.config.codebook_padding_idx] = 0
+        for i in range(self.config.num_codebooks):
+            emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size)
+            emb[x[:, 0] != self.semantic_token_id] = 0
             vocab_embeds.append(emb)
 
         x = torch.stack(vocab_embeds, dim=3)
@@ -174,21 +225,23 @@ class BaseTransformer(nn.Module):
         return x
 
     def forward(
-        self, inp: Tensor, key_padding_mask: Optional[Tensor] = None
+        self,
+        inp: Tensor,
+        key_padding_mask: Optional[Tensor] = None,
     ) -> BaseTransformerForwardResult:
-        # x: (batch, num_codebooks + 1, seq_len)
         seq_len = inp.size(2)
 
         # Here we want to merge the embeddings of the codebooks
         x = self.embed(inp)
 
-        mask = self.causal_mask[None, None, :seq_len, :seq_len]  # (B, N, Q, K)
         freqs_cis = self.freqs_cis[:seq_len]
 
         # Not that the causal mask here follows the definition of scaled_dot_product_attention
         # That is, FALSE means masked out
         # To maintain consistency, key_padding_mask use TRUE to mask out
+        mask = None
         if key_padding_mask is not None:
+            mask = self.causal_mask[None, None, :seq_len, :seq_len]  # (B, N, Q, K)
             mask = mask & key_padding_mask[:, None, None, :].logical_not()
 
         for layer in self.layers:
@@ -199,7 +252,11 @@ class BaseTransformer(nn.Module):
 
         # We got slow_out here
         slow_out = self.norm(x)
-        token_logits = self.output(slow_out)
+
+        if self.config.tie_word_embeddings:
+            token_logits = F.linear(slow_out, self.embeddings.weight)
+        else:
+            token_logits = self.output(slow_out)
 
         return BaseTransformerForwardResult(
             logits=token_logits,
@@ -207,7 +264,10 @@ class BaseTransformer(nn.Module):
         )
 
     def forward_generate(
-        self, x: Tensor, input_pos: Optional[Tensor] = None
+        self,
+        x: Tensor,
+        input_pos: Optional[Tensor] = None,
+        return_all: bool = False,
     ) -> BaseTransformerForwardResult:
         # This is used for generation, optimized for torch compile
         assert (
@@ -225,22 +285,99 @@ class BaseTransformer(nn.Module):
             x = layer(x, freqs_cis, mask, input_pos=input_pos)
 
         # If prefill, we only calculate the logits of last token
-        if x.size(1) > 1:
+        if x.size(1) > 1 and not return_all:
             x = x[:, -1:]
 
         # We got slow_out here
         slow_out = self.norm(x)
-        token_logits = self.output(slow_out)
+
+        if self.config.tie_word_embeddings:
+            token_logits = F.linear(slow_out, self.embeddings.weight)
+        else:
+            token_logits = self.output(slow_out)
 
         return BaseTransformerForwardResult(
             logits=token_logits,
             hidden_states=x,
         )
 
+    def _init_weights(self, module):
+        std = self.config.initializer_range
+        if isinstance(module, nn.Linear):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.Embedding):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.padding_idx is not None:
+                module.weight.data[module.padding_idx].zero_()
+
+    @staticmethod
+    def from_pretrained(
+        path: str,
+        load_weights: bool = False,
+        max_length: int | None = None,
+        lora_config: LoraConfig | None = None,
+        rope_base: int | None = None,
+    ) -> "BaseTransformer":
+        config = BaseModelArgs.from_pretrained(path)
+        if max_length is not None:
+            config.max_seq_len = max_length
+            log.info(f"Override max_seq_len to {max_length}")
+
+        if rope_base is not None:
+            config.rope_base = rope_base
+            log.info(f"Override rope_base to {rope_base}")
+
+        match config.model_type:
+            case "naive":
+                model_cls = NaiveTransformer
+            case "dual_ar":
+                model_cls = DualARTransformer
+            case _:
+                raise ValueError(f"Unknown model type: {config.model_type}")
+
+        tokenizer = AutoTokenizer.from_pretrained(str(path))
+        log.info(f"Loading model from {path}, config: {config}")
+        model = model_cls(config, tokenizer=tokenizer)
+
+        if lora_config is not None:
+            setup_lora(model, lora_config)
+            log.info(f"LoRA setup: {lora_config}")
+
+        if load_weights is False:
+            log.info("Randomly initialized model")
+        else:
+            weights = torch.load(
+                Path(path) / "model.pth", map_location="cpu", mmap=True
+            )
+            err = model.load_state_dict(weights, strict=False, assign=True)
+            log.info(f"Loaded weights with error: {err}")
+
+        return model
+
+    def save_pretrained(self, path: str, drop_lora: bool = False):
+        path = Path(path)
+        path.mkdir(parents=True, exist_ok=True)
+
+        self.config.save(path / "config.json")
+        state_dict = self.state_dict()
+
+        if drop_lora:
+            for key in list(state_dict.keys()):
+                if "lora" not in key:
+                    continue
+
+                state_dict.pop(key)
+                log.info(f"Drop LoRA parameter: {key}")
+
+        torch.save(state_dict, path / "model.pth")
+        self.tokenizer.save_pretrained(path)
+
 
 class NaiveTransformer(BaseTransformer):
-    def __init__(self, config: NaiveModelArgs) -> None:
-        super().__init__(config)
+    def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
+        super().__init__(config, init_weights=False, tokenizer=tokenizer)
 
         self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
         self.codebook_output = nn.Linear(
@@ -249,6 +386,8 @@ class NaiveTransformer(BaseTransformer):
             bias=False,
         )
 
+        self.apply(self._init_weights)
+
     def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
         token_logits = result.logits
         x = result.hidden_states
@@ -265,9 +404,14 @@ class NaiveTransformer(BaseTransformer):
         )
 
     def forward(
-        self, inp: Tensor, key_padding_mask: Optional[Tensor] = None
+        self,
+        inp: Tensor,
+        key_padding_mask: Optional[Tensor] = None,
     ) -> TransformerForwardResult:
-        result = super().forward(inp, key_padding_mask)
+        result = super().forward(
+            inp=inp,
+            key_padding_mask=key_padding_mask,
+        )
         return self.decode(result)
 
     def forward_generate(
@@ -278,13 +422,11 @@ class NaiveTransformer(BaseTransformer):
 
 
 class DualARTransformer(BaseTransformer):
-    def __init__(self, config: DualARModelArgs) -> None:
-        super().__init__(config)
+    def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
+        super().__init__(config, init_weights=False, tokenizer=tokenizer)
 
         # Fast transformer
-        self.fast_embeddings = nn.Embedding(
-            config.codebook_size, config.dim, padding_idx=config.codebook_padding_idx
-        )
+        self.fast_embeddings = nn.Embedding(config.codebook_size, config.dim)
 
         # The equivalent bs is so large that sdpa doesn't work
         self.fast_layers = nn.ModuleList(
@@ -297,6 +439,8 @@ class DualARTransformer(BaseTransformer):
             bias=False,
         )
 
+        self.apply(self._init_weights)
+
     def setup_caches(
         self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
     ):
@@ -316,7 +460,9 @@ class DualARTransformer(BaseTransformer):
             )
 
     def forward(
-        self, inp: Tensor, key_padding_mask: Optional[Tensor] = None
+        self,
+        inp: Tensor,
+        key_padding_mask: Optional[Tensor] = None,
     ) -> TransformerForwardResult:
         parent_result = super().forward(inp, key_padding_mask)
         token_logits = parent_result.logits
@@ -340,6 +486,11 @@ class DualARTransformer(BaseTransformer):
         # Remove padded part
         codebooks = rearrange(codebooks, "b n s -> (b s) n")
         codebook_mask = (codebooks == self.config.codebook_padding_idx).all(dim=-1)
+
+        if torch.all(codebook_mask):
+            # If all codebooks are padded, we keep first 8 to make sure the model runs
+            codebook_mask[:8] = False
+
         x_bs, x_len = x.size(0), x.size(1)
         x = x[~codebook_mask]
 
@@ -422,7 +573,9 @@ class Attention(nn.Module):
 
         total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
         # key, query, value projections for all heads, but in a batch
-        self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
+        self.wqkv = nn.Linear(
+            config.dim, total_head_dim, bias=config.attention_qkv_bias
+        )
         self.wo = nn.Linear(config.dim, config.dim, bias=False)
         self.kv_cache = None
 
@@ -469,13 +622,24 @@ class Attention(nn.Module):
         v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
 
         if self.use_sdpa:
-            y = F.scaled_dot_product_attention(
-                q,
-                k,
-                v,
-                attn_mask=mask,
-                dropout_p=self.dropout if self.training else 0.0,
-            )
+            if mask is None:
+                with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
+                    y = F.scaled_dot_product_attention(
+                        q,
+                        k,
+                        v,
+                        dropout_p=self.dropout if self.training else 0.0,
+                        is_causal=True,
+                        # No thirdparty attn_mask here to use flash_attention
+                    )
+            else:
+                y = F.scaled_dot_product_attention(
+                    q,
+                    k,
+                    v,
+                    attn_mask=mask,
+                    dropout_p=self.dropout if self.training else 0.0,
+                )
         else:
             y = self.eq_scaled_dot_product_attention(
                 q,
@@ -567,29 +731,3 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
 
     x_out2 = x_out2.flatten(3)
     return x_out2.type_as(x)
-
-
-if __name__ == "__main__":
-    args = DualARModelArgs(
-        max_seq_len=4096,
-        vocab_size=32312,
-        n_layer=12,
-        n_fast_layer=4,
-        n_head=12,
-        dim=768,
-        rope_base=10000,
-        norm_eps=1e-5,
-        codebook_size=128,
-        num_codebooks=4,
-    )
-
-    model = DualARTransformer(args)
-    model = model.cuda().bfloat16()
-    print("Total params:", sum(i.numel() for i in model.parameters()) / 1024 / 1024)
-
-    inputs = torch.randint(0, 100, (2, 5, 128)).cuda()
-    key_padding_mask = torch.zeros(2, 128).bool().cuda()
-    key_padding_mask[0, 2:] = True
-    x1 = model(inputs, key_padding_mask=key_padding_mask)
-    print(x1.token_logits.shape)
-    print(x1.codebook_logits.shape)

+ 8 - 0
fish_speech/models/text2semantic/lora_utils.py → fish_speech/models/text2semantic/lora.py

@@ -20,6 +20,14 @@ def setup_lora(model, lora_config):
         lora_alpha=lora_config.lora_alpha,
     )
 
+    model.codebook_embeddings = lora.Embedding(
+        num_embeddings=model.codebook_embeddings.num_embeddings,
+        embedding_dim=model.codebook_embeddings.embedding_dim,
+        padding_idx=model.codebook_embeddings.padding_idx,
+        r=lora_config.r,
+        lora_alpha=lora_config.lora_alpha,
+    )
+
     # Replace output layer with a LoRA layer
     linears = [(model, "output")]
 

+ 0 - 3
fish_speech/models/vits_decoder/__init__.py

@@ -1,3 +0,0 @@
-from .lit_module import VITSDecoder
-
-__all__ = ["VITSDecoder"]

+ 0 - 394
fish_speech/models/vits_decoder/lit_module.py

@@ -1,394 +0,0 @@
-from typing import Any, Callable
-
-import lightning as L
-import torch
-import torch.nn.functional as F
-import wandb
-from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
-from matplotlib import pyplot as plt
-from torch import nn
-
-from fish_speech.models.vits_decoder.losses import (
-    discriminator_loss,
-    feature_loss,
-    generator_loss,
-    kl_loss,
-)
-from fish_speech.models.vqgan.utils import (
-    avg_with_mask,
-    plot_mel,
-    sequence_mask,
-    slice_segments,
-)
-
-
-class VITSDecoder(L.LightningModule):
-    def __init__(
-        self,
-        optimizer: Callable,
-        lr_scheduler: Callable,
-        generator: nn.Module,
-        discriminator: nn.Module,
-        mel_transform: nn.Module,
-        spec_transform: nn.Module,
-        hop_length: int = 512,
-        sample_rate: int = 44100,
-        freeze_discriminator: bool = False,
-        weight_mel: float = 45,
-        weight_kl: float = 0.1,
-    ):
-        super().__init__()
-
-        # Model parameters
-        self.optimizer_builder = optimizer
-        self.lr_scheduler_builder = lr_scheduler
-
-        # Generator and discriminator
-        self.generator = generator
-        self.discriminator = discriminator
-        self.mel_transform = mel_transform
-        self.spec_transform = spec_transform
-        self.freeze_discriminator = freeze_discriminator
-
-        # Loss weights
-        self.weight_mel = weight_mel
-        self.weight_kl = weight_kl
-
-        # Other parameters
-        self.hop_length = hop_length
-        self.sampling_rate = sample_rate
-
-        # Disable automatic optimization
-        self.automatic_optimization = False
-
-        if self.freeze_discriminator:
-            for p in self.discriminator.parameters():
-                p.requires_grad = False
-
-    def configure_optimizers(self):
-        # Need two optimizers and two schedulers
-        optimizer_generator = self.optimizer_builder(self.generator.parameters())
-        optimizer_discriminator = self.optimizer_builder(
-            self.discriminator.parameters()
-        )
-
-        lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
-        lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator)
-
-        return (
-            {
-                "optimizer": optimizer_generator,
-                "lr_scheduler": {
-                    "scheduler": lr_scheduler_generator,
-                    "interval": "step",
-                    "name": "optimizer/generator",
-                },
-            },
-            {
-                "optimizer": optimizer_discriminator,
-                "lr_scheduler": {
-                    "scheduler": lr_scheduler_discriminator,
-                    "interval": "step",
-                    "name": "optimizer/discriminator",
-                },
-            },
-        )
-
-    def training_step(self, batch, batch_idx):
-        optim_g, optim_d = self.optimizers()
-
-        audios, audio_lengths = batch["audios"], batch["audio_lengths"]
-        texts, text_lengths = batch["texts"], batch["text_lengths"]
-
-        audios = audios.float()
-        audios = audios[:, None, :]
-
-        with torch.no_grad():
-            gt_mels = self.mel_transform(audios)
-            gt_specs = self.spec_transform(audios)
-
-        spec_lengths = audio_lengths // self.hop_length
-        spec_masks = torch.unsqueeze(
-            sequence_mask(spec_lengths, gt_mels.shape[2]), 1
-        ).to(gt_mels.dtype)
-
-        (
-            fake_audios,
-            ids_slice,
-            y_mask,
-            (z, z_p, m_p, logs_p, m_q, logs_q),
-        ) = self.generator(
-            audios,
-            audio_lengths,
-            gt_specs,
-            spec_lengths,
-            texts,
-            text_lengths,
-        )
-
-        gt_mels = slice_segments(gt_mels, ids_slice, self.generator.segment_size)
-        spec_masks = slice_segments(spec_masks, ids_slice, self.generator.segment_size)
-        audios = slice_segments(
-            audios,
-            ids_slice * self.hop_length,
-            self.generator.segment_size * self.hop_length,
-        )
-        fake_mels = self.mel_transform(fake_audios.squeeze(1))
-
-        assert (
-            audios.shape == fake_audios.shape
-        ), f"{audios.shape} != {fake_audios.shape}"
-
-        # Discriminator
-        if self.freeze_discriminator is False:
-            y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(
-                audios, fake_audios.detach()
-            )
-
-            with torch.autocast(device_type=audios.device.type, enabled=False):
-                loss_disc, _, _ = discriminator_loss(y_d_hat_r, y_d_hat_g)
-
-            self.log(
-                f"train/discriminator/loss",
-                loss_disc,
-                on_step=True,
-                on_epoch=False,
-                prog_bar=False,
-                logger=True,
-                sync_dist=True,
-            )
-
-            optim_d.zero_grad()
-            self.manual_backward(loss_disc)
-            self.clip_gradients(
-                optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
-            )
-            optim_d.step()
-
-        # Adv Loss
-        y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(audios, fake_audios)
-
-        # Adversarial Loss
-        with torch.autocast(device_type=audios.device.type, enabled=False):
-            loss_adv, _ = generator_loss(y_d_hat_g)
-
-        self.log(
-            f"train/generator/adv",
-            loss_adv,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=False,
-            logger=True,
-            sync_dist=True,
-        )
-
-        with torch.autocast(device_type=audios.device.type, enabled=False):
-            loss_fm = feature_loss(y_d_hat_r, y_d_hat_g)
-
-        self.log(
-            f"train/generator/adv_fm",
-            loss_fm,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=False,
-            logger=True,
-            sync_dist=True,
-        )
-
-        with torch.autocast(device_type=audios.device.type, enabled=False):
-            loss_mel = avg_with_mask(
-                F.l1_loss(gt_mels, fake_mels, reduction="none"), spec_masks
-            )
-
-        self.log(
-            "train/generator/loss_mel",
-            loss_mel,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=False,
-            logger=True,
-            sync_dist=True,
-        )
-
-        loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, y_mask)
-
-        self.log(
-            "train/generator/loss_kl",
-            loss_kl,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=False,
-            logger=True,
-            sync_dist=True,
-        )
-
-        loss = (
-            loss_mel * self.weight_mel + loss_kl * self.weight_kl + loss_adv + loss_fm
-        )
-        self.log(
-            "train/generator/loss",
-            loss,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=True,
-            logger=True,
-            sync_dist=True,
-        )
-
-        # Backward
-        optim_g.zero_grad()
-
-        self.manual_backward(loss)
-        self.clip_gradients(
-            optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
-        )
-        optim_g.step()
-
-        # Manual LR Scheduler
-        scheduler_g, scheduler_d = self.lr_schedulers()
-        scheduler_g.step()
-        scheduler_d.step()
-
-    def validation_step(self, batch: Any, batch_idx: int):
-        audios, audio_lengths = batch["audios"], batch["audio_lengths"]
-        texts, text_lengths = batch["texts"], batch["text_lengths"]
-
-        audios = audios.float()
-        audios = audios[:, None, :]
-
-        gt_mels = self.mel_transform(audios)
-        gt_specs = self.spec_transform(audios)
-        spec_lengths = audio_lengths // self.hop_length
-        spec_masks = torch.unsqueeze(
-            sequence_mask(spec_lengths, gt_mels.shape[2]), 1
-        ).to(gt_mels.dtype)
-
-        prior_audios = self.generator.infer(
-            audios, audio_lengths, gt_specs, spec_lengths, texts, text_lengths
-        )
-        posterior_audios = self.generator.infer_posterior(gt_specs, spec_lengths)
-        prior_mels = self.mel_transform(prior_audios.squeeze(1))
-        posterior_mels = self.mel_transform(posterior_audios.squeeze(1))
-
-        min_mel_length = min(
-            gt_mels.shape[-1], prior_mels.shape[-1], posterior_mels.shape[-1]
-        )
-        gt_mels = gt_mels[:, :, :min_mel_length]
-        prior_mels = prior_mels[:, :, :min_mel_length]
-        posterior_mels = posterior_mels[:, :, :min_mel_length]
-
-        prior_mel_loss = avg_with_mask(
-            F.l1_loss(gt_mels, prior_mels, reduction="none"), spec_masks
-        )
-        posterior_mel_loss = avg_with_mask(
-            F.l1_loss(gt_mels, posterior_mels, reduction="none"), spec_masks
-        )
-
-        self.log(
-            "val/prior_mel_loss",
-            prior_mel_loss,
-            on_step=False,
-            on_epoch=True,
-            prog_bar=False,
-            logger=True,
-            sync_dist=True,
-        )
-
-        self.log(
-            "val/posterior_mel_loss",
-            posterior_mel_loss,
-            on_step=False,
-            on_epoch=True,
-            prog_bar=False,
-            logger=True,
-            sync_dist=True,
-        )
-
-        # only log the first batch
-        if batch_idx != 0:
-            return
-
-        for idx, (
-            mel,
-            prior_mel,
-            posterior_mel,
-            audio,
-            prior_audio,
-            posterior_audio,
-            audio_len,
-        ) in enumerate(
-            zip(
-                gt_mels,
-                prior_mels,
-                posterior_mels,
-                audios.detach().float(),
-                prior_audios.detach().float(),
-                posterior_audios.detach().float(),
-                audio_lengths,
-            )
-        ):
-            mel_len = audio_len // self.hop_length
-
-            image_mels = plot_mel(
-                [
-                    prior_mel[:, :mel_len],
-                    posterior_mel[:, :mel_len],
-                    mel[:, :mel_len],
-                ],
-                [
-                    "Prior (VQ)",
-                    "Posterior (Reconstruction)",
-                    "Ground-Truth",
-                ],
-            )
-
-            if isinstance(self.logger, WandbLogger):
-                self.logger.experiment.log(
-                    {
-                        "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
-                        "wavs": [
-                            wandb.Audio(
-                                audio[0, :audio_len],
-                                sample_rate=self.sampling_rate,
-                                caption="gt",
-                            ),
-                            wandb.Audio(
-                                prior_audio[0, :audio_len],
-                                sample_rate=self.sampling_rate,
-                                caption="prior",
-                            ),
-                            wandb.Audio(
-                                posterior_audio[0, :audio_len],
-                                sample_rate=self.sampling_rate,
-                                caption="posterior",
-                            ),
-                        ],
-                    },
-                )
-
-            if isinstance(self.logger, TensorBoardLogger):
-                self.logger.experiment.add_figure(
-                    f"sample-{idx}/mels",
-                    image_mels,
-                    global_step=self.global_step,
-                )
-                self.logger.experiment.add_audio(
-                    f"sample-{idx}/wavs/gt",
-                    audio[0, :audio_len],
-                    self.global_step,
-                    sample_rate=self.sampling_rate,
-                )
-                self.logger.experiment.add_audio(
-                    f"sample-{idx}/wavs/prior",
-                    prior_audio[0, :audio_len],
-                    self.global_step,
-                    sample_rate=self.sampling_rate,
-                )
-                self.logger.experiment.add_audio(
-                    f"sample-{idx}/wavs/posterior",
-                    posterior_audio[0, :audio_len],
-                    self.global_step,
-                    sample_rate=self.sampling_rate,
-                )
-
-            plt.close(image_mels)

+ 0 - 67
fish_speech/models/vits_decoder/losses.py

@@ -1,67 +0,0 @@
-import torch
-import torch.nn.functional as F
-from torch import nn
-
-
-def feature_loss(fmap_r: list[torch.Tensor], fmap_g: list[torch.Tensor]):
-    loss = 0
-    for dr, dg in zip(fmap_r, fmap_g):
-        dr = dr.float().detach()
-        dg = dg.float()
-        loss += torch.mean(torch.abs(dr - dg))
-
-    return loss * 2
-
-
-def discriminator_loss(
-    disc_real_outputs: list[torch.Tensor], disc_generated_outputs: list[torch.Tensor]
-):
-    loss = 0
-    r_losses = []
-    g_losses = []
-    for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
-        dr = dr.float()
-        dg = dg.float()
-        r_loss = torch.mean((1 - dr) ** 2)
-        g_loss = torch.mean(dg**2)
-        loss += r_loss + g_loss
-        r_losses.append(r_loss.item())
-        g_losses.append(g_loss.item())
-
-    return loss, r_losses, g_losses
-
-
-def generator_loss(disc_outputs: list[torch.Tensor]):
-    loss = 0
-    gen_losses = []
-    for dg in disc_outputs:
-        dg = dg.float()
-        l = torch.mean((1 - dg) ** 2)
-        gen_losses.append(l)
-        loss += l
-
-    return loss, gen_losses
-
-
-def kl_loss(
-    z_p: torch.Tensor,
-    logs_q: torch.Tensor,
-    m_p: torch.Tensor,
-    logs_p: torch.Tensor,
-    z_mask: torch.Tensor,
-):
-    """
-    z_p, logs_q: [b, h, t_t]
-    m_p, logs_p: [b, h, t_t]
-    """
-    z_p = z_p.float()
-    logs_q = logs_q.float()
-    m_p = m_p.float()
-    logs_p = logs_p.float()
-    z_mask = z_mask.float()
-
-    kl = logs_p - logs_q - 0.5
-    kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
-    kl = torch.sum(kl * z_mask)
-    l = kl / torch.sum(z_mask)
-    return l

+ 0 - 350
fish_speech/models/vits_decoder/modules/attentions.py

@@ -1,350 +0,0 @@
-import math
-
-import torch
-from torch import nn
-from torch.nn import functional as F
-from torch.nn.utils import remove_weight_norm, weight_norm
-
-from fish_speech.models.vits_decoder.modules import commons
-
-from .modules import LayerNorm
-
-
-class Encoder(nn.Module):
-    def __init__(
-        self,
-        hidden_channels,
-        filter_channels,
-        n_heads,
-        n_layers,
-        kernel_size=1,
-        p_dropout=0.0,
-        window_size=4,
-        isflow=False,
-        gin_channels=0,
-    ):
-        super().__init__()
-        self.hidden_channels = hidden_channels
-        self.filter_channels = filter_channels
-        self.n_heads = n_heads
-        self.n_layers = n_layers
-        self.kernel_size = kernel_size
-        self.p_dropout = p_dropout
-        self.window_size = window_size
-
-        self.drop = nn.Dropout(p_dropout)
-        self.attn_layers = nn.ModuleList()
-        self.norm_layers_1 = nn.ModuleList()
-        self.ffn_layers = nn.ModuleList()
-        self.norm_layers_2 = nn.ModuleList()
-        for i in range(self.n_layers):
-            self.attn_layers.append(
-                MultiHeadAttention(
-                    hidden_channels,
-                    hidden_channels,
-                    n_heads,
-                    p_dropout=p_dropout,
-                    window_size=window_size,
-                )
-            )
-            self.norm_layers_1.append(LayerNorm(hidden_channels))
-            self.ffn_layers.append(
-                FFN(
-                    hidden_channels,
-                    hidden_channels,
-                    filter_channels,
-                    kernel_size,
-                    p_dropout=p_dropout,
-                )
-            )
-            self.norm_layers_2.append(LayerNorm(hidden_channels))
-
-        if isflow:
-            cond_layer = torch.nn.Conv1d(
-                gin_channels, 2 * hidden_channels * n_layers, 1
-            )
-            self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
-            self.cond_layer = weight_norm(cond_layer, "weight")
-            self.gin_channels = gin_channels
-
-    def forward(self, x, x_mask, g=None):
-        attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
-        x = x * x_mask
-        if g is not None:
-            g = self.cond_layer(g)
-
-        for i in range(self.n_layers):
-            if g is not None:
-                x = self.cond_pre(x)
-                cond_offset = i * 2 * self.hidden_channels
-                g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
-                x = commons.fused_add_tanh_sigmoid_multiply(
-                    x, g_l, torch.IntTensor([self.hidden_channels])
-                )
-            y = self.attn_layers[i](x, x, attn_mask)
-            y = self.drop(y)
-            x = self.norm_layers_1[i](x + y)
-
-            y = self.ffn_layers[i](x, x_mask)
-            y = self.drop(y)
-            x = self.norm_layers_2[i](x + y)
-        x = x * x_mask
-        return x
-
-
-class MultiHeadAttention(nn.Module):
-    def __init__(
-        self,
-        channels,
-        out_channels,
-        n_heads,
-        p_dropout=0.0,
-        window_size=None,
-        heads_share=True,
-        block_length=None,
-        proximal_bias=False,
-        proximal_init=False,
-    ):
-        super().__init__()
-        assert channels % n_heads == 0
-
-        self.channels = channels
-        self.out_channels = out_channels
-        self.n_heads = n_heads
-        self.p_dropout = p_dropout
-        self.window_size = window_size
-        self.heads_share = heads_share
-        self.block_length = block_length
-        self.proximal_bias = proximal_bias
-        self.proximal_init = proximal_init
-        self.attn = None
-
-        self.k_channels = channels // n_heads
-        self.conv_q = nn.Conv1d(channels, channels, 1)
-        self.conv_k = nn.Conv1d(channels, channels, 1)
-        self.conv_v = nn.Conv1d(channels, channels, 1)
-        self.conv_o = nn.Conv1d(channels, out_channels, 1)
-        self.drop = nn.Dropout(p_dropout)
-
-        if window_size is not None:
-            n_heads_rel = 1 if heads_share else n_heads
-            rel_stddev = self.k_channels**-0.5
-            self.emb_rel_k = nn.Parameter(
-                torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
-                * rel_stddev
-            )
-            self.emb_rel_v = nn.Parameter(
-                torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
-                * rel_stddev
-            )
-
-        nn.init.xavier_uniform_(self.conv_q.weight)
-        nn.init.xavier_uniform_(self.conv_k.weight)
-        nn.init.xavier_uniform_(self.conv_v.weight)
-        if proximal_init:
-            with torch.no_grad():
-                self.conv_k.weight.copy_(self.conv_q.weight)
-                self.conv_k.bias.copy_(self.conv_q.bias)
-
-    def forward(self, x, c, attn_mask=None):
-        q = self.conv_q(x)
-        k = self.conv_k(c)
-        v = self.conv_v(c)
-
-        x, self.attn = self.attention(q, k, v, mask=attn_mask)
-
-        x = self.conv_o(x)
-        return x
-
-    def attention(self, query, key, value, mask=None):
-        # reshape [b, d, t] -> [b, n_h, t, d_k]
-        b, d, t_s, t_t = (*key.size(), query.size(2))
-        query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
-        key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
-        value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
-
-        scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
-        if self.window_size is not None:
-            assert (
-                t_s == t_t
-            ), "Relative attention is only available for self-attention."
-            key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
-            rel_logits = self._matmul_with_relative_keys(
-                query / math.sqrt(self.k_channels), key_relative_embeddings
-            )
-            scores_local = self._relative_position_to_absolute_position(rel_logits)
-            scores = scores + scores_local
-        if self.proximal_bias:
-            assert t_s == t_t, "Proximal bias is only available for self-attention."
-            scores = scores + self._attention_bias_proximal(t_s).to(
-                device=scores.device, dtype=scores.dtype
-            )
-        if mask is not None:
-            scores = scores.masked_fill(mask == 0, -1e4)
-            if self.block_length is not None:
-                assert (
-                    t_s == t_t
-                ), "Local attention is only available for self-attention."
-                block_mask = (
-                    torch.ones_like(scores)
-                    .triu(-self.block_length)
-                    .tril(self.block_length)
-                )
-                scores = scores.masked_fill(block_mask == 0, -1e4)
-        p_attn = F.softmax(scores, dim=-1)  # [b, n_h, t_t, t_s]
-        p_attn = self.drop(p_attn)
-        output = torch.matmul(p_attn, value)
-        if self.window_size is not None:
-            relative_weights = self._absolute_position_to_relative_position(p_attn)
-            value_relative_embeddings = self._get_relative_embeddings(
-                self.emb_rel_v, t_s
-            )
-            output = output + self._matmul_with_relative_values(
-                relative_weights, value_relative_embeddings
-            )
-        output = (
-            output.transpose(2, 3).contiguous().view(b, d, t_t)
-        )  # [b, n_h, t_t, d_k] -> [b, d, t_t]
-        return output, p_attn
-
-    def _matmul_with_relative_values(self, x, y):
-        """
-        x: [b, h, l, m]
-        y: [h or 1, m, d]
-        ret: [b, h, l, d]
-        """
-        ret = torch.matmul(x, y.unsqueeze(0))
-        return ret
-
-    def _matmul_with_relative_keys(self, x, y):
-        """
-        x: [b, h, l, d]
-        y: [h or 1, m, d]
-        ret: [b, h, l, m]
-        """
-        ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
-        return ret
-
-    def _get_relative_embeddings(self, relative_embeddings, length):
-        max_relative_position = 2 * self.window_size + 1
-        # Pad first before slice to avoid using cond ops.
-        pad_length = max(length - (self.window_size + 1), 0)
-        slice_start_position = max((self.window_size + 1) - length, 0)
-        slice_end_position = slice_start_position + 2 * length - 1
-        if pad_length > 0:
-            padded_relative_embeddings = F.pad(
-                relative_embeddings,
-                commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
-            )
-        else:
-            padded_relative_embeddings = relative_embeddings
-        used_relative_embeddings = padded_relative_embeddings[
-            :, slice_start_position:slice_end_position
-        ]
-        return used_relative_embeddings
-
-    def _relative_position_to_absolute_position(self, x):
-        """
-        x: [b, h, l, 2*l-1]
-        ret: [b, h, l, l]
-        """
-        batch, heads, length, _ = x.size()
-        # Concat columns of pad to shift from relative to absolute indexing.
-        x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
-
-        # Concat extra elements so to add up to shape (len+1, 2*len-1).
-        x_flat = x.view([batch, heads, length * 2 * length])
-        x_flat = F.pad(
-            x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
-        )
-
-        # Reshape and slice out the padded elements.
-        x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
-            :, :, :length, length - 1 :
-        ]
-        return x_final
-
-    def _absolute_position_to_relative_position(self, x):
-        """
-        x: [b, h, l, l]
-        ret: [b, h, l, 2*l-1]
-        """
-        batch, heads, length, _ = x.size()
-        # pad along column
-        x = F.pad(
-            x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
-        )
-        x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
-        # add 0's in the beginning that will skew the elements after reshape
-        x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
-        x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
-        return x_final
-
-    def _attention_bias_proximal(self, length):
-        """Bias for self-attention to encourage attention to close positions.
-        Args:
-          length: an integer scalar.
-        Returns:
-          a Tensor with shape [1, 1, length, length]
-        """
-        r = torch.arange(length, dtype=torch.float32)
-        diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
-        return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
-
-
-class FFN(nn.Module):
-    def __init__(
-        self,
-        in_channels,
-        out_channels,
-        filter_channels,
-        kernel_size,
-        p_dropout=0.0,
-        activation=None,
-        causal=False,
-    ):
-        super().__init__()
-        self.in_channels = in_channels
-        self.out_channels = out_channels
-        self.filter_channels = filter_channels
-        self.kernel_size = kernel_size
-        self.p_dropout = p_dropout
-        self.activation = activation
-        self.causal = causal
-
-        if causal:
-            self.padding = self._causal_padding
-        else:
-            self.padding = self._same_padding
-
-        self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
-        self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
-        self.drop = nn.Dropout(p_dropout)
-
-    def forward(self, x, x_mask):
-        x = self.conv_1(self.padding(x * x_mask))
-        if self.activation == "gelu":
-            x = x * torch.sigmoid(1.702 * x)
-        else:
-            x = torch.relu(x)
-        x = self.drop(x)
-        x = self.conv_2(self.padding(x * x_mask))
-        return x * x_mask
-
-    def _causal_padding(self, x):
-        if self.kernel_size == 1:
-            return x
-        pad_l = self.kernel_size - 1
-        pad_r = 0
-        padding = [[0, 0], [0, 0], [pad_l, pad_r]]
-        x = F.pad(x, commons.convert_pad_shape(padding))
-        return x
-
-    def _same_padding(self, x):
-        if self.kernel_size == 1:
-            return x
-        pad_l = (self.kernel_size - 1) // 2
-        pad_r = self.kernel_size // 2
-        padding = [[0, 0], [0, 0], [pad_l, pad_r]]
-        x = F.pad(x, commons.convert_pad_shape(padding))
-        return x

+ 0 - 190
fish_speech/models/vits_decoder/modules/commons.py

@@ -1,190 +0,0 @@
-import math
-
-import torch
-from torch.nn import functional as F
-
-
-def init_weights(m, mean=0.0, std=0.01):
-    classname = m.__class__.__name__
-    if classname.find("Conv") != -1:
-        m.weight.data.normal_(mean, std)
-
-
-def get_padding(kernel_size, dilation=1):
-    return int((kernel_size * dilation - dilation) / 2)
-
-
-def convert_pad_shape(pad_shape):
-    l = pad_shape[::-1]
-    pad_shape = [item for sublist in l for item in sublist]
-    return pad_shape
-
-
-def intersperse(lst, item):
-    result = [item] * (len(lst) * 2 + 1)
-    result[1::2] = lst
-    return result
-
-
-def kl_divergence(m_p, logs_p, m_q, logs_q):
-    """KL(P||Q)"""
-    kl = (logs_q - logs_p) - 0.5
-    kl += (
-        0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
-    )
-    return kl
-
-
-def rand_gumbel(shape):
-    """Sample from the Gumbel distribution, protect from overflows."""
-    uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
-    return -torch.log(-torch.log(uniform_samples))
-
-
-def rand_gumbel_like(x):
-    g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
-    return g
-
-
-def slice_segments(x, ids_str, segment_size=4):
-    ret = torch.zeros_like(x[:, :, :segment_size])
-    for i in range(x.size(0)):
-        idx_str = ids_str[i]
-        idx_end = idx_str + segment_size
-        ret[i] = x[i, :, idx_str:idx_end]
-    return ret
-
-
-def rand_slice_segments(x, x_lengths=None, segment_size=4):
-    b, d, t = x.size()
-    if x_lengths is None:
-        x_lengths = t
-    ids_str_max = x_lengths - segment_size + 1
-    ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
-    ret = slice_segments(x, ids_str, segment_size)
-    return ret, ids_str
-
-
-def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
-    position = torch.arange(length, dtype=torch.float)
-    num_timescales = channels // 2
-    log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
-        num_timescales - 1
-    )
-    inv_timescales = min_timescale * torch.exp(
-        torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
-    )
-    scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
-    signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
-    signal = F.pad(signal, [0, 0, 0, channels % 2])
-    signal = signal.view(1, channels, length)
-    return signal
-
-
-def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
-    b, channels, length = x.size()
-    signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
-    return x + signal.to(dtype=x.dtype, device=x.device)
-
-
-def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
-    b, channels, length = x.size()
-    signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
-    return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
-
-
-def subsequent_mask(length):
-    mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
-    return mask
-
-
-@torch.jit.script
-def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
-    n_channels_int = n_channels[0]
-    in_act = input_a + input_b
-    t_act = torch.tanh(in_act[:, :n_channels_int, :])
-    s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
-    acts = t_act * s_act
-    return acts
-
-
-def convert_pad_shape(pad_shape):
-    l = pad_shape[::-1]
-    pad_shape = [item for sublist in l for item in sublist]
-    return pad_shape
-
-
-def shift_1d(x):
-    x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
-    return x
-
-
-def sequence_mask(length, max_length=None):
-    if max_length is None:
-        max_length = length.max()
-    x = torch.arange(max_length, dtype=length.dtype, device=length.device)
-    return x.unsqueeze(0) < length.unsqueeze(1)
-
-
-def generate_path(duration, mask):
-    """
-    duration: [b, 1, t_x]
-    mask: [b, 1, t_y, t_x]
-    """
-    device = duration.device
-
-    b, _, t_y, t_x = mask.shape
-    cum_duration = torch.cumsum(duration, -1)
-
-    cum_duration_flat = cum_duration.view(b * t_x)
-    path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
-    path = path.view(b, t_x, t_y)
-    path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
-    path = path.unsqueeze(1).transpose(2, 3) * mask
-    return path
-
-
-def clip_grad_value_(parameters, clip_value, norm_type=2):
-    if isinstance(parameters, torch.Tensor):
-        parameters = [parameters]
-    parameters = list(filter(lambda p: p.grad is not None, parameters))
-    norm_type = float(norm_type)
-    if clip_value is not None:
-        clip_value = float(clip_value)
-
-    total_norm = 0
-    for p in parameters:
-        param_norm = p.grad.data.norm(norm_type)
-        total_norm += param_norm.item() ** norm_type
-        if clip_value is not None:
-            p.grad.data.clamp_(min=-clip_value, max=clip_value)
-    total_norm = total_norm ** (1.0 / norm_type)
-    return total_norm
-
-
-def squeeze(x, x_mask=None, n_sqz=2):
-    b, c, t = x.size()
-
-    t = (t // n_sqz) * n_sqz
-    x = x[:, :, :t]
-    x_sqz = x.view(b, c, t // n_sqz, n_sqz)
-    x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz)
-
-    if x_mask is not None:
-        x_mask = x_mask[:, :, n_sqz - 1 :: n_sqz]
-    else:
-        x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype)
-    return x_sqz * x_mask, x_mask
-
-
-def unsqueeze(x, x_mask=None, n_sqz=2):
-    b, c, t = x.size()
-
-    x_unsqz = x.view(b, n_sqz, c // n_sqz, t)
-    x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz)
-
-    if x_mask is not None:
-        x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz)
-    else:
-        x_mask = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype)
-    return x_unsqz * x_mask, x_mask

+ 0 - 686
fish_speech/models/vits_decoder/modules/models.py

@@ -1,686 +0,0 @@
-import torch
-from torch import nn
-from torch.nn import Conv1d, Conv2d, ConvTranspose1d
-from torch.nn import functional as F
-from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
-
-from fish_speech.models.vits_decoder.modules import attentions, commons, modules
-
-from .commons import get_padding, init_weights
-from .mrte import MRTE
-from .vq_encoder import VQEncoder
-
-
-class TextEncoder(nn.Module):
-    def __init__(
-        self,
-        out_channels,
-        hidden_channels,
-        filter_channels,
-        n_heads,
-        n_layers,
-        kernel_size,
-        p_dropout,
-        latent_channels=192,
-        codebook_size=264,
-    ):
-        super().__init__()
-        self.out_channels = out_channels
-        self.hidden_channels = hidden_channels
-        self.filter_channels = filter_channels
-        self.n_heads = n_heads
-        self.n_layers = n_layers
-        self.kernel_size = kernel_size
-        self.p_dropout = p_dropout
-        self.latent_channels = latent_channels
-
-        self.ssl_proj = nn.Conv1d(768, hidden_channels, 1)
-
-        self.encoder_ssl = attentions.Encoder(
-            hidden_channels,
-            filter_channels,
-            n_heads,
-            n_layers // 2,
-            kernel_size,
-            p_dropout,
-        )
-
-        self.encoder_text = attentions.Encoder(
-            hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
-        )
-        self.text_embedding = nn.Embedding(codebook_size, hidden_channels)
-
-        self.mrte = MRTE()
-
-        self.encoder2 = attentions.Encoder(
-            hidden_channels,
-            filter_channels,
-            n_heads,
-            n_layers // 2,
-            kernel_size,
-            p_dropout,
-        )
-
-        self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
-
-    def forward(self, y, y_lengths, text, text_lengths, ge):
-        y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
-            y.dtype
-        )
-
-        y = self.ssl_proj(y * y_mask) * y_mask
-
-        y = self.encoder_ssl(y * y_mask, y_mask)
-
-        text_mask = torch.unsqueeze(
-            commons.sequence_mask(text_lengths, text.size(1)), 1
-        ).to(y.dtype)
-        text = self.text_embedding(text).transpose(1, 2)
-        text = self.encoder_text(text * text_mask, text_mask)
-
-        y = self.mrte(y, y_mask, text, text_mask, ge)
-
-        y = self.encoder2(y * y_mask, y_mask)
-
-        stats = self.proj(y) * y_mask
-        m, logs = torch.split(stats, self.out_channels, dim=1)
-        return y, m, logs, y_mask
-
-
-class ResidualCouplingBlock(nn.Module):
-    def __init__(
-        self,
-        channels,
-        hidden_channels,
-        kernel_size,
-        dilation_rate,
-        n_layers,
-        n_flows=4,
-        gin_channels=0,
-    ):
-        super().__init__()
-        self.channels = channels
-        self.hidden_channels = hidden_channels
-        self.kernel_size = kernel_size
-        self.dilation_rate = dilation_rate
-        self.n_layers = n_layers
-        self.n_flows = n_flows
-        self.gin_channels = gin_channels
-
-        self.flows = nn.ModuleList()
-        for i in range(n_flows):
-            self.flows.append(
-                modules.ResidualCouplingLayer(
-                    channels,
-                    hidden_channels,
-                    kernel_size,
-                    dilation_rate,
-                    n_layers,
-                    gin_channels=gin_channels,
-                    mean_only=True,
-                )
-            )
-            self.flows.append(modules.Flip())
-
-    def forward(self, x, x_mask, g=None, reverse=False):
-        if not reverse:
-            for flow in self.flows:
-                x, _ = flow(x, x_mask, g=g, reverse=reverse)
-        else:
-            for flow in reversed(self.flows):
-                x = flow(x, x_mask, g=g, reverse=reverse)
-        return x
-
-
-class PosteriorEncoder(nn.Module):
-    def __init__(
-        self,
-        in_channels,
-        out_channels,
-        hidden_channels,
-        kernel_size,
-        dilation_rate,
-        n_layers,
-        gin_channels=0,
-    ):
-        super().__init__()
-        self.in_channels = in_channels
-        self.out_channels = out_channels
-        self.hidden_channels = hidden_channels
-        self.kernel_size = kernel_size
-        self.dilation_rate = dilation_rate
-        self.n_layers = n_layers
-        self.gin_channels = gin_channels
-
-        self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
-        self.enc = modules.WN(
-            hidden_channels,
-            kernel_size,
-            dilation_rate,
-            n_layers,
-            gin_channels=gin_channels,
-        )
-        self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
-
-    def forward(self, x, x_lengths, g=None):
-        if g != None:
-            g = g.detach()
-        x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
-            x.dtype
-        )
-        x = self.pre(x) * x_mask
-        x = self.enc(x, x_mask, g=g)
-        stats = self.proj(x) * x_mask
-        m, logs = torch.split(stats, self.out_channels, dim=1)
-        z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
-        return z, m, logs, x_mask
-
-
-class Generator(torch.nn.Module):
-    def __init__(
-        self,
-        initial_channel,
-        resblock,
-        resblock_kernel_sizes,
-        resblock_dilation_sizes,
-        upsample_rates,
-        upsample_initial_channel,
-        upsample_kernel_sizes,
-        gin_channels=0,
-    ):
-        super(Generator, self).__init__()
-        self.num_kernels = len(resblock_kernel_sizes)
-        self.num_upsamples = len(upsample_rates)
-        self.conv_pre = Conv1d(
-            initial_channel, upsample_initial_channel, 7, 1, padding=3
-        )
-        resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
-
-        self.ups = nn.ModuleList()
-        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
-            self.ups.append(
-                weight_norm(
-                    ConvTranspose1d(
-                        upsample_initial_channel // (2**i),
-                        upsample_initial_channel // (2 ** (i + 1)),
-                        k,
-                        u,
-                        padding=(k - u) // 2,
-                    )
-                )
-            )
-
-        self.resblocks = nn.ModuleList()
-        for i in range(len(self.ups)):
-            ch = upsample_initial_channel // (2 ** (i + 1))
-            for j, (k, d) in enumerate(
-                zip(resblock_kernel_sizes, resblock_dilation_sizes)
-            ):
-                self.resblocks.append(resblock(ch, k, d))
-
-        self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
-        self.ups.apply(init_weights)
-
-        if gin_channels != 0:
-            self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
-
-    def forward(self, x, g=None):
-        x = self.conv_pre(x)
-        if g is not None:
-            x = x + self.cond(g)
-
-        for i in range(self.num_upsamples):
-            x = F.leaky_relu(x, modules.LRELU_SLOPE)
-            x = self.ups[i](x)
-            xs = None
-            for j in range(self.num_kernels):
-                if xs is None:
-                    xs = self.resblocks[i * self.num_kernels + j](x)
-                else:
-                    xs += self.resblocks[i * self.num_kernels + j](x)
-            x = xs / self.num_kernels
-        x = F.leaky_relu(x)
-        x = self.conv_post(x)
-        x = torch.tanh(x)
-
-        return x
-
-    def remove_weight_norm(self):
-        print("Removing weight norm...")
-        for l in self.ups:
-            remove_weight_norm(l)
-        for l in self.resblocks:
-            l.remove_weight_norm()
-
-
-class DiscriminatorP(torch.nn.Module):
-    def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
-        super(DiscriminatorP, self).__init__()
-        self.period = period
-        self.use_spectral_norm = use_spectral_norm
-        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
-        self.convs = nn.ModuleList(
-            [
-                norm_f(
-                    Conv2d(
-                        1,
-                        32,
-                        (kernel_size, 1),
-                        (stride, 1),
-                        padding=(get_padding(kernel_size, 1), 0),
-                    )
-                ),
-                norm_f(
-                    Conv2d(
-                        32,
-                        128,
-                        (kernel_size, 1),
-                        (stride, 1),
-                        padding=(get_padding(kernel_size, 1), 0),
-                    )
-                ),
-                norm_f(
-                    Conv2d(
-                        128,
-                        512,
-                        (kernel_size, 1),
-                        (stride, 1),
-                        padding=(get_padding(kernel_size, 1), 0),
-                    )
-                ),
-                norm_f(
-                    Conv2d(
-                        512,
-                        1024,
-                        (kernel_size, 1),
-                        (stride, 1),
-                        padding=(get_padding(kernel_size, 1), 0),
-                    )
-                ),
-                norm_f(
-                    Conv2d(
-                        1024,
-                        1024,
-                        (kernel_size, 1),
-                        1,
-                        padding=(get_padding(kernel_size, 1), 0),
-                    )
-                ),
-            ]
-        )
-        self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
-
-    def forward(self, x):
-        fmap = []
-
-        # 1d to 2d
-        b, c, t = x.shape
-        if t % self.period != 0:  # pad first
-            n_pad = self.period - (t % self.period)
-            x = F.pad(x, (0, n_pad), "reflect")
-            t = t + n_pad
-        x = x.view(b, c, t // self.period, self.period)
-
-        for l in self.convs:
-            x = l(x)
-            x = F.leaky_relu(x, modules.LRELU_SLOPE)
-            fmap.append(x)
-        x = self.conv_post(x)
-        fmap.append(x)
-        x = torch.flatten(x, 1, -1)
-
-        return x, fmap
-
-
-class DiscriminatorS(torch.nn.Module):
-    def __init__(self, use_spectral_norm=False):
-        super(DiscriminatorS, self).__init__()
-        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
-        self.convs = nn.ModuleList(
-            [
-                norm_f(Conv1d(1, 16, 15, 1, padding=7)),
-                norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
-                norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
-                norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
-                norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
-                norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
-            ]
-        )
-        self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
-
-    def forward(self, x):
-        fmap = []
-
-        for l in self.convs:
-            x = l(x)
-            x = F.leaky_relu(x, modules.LRELU_SLOPE)
-            fmap.append(x)
-        x = self.conv_post(x)
-        fmap.append(x)
-        x = torch.flatten(x, 1, -1)
-
-        return x, fmap
-
-
-class EnsembledDiscriminator(torch.nn.Module):
-    def __init__(self, periods=(2, 3, 5, 7, 11), use_spectral_norm=False):
-        super().__init__()
-        discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
-        discs = discs + [
-            DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
-        ]
-        self.discriminators = nn.ModuleList(discs)
-
-    def forward(self, y, y_hat):
-        y_d_rs = []
-        y_d_gs = []
-        fmap_rs = []
-        fmap_gs = []
-        for i, d in enumerate(self.discriminators):
-            y_d_r, fmap_r = d(y)
-            y_d_g, fmap_g = d(y_hat)
-            y_d_rs.append(y_d_r)
-            y_d_gs.append(y_d_g)
-            fmap_rs.append(fmap_r)
-            fmap_gs.append(fmap_g)
-
-        return y_d_rs, y_d_gs, fmap_rs, fmap_gs
-
-
-class SynthesizerTrn(nn.Module):
-    """
-    Synthesizer for Training
-    """
-
-    def __init__(
-        self,
-        *,
-        spec_channels,
-        segment_size,
-        inter_channels,
-        hidden_channels,
-        filter_channels,
-        n_heads,
-        n_layers,
-        kernel_size,
-        p_dropout,
-        resblock,
-        resblock_kernel_sizes,
-        resblock_dilation_sizes,
-        upsample_rates,
-        upsample_initial_channel,
-        upsample_kernel_sizes,
-        gin_channels=0,
-        codebook_size=264,
-        vq_mask_ratio=0.0,
-        ref_mask_ratio=0.0,
-    ):
-        super().__init__()
-
-        self.spec_channels = spec_channels
-        self.inter_channels = inter_channels
-        self.hidden_channels = hidden_channels
-        self.filter_channels = filter_channels
-        self.n_heads = n_heads
-        self.n_layers = n_layers
-        self.kernel_size = kernel_size
-        self.p_dropout = p_dropout
-        self.resblock = resblock
-        self.resblock_kernel_sizes = resblock_kernel_sizes
-        self.resblock_dilation_sizes = resblock_dilation_sizes
-        self.upsample_rates = upsample_rates
-        self.upsample_initial_channel = upsample_initial_channel
-        self.upsample_kernel_sizes = upsample_kernel_sizes
-        self.segment_size = segment_size
-        self.gin_channels = gin_channels
-        self.vq_mask_ratio = vq_mask_ratio
-        self.ref_mask_ratio = ref_mask_ratio
-
-        self.enc_p = TextEncoder(
-            inter_channels,
-            hidden_channels,
-            filter_channels,
-            n_heads,
-            n_layers,
-            kernel_size,
-            p_dropout,
-            codebook_size=codebook_size,
-        )
-        self.dec = Generator(
-            inter_channels,
-            resblock,
-            resblock_kernel_sizes,
-            resblock_dilation_sizes,
-            upsample_rates,
-            upsample_initial_channel,
-            upsample_kernel_sizes,
-            gin_channels=gin_channels,
-        )
-        self.enc_q = PosteriorEncoder(
-            spec_channels,
-            inter_channels,
-            hidden_channels,
-            5,
-            1,
-            16,
-            gin_channels=gin_channels,
-        )
-        self.flow = ResidualCouplingBlock(
-            inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
-        )
-
-        self.ref_enc = modules.MelStyleEncoder(
-            spec_channels, style_vector_dim=gin_channels
-        )
-
-        self.vq = VQEncoder()
-        for param in self.vq.parameters():
-            param.requires_grad = False
-
-    def forward(
-        self, audio, audio_lengths, gt_specs, gt_spec_lengths, text, text_lengths
-    ):
-        y_mask = torch.unsqueeze(
-            commons.sequence_mask(gt_spec_lengths, gt_specs.size(2)), 1
-        ).to(gt_specs.dtype)
-        ge = self.ref_enc(gt_specs * y_mask, y_mask)
-
-        if self.training and self.ref_mask_ratio > 0:
-            bs = audio.size(0)
-            mask_speaker_len = int(bs * self.ref_mask_ratio)
-            mask_indices = torch.randperm(bs)[:mask_speaker_len]
-            audio[mask_indices] = 0
-
-        quantized = self.vq(audio, audio_lengths)
-
-        # Block masking, block_size = 4
-        block_size = 4
-        if self.training and self.vq_mask_ratio > 0:
-            reduced_length = quantized.size(-1) // block_size
-            mask_length = int(reduced_length * self.vq_mask_ratio)
-            mask_indices = torch.randperm(reduced_length)[:mask_length]
-            short_mask = torch.zeros(
-                quantized.size(0),
-                quantized.size(1),
-                reduced_length,
-                device=quantized.device,
-                dtype=torch.float,
-            )
-            short_mask[:, :, mask_indices] = 1.0
-            long_mask = short_mask.repeat_interleave(block_size, dim=-1)
-            long_mask = F.interpolate(
-                long_mask, size=quantized.size(-1), mode="nearest"
-            )
-            quantized = quantized.masked_fill(long_mask > 0.5, 0)
-
-        x, m_p, logs_p, y_mask = self.enc_p(
-            quantized, gt_spec_lengths, text, text_lengths, ge
-        )
-        z, m_q, logs_q, y_mask = self.enc_q(gt_specs, gt_spec_lengths, g=ge)
-        z_p = self.flow(z, y_mask, g=ge)
-
-        z_slice, ids_slice = commons.rand_slice_segments(
-            z, gt_spec_lengths, self.segment_size
-        )
-        o = self.dec(z_slice, g=ge)
-
-        return (
-            o,
-            ids_slice,
-            y_mask,
-            (z, z_p, m_p, logs_p, m_q, logs_q),
-        )
-
-    @torch.no_grad()
-    def infer(
-        self,
-        audio,
-        audio_lengths,
-        gt_specs,
-        gt_spec_lengths,
-        text,
-        text_lengths,
-        noise_scale=0.5,
-    ):
-        quantized = self.vq(audio, audio_lengths)
-        quantized_lengths = audio_lengths // 512
-        ge = self.encode_ref(gt_specs, gt_spec_lengths)
-
-        return self.decode(
-            quantized,
-            quantized_lengths,
-            text,
-            text_lengths,
-            noise_scale=noise_scale,
-            ge=ge,
-        )
-
-    @torch.no_grad()
-    def infer_posterior(
-        self,
-        gt_specs,
-        gt_spec_lengths,
-    ):
-        y_mask = torch.unsqueeze(
-            commons.sequence_mask(gt_spec_lengths, gt_specs.size(2)), 1
-        ).to(gt_specs.dtype)
-        ge = self.ref_enc(gt_specs * y_mask, y_mask)
-        z, m_q, logs_q, y_mask = self.enc_q(gt_specs, gt_spec_lengths, g=ge)
-        o = self.dec(z * y_mask, g=ge)
-
-        return o
-
-    @torch.no_grad()
-    def decode(
-        self,
-        quantized,
-        quantized_lengths,
-        text,
-        text_lengths,
-        noise_scale=0.5,
-        ge=None,
-    ):
-        x, m_p, logs_p, y_mask = self.enc_p(
-            quantized, quantized_lengths, text, text_lengths, ge
-        )
-        z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
-
-        z = self.flow(z_p, y_mask, g=ge, reverse=True)
-
-        o = self.dec(z * y_mask, g=ge)
-
-        return o
-
-    @torch.no_grad()
-    def encode_ref(self, gt_specs, gt_spec_lengths):
-        y_mask = torch.unsqueeze(
-            commons.sequence_mask(gt_spec_lengths, gt_specs.size(2)), 1
-        ).to(gt_specs.dtype)
-        ge = self.ref_enc(gt_specs * y_mask, y_mask)
-
-        return ge
-
-
-if __name__ == "__main__":
-    import librosa
-    from transformers import AutoTokenizer
-
-    from fish_speech.utils.spectrogram import LinearSpectrogram
-
-    model = SynthesizerTrn(
-        spec_channels=1025,
-        segment_size=20480 // 640,
-        inter_channels=192,
-        hidden_channels=192,
-        filter_channels=768,
-        n_heads=2,
-        n_layers=6,
-        kernel_size=3,
-        p_dropout=0.1,
-        resblock="1",
-        resblock_kernel_sizes=[3, 7, 11],
-        resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
-        upsample_rates=[8, 8, 2, 2, 2],
-        upsample_initial_channel=512,
-        upsample_kernel_sizes=[16, 16, 8, 2, 2],
-        gin_channels=512,
-    )
-
-    ckpt = "checkpoints/Bert-VITS2/G_0.pth"
-    # Try to load the model
-    print(f"Loading model from {ckpt}")
-    checkpoint = torch.load(ckpt, map_location="cpu", weights_only=True)["model"]
-    # d_checkpoint = torch.load(
-    #     "checkpoints/Bert-VITS2/D_0.pth", map_location="cpu", weights_only=True
-    # )["model"]
-    # print(checkpoint.keys())
-
-    checkpoint.pop("dec.cond.weight")
-    checkpoint.pop("enc_q.enc.cond_layer.weight_v")
-
-    # new_checkpoint = {}
-    # for k, v in checkpoint.items():
-    #     new_checkpoint["generator." + k] = v
-
-    # for k, v in d_checkpoint.items():
-    #     new_checkpoint["discriminator." + k] = v
-
-    # torch.save(new_checkpoint, "checkpoints/Bert-VITS2/ensemble.pth")
-    # exit()
-
-    print(model.load_state_dict(checkpoint, strict=False))
-
-    # Test
-
-    ref_audio = librosa.load(
-        "data/source/云天河/云天河-旁白/《薄太太》第0025集-yth_24.wav", sr=32000
-    )[0]
-    input_audio = librosa.load(
-        "data/source/云天河/云天河-旁白/《薄太太》第0025集-yth_24.wav", sr=32000
-    )[0]
-    ref_audio = input_audio
-    text = "博兴只知道身边的小女人没睡着,他又凑到她耳边压低了声线。阮苏眉睁眼,不觉得你老公像英雄吗?阮苏还是没反应,这男人是不是有病?刚才那冰冷又强势的样子,和现在这幼稚无赖的样子,根本就判若二人。"
-    encoded_text = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1")
-    spec = LinearSpectrogram(n_fft=2048, hop_length=640, win_length=2048)
-
-    ref_audio = torch.tensor(ref_audio).unsqueeze(0).unsqueeze(0)
-    ref_spec = spec(ref_audio)
-
-    input_audio = torch.tensor(input_audio).unsqueeze(0).unsqueeze(0)
-    text = encoded_text(text, return_tensors="pt")["input_ids"]
-    print(ref_audio.size(), ref_spec.size(), input_audio.size(), text.size())
-
-    o, y_mask, (z, z_p, m_p, logs_p) = model.infer(
-        input_audio,
-        torch.LongTensor([input_audio.size(2)]),
-        ref_spec,
-        torch.LongTensor([ref_spec.size(2)]),
-        text,
-        torch.LongTensor([text.size(1)]),
-    )
-    print(o.size(), y_mask.size(), z.size(), z_p.size(), m_p.size(), logs_p.size())
-
-    # Save output
-    # import soundfile as sf
-
-    # sf.write("output.wav", o.squeeze().detach().numpy(), 32000)

+ 0 - 619
fish_speech/models/vits_decoder/modules/modules.py

@@ -1,619 +0,0 @@
-import numpy as np
-import torch
-from torch import nn
-from torch.nn import Conv1d
-from torch.nn import functional as F
-from torch.nn.utils import remove_weight_norm, weight_norm
-
-from .commons import fused_add_tanh_sigmoid_multiply, get_padding, init_weights
-
-LRELU_SLOPE = 0.1
-
-
-class LayerNorm(nn.Module):
-    def __init__(self, channels, eps=1e-5):
-        super().__init__()
-        self.channels = channels
-        self.eps = eps
-
-        self.gamma = nn.Parameter(torch.ones(channels))
-        self.beta = nn.Parameter(torch.zeros(channels))
-
-    def forward(self, x):
-        x = x.transpose(1, -1)
-        x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
-        return x.transpose(1, -1)
-
-
-class ConvReluNorm(nn.Module):
-    def __init__(
-        self,
-        in_channels,
-        hidden_channels,
-        out_channels,
-        kernel_size,
-        n_layers,
-        p_dropout,
-    ):
-        super().__init__()
-        self.in_channels = in_channels
-        self.hidden_channels = hidden_channels
-        self.out_channels = out_channels
-        self.kernel_size = kernel_size
-        self.n_layers = n_layers
-        self.p_dropout = p_dropout
-        assert n_layers > 1, "Number of layers should be larger than 0."
-
-        self.conv_layers = nn.ModuleList()
-        self.norm_layers = nn.ModuleList()
-        self.conv_layers.append(
-            nn.Conv1d(
-                in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
-            )
-        )
-        self.norm_layers.append(LayerNorm(hidden_channels))
-        self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
-        for _ in range(n_layers - 1):
-            self.conv_layers.append(
-                nn.Conv1d(
-                    hidden_channels,
-                    hidden_channels,
-                    kernel_size,
-                    padding=kernel_size // 2,
-                )
-            )
-            self.norm_layers.append(LayerNorm(hidden_channels))
-        self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
-        self.proj.weight.data.zero_()
-        self.proj.bias.data.zero_()
-
-    def forward(self, x, x_mask):
-        x_org = x
-        for i in range(self.n_layers):
-            x = self.conv_layers[i](x * x_mask)
-            x = self.norm_layers[i](x)
-            x = self.relu_drop(x)
-        x = x_org + self.proj(x)
-        return x * x_mask
-
-
-class WN(torch.nn.Module):
-    def __init__(
-        self,
-        hidden_channels,
-        kernel_size,
-        dilation_rate,
-        n_layers,
-        gin_channels=0,
-        p_dropout=0,
-    ):
-        super(WN, self).__init__()
-        assert kernel_size % 2 == 1
-        self.hidden_channels = hidden_channels
-        self.kernel_size = (kernel_size,)
-        self.dilation_rate = dilation_rate
-        self.n_layers = n_layers
-        self.gin_channels = gin_channels
-        self.p_dropout = p_dropout
-
-        self.in_layers = torch.nn.ModuleList()
-        self.res_skip_layers = torch.nn.ModuleList()
-        self.drop = nn.Dropout(p_dropout)
-
-        if gin_channels != 0:
-            cond_layer = torch.nn.Conv1d(
-                gin_channels, 2 * hidden_channels * n_layers, 1
-            )
-            self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
-
-        for i in range(n_layers):
-            dilation = dilation_rate**i
-            padding = int((kernel_size * dilation - dilation) / 2)
-            in_layer = torch.nn.Conv1d(
-                hidden_channels,
-                2 * hidden_channels,
-                kernel_size,
-                dilation=dilation,
-                padding=padding,
-            )
-            in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
-            self.in_layers.append(in_layer)
-
-            # last one is not necessary
-            if i < n_layers - 1:
-                res_skip_channels = 2 * hidden_channels
-            else:
-                res_skip_channels = hidden_channels
-
-            res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
-            res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
-            self.res_skip_layers.append(res_skip_layer)
-
-    def forward(self, x, x_mask, g=None, **kwargs):
-        output = torch.zeros_like(x)
-        n_channels_tensor = torch.IntTensor([self.hidden_channels])
-
-        if g is not None:
-            g = self.cond_layer(g)
-
-        for i in range(self.n_layers):
-            x_in = self.in_layers[i](x)
-            if g is not None:
-                cond_offset = i * 2 * self.hidden_channels
-                g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
-            else:
-                g_l = torch.zeros_like(x_in)
-
-            acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
-            acts = self.drop(acts)
-
-            res_skip_acts = self.res_skip_layers[i](acts)
-            if i < self.n_layers - 1:
-                res_acts = res_skip_acts[:, : self.hidden_channels, :]
-                x = (x + res_acts) * x_mask
-                output = output + res_skip_acts[:, self.hidden_channels :, :]
-            else:
-                output = output + res_skip_acts
-        return output * x_mask
-
-    def remove_weight_norm(self):
-        if self.gin_channels != 0:
-            torch.nn.utils.remove_weight_norm(self.cond_layer)
-        for l in self.in_layers:
-            torch.nn.utils.remove_weight_norm(l)
-        for l in self.res_skip_layers:
-            torch.nn.utils.remove_weight_norm(l)
-
-
-class ResBlock1(torch.nn.Module):
-    def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
-        super(ResBlock1, self).__init__()
-        self.convs1 = nn.ModuleList(
-            [
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[0],
-                        padding=get_padding(kernel_size, dilation[0]),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[1],
-                        padding=get_padding(kernel_size, dilation[1]),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[2],
-                        padding=get_padding(kernel_size, dilation[2]),
-                    )
-                ),
-            ]
-        )
-        self.convs1.apply(init_weights)
-
-        self.convs2 = nn.ModuleList(
-            [
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=1,
-                        padding=get_padding(kernel_size, 1),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=1,
-                        padding=get_padding(kernel_size, 1),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=1,
-                        padding=get_padding(kernel_size, 1),
-                    )
-                ),
-            ]
-        )
-        self.convs2.apply(init_weights)
-
-    def forward(self, x, x_mask=None):
-        for c1, c2 in zip(self.convs1, self.convs2):
-            xt = F.leaky_relu(x, LRELU_SLOPE)
-            if x_mask is not None:
-                xt = xt * x_mask
-            xt = c1(xt)
-            xt = F.leaky_relu(xt, LRELU_SLOPE)
-            if x_mask is not None:
-                xt = xt * x_mask
-            xt = c2(xt)
-            x = xt + x
-        if x_mask is not None:
-            x = x * x_mask
-        return x
-
-    def remove_weight_norm(self):
-        for l in self.convs1:
-            remove_weight_norm(l)
-        for l in self.convs2:
-            remove_weight_norm(l)
-
-
-class ResBlock2(torch.nn.Module):
-    def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
-        super(ResBlock2, self).__init__()
-        self.convs = nn.ModuleList(
-            [
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[0],
-                        padding=get_padding(kernel_size, dilation[0]),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[1],
-                        padding=get_padding(kernel_size, dilation[1]),
-                    )
-                ),
-            ]
-        )
-        self.convs.apply(init_weights)
-
-    def forward(self, x, x_mask=None):
-        for c in self.convs:
-            xt = F.leaky_relu(x, LRELU_SLOPE)
-            if x_mask is not None:
-                xt = xt * x_mask
-            xt = c(xt)
-            x = xt + x
-        if x_mask is not None:
-            x = x * x_mask
-        return x
-
-    def remove_weight_norm(self):
-        for l in self.convs:
-            remove_weight_norm(l)
-
-
-class Flip(nn.Module):
-    def forward(self, x, *args, reverse=False, **kwargs):
-        x = torch.flip(x, [1])
-        if not reverse:
-            logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
-            return x, logdet
-        else:
-            return x
-
-
-class ResidualCouplingLayer(nn.Module):
-    def __init__(
-        self,
-        channels,
-        hidden_channels,
-        kernel_size,
-        dilation_rate,
-        n_layers,
-        p_dropout=0,
-        gin_channels=0,
-        mean_only=False,
-    ):
-        assert channels % 2 == 0, "channels should be divisible by 2"
-        super().__init__()
-        self.channels = channels
-        self.hidden_channels = hidden_channels
-        self.kernel_size = kernel_size
-        self.dilation_rate = dilation_rate
-        self.n_layers = n_layers
-        self.half_channels = channels // 2
-        self.mean_only = mean_only
-
-        self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
-        self.enc = WN(
-            hidden_channels,
-            kernel_size,
-            dilation_rate,
-            n_layers,
-            p_dropout=p_dropout,
-            gin_channels=gin_channels,
-        )
-        self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
-        self.post.weight.data.zero_()
-        self.post.bias.data.zero_()
-
-    def forward(self, x, x_mask, g=None, reverse=False):
-        x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
-        h = self.pre(x0) * x_mask
-        h = self.enc(h, x_mask, g=g)
-        stats = self.post(h) * x_mask
-        if not self.mean_only:
-            m, logs = torch.split(stats, [self.half_channels] * 2, 1)
-        else:
-            m = stats
-            logs = torch.zeros_like(m)
-
-        if not reverse:
-            x1 = m + x1 * torch.exp(logs) * x_mask
-            x = torch.cat([x0, x1], 1)
-            logdet = torch.sum(logs, [1, 2])
-            return x, logdet
-        else:
-            x1 = (x1 - m) * torch.exp(-logs) * x_mask
-            x = torch.cat([x0, x1], 1)
-            return x
-
-
-class LinearNorm(nn.Module):
-    def __init__(
-        self,
-        in_channels,
-        out_channels,
-        bias=True,
-        spectral_norm=False,
-    ):
-        super(LinearNorm, self).__init__()
-        self.fc = nn.Linear(in_channels, out_channels, bias)
-
-        if spectral_norm:
-            self.fc = nn.utils.spectral_norm(self.fc)
-
-    def forward(self, input):
-        out = self.fc(input)
-        return out
-
-
-class Mish(nn.Module):
-    def __init__(self):
-        super(Mish, self).__init__()
-
-    def forward(self, x):
-        return x * torch.tanh(F.softplus(x))
-
-
-class Conv1dGLU(nn.Module):
-    """
-    Conv1d + GLU(Gated Linear Unit) with residual connection.
-    For GLU refer to https://arxiv.org/abs/1612.08083 paper.
-    """
-
-    def __init__(self, in_channels, out_channels, kernel_size, dropout):
-        super(Conv1dGLU, self).__init__()
-        self.out_channels = out_channels
-        self.conv1 = ConvNorm(in_channels, 2 * out_channels, kernel_size=kernel_size)
-        self.dropout = nn.Dropout(dropout)
-
-    def forward(self, x):
-        residual = x
-        x = self.conv1(x)
-        x1, x2 = torch.split(x, split_size_or_sections=self.out_channels, dim=1)
-        x = x1 * torch.sigmoid(x2)
-        x = residual + self.dropout(x)
-        return x
-
-
-class ConvNorm(nn.Module):
-    def __init__(
-        self,
-        in_channels,
-        out_channels,
-        kernel_size=1,
-        stride=1,
-        padding=None,
-        dilation=1,
-        bias=True,
-        spectral_norm=False,
-    ):
-        super(ConvNorm, self).__init__()
-
-        if padding is None:
-            assert kernel_size % 2 == 1
-            padding = int(dilation * (kernel_size - 1) / 2)
-
-        self.conv = torch.nn.Conv1d(
-            in_channels,
-            out_channels,
-            kernel_size=kernel_size,
-            stride=stride,
-            padding=padding,
-            dilation=dilation,
-            bias=bias,
-        )
-
-        if spectral_norm:
-            self.conv = nn.utils.spectral_norm(self.conv)
-
-    def forward(self, input):
-        out = self.conv(input)
-        return out
-
-
-class MultiHeadAttention(nn.Module):
-    """Multi-Head Attention module"""
-
-    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.0, spectral_norm=False):
-        super().__init__()
-
-        self.n_head = n_head
-        self.d_k = d_k
-        self.d_v = d_v
-
-        self.w_qs = nn.Linear(d_model, n_head * d_k)
-        self.w_ks = nn.Linear(d_model, n_head * d_k)
-        self.w_vs = nn.Linear(d_model, n_head * d_v)
-
-        self.attention = ScaledDotProductAttention(
-            temperature=np.power(d_model, 0.5), dropout=dropout
-        )
-
-        self.fc = nn.Linear(n_head * d_v, d_model)
-        self.dropout = nn.Dropout(dropout)
-
-        if spectral_norm:
-            self.w_qs = nn.utils.spectral_norm(self.w_qs)
-            self.w_ks = nn.utils.spectral_norm(self.w_ks)
-            self.w_vs = nn.utils.spectral_norm(self.w_vs)
-            self.fc = nn.utils.spectral_norm(self.fc)
-
-    def forward(self, x, mask=None):
-        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
-        sz_b, len_x, _ = x.size()
-
-        residual = x
-
-        q = self.w_qs(x).view(sz_b, len_x, n_head, d_k)
-        k = self.w_ks(x).view(sz_b, len_x, n_head, d_k)
-        v = self.w_vs(x).view(sz_b, len_x, n_head, d_v)
-        q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_k)  # (n*b) x lq x dk
-        k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_k)  # (n*b) x lk x dk
-        v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_v)  # (n*b) x lv x dv
-
-        if mask is not None:
-            slf_mask = mask.repeat(n_head, 1, 1)  # (n*b) x .. x ..
-        else:
-            slf_mask = None
-        output, attn = self.attention(q, k, v, mask=slf_mask)
-
-        output = output.view(n_head, sz_b, len_x, d_v)
-        output = (
-            output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_x, -1)
-        )  # b x lq x (n*dv)
-
-        output = self.fc(output)
-
-        output = self.dropout(output) + residual
-        return output, attn
-
-
-class ScaledDotProductAttention(nn.Module):
-    """Scaled Dot-Product Attention"""
-
-    def __init__(self, temperature, dropout):
-        super().__init__()
-        self.temperature = temperature
-        self.softmax = nn.Softmax(dim=2)
-        self.dropout = nn.Dropout(dropout)
-
-    def forward(self, q, k, v, mask=None):
-        attn = torch.bmm(q, k.transpose(1, 2))
-        attn = attn / self.temperature
-
-        if mask is not None:
-            attn = attn.masked_fill(mask, -np.inf)
-
-        attn = self.softmax(attn)
-        p_attn = self.dropout(attn)
-
-        output = torch.bmm(p_attn, v)
-        return output, attn
-
-
-class MelStyleEncoder(nn.Module):
-    """MelStyleEncoder"""
-
-    def __init__(
-        self,
-        n_mel_channels=80,
-        style_hidden=128,
-        style_vector_dim=256,
-        style_kernel_size=5,
-        style_head=2,
-        dropout=0.1,
-    ):
-        super(MelStyleEncoder, self).__init__()
-        self.in_dim = n_mel_channels
-        self.hidden_dim = style_hidden
-        self.out_dim = style_vector_dim
-        self.kernel_size = style_kernel_size
-        self.n_head = style_head
-        self.dropout = dropout
-
-        self.spectral = nn.Sequential(
-            LinearNorm(self.in_dim, self.hidden_dim),
-            Mish(),
-            nn.Dropout(self.dropout),
-            LinearNorm(self.hidden_dim, self.hidden_dim),
-            Mish(),
-            nn.Dropout(self.dropout),
-        )
-
-        self.temporal = nn.Sequential(
-            Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
-            Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
-        )
-
-        self.slf_attn = MultiHeadAttention(
-            self.n_head,
-            self.hidden_dim,
-            self.hidden_dim // self.n_head,
-            self.hidden_dim // self.n_head,
-            self.dropout,
-        )
-
-        self.fc = LinearNorm(self.hidden_dim, self.out_dim)
-
-    def temporal_avg_pool(self, x, mask=None):
-        if mask is None:
-            out = torch.mean(x, dim=1)
-        else:
-            len_ = (~mask).sum(dim=1).unsqueeze(1)
-            x = x.masked_fill(mask.unsqueeze(-1), 0)
-            x = x.sum(dim=1)
-            out = torch.div(x, len_)
-        return out
-
-    def forward(self, x, mask=None):
-        x = x.transpose(1, 2)
-        if mask is not None:
-            mask = (mask.int() == 0).squeeze(1)
-        max_len = x.shape[1]
-        slf_attn_mask = (
-            mask.unsqueeze(1).expand(-1, max_len, -1) if mask is not None else None
-        )
-
-        # spectral
-        x = self.spectral(x)
-        # temporal
-        x = x.transpose(1, 2)
-        x = self.temporal(x)
-        x = x.transpose(1, 2)
-        # self-attention
-        if mask is not None:
-            x = x.masked_fill(mask.unsqueeze(-1), 0)
-        x, _ = self.slf_attn(x, mask=slf_attn_mask)
-        # fc
-        x = self.fc(x)
-        # temoral average pooling
-        w = self.temporal_avg_pool(x, mask=mask)
-
-        return w.unsqueeze(-1)

+ 0 - 58
fish_speech/models/vits_decoder/modules/mrte.py

@@ -1,58 +0,0 @@
-import torch
-from torch import nn
-from torch.nn.utils import remove_weight_norm, weight_norm
-
-from fish_speech.models.vits_decoder.modules.attentions import MultiHeadAttention
-
-
-class MRTE(nn.Module):
-    def __init__(
-        self,
-        content_enc_channels=192,
-        hidden_size=512,
-        out_channels=192,
-        n_heads=4,
-    ):
-        super(MRTE, self).__init__()
-        self.cross_attention = MultiHeadAttention(hidden_size, hidden_size, n_heads)
-        self.c_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
-        self.text_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
-        self.c_post = nn.Conv1d(hidden_size, out_channels, 1)
-
-    def forward(self, ssl_enc, ssl_mask, text, text_mask, ge, test=None):
-        if ge == None:
-            ge = 0
-        attn_mask = text_mask.unsqueeze(2) * ssl_mask.unsqueeze(-1)
-
-        ssl_enc = self.c_pre(ssl_enc * ssl_mask)
-        text_enc = self.text_pre(text * text_mask)
-        if test != None:
-            if test == 0:
-                x = (
-                    self.cross_attention(
-                        ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
-                    )
-                    + ssl_enc
-                    + ge
-                )
-            elif test == 1:
-                x = ssl_enc + ge
-            elif test == 2:
-                x = (
-                    self.cross_attention(
-                        ssl_enc * 0 * ssl_mask, text_enc * text_mask, attn_mask
-                    )
-                    + ge
-                )
-            else:
-                raise ValueError("test should be 0,1,2")
-        else:
-            x = (
-                self.cross_attention(
-                    ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
-                )
-                + ssl_enc
-                + ge
-            )
-        x = self.c_post(x * ssl_mask)
-        return x

+ 0 - 101
fish_speech/models/vits_decoder/modules/vq_encoder.py

@@ -1,101 +0,0 @@
-import math
-
-import torch
-from torch import nn
-
-from fish_speech.models.vqgan.modules.fsq import DownsampleFiniteScalarQuantize
-from fish_speech.models.vqgan.modules.wavenet import WaveNet
-from fish_speech.models.vqgan.utils import sequence_mask
-from fish_speech.utils.spectrogram import LogMelSpectrogram
-
-
-class VQEncoder(nn.Module):
-    def __init__(
-        self,
-    ):
-        super().__init__()
-
-        self.encoder = WaveNet(
-            input_channels=128,
-            residual_channels=768,
-            residual_layers=20,
-            dilation_cycle=4,
-        )
-
-        self.quantizer = DownsampleFiniteScalarQuantize(
-            input_dim=768, n_codebooks=1, n_groups=2, levels=[8, 5, 5, 5]
-        )
-
-        self.spec = LogMelSpectrogram(
-            sample_rate=44100,
-            n_fft=2048,
-            win_length=2048,
-            hop_length=512,
-            n_mels=128,
-            f_min=0.0,
-            f_max=8000.0,
-        )
-
-        self.eval()
-        e = self.load_state_dict(
-            torch.load("checkpoints/vq-gan-group-fsq-2x1024.pth", map_location="cpu"),
-            strict=False,
-        )
-
-        assert len(e.missing_keys) == 0, e.missing_keys
-        assert all(
-            k.startswith("decoder.")
-            or k.startswith("quality_projection.")
-            or k.startswith("discriminator.")
-            for k in e.unexpected_keys
-        ), e.unexpected_keys
-
-    @torch.no_grad()
-    def forward(self, audios, audio_lengths, sr=None):
-        mel_spec = self.spec(audios, sample_rate=sr)
-
-        if sr is not None:
-            audio_lengths = audio_lengths * 44100 // sr
-
-        mel_lengths = audio_lengths // self.spec.hop_length
-        mel_masks = (
-            torch.arange(mel_spec.shape[2], device=mel_spec.device)
-            < mel_lengths[:, None]
-        )
-        mel_masks_float_conv = mel_masks[:, None, :].float()
-        mels = mel_spec * mel_masks_float_conv
-
-        # Encode
-        encoded_features = self.encoder(mels) * mel_masks_float_conv
-        encoded_features = self.quantizer(encoded_features).z * mel_masks_float_conv
-
-        return encoded_features
-
-    @torch.no_grad()
-    def indicies_to_vq_features(
-        self,
-        indices,
-        feature_lengths,
-    ):
-        factor = math.prod(self.quantizer.downsample_factor)
-        mel_masks = sequence_mask(feature_lengths * factor, indices.shape[2] * factor)
-        mel_masks_float_conv = mel_masks[:, None, :].float()
-        z = self.quantizer.decode(indices) * mel_masks_float_conv
-
-        return z
-
-    @torch.no_grad()
-    def encode(self, audios, audio_lengths, sr=None):
-        audios = audios.float()
-
-        mels = self.spec(audios, sample_rate=sr)
-        mel_lengths = audio_lengths // self.spec.hop_length
-        mel_masks = sequence_mask(mel_lengths, mels.shape[2])
-        mel_masks_float_conv = mel_masks[:, None, :].float()
-        mels = mels * mel_masks_float_conv
-
-        # Encode
-        encoded_features = self.encoder(mels) * mel_masks_float_conv
-        feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor)
-
-        return self.quantizer.encode(encoded_features), feature_lengths

+ 86 - 0
fish_speech/models/vqgan/modules/firefly.py

@@ -1,5 +1,6 @@
 # A inference only version of the FireflyGAN model
 
+import math
 from functools import partial
 from math import prod
 from typing import Callable
@@ -13,6 +14,8 @@ from torch.nn.utils.parametrizations import weight_norm
 from torch.nn.utils.parametrize import remove_parametrizations
 from torch.utils.checkpoint import checkpoint
 
+from fish_speech.models.vqgan.utils import sequence_mask
+
 
 def init_weights(m, mean=0.0, std=0.01):
     classname = m.__class__.__name__
@@ -474,6 +477,89 @@ class ConvNeXtEncoder(nn.Module):
         return self.norm(x)
 
 
+class FireflyArchitecture(nn.Module):
+    def __init__(
+        self,
+        backbone: nn.Module,
+        head: nn.Module,
+        quantizer: nn.Module,
+        spec_transform: nn.Module,
+    ):
+        super().__init__()
+
+        self.backbone = backbone
+        self.head = head
+        self.quantizer = quantizer
+        self.spec_transform = spec_transform
+
+    def forward(self, x: torch.Tensor, template=None, mask=None) -> torch.Tensor:
+        if self.spec_transform is not None:
+            x = self.spec_transform(x)
+
+        x = self.backbone(x)
+        if mask is not None:
+            x = x * mask
+
+        if self.quantizer is not None:
+            vq_result = self.quantizer(x)
+            x = vq_result.z
+
+            if mask is not None:
+                x = x * mask
+
+        x = self.head(x, template=template)
+
+        if x.ndim == 2:
+            x = x[:, None, :]
+
+        if self.vq is not None:
+            return x, vq_result
+
+        return x
+
+    def encode(self, audios, audio_lengths):
+        audios = audios.float()
+
+        mels = self.spec_transform(audios)
+        mel_lengths = audio_lengths // self.spec_transform.hop_length
+        mel_masks = sequence_mask(mel_lengths, mels.shape[2])
+        mel_masks_float_conv = mel_masks[:, None, :].float()
+        mels = mels * mel_masks_float_conv
+
+        # Encode
+        encoded_features = self.backbone(mels) * mel_masks_float_conv
+        feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor)
+
+        return self.quantizer.encode(encoded_features), feature_lengths
+
+    def decode(self, indices, feature_lengths) -> torch.Tensor:
+        factor = math.prod(self.quantizer.downsample_factor)
+        mel_masks = sequence_mask(feature_lengths * factor, indices.shape[2] * factor)
+        mel_masks_float_conv = mel_masks[:, None, :].float()
+
+        audio_masks = sequence_mask(
+            feature_lengths * factor * self.spec_transform.hop_length,
+            indices.shape[2] * factor * self.spec_transform.hop_length,
+        )
+        audio_masks_float_conv = audio_masks[:, None, :].float()
+
+        z = self.quantizer.decode(indices) * mel_masks_float_conv
+        x = self.head(z) * audio_masks_float_conv
+
+        return x
+
+    def remove_parametrizations(self):
+        if hasattr(self.backbone, "remove_parametrizations"):
+            self.backbone.remove_parametrizations()
+
+        if hasattr(self.head, "remove_parametrizations"):
+            self.head.remove_parametrizations()
+
+    @property
+    def device(self):
+        return next(self.parameters()).device
+
+
 class FireflyBase(nn.Module):
     def __init__(self, ckpt_path: str = None, pretrained: bool = True):
         super().__init__()

+ 1 - 1
fish_speech/models/vqgan/modules/fsq.py

@@ -20,7 +20,7 @@ class DownsampleFiniteScalarQuantize(nn.Module):
     def __init__(
         self,
         input_dim: int = 512,
-        n_codebooks: int = 9,
+        n_codebooks: int = 1,
         n_groups: int = 1,
         levels: tuple[int] = (8, 5, 5, 5),  # Approximate 2**10
         downsample_factor: tuple[int] = (2, 2),

+ 7 - 7
fish_speech/webui/manage.py

@@ -26,7 +26,7 @@ from fish_speech.i18n import i18n
 from fish_speech.webui.launch_utils import Seafoam, is_module_installed, versions_html
 
 config_path = cur_work_dir / "fish_speech" / "configs"
-vqgan_yml_path = config_path / "vqgan_finetune.yaml"
+vqgan_yml_path = config_path / "firefly_gan_vq.yaml"
 llama_yml_path = config_path / "text2semantic_finetune.yaml"
 vits_yml_path = config_path / "vits_decoder_finetune.yaml"
 
@@ -137,7 +137,7 @@ def change_decoder_config(decoder_model_path):
         choices = ["vits_decoder_finetune", "vits_decoder_pretrain"]
         return gr.Dropdown(choices=choices, value=choices[0])
     elif "vqgan" in decoder_model_path or "vq-gan" in decoder_model_path:
-        choices = ["vqgan_finetune", "vqgan_pretrain"]
+        choices = ["firefly_gan_vq", "firefly_gan_vq"]
         return gr.Dropdown(choices=choices, value=choices[0])
     else:
         raise ValueError("Invalid decoder name")
@@ -517,7 +517,7 @@ def train_process(
             PYTHON,
             "fish_speech/train.py",
             "--config-name",
-            "vqgan_finetune",
+            "firefly_gan_vq",
             f"project={project}",
             f"trainer.strategy.process_group_backend={backend}",
             f"model.optimizer.lr={vqgan_lr}",
@@ -590,9 +590,9 @@ def train_process(
                 "--batch-size",
                 "16",
                 "--config-name",
-                "vqgan_pretrain",
+                "firefly_gan_vq",
                 "--checkpoint-path",
-                "checkpoints/vq-gan-group-fsq-2x1024.pth",
+                "checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
             ]
         )
 
@@ -1292,8 +1292,8 @@ with gr.Blocks(
                                     choices=[
                                         "vits_decoder_finetune",
                                         "vits_decoder_pretrain",
-                                        "vqgan_finetune",
-                                        "vqgan_pretrain",
+                                        "firefly_gan_vq",
+                                        "firefly_gan_vq",
                                     ],
                                     allow_custom_value=True,
                                 )

+ 2 - 1
pyproject.toml

@@ -35,7 +35,8 @@ dependencies = [
     "vector_quantize_pytorch>=1.14.24",
     "samplerate>=0.2.1",
     "resampy>=0.4.3",
-    "einx[torch]==0.2.2"
+    "einx[torch]==0.2.2",
+    "zstandard>=0.22.0"
 ]
 
 [project.optional-dependencies]

+ 112 - 0
run.py

@@ -0,0 +1,112 @@
+import audioop
+import base64
+
+import numpy as np
+import soundfile as sf
+from fastapi import FastAPI, WebSocket
+from fastapi.responses import Response
+from loguru import logger
+
+from stream_service import FishAgentPipeline
+
+app = FastAPI()
+
+
+@app.post("/incoming")
+async def handle_incoming():
+    xml = """<Response>
+    <Connect>
+    <Stream url="wss://2427-24-4-31-213.ngrok-free.app/connection" />
+    </Connect>
+</Response>"""
+
+    logger.info("Incoming call received")
+    return Response(media_type="text/xml", content=xml)
+
+
+async def send_audio(ws, audio, stream_sid=""):
+    await ws.send_json(
+        {
+            "streamSid": stream_sid,
+            "event": "media",
+            "media": {
+                "payload": audio,
+            },
+        }
+    )
+
+
+def decode_mu_law(data):
+    samples = audioop.ulaw2lin(data, 2)
+    samples = np.frombuffer(samples, dtype=np.int16)
+    samples = samples.astype(np.float32) / 32768.0
+
+    return samples
+
+
+def encode_mu_law(data):
+    samples = np.clip(data, -1.0, 1.0)
+    samples = (samples * 32768).astype(np.int16)
+    samples = audioop.lin2ulaw(samples.tobytes(), 2)
+
+    return samples
+
+
+is_working = False
+
+
+@app.websocket("/connection")
+async def handle_connection(websocket: WebSocket):
+    global is_working
+
+    await websocket.accept()
+    logger.info("Connection established")
+    stream_sid = None
+    call_sid = None
+
+    if is_working:
+        logger.info("Already working, closing connection")
+        await websocket.close()
+        return
+
+    is_working = True
+    pipe.reset()
+
+    while True:
+        data = await websocket.receive_json()
+        if data["event"] == "connected":
+            logger.info("Connected message received")
+        elif data["event"] == "start":
+            stream_sid = data["start"]["streamSid"]
+            call_sid = data["start"]["callSid"]
+            logger.info(f"Start media streaming: {stream_sid} - {call_sid}")
+        elif data["event"] == "media":
+            payload = data["media"]["payload"]
+            chunk = base64.b64decode(payload)
+            samples = decode_mu_law(chunk)
+            for i in pipe.add_chunk(samples, sr=8000):
+                await send_audio(
+                    websocket, base64.b64encode(encode_mu_law(i)).decode(), stream_sid
+                )
+        elif data["event"] == "closed":
+            logger.info("Connection closed")
+            await websocket.close()
+            break
+        elif data["event"] == "stop":
+            logger.info("Stop media streaming")
+            await websocket.close()
+            break
+        else:
+            logger.info(f"Unknown event: {data}")
+
+    is_working = False
+
+
+if __name__ == "__main__":
+    import uvicorn
+
+    pipe = FishAgentPipeline()
+    pipe.warmup()
+
+    logger.info("Starting server")
+    uvicorn.run(app, host="localhost", port=5000)

+ 412 - 0
stream_service.py

@@ -0,0 +1,412 @@
+import time
+
+import librosa
+import numpy as np
+import torch
+import torchaudio
+from loguru import logger
+from torchaudio import functional as AF
+from transformers import (
+    AutoModelForSpeechSeq2Seq,
+    AutoProcessor,
+    AutoTokenizer,
+    pipeline,
+)
+
+from fish_speech.conversation import (
+    CODEBOOK_EOS_TOKEN_ID,
+    Conversation,
+    Message,
+    TokensPart,
+    encode_conversation,
+)
+from fish_speech.models.text2semantic.llama import DualARTransformer
+from tools.api import decode_vq_tokens, encode_reference
+from tools.llama.generate_test import convert_string
+from tools.llama.generate_test import generate as llama_generate
+from tools.llama.generate_test import load_model as load_llama_model
+from tools.vqgan.inference import load_model as load_decoder_model
+
+
+class FishStreamVAD:
+    def __init__(self) -> None:
+        # Args
+        self.sample_rate = 16000
+        self.threshold = 0.5
+        self.neg_threshold = self.threshold - 0.15
+        self.min_speech_duration_ms = 100
+        self.min_silence_ms = 500
+        self.speech_pad_ms = 30
+        self.chunk_size = 512
+
+        # Convert to samples
+        self.min_speech_duration_samples = (
+            self.min_speech_duration_ms * self.sample_rate // 1000
+        )
+        self.min_silence_samples = self.min_silence_ms * self.sample_rate // 1000
+        self.speech_pad_samples = self.speech_pad_ms * self.sample_rate // 1000
+
+        # Core buffers
+        self.reset()
+
+        # Load models
+        logger.info("Loading VAD model")
+        vad_model, vad_utils = torch.hub.load(
+            repo_or_dir="snakers4/silero-vad",
+            model="silero_vad",
+            force_reload=True,
+            onnx=True,
+        )
+
+        self.vad_model = vad_model
+        self.get_speech_timestamps = vad_utils[0]
+        logger.info("VAD model loaded")
+
+    def reset(self):
+        self.audio_chunks = None
+        self.vad_pointer = 0
+        self.speech_probs = []
+
+        self.triggered = False
+        self.start = self.end = self.temp_end = 0
+        self.last_seen_end = 0
+        self.speech_segments = []
+
+    def add_chunk(self, chunk, sr=None):
+        """
+        Add a chunk to the buffer
+        """
+
+        if isinstance(chunk, np.ndarray):
+            chunk = torch.from_numpy(chunk)
+
+        if sr is not None and sr != self.sample_rate:
+            chunk = AF.resample(chunk, sr, self.sample_rate)
+
+        # self.audio_chunks.append(chunk)
+        if self.audio_chunks is None:
+            self.audio_chunks = chunk
+        else:
+            self.audio_chunks = torch.cat([self.audio_chunks, chunk])
+
+        # Trigger VAD
+        yield from self.detect_speech()
+
+    def detect_speech(self):
+        """
+        Run the VAD model on the current buffer
+        """
+
+        speech_prob_start_idx = len(self.speech_probs)
+        while len(self.audio_chunks) - self.vad_pointer >= self.chunk_size:
+            chunk = self.audio_chunks[
+                self.vad_pointer : self.vad_pointer + self.chunk_size
+            ]
+            speech_prob = self.vad_model(chunk, self.sample_rate)
+            self.speech_probs.append(speech_prob)
+            self.vad_pointer += self.chunk_size
+
+        # Process speech probs
+        for i in range(speech_prob_start_idx, len(self.speech_probs)):
+            speech_prob = self.speech_probs[i]
+
+            if speech_prob >= self.threshold and self.temp_end:
+                self.temp_end = 0
+
+            if speech_prob >= self.threshold and self.triggered is False:
+                self.triggered = True
+                self.start = i * self.chunk_size
+                continue
+
+            if speech_prob < self.neg_threshold and self.triggered is True:
+                if self.temp_end == 0:
+                    self.temp_end = i * self.chunk_size
+
+                if i * self.chunk_size - self.temp_end < self.min_silence_samples:
+                    continue
+
+                self.end = self.temp_end
+                if self.end - self.start > self.min_speech_duration_samples:
+                    yield self.audio_chunks[
+                        self.start : self.end + self.speech_pad_samples
+                    ]
+
+                self.triggered = False
+                self.start = self.end = self.temp_end = 0
+
+
+class FishASR:
+    def __init__(self) -> None:
+        self.audio_chunks = None
+        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
+        torch_dtype = torch.bfloat16
+        model_id = "openai/whisper-medium.en"
+
+        logger.info("Loading ASR model")
+        model = AutoModelForSpeechSeq2Seq.from_pretrained(
+            model_id, torch_dtype=torch_dtype, use_safetensors=True
+        ).to(self.device)
+        processor = AutoProcessor.from_pretrained(model_id)
+        self.pipe = pipeline(
+            "automatic-speech-recognition",
+            model=model,
+            tokenizer=processor.tokenizer,
+            feature_extractor=processor.feature_extractor,
+            max_new_tokens=256,
+            torch_dtype=torch_dtype,
+            device=self.device,
+        )
+        logger.info("ASR model loaded")
+
+    @torch.inference_mode()
+    def run(self, chunk):
+        return self.pipe(chunk.numpy())
+
+
+class FishE2EAgent:
+    def __init__(self) -> None:
+        self.device = device = "cuda" if torch.cuda.is_available() else "cpu"
+        logger.info(f"Using device: {device}")
+
+        decoder_model = load_decoder_model(
+            config_name="firefly_gan_vq",
+            checkpoint_path="checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
+            device=device,
+        )
+        self.decoder_model = decoder_model
+        logger.info("Decoder model loaded")
+
+        llama_model, decode_one_token = load_llama_model(
+            config_name="dual_ar_2_codebook_1.3b",
+            checkpoint_path="checkpoints/step_000206000.ckpt",
+            device=device,
+            precision=torch.bfloat16,
+            max_length=2048,
+            compile=True,
+        )
+        self.llama_model: DualARTransformer = llama_model
+        self.decode_one_token = decode_one_token
+        logger.info("LLAMA model loaded")
+
+        self.tokenizer = AutoTokenizer.from_pretrained(
+            "checkpoints/fish-speech-agent-1"
+        )
+        self.semantic_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>")
+        self.im_end_id = self.tokenizer.convert_tokens_to_ids("<|im_end|>")
+        self.decoder_tokenizer = AutoTokenizer.from_pretrained(
+            "fishaudio/fish-speech-1"
+        )
+
+        # Control params
+        self.temperature = torch.tensor(0.7, device=device, dtype=torch.float)
+        self.top_p = torch.tensor(0.7, device=device, dtype=torch.float)
+        self.repetition_penalty = torch.tensor(1.2, device=device, dtype=torch.float)
+
+        # This is used to control the timbre of the generated audio
+        self.base_messages = [
+            # Message(
+            #     role="user",
+            #     parts=[np.load("example/q0.npy")],
+            # ),
+            # Message(
+            #     role="assistant",
+            #     parts=[
+            #         "Transcribed: Hi, can you briefly describe what is machine learning?\nResponse: Sure! Machine learning is the process of automating tasks that humans are capable of doing with a computer. It involves training computers to make decisions based on data.",
+            #         np.load("example/a0.npy"),
+            #     ],
+            # ),
+        ]
+        self.reference = encode_reference(
+            decoder_model=self.decoder_model,
+            reference_audio="example/a0.wav",
+            enable_reference_audio=True,
+        )
+        self.messages = self.base_messages.copy()
+
+    def reset(self):
+        self.messages = self.base_messages.copy()
+
+    @torch.inference_mode()
+    def vq_encode(self, audios, sr=None):
+        if isinstance(audios, np.ndarray):
+            audios = torch.from_numpy(audios)
+
+        if audios.ndim == 1:
+            audios = audios[None, None, :]
+
+        audios = audios.to(self.decoder_model.device)
+        if sr is not None and sr != self.decoder_model.sampling_rate:
+            audios = AF.resample(audios, sr, self.decoder_model.sampling_rate)
+
+        audio_lengths = torch.tensor(
+            [audios.shape[2]], device=self.decoder_model.device, dtype=torch.long
+        )
+
+        return self.decoder_model.encode(audios, audio_lengths)[0][0]
+
+    @torch.inference_mode()
+    def generate(self, audio_chunk, sr=None, text=None):
+        vq_output = self.vq_encode(audio_chunk, sr)
+        logger.info(f"VQ output: {vq_output.shape}")
+
+        # Encode conversation
+        self.messages.append(
+            Message(
+                role="user",
+                parts=[vq_output],
+            )
+        )
+
+        parts = []
+        if text is not None:
+            parts.append(f"Transcribed: {text}\nResponse:")
+
+        self.messages.append(
+            Message(
+                role="assistant",
+                parts=parts,
+            )
+        )
+        conversation = Conversation(self.messages)
+
+        # Encode the conversation
+        prompt, _ = encode_conversation(
+            conversation, self.tokenizer, self.llama_model.config.num_codebooks
+        )
+        prompt = prompt[:, :-1].to(dtype=torch.int, device=self.device)
+        prompt_length = prompt.shape[1]
+
+        # Generate
+        y = llama_generate(
+            model=self.llama_model,
+            prompt=prompt,
+            max_new_tokens=0,
+            eos_token_id=self.tokenizer.eos_token_id,
+            im_end_id=self.im_end_id,
+            decode_one_token=self.decode_one_token,
+            temperature=self.temperature,
+            top_p=self.top_p,
+            repetition_penalty=self.repetition_penalty,
+        )
+
+        tokens = self.tokenizer.decode(
+            y[0, prompt_length:].tolist(), skip_special_tokens=False
+        )
+        logger.info(f"Generated: {convert_string(tokens)}")
+
+        # Put the generated tokens
+        # since there is <im_end> and <eos> tokens, we remove last 2 tokens
+        code_mask = y[0, prompt_length:-2] == self.semantic_id
+        codes = y[1:, prompt_length:-2][:, code_mask].clone()
+
+        codes = codes - 2
+        assert (codes >= 0).all(), f"Negative code found"
+
+        decoded = y[:, prompt_length:-1].clone()
+        if decoded[0, -1] != self.im_end_id:  # <im_end>
+            val = [[self.im_end_id]] + [[CODEBOOK_EOS_TOKEN_ID]] * (decoded.size(0) - 1)
+            decoded = torch.cat(
+                (decoded, torch.tensor(val, device=self.device, dtype=torch.int)), dim=1
+            )
+
+        decoded = decoded.cpu()
+        self.messages[-1].parts.append(
+            TokensPart(
+                tokens=decoded[:1],
+                codes=decoded[1:],
+            )
+        )
+
+        # Less than 5 * 20 = 100ms
+        if codes.shape[1] <= 5:
+            return
+
+        # Generate audio
+        main_tokens = decoded[0]
+        text_tokens = main_tokens[main_tokens != self.semantic_id]
+        text = self.tokenizer.decode(text_tokens.tolist(), skip_special_tokens=True)
+        text_tokens = self.decoder_tokenizer.encode(text, return_tensors="pt").to(
+            self.device
+        )
+
+        audio = decode_vq_tokens(
+            decoder_model=self.decoder_model,
+            codes=codes,
+            text_tokens=text_tokens,
+            reference_embedding=self.reference,
+        )
+
+        if sr is not None and sr != self.decoder_model.sampling_rate:
+            audio = AF.resample(audio, self.decoder_model.sampling_rate, sr)
+
+        return audio.float()
+
+
+class FishAgentPipeline:
+    def __init__(self) -> None:
+        self.vad = FishStreamVAD()
+        # Currently use ASR model as intermediate
+        self.asr = FishASR()
+        self.agent = FishE2EAgent()
+
+        self.vad_segments = []
+        self.text_segments = []
+
+    def add_chunk(self, chunk, sr=None):
+        use_np = isinstance(chunk, np.ndarray)
+        if use_np:
+            chunk = torch.from_numpy(chunk)
+
+        if sr is not None and sr != 16000:
+            chunk = AF.resample(chunk, sr, 16000)
+
+        for vad_audio in self.vad.add_chunk(chunk, 16000):
+            self.vad_segments.append(vad_audio)
+            asr_text = self.asr.run(vad_audio)
+            self.text_segments.append(asr_text)
+            logger.info(f"ASR: {asr_text}")
+
+            # Actually should detect if intent is finished here
+            result = self.agent.generate(vad_audio, 16000, text=asr_text)
+            if result is None:
+                continue
+
+            if sr is not None and sr != 16000:
+                result = AF.resample(result, 16000, sr)
+
+            if use_np:
+                result = result.cpu().numpy()
+
+            yield result
+
+    def reset(self):
+        self.vad.reset()
+        self.agent.reset()
+        self.vad_segments = []
+        self.text_segments = []
+
+    def warmup(self):
+        logger.info("Warming up the pipeline")
+        audio, sr = librosa.load("example/q0.mp3", sr=16000)
+        for i in range(0, len(audio), 882):
+            for audio in self.add_chunk(audio[i : i + 882], sr):
+                pass
+        logger.info("Pipeline warmed up")
+        self.reset()
+
+
+if __name__ == "__main__":
+    import soundfile as sf
+
+    service = FishAgentPipeline()
+    service.warmup()
+    logger.info("Stream service started")
+
+    audio, sr = librosa.load("example/q1.mp3", sr=16000)
+    seg = []
+    for i in range(0, len(audio), 882):
+        for audio in service.add_chunk(audio[i : i + 882], sr):
+            seg.append(audio)
+
+    audio = np.concatenate(seg)
+    sf.write("output.wav", audio, 16000)

+ 183 - 0
test_echo.py

@@ -0,0 +1,183 @@
+import io
+import wave
+from typing import List
+
+import av
+import numpy as np
+from fastapi import FastAPI, WebSocket, WebSocketDisconnect
+from fastapi.responses import HTMLResponse
+
+app = FastAPI()
+
+html = """
+<!DOCTYPE html>
+<html>
+<head>
+    <title>Real-time Chat Room</title>
+</head>
+<body>
+    <h1>Real-time Chat Room</h1>
+    <button id="start">Start Streaming</button>
+    <button id="stop">Stop Streaming</button>
+    <script type="module">
+        import { MediaRecorder, register } from 'https://dev.jspm.io/npm:extendable-media-recorder';
+        import { connect } from 'https://dev.jspm.io/npm:extendable-media-recorder-wav-encoder';
+    
+        await register(await connect());
+
+        let socket;
+        let mediaRecorder;
+        let audioContext;
+
+        function startStreaming() {
+            initWebSocket();
+
+            audioContext = new (window.AudioContext || window.webkitAudioContext)();
+            navigator.mediaDevices.getUserMedia({ audio: {
+                channelCount: 1,  
+                sampleRate: 44100,
+                sampleSize: 16,
+                echoCancellation: true,
+                noiseSuppression: true
+            } })
+                .then(function (stream) {
+                    mediaRecorder = new MediaRecorder(stream, { mimeType: 'audio/webm;codecs=opus' });
+                    mediaRecorder.start(100);
+                    mediaRecorder.addEventListener("dataavailable", function (event) {
+                        socket.send(event.data);
+                    });
+                })
+                .catch(function (err) {
+                    console.error("Error accessing microphone:", err);
+                });
+
+                // Create a MediaSource
+                const mediaSource = new MediaSource();
+                const mediaStream = new MediaStream();
+
+                // Create an HTMLVideoElement and attach the MediaSource to it
+                const audioElement = document.createElement('audio');
+                audioElement.src = URL.createObjectURL(mediaSource);
+                audioElement.autoplay = true;
+                document.body.appendChild(audioElement);
+
+                mediaSource.addEventListener('sourceopen', function() {
+                    const sourceBuffer = mediaSource.addSourceBuffer('audio/webm; codecs=opus');
+
+                    socket.onmessage = function(event) {
+                        const arrayBuffer = event.data;
+
+                        sourceBuffer.appendBuffer(arrayBuffer);
+                    };
+                });
+        }
+
+        function stopStreaming() {
+            mediaRecorder.stop();
+        }
+
+        function initWebSocket() {
+            const is_wss = window.location.protocol === "https:";
+            socket = new WebSocket(`${is_wss ? "wss" : "ws"}://${window.location.host}/ws`);
+            socket.binaryType = 'arraybuffer';
+        }
+
+        document.getElementById("start").onclick = startStreaming;
+        document.getElementById("stop").onclick = stopStreaming;
+    </script>
+</body>
+</html>
+"""
+
+
+def encode_wav(data):
+    sample_rate = 44100
+    samples = np.frombuffer(data, dtype=np.int16)
+    buffer = io.BytesIO()
+
+    with wave.open(buffer, "wb") as wav_file:
+        wav_file.setnchannels(1)
+        wav_file.setsampwidth(2)
+        wav_file.setframerate(sample_rate)
+        wav_file.writeframes(samples.tobytes())
+
+    return buffer.getvalue()
+
+
+class ConnectionManager:
+    def __init__(self):
+        self.active_connections: List[WebSocket] = []
+
+    async def connect(self, websocket: WebSocket):
+        await websocket.accept()
+        self.active_connections.append(websocket)
+
+    def disconnect(self, websocket: WebSocket):
+        self.active_connections.remove(websocket)
+
+    async def broadcast(self, message: bytes, sender: WebSocket):
+        for connection in self.active_connections:
+            if connection == sender:
+                #     print("Sending message to client", connection)
+                await connection.send_bytes(message)
+
+
+manager = ConnectionManager()
+
+
+@app.get("/")
+async def get():
+    return HTMLResponse(html)
+
+
+@app.websocket("/ws")
+async def websocket_endpoint(websocket: WebSocket):
+    await manager.connect(websocket)
+    try:
+        buffer = io.BytesIO()
+        container = None
+        cur_pos = 0
+        total_size = 0
+
+        while True:
+            data = await websocket.receive_bytes()
+            # data = encode_wav(data)
+            # if len(data) == 1:
+            #     print(f"len(data): {len(data)}, data: {data}")
+            # if len(data) > 1:
+            #     data = b'\x1a' + data
+            #     with open("output.webm", "wb") as f:
+            #         f.write(data)
+            #     exit()
+            # print(f"len(data): {len(data)}")
+
+            # print("Received data:", data)
+            # Save as webm file and exit
+            # with open("output.wav", "wb") as f:
+            #     f.write(encode_wav(data))
+
+            buffer.write(data)
+            buffer.seek(cur_pos)
+            total_size += len(data)
+
+            if not container and total_size > 1000:
+                container = av.open(buffer, "r", format="webm")
+                print(container)
+            elif container:
+                for packet in container.decode(video=0):
+                    if packet.size == 0:
+                        continue
+
+                    cur_pos += packet.size
+                    for frame in packet.decode():
+                        print(frame.to_ndarray().shape)
+
+            await manager.broadcast(data, websocket)
+    except WebSocketDisconnect:
+        manager.disconnect(websocket)
+
+
+if __name__ == "__main__":
+    import uvicorn
+
+    uvicorn.run(app, host="0.0.0.0", port=8000)

+ 3 - 9
tools/api.py

@@ -400,21 +400,17 @@ def parse_args():
     parser.add_argument(
         "--llama-checkpoint-path",
         type=str,
-        default="checkpoints/text2semantic-sft-medium-v1-4k.pth",
-    )
-    parser.add_argument(
-        "--llama-config-name", type=str, default="dual_ar_2_codebook_medium"
+        default="checkpoints/fish-speech-1.2",
     )
     parser.add_argument(
         "--decoder-checkpoint-path",
         type=str,
-        default="checkpoints/vq-gan-group-fsq-2x1024.pth",
+        default="checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
     )
-    parser.add_argument("--decoder-config-name", type=str, default="vqgan_pretrain")
+    parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
     parser.add_argument("--tokenizer", type=str, default="fishaudio/fish-speech-1")
     parser.add_argument("--device", type=str, default="cuda")
     parser.add_argument("--half", action="store_true")
-    parser.add_argument("--max-length", type=int, default=2048)
     parser.add_argument("--compile", action="store_true")
     parser.add_argument("--max-text-length", type=int, default=0)
     parser.add_argument("--listen", type=str, default="127.0.0.1:8000")
@@ -450,11 +446,9 @@ if __name__ == "__main__":
 
     logger.info("Loading Llama model...")
     llama_queue = launch_thread_safe_queue(
-        config_name=args.llama_config_name,
         checkpoint_path=args.llama_checkpoint_path,
         device=args.device,
         precision=args.precision,
-        max_length=args.max_length,
         compile=args.compile,
     )
     llama_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)

+ 101 - 0
tools/llama/convert_hf_weights_to_llama.py

@@ -0,0 +1,101 @@
+import torch
+from transformers import LlamaForCausalLM
+
+from fish_speech.models.text2semantic.llama import BaseModelArgs, BaseTransformer
+
+# Load the HF model
+hf_model = LlamaForCausalLM.from_pretrained(
+    "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
+)
+
+model = BaseTransformer(
+    BaseModelArgs(
+        vocab_size=hf_model.config.vocab_size + 8,
+        n_layer=hf_model.config.num_hidden_layers,
+        n_head=hf_model.config.num_attention_heads,
+        n_local_heads=hf_model.config.num_key_value_heads,
+        dim=hf_model.config.hidden_size,
+        head_dim=hf_model.config.hidden_size // hf_model.config.num_attention_heads,
+        num_codebooks=2,
+        codebook_size=1032,
+    )
+)
+print(model.config)
+
+hf_state_dict = hf_model.state_dict()
+model_state_dict = model.state_dict()
+
+# print(hf_state_dict.keys())
+# print(model_state_dict.keys())
+
+new_state_dict = {}
+
+# Handle embeddings
+new_state_dict["embeddings.weight"] = model_state_dict.pop("embeddings.weight")
+hf_embed_tokens = hf_state_dict.pop("model.embed_tokens.weight")
+new_state_dict["embeddings.weight"][: hf_embed_tokens.shape[0]] = hf_embed_tokens
+
+# Restore layers
+for layer_idx in range(hf_model.config.num_hidden_layers):
+    # Handle attention
+    q_weight = hf_state_dict.pop(f"model.layers.{layer_idx}.self_attn.q_proj.weight")
+    k_weight = hf_state_dict.pop(f"model.layers.{layer_idx}.self_attn.k_proj.weight")
+    v_weight = hf_state_dict.pop(f"model.layers.{layer_idx}.self_attn.v_proj.weight")
+    qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0)
+    new_state_dict[f"layers.{layer_idx}.attention.wqkv.weight"] = qkv_weight
+    model_state_dict.pop(f"layers.{layer_idx}.attention.wqkv.weight")
+
+    o_weight = hf_state_dict.pop(f"model.layers.{layer_idx}.self_attn.o_proj.weight")
+    new_state_dict[f"layers.{layer_idx}.attention.wo.weight"] = o_weight
+    model_state_dict.pop(f"layers.{layer_idx}.attention.wo.weight")
+
+    # Handle feed forward
+    up_weight = hf_state_dict.pop(f"model.layers.{layer_idx}.mlp.up_proj.weight")
+    down_weight = hf_state_dict.pop(f"model.layers.{layer_idx}.mlp.down_proj.weight")
+    gate_weight = hf_state_dict.pop(f"model.layers.{layer_idx}.mlp.gate_proj.weight")
+
+    new_state_dict[f"layers.{layer_idx}.feed_forward.w1.weight"] = gate_weight
+    new_state_dict[f"layers.{layer_idx}.feed_forward.w2.weight"] = down_weight
+    new_state_dict[f"layers.{layer_idx}.feed_forward.w3.weight"] = up_weight
+
+    model_state_dict.pop(f"layers.{layer_idx}.feed_forward.w1.weight")
+    model_state_dict.pop(f"layers.{layer_idx}.feed_forward.w2.weight")
+    model_state_dict.pop(f"layers.{layer_idx}.feed_forward.w3.weight")
+
+    # Handle layer norms
+    input_layernorm_weight = hf_state_dict.pop(
+        f"model.layers.{layer_idx}.input_layernorm.weight"
+    )
+    post_attention_layernorm_weight = hf_state_dict.pop(
+        f"model.layers.{layer_idx}.post_attention_layernorm.weight"
+    )
+
+    new_state_dict[f"layers.{layer_idx}.ffn_norm.weight"] = (
+        post_attention_layernorm_weight
+    )
+    new_state_dict[f"layers.{layer_idx}.attention_norm.weight"] = input_layernorm_weight
+
+    model_state_dict.pop(f"layers.{layer_idx}.ffn_norm.weight")
+    model_state_dict.pop(f"layers.{layer_idx}.attention_norm.weight")
+
+# Handle final layer norm
+new_state_dict["norm.weight"] = hf_state_dict.pop("model.norm.weight")
+model_state_dict.pop("norm.weight")
+
+# Handle output layer
+w = hf_state_dict.pop("lm_head.weight")
+new_state_dict["output.weight"] = model_state_dict.pop("output.weight")
+new_state_dict["output.weight"][: w.shape[0]] = w
+
+print(hf_state_dict.keys(), len(hf_state_dict))
+print(model_state_dict.keys(), len(model_state_dict))
+
+print(model.load_state_dict(new_state_dict, strict=True))
+
+model = model.bfloat16()
+
+new_state_dict = {f"model.{k}": v for k, v in model.state_dict().items()}
+torch.save(
+    new_state_dict,
+    "checkpoints/fish-speech-agent-1/TinyLlama-1.1B-intermediate-step-1431k-3T.pth",
+)

+ 24 - 85
tools/llama/generate.py

@@ -19,7 +19,7 @@ from loguru import logger
 from tqdm import tqdm
 from transformers import AutoTokenizer
 
-from fish_speech.datasets.text import CODEBOOK_EOS_TOKEN_ID, CODEBOOK_PAD_TOKEN_ID
+from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
 from fish_speech.text import clean_text, split_text
 
 os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -31,7 +31,11 @@ if hasattr(torch._inductor.config, "fx_graph_cache"):
     torch._inductor.config.fx_graph_cache = True
 
 
-from fish_speech.models.text2semantic.llama import DualARTransformer, NaiveTransformer
+from fish_speech.models.text2semantic.llama import (
+    BaseTransformer,
+    DualARTransformer,
+    NaiveTransformer,
+)
 
 
 def multinomial_sample_one_no_sync(
@@ -161,7 +165,6 @@ def decode_n_tokens(
     cur_token: torch.Tensor,
     input_pos: torch.Tensor,
     num_new_tokens: int,
-    eos_token_id: int = 2,
     im_end_id: int = 4,
     decode_one_token=decode_one_token_naive,
     **sampling_kwargs,
@@ -197,11 +200,7 @@ def decode_n_tokens(
             model.config.num_codebooks + 1, -1
         )
 
-        if (
-            cur_token[0, 0, -1] == eos_token_id
-            or cur_token[0, 0, -1] == im_end_id
-            or (cur_token[0, 1:, -1] == CODEBOOK_EOS_TOKEN_ID).any()
-        ):
+        if cur_token[0, 0, -1] == im_end_id:
             break
 
     return previous_tokens[:, : i + 1]
@@ -214,7 +213,6 @@ def generate(
     model: NaiveTransformer,
     prompt: torch.Tensor,
     max_new_tokens: int,
-    eos_token_id: int = 2,
     im_end_id: int = 4,
     decode_one_token=decode_one_token_naive,
     **sampling_kwargs,
@@ -255,6 +253,7 @@ def generate(
         if isinstance(model, NaiveTransformer)
         else decode_one_token_ar
     )
+
     next_token = prefill_decode(
         model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
     )
@@ -266,7 +265,6 @@ def generate(
         next_token.view(1, codebook_dim, -1),
         input_pos,
         max_new_tokens - 1,
-        eos_token_id=eos_token_id,
         im_end_id=im_end_id,
         decode_one_token=decode_one_token,
         **sampling_kwargs,
@@ -281,22 +279,12 @@ def generate(
 def encode_tokens(
     tokenizer,
     string,
-    bos=True,
     device="cuda",
     prompt_tokens=None,
-    speaker=None,
     num_codebooks=4,
 ):
     string = clean_text(string)
-
-    if speaker is None:
-        speaker = "assistant"
-
-    string = (
-        f"<|im_start|>user<|im_sep|>{string}<|im_end|><|im_start|>{speaker}<|im_sep|>"
-    )
-    if bos:
-        string = f"<|begin_of_sequence|>{string}"
+    string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n"
 
     new_tokens = tokenizer.encode(
         string,
@@ -324,7 +312,7 @@ def encode_tokens(
         prompt_tokens = prompt_tokens[0]
 
     assert prompt_tokens.ndim == 2
-    data = prompt_tokens + 2
+    data = prompt_tokens + 1
 
     if prompt_tokens.shape[0] > num_codebooks:
         logger.warning(
@@ -332,13 +320,9 @@ def encode_tokens(
         )
         data = data[:num_codebooks]
 
-    # Add eos token for each codebook
+    # Add pad token for each codebook
     data = torch.cat(
-        (
-            data,
-            torch.ones((data.size(0), 1), dtype=torch.int, device=device)
-            * CODEBOOK_EOS_TOKEN_ID,
-        ),
+        (data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)),
         dim=1,
     )
 
@@ -356,16 +340,10 @@ def encode_tokens(
     return prompt
 
 
-def load_model(
-    config_name, checkpoint_path, device, precision, max_length, compile=False
-):
-    hydra.core.global_hydra.GlobalHydra.instance().clear()
-    with initialize(version_base="1.3", config_path="../../fish_speech/configs/model"):
-        cfg = compose(
-            config_name=config_name, overrides=[f"config.max_seq_len={max_length}"]
-        )
-
-    model: Union[NaiveTransformer, DualARTransformer] = instantiate(cfg)
+def load_model(checkpoint_path, device, precision, compile=False):
+    model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
+        checkpoint_path, load_weights=True
+    )
 
     if "int8" in str(checkpoint_path):
         logger.info("Using int8 weight-only quantization!")
@@ -384,21 +362,8 @@ def load_model(
         simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
         model = simple_quantizer.convert_for_runtime()
 
-    checkpoint = torch.load(str(checkpoint_path), map_location="cpu")
-    if "state_dict" in checkpoint:
-        checkpoint = checkpoint["state_dict"]
-
-    if any(k.startswith("model.") for k in checkpoint):
-        checkpoint = {
-            k.replace("model.", ""): v
-            for k, v in checkpoint.items()
-            if k.startswith("model.")
-        }
-
-    model.load_state_dict(checkpoint, assign=True)
-
     model = model.to(device=device, dtype=precision)
-    logger.info("Restored model from checkpoint")
+    logger.info(f"Restored model from checkpoint")
 
     if isinstance(model, DualARTransformer):
         decode_one_token = decode_one_token_ar
@@ -426,7 +391,6 @@ class GenerateResponse:
 def generate_long(
     *,
     model,
-    tokenizer: callable,
     device: str | torch.device,
     decode_one_token: callable,
     text: str,
@@ -439,7 +403,6 @@ def generate_long(
     iterative_prompt: bool = True,
     max_length: int = 2048,
     chunk_length: int = 150,
-    speaker: Optional[str] = None,
     prompt_text: Optional[str | list[str]] = None,
     prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
 ):
@@ -457,6 +420,7 @@ def generate_long(
     ), "Prompt text and tokens must have the same length"
 
     model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
+    tokenizer = model.tokenizer
     im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
 
     encoded = []
@@ -469,10 +433,8 @@ def generate_long(
                 encode_tokens(
                     tokenizer,
                     string=t,
-                    bos=idx == 0,
                     device=device,
                     prompt_tokens=c,
-                    speaker=speaker,
                     num_codebooks=model.config.num_codebooks,
                 )
             )
@@ -482,9 +444,7 @@ def generate_long(
             encode_tokens(
                 tokenizer,
                 string=text,
-                bos=idx == 0 and not use_prompt,
                 device=device,
-                speaker=speaker,
                 num_codebooks=model.config.num_codebooks,
             )
         )
@@ -544,7 +504,6 @@ def generate_long(
                 model=model,
                 prompt=cat_encoded,
                 max_new_tokens=max_new_tokens,
-                eos_token_id=tokenizer.eos_token_id,
                 im_end_id=im_end_id,
                 decode_one_token=decode_one_token,
                 temperature=temperature,
@@ -576,19 +535,13 @@ def generate_long(
 
             # Put the generated tokens
             # since there is <im_end> and <eos> tokens, we remove last 2 tokens
-            codes = y[1:, prompt_length:-2].clone()
-
-            codes = codes - 2
+            codes = y[1:, prompt_length:-1].clone()
+            codes = codes - 1
             assert (codes >= 0).all(), f"Negative code found"
 
             decoded = y[:, prompt_length:-1].clone()
-            if decoded[0, -1] != im_end_id:  # <im_end>
-                val = [[im_end_id]] + [[CODEBOOK_EOS_TOKEN_ID]] * (decoded.size(0) - 1)
-                decoded = torch.cat(
-                    (decoded, torch.tensor(val, device=device, dtype=torch.int)), dim=1
-                )
-
             # But for global encoding, we should keep the <im_end> token
+
             global_encoded.append(decoded)
             assert (codes >= 0).all(), f"Negative code found: {codes}"
             yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
@@ -611,11 +564,9 @@ class GenerateRequest:
 
 
 def launch_thread_safe_queue(
-    config_name,
     checkpoint_path,
     device,
     precision,
-    max_length: int,
     compile: bool = False,
 ):
     input_queue = queue.Queue()
@@ -623,7 +574,7 @@ def launch_thread_safe_queue(
 
     def worker():
         model, decode_one_token = load_model(
-            config_name, checkpoint_path, device, precision, max_length, compile=compile
+            checkpoint_path, device, precision, compile=compile
         )
         init_event.set()
 
@@ -672,16 +623,12 @@ def launch_thread_safe_queue(
 @click.option(
     "--checkpoint-path",
     type=click.Path(path_type=Path, exists=True),
-    default="checkpoints/text2semantic-sft-medium-v1-4k.pth",
+    default="checkpoints/fish-speech-1.2",
 )
-@click.option("--config-name", type=str, default="dual_ar_2_codebook_medium")
-@click.option("--tokenizer", type=str, default="fishaudio/fish-speech-1")
 @click.option("--compile/--no-compile", default=False)
 @click.option("--seed", type=int, default=42)
-@click.option("--speaker", type=str, default=None)
 @click.option("--half/--no-half", default=False)
 @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
-@click.option("--max-length", type=int, default=2048)
 @click.option("--chunk-length", type=int, default=150)
 def main(
     text: str,
@@ -693,14 +640,10 @@ def main(
     repetition_penalty: float,
     temperature: float,
     checkpoint_path: Path,
-    config_name: str,
-    tokenizer: str,
     compile: bool,
     seed: int,
-    speaker: Optional[str],
     half: bool,
     iterative_prompt: bool,
-    max_length: int,
     chunk_length: int,
 ) -> None:
     device = "cuda"
@@ -715,7 +658,7 @@ def main(
     logger.info("Loading model ...")
     t0 = time.time()
     model, decode_one_token = load_model(
-        config_name, checkpoint_path, device, precision, max_length, compile=compile
+        checkpoint_path, device, precision, compile=compile
     )
 
     if torch.cuda.is_available():
@@ -726,7 +669,6 @@ def main(
     if prompt_tokens is not None:
         prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
 
-    tokenizer = AutoTokenizer.from_pretrained(tokenizer)
     torch.manual_seed(seed)
 
     if torch.cuda.is_available():
@@ -742,11 +684,8 @@ def main(
         top_p=top_p,
         repetition_penalty=repetition_penalty,
         temperature=temperature,
-        tokenizer=tokenizer,
         compile=compile,
-        speaker=speaker,
         iterative_prompt=iterative_prompt,
-        max_length=max_length,
         chunk_length=chunk_length,
         prompt_text=prompt_text,
         prompt_tokens=prompt_tokens,

+ 1 - 3
tools/llama/merge_lora.py

@@ -14,9 +14,7 @@ from fish_speech.models.text2semantic.lora_utils import (
 @click.command()
 @click.option("--llama-config", type=str, default="dual_ar_2_codebook_medium")
 @click.option("--lora-config", type=str, default="r_8_alpha_16")
-@click.option(
-    "--llama-weight", type=str, default="checkpoints/text2semantic-sft-medium-v1-4k.pth"
-)
+@click.option("--llama-weight", type=str, default="checkpoints/fish-speech-1.2")
 @click.option("--lora-weight", type=str, required=True)
 @click.option("--output", type=str, required=True)
 def merge(llama_config, lora_config, llama_weight, lora_weight, output):

+ 1 - 1
tools/llama/quantize.py

@@ -419,7 +419,7 @@ class WeightOnlyInt4Linear(torch.nn.Module):
 @click.option(
     "--checkpoint-path",
     type=click.Path(path_type=Path, exists=True),
-    default="checkpoints/text2semantic-sft-medium-v1-4k.pth",
+    default="checkpoints/fish-speech-1.2",
 )
 @click.option("--config-name", type=str, default="dual_ar_2_codebook_medium")
 @click.option(

+ 1 - 1
tools/vits_decoder/inference.py

@@ -72,7 +72,7 @@ def load_model(config_name, checkpoint_path, device="cuda"):
 @click.option(
     "--checkpoint-path",
     "-ckpt",
-    default="checkpoints/vq-gan-group-fsq-2x1024.pth",
+    default="checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
 )
 @click.option(
     "--device",

+ 19 - 9
tools/vqgan/extract_vq.py

@@ -41,23 +41,31 @@ logger.add(sys.stderr, format=logger_format)
 
 @lru_cache(maxsize=1)
 def get_model(
-    config_name: str = "vqgan_pretrain",
-    checkpoint_path: str = "checkpoints/vq-gan-group-fsq-2x1024.pth",
+    config_name: str = "firefly_gan_vq",
+    checkpoint_path: str = "checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
+    device: str | torch.device = "cuda",
 ):
     with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
         cfg = compose(config_name=config_name)
 
-    model: LightningModule = instantiate(cfg.model)
+    model = instantiate(cfg)
     state_dict = torch.load(
         checkpoint_path,
-        map_location=model.device,
+        map_location=device,
     )
     if "state_dict" in state_dict:
         state_dict = state_dict["state_dict"]
 
+    if any("generator" in k for k in state_dict):
+        state_dict = {
+            k.replace("generator.", ""): v
+            for k, v in state_dict.items()
+            if "generator." in k
+        }
+
     model.load_state_dict(state_dict, strict=False)
     model.eval()
-    model.cuda()
+    model.to(device)
 
     logger.info(f"Loaded model")
     return model
@@ -82,8 +90,10 @@ def process_batch(files: list[Path], model) -> float:
         if wav.shape[0] > 1:
             wav = wav.mean(dim=0, keepdim=True)
 
-        wav = torchaudio.functional.resample(wav.cuda(), sr, model.sampling_rate)[0]
-        total_time += len(wav) / model.sampling_rate
+        wav = torchaudio.functional.resample(
+            wav.cuda(), sr, model.spec_transform.sample_rate
+        )[0]
+        total_time += len(wav) / model.spec_transform.sample_rate
         max_length = max(max_length, len(wav))
 
         wavs.append(wav)
@@ -120,10 +130,10 @@ def process_batch(files: list[Path], model) -> float:
 @click.command()
 @click.argument("folder")
 @click.option("--num-workers", default=1)
-@click.option("--config-name", default="vqgan_pretrain")
+@click.option("--config-name", default="firefly_gan_vq")
 @click.option(
     "--checkpoint-path",
-    default="checkpoints/vq-gan-group-fsq-2x1024.pth",
+    default="checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
 )
 @click.option("--batch-size", default=64)
 @click.option("--filelist", default=None, type=Path)

+ 24 - 22
tools/vqgan/inference.py

@@ -8,7 +8,6 @@ import torch
 import torchaudio
 from hydra import compose, initialize
 from hydra.utils import instantiate
-from lightning import LightningModule
 from loguru import logger
 from omegaconf import OmegaConf
 
@@ -23,20 +22,26 @@ def load_model(config_name, checkpoint_path, device="cuda"):
     with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
         cfg = compose(config_name=config_name)
 
-    model: LightningModule = instantiate(cfg.model)
+    model = instantiate(cfg)
     state_dict = torch.load(
         checkpoint_path,
-        map_location=model.device,
+        map_location=device,
     )
-
     if "state_dict" in state_dict:
         state_dict = state_dict["state_dict"]
 
-    model.load_state_dict(state_dict, strict=False)
+    if any("generator" in k for k in state_dict):
+        state_dict = {
+            k.replace("generator.", ""): v
+            for k, v in state_dict.items()
+            if "generator." in k
+        }
+
+    result = model.load_state_dict(state_dict, strict=False)
     model.eval()
     model.to(device)
-    logger.info("Restored model from checkpoint")
 
+    logger.info(f"Loaded model: {result}")
     return model
 
 
@@ -51,11 +56,10 @@ def load_model(config_name, checkpoint_path, device="cuda"):
 @click.option(
     "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
 )
-@click.option("--config-name", "-cfg", default="vqgan_pretrain")
+@click.option("--config-name", default="firefly_gan_vq")
 @click.option(
     "--checkpoint-path",
-    "-ckpt",
-    default="checkpoints/vq-gan-group-fsq-2x1024.pth",
+    default="checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
 )
 @click.option(
     "--device",
@@ -72,17 +76,17 @@ def main(input_path, output_path, config_name, checkpoint_path, device):
         audio, sr = torchaudio.load(str(input_path))
         if audio.shape[0] > 1:
             audio = audio.mean(0, keepdim=True)
-        audio = torchaudio.functional.resample(audio, sr, model.sampling_rate)
+        audio = torchaudio.functional.resample(
+            audio, sr, model.spec_transform.sample_rate
+        )
 
-        audios = audio[None].to(model.device)
+        audios = audio[None].to(device)
         logger.info(
-            f"Loaded audio with {audios.shape[2] / model.sampling_rate:.2f} seconds"
+            f"Loaded audio with {audios.shape[2] / model.spec_transform.sample_rate:.2f} seconds"
         )
 
         # VQ Encoder
-        audio_lengths = torch.tensor(
-            [audios.shape[2]], device=model.device, dtype=torch.long
-        )
+        audio_lengths = torch.tensor([audios.shape[2]], device=device, dtype=torch.long)
         indices = model.encode(audios, audio_lengths)[0][0]
 
         logger.info(f"Generated indices of shape {indices.shape}")
@@ -92,17 +96,15 @@ def main(input_path, output_path, config_name, checkpoint_path, device):
     elif input_path.suffix == ".npy":
         logger.info(f"Processing precomputed indices from {input_path}")
         indices = np.load(input_path)
-        indices = torch.from_numpy(indices).to(model.device).long()
+        indices = torch.from_numpy(indices).to(device).long()
         assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
     else:
         raise ValueError(f"Unknown input type: {input_path}")
 
     # Restore
-    feature_lengths = torch.tensor([indices.shape[1]], device=model.device)
-    fake_audios = model.decode(
-        indices=indices[None], feature_lengths=feature_lengths, return_audios=True
-    )
-    audio_time = fake_audios.shape[-1] / model.sampling_rate
+    feature_lengths = torch.tensor([indices.shape[1]], device=device)
+    fake_audios = model.decode(indices=indices[None], feature_lengths=feature_lengths)
+    audio_time = fake_audios.shape[-1] / model.spec_transform.sample_rate
 
     logger.info(
         f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}"
@@ -110,7 +112,7 @@ def main(input_path, output_path, config_name, checkpoint_path, device):
 
     # Save audio
     fake_audio = fake_audios[0, 0].float().cpu().numpy()
-    sf.write(output_path, fake_audio, model.sampling_rate)
+    sf.write(output_path, fake_audio, model.spec_transform.sample_rate)
     logger.info(f"Saved audio to {output_path}")
 
 

+ 3 - 9
tools/webui.py

@@ -443,21 +443,17 @@ def parse_args():
     parser.add_argument(
         "--llama-checkpoint-path",
         type=Path,
-        default="checkpoints/text2semantic-sft-large-v1.1-4k.pth",
-    )
-    parser.add_argument(
-        "--llama-config-name", type=str, default="dual_ar_2_codebook_large"
+        default="checkpoints/fish-speech-1.2",
     )
     parser.add_argument(
         "--decoder-checkpoint-path",
         type=Path,
-        default="checkpoints/vq-gan-group-fsq-2x1024.pth",
+        default="checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
     )
-    parser.add_argument("--decoder-config-name", type=str, default="vqgan_pretrain")
+    parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
     parser.add_argument("--tokenizer", type=str, default="fishaudio/fish-speech-1")
     parser.add_argument("--device", type=str, default="cuda")
     parser.add_argument("--half", action="store_true")
-    parser.add_argument("--max-length", type=int, default=2048)
     parser.add_argument("--compile", action="store_true")
     parser.add_argument("--max-gradio-length", type=int, default=0)
 
@@ -470,11 +466,9 @@ if __name__ == "__main__":
 
     logger.info("Loading Llama model...")
     llama_queue = launch_thread_safe_queue(
-        config_name=args.llama_config_name,
         checkpoint_path=args.llama_checkpoint_path,
         device=args.device,
         precision=args.precision,
-        max_length=args.max_length,
         compile=args.compile,
     )
     llama_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)