Просмотр исходного кода

This PR brings V1.2 inference into main (#300)

* Add model converter & agent test code

* Optimize training speed

* Fix agent nan loss

* Add new dataset & conversation system

* handle codebook bias in conversation

* Fix compile

* Implement new conversation system

* Support new vq

* Update VQ config

* optm:(build dataset) use mp to speedup grouping

* fix p not defined

* rollback due to thread lock

* remove vits decoder & add firefly vq

* remove unused configs, test generate.py

---------

Co-authored-by: Stardust·减 <stardust@fish.audio>
Leng Yue 1 год назад
Родитель
Сommit
5e7914472f
49 измененных файлов с 1975 добавлено и 3625 удалено
  1. 1 0
      .gitignore
  2. 2 2
      API_FLAGS.txt
  3. 5 5
      docs/en/finetune.md
  4. 5 5
      docs/en/inference.md
  5. 5 5
      docs/zh/finetune.md
  6. 6 6
      docs/zh/inference.md
  7. 34 0
      fish_speech/configs/firefly_gan_vq.yaml
  8. 0 9
      fish_speech/configs/model/dual_ar_2_codebook_large.yaml
  9. 0 9
      fish_speech/configs/model/dual_ar_2_codebook_medium.yaml
  10. 0 13
      fish_speech/configs/model/dual_ar_2_codebook_small.yaml
  11. 0 12
      fish_speech/configs/model/naive_2_codebook_small.yaml
  12. 66 0
      fish_speech/configs/text2semantic_agent.yaml
  13. 0 128
      fish_speech/configs/vits_decoder_finetune.yaml
  14. 0 127
      fish_speech/configs/vits_decoder_pretrain.yaml
  15. 0 137
      fish_speech/configs/vqgan_finetune.yaml
  16. 0 140
      fish_speech/configs/vqgan_pretrain.yaml
  17. 2 0
      fish_speech/conversation.py
  18. 0 35
      fish_speech/datasets/concat_repeat.py
  19. 381 0
      fish_speech/datasets/prompts.py
  20. 263 298
      fish_speech/datasets/text.py
  21. 0 3
      fish_speech/models/text2semantic/__init__.py
  22. 11 7
      fish_speech/models/text2semantic/lit_module.py
  23. 206 68
      fish_speech/models/text2semantic/llama.py
  24. 8 0
      fish_speech/models/text2semantic/lora.py
  25. 0 3
      fish_speech/models/vits_decoder/__init__.py
  26. 0 394
      fish_speech/models/vits_decoder/lit_module.py
  27. 0 67
      fish_speech/models/vits_decoder/losses.py
  28. 0 350
      fish_speech/models/vits_decoder/modules/attentions.py
  29. 0 190
      fish_speech/models/vits_decoder/modules/commons.py
  30. 0 686
      fish_speech/models/vits_decoder/modules/models.py
  31. 0 619
      fish_speech/models/vits_decoder/modules/modules.py
  32. 0 58
      fish_speech/models/vits_decoder/modules/mrte.py
  33. 0 101
      fish_speech/models/vits_decoder/modules/vq_encoder.py
  34. 86 0
      fish_speech/models/vqgan/modules/firefly.py
  35. 1 1
      fish_speech/models/vqgan/modules/fsq.py
  36. 7 7
      fish_speech/webui/manage.py
  37. 2 1
      pyproject.toml
  38. 112 0
      run.py
  39. 412 0
      stream_service.py
  40. 183 0
      test_echo.py
  41. 3 9
      tools/api.py
  42. 101 0
      tools/llama/convert_hf_weights_to_llama.py
  43. 24 85
      tools/llama/generate.py
  44. 1 3
      tools/llama/merge_lora.py
  45. 1 1
      tools/llama/quantize.py
  46. 1 1
      tools/vits_decoder/inference.py
  47. 19 9
      tools/vqgan/extract_vq.py
  48. 24 22
      tools/vqgan/inference.py
  49. 3 9
      tools/webui.py

+ 1 - 0
.gitignore

@@ -25,3 +25,4 @@ asr-label*
 /.locale
 /.locale
 /demo-audios
 /demo-audios
 ref_data*
 ref_data*
+/example

+ 2 - 2
API_FLAGS.txt

@@ -3,5 +3,5 @@
 --listen 0.0.0.0:8000 \
 --listen 0.0.0.0:8000 \
 --llama-checkpoint-path "checkpoints/text2semantic-sft-medium-v1.1-4k.pth" \
 --llama-checkpoint-path "checkpoints/text2semantic-sft-medium-v1.1-4k.pth" \
 --llama-config-name dual_ar_2_codebook_medium \
 --llama-config-name dual_ar_2_codebook_medium \
---decoder-checkpoint-path "checkpoints/vq-gan-group-fsq-2x1024.pth" \
---decoder-config-name vqgan_finetune
+--decoder-checkpoint-path "checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth" \
+--decoder-config-name firefly_gan_vq

+ 5 - 5
docs/en/finetune.md

@@ -59,8 +59,8 @@ You can then run the following command to extract semantic tokens:
 ```bash
 ```bash
 python tools/vqgan/extract_vq.py data \
 python tools/vqgan/extract_vq.py data \
     --num-workers 1 --batch-size 16 \
     --num-workers 1 --batch-size 16 \
-    --config-name "vqgan_pretrain" \
-    --checkpoint-path "checkpoints/vq-gan-group-fsq-2x1024.pth"
+    --config-name "firefly_gan_vq" \
+    --checkpoint-path "checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth"
 ```
 ```
 
 
 !!! note
 !!! note
@@ -233,16 +233,16 @@ This command will create `data/vq_train_filelist.txt` and `data/vq_val_filelist.
 ### 3. Start Training
 ### 3. Start Training
 
 
 ```bash
 ```bash
-python fish_speech/train.py --config-name vqgan_finetune
+python fish_speech/train.py --config-name firefly_gan_vq
 ```
 ```
 
 
 !!! note
 !!! note
-    You can modify training parameters by editing `fish_speech/configs/vqgan_finetune.yaml`, but in most cases, this won't be necessary.
+    You can modify training parameters by editing `fish_speech/configs/firefly_gan_vq.yaml`, but in most cases, this won't be necessary.
 
 
 ### 4. Test the Audio
 ### 4. Test the Audio
     
     
 ```bash
 ```bash
-python tools/vqgan/inference.py -i test.wav --checkpoint-path results/vqgan_finetune/checkpoints/step_000010000.ckpt
+python tools/vqgan/inference.py -i test.wav --checkpoint-path results/firefly_gan_vq/checkpoints/step_000010000.ckpt
 ```
 ```
 
 
 You can review `fake.wav` to assess the fine-tuning results.
 You can review `fake.wav` to assess the fine-tuning results.

+ 5 - 5
docs/en/inference.md

@@ -31,7 +31,7 @@ huggingface-cli download fishaudio/fish-speech-1 firefly-gan-base-generator.ckpt
 ```bash
 ```bash
 python tools/vqgan/inference.py \
 python tools/vqgan/inference.py \
     -i "paimon.wav" \
     -i "paimon.wav" \
-    --checkpoint-path "checkpoints/vq-gan-group-fsq-2x1024.pth"
+    --checkpoint-path "checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth"
 ```
 ```
 You should get a `fake.npy` file.
 You should get a `fake.npy` file.
 
 
@@ -73,7 +73,7 @@ python tools/vits_decoder/inference.py \
 ```bash
 ```bash
 python tools/vqgan/inference.py \
 python tools/vqgan/inference.py \
     -i "codes_0.npy" \
     -i "codes_0.npy" \
-    --checkpoint-path "checkpoints/vq-gan-group-fsq-2x1024.pth"
+    --checkpoint-path "checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth"
 ```
 ```
 
 
 ## HTTP API Inference
 ## HTTP API Inference
@@ -85,8 +85,8 @@ python -m tools.api \
     --listen 0.0.0.0:8000 \
     --listen 0.0.0.0:8000 \
     --llama-checkpoint-path "checkpoints/text2semantic-sft-medium-v1.1-4k.pth" \
     --llama-checkpoint-path "checkpoints/text2semantic-sft-medium-v1.1-4k.pth" \
     --llama-config-name dual_ar_2_codebook_medium \
     --llama-config-name dual_ar_2_codebook_medium \
-    --decoder-checkpoint-path "checkpoints/vq-gan-group-fsq-2x1024.pth" \
-    --decoder-config-name vqgan_pretrain
+    --decoder-checkpoint-path "checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth" \
+    --decoder-config-name firefly_gan_vq
 ```
 ```
 
 
 After that, you can view and test the API at http://127.0.0.1:8000/.  
 After that, you can view and test the API at http://127.0.0.1:8000/.  
@@ -107,7 +107,7 @@ You can start the WebUI using the following command:
 python -m tools.webui \
 python -m tools.webui \
     --llama-checkpoint-path "checkpoints/text2semantic-sft-medium-v1.1-4k.pth" \
     --llama-checkpoint-path "checkpoints/text2semantic-sft-medium-v1.1-4k.pth" \
     --llama-config-name dual_ar_2_codebook_medium \
     --llama-config-name dual_ar_2_codebook_medium \
-    --vqgan-checkpoint-path "checkpoints/vq-gan-group-fsq-2x1024.pth" \
+    --vqgan-checkpoint-path "checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth" \
     --vits-checkpoint-path "checkpoints/vits_decoder_v1.1.ckpt"
     --vits-checkpoint-path "checkpoints/vits_decoder_v1.1.ckpt"
 ```
 ```
 
 

+ 5 - 5
docs/zh/finetune.md

@@ -63,8 +63,8 @@ HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech
 ```bash
 ```bash
 python tools/vqgan/extract_vq.py data \
 python tools/vqgan/extract_vq.py data \
     --num-workers 1 --batch-size 16 \
     --num-workers 1 --batch-size 16 \
-    --config-name "vqgan_pretrain" \
-    --checkpoint-path "checkpoints/vq-gan-group-fsq-2x1024.pth"
+    --config-name "firefly_gan_vq" \
+    --checkpoint-path "checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth"
 ```
 ```
 
 
 !!! note
 !!! note
@@ -239,16 +239,16 @@ python tools/vqgan/create_train_split.py data
 ### 3. 启动训练
 ### 3. 启动训练
 
 
 ```bash
 ```bash
-python fish_speech/train.py --config-name vqgan_finetune
+python fish_speech/train.py --config-name firefly_gan_vq
 ```
 ```
 
 
 !!! note
 !!! note
-    你可以通过修改 `fish_speech/configs/vqgan_finetune.yaml` 来修改训练参数, 但大部分情况下, 你不需要这么做.
+    你可以通过修改 `fish_speech/configs/firefly_gan_vq.yaml` 来修改训练参数, 但大部分情况下, 你不需要这么做.
 
 
 ### 4. 测试音频
 ### 4. 测试音频
     
     
 ```bash
 ```bash
-python tools/vqgan/inference.py -i test.wav --checkpoint-path results/vqgan_finetune/checkpoints/step_000010000.ckpt
+python tools/vqgan/inference.py -i test.wav --checkpoint-path results/firefly_gan_vq/checkpoints/step_000010000.ckpt
 ```
 ```
 
 
 你可以查看 `fake.wav` 来判断微调效果.
 你可以查看 `fake.wav` 来判断微调效果.

+ 6 - 6
docs/zh/inference.md

@@ -41,7 +41,7 @@ HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech
 ```bash
 ```bash
 python tools/vqgan/inference.py \
 python tools/vqgan/inference.py \
     -i "paimon.wav" \
     -i "paimon.wav" \
-    --checkpoint-path "checkpoints/vq-gan-group-fsq-2x1024.pth"
+    --checkpoint-path "checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth"
 ```
 ```
 你应该能得到一个 `fake.npy` 文件.
 你应该能得到一个 `fake.npy` 文件.
 
 
@@ -83,7 +83,7 @@ python tools/vits_decoder/inference.py \
 ```bash
 ```bash
 python tools/vqgan/inference.py \
 python tools/vqgan/inference.py \
     -i "codes_0.npy" \
     -i "codes_0.npy" \
-    --checkpoint-path "checkpoints/vq-gan-group-fsq-2x1024.pth"
+    --checkpoint-path "checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth"
 ```
 ```
 
 
 ## HTTP API 推理
 ## HTTP API 推理
@@ -95,8 +95,8 @@ python -m tools.api \
     --listen 0.0.0.0:8000 \
     --listen 0.0.0.0:8000 \
     --llama-checkpoint-path "checkpoints/text2semantic-sft-medium-v1.1-4k.pth" \
     --llama-checkpoint-path "checkpoints/text2semantic-sft-medium-v1.1-4k.pth" \
     --llama-config-name dual_ar_2_codebook_medium \
     --llama-config-name dual_ar_2_codebook_medium \
-    --decoder-checkpoint-path "checkpoints/vq-gan-group-fsq-2x1024.pth" \
-    --decoder-config-name vqgan_pretrain
+    --decoder-checkpoint-path "checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth" \
+    --decoder-config-name firefly_gan_vq
 
 
 # 推荐中国大陆用户运行以下命令来启动 HTTP 服务:
 # 推荐中国大陆用户运行以下命令来启动 HTTP 服务:
 HF_ENDPOINT=https://hf-mirror.com python -m ...
 HF_ENDPOINT=https://hf-mirror.com python -m ...
@@ -120,8 +120,8 @@ HF_ENDPOINT=https://hf-mirror.com python -m ...
 python -m tools.webui \
 python -m tools.webui \
     --llama-checkpoint-path "checkpoints/text2semantic-sft-medium-v1.1-4k.pth" \
     --llama-checkpoint-path "checkpoints/text2semantic-sft-medium-v1.1-4k.pth" \
     --llama-config-name dual_ar_2_codebook_medium \
     --llama-config-name dual_ar_2_codebook_medium \
-    --decoder-checkpoint-path "checkpoints/vq-gan-group-fsq-2x1024.pth" \
-    --decoder-config-name vqgan_pretrain
+    --decoder-checkpoint-path "checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth" \
+    --decoder-config-name firefly_gan_vq
 ```
 ```
 
 
 !!! info
 !!! info

+ 34 - 0
fish_speech/configs/firefly_gan_vq.yaml

@@ -0,0 +1,34 @@
+_target_: fish_speech.models.vqgan.modules.firefly.FireflyArchitecture
+spec_transform:
+  _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
+  sample_rate: 44100
+  n_mels: 160
+  n_fft: 2048
+  hop_length: 512
+  win_length: 2048
+backbone:
+  _target_: fish_speech.models.vqgan.modules.firefly.ConvNeXtEncoder
+  input_channels: 160
+  depths: [3, 3, 9, 3]
+  dims: [128, 256, 384, 512]
+  drop_path_rate: 0.2
+  kernel_size: 7
+head:
+  _target_: fish_speech.models.vqgan.modules.firefly.HiFiGANGenerator
+  hop_length: 512
+  upsample_rates: [8, 8, 2, 2, 2]  # aka. strides
+  upsample_kernel_sizes: [16, 16, 4, 4, 4]
+  resblock_kernel_sizes: [3, 7, 11]
+  resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
+  num_mels: 512
+  upsample_initial_channel: 512
+  use_template: false
+  pre_conv_kernel_size: 13
+  post_conv_kernel_size: 13
+quantizer:
+  _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
+  input_dim: 512
+  n_groups: 4
+  n_codebooks: 1
+  levels: [8, 5, 5, 5]
+  downsample_factor: [2]

+ 0 - 9
fish_speech/configs/model/dual_ar_2_codebook_large.yaml

@@ -1,9 +0,0 @@
-defaults:
-  - dual_ar_2_codebook_small
-  - _self_
-
-config:
-  n_layer: 30
-  n_fast_layer: 6
-  n_head: 24
-  dim: 1536

+ 0 - 9
fish_speech/configs/model/dual_ar_2_codebook_medium.yaml

@@ -1,9 +0,0 @@
-defaults:
-  - dual_ar_2_codebook_small
-  - _self_
-
-config:
-  n_layer: 24
-  n_fast_layer: 6
-  n_head: 16
-  dim: 1024

+ 0 - 13
fish_speech/configs/model/dual_ar_2_codebook_small.yaml

@@ -1,13 +0,0 @@
-_target_: fish_speech.models.text2semantic.llama.DualARTransformer
-config:
-  _target_: fish_speech.models.text2semantic.llama.DualARModelArgs
-  max_seq_len: ${max_length}
-  vocab_size: 264 # pad 262 to 8x
-  n_layer: 12
-  n_fast_layer: 4
-  n_head: 12
-  dim: 768
-  rope_base: 10000
-  norm_eps: 1e-5
-  num_codebooks: 2  # input/output codebook size
-  codebook_size: 1032 # codebook size 1024 + 2 special tokens

+ 0 - 12
fish_speech/configs/model/naive_2_codebook_small.yaml

@@ -1,12 +0,0 @@
-_target_: fish_speech.models.text2semantic.llama.NaiveTransformer
-config:
-  _target_: fish_speech.models.text2semantic.llama.NaiveModelArgs
-  max_seq_len: ${max_length}
-  vocab_size: 36408
-  n_layer: 12
-  n_head: 12
-  dim: 768
-  rope_base: 10000
-  norm_eps: 1e-5
-  num_codebooks: 2  # input/output codebook size
-  codebook_size: 1032 # codebook size 1024 + 2 special tokens

+ 66 - 0
fish_speech/configs/text2semantic_agent.yaml

@@ -0,0 +1,66 @@
+defaults:
+  - base
+  - model@model.model: dual_ar_2_codebook_1.3b
+  - _self_
+
+project: text2semantic_agent_dual_ar_debug
+max_length: 2048
+ckpt_path: checkpoints/fish-speech-agent-1/step_000013000.ckpt
+resume_weights_only: true
+
+# Lightning Trainer
+trainer:
+  accumulate_grad_batches: 1
+  gradient_clip_val: 1.0
+  gradient_clip_algorithm: 'norm'
+  max_steps: 1_000_000
+  precision: bf16-true
+  log_every_n_steps: 10
+  limit_val_batches: 10
+  val_check_interval: 1000
+
+# Dataset Configuration
+tokenizer:
+  _target_: transformers.AutoTokenizer.from_pretrained
+  pretrained_model_name_or_path: checkpoints/fish-speech-agent-1
+
+# Dataset Configuration
+train_dataset: {}
+val_dataset: {}
+
+data:
+  _target_: fish_speech.datasets.text.TextDataModule
+  train_dataset: ${train_dataset}
+  val_dataset: ${val_dataset}
+  num_workers: 4
+  batch_size: 8
+  tokenizer: ${tokenizer}
+  max_length: ${max_length}
+
+# Model Configuration
+model:
+  _target_: fish_speech.models.text2semantic.TextToSemantic
+  model: {}
+
+  optimizer:
+    _target_: torch.optim.AdamW
+    _partial_: true
+    lr: 3e-4
+    weight_decay: 0.01
+    betas: [0.9, 0.95]
+    eps: 1e-5
+
+  lr_scheduler:
+    _target_: torch.optim.lr_scheduler.LambdaLR
+    _partial_: true
+    lr_lambda:
+      _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
+      _partial_: true
+      num_warmup_steps: 1000
+      num_training_steps: ${trainer.max_steps}
+      final_lr_ratio: 0.1
+
+# Callbacks
+callbacks:
+  model_checkpoint:
+    every_n_train_steps: ${trainer.val_check_interval}

+ 0 - 128
fish_speech/configs/vits_decoder_finetune.yaml

@@ -1,128 +0,0 @@
-defaults:
-  - base
-  - _self_
-
-project: vits_decoder
-ckpt_path: checkpoints/vits_decoder_v1.1.ckpt
-resume_weights_only: true
-
-# Lightning Trainer
-trainer:
-  accelerator: gpu
-  devices: auto
-  strategy:
-    find_unused_parameters: true
-  precision: 32
-  max_steps: 100_000
-  val_check_interval: 100
-  benchmark: false
-
-sample_rate: 44100
-hop_length: 512
-num_mels: 128
-n_fft: 2048
-win_length: 2048
-
-# Dataset Configuration
-tokenizer:
-  _target_: transformers.AutoTokenizer.from_pretrained
-  pretrained_model_name_or_path: fishaudio/fish-speech-1
-
-# Dataset Configuration
-train_dataset:
-  _target_: fish_speech.datasets.vits.VITSDataset
-  filelist: data/source/Genshin/filelist.train.txt
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-  suffix: ".lab"
-  tokenizer: ${tokenizer}
-  sentence_mask_ratio: 0.2
-
-val_dataset:
-  _target_: fish_speech.datasets.vits.VITSDataset
-  filelist: data/source/Genshin/filelist.test.txt
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-  suffix: ".lab"
-  tokenizer: ${tokenizer}
-
-data:
-  _target_: fish_speech.datasets.vits.VITSDataModule
-  train_dataset: ${train_dataset}
-  val_dataset: ${val_dataset}
-  num_workers: 4
-  batch_size: 8
-  val_batch_size: 4
-  tokenizer: ${tokenizer}
-
-# Model Configuration
-model:
-  _target_: fish_speech.models.vits_decoder.VITSDecoder
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-  freeze_discriminator: false
-
-  weight_mel: 45.0
-  weight_kl: 1.0
-
-  generator:
-    _target_: fish_speech.models.vits_decoder.modules.models.SynthesizerTrn
-    spec_channels: 1025
-    segment_size: 32
-    inter_channels: 192
-    hidden_channels: 192
-    filter_channels: 768
-    n_heads: 2
-    n_layers: 6
-    kernel_size: 3
-    p_dropout: 0.1
-    resblock: "1"
-    resblock_kernel_sizes: [3, 7, 11]
-    resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
-    upsample_rates: [8, 8, 2, 2, 2]
-    upsample_initial_channel: 512
-    upsample_kernel_sizes: [16, 16, 8, 2, 2]
-    gin_channels: 512
-    vq_mask_ratio: 0.2
-    ref_mask_ratio: 0.2
-
-  discriminator:
-    _target_: fish_speech.models.vits_decoder.modules.models.EnsembledDiscriminator
-    periods: [2, 3, 5, 7, 11]
-
-  mel_transform:
-    _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
-    sample_rate: ${sample_rate}
-    n_fft: ${n_fft}
-    hop_length: ${hop_length}
-    win_length: ${win_length}
-    n_mels: ${num_mels}
-
-  spec_transform:
-    _target_: fish_speech.utils.spectrogram.LinearSpectrogram
-    n_fft: ${n_fft}
-    hop_length: ${hop_length}
-    win_length: ${win_length}
-    mode: pow2_sqrt
-  
-  optimizer:
-    _target_: torch.optim.AdamW
-    _partial_: true
-    lr: 1e-4
-    betas: [0.8, 0.99]
-    eps: 1e-6
-
-  lr_scheduler:
-    _target_: torch.optim.lr_scheduler.ExponentialLR
-    _partial_: true
-    gamma: 0.999999
-
-callbacks:
-  grad_norm_monitor:
-    sub_module: 
-      - generator
-      - discriminator
-
-  model_checkpoint:
-    every_n_train_steps: ${trainer.val_check_interval}
-    save_top_k: 10

+ 0 - 127
fish_speech/configs/vits_decoder_pretrain.yaml

@@ -1,127 +0,0 @@
-defaults:
-  - base
-  - _self_
-
-project: vits_decoder
-ckpt_path: checkpoints/Bert-VITS2/ensemble.pth
-resume_weights_only: true
-
-# Lightning Trainer
-trainer:
-  accelerator: gpu
-  devices: auto
-  strategy: ddp_find_unused_parameters_true
-  precision: 32
-  max_steps: 1_000_000
-  val_check_interval: 1000
-  benchmark: false
-
-sample_rate: 44100
-hop_length: 512
-num_mels: 128
-n_fft: 2048
-win_length: 2048
-
-# Dataset Configuration
-tokenizer:
-  _target_: transformers.AutoTokenizer.from_pretrained
-  pretrained_model_name_or_path: fishaudio/fish-speech-1
-
-# Dataset Configuration
-train_dataset:
-  _target_: fish_speech.datasets.vits.VITSDataset
-  filelist: data/source/Genshin/filelist.train.txt
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-  suffix: ".lab"
-  tokenizer: ${tokenizer}
-  sentence_mask_ratio: 0.2
-
-val_dataset:
-  _target_: fish_speech.datasets.vits.VITSDataset
-  filelist: data/source/Genshin/filelist.test.txt
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-  suffix: ".lab"
-  tokenizer: ${tokenizer}
-
-data:
-  _target_: fish_speech.datasets.vits.VITSDataModule
-  train_dataset: ${train_dataset}
-  val_dataset: ${val_dataset}
-  num_workers: 4
-  batch_size: 8
-  val_batch_size: 4
-  tokenizer: ${tokenizer}
-
-# Model Configuration
-model:
-  _target_: fish_speech.models.vits_decoder.VITSDecoder
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-  freeze_discriminator: false
-
-  weight_mel: 45.0
-  weight_kl: 1.0
-
-  generator:
-    _target_: fish_speech.models.vits_decoder.modules.models.SynthesizerTrn
-    spec_channels: 1025
-    segment_size: 32
-    inter_channels: 192
-    hidden_channels: 192
-    filter_channels: 768
-    n_heads: 2
-    n_layers: 6
-    kernel_size: 3
-    p_dropout: 0.1
-    resblock: "1"
-    resblock_kernel_sizes: [3, 7, 11]
-    resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
-    upsample_rates: [8, 8, 2, 2, 2]
-    upsample_initial_channel: 512
-    upsample_kernel_sizes: [16, 16, 8, 2, 2]
-    gin_channels: 512
-    vq_mask_ratio: 0.2
-    ref_mask_ratio: 0.2
-
-  discriminator:
-    _target_: fish_speech.models.vits_decoder.modules.models.EnsembledDiscriminator
-    periods: [2, 3, 5, 7, 11]
-
-  mel_transform:
-    _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
-    sample_rate: ${sample_rate}
-    n_fft: ${n_fft}
-    hop_length: ${hop_length}
-    win_length: ${win_length}
-    n_mels: ${num_mels}
-
-  spec_transform:
-    _target_: fish_speech.utils.spectrogram.LinearSpectrogram
-    n_fft: ${n_fft}
-    hop_length: ${hop_length}
-    win_length: ${win_length}
-    mode: pow2_sqrt
-  
-  optimizer:
-    _target_: torch.optim.AdamW
-    _partial_: true
-    lr: 1e-4
-    betas: [0.8, 0.99]
-    eps: 1e-6
-
-  lr_scheduler:
-    _target_: torch.optim.lr_scheduler.ExponentialLR
-    _partial_: true
-    gamma: 0.999999
-
-callbacks:
-  grad_norm_monitor:
-    sub_module: 
-      - generator
-      - discriminator
-
-  model_checkpoint:
-    every_n_train_steps: 1000
-    save_top_k: 10

+ 0 - 137
fish_speech/configs/vqgan_finetune.yaml

@@ -1,137 +0,0 @@
-defaults:
-  - base
-  - _self_
-
-project: vq-gan-finetune
-ckpt_path: checkpoints/vq-gan-group-fsq-2x1024.pth
-resume_weights_only: true
-
-# Lightning Trainer
-trainer:
-  accelerator: gpu
-  devices: auto
-  precision: bf16-mixed
-  max_steps: 100_000
-  val_check_interval: 5000
-  strategy:
-    find_unused_parameters: true
-
-sample_rate: 44100
-hop_length: 512
-num_mels: 128
-n_fft: 2048
-win_length: 2048
-
-# Dataset Configuration
-train_dataset:
-  _target_: fish_speech.datasets.vqgan.VQGANDataset
-  filelist: data/vq_train_filelist.txt
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-  slice_frames: 512
-
-val_dataset:
-  _target_: fish_speech.datasets.vqgan.VQGANDataset
-  filelist: data/vq_val_filelist.txt
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-
-data:
-  _target_: fish_speech.datasets.vqgan.VQGANDataModule
-  train_dataset: ${train_dataset}
-  val_dataset: ${val_dataset}
-  num_workers: 4
-  batch_size: 16
-  val_batch_size: 16
-
-# Model Configuration
-model:
-  _target_: fish_speech.models.vqgan.VQGAN
-
-  sampling_rate: ${sample_rate}
-  weight_adv: 0.2
-  weight_vq: 1.0
-  weight_mel: 1.0
-
-  # Important: Set the freeze_encoder to true to only train the decoder
-  freeze_encoder: true
-
-  encoder:
-    _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
-    input_channels: ${num_mels}
-    residual_channels: 768
-    residual_layers: 20
-    dilation_cycle: 4
-  
-  quantizer:
-    _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
-    input_dim: 768
-    n_codebooks: 1
-    n_groups: 2
-    levels: [8, 5, 5, 5]
-
-  decoder:
-    _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
-    output_channels: ${num_mels}
-    residual_channels: 768
-    residual_layers: 20
-    dilation_cycle: 4
-    condition_channels: 768
-  
-  discriminator:
-    _target_: fish_speech.models.vqgan.modules.discriminator.Discriminator
-
-  vocoder:
-    _target_: fish_speech.models.vqgan.modules.firefly.FireflyBase
-    ckpt_path: null # You may download the pretrained vocoder and set the path here
-
-  encode_mel_transform:
-    _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
-    sample_rate: ${sample_rate}
-    n_fft: ${n_fft}
-    hop_length: ${hop_length}
-    win_length: ${win_length}
-    n_mels: ${num_mels}
-    f_min: 0.0
-    f_max: 8000.0
-
-  gt_mel_transform:
-    _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
-    sample_rate: ${sample_rate}
-    n_fft: ${n_fft}
-    hop_length: ${hop_length}
-    win_length: ${win_length}
-    n_mels: ${num_mels}
-
-  optimizer:
-    _target_: torch.optim.AdamW
-    _partial_: true
-    lr: 4e-5
-    betas: [0.8, 0.99]
-    eps: 1e-5
-    weight_decay: 0.01
-
-  lr_scheduler:
-    _target_: torch.optim.lr_scheduler.LambdaLR
-    _partial_: true
-    lr_lambda:
-      _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
-      _partial_: true
-      num_warmup_steps: 0
-      num_training_steps: ${trainer.max_steps}
-      final_lr_ratio: 0
-
-callbacks:
-  model_summary:
-    _target_: lightning.pytorch.callbacks.ModelSummary
-    max_depth: 1
-
-  model_checkpoint:
-    every_n_train_steps: ${trainer.val_check_interval}
-
-  grad_norm_monitor:
-    sub_module: 
-      - encoder
-      - decoder
-      - quantizer
-      - discriminator

+ 0 - 140
fish_speech/configs/vqgan_pretrain.yaml

@@ -1,140 +0,0 @@
-defaults:
-  - base
-  - _self_
-
-project: vq-gan-pretrain
-
-# Lightning Trainer
-trainer:
-  accelerator: gpu
-  devices: auto
-  precision: bf16-mixed
-  max_steps: 1_000_000
-  val_check_interval: 5000
-  strategy:
-    find_unused_parameters: true
-
-sample_rate: 44100
-hop_length: 512
-num_mels: 128
-n_fft: 2048
-win_length: 2048
-
-# Dataset Configuration
-train_dataset:
-  _target_: torch.utils.data.ConcatDataset
-  datasets:
-    - _target_: fish_speech.datasets.vqgan.VQGANDataset
-      filelist: data/gigaspeech/vq_train_filelist.txt
-      sample_rate: ${sample_rate}
-      hop_length: ${hop_length}
-      slice_frames: 512
-    - _target_: fish_speech.datasets.vqgan.VQGANDataset
-      filelist: data/sft/vq_train_filelist.txt
-      sample_rate: ${sample_rate}
-      hop_length: ${hop_length}
-      slice_frames: 512
-
-val_dataset:
-  _target_: fish_speech.datasets.vqgan.VQGANDataset
-  filelist: data/sft/vq_val_filelist.txt
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-
-data:
-  _target_: fish_speech.datasets.vqgan.VQGANDataModule
-  train_dataset: ${train_dataset}
-  val_dataset: ${val_dataset}
-  num_workers: 4
-  batch_size: 32
-  val_batch_size: 32
-
-# Model Configuration
-model:
-  _target_: fish_speech.models.vqgan.VQGAN
-
-  sampling_rate: ${sample_rate}
-  weight_adv: 0.2
-  weight_vq: 1.0
-  weight_mel: 1.0
-  freeze_encoder: false
-
-  encoder:
-    _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
-    input_channels: ${num_mels}
-    residual_channels: 768
-    residual_layers: 20
-    dilation_cycle: 4
-  
-  quantizer:
-    _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
-    input_dim: 768
-    n_codebooks: 1
-    n_groups: 2
-    levels: [8, 5, 5, 5]
-
-  decoder:
-    _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
-    output_channels: ${num_mels}
-    residual_channels: 768
-    residual_layers: 20
-    dilation_cycle: 4
-    condition_channels: 768
-  
-  discriminator:
-    _target_: fish_speech.models.vqgan.modules.discriminator.Discriminator
-
-  vocoder:
-    _target_: fish_speech.models.vqgan.modules.firefly.FireflyBase
-    ckpt_path: null # You may download the pretrained vocoder and set the path here
-
-  encode_mel_transform:
-    _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
-    sample_rate: ${sample_rate}
-    n_fft: ${n_fft}
-    hop_length: ${hop_length}
-    win_length: ${win_length}
-    n_mels: ${num_mels}
-    f_min: 0.0
-    f_max: 8000.0
-
-  gt_mel_transform:
-    _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
-    sample_rate: ${sample_rate}
-    n_fft: ${n_fft}
-    hop_length: ${hop_length}
-    win_length: ${win_length}
-    n_mels: ${num_mels}
-
-  optimizer:
-    _target_: torch.optim.AdamW
-    _partial_: true
-    lr: 1e-4
-    betas: [0.8, 0.99]
-    eps: 1e-5
-    weight_decay: 0.01
-
-  lr_scheduler:
-    _target_: torch.optim.lr_scheduler.LambdaLR
-    _partial_: true
-    lr_lambda:
-      _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
-      _partial_: true
-      num_warmup_steps: 100
-      num_training_steps: ${trainer.max_steps}
-      final_lr_ratio: 0
-
-callbacks:
-  model_summary:
-    _target_: lightning.pytorch.callbacks.ModelSummary
-    max_depth: 1
-
-  model_checkpoint:
-    every_n_train_steps: ${trainer.val_check_interval}
-
-  grad_norm_monitor:
-    sub_module: 
-      - encoder
-      - decoder
-      - quantizer
-      - discriminator

+ 2 - 0
fish_speech/conversation.py

@@ -0,0 +1,2 @@
+SEMANTIC_TOKEN = "<|semantic|>"
+CODEBOOK_PAD_TOKEN_ID = 0

+ 0 - 35
fish_speech/datasets/concat_repeat.py

@@ -51,38 +51,3 @@ class ConcatRepeatDataset(Dataset):
         dataset = self.datasets[dataset_idx]
         dataset = self.datasets[dataset_idx]
 
 
         return dataset[sample_idx % len(dataset)]
         return dataset[sample_idx % len(dataset)]
-
-
-class ConcatWeightedIterableDataset(IterableDataset):
-    datasets: list[IterableDataset]
-    weights: list[float]
-
-    def __init__(self, datasets: Iterable[IterableDataset], weights: list[float]):
-        super().__init__()
-
-        total_weight = sum(weights)
-        self.weights = [w / total_weight for w in weights]
-        self.datasets = list(datasets)
-
-        assert len(self.datasets) > 0, "datasets should not be an empty iterable"
-        assert len(self.datasets) == len(
-            weights
-        ), "datasets and repeats should have the same length"
-
-        for d in self.datasets:
-            assert isinstance(
-                d, IterableDataset
-            ), "ConcatRepeatIterableDataset only supports IterableDataset"
-
-    def __iter__(self):
-        all_datasets = [iter(dataset) for dataset in self.datasets]
-        ids = list(range(len(self.datasets)))
-
-        while True:
-            chosen_dataset = random.choices(ids, self.weights)[0]
-
-            try:
-                yield next(all_datasets[chosen_dataset])
-            except StopIteration:
-                all_datasets[chosen_dataset] = iter(self.datasets[chosen_dataset])
-                yield next(all_datasets[chosen_dataset])

+ 381 - 0
fish_speech/datasets/prompts.py

@@ -0,0 +1,381 @@
+# "Transcribe the following audio into text."
+# "Transcribe what you will hear."
+
+asr_instructions = [
+    "Transcribe:",
+    "Transcribe the following audio into text.",
+    "Convert the audio you're about to hear into written text.",
+    "Please write down what you hear in the audio file.",
+    "Listen to the audio and type out its contents.",
+    "Your task is to write the audio's content in text form.",
+    "Transcribe the content of the audio into text.",
+    "Transform the given audio into a textual format.",
+    "Listen to the following sound clip and transcribe it.",
+    "The audio provided should be converted into written words.",
+    "Document the audio in text.",
+    "Put the audio's dialogue into written form.",
+    "Capture the audio's message in text.",
+    "Turn the sound file's speech into text.",
+    "Render the audio into a text version.",
+    "Translate the audio recording to text.",
+    "Write out the dialogue from the audio.",
+    "Listen and transcribe the audio into words.",
+    "Change the audio into a written transcript.",
+    "Your job is to transcribe the audio to text.",
+    "Please transcribe the spoken words into text.",
+    "The task is to convert audio speech into written text.",
+    "Make a text transcript of the following audio.",
+    "Decode the audio into a written document.",
+    "Write down the transcription of the audio.",
+    "Please provide a text version of this audio.",
+    "The objective is to transcribe the audio into readable text.",
+    "Listen carefully and type out the audio.",
+    "Transform this audio clip into a text document.",
+    "Your assignment is to transcribe this audio.",
+    "Transcribe this sound recording into text format.",
+    "The goal is to turn the audio into text.",
+    "Your duty is to document the audio in written form.",
+    "Listen to this audio piece and write down its contents.",
+    "The task is converting the audio into text.",
+    "Please create a textual transcription of the audio.",
+    "Capture in writing what is said in the audio.",
+    "Transcribe the audible content into a text format.",
+    "The mission is to transcribe the audio into text.",
+    "Your task: convert the audio to text.",
+    "Write the contents of the audio as text.",
+    "Listen to the clip and transcribe its audio to text.",
+    "Transcribe the given audio track into written words.",
+    "The assignment is to write out the audio in text.",
+    "Convert the spoken words into text.",
+    "Transcribe the voice recording into text.",
+    "Your task is to make a written record of the audio.",
+    "Listen to the audio and reproduce it in text.",
+    "Transcribe the following sound into written text.",
+    "Your challenge is to transcribe the audio into written form.",
+    "Make a written version of the audio.",
+    "Take the audio and transcribe it to text.",
+    "Write down everything you hear in the audio.",
+    "Please put the audio into text format.",
+    "Your role is to transcribe the following audio into text.",
+    "Convert the audio message into written text.",
+    "Provide a written transcription of the audio.",
+    "Listen and convert the audio to text.",
+    "The requirement is to transcribe the audio into text form.",
+    "Document in text what the audio says.",
+    "Transcribe into text what you hear in the audio.",
+    "Translate the audio file's contents into text.",
+    "The task is to create a text transcript of the audio.",
+    "Your assignment: Translate the audio into written words.",
+    "Write a textual representation of the audio.",
+    "Capture the essence of the audio in text.",
+    "Your job: Listen to the audio and transcribe it.",
+    "Turn the audio content into a text transcript.",
+    "The task at hand is to transcribe the audio to text.",
+    "Reproduce the audio in text form.",
+    "Your mission: Convert the audio into a textual format.",
+    "Transcribe what is spoken in the audio into text.",
+    "Create a written version of what's in the audio.",
+    "Transform the spoken audio into text.",
+    "Document the spoken words in the audio as text.",
+    "The objective is to write down the audio in text.",
+    "Your goal: Transcribe the audio into text.",
+    "Please convert the audio file into text.",
+    "Transcribe the audio clip into written text.",
+    "Listen to the audio and transcribe the speech into text.",
+    "Transform the voice from the audio into written words.",
+    "The task is to write the audio's speech in text form.",
+    "Your duty: Write down what the audio says.",
+    "Turn the given audio into a written format.",
+    "Write in text form what is said in the audio.",
+    "Your task: Document the audio in text.",
+    "Provide a text transcription of the audio.",
+    "Provide a text transcription of the audio.",
+    "Write down the audio you listen to.",
+    "Type out the spoken words you hear.",
+    "Document the audio content verbatim.",
+    "Transcribe the spoken content accurately.",
+    "Convert the audio you hear into text.",
+    "Record in writing what is said in the audio.",
+    "Capture the spoken words in written form.",
+    "Translate the audio into written text.",
+    "Jot down the words you hear in the audio.",
+    "Put into writing the spoken words you hear.",
+    "Transcribe the auditory information verbatim.",
+    "Note down the dialogue from the audio.",
+    "Write out the spoken words from the audio.",
+    "Transcribe the oral presentation into text.",
+    "Render the spoken audio into written form.",
+    "Reproduce the spoken words in text form.",
+    "Document what is being said in the audio.",
+    "Translate the spoken word into written form.",
+    "Write verbatim what you hear in the audio.",
+    "Capture in writing the contents of the audio.",
+    "Transcribe verbatim the spoken words.",
+    "Write down verbatim what is spoken.",
+    "Transcribe the sounds into words on paper.",
+    "Translate the sounds you hear into words.",
+    "Write the spoken words in text form.",
+    "Reproduce the audio content in writing.",
+    "Note verbatim what is said in the audio.",
+    "Put the audio content into written words.",
+    "Record the spoken words into text format.",
+    "Transcribe the audio into a written document.",
+    "Write down exactly what you hear.",
+    "Type out the content of the audio.",
+    "Document the words spoken in the audio.",
+    "Translate the verbal content into text.",
+    "Convert what you hear into written words.",
+    "Capture the essence of the audio in writing.",
+    "Reproduce the spoken content in written form.",
+    "Jot down exactly what is said in the audio.",
+    "Document every word you hear in the audio.",
+    "Record the audio content by writing it down.",
+    "Capture the audio's spoken words in text.",
+    "Turn the spoken audio into a written transcript.",
+    "Write down the contents of the audio verbatim.",
+    "Transcribe the voice you hear into text.",
+    "Convert the spoken audio into text format.",
+    "Type what is being spoken in the audio.",
+    "Translate the audio speech into written words.",
+    "Write the audio's dialogue in written form.",
+    "Record the verbal content as written text.",
+    "Transcribe the spoken parts of the audio.",
+    "Note down everything you hear in the audio.",
+    "Capture every word from the audio in text.",
+    "Put the spoken audio into text form.",
+    "Transcribe the audible content into words.",
+    "Translate the oral content into written text.",
+    "Type out everything heard in the audio.",
+    "Write down the spoken parts verbatim.",
+    "Document the spoken audio in text form.",
+    "Capture the verbal exchanges in written text.",
+    "Transcribe each word you hear accurately.",
+    "Turn the audio into a textual document.",
+    "Transcribe the sound into written words.",
+    "Write the audio transcript in your own words.",
+    "Document in text what you hear in the audio.",
+    "Record in text the spoken parts of the audio.",
+    "Transcribe the narrative you hear into text.",
+    "Capture the spoken narrative in written form.",
+    "Convert the verbal audio into written script.",
+    "Note down the spoken words in the audio.",
+    "Write in text form what is spoken in the audio.",
+    "Record the audio's spoken words verbatim.",
+    "Jot down the audio's dialogue accurately.",
+    "Transcribe the verbal parts into written words.",
+    "Translate the audio's spoken content into text.",
+    "Document the audio dialogue in written form.",
+    "Type out the words spoken in the audio verbatim.",
+    "Write down word for word what is said in the audio.",
+    "Transcribe the entire audio content into text.",
+    "Note down precisely what is said in the audio.",
+    "Capture in text the spoken content of the audio.",
+    "Record the spoken audio into written language.",
+    "Write the essence of the audio in text form.",
+    "Transcribe the words you hear in the audio.",
+    "Translate every spoken word into written text.",
+    "Convert the oral speech into a written format.",
+    "Jot down the words spoken in the audio.",
+    "Record every word from the audio in writing.",
+    "Document the entire audio in written form.",
+    "Transcribe the spoken language into text.",
+    "Write down the audio's words exactly as spoken.",
+    "Capture the spoken word in written format.",
+    "Type out verbatim the spoken audio content.",
+    "Write precisely what you hear from the audio.",
+]
+
+# "Read the following text with emotion."
+# "Read the following text."
+
+tts_instructions = [
+    "Speak:",
+    "Expressively read the text that follows.",
+    "Convey the upcoming text with emotion.",
+    "Deliver the following passage with heartfelt expression.",
+    "Evoke emotion while reading the text below.",
+    "With feeling, please read the text that comes next.",
+    "Infuse the upcoming words with emotional depth as you read.",
+    "Let your emotions guide you as you read the following lines.",
+    "Channel emotion into your reading of the next passage.",
+    "Read the text below with a sense of emotion.",
+    "Bring the following words to life with emotional expression.",
+    "Engage emotionally with the text as you read it aloud.",
+    "Imbue the subsequent text with feeling as you read.",
+    "Read the following content with genuine emotion.",
+    "Allow your feelings to resonate through the upcoming text.",
+    "Emotionally interpret the text that follows.",
+    "Read the ensuing passage with deep feeling.",
+    "Convey the text below with genuine emotional depth.",
+    "Read the text that comes next, letting your emotions flow.",
+    "With emotion, present the following words.",
+    "Let your emotional expression enhance the next text.",
+    "Embrace emotion as you read the following passage.",
+    "Read aloud the text below with emotive expression.",
+    "Infuse the upcoming lines with emotional intensity.",
+    "With sincerity, read the following text with emotion.",
+    "Project emotion as you deliver the text that follows.",
+    "Let the next words be read with a wealth of emotion.",
+    "Give the upcoming text an emotional rendition.",
+    "With emotion, read the text that is presented next.",
+    "Convey the essence of the following text with heartfelt emotion.",
+    "Inject emotional depth into your reading of the next passage.",
+    "Bring out the emotional undertones in the following text.",
+    "Embody the emotions as you read the text below.",
+    "Express the following narrative with emotional depth.",
+    "Let emotion permeate your reading of the upcoming passage.",
+    "Interpret the following text with a rich emotional tone.",
+    "Elicit emotion through your reading of the next content.",
+    "Read the subsequent text with a deep emotional connection.",
+    "Emote the essence of the text that follows in your reading.",
+    "Render the following lines with emotional expression.",
+    "Expressively interpret the upcoming text.",
+    "Immerse in emotion as you read the following passage.",
+    "Engage with the text below on an emotional level as you read.",
+    "With emotional clarity, read the next text.",
+    "Let an emotional depth inform your reading of the following words.",
+    "Express the following content with deep emotional resonance.",
+    "Deliver the upcoming text with a range of emotions.",
+    "Narrate the following lines with emotional expressiveness.",
+    "Convey emotional texture as you read the text below.",
+    "Instill the next passage with emotive power as you read.",
+    "Read the ensuing text with a palette of emotions.",
+    "With a depth of feeling, present the next text.",
+    "Inflect the upcoming words with emotional vibrancy.",
+    "Emotionally engage with the text that follows in your reading.",
+    "Lend emotional expression to the passage below.",
+    "Evoke a spectrum of emotions as you read the next lines.",
+    "Channel a rich emotional tone into the following text.",
+    "With feeling, convey the essence of the upcoming passage.",
+    "Read the text that comes next with emotional fervor.",
+    "Render the following words with emotional authenticity.",
+    "Give the upcoming passage an emotive interpretation.",
+    "Allow your reading of the text below to be emotionally driven.",
+    "Imbue the next lines with a sense of emotion.",
+    "Emotionally animate the following text as you read.",
+    "Bring emotional depth to the passage that follows.",
+    "Articulate the text below with emotional nuance.",
+    "Project a range of emotions as you read the upcoming text.",
+    "With emotion, breathe life into the following words.",
+    "Narrate the ensuing text with heartfelt emotion.",
+    "Convey the text that follows with emotional richness.",
+    "Read aloud the next passage with a depth of emotion.",
+    "Emphasize emotional expression in your reading of the text below.",
+    "Let your reading of the following lines be emotionally charged.",
+    "With a heartfelt approach, read the upcoming text.",
+    "Express the essence of emotion as you deliver the next passage.",
+    "Read the following text, infused with emotional energy.",
+    "Allow the text that comes next to be expressed with emotion.",
+    "Convey the following passage with an emotional depth.",
+    "Emotionally render the text that follows.",
+    "With an emotional undertone, read the upcoming words.",
+    "Read the text below, letting emotion guide your expression.",
+    "Elicit an emotional response through your reading of the next passage.",
+    "Give the following lines an emotive delivery.",
+    "Read the upcoming text with emotional sincerity.",
+    "Narrate the text that follows with an emotional touch.",
+    "Deliver the following words with an emotive clarity.",
+    "Express the next passage with a range of emotional tones.",
+    "Immerse yourself emotionally in the text below as you read.",
+    "Let the ensuing text be conveyed with profound emotion.",
+    "Infuse the following lines with a sense of heartfelt emotion.",
+    "Emotionally engage with the upcoming text in your reading.",
+    "Convey deep emotion as you read the text that follows.",
+    "Let your reading of the next passage be rich in emotion.",
+    "With emotional depth, narrate the following text.",
+    "Read the text below, capturing its emotional essence.",
+    "Emote through your reading of the upcoming lines.",
+    "Please read the text that follows aloud.",
+    "Proceed to vocalize the upcoming text.",
+    "Kindly articulate the subsequent text.",
+    "Go ahead and pronounce the text below.",
+    "Could you recite the forthcoming passage?",
+    "Start reading the text below out loud.",
+    "Announce the following text audibly.",
+    "Voice the text that comes next.",
+    "Read through the following lines aloud.",
+    "Narrate the text presented below.",
+    "Elevate your voice for the upcoming script.",
+    "Broadcast the text that follows.",
+    "Project the subsequent lines audibly.",
+    "Give voice to the text underneath.",
+    "Unfold the following text with your voice.",
+    "Engage in reading the next piece of text aloud.",
+    "Orate the following series of words.",
+    "Enunciate the text appearing next.",
+    "Verbally present the upcoming text.",
+    "Articulate the passage that follows.",
+    "Read aloud the text that's coming up.",
+    "Proclaim the subsequent words.",
+    "Vocalize the narrative below.",
+    "Bring the following text to life by reading it aloud.",
+    "Express the next text with your voice.",
+    "Render the following text audibly.",
+    "Voice out the lines that follow.",
+    "Orally deliver the upcoming text.",
+    "Loudly read out the text below.",
+    "Share the next text by reading it out loud.",
+    "Speak the following passage aloud.",
+    "Let your voice carry the upcoming words.",
+    "Annunciate the text that follows.",
+    "Sound out the subsequent text.",
+    "Aurally present the text below.",
+    "Elocute the forthcoming lines.",
+    "Recite the text below with clarity.",
+    "Make the next text heard by reading aloud.",
+    "Bring forth your voice for the following script.",
+    "Read the text that ensues out loud.",
+    "Deliver the following lines vocally.",
+    "Voice the ensuing text.",
+    "Publicly read the text that follows.",
+    "Loudly narrate the subsequent text.",
+    "Express the following text through your voice.",
+    "Verbally articulate the next passage.",
+    "Read the forthcoming text clearly.",
+    "Announce the next set of words aloud.",
+    "Broadcast the following narrative.",
+    "Articulate the text coming up next.",
+    "Enunciate the passage that follows clearly.",
+    "Recite the subsequent text audibly.",
+    "Speak out the text below.",
+    "Project your voice with the following words.",
+    "Read the next lines aloud.",
+    "Vocalize the text that is to follow.",
+    "Narrate aloud the text below.",
+    "Orate the forthcoming script.",
+    "Pronounce the next passage.",
+    "Read out the subsequent text.",
+    "Let the following words be heard by reading them aloud.",
+    "Express the text that follows with your voice.",
+    "Give audible life to the text below.",
+    "Speak the ensuing text clearly.",
+    "Make the forthcoming text audible.",
+    "Project the next series of words audibly.",
+    "Voice out the following narrative.",
+    "Elevate the subsequent text with your voice.",
+    "Bring the next passage to audible life.",
+    "Read the lines that come next out loud.",
+    "Announce the text below with clarity.",
+    "Vocalize the script that follows.",
+    "Narrate the following text with emphasis.",
+    "Deliver the upcoming words with your voice.",
+    "Articulate the next set of lines.",
+    "Verbally convey the following text.",
+    "Present the subsequent text vocally.",
+    "Enunciate the upcoming passage loudly.",
+    "Orally render the text that follows.",
+    "Speak out the subsequent narrative.",
+    "Proclaim the next text audibly.",
+    "Elocute the following lines with clarity.",
+    "Give voice to the upcoming script.",
+    "Let your voice express the text below.",
+    "Annunciate the following words clearly.",
+    "Sound out the text that is next.",
+    "Aurally convey the subsequent passage.",
+    "Read the text up next aloud.",
+]
+
+prompt_dict = {
+    "asr": asr_instructions,
+    "tts": tts_instructions,
+}

+ 263 - 298
fish_speech/datasets/text.py

@@ -1,21 +1,29 @@
+import gzip
+import io
+import json
 import random
 import random
 from dataclasses import dataclass
 from dataclasses import dataclass
-from itertools import chain
 from pathlib import Path
 from pathlib import Path
 from random import Random
 from random import Random
 from typing import Optional, Union
 from typing import Optional, Union
 
 
 import numpy as np
 import numpy as np
-import pyarrow.parquet as pq
 import torch
 import torch
 import torch.nn.functional as F
 import torch.nn.functional as F
-from datasets.download.streaming_download_manager import xopen
-from huggingface_hub import HfApi
+import zstandard as zstd
 from lightning import LightningDataModule
 from lightning import LightningDataModule
 from torch.distributed import get_rank, get_world_size, is_initialized
 from torch.distributed import get_rank, get_world_size, is_initialized
 from torch.utils.data import DataLoader, IterableDataset, get_worker_info
 from torch.utils.data import DataLoader, IterableDataset, get_worker_info
 from transformers import AutoTokenizer
 from transformers import AutoTokenizer
 
 
+from fish_speech.conversation import (
+    CODEBOOK_PAD_TOKEN_ID,
+    SKIP_TEXT_STRING,
+    Conversation,
+    Message,
+    encode_conversation,
+)
+from fish_speech.datasets.prompts import asr_instructions, tts_instructions
 from fish_speech.datasets.protos.text_data_pb2 import SampledData
 from fish_speech.datasets.protos.text_data_pb2 import SampledData
 from fish_speech.datasets.protos.text_data_stream import read_pb_stream
 from fish_speech.datasets.protos.text_data_stream import read_pb_stream
 from fish_speech.text.clean import clean_text
 from fish_speech.text.clean import clean_text
@@ -24,9 +32,7 @@ from fish_speech.utils.braceexpand import braceexpand
 
 
 log = RankedLogger(__name__, rank_zero_only=True)
 log = RankedLogger(__name__, rank_zero_only=True)
 
 
-CODEBOOK_PAD_TOKEN_ID = 0
-CODEBOOK_EOS_TOKEN_ID = 1
-SKIP_TEXT_STRING = "<|skip_text|>"
+DCTX = zstd.ZstdDecompressor(max_window_size=2**31)
 
 
 
 
 def split_by_rank_worker(files):
 def split_by_rank_worker(files):
@@ -56,43 +62,55 @@ def split_by_rank_worker(files):
     return files
     return files
 
 
 
 
-class StreamTextDataset(IterableDataset):
+def expand_split_proto_files(proto_files, seed: int = 42):
+    # Expand the proto files
+    expanded_proto_files = []
+    for filename in proto_files:
+        for i in braceexpand(filename):
+            i = Path(i)
+            if i.is_file():
+                expanded_proto_files.append(i)
+            elif i.is_dir():
+                expanded_proto_files.extend(i.rglob("*.proto"))
+                expanded_proto_files.extend(i.rglob("*.protos"))
+            else:
+                raise ValueError(f"{i} is not a file or directory")
+
+    expanded_proto_files = sorted(expanded_proto_files)
+    Random(seed).shuffle(expanded_proto_files)
+    return split_by_rank_worker(expanded_proto_files)
+
+
+class TextPretrainDataset(IterableDataset):
     def __init__(
     def __init__(
         self,
         self,
-        files: Optional[Union[list[str], str]] = None,
-        prefix: Optional[str] = None,
+        source: str,
         seed: int = 42,
         seed: int = 42,
-        parquet_batch_size: int = 10000,
-        repo: str = "uonlp/CulturaX",
         max_length: int = 1024,
         max_length: int = 1024,
         tokenizer: AutoTokenizer = None,
         tokenizer: AutoTokenizer = None,
+        num_codebooks: int = 2,
     ):
     ):
         super().__init__()
         super().__init__()
 
 
+        self.source = Path(source)
         self.seed = seed
         self.seed = seed
-        self.parquet_batch_size = parquet_batch_size
-        self.repo = repo
         self.max_length = max_length
         self.max_length = max_length
         self.tokenizer = tokenizer
         self.tokenizer = tokenizer
+        self.num_codebooks = num_codebooks
 
 
-        if files is None and prefix is None:
-            raise ValueError("Either files or prefix must be specified")
-
-        if prefix is not None:
-            files = HfApi().list_repo_files(repo, repo_type="dataset")
+        if self.source.is_file():
+            with open(self.source, "r") as f:
+                files = f.read().strip().split("\n")
+            self.root = self.source.parent
+        else:
             files = [
             files = [
-                f for f in files if f.startswith(prefix) and f.endswith(".parquet")
+                str(i.relative_to(self.source)) for i in self.source.rglob("*.jsonl")
             ]
             ]
-            log.info(f"Found {len(files)} files in {repo} with prefix {prefix}")
-        else:
-            if isinstance(files, str):
-                files = [files]
-
-            files = list(chain.from_iterable(map(braceexpand, files)))
-            log.info(f"Expanded {len(files)} files in {repo}")
+            self.root = self.source
 
 
         # Get sharded files
         # Get sharded files
         self.files = sorted(files)
         self.files = sorted(files)
+
         Random(seed).shuffle(self.files)
         Random(seed).shuffle(self.files)
 
 
     def __iter__(self):
     def __iter__(self):
@@ -105,142 +123,147 @@ class StreamTextDataset(IterableDataset):
             except Exception as e:
             except Exception as e:
                 log.exception(f"Failed to parse {filename}: {e}")
                 log.exception(f"Failed to parse {filename}: {e}")
 
 
-    def parse_data(self, filename: str):
-        for data in self.parse_data_internal(filename):
-            text = data["text"]
+    def read_jsonl(self, filename: str):
+        with open(self.root / filename, "rb") as f:
+            if filename.endswith(".zst"):
+                stream_reader = DCTX.stream_reader(f)
+            elif filename.endswith(".gz"):
+                stream_reader = gzip.open(f, "rb")
+            elif filename.endswith(".jsonl"):
+                stream_reader = f
+            else:
+                raise ValueError(f"Unknown file type: {filename}")
 
 
+            stream = io.TextIOWrapper(stream_reader, encoding="utf-8")
+
+            # Parse jsonl
+            for line in stream:
+                line = json.loads(line)
+                yield line
+
+    def parse_data(self, filename: str):
+        for line in self.read_jsonl(filename):
             # encode
             # encode
             tokens = self.tokenizer.encode(
             tokens = self.tokenizer.encode(
-                text,
+                line["text"],
                 add_special_tokens=False,
                 add_special_tokens=False,
                 truncation=False,
                 truncation=False,
                 max_length=10**6,
                 max_length=10**6,
             )
             )
 
 
-            # Random choice self.max_length
-            if len(tokens) > self.max_length:
-                start = random.randint(0, len(tokens) - self.max_length)
-                tokens = tokens[start : start + self.max_length - 1]
-
             tokens = (
             tokens = (
                 [self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
                 [self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
             )
             )
-            # Pad dims
-            placeholder_multi_codebook = torch.zeros((4, len(tokens)), dtype=torch.long)
-
-            tokens = torch.concat(
-                [
-                    torch.tensor([tokens], dtype=torch.long),
-                    placeholder_multi_codebook,
-                ],
-                dim=0,
-            )
+
+            if len(tokens) > self.max_length:
+                tokens = tokens[: self.max_length]
+
+            tokens = self.pad_codebooks(tokens)
             labels = tokens.clone()
             labels = tokens.clone()
             tokens = tokens[:, :-1]
             tokens = tokens[:, :-1]
             labels = labels[:, 1:]
             labels = labels[:, 1:]
-            labels[1:] = -100  # remove all placeholders
+            labels[1:] = -100  # no loss on codebook
 
 
             yield {"tokens": tokens, "labels": labels}
             yield {"tokens": tokens, "labels": labels}
 
 
-    def parse_data_internal(self, filename: str):
-        url = f"https://huggingface.co/datasets/{self.repo}/resolve/main/{filename}"
+    def pad_codebooks(self, tokens):
+        placeholder_multi_codebook = (
+            torch.zeros((self.num_codebooks, len(tokens)), dtype=torch.long)
+            + CODEBOOK_PAD_TOKEN_ID
+        )
+        return torch.concat(
+            [
+                torch.tensor([tokens], dtype=torch.long),
+                placeholder_multi_codebook,
+            ],
+            dim=0,
+        )
+
 
 
-        with xopen(url, mode="rb") as stream:
-            parquet_file = pq.ParquetFile(stream)
+class TextInstructionDataset(TextPretrainDataset):
+    def parse_data(self, filename: str):
+        for line in self.read_jsonl(filename):
+            messages = []
+            for conversation in line["conversations"]:
+                role = {
+                    "human": "user",
+                    "gpt": "assistant",
+                    "system": "system",
+                }[conversation["from"]]
+
+                message = Message(
+                    role=role,
+                    parts=[conversation["value"]],
+                )
+                messages.append(message)
+
+            conversation = Conversation(messages=messages)
+            tokens, labels = encode_conversation(
+                conversation,
+                self.tokenizer,
+                num_codebooks=self.num_codebooks,
+            )
 
 
-            for batch in parquet_file.iter_batches(
-                batch_size=self.parquet_batch_size, columns=["text"]
-            ):
-                # In-batch shuffling
-                texts = [{"text": text.as_py()} for text in batch["text"]]
-                random.shuffle(texts)
-                yield from texts
+            yield {"tokens": tokens, "labels": labels}
 
 
 
 
-class AutoAugTextDataset(IterableDataset):
-    """
-    Auto Augment Dataset by Speaker
+def semantic_to_tensor(semantics):
+    num_codebooks = len(semantics)
+    codes = [[] for _ in range(num_codebooks)]
 
 
-    1. Random concatenate multiple sentences from the same speaker to form a longer sentence
-    2. Automatically normalize the text
+    for book_idx, book in zip(range(num_codebooks), semantics):
+        for j in book.values:
+            codes[book_idx].append(int(j))
 
 
-    For interactive mode, we use the following format (multiple sequences):
-    <s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
+    return torch.tensor(codes, dtype=torch.int)
 
 
-    For non-interactive mode, we use the following format (one long sequence):
-    <s> [INST] text [/INST] ... </s>
-    """
 
 
+class AutoTextSemanticInstructionDataset(IterableDataset):
     def __init__(
     def __init__(
         self,
         self,
         proto_files: list[str],
         proto_files: list[str],
         seed: int = 42,
         seed: int = 42,
-        interactive_prob: float = 0.5,
         max_length: int = 1024,
         max_length: int = 1024,
         tokenizer: AutoTokenizer = None,
         tokenizer: AutoTokenizer = None,
-        use_speaker: bool | float = True,
-        causual: bool = True,
-        use_negative_samples: bool = False,
+        causual: Union[bool, float] = True,
         num_codebooks: Optional[int] = None,
         num_codebooks: Optional[int] = None,
         skip_text_prob: float = 0.0,
         skip_text_prob: float = 0.0,
+        asr_prob: float = 0.0,
     ):
     ):
         """
         """
         Args:
         Args:
             proto_files: proto buf files if using local data
             proto_files: proto buf files if using local data
             seed: random seed
             seed: random seed
-            interactive_prob: probability to use interactive mode
             max_length: max length of the text
             max_length: max length of the text
             tokenizer: tokenizer
             tokenizer: tokenizer
-            use_speaker: include speaker information in the prompt
             causual: use causual sampling when using local data, disable will lead to random sampling
             causual: use causual sampling when using local data, disable will lead to random sampling
-            use_negative_samples: generate negative samples
             num_codebooks: number of codebooks, if None, it will be automatically detected
             num_codebooks: number of codebooks, if None, it will be automatically detected
             skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
             skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
+            asr_prob: probability to use ASR
         """
         """
 
 
         super().__init__()
         super().__init__()
 
 
-        assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
+        assert 0 <= skip_text_prob <= 1, "skip_text_prob must be in [0, 1]"
+        assert 0 <= asr_prob <= 1, "asr_prob must be in [0, 1]"
 
 
         self.seed = seed
         self.seed = seed
         self.max_length = max_length
         self.max_length = max_length
         self.tokenizer = tokenizer
         self.tokenizer = tokenizer
-        self.interactive_prob = interactive_prob
-        self.use_speaker = use_speaker
         self.proto_files = proto_files
         self.proto_files = proto_files
         self.causual = causual
         self.causual = causual
-        self.use_negative_samples = use_negative_samples
         self.num_codebooks = num_codebooks
         self.num_codebooks = num_codebooks
         self.skip_text_prob = skip_text_prob
         self.skip_text_prob = skip_text_prob
-
-        self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>")
+        self.asr_prob = asr_prob
         self.groups = None
         self.groups = None
 
 
     def init_mock_data_server(self):
     def init_mock_data_server(self):
         if self.groups is not None:
         if self.groups is not None:
             return
             return
 
 
-        # Expand the proto files
-        expanded_proto_files = []
-        for filename in self.proto_files:
-            for i in braceexpand(filename):
-                i = Path(i)
-                if i.is_file():
-                    expanded_proto_files.append(i)
-                elif i.is_dir():
-                    expanded_proto_files.extend(i.rglob("*.proto"))
-                    expanded_proto_files.extend(i.rglob("*.protos"))
-                else:
-                    raise ValueError(f"{i} is not a file or directory")
-
-        expanded_proto_files = sorted(expanded_proto_files)
-        Random(self.seed).shuffle(expanded_proto_files)
-
         self.groups = []
         self.groups = []
-        shard_proto_files = split_by_rank_worker(expanded_proto_files)
-        log.info(
-            f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
-        )
+        shard_proto_files = expand_split_proto_files(self.proto_files, seed=self.seed)
+        log.info(f"Reading {len(shard_proto_files)} files")
 
 
         count = 0
         count = 0
         for filename in shard_proto_files:
         for filename in shard_proto_files:
@@ -279,7 +302,11 @@ class AutoAugTextDataset(IterableDataset):
         # choice group based on their number of samples
         # choice group based on their number of samples
         group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
         group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
 
 
-        if self.causual:
+        causual = self.causual
+        if isinstance(self.causual, float):
+            causual = random.random() < self.causual
+
+        if causual:
             # Sample in order
             # Sample in order
             if num_samples >= len(group.sentences):
             if num_samples >= len(group.sentences):
                 samples = group.sentences
                 samples = group.sentences
@@ -298,7 +325,6 @@ class AutoAugTextDataset(IterableDataset):
         )
         )
 
 
     def augment(self):
     def augment(self):
-        final_text, final_semantic = [], []
         response = self.sample_data()
         response = self.sample_data()
         if len(response.samples) == 0:
         if len(response.samples) == 0:
             # Invalid group
             # Invalid group
@@ -306,29 +332,9 @@ class AutoAugTextDataset(IterableDataset):
 
 
         samples = list(response.samples)
         samples = list(response.samples)
         idx = 0
         idx = 0
-        use_interactive = random.random() < self.interactive_prob
-
-        if use_interactive is False:
-            # Random sample based on speaker using a truncated normal distribution
-            a = torch.tensor([0], dtype=torch.float32)
-            torch.nn.init.trunc_normal_(
-                a,
-                mean=self.max_length // 2,
-                std=self.max_length // 4,
-                a=10,
-                b=self.max_length,
-            )
-            remaining_tokens = a.long().item() - 4
-        else:
-            remaining_tokens = self.max_length
-
-        # Use speaker
-        if isinstance(self.use_speaker, float):
-            use_speaker = random.random() < self.use_speaker
-        else:
-            use_speaker = self.use_speaker
+        remaining_tokens = self.max_length
 
 
-        all_tokens, all_labels = [], []
+        all_messages = []
         while remaining_tokens > 0 and len(samples) > 0:
         while remaining_tokens > 0 and len(samples) > 0:
             sentence = samples.pop(0)
             sentence = samples.pop(0)
 
 
@@ -336,37 +342,52 @@ class AutoAugTextDataset(IterableDataset):
             text, length = self.tokenize_sentence(text)
             text, length = self.tokenize_sentence(text)
             remaining_tokens -= length + len(sentence.semantics[0].values)
             remaining_tokens -= length + len(sentence.semantics[0].values)
 
 
-            if use_interactive is False:
-                final_text.append(text)
-                final_semantic.append(sentence.semantics)
+            # For interactive mode, we only apply speaker for the first sentence
+            # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
+
+            if random.random() < self.asr_prob:
+                all_messages.append(
+                    Message(
+                        role="user",
+                        parts=[
+                            random.choice(asr_instructions),
+                            semantic_to_tensor(sentence.semantics),
+                        ],
+                    )
+                )
+                all_messages.append(
+                    Message(
+                        role="assistant",
+                        parts=[text],
+                    )
+                )
             else:
             else:
-                # For interactive mode, we only apply speaker for the first sentence
-                # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
-                tokens, labels = self.pack_sentences(
-                    sentences=[text],
-                    semantics=[sentence.semantics],
-                    speaker=response.name if use_speaker else None,
-                    add_bos=idx == 0,
-                    skip_text=random.random() < self.skip_text_prob,
+                skip_text = random.random() < self.skip_text_prob
+                if skip_text:
+                    text = SKIP_TEXT_STRING
+
+                all_messages.append(
+                    Message(
+                        role="user",
+                        parts=[random.choice(tts_instructions) + text],
+                        mask_labels=skip_text,
+                    )
+                )
+                all_messages.append(
+                    Message(
+                        role="assistant",
+                        parts=[semantic_to_tensor(sentence.semantics)],
+                        mask_labels=skip_text,
+                    )
                 )
                 )
-
-                all_tokens.append(tokens)
-                all_labels.append(labels)
 
 
             idx += 1
             idx += 1
 
 
-        if use_interactive is False:
-            tokens, labels = self.pack_sentences(
-                final_text,
-                semantics=final_semantic,
-                speaker=response.name if use_speaker else None,
-                add_bos=True,
-            )
-            all_tokens.append(tokens)
-            all_labels.append(labels)
-
-        tokens = torch.cat(all_tokens, dim=1)
-        labels = torch.cat(all_labels, dim=1)
+        tokens, labels = encode_conversation(
+            Conversation(messages=all_messages),
+            self.tokenizer,
+            num_codebooks=self.num_codebooks,
+        )
 
 
         # Verify that the length is correct
         # Verify that the length is correct
         assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
         assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
@@ -374,156 +395,71 @@ class AutoAugTextDataset(IterableDataset):
         # Verify bos token
         # Verify bos token
         assert tokens[0, 0] == self.tokenizer.bos_token_id
         assert tokens[0, 0] == self.tokenizer.bos_token_id
 
 
-        data = {"tokens": tokens, "labels": labels}
-
-        if self.use_negative_samples:
-            negative_samples = self.generate_negative_samples(all_tokens, all_labels)
-            data.update(negative_samples)
-
-        return data
-
-    def generate_negative_samples(self, all_tokens, all_labels):
-        new_tokens, new_labels = [], []
-
-        for tokens, labels in zip(all_tokens, all_labels):
-            # If all codebooks are not -100, we find where it starts
-            start = torch.where(labels[1:].sum(0) != -100 * (labels.size(0) - 1))[0][0]
-            assert (labels[1:, start:] != -100).all()  # This shouldn't happen
+        return {"tokens": tokens, "labels": labels}
 
 
-            mode = random.choice(["repeat", "lost", "noise"])
-            begin = random.randint(start, labels.size(1) - 1)
-            end = random.randint(begin, labels.size(1) - 1)
 
 
-            if mode == "repeat":
-                tokens = torch.cat(
-                    [
-                        tokens[:, :begin],
-                        tokens[:, begin:end],
-                        tokens[:, begin:end],
-                        tokens[:, end:],
-                    ],
-                    dim=1,
-                )
-                labels = torch.cat(
-                    [
-                        labels[:, :begin],
-                        labels[:, begin:end],
-                        labels[:, begin:end],
-                        labels[:, end:],
-                    ],
-                    dim=1,
-                )
-            elif mode == "lost":
-                tokens = torch.cat([tokens[:, :begin], tokens[:, end:]], dim=1)
-                labels = torch.cat([labels[:, :begin], labels[:, end:]], dim=1)
-            elif mode == "noise":
-                middle_tokens, middle_labels = (
-                    tokens[:, begin:end],
-                    labels[:, begin:end],
-                )
-                random_order0 = torch.randperm(middle_tokens.size(1))
-                random_order1 = torch.randperm(middle_tokens.size(1))
-                middle_tokens = middle_tokens[:, random_order0]
-                middle_labels = middle_labels[:, random_order1]
-                tokens = torch.cat(
-                    [tokens[:, :begin], middle_tokens, tokens[:, end:]], dim=1
-                )
-                labels = torch.cat(
-                    [labels[:, :begin], middle_labels, labels[:, end:]], dim=1
-                )
-
-            new_tokens.append(tokens)
-            new_labels.append(labels)
-
-        tokens = torch.cat(new_tokens, dim=1)
-        labels = torch.cat(new_labels, dim=1)
-
-        # Verify that the length is correct
-        assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
-
-        return {"negative_tokens": tokens, "negative_labels": labels}
-
-    def pack_sentences(
+class SemanticInstructionDataset(IterableDataset):
+    def __init__(
         self,
         self,
-        sentences: list[str],
-        semantics: list,
-        speaker: Optional[str] = None,
-        add_bos: bool = True,
-        skip_text: bool = False,
+        proto_files: list[str],
+        seed: int = 42,
+        max_length: int = 1024,
+        tokenizer: AutoTokenizer = None,
+        num_codebooks: Optional[int] = None,
     ):
     ):
-        if speaker is None:
-            speaker = "assistant"
+        super().__init__()
 
 
-        cated_sentences = " ".join(sentences)
-        if skip_text:
-            cated_sentences = SKIP_TEXT_STRING
+        self.seed = seed
+        self.max_length = max_length
+        self.tokenizer = tokenizer
+        self.proto_files = proto_files
+        self.num_codebooks = num_codebooks
 
 
-        final_text = "<|im_start|>user<|im_sep|>" + cated_sentences + "<|im_end|>"
-        final_text = final_text + f"<|im_start|>{speaker}<|im_sep|>"
+    def get_data_generator(self):
+        shard_proto_files = expand_split_proto_files(self.proto_files, seed=self.seed)
+        random.shuffle(shard_proto_files)
+        log.info(f"Fetched {len(shard_proto_files)} files")
 
 
-        encoded = self.tokenizer.encode(
-            final_text,
-            add_special_tokens=False,
-            truncation=False,
-            max_length=10**6,
-        )
-        semantic_length = sum([len(i[0].values) for i in semantics])
-        prompt_length = len(encoded)
-        num_codebooks = (
-            len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
-        )
+        for filename in shard_proto_files:
+            with open(filename, "rb") as f:
+                for group in read_pb_stream(f):
+                    yield group
 
 
-        bos_bias = 1 if add_bos else 0
+    def pack_one_group(self, group):
+        sentences = group.sentences
 
 
-        # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
-        tokens = (
-            encoded
-            + [self.semantic_token_id] * semantic_length
-            + self.tokenizer.convert_tokens_to_ids(
-                ["<|im_end|>", "<|end_of_sequence|>"]
+        messages = []
+        for idx, sentence in enumerate(sentences):
+            role = "user" if idx % 2 == 0 else "assistant"
+            semantic = semantic_to_tensor(sentence.semantics)
+            text = random.choice(sentence.texts)
+            parts = [semantic]
+            if role == "assistant":
+                # Let model to predict the text first
+                prev_text = random.choice(sentences[idx - 1].texts)
+                # parts.insert(0, f"Q: {prev_text}\nA: {text}")
+            messages.append(
+                Message(
+                    role=role,
+                    parts=parts,
+                )
             )
             )
-        )
-
-        if add_bos:
-            tokens = [self.tokenizer.bos_token_id] + tokens
-
-        # Codebook bos/padding: 0, eos: 1
-        codes = [
-            [CODEBOOK_PAD_TOKEN_ID] * (prompt_length + bos_bias)
-            for _ in range(num_codebooks)
-        ]
-        for segment in semantics:
-            for book_idx, book in zip(range(num_codebooks), segment):
-                for j in book.values:
-                    codes[book_idx].append(int(j) + 2)
 
 
-        for book in codes:
-            book.extend([CODEBOOK_EOS_TOKEN_ID] * 2)
-
-        tokens = [tokens] + codes
-
-        tokens = torch.tensor(tokens, dtype=torch.long)
-        labels = tokens.clone()
-
-        if skip_text:
-            # If text is not provided, the sentence is used for condition only, all labels are -100
-            torch.fill_(labels, -100)
-            return tokens, labels
-
-        # Mask out the <s> tokens for semantic, predict semantic tokens only
-        # Since we don't mask out the input tokens, the language modeling still works
-        labels[1:, : (prompt_length + bos_bias)] = -100
-
-        tokens = tokens[:, :-1]
-        labels = labels[:, 1:]
+        conversation = Conversation(messages=messages)
+        tokens, labels = encode_conversation(
+            conversation,
+            self.tokenizer,
+            num_codebooks=self.num_codebooks,
+        )
 
 
-        # Verify the padding is correct, and the last token is eos
-        assert add_bos is False or tokens[0, 0] == self.tokenizer.bos_token_id
-        assert (tokens[1:, : prompt_length + bos_bias] == CODEBOOK_PAD_TOKEN_ID).all()
-        assert labels[0, -1] == self.tokenizer.eos_token_id
-        assert (labels[1:, -2:] == CODEBOOK_EOS_TOKEN_ID).all()
+        return {"tokens": tokens, "labels": labels}
 
 
-        return tokens, labels
+    def __iter__(self):
+        for group in self.get_data_generator():
+            try:
+                yield self.pack_one_group(group)
+            except Exception as e:
+                log.exception(f"Failed to parse {group}: {e}")
 
 
 
 
 @dataclass
 @dataclass
@@ -633,8 +569,18 @@ class InterleaveDataset(IterableDataset):
 class TextDataModule(LightningDataModule):
 class TextDataModule(LightningDataModule):
     def __init__(
     def __init__(
         self,
         self,
-        train_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
-        val_dataset: Union[StreamTextDataset, AutoAugTextDataset, InterleaveDataset],
+        train_dataset: Union[
+            AutoTextSemanticInstructionDataset,
+            TextPretrainDataset,
+            TextInstructionDataset,
+            InterleaveDataset,
+        ],
+        val_dataset: Union[
+            AutoTextSemanticInstructionDataset,
+            TextPretrainDataset,
+            TextInstructionDataset,
+            InterleaveDataset,
+        ],
         batch_size: int = 32,
         batch_size: int = 32,
         tokenizer: AutoTokenizer = None,
         tokenizer: AutoTokenizer = None,
         max_length: int = 1024,
         max_length: int = 1024,
@@ -671,17 +617,36 @@ class TextDataModule(LightningDataModule):
 if __name__ == "__main__":
 if __name__ == "__main__":
     from tqdm import tqdm
     from tqdm import tqdm
 
 
-    ds = AutoAugTextDataset(
-        ["data/protos"],
-        tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"),
-        use_speaker=False,
-        interactive_prob=1.0,
-        use_negative_samples=False,
-        skip_text_prob=0.5,
+    # ds = AutoTextSemanticInstructionDataset(
+    #     ["data/protos/sft/val/11labs"],
+    #     tokenizer=AutoTokenizer.from_pretrained("checkpoints/fish-speech-agent-1"),
+    #     skip_text_prob=1.0,
+    #     asr_prob=0.0,
+    #     num_codebooks=2,
+    # )
+    # ds = TextInstructionDataset(
+    #     source="data/openhermes2_5",
+    #     tokenizer=AutoTokenizer.from_pretrained("checkpoints/fish-speech-agent-1"),
+    # )
+
+    ds = SemanticInstructionDataset(
+        proto_files=["data/protos/sft/val/ultrachat_200k_spoken_openai"],
+        tokenizer=AutoTokenizer.from_pretrained("checkpoints/fish-speech-agent-1"),
+        num_codebooks=2,
     )
     )
 
 
     for i in ds:
     for i in ds:
-        print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
+        # print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
         # i["labels"][0][i["labels"][0] == -100] = 0
         # i["labels"][0][i["labels"][0] == -100] = 0
         # print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False))
         # print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False))
+
+        length = i["tokens"].size(1)
+        print(i["tokens"].size(), i["tokens"].dtype)
+        for j in range(length):
+            print(
+                ds.tokenizer.decode(i["tokens"][0, j]),
+                i["tokens"][:, j],
+                i["labels"][:, j],
+            )
+            input()
         break
         break

+ 0 - 3
fish_speech/models/text2semantic/__init__.py

@@ -1,3 +0,0 @@
-from .lit_module import TextToSemantic
-
-__all__ = ["TextToSemantic"]

+ 11 - 7
fish_speech/models/text2semantic/lit_module.py

@@ -6,8 +6,8 @@ import torch.nn.functional as F
 from lightning.pytorch.utilities.types import OptimizerLRScheduler
 from lightning.pytorch.utilities.types import OptimizerLRScheduler
 
 
 import fish_speech.utils as utils
 import fish_speech.utils as utils
+from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
 from fish_speech.models.text2semantic.llama import NaiveTransformer
 from fish_speech.models.text2semantic.llama import NaiveTransformer
-from fish_speech.models.text2semantic.lora_utils import LoraConfig, setup_lora
 
 
 log = utils.RankedLogger(__name__, rank_zero_only=True)
 log = utils.RankedLogger(__name__, rank_zero_only=True)
 
 
@@ -137,15 +137,15 @@ class TextToSemantic(L.LightningModule):
             labels, negative_labels = labels.chunk(2)
             labels, negative_labels = labels.chunk(2)
 
 
         # Generate labels
         # Generate labels
-        base_loss = F.cross_entropy(
-            token_logits.reshape(-1, token_logits.size(-1)),
+        base_loss = fast_cross_entropy_loss(
+            token_logits.view(-1, token_logits.size(-1)),
             labels[:, 0].reshape(-1),
             labels[:, 0].reshape(-1),
             ignore_index=-100,
             ignore_index=-100,
         )
         )
 
 
         codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
         codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
-        semantic_loss = F.cross_entropy(
-            codebook_logits.reshape(-1, codebook_logits.size(-1)),
+        semantic_loss = fast_cross_entropy_loss(
+            codebook_logits.view(-1, codebook_logits.size(-1)),
             codebook_labels.reshape(-1),
             codebook_labels.reshape(-1),
             ignore_index=-100,
             ignore_index=-100,
         )
         )
@@ -281,11 +281,15 @@ class TextToSemantic(L.LightningModule):
         return loss
         return loss
 
 
     def get_accuracy(self, logits, labels):
     def get_accuracy(self, logits, labels):
+        mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID)
+        if mask.sum() == 0:
+            return torch.tensor(0.0, device=logits.device)
+
         _, indices = logits.topk(5, dim=-1)
         _, indices = logits.topk(5, dim=-1)
         correct = indices.eq(labels.unsqueeze(-1))
         correct = indices.eq(labels.unsqueeze(-1))
-        correct[labels == -100] = 0
+        correct[~mask] = 0
         correct = correct.sum()
         correct = correct.sum()
-        accuracy = correct / (labels != -100).sum()
+        accuracy = correct / mask.sum()
 
 
         return accuracy
         return accuracy
 
 

+ 206 - 68
fish_speech/models/text2semantic/llama.py

@@ -1,5 +1,7 @@
+import json
 import math
 import math
 from dataclasses import dataclass
 from dataclasses import dataclass
+from pathlib import Path
 from typing import Optional
 from typing import Optional
 
 
 import torch
 import torch
@@ -7,7 +9,16 @@ import torch.nn as nn
 from einops import rearrange
 from einops import rearrange
 from torch import Tensor
 from torch import Tensor
 from torch.nn import functional as F
 from torch.nn import functional as F
+from torch.nn.attention import SDPBackend, sdpa_kernel
 from torch.utils.checkpoint import checkpoint
 from torch.utils.checkpoint import checkpoint
+from transformers import AutoTokenizer
+
+from fish_speech.conversation import SEMANTIC_TOKEN
+from fish_speech.utils import RankedLogger
+
+from .lora import LoraConfig, setup_lora
+
+log = RankedLogger(__name__, rank_zero_only=True)
 
 
 
 
 def find_multiple(n: int, k: int) -> int:
 def find_multiple(n: int, k: int) -> int:
@@ -18,6 +29,8 @@ def find_multiple(n: int, k: int) -> int:
 
 
 @dataclass
 @dataclass
 class BaseModelArgs:
 class BaseModelArgs:
+    model_type: str = "base"
+
     vocab_size: int = 32000
     vocab_size: int = 32000
     n_layer: int = 32
     n_layer: int = 32
     n_head: int = 32
     n_head: int = 32
@@ -29,16 +42,19 @@ class BaseModelArgs:
     norm_eps: float = 1e-5
     norm_eps: float = 1e-5
     max_seq_len: int = 2048
     max_seq_len: int = 2048
     dropout: float = 0.0
     dropout: float = 0.0
+    tie_word_embeddings: bool = True
+    attention_qkv_bias: bool = False
 
 
     # Codebook configs
     # Codebook configs
     codebook_size: int = 160
     codebook_size: int = 160
     num_codebooks: int = 4
     num_codebooks: int = 4
-    num_in_codebooks: Optional[int] = None
-    codebook_padding_idx: int = 0
 
 
     # Gradient checkpointing
     # Gradient checkpointing
     use_gradient_checkpointing: bool = True
     use_gradient_checkpointing: bool = True
 
 
+    # Initialize the model
+    initializer_range: float = 0.02
+
     def __post_init__(self):
     def __post_init__(self):
         if self.n_local_heads == -1:
         if self.n_local_heads == -1:
             self.n_local_heads = self.n_head
             self.n_local_heads = self.n_head
@@ -46,18 +62,41 @@ class BaseModelArgs:
             hidden_dim = 4 * self.dim
             hidden_dim = 4 * self.dim
             n_hidden = int(2 * hidden_dim / 3)
             n_hidden = int(2 * hidden_dim / 3)
             self.intermediate_size = find_multiple(n_hidden, 256)
             self.intermediate_size = find_multiple(n_hidden, 256)
-        if self.num_in_codebooks is None:
-            self.num_in_codebooks = self.num_codebooks
         self.head_dim = self.dim // self.n_head
         self.head_dim = self.dim // self.n_head
 
 
+    @staticmethod
+    def from_pretrained(path: str):
+        path = Path(path)
+
+        if path.is_dir():
+            path = path / "config.json"
+
+        with open(path, "r") as f:
+            data = json.load(f)
+
+        match data["model_type"]:
+            case "naive":
+                cls = NaiveModelArgs
+            case "dual_ar":
+                cls = DualARModelArgs
+            case _:
+                raise ValueError(f"Unknown model type: {data['model_type']}")
+
+        return cls(**data)
+
+    def save(self, path: str):
+        with open(path, "w") as f:
+            json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False)
+
 
 
 @dataclass
 @dataclass
 class NaiveModelArgs(BaseModelArgs):
 class NaiveModelArgs(BaseModelArgs):
-    pass
+    model_type: str = "naive"
 
 
 
 
 @dataclass
 @dataclass
 class DualARModelArgs(BaseModelArgs):
 class DualARModelArgs(BaseModelArgs):
+    model_type: str = "dual_ar"
     n_fast_layer: int = 4
     n_fast_layer: int = 4
 
 
 
 
@@ -95,24 +134,35 @@ class BaseTransformerForwardResult:
 
 
 
 
 class BaseTransformer(nn.Module):
 class BaseTransformer(nn.Module):
-    def __init__(self, config: BaseModelArgs) -> None:
+    def __init__(
+        self, config: BaseModelArgs, tokenizer: AutoTokenizer, init_weights: bool = True
+    ) -> None:
         super().__init__()
         super().__init__()
         self.config = config
         self.config = config
+        self.tokenizer = tokenizer
+
+        self.semantic_token_id = tokenizer.convert_tokens_to_ids(SEMANTIC_TOKEN)
 
 
         # Slow transformer
         # Slow transformer
         self.embeddings = nn.Embedding(
         self.embeddings = nn.Embedding(
-            config.vocab_size + config.codebook_size * config.num_in_codebooks,
+            config.vocab_size,
+            config.dim,
+        )
+        self.codebook_embeddings = nn.Embedding(
+            config.codebook_size * config.num_codebooks,
             config.dim,
             config.dim,
         )
         )
         self.layers = nn.ModuleList(
         self.layers = nn.ModuleList(
             TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
             TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
         )
         )
         self.norm = RMSNorm(config.dim, eps=config.norm_eps)
         self.norm = RMSNorm(config.dim, eps=config.norm_eps)
-        self.output = nn.Linear(
-            config.dim,
-            config.vocab_size,
-            bias=False,
-        )
+
+        if self.config.tie_word_embeddings is False:
+            self.output = nn.Linear(
+                config.dim,
+                config.vocab_size,
+                bias=False,
+            )
 
 
         self.register_buffer(
         self.register_buffer(
             "freqs_cis",
             "freqs_cis",
@@ -139,6 +189,9 @@ class BaseTransformer(nn.Module):
         self.max_batch_size = -1
         self.max_batch_size = -1
         self.max_seq_len = -1
         self.max_seq_len = -1
 
 
+        if init_weights:
+            self.apply(self._init_weights)
+
     def setup_caches(
     def setup_caches(
         self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
         self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
     ):
     ):
@@ -161,11 +214,9 @@ class BaseTransformer(nn.Module):
 
 
     def embed(self, x: Tensor) -> Tensor:
     def embed(self, x: Tensor) -> Tensor:
         vocab_embeds = [self.embeddings(x[:, 0])]
         vocab_embeds = [self.embeddings(x[:, 0])]
-        for i in range(self.config.num_in_codebooks):
-            emb = self.embeddings(
-                x[:, i + 1] + i * self.config.codebook_size + self.config.vocab_size
-            )
-            emb[x[:, i + 1] == self.config.codebook_padding_idx] = 0
+        for i in range(self.config.num_codebooks):
+            emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size)
+            emb[x[:, 0] != self.semantic_token_id] = 0
             vocab_embeds.append(emb)
             vocab_embeds.append(emb)
 
 
         x = torch.stack(vocab_embeds, dim=3)
         x = torch.stack(vocab_embeds, dim=3)
@@ -174,21 +225,23 @@ class BaseTransformer(nn.Module):
         return x
         return x
 
 
     def forward(
     def forward(
-        self, inp: Tensor, key_padding_mask: Optional[Tensor] = None
+        self,
+        inp: Tensor,
+        key_padding_mask: Optional[Tensor] = None,
     ) -> BaseTransformerForwardResult:
     ) -> BaseTransformerForwardResult:
-        # x: (batch, num_codebooks + 1, seq_len)
         seq_len = inp.size(2)
         seq_len = inp.size(2)
 
 
         # Here we want to merge the embeddings of the codebooks
         # Here we want to merge the embeddings of the codebooks
         x = self.embed(inp)
         x = self.embed(inp)
 
 
-        mask = self.causal_mask[None, None, :seq_len, :seq_len]  # (B, N, Q, K)
         freqs_cis = self.freqs_cis[:seq_len]
         freqs_cis = self.freqs_cis[:seq_len]
 
 
         # Not that the causal mask here follows the definition of scaled_dot_product_attention
         # Not that the causal mask here follows the definition of scaled_dot_product_attention
         # That is, FALSE means masked out
         # That is, FALSE means masked out
         # To maintain consistency, key_padding_mask use TRUE to mask out
         # To maintain consistency, key_padding_mask use TRUE to mask out
+        mask = None
         if key_padding_mask is not None:
         if key_padding_mask is not None:
+            mask = self.causal_mask[None, None, :seq_len, :seq_len]  # (B, N, Q, K)
             mask = mask & key_padding_mask[:, None, None, :].logical_not()
             mask = mask & key_padding_mask[:, None, None, :].logical_not()
 
 
         for layer in self.layers:
         for layer in self.layers:
@@ -199,7 +252,11 @@ class BaseTransformer(nn.Module):
 
 
         # We got slow_out here
         # We got slow_out here
         slow_out = self.norm(x)
         slow_out = self.norm(x)
-        token_logits = self.output(slow_out)
+
+        if self.config.tie_word_embeddings:
+            token_logits = F.linear(slow_out, self.embeddings.weight)
+        else:
+            token_logits = self.output(slow_out)
 
 
         return BaseTransformerForwardResult(
         return BaseTransformerForwardResult(
             logits=token_logits,
             logits=token_logits,
@@ -207,7 +264,10 @@ class BaseTransformer(nn.Module):
         )
         )
 
 
     def forward_generate(
     def forward_generate(
-        self, x: Tensor, input_pos: Optional[Tensor] = None
+        self,
+        x: Tensor,
+        input_pos: Optional[Tensor] = None,
+        return_all: bool = False,
     ) -> BaseTransformerForwardResult:
     ) -> BaseTransformerForwardResult:
         # This is used for generation, optimized for torch compile
         # This is used for generation, optimized for torch compile
         assert (
         assert (
@@ -225,22 +285,99 @@ class BaseTransformer(nn.Module):
             x = layer(x, freqs_cis, mask, input_pos=input_pos)
             x = layer(x, freqs_cis, mask, input_pos=input_pos)
 
 
         # If prefill, we only calculate the logits of last token
         # If prefill, we only calculate the logits of last token
-        if x.size(1) > 1:
+        if x.size(1) > 1 and not return_all:
             x = x[:, -1:]
             x = x[:, -1:]
 
 
         # We got slow_out here
         # We got slow_out here
         slow_out = self.norm(x)
         slow_out = self.norm(x)
-        token_logits = self.output(slow_out)
+
+        if self.config.tie_word_embeddings:
+            token_logits = F.linear(slow_out, self.embeddings.weight)
+        else:
+            token_logits = self.output(slow_out)
 
 
         return BaseTransformerForwardResult(
         return BaseTransformerForwardResult(
             logits=token_logits,
             logits=token_logits,
             hidden_states=x,
             hidden_states=x,
         )
         )
 
 
+    def _init_weights(self, module):
+        std = self.config.initializer_range
+        if isinstance(module, nn.Linear):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.Embedding):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.padding_idx is not None:
+                module.weight.data[module.padding_idx].zero_()
+
+    @staticmethod
+    def from_pretrained(
+        path: str,
+        load_weights: bool = False,
+        max_length: int | None = None,
+        lora_config: LoraConfig | None = None,
+        rope_base: int | None = None,
+    ) -> "BaseTransformer":
+        config = BaseModelArgs.from_pretrained(path)
+        if max_length is not None:
+            config.max_seq_len = max_length
+            log.info(f"Override max_seq_len to {max_length}")
+
+        if rope_base is not None:
+            config.rope_base = rope_base
+            log.info(f"Override rope_base to {rope_base}")
+
+        match config.model_type:
+            case "naive":
+                model_cls = NaiveTransformer
+            case "dual_ar":
+                model_cls = DualARTransformer
+            case _:
+                raise ValueError(f"Unknown model type: {config.model_type}")
+
+        tokenizer = AutoTokenizer.from_pretrained(str(path))
+        log.info(f"Loading model from {path}, config: {config}")
+        model = model_cls(config, tokenizer=tokenizer)
+
+        if lora_config is not None:
+            setup_lora(model, lora_config)
+            log.info(f"LoRA setup: {lora_config}")
+
+        if load_weights is False:
+            log.info("Randomly initialized model")
+        else:
+            weights = torch.load(
+                Path(path) / "model.pth", map_location="cpu", mmap=True
+            )
+            err = model.load_state_dict(weights, strict=False, assign=True)
+            log.info(f"Loaded weights with error: {err}")
+
+        return model
+
+    def save_pretrained(self, path: str, drop_lora: bool = False):
+        path = Path(path)
+        path.mkdir(parents=True, exist_ok=True)
+
+        self.config.save(path / "config.json")
+        state_dict = self.state_dict()
+
+        if drop_lora:
+            for key in list(state_dict.keys()):
+                if "lora" not in key:
+                    continue
+
+                state_dict.pop(key)
+                log.info(f"Drop LoRA parameter: {key}")
+
+        torch.save(state_dict, path / "model.pth")
+        self.tokenizer.save_pretrained(path)
+
 
 
 class NaiveTransformer(BaseTransformer):
 class NaiveTransformer(BaseTransformer):
-    def __init__(self, config: NaiveModelArgs) -> None:
-        super().__init__(config)
+    def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
+        super().__init__(config, init_weights=False, tokenizer=tokenizer)
 
 
         self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
         self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
         self.codebook_output = nn.Linear(
         self.codebook_output = nn.Linear(
@@ -249,6 +386,8 @@ class NaiveTransformer(BaseTransformer):
             bias=False,
             bias=False,
         )
         )
 
 
+        self.apply(self._init_weights)
+
     def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
     def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
         token_logits = result.logits
         token_logits = result.logits
         x = result.hidden_states
         x = result.hidden_states
@@ -265,9 +404,14 @@ class NaiveTransformer(BaseTransformer):
         )
         )
 
 
     def forward(
     def forward(
-        self, inp: Tensor, key_padding_mask: Optional[Tensor] = None
+        self,
+        inp: Tensor,
+        key_padding_mask: Optional[Tensor] = None,
     ) -> TransformerForwardResult:
     ) -> TransformerForwardResult:
-        result = super().forward(inp, key_padding_mask)
+        result = super().forward(
+            inp=inp,
+            key_padding_mask=key_padding_mask,
+        )
         return self.decode(result)
         return self.decode(result)
 
 
     def forward_generate(
     def forward_generate(
@@ -278,13 +422,11 @@ class NaiveTransformer(BaseTransformer):
 
 
 
 
 class DualARTransformer(BaseTransformer):
 class DualARTransformer(BaseTransformer):
-    def __init__(self, config: DualARModelArgs) -> None:
-        super().__init__(config)
+    def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
+        super().__init__(config, init_weights=False, tokenizer=tokenizer)
 
 
         # Fast transformer
         # Fast transformer
-        self.fast_embeddings = nn.Embedding(
-            config.codebook_size, config.dim, padding_idx=config.codebook_padding_idx
-        )
+        self.fast_embeddings = nn.Embedding(config.codebook_size, config.dim)
 
 
         # The equivalent bs is so large that sdpa doesn't work
         # The equivalent bs is so large that sdpa doesn't work
         self.fast_layers = nn.ModuleList(
         self.fast_layers = nn.ModuleList(
@@ -297,6 +439,8 @@ class DualARTransformer(BaseTransformer):
             bias=False,
             bias=False,
         )
         )
 
 
+        self.apply(self._init_weights)
+
     def setup_caches(
     def setup_caches(
         self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
         self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
     ):
     ):
@@ -316,7 +460,9 @@ class DualARTransformer(BaseTransformer):
             )
             )
 
 
     def forward(
     def forward(
-        self, inp: Tensor, key_padding_mask: Optional[Tensor] = None
+        self,
+        inp: Tensor,
+        key_padding_mask: Optional[Tensor] = None,
     ) -> TransformerForwardResult:
     ) -> TransformerForwardResult:
         parent_result = super().forward(inp, key_padding_mask)
         parent_result = super().forward(inp, key_padding_mask)
         token_logits = parent_result.logits
         token_logits = parent_result.logits
@@ -340,6 +486,11 @@ class DualARTransformer(BaseTransformer):
         # Remove padded part
         # Remove padded part
         codebooks = rearrange(codebooks, "b n s -> (b s) n")
         codebooks = rearrange(codebooks, "b n s -> (b s) n")
         codebook_mask = (codebooks == self.config.codebook_padding_idx).all(dim=-1)
         codebook_mask = (codebooks == self.config.codebook_padding_idx).all(dim=-1)
+
+        if torch.all(codebook_mask):
+            # If all codebooks are padded, we keep first 8 to make sure the model runs
+            codebook_mask[:8] = False
+
         x_bs, x_len = x.size(0), x.size(1)
         x_bs, x_len = x.size(0), x.size(1)
         x = x[~codebook_mask]
         x = x[~codebook_mask]
 
 
@@ -422,7 +573,9 @@ class Attention(nn.Module):
 
 
         total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
         total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
         # key, query, value projections for all heads, but in a batch
         # key, query, value projections for all heads, but in a batch
-        self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
+        self.wqkv = nn.Linear(
+            config.dim, total_head_dim, bias=config.attention_qkv_bias
+        )
         self.wo = nn.Linear(config.dim, config.dim, bias=False)
         self.wo = nn.Linear(config.dim, config.dim, bias=False)
         self.kv_cache = None
         self.kv_cache = None
 
 
@@ -469,13 +622,24 @@ class Attention(nn.Module):
         v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
         v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
 
 
         if self.use_sdpa:
         if self.use_sdpa:
-            y = F.scaled_dot_product_attention(
-                q,
-                k,
-                v,
-                attn_mask=mask,
-                dropout_p=self.dropout if self.training else 0.0,
-            )
+            if mask is None:
+                with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
+                    y = F.scaled_dot_product_attention(
+                        q,
+                        k,
+                        v,
+                        dropout_p=self.dropout if self.training else 0.0,
+                        is_causal=True,
+                        # No thirdparty attn_mask here to use flash_attention
+                    )
+            else:
+                y = F.scaled_dot_product_attention(
+                    q,
+                    k,
+                    v,
+                    attn_mask=mask,
+                    dropout_p=self.dropout if self.training else 0.0,
+                )
         else:
         else:
             y = self.eq_scaled_dot_product_attention(
             y = self.eq_scaled_dot_product_attention(
                 q,
                 q,
@@ -567,29 +731,3 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
 
 
     x_out2 = x_out2.flatten(3)
     x_out2 = x_out2.flatten(3)
     return x_out2.type_as(x)
     return x_out2.type_as(x)
-
-
-if __name__ == "__main__":
-    args = DualARModelArgs(
-        max_seq_len=4096,
-        vocab_size=32312,
-        n_layer=12,
-        n_fast_layer=4,
-        n_head=12,
-        dim=768,
-        rope_base=10000,
-        norm_eps=1e-5,
-        codebook_size=128,
-        num_codebooks=4,
-    )
-
-    model = DualARTransformer(args)
-    model = model.cuda().bfloat16()
-    print("Total params:", sum(i.numel() for i in model.parameters()) / 1024 / 1024)
-
-    inputs = torch.randint(0, 100, (2, 5, 128)).cuda()
-    key_padding_mask = torch.zeros(2, 128).bool().cuda()
-    key_padding_mask[0, 2:] = True
-    x1 = model(inputs, key_padding_mask=key_padding_mask)
-    print(x1.token_logits.shape)
-    print(x1.codebook_logits.shape)

+ 8 - 0
fish_speech/models/text2semantic/lora_utils.py → fish_speech/models/text2semantic/lora.py

@@ -20,6 +20,14 @@ def setup_lora(model, lora_config):
         lora_alpha=lora_config.lora_alpha,
         lora_alpha=lora_config.lora_alpha,
     )
     )
 
 
+    model.codebook_embeddings = lora.Embedding(
+        num_embeddings=model.codebook_embeddings.num_embeddings,
+        embedding_dim=model.codebook_embeddings.embedding_dim,
+        padding_idx=model.codebook_embeddings.padding_idx,
+        r=lora_config.r,
+        lora_alpha=lora_config.lora_alpha,
+    )
+
     # Replace output layer with a LoRA layer
     # Replace output layer with a LoRA layer
     linears = [(model, "output")]
     linears = [(model, "output")]
 
 

+ 0 - 3
fish_speech/models/vits_decoder/__init__.py

@@ -1,3 +0,0 @@
-from .lit_module import VITSDecoder
-
-__all__ = ["VITSDecoder"]

+ 0 - 394
fish_speech/models/vits_decoder/lit_module.py

@@ -1,394 +0,0 @@
-from typing import Any, Callable
-
-import lightning as L
-import torch
-import torch.nn.functional as F
-import wandb
-from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
-from matplotlib import pyplot as plt
-from torch import nn
-
-from fish_speech.models.vits_decoder.losses import (
-    discriminator_loss,
-    feature_loss,
-    generator_loss,
-    kl_loss,
-)
-from fish_speech.models.vqgan.utils import (
-    avg_with_mask,
-    plot_mel,
-    sequence_mask,
-    slice_segments,
-)
-
-
-class VITSDecoder(L.LightningModule):
-    def __init__(
-        self,
-        optimizer: Callable,
-        lr_scheduler: Callable,
-        generator: nn.Module,
-        discriminator: nn.Module,
-        mel_transform: nn.Module,
-        spec_transform: nn.Module,
-        hop_length: int = 512,
-        sample_rate: int = 44100,
-        freeze_discriminator: bool = False,
-        weight_mel: float = 45,
-        weight_kl: float = 0.1,
-    ):
-        super().__init__()
-
-        # Model parameters
-        self.optimizer_builder = optimizer
-        self.lr_scheduler_builder = lr_scheduler
-
-        # Generator and discriminator
-        self.generator = generator
-        self.discriminator = discriminator
-        self.mel_transform = mel_transform
-        self.spec_transform = spec_transform
-        self.freeze_discriminator = freeze_discriminator
-
-        # Loss weights
-        self.weight_mel = weight_mel
-        self.weight_kl = weight_kl
-
-        # Other parameters
-        self.hop_length = hop_length
-        self.sampling_rate = sample_rate
-
-        # Disable automatic optimization
-        self.automatic_optimization = False
-
-        if self.freeze_discriminator:
-            for p in self.discriminator.parameters():
-                p.requires_grad = False
-
-    def configure_optimizers(self):
-        # Need two optimizers and two schedulers
-        optimizer_generator = self.optimizer_builder(self.generator.parameters())
-        optimizer_discriminator = self.optimizer_builder(
-            self.discriminator.parameters()
-        )
-
-        lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
-        lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator)
-
-        return (
-            {
-                "optimizer": optimizer_generator,
-                "lr_scheduler": {
-                    "scheduler": lr_scheduler_generator,
-                    "interval": "step",
-                    "name": "optimizer/generator",
-                },
-            },
-            {
-                "optimizer": optimizer_discriminator,
-                "lr_scheduler": {
-                    "scheduler": lr_scheduler_discriminator,
-                    "interval": "step",
-                    "name": "optimizer/discriminator",
-                },
-            },
-        )
-
-    def training_step(self, batch, batch_idx):
-        optim_g, optim_d = self.optimizers()
-
-        audios, audio_lengths = batch["audios"], batch["audio_lengths"]
-        texts, text_lengths = batch["texts"], batch["text_lengths"]
-
-        audios = audios.float()
-        audios = audios[:, None, :]
-
-        with torch.no_grad():
-            gt_mels = self.mel_transform(audios)
-            gt_specs = self.spec_transform(audios)
-
-        spec_lengths = audio_lengths // self.hop_length
-        spec_masks = torch.unsqueeze(
-            sequence_mask(spec_lengths, gt_mels.shape[2]), 1
-        ).to(gt_mels.dtype)
-
-        (
-            fake_audios,
-            ids_slice,
-            y_mask,
-            (z, z_p, m_p, logs_p, m_q, logs_q),
-        ) = self.generator(
-            audios,
-            audio_lengths,
-            gt_specs,
-            spec_lengths,
-            texts,
-            text_lengths,
-        )
-
-        gt_mels = slice_segments(gt_mels, ids_slice, self.generator.segment_size)
-        spec_masks = slice_segments(spec_masks, ids_slice, self.generator.segment_size)
-        audios = slice_segments(
-            audios,
-            ids_slice * self.hop_length,
-            self.generator.segment_size * self.hop_length,
-        )
-        fake_mels = self.mel_transform(fake_audios.squeeze(1))
-
-        assert (
-            audios.shape == fake_audios.shape
-        ), f"{audios.shape} != {fake_audios.shape}"
-
-        # Discriminator
-        if self.freeze_discriminator is False:
-            y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(
-                audios, fake_audios.detach()
-            )
-
-            with torch.autocast(device_type=audios.device.type, enabled=False):
-                loss_disc, _, _ = discriminator_loss(y_d_hat_r, y_d_hat_g)
-
-            self.log(
-                f"train/discriminator/loss",
-                loss_disc,
-                on_step=True,
-                on_epoch=False,
-                prog_bar=False,
-                logger=True,
-                sync_dist=True,
-            )
-
-            optim_d.zero_grad()
-            self.manual_backward(loss_disc)
-            self.clip_gradients(
-                optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
-            )
-            optim_d.step()
-
-        # Adv Loss
-        y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(audios, fake_audios)
-
-        # Adversarial Loss
-        with torch.autocast(device_type=audios.device.type, enabled=False):
-            loss_adv, _ = generator_loss(y_d_hat_g)
-
-        self.log(
-            f"train/generator/adv",
-            loss_adv,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=False,
-            logger=True,
-            sync_dist=True,
-        )
-
-        with torch.autocast(device_type=audios.device.type, enabled=False):
-            loss_fm = feature_loss(y_d_hat_r, y_d_hat_g)
-
-        self.log(
-            f"train/generator/adv_fm",
-            loss_fm,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=False,
-            logger=True,
-            sync_dist=True,
-        )
-
-        with torch.autocast(device_type=audios.device.type, enabled=False):
-            loss_mel = avg_with_mask(
-                F.l1_loss(gt_mels, fake_mels, reduction="none"), spec_masks
-            )
-
-        self.log(
-            "train/generator/loss_mel",
-            loss_mel,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=False,
-            logger=True,
-            sync_dist=True,
-        )
-
-        loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, y_mask)
-
-        self.log(
-            "train/generator/loss_kl",
-            loss_kl,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=False,
-            logger=True,
-            sync_dist=True,
-        )
-
-        loss = (
-            loss_mel * self.weight_mel + loss_kl * self.weight_kl + loss_adv + loss_fm
-        )
-        self.log(
-            "train/generator/loss",
-            loss,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=True,
-            logger=True,
-            sync_dist=True,
-        )
-
-        # Backward
-        optim_g.zero_grad()
-
-        self.manual_backward(loss)
-        self.clip_gradients(
-            optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
-        )
-        optim_g.step()
-
-        # Manual LR Scheduler
-        scheduler_g, scheduler_d = self.lr_schedulers()
-        scheduler_g.step()
-        scheduler_d.step()
-
-    def validation_step(self, batch: Any, batch_idx: int):
-        audios, audio_lengths = batch["audios"], batch["audio_lengths"]
-        texts, text_lengths = batch["texts"], batch["text_lengths"]
-
-        audios = audios.float()
-        audios = audios[:, None, :]
-
-        gt_mels = self.mel_transform(audios)
-        gt_specs = self.spec_transform(audios)
-        spec_lengths = audio_lengths // self.hop_length
-        spec_masks = torch.unsqueeze(
-            sequence_mask(spec_lengths, gt_mels.shape[2]), 1
-        ).to(gt_mels.dtype)
-
-        prior_audios = self.generator.infer(
-            audios, audio_lengths, gt_specs, spec_lengths, texts, text_lengths
-        )
-        posterior_audios = self.generator.infer_posterior(gt_specs, spec_lengths)
-        prior_mels = self.mel_transform(prior_audios.squeeze(1))
-        posterior_mels = self.mel_transform(posterior_audios.squeeze(1))
-
-        min_mel_length = min(
-            gt_mels.shape[-1], prior_mels.shape[-1], posterior_mels.shape[-1]
-        )
-        gt_mels = gt_mels[:, :, :min_mel_length]
-        prior_mels = prior_mels[:, :, :min_mel_length]
-        posterior_mels = posterior_mels[:, :, :min_mel_length]
-
-        prior_mel_loss = avg_with_mask(
-            F.l1_loss(gt_mels, prior_mels, reduction="none"), spec_masks
-        )
-        posterior_mel_loss = avg_with_mask(
-            F.l1_loss(gt_mels, posterior_mels, reduction="none"), spec_masks
-        )
-
-        self.log(
-            "val/prior_mel_loss",
-            prior_mel_loss,
-            on_step=False,
-            on_epoch=True,
-            prog_bar=False,
-            logger=True,
-            sync_dist=True,
-        )
-
-        self.log(
-            "val/posterior_mel_loss",
-            posterior_mel_loss,
-            on_step=False,
-            on_epoch=True,
-            prog_bar=False,
-            logger=True,
-            sync_dist=True,
-        )
-
-        # only log the first batch
-        if batch_idx != 0:
-            return
-
-        for idx, (
-            mel,
-            prior_mel,
-            posterior_mel,
-            audio,
-            prior_audio,
-            posterior_audio,
-            audio_len,
-        ) in enumerate(
-            zip(
-                gt_mels,
-                prior_mels,
-                posterior_mels,
-                audios.detach().float(),
-                prior_audios.detach().float(),
-                posterior_audios.detach().float(),
-                audio_lengths,
-            )
-        ):
-            mel_len = audio_len // self.hop_length
-
-            image_mels = plot_mel(
-                [
-                    prior_mel[:, :mel_len],
-                    posterior_mel[:, :mel_len],
-                    mel[:, :mel_len],
-                ],
-                [
-                    "Prior (VQ)",
-                    "Posterior (Reconstruction)",
-                    "Ground-Truth",
-                ],
-            )
-
-            if isinstance(self.logger, WandbLogger):
-                self.logger.experiment.log(
-                    {
-                        "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
-                        "wavs": [
-                            wandb.Audio(
-                                audio[0, :audio_len],
-                                sample_rate=self.sampling_rate,
-                                caption="gt",
-                            ),
-                            wandb.Audio(
-                                prior_audio[0, :audio_len],
-                                sample_rate=self.sampling_rate,
-                                caption="prior",
-                            ),
-                            wandb.Audio(
-                                posterior_audio[0, :audio_len],
-                                sample_rate=self.sampling_rate,
-                                caption="posterior",
-                            ),
-                        ],
-                    },
-                )
-
-            if isinstance(self.logger, TensorBoardLogger):
-                self.logger.experiment.add_figure(
-                    f"sample-{idx}/mels",
-                    image_mels,
-                    global_step=self.global_step,
-                )
-                self.logger.experiment.add_audio(
-                    f"sample-{idx}/wavs/gt",
-                    audio[0, :audio_len],
-                    self.global_step,
-                    sample_rate=self.sampling_rate,
-                )
-                self.logger.experiment.add_audio(
-                    f"sample-{idx}/wavs/prior",
-                    prior_audio[0, :audio_len],
-                    self.global_step,
-                    sample_rate=self.sampling_rate,
-                )
-                self.logger.experiment.add_audio(
-                    f"sample-{idx}/wavs/posterior",
-                    posterior_audio[0, :audio_len],
-                    self.global_step,
-                    sample_rate=self.sampling_rate,
-                )
-
-            plt.close(image_mels)

+ 0 - 67
fish_speech/models/vits_decoder/losses.py

@@ -1,67 +0,0 @@
-import torch
-import torch.nn.functional as F
-from torch import nn
-
-
-def feature_loss(fmap_r: list[torch.Tensor], fmap_g: list[torch.Tensor]):
-    loss = 0
-    for dr, dg in zip(fmap_r, fmap_g):
-        dr = dr.float().detach()
-        dg = dg.float()
-        loss += torch.mean(torch.abs(dr - dg))
-
-    return loss * 2
-
-
-def discriminator_loss(
-    disc_real_outputs: list[torch.Tensor], disc_generated_outputs: list[torch.Tensor]
-):
-    loss = 0
-    r_losses = []
-    g_losses = []
-    for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
-        dr = dr.float()
-        dg = dg.float()
-        r_loss = torch.mean((1 - dr) ** 2)
-        g_loss = torch.mean(dg**2)
-        loss += r_loss + g_loss
-        r_losses.append(r_loss.item())
-        g_losses.append(g_loss.item())
-
-    return loss, r_losses, g_losses
-
-
-def generator_loss(disc_outputs: list[torch.Tensor]):
-    loss = 0
-    gen_losses = []
-    for dg in disc_outputs:
-        dg = dg.float()
-        l = torch.mean((1 - dg) ** 2)
-        gen_losses.append(l)
-        loss += l
-
-    return loss, gen_losses
-
-
-def kl_loss(
-    z_p: torch.Tensor,
-    logs_q: torch.Tensor,
-    m_p: torch.Tensor,
-    logs_p: torch.Tensor,
-    z_mask: torch.Tensor,
-):
-    """
-    z_p, logs_q: [b, h, t_t]
-    m_p, logs_p: [b, h, t_t]
-    """
-    z_p = z_p.float()
-    logs_q = logs_q.float()
-    m_p = m_p.float()
-    logs_p = logs_p.float()
-    z_mask = z_mask.float()
-
-    kl = logs_p - logs_q - 0.5
-    kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
-    kl = torch.sum(kl * z_mask)
-    l = kl / torch.sum(z_mask)
-    return l

+ 0 - 350
fish_speech/models/vits_decoder/modules/attentions.py

@@ -1,350 +0,0 @@
-import math
-
-import torch
-from torch import nn
-from torch.nn import functional as F
-from torch.nn.utils import remove_weight_norm, weight_norm
-
-from fish_speech.models.vits_decoder.modules import commons
-
-from .modules import LayerNorm
-
-
-class Encoder(nn.Module):
-    def __init__(
-        self,
-        hidden_channels,
-        filter_channels,
-        n_heads,
-        n_layers,
-        kernel_size=1,
-        p_dropout=0.0,
-        window_size=4,
-        isflow=False,
-        gin_channels=0,
-    ):
-        super().__init__()
-        self.hidden_channels = hidden_channels
-        self.filter_channels = filter_channels
-        self.n_heads = n_heads
-        self.n_layers = n_layers
-        self.kernel_size = kernel_size
-        self.p_dropout = p_dropout
-        self.window_size = window_size
-
-        self.drop = nn.Dropout(p_dropout)
-        self.attn_layers = nn.ModuleList()
-        self.norm_layers_1 = nn.ModuleList()
-        self.ffn_layers = nn.ModuleList()
-        self.norm_layers_2 = nn.ModuleList()
-        for i in range(self.n_layers):
-            self.attn_layers.append(
-                MultiHeadAttention(
-                    hidden_channels,
-                    hidden_channels,
-                    n_heads,
-                    p_dropout=p_dropout,
-                    window_size=window_size,
-                )
-            )
-            self.norm_layers_1.append(LayerNorm(hidden_channels))
-            self.ffn_layers.append(
-                FFN(
-                    hidden_channels,
-                    hidden_channels,
-                    filter_channels,
-                    kernel_size,
-                    p_dropout=p_dropout,
-                )
-            )
-            self.norm_layers_2.append(LayerNorm(hidden_channels))
-
-        if isflow:
-            cond_layer = torch.nn.Conv1d(
-                gin_channels, 2 * hidden_channels * n_layers, 1
-            )
-            self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
-            self.cond_layer = weight_norm(cond_layer, "weight")
-            self.gin_channels = gin_channels
-
-    def forward(self, x, x_mask, g=None):
-        attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
-        x = x * x_mask
-        if g is not None:
-            g = self.cond_layer(g)
-
-        for i in range(self.n_layers):
-            if g is not None:
-                x = self.cond_pre(x)
-                cond_offset = i * 2 * self.hidden_channels
-                g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
-                x = commons.fused_add_tanh_sigmoid_multiply(
-                    x, g_l, torch.IntTensor([self.hidden_channels])
-                )
-            y = self.attn_layers[i](x, x, attn_mask)
-            y = self.drop(y)
-            x = self.norm_layers_1[i](x + y)
-
-            y = self.ffn_layers[i](x, x_mask)
-            y = self.drop(y)
-            x = self.norm_layers_2[i](x + y)
-        x = x * x_mask
-        return x
-
-
-class MultiHeadAttention(nn.Module):
-    def __init__(
-        self,
-        channels,
-        out_channels,
-        n_heads,
-        p_dropout=0.0,
-        window_size=None,
-        heads_share=True,
-        block_length=None,
-        proximal_bias=False,
-        proximal_init=False,
-    ):
-        super().__init__()
-        assert channels % n_heads == 0
-
-        self.channels = channels
-        self.out_channels = out_channels
-        self.n_heads = n_heads
-        self.p_dropout = p_dropout
-        self.window_size = window_size
-        self.heads_share = heads_share
-        self.block_length = block_length
-        self.proximal_bias = proximal_bias
-        self.proximal_init = proximal_init
-        self.attn = None
-
-        self.k_channels = channels // n_heads
-        self.conv_q = nn.Conv1d(channels, channels, 1)
-        self.conv_k = nn.Conv1d(channels, channels, 1)
-        self.conv_v = nn.Conv1d(channels, channels, 1)
-        self.conv_o = nn.Conv1d(channels, out_channels, 1)
-        self.drop = nn.Dropout(p_dropout)
-
-        if window_size is not None:
-            n_heads_rel = 1 if heads_share else n_heads
-            rel_stddev = self.k_channels**-0.5
-            self.emb_rel_k = nn.Parameter(
-                torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
-                * rel_stddev
-            )
-            self.emb_rel_v = nn.Parameter(
-                torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
-                * rel_stddev
-            )
-
-        nn.init.xavier_uniform_(self.conv_q.weight)
-        nn.init.xavier_uniform_(self.conv_k.weight)
-        nn.init.xavier_uniform_(self.conv_v.weight)
-        if proximal_init:
-            with torch.no_grad():
-                self.conv_k.weight.copy_(self.conv_q.weight)
-                self.conv_k.bias.copy_(self.conv_q.bias)
-
-    def forward(self, x, c, attn_mask=None):
-        q = self.conv_q(x)
-        k = self.conv_k(c)
-        v = self.conv_v(c)
-
-        x, self.attn = self.attention(q, k, v, mask=attn_mask)
-
-        x = self.conv_o(x)
-        return x
-
-    def attention(self, query, key, value, mask=None):
-        # reshape [b, d, t] -> [b, n_h, t, d_k]
-        b, d, t_s, t_t = (*key.size(), query.size(2))
-        query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
-        key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
-        value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
-
-        scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
-        if self.window_size is not None:
-            assert (
-                t_s == t_t
-            ), "Relative attention is only available for self-attention."
-            key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
-            rel_logits = self._matmul_with_relative_keys(
-                query / math.sqrt(self.k_channels), key_relative_embeddings
-            )
-            scores_local = self._relative_position_to_absolute_position(rel_logits)
-            scores = scores + scores_local
-        if self.proximal_bias:
-            assert t_s == t_t, "Proximal bias is only available for self-attention."
-            scores = scores + self._attention_bias_proximal(t_s).to(
-                device=scores.device, dtype=scores.dtype
-            )
-        if mask is not None:
-            scores = scores.masked_fill(mask == 0, -1e4)
-            if self.block_length is not None:
-                assert (
-                    t_s == t_t
-                ), "Local attention is only available for self-attention."
-                block_mask = (
-                    torch.ones_like(scores)
-                    .triu(-self.block_length)
-                    .tril(self.block_length)
-                )
-                scores = scores.masked_fill(block_mask == 0, -1e4)
-        p_attn = F.softmax(scores, dim=-1)  # [b, n_h, t_t, t_s]
-        p_attn = self.drop(p_attn)
-        output = torch.matmul(p_attn, value)
-        if self.window_size is not None:
-            relative_weights = self._absolute_position_to_relative_position(p_attn)
-            value_relative_embeddings = self._get_relative_embeddings(
-                self.emb_rel_v, t_s
-            )
-            output = output + self._matmul_with_relative_values(
-                relative_weights, value_relative_embeddings
-            )
-        output = (
-            output.transpose(2, 3).contiguous().view(b, d, t_t)
-        )  # [b, n_h, t_t, d_k] -> [b, d, t_t]
-        return output, p_attn
-
-    def _matmul_with_relative_values(self, x, y):
-        """
-        x: [b, h, l, m]
-        y: [h or 1, m, d]
-        ret: [b, h, l, d]
-        """
-        ret = torch.matmul(x, y.unsqueeze(0))
-        return ret
-
-    def _matmul_with_relative_keys(self, x, y):
-        """
-        x: [b, h, l, d]
-        y: [h or 1, m, d]
-        ret: [b, h, l, m]
-        """
-        ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
-        return ret
-
-    def _get_relative_embeddings(self, relative_embeddings, length):
-        max_relative_position = 2 * self.window_size + 1
-        # Pad first before slice to avoid using cond ops.
-        pad_length = max(length - (self.window_size + 1), 0)
-        slice_start_position = max((self.window_size + 1) - length, 0)
-        slice_end_position = slice_start_position + 2 * length - 1
-        if pad_length > 0:
-            padded_relative_embeddings = F.pad(
-                relative_embeddings,
-                commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
-            )
-        else:
-            padded_relative_embeddings = relative_embeddings
-        used_relative_embeddings = padded_relative_embeddings[
-            :, slice_start_position:slice_end_position
-        ]
-        return used_relative_embeddings
-
-    def _relative_position_to_absolute_position(self, x):
-        """
-        x: [b, h, l, 2*l-1]
-        ret: [b, h, l, l]
-        """
-        batch, heads, length, _ = x.size()
-        # Concat columns of pad to shift from relative to absolute indexing.
-        x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
-
-        # Concat extra elements so to add up to shape (len+1, 2*len-1).
-        x_flat = x.view([batch, heads, length * 2 * length])
-        x_flat = F.pad(
-            x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
-        )
-
-        # Reshape and slice out the padded elements.
-        x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
-            :, :, :length, length - 1 :
-        ]
-        return x_final
-
-    def _absolute_position_to_relative_position(self, x):
-        """
-        x: [b, h, l, l]
-        ret: [b, h, l, 2*l-1]
-        """
-        batch, heads, length, _ = x.size()
-        # pad along column
-        x = F.pad(
-            x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
-        )
-        x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
-        # add 0's in the beginning that will skew the elements after reshape
-        x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
-        x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
-        return x_final
-
-    def _attention_bias_proximal(self, length):
-        """Bias for self-attention to encourage attention to close positions.
-        Args:
-          length: an integer scalar.
-        Returns:
-          a Tensor with shape [1, 1, length, length]
-        """
-        r = torch.arange(length, dtype=torch.float32)
-        diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
-        return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
-
-
-class FFN(nn.Module):
-    def __init__(
-        self,
-        in_channels,
-        out_channels,
-        filter_channels,
-        kernel_size,
-        p_dropout=0.0,
-        activation=None,
-        causal=False,
-    ):
-        super().__init__()
-        self.in_channels = in_channels
-        self.out_channels = out_channels
-        self.filter_channels = filter_channels
-        self.kernel_size = kernel_size
-        self.p_dropout = p_dropout
-        self.activation = activation
-        self.causal = causal
-
-        if causal:
-            self.padding = self._causal_padding
-        else:
-            self.padding = self._same_padding
-
-        self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
-        self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
-        self.drop = nn.Dropout(p_dropout)
-
-    def forward(self, x, x_mask):
-        x = self.conv_1(self.padding(x * x_mask))
-        if self.activation == "gelu":
-            x = x * torch.sigmoid(1.702 * x)
-        else:
-            x = torch.relu(x)
-        x = self.drop(x)
-        x = self.conv_2(self.padding(x * x_mask))
-        return x * x_mask
-
-    def _causal_padding(self, x):
-        if self.kernel_size == 1:
-            return x
-        pad_l = self.kernel_size - 1
-        pad_r = 0
-        padding = [[0, 0], [0, 0], [pad_l, pad_r]]
-        x = F.pad(x, commons.convert_pad_shape(padding))
-        return x
-
-    def _same_padding(self, x):
-        if self.kernel_size == 1:
-            return x
-        pad_l = (self.kernel_size - 1) // 2
-        pad_r = self.kernel_size // 2
-        padding = [[0, 0], [0, 0], [pad_l, pad_r]]
-        x = F.pad(x, commons.convert_pad_shape(padding))
-        return x

+ 0 - 190
fish_speech/models/vits_decoder/modules/commons.py

@@ -1,190 +0,0 @@
-import math
-
-import torch
-from torch.nn import functional as F
-
-
-def init_weights(m, mean=0.0, std=0.01):
-    classname = m.__class__.__name__
-    if classname.find("Conv") != -1:
-        m.weight.data.normal_(mean, std)
-
-
-def get_padding(kernel_size, dilation=1):
-    return int((kernel_size * dilation - dilation) / 2)
-
-
-def convert_pad_shape(pad_shape):
-    l = pad_shape[::-1]
-    pad_shape = [item for sublist in l for item in sublist]
-    return pad_shape
-
-
-def intersperse(lst, item):
-    result = [item] * (len(lst) * 2 + 1)
-    result[1::2] = lst
-    return result
-
-
-def kl_divergence(m_p, logs_p, m_q, logs_q):
-    """KL(P||Q)"""
-    kl = (logs_q - logs_p) - 0.5
-    kl += (
-        0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
-    )
-    return kl
-
-
-def rand_gumbel(shape):
-    """Sample from the Gumbel distribution, protect from overflows."""
-    uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
-    return -torch.log(-torch.log(uniform_samples))
-
-
-def rand_gumbel_like(x):
-    g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
-    return g
-
-
-def slice_segments(x, ids_str, segment_size=4):
-    ret = torch.zeros_like(x[:, :, :segment_size])
-    for i in range(x.size(0)):
-        idx_str = ids_str[i]
-        idx_end = idx_str + segment_size
-        ret[i] = x[i, :, idx_str:idx_end]
-    return ret
-
-
-def rand_slice_segments(x, x_lengths=None, segment_size=4):
-    b, d, t = x.size()
-    if x_lengths is None:
-        x_lengths = t
-    ids_str_max = x_lengths - segment_size + 1
-    ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
-    ret = slice_segments(x, ids_str, segment_size)
-    return ret, ids_str
-
-
-def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
-    position = torch.arange(length, dtype=torch.float)
-    num_timescales = channels // 2
-    log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
-        num_timescales - 1
-    )
-    inv_timescales = min_timescale * torch.exp(
-        torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
-    )
-    scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
-    signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
-    signal = F.pad(signal, [0, 0, 0, channels % 2])
-    signal = signal.view(1, channels, length)
-    return signal
-
-
-def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
-    b, channels, length = x.size()
-    signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
-    return x + signal.to(dtype=x.dtype, device=x.device)
-
-
-def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
-    b, channels, length = x.size()
-    signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
-    return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
-
-
-def subsequent_mask(length):
-    mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
-    return mask
-
-
-@torch.jit.script
-def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
-    n_channels_int = n_channels[0]
-    in_act = input_a + input_b
-    t_act = torch.tanh(in_act[:, :n_channels_int, :])
-    s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
-    acts = t_act * s_act
-    return acts
-
-
-def convert_pad_shape(pad_shape):
-    l = pad_shape[::-1]
-    pad_shape = [item for sublist in l for item in sublist]
-    return pad_shape
-
-
-def shift_1d(x):
-    x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
-    return x
-
-
-def sequence_mask(length, max_length=None):
-    if max_length is None:
-        max_length = length.max()
-    x = torch.arange(max_length, dtype=length.dtype, device=length.device)
-    return x.unsqueeze(0) < length.unsqueeze(1)
-
-
-def generate_path(duration, mask):
-    """
-    duration: [b, 1, t_x]
-    mask: [b, 1, t_y, t_x]
-    """
-    device = duration.device
-
-    b, _, t_y, t_x = mask.shape
-    cum_duration = torch.cumsum(duration, -1)
-
-    cum_duration_flat = cum_duration.view(b * t_x)
-    path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
-    path = path.view(b, t_x, t_y)
-    path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
-    path = path.unsqueeze(1).transpose(2, 3) * mask
-    return path
-
-
-def clip_grad_value_(parameters, clip_value, norm_type=2):
-    if isinstance(parameters, torch.Tensor):
-        parameters = [parameters]
-    parameters = list(filter(lambda p: p.grad is not None, parameters))
-    norm_type = float(norm_type)
-    if clip_value is not None:
-        clip_value = float(clip_value)
-
-    total_norm = 0
-    for p in parameters:
-        param_norm = p.grad.data.norm(norm_type)
-        total_norm += param_norm.item() ** norm_type
-        if clip_value is not None:
-            p.grad.data.clamp_(min=-clip_value, max=clip_value)
-    total_norm = total_norm ** (1.0 / norm_type)
-    return total_norm
-
-
-def squeeze(x, x_mask=None, n_sqz=2):
-    b, c, t = x.size()
-
-    t = (t // n_sqz) * n_sqz
-    x = x[:, :, :t]
-    x_sqz = x.view(b, c, t // n_sqz, n_sqz)
-    x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz)
-
-    if x_mask is not None:
-        x_mask = x_mask[:, :, n_sqz - 1 :: n_sqz]
-    else:
-        x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype)
-    return x_sqz * x_mask, x_mask
-
-
-def unsqueeze(x, x_mask=None, n_sqz=2):
-    b, c, t = x.size()
-
-    x_unsqz = x.view(b, n_sqz, c // n_sqz, t)
-    x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz)
-
-    if x_mask is not None:
-        x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz)
-    else:
-        x_mask = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype)
-    return x_unsqz * x_mask, x_mask

+ 0 - 686
fish_speech/models/vits_decoder/modules/models.py

@@ -1,686 +0,0 @@
-import torch
-from torch import nn
-from torch.nn import Conv1d, Conv2d, ConvTranspose1d
-from torch.nn import functional as F
-from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
-
-from fish_speech.models.vits_decoder.modules import attentions, commons, modules
-
-from .commons import get_padding, init_weights
-from .mrte import MRTE
-from .vq_encoder import VQEncoder
-
-
-class TextEncoder(nn.Module):
-    def __init__(
-        self,
-        out_channels,
-        hidden_channels,
-        filter_channels,
-        n_heads,
-        n_layers,
-        kernel_size,
-        p_dropout,
-        latent_channels=192,
-        codebook_size=264,
-    ):
-        super().__init__()
-        self.out_channels = out_channels
-        self.hidden_channels = hidden_channels
-        self.filter_channels = filter_channels
-        self.n_heads = n_heads
-        self.n_layers = n_layers
-        self.kernel_size = kernel_size
-        self.p_dropout = p_dropout
-        self.latent_channels = latent_channels
-
-        self.ssl_proj = nn.Conv1d(768, hidden_channels, 1)
-
-        self.encoder_ssl = attentions.Encoder(
-            hidden_channels,
-            filter_channels,
-            n_heads,
-            n_layers // 2,
-            kernel_size,
-            p_dropout,
-        )
-
-        self.encoder_text = attentions.Encoder(
-            hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
-        )
-        self.text_embedding = nn.Embedding(codebook_size, hidden_channels)
-
-        self.mrte = MRTE()
-
-        self.encoder2 = attentions.Encoder(
-            hidden_channels,
-            filter_channels,
-            n_heads,
-            n_layers // 2,
-            kernel_size,
-            p_dropout,
-        )
-
-        self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
-
-    def forward(self, y, y_lengths, text, text_lengths, ge):
-        y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
-            y.dtype
-        )
-
-        y = self.ssl_proj(y * y_mask) * y_mask
-
-        y = self.encoder_ssl(y * y_mask, y_mask)
-
-        text_mask = torch.unsqueeze(
-            commons.sequence_mask(text_lengths, text.size(1)), 1
-        ).to(y.dtype)
-        text = self.text_embedding(text).transpose(1, 2)
-        text = self.encoder_text(text * text_mask, text_mask)
-
-        y = self.mrte(y, y_mask, text, text_mask, ge)
-
-        y = self.encoder2(y * y_mask, y_mask)
-
-        stats = self.proj(y) * y_mask
-        m, logs = torch.split(stats, self.out_channels, dim=1)
-        return y, m, logs, y_mask
-
-
-class ResidualCouplingBlock(nn.Module):
-    def __init__(
-        self,
-        channels,
-        hidden_channels,
-        kernel_size,
-        dilation_rate,
-        n_layers,
-        n_flows=4,
-        gin_channels=0,
-    ):
-        super().__init__()
-        self.channels = channels
-        self.hidden_channels = hidden_channels
-        self.kernel_size = kernel_size
-        self.dilation_rate = dilation_rate
-        self.n_layers = n_layers
-        self.n_flows = n_flows
-        self.gin_channels = gin_channels
-
-        self.flows = nn.ModuleList()
-        for i in range(n_flows):
-            self.flows.append(
-                modules.ResidualCouplingLayer(
-                    channels,
-                    hidden_channels,
-                    kernel_size,
-                    dilation_rate,
-                    n_layers,
-                    gin_channels=gin_channels,
-                    mean_only=True,
-                )
-            )
-            self.flows.append(modules.Flip())
-
-    def forward(self, x, x_mask, g=None, reverse=False):
-        if not reverse:
-            for flow in self.flows:
-                x, _ = flow(x, x_mask, g=g, reverse=reverse)
-        else:
-            for flow in reversed(self.flows):
-                x = flow(x, x_mask, g=g, reverse=reverse)
-        return x
-
-
-class PosteriorEncoder(nn.Module):
-    def __init__(
-        self,
-        in_channels,
-        out_channels,
-        hidden_channels,
-        kernel_size,
-        dilation_rate,
-        n_layers,
-        gin_channels=0,
-    ):
-        super().__init__()
-        self.in_channels = in_channels
-        self.out_channels = out_channels
-        self.hidden_channels = hidden_channels
-        self.kernel_size = kernel_size
-        self.dilation_rate = dilation_rate
-        self.n_layers = n_layers
-        self.gin_channels = gin_channels
-
-        self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
-        self.enc = modules.WN(
-            hidden_channels,
-            kernel_size,
-            dilation_rate,
-            n_layers,
-            gin_channels=gin_channels,
-        )
-        self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
-
-    def forward(self, x, x_lengths, g=None):
-        if g != None:
-            g = g.detach()
-        x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
-            x.dtype
-        )
-        x = self.pre(x) * x_mask
-        x = self.enc(x, x_mask, g=g)
-        stats = self.proj(x) * x_mask
-        m, logs = torch.split(stats, self.out_channels, dim=1)
-        z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
-        return z, m, logs, x_mask
-
-
-class Generator(torch.nn.Module):
-    def __init__(
-        self,
-        initial_channel,
-        resblock,
-        resblock_kernel_sizes,
-        resblock_dilation_sizes,
-        upsample_rates,
-        upsample_initial_channel,
-        upsample_kernel_sizes,
-        gin_channels=0,
-    ):
-        super(Generator, self).__init__()
-        self.num_kernels = len(resblock_kernel_sizes)
-        self.num_upsamples = len(upsample_rates)
-        self.conv_pre = Conv1d(
-            initial_channel, upsample_initial_channel, 7, 1, padding=3
-        )
-        resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
-
-        self.ups = nn.ModuleList()
-        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
-            self.ups.append(
-                weight_norm(
-                    ConvTranspose1d(
-                        upsample_initial_channel // (2**i),
-                        upsample_initial_channel // (2 ** (i + 1)),
-                        k,
-                        u,
-                        padding=(k - u) // 2,
-                    )
-                )
-            )
-
-        self.resblocks = nn.ModuleList()
-        for i in range(len(self.ups)):
-            ch = upsample_initial_channel // (2 ** (i + 1))
-            for j, (k, d) in enumerate(
-                zip(resblock_kernel_sizes, resblock_dilation_sizes)
-            ):
-                self.resblocks.append(resblock(ch, k, d))
-
-        self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
-        self.ups.apply(init_weights)
-
-        if gin_channels != 0:
-            self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
-
-    def forward(self, x, g=None):
-        x = self.conv_pre(x)
-        if g is not None:
-            x = x + self.cond(g)
-
-        for i in range(self.num_upsamples):
-            x = F.leaky_relu(x, modules.LRELU_SLOPE)
-            x = self.ups[i](x)
-            xs = None
-            for j in range(self.num_kernels):
-                if xs is None:
-                    xs = self.resblocks[i * self.num_kernels + j](x)
-                else:
-                    xs += self.resblocks[i * self.num_kernels + j](x)
-            x = xs / self.num_kernels
-        x = F.leaky_relu(x)
-        x = self.conv_post(x)
-        x = torch.tanh(x)
-
-        return x
-
-    def remove_weight_norm(self):
-        print("Removing weight norm...")
-        for l in self.ups:
-            remove_weight_norm(l)
-        for l in self.resblocks:
-            l.remove_weight_norm()
-
-
-class DiscriminatorP(torch.nn.Module):
-    def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
-        super(DiscriminatorP, self).__init__()
-        self.period = period
-        self.use_spectral_norm = use_spectral_norm
-        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
-        self.convs = nn.ModuleList(
-            [
-                norm_f(
-                    Conv2d(
-                        1,
-                        32,
-                        (kernel_size, 1),
-                        (stride, 1),
-                        padding=(get_padding(kernel_size, 1), 0),
-                    )
-                ),
-                norm_f(
-                    Conv2d(
-                        32,
-                        128,
-                        (kernel_size, 1),
-                        (stride, 1),
-                        padding=(get_padding(kernel_size, 1), 0),
-                    )
-                ),
-                norm_f(
-                    Conv2d(
-                        128,
-                        512,
-                        (kernel_size, 1),
-                        (stride, 1),
-                        padding=(get_padding(kernel_size, 1), 0),
-                    )
-                ),
-                norm_f(
-                    Conv2d(
-                        512,
-                        1024,
-                        (kernel_size, 1),
-                        (stride, 1),
-                        padding=(get_padding(kernel_size, 1), 0),
-                    )
-                ),
-                norm_f(
-                    Conv2d(
-                        1024,
-                        1024,
-                        (kernel_size, 1),
-                        1,
-                        padding=(get_padding(kernel_size, 1), 0),
-                    )
-                ),
-            ]
-        )
-        self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
-
-    def forward(self, x):
-        fmap = []
-
-        # 1d to 2d
-        b, c, t = x.shape
-        if t % self.period != 0:  # pad first
-            n_pad = self.period - (t % self.period)
-            x = F.pad(x, (0, n_pad), "reflect")
-            t = t + n_pad
-        x = x.view(b, c, t // self.period, self.period)
-
-        for l in self.convs:
-            x = l(x)
-            x = F.leaky_relu(x, modules.LRELU_SLOPE)
-            fmap.append(x)
-        x = self.conv_post(x)
-        fmap.append(x)
-        x = torch.flatten(x, 1, -1)
-
-        return x, fmap
-
-
-class DiscriminatorS(torch.nn.Module):
-    def __init__(self, use_spectral_norm=False):
-        super(DiscriminatorS, self).__init__()
-        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
-        self.convs = nn.ModuleList(
-            [
-                norm_f(Conv1d(1, 16, 15, 1, padding=7)),
-                norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
-                norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
-                norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
-                norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
-                norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
-            ]
-        )
-        self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
-
-    def forward(self, x):
-        fmap = []
-
-        for l in self.convs:
-            x = l(x)
-            x = F.leaky_relu(x, modules.LRELU_SLOPE)
-            fmap.append(x)
-        x = self.conv_post(x)
-        fmap.append(x)
-        x = torch.flatten(x, 1, -1)
-
-        return x, fmap
-
-
-class EnsembledDiscriminator(torch.nn.Module):
-    def __init__(self, periods=(2, 3, 5, 7, 11), use_spectral_norm=False):
-        super().__init__()
-        discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
-        discs = discs + [
-            DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
-        ]
-        self.discriminators = nn.ModuleList(discs)
-
-    def forward(self, y, y_hat):
-        y_d_rs = []
-        y_d_gs = []
-        fmap_rs = []
-        fmap_gs = []
-        for i, d in enumerate(self.discriminators):
-            y_d_r, fmap_r = d(y)
-            y_d_g, fmap_g = d(y_hat)
-            y_d_rs.append(y_d_r)
-            y_d_gs.append(y_d_g)
-            fmap_rs.append(fmap_r)
-            fmap_gs.append(fmap_g)
-
-        return y_d_rs, y_d_gs, fmap_rs, fmap_gs
-
-
-class SynthesizerTrn(nn.Module):
-    """
-    Synthesizer for Training
-    """
-
-    def __init__(
-        self,
-        *,
-        spec_channels,
-        segment_size,
-        inter_channels,
-        hidden_channels,
-        filter_channels,
-        n_heads,
-        n_layers,
-        kernel_size,
-        p_dropout,
-        resblock,
-        resblock_kernel_sizes,
-        resblock_dilation_sizes,
-        upsample_rates,
-        upsample_initial_channel,
-        upsample_kernel_sizes,
-        gin_channels=0,
-        codebook_size=264,
-        vq_mask_ratio=0.0,
-        ref_mask_ratio=0.0,
-    ):
-        super().__init__()
-
-        self.spec_channels = spec_channels
-        self.inter_channels = inter_channels
-        self.hidden_channels = hidden_channels
-        self.filter_channels = filter_channels
-        self.n_heads = n_heads
-        self.n_layers = n_layers
-        self.kernel_size = kernel_size
-        self.p_dropout = p_dropout
-        self.resblock = resblock
-        self.resblock_kernel_sizes = resblock_kernel_sizes
-        self.resblock_dilation_sizes = resblock_dilation_sizes
-        self.upsample_rates = upsample_rates
-        self.upsample_initial_channel = upsample_initial_channel
-        self.upsample_kernel_sizes = upsample_kernel_sizes
-        self.segment_size = segment_size
-        self.gin_channels = gin_channels
-        self.vq_mask_ratio = vq_mask_ratio
-        self.ref_mask_ratio = ref_mask_ratio
-
-        self.enc_p = TextEncoder(
-            inter_channels,
-            hidden_channels,
-            filter_channels,
-            n_heads,
-            n_layers,
-            kernel_size,
-            p_dropout,
-            codebook_size=codebook_size,
-        )
-        self.dec = Generator(
-            inter_channels,
-            resblock,
-            resblock_kernel_sizes,
-            resblock_dilation_sizes,
-            upsample_rates,
-            upsample_initial_channel,
-            upsample_kernel_sizes,
-            gin_channels=gin_channels,
-        )
-        self.enc_q = PosteriorEncoder(
-            spec_channels,
-            inter_channels,
-            hidden_channels,
-            5,
-            1,
-            16,
-            gin_channels=gin_channels,
-        )
-        self.flow = ResidualCouplingBlock(
-            inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
-        )
-
-        self.ref_enc = modules.MelStyleEncoder(
-            spec_channels, style_vector_dim=gin_channels
-        )
-
-        self.vq = VQEncoder()
-        for param in self.vq.parameters():
-            param.requires_grad = False
-
-    def forward(
-        self, audio, audio_lengths, gt_specs, gt_spec_lengths, text, text_lengths
-    ):
-        y_mask = torch.unsqueeze(
-            commons.sequence_mask(gt_spec_lengths, gt_specs.size(2)), 1
-        ).to(gt_specs.dtype)
-        ge = self.ref_enc(gt_specs * y_mask, y_mask)
-
-        if self.training and self.ref_mask_ratio > 0:
-            bs = audio.size(0)
-            mask_speaker_len = int(bs * self.ref_mask_ratio)
-            mask_indices = torch.randperm(bs)[:mask_speaker_len]
-            audio[mask_indices] = 0
-
-        quantized = self.vq(audio, audio_lengths)
-
-        # Block masking, block_size = 4
-        block_size = 4
-        if self.training and self.vq_mask_ratio > 0:
-            reduced_length = quantized.size(-1) // block_size
-            mask_length = int(reduced_length * self.vq_mask_ratio)
-            mask_indices = torch.randperm(reduced_length)[:mask_length]
-            short_mask = torch.zeros(
-                quantized.size(0),
-                quantized.size(1),
-                reduced_length,
-                device=quantized.device,
-                dtype=torch.float,
-            )
-            short_mask[:, :, mask_indices] = 1.0
-            long_mask = short_mask.repeat_interleave(block_size, dim=-1)
-            long_mask = F.interpolate(
-                long_mask, size=quantized.size(-1), mode="nearest"
-            )
-            quantized = quantized.masked_fill(long_mask > 0.5, 0)
-
-        x, m_p, logs_p, y_mask = self.enc_p(
-            quantized, gt_spec_lengths, text, text_lengths, ge
-        )
-        z, m_q, logs_q, y_mask = self.enc_q(gt_specs, gt_spec_lengths, g=ge)
-        z_p = self.flow(z, y_mask, g=ge)
-
-        z_slice, ids_slice = commons.rand_slice_segments(
-            z, gt_spec_lengths, self.segment_size
-        )
-        o = self.dec(z_slice, g=ge)
-
-        return (
-            o,
-            ids_slice,
-            y_mask,
-            (z, z_p, m_p, logs_p, m_q, logs_q),
-        )
-
-    @torch.no_grad()
-    def infer(
-        self,
-        audio,
-        audio_lengths,
-        gt_specs,
-        gt_spec_lengths,
-        text,
-        text_lengths,
-        noise_scale=0.5,
-    ):
-        quantized = self.vq(audio, audio_lengths)
-        quantized_lengths = audio_lengths // 512
-        ge = self.encode_ref(gt_specs, gt_spec_lengths)
-
-        return self.decode(
-            quantized,
-            quantized_lengths,
-            text,
-            text_lengths,
-            noise_scale=noise_scale,
-            ge=ge,
-        )
-
-    @torch.no_grad()
-    def infer_posterior(
-        self,
-        gt_specs,
-        gt_spec_lengths,
-    ):
-        y_mask = torch.unsqueeze(
-            commons.sequence_mask(gt_spec_lengths, gt_specs.size(2)), 1
-        ).to(gt_specs.dtype)
-        ge = self.ref_enc(gt_specs * y_mask, y_mask)
-        z, m_q, logs_q, y_mask = self.enc_q(gt_specs, gt_spec_lengths, g=ge)
-        o = self.dec(z * y_mask, g=ge)
-
-        return o
-
-    @torch.no_grad()
-    def decode(
-        self,
-        quantized,
-        quantized_lengths,
-        text,
-        text_lengths,
-        noise_scale=0.5,
-        ge=None,
-    ):
-        x, m_p, logs_p, y_mask = self.enc_p(
-            quantized, quantized_lengths, text, text_lengths, ge
-        )
-        z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
-
-        z = self.flow(z_p, y_mask, g=ge, reverse=True)
-
-        o = self.dec(z * y_mask, g=ge)
-
-        return o
-
-    @torch.no_grad()
-    def encode_ref(self, gt_specs, gt_spec_lengths):
-        y_mask = torch.unsqueeze(
-            commons.sequence_mask(gt_spec_lengths, gt_specs.size(2)), 1
-        ).to(gt_specs.dtype)
-        ge = self.ref_enc(gt_specs * y_mask, y_mask)
-
-        return ge
-
-
-if __name__ == "__main__":
-    import librosa
-    from transformers import AutoTokenizer
-
-    from fish_speech.utils.spectrogram import LinearSpectrogram
-
-    model = SynthesizerTrn(
-        spec_channels=1025,
-        segment_size=20480 // 640,
-        inter_channels=192,
-        hidden_channels=192,
-        filter_channels=768,
-        n_heads=2,
-        n_layers=6,
-        kernel_size=3,
-        p_dropout=0.1,
-        resblock="1",
-        resblock_kernel_sizes=[3, 7, 11],
-        resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
-        upsample_rates=[8, 8, 2, 2, 2],
-        upsample_initial_channel=512,
-        upsample_kernel_sizes=[16, 16, 8, 2, 2],
-        gin_channels=512,
-    )
-
-    ckpt = "checkpoints/Bert-VITS2/G_0.pth"
-    # Try to load the model
-    print(f"Loading model from {ckpt}")
-    checkpoint = torch.load(ckpt, map_location="cpu", weights_only=True)["model"]
-    # d_checkpoint = torch.load(
-    #     "checkpoints/Bert-VITS2/D_0.pth", map_location="cpu", weights_only=True
-    # )["model"]
-    # print(checkpoint.keys())
-
-    checkpoint.pop("dec.cond.weight")
-    checkpoint.pop("enc_q.enc.cond_layer.weight_v")
-
-    # new_checkpoint = {}
-    # for k, v in checkpoint.items():
-    #     new_checkpoint["generator." + k] = v
-
-    # for k, v in d_checkpoint.items():
-    #     new_checkpoint["discriminator." + k] = v
-
-    # torch.save(new_checkpoint, "checkpoints/Bert-VITS2/ensemble.pth")
-    # exit()
-
-    print(model.load_state_dict(checkpoint, strict=False))
-
-    # Test
-
-    ref_audio = librosa.load(
-        "data/source/云天河/云天河-旁白/《薄太太》第0025集-yth_24.wav", sr=32000
-    )[0]
-    input_audio = librosa.load(
-        "data/source/云天河/云天河-旁白/《薄太太》第0025集-yth_24.wav", sr=32000
-    )[0]
-    ref_audio = input_audio
-    text = "博兴只知道身边的小女人没睡着,他又凑到她耳边压低了声线。阮苏眉睁眼,不觉得你老公像英雄吗?阮苏还是没反应,这男人是不是有病?刚才那冰冷又强势的样子,和现在这幼稚无赖的样子,根本就判若二人。"
-    encoded_text = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1")
-    spec = LinearSpectrogram(n_fft=2048, hop_length=640, win_length=2048)
-
-    ref_audio = torch.tensor(ref_audio).unsqueeze(0).unsqueeze(0)
-    ref_spec = spec(ref_audio)
-
-    input_audio = torch.tensor(input_audio).unsqueeze(0).unsqueeze(0)
-    text = encoded_text(text, return_tensors="pt")["input_ids"]
-    print(ref_audio.size(), ref_spec.size(), input_audio.size(), text.size())
-
-    o, y_mask, (z, z_p, m_p, logs_p) = model.infer(
-        input_audio,
-        torch.LongTensor([input_audio.size(2)]),
-        ref_spec,
-        torch.LongTensor([ref_spec.size(2)]),
-        text,
-        torch.LongTensor([text.size(1)]),
-    )
-    print(o.size(), y_mask.size(), z.size(), z_p.size(), m_p.size(), logs_p.size())
-
-    # Save output
-    # import soundfile as sf
-
-    # sf.write("output.wav", o.squeeze().detach().numpy(), 32000)

+ 0 - 619
fish_speech/models/vits_decoder/modules/modules.py

@@ -1,619 +0,0 @@
-import numpy as np
-import torch
-from torch import nn
-from torch.nn import Conv1d
-from torch.nn import functional as F
-from torch.nn.utils import remove_weight_norm, weight_norm
-
-from .commons import fused_add_tanh_sigmoid_multiply, get_padding, init_weights
-
-LRELU_SLOPE = 0.1
-
-
-class LayerNorm(nn.Module):
-    def __init__(self, channels, eps=1e-5):
-        super().__init__()
-        self.channels = channels
-        self.eps = eps
-
-        self.gamma = nn.Parameter(torch.ones(channels))
-        self.beta = nn.Parameter(torch.zeros(channels))
-
-    def forward(self, x):
-        x = x.transpose(1, -1)
-        x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
-        return x.transpose(1, -1)
-
-
-class ConvReluNorm(nn.Module):
-    def __init__(
-        self,
-        in_channels,
-        hidden_channels,
-        out_channels,
-        kernel_size,
-        n_layers,
-        p_dropout,
-    ):
-        super().__init__()
-        self.in_channels = in_channels
-        self.hidden_channels = hidden_channels
-        self.out_channels = out_channels
-        self.kernel_size = kernel_size
-        self.n_layers = n_layers
-        self.p_dropout = p_dropout
-        assert n_layers > 1, "Number of layers should be larger than 0."
-
-        self.conv_layers = nn.ModuleList()
-        self.norm_layers = nn.ModuleList()
-        self.conv_layers.append(
-            nn.Conv1d(
-                in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
-            )
-        )
-        self.norm_layers.append(LayerNorm(hidden_channels))
-        self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
-        for _ in range(n_layers - 1):
-            self.conv_layers.append(
-                nn.Conv1d(
-                    hidden_channels,
-                    hidden_channels,
-                    kernel_size,
-                    padding=kernel_size // 2,
-                )
-            )
-            self.norm_layers.append(LayerNorm(hidden_channels))
-        self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
-        self.proj.weight.data.zero_()
-        self.proj.bias.data.zero_()
-
-    def forward(self, x, x_mask):
-        x_org = x
-        for i in range(self.n_layers):
-            x = self.conv_layers[i](x * x_mask)
-            x = self.norm_layers[i](x)
-            x = self.relu_drop(x)
-        x = x_org + self.proj(x)
-        return x * x_mask
-
-
-class WN(torch.nn.Module):
-    def __init__(
-        self,
-        hidden_channels,
-        kernel_size,
-        dilation_rate,
-        n_layers,
-        gin_channels=0,
-        p_dropout=0,
-    ):
-        super(WN, self).__init__()
-        assert kernel_size % 2 == 1
-        self.hidden_channels = hidden_channels
-        self.kernel_size = (kernel_size,)
-        self.dilation_rate = dilation_rate
-        self.n_layers = n_layers
-        self.gin_channels = gin_channels
-        self.p_dropout = p_dropout
-
-        self.in_layers = torch.nn.ModuleList()
-        self.res_skip_layers = torch.nn.ModuleList()
-        self.drop = nn.Dropout(p_dropout)
-
-        if gin_channels != 0:
-            cond_layer = torch.nn.Conv1d(
-                gin_channels, 2 * hidden_channels * n_layers, 1
-            )
-            self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
-
-        for i in range(n_layers):
-            dilation = dilation_rate**i
-            padding = int((kernel_size * dilation - dilation) / 2)
-            in_layer = torch.nn.Conv1d(
-                hidden_channels,
-                2 * hidden_channels,
-                kernel_size,
-                dilation=dilation,
-                padding=padding,
-            )
-            in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
-            self.in_layers.append(in_layer)
-
-            # last one is not necessary
-            if i < n_layers - 1:
-                res_skip_channels = 2 * hidden_channels
-            else:
-                res_skip_channels = hidden_channels
-
-            res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
-            res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
-            self.res_skip_layers.append(res_skip_layer)
-
-    def forward(self, x, x_mask, g=None, **kwargs):
-        output = torch.zeros_like(x)
-        n_channels_tensor = torch.IntTensor([self.hidden_channels])
-
-        if g is not None:
-            g = self.cond_layer(g)
-
-        for i in range(self.n_layers):
-            x_in = self.in_layers[i](x)
-            if g is not None:
-                cond_offset = i * 2 * self.hidden_channels
-                g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
-            else:
-                g_l = torch.zeros_like(x_in)
-
-            acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
-            acts = self.drop(acts)
-
-            res_skip_acts = self.res_skip_layers[i](acts)
-            if i < self.n_layers - 1:
-                res_acts = res_skip_acts[:, : self.hidden_channels, :]
-                x = (x + res_acts) * x_mask
-                output = output + res_skip_acts[:, self.hidden_channels :, :]
-            else:
-                output = output + res_skip_acts
-        return output * x_mask
-
-    def remove_weight_norm(self):
-        if self.gin_channels != 0:
-            torch.nn.utils.remove_weight_norm(self.cond_layer)
-        for l in self.in_layers:
-            torch.nn.utils.remove_weight_norm(l)
-        for l in self.res_skip_layers:
-            torch.nn.utils.remove_weight_norm(l)
-
-
-class ResBlock1(torch.nn.Module):
-    def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
-        super(ResBlock1, self).__init__()
-        self.convs1 = nn.ModuleList(
-            [
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[0],
-                        padding=get_padding(kernel_size, dilation[0]),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[1],
-                        padding=get_padding(kernel_size, dilation[1]),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[2],
-                        padding=get_padding(kernel_size, dilation[2]),
-                    )
-                ),
-            ]
-        )
-        self.convs1.apply(init_weights)
-
-        self.convs2 = nn.ModuleList(
-            [
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=1,
-                        padding=get_padding(kernel_size, 1),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=1,
-                        padding=get_padding(kernel_size, 1),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=1,
-                        padding=get_padding(kernel_size, 1),
-                    )
-                ),
-            ]
-        )
-        self.convs2.apply(init_weights)
-
-    def forward(self, x, x_mask=None):
-        for c1, c2 in zip(self.convs1, self.convs2):
-            xt = F.leaky_relu(x, LRELU_SLOPE)
-            if x_mask is not None:
-                xt = xt * x_mask
-            xt = c1(xt)
-            xt = F.leaky_relu(xt, LRELU_SLOPE)
-            if x_mask is not None:
-                xt = xt * x_mask
-            xt = c2(xt)
-            x = xt + x
-        if x_mask is not None:
-            x = x * x_mask
-        return x
-
-    def remove_weight_norm(self):
-        for l in self.convs1:
-            remove_weight_norm(l)
-        for l in self.convs2:
-            remove_weight_norm(l)
-
-
-class ResBlock2(torch.nn.Module):
-    def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
-        super(ResBlock2, self).__init__()
-        self.convs = nn.ModuleList(
-            [
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[0],
-                        padding=get_padding(kernel_size, dilation[0]),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[1],
-                        padding=get_padding(kernel_size, dilation[1]),
-                    )
-                ),
-            ]
-        )
-        self.convs.apply(init_weights)
-
-    def forward(self, x, x_mask=None):
-        for c in self.convs:
-            xt = F.leaky_relu(x, LRELU_SLOPE)
-            if x_mask is not None:
-                xt = xt * x_mask
-            xt = c(xt)
-            x = xt + x
-        if x_mask is not None:
-            x = x * x_mask
-        return x
-
-    def remove_weight_norm(self):
-        for l in self.convs:
-            remove_weight_norm(l)
-
-
-class Flip(nn.Module):
-    def forward(self, x, *args, reverse=False, **kwargs):
-        x = torch.flip(x, [1])
-        if not reverse:
-            logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
-            return x, logdet
-        else:
-            return x
-
-
-class ResidualCouplingLayer(nn.Module):
-    def __init__(
-        self,
-        channels,
-        hidden_channels,
-        kernel_size,
-        dilation_rate,
-        n_layers,
-        p_dropout=0,
-        gin_channels=0,
-        mean_only=False,
-    ):
-        assert channels % 2 == 0, "channels should be divisible by 2"
-        super().__init__()
-        self.channels = channels
-        self.hidden_channels = hidden_channels
-        self.kernel_size = kernel_size
-        self.dilation_rate = dilation_rate
-        self.n_layers = n_layers
-        self.half_channels = channels // 2
-        self.mean_only = mean_only
-
-        self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
-        self.enc = WN(
-            hidden_channels,
-            kernel_size,
-            dilation_rate,
-            n_layers,
-            p_dropout=p_dropout,
-            gin_channels=gin_channels,
-        )
-        self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
-        self.post.weight.data.zero_()
-        self.post.bias.data.zero_()
-
-    def forward(self, x, x_mask, g=None, reverse=False):
-        x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
-        h = self.pre(x0) * x_mask
-        h = self.enc(h, x_mask, g=g)
-        stats = self.post(h) * x_mask
-        if not self.mean_only:
-            m, logs = torch.split(stats, [self.half_channels] * 2, 1)
-        else:
-            m = stats
-            logs = torch.zeros_like(m)
-
-        if not reverse:
-            x1 = m + x1 * torch.exp(logs) * x_mask
-            x = torch.cat([x0, x1], 1)
-            logdet = torch.sum(logs, [1, 2])
-            return x, logdet
-        else:
-            x1 = (x1 - m) * torch.exp(-logs) * x_mask
-            x = torch.cat([x0, x1], 1)
-            return x
-
-
-class LinearNorm(nn.Module):
-    def __init__(
-        self,
-        in_channels,
-        out_channels,
-        bias=True,
-        spectral_norm=False,
-    ):
-        super(LinearNorm, self).__init__()
-        self.fc = nn.Linear(in_channels, out_channels, bias)
-
-        if spectral_norm:
-            self.fc = nn.utils.spectral_norm(self.fc)
-
-    def forward(self, input):
-        out = self.fc(input)
-        return out
-
-
-class Mish(nn.Module):
-    def __init__(self):
-        super(Mish, self).__init__()
-
-    def forward(self, x):
-        return x * torch.tanh(F.softplus(x))
-
-
-class Conv1dGLU(nn.Module):
-    """
-    Conv1d + GLU(Gated Linear Unit) with residual connection.
-    For GLU refer to https://arxiv.org/abs/1612.08083 paper.
-    """
-
-    def __init__(self, in_channels, out_channels, kernel_size, dropout):
-        super(Conv1dGLU, self).__init__()
-        self.out_channels = out_channels
-        self.conv1 = ConvNorm(in_channels, 2 * out_channels, kernel_size=kernel_size)
-        self.dropout = nn.Dropout(dropout)
-
-    def forward(self, x):
-        residual = x
-        x = self.conv1(x)
-        x1, x2 = torch.split(x, split_size_or_sections=self.out_channels, dim=1)
-        x = x1 * torch.sigmoid(x2)
-        x = residual + self.dropout(x)
-        return x
-
-
-class ConvNorm(nn.Module):
-    def __init__(
-        self,
-        in_channels,
-        out_channels,
-        kernel_size=1,
-        stride=1,
-        padding=None,
-        dilation=1,
-        bias=True,
-        spectral_norm=False,
-    ):
-        super(ConvNorm, self).__init__()
-
-        if padding is None:
-            assert kernel_size % 2 == 1
-            padding = int(dilation * (kernel_size - 1) / 2)
-
-        self.conv = torch.nn.Conv1d(
-            in_channels,
-            out_channels,
-            kernel_size=kernel_size,
-            stride=stride,
-            padding=padding,
-            dilation=dilation,
-            bias=bias,
-        )
-
-        if spectral_norm:
-            self.conv = nn.utils.spectral_norm(self.conv)
-
-    def forward(self, input):
-        out = self.conv(input)
-        return out
-
-
-class MultiHeadAttention(nn.Module):
-    """Multi-Head Attention module"""
-
-    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.0, spectral_norm=False):
-        super().__init__()
-
-        self.n_head = n_head
-        self.d_k = d_k
-        self.d_v = d_v
-
-        self.w_qs = nn.Linear(d_model, n_head * d_k)
-        self.w_ks = nn.Linear(d_model, n_head * d_k)
-        self.w_vs = nn.Linear(d_model, n_head * d_v)
-
-        self.attention = ScaledDotProductAttention(
-            temperature=np.power(d_model, 0.5), dropout=dropout
-        )
-
-        self.fc = nn.Linear(n_head * d_v, d_model)
-        self.dropout = nn.Dropout(dropout)
-
-        if spectral_norm:
-            self.w_qs = nn.utils.spectral_norm(self.w_qs)
-            self.w_ks = nn.utils.spectral_norm(self.w_ks)
-            self.w_vs = nn.utils.spectral_norm(self.w_vs)
-            self.fc = nn.utils.spectral_norm(self.fc)
-
-    def forward(self, x, mask=None):
-        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
-        sz_b, len_x, _ = x.size()
-
-        residual = x
-
-        q = self.w_qs(x).view(sz_b, len_x, n_head, d_k)
-        k = self.w_ks(x).view(sz_b, len_x, n_head, d_k)
-        v = self.w_vs(x).view(sz_b, len_x, n_head, d_v)
-        q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_k)  # (n*b) x lq x dk
-        k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_k)  # (n*b) x lk x dk
-        v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_v)  # (n*b) x lv x dv
-
-        if mask is not None:
-            slf_mask = mask.repeat(n_head, 1, 1)  # (n*b) x .. x ..
-        else:
-            slf_mask = None
-        output, attn = self.attention(q, k, v, mask=slf_mask)
-
-        output = output.view(n_head, sz_b, len_x, d_v)
-        output = (
-            output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_x, -1)
-        )  # b x lq x (n*dv)
-
-        output = self.fc(output)
-
-        output = self.dropout(output) + residual
-        return output, attn
-
-
-class ScaledDotProductAttention(nn.Module):
-    """Scaled Dot-Product Attention"""
-
-    def __init__(self, temperature, dropout):
-        super().__init__()
-        self.temperature = temperature
-        self.softmax = nn.Softmax(dim=2)
-        self.dropout = nn.Dropout(dropout)
-
-    def forward(self, q, k, v, mask=None):
-        attn = torch.bmm(q, k.transpose(1, 2))
-        attn = attn / self.temperature
-
-        if mask is not None:
-            attn = attn.masked_fill(mask, -np.inf)
-
-        attn = self.softmax(attn)
-        p_attn = self.dropout(attn)
-
-        output = torch.bmm(p_attn, v)
-        return output, attn
-
-
-class MelStyleEncoder(nn.Module):
-    """MelStyleEncoder"""
-
-    def __init__(
-        self,
-        n_mel_channels=80,
-        style_hidden=128,
-        style_vector_dim=256,
-        style_kernel_size=5,
-        style_head=2,
-        dropout=0.1,
-    ):
-        super(MelStyleEncoder, self).__init__()
-        self.in_dim = n_mel_channels
-        self.hidden_dim = style_hidden
-        self.out_dim = style_vector_dim
-        self.kernel_size = style_kernel_size
-        self.n_head = style_head
-        self.dropout = dropout
-
-        self.spectral = nn.Sequential(
-            LinearNorm(self.in_dim, self.hidden_dim),
-            Mish(),
-            nn.Dropout(self.dropout),
-            LinearNorm(self.hidden_dim, self.hidden_dim),
-            Mish(),
-            nn.Dropout(self.dropout),
-        )
-
-        self.temporal = nn.Sequential(
-            Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
-            Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
-        )
-
-        self.slf_attn = MultiHeadAttention(
-            self.n_head,
-            self.hidden_dim,
-            self.hidden_dim // self.n_head,
-            self.hidden_dim // self.n_head,
-            self.dropout,
-        )
-
-        self.fc = LinearNorm(self.hidden_dim, self.out_dim)
-
-    def temporal_avg_pool(self, x, mask=None):
-        if mask is None:
-            out = torch.mean(x, dim=1)
-        else:
-            len_ = (~mask).sum(dim=1).unsqueeze(1)
-            x = x.masked_fill(mask.unsqueeze(-1), 0)
-            x = x.sum(dim=1)
-            out = torch.div(x, len_)
-        return out
-
-    def forward(self, x, mask=None):
-        x = x.transpose(1, 2)
-        if mask is not None:
-            mask = (mask.int() == 0).squeeze(1)
-        max_len = x.shape[1]
-        slf_attn_mask = (
-            mask.unsqueeze(1).expand(-1, max_len, -1) if mask is not None else None
-        )
-
-        # spectral
-        x = self.spectral(x)
-        # temporal
-        x = x.transpose(1, 2)
-        x = self.temporal(x)
-        x = x.transpose(1, 2)
-        # self-attention
-        if mask is not None:
-            x = x.masked_fill(mask.unsqueeze(-1), 0)
-        x, _ = self.slf_attn(x, mask=slf_attn_mask)
-        # fc
-        x = self.fc(x)
-        # temoral average pooling
-        w = self.temporal_avg_pool(x, mask=mask)
-
-        return w.unsqueeze(-1)

+ 0 - 58
fish_speech/models/vits_decoder/modules/mrte.py

@@ -1,58 +0,0 @@
-import torch
-from torch import nn
-from torch.nn.utils import remove_weight_norm, weight_norm
-
-from fish_speech.models.vits_decoder.modules.attentions import MultiHeadAttention
-
-
-class MRTE(nn.Module):
-    def __init__(
-        self,
-        content_enc_channels=192,
-        hidden_size=512,
-        out_channels=192,
-        n_heads=4,
-    ):
-        super(MRTE, self).__init__()
-        self.cross_attention = MultiHeadAttention(hidden_size, hidden_size, n_heads)
-        self.c_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
-        self.text_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
-        self.c_post = nn.Conv1d(hidden_size, out_channels, 1)
-
-    def forward(self, ssl_enc, ssl_mask, text, text_mask, ge, test=None):
-        if ge == None:
-            ge = 0
-        attn_mask = text_mask.unsqueeze(2) * ssl_mask.unsqueeze(-1)
-
-        ssl_enc = self.c_pre(ssl_enc * ssl_mask)
-        text_enc = self.text_pre(text * text_mask)
-        if test != None:
-            if test == 0:
-                x = (
-                    self.cross_attention(
-                        ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
-                    )
-                    + ssl_enc
-                    + ge
-                )
-            elif test == 1:
-                x = ssl_enc + ge
-            elif test == 2:
-                x = (
-                    self.cross_attention(
-                        ssl_enc * 0 * ssl_mask, text_enc * text_mask, attn_mask
-                    )
-                    + ge
-                )
-            else:
-                raise ValueError("test should be 0,1,2")
-        else:
-            x = (
-                self.cross_attention(
-                    ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
-                )
-                + ssl_enc
-                + ge
-            )
-        x = self.c_post(x * ssl_mask)
-        return x

+ 0 - 101
fish_speech/models/vits_decoder/modules/vq_encoder.py

@@ -1,101 +0,0 @@
-import math
-
-import torch
-from torch import nn
-
-from fish_speech.models.vqgan.modules.fsq import DownsampleFiniteScalarQuantize
-from fish_speech.models.vqgan.modules.wavenet import WaveNet
-from fish_speech.models.vqgan.utils import sequence_mask
-from fish_speech.utils.spectrogram import LogMelSpectrogram
-
-
-class VQEncoder(nn.Module):
-    def __init__(
-        self,
-    ):
-        super().__init__()
-
-        self.encoder = WaveNet(
-            input_channels=128,
-            residual_channels=768,
-            residual_layers=20,
-            dilation_cycle=4,
-        )
-
-        self.quantizer = DownsampleFiniteScalarQuantize(
-            input_dim=768, n_codebooks=1, n_groups=2, levels=[8, 5, 5, 5]
-        )
-
-        self.spec = LogMelSpectrogram(
-            sample_rate=44100,
-            n_fft=2048,
-            win_length=2048,
-            hop_length=512,
-            n_mels=128,
-            f_min=0.0,
-            f_max=8000.0,
-        )
-
-        self.eval()
-        e = self.load_state_dict(
-            torch.load("checkpoints/vq-gan-group-fsq-2x1024.pth", map_location="cpu"),
-            strict=False,
-        )
-
-        assert len(e.missing_keys) == 0, e.missing_keys
-        assert all(
-            k.startswith("decoder.")
-            or k.startswith("quality_projection.")
-            or k.startswith("discriminator.")
-            for k in e.unexpected_keys
-        ), e.unexpected_keys
-
-    @torch.no_grad()
-    def forward(self, audios, audio_lengths, sr=None):
-        mel_spec = self.spec(audios, sample_rate=sr)
-
-        if sr is not None:
-            audio_lengths = audio_lengths * 44100 // sr
-
-        mel_lengths = audio_lengths // self.spec.hop_length
-        mel_masks = (
-            torch.arange(mel_spec.shape[2], device=mel_spec.device)
-            < mel_lengths[:, None]
-        )
-        mel_masks_float_conv = mel_masks[:, None, :].float()
-        mels = mel_spec * mel_masks_float_conv
-
-        # Encode
-        encoded_features = self.encoder(mels) * mel_masks_float_conv
-        encoded_features = self.quantizer(encoded_features).z * mel_masks_float_conv
-
-        return encoded_features
-
-    @torch.no_grad()
-    def indicies_to_vq_features(
-        self,
-        indices,
-        feature_lengths,
-    ):
-        factor = math.prod(self.quantizer.downsample_factor)
-        mel_masks = sequence_mask(feature_lengths * factor, indices.shape[2] * factor)
-        mel_masks_float_conv = mel_masks[:, None, :].float()
-        z = self.quantizer.decode(indices) * mel_masks_float_conv
-
-        return z
-
-    @torch.no_grad()
-    def encode(self, audios, audio_lengths, sr=None):
-        audios = audios.float()
-
-        mels = self.spec(audios, sample_rate=sr)
-        mel_lengths = audio_lengths // self.spec.hop_length
-        mel_masks = sequence_mask(mel_lengths, mels.shape[2])
-        mel_masks_float_conv = mel_masks[:, None, :].float()
-        mels = mels * mel_masks_float_conv
-
-        # Encode
-        encoded_features = self.encoder(mels) * mel_masks_float_conv
-        feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor)
-
-        return self.quantizer.encode(encoded_features), feature_lengths

+ 86 - 0
fish_speech/models/vqgan/modules/firefly.py

@@ -1,5 +1,6 @@
 # A inference only version of the FireflyGAN model
 # A inference only version of the FireflyGAN model
 
 
+import math
 from functools import partial
 from functools import partial
 from math import prod
 from math import prod
 from typing import Callable
 from typing import Callable
@@ -13,6 +14,8 @@ from torch.nn.utils.parametrizations import weight_norm
 from torch.nn.utils.parametrize import remove_parametrizations
 from torch.nn.utils.parametrize import remove_parametrizations
 from torch.utils.checkpoint import checkpoint
 from torch.utils.checkpoint import checkpoint
 
 
+from fish_speech.models.vqgan.utils import sequence_mask
+
 
 
 def init_weights(m, mean=0.0, std=0.01):
 def init_weights(m, mean=0.0, std=0.01):
     classname = m.__class__.__name__
     classname = m.__class__.__name__
@@ -474,6 +477,89 @@ class ConvNeXtEncoder(nn.Module):
         return self.norm(x)
         return self.norm(x)
 
 
 
 
+class FireflyArchitecture(nn.Module):
+    def __init__(
+        self,
+        backbone: nn.Module,
+        head: nn.Module,
+        quantizer: nn.Module,
+        spec_transform: nn.Module,
+    ):
+        super().__init__()
+
+        self.backbone = backbone
+        self.head = head
+        self.quantizer = quantizer
+        self.spec_transform = spec_transform
+
+    def forward(self, x: torch.Tensor, template=None, mask=None) -> torch.Tensor:
+        if self.spec_transform is not None:
+            x = self.spec_transform(x)
+
+        x = self.backbone(x)
+        if mask is not None:
+            x = x * mask
+
+        if self.quantizer is not None:
+            vq_result = self.quantizer(x)
+            x = vq_result.z
+
+            if mask is not None:
+                x = x * mask
+
+        x = self.head(x, template=template)
+
+        if x.ndim == 2:
+            x = x[:, None, :]
+
+        if self.vq is not None:
+            return x, vq_result
+
+        return x
+
+    def encode(self, audios, audio_lengths):
+        audios = audios.float()
+
+        mels = self.spec_transform(audios)
+        mel_lengths = audio_lengths // self.spec_transform.hop_length
+        mel_masks = sequence_mask(mel_lengths, mels.shape[2])
+        mel_masks_float_conv = mel_masks[:, None, :].float()
+        mels = mels * mel_masks_float_conv
+
+        # Encode
+        encoded_features = self.backbone(mels) * mel_masks_float_conv
+        feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor)
+
+        return self.quantizer.encode(encoded_features), feature_lengths
+
+    def decode(self, indices, feature_lengths) -> torch.Tensor:
+        factor = math.prod(self.quantizer.downsample_factor)
+        mel_masks = sequence_mask(feature_lengths * factor, indices.shape[2] * factor)
+        mel_masks_float_conv = mel_masks[:, None, :].float()
+
+        audio_masks = sequence_mask(
+            feature_lengths * factor * self.spec_transform.hop_length,
+            indices.shape[2] * factor * self.spec_transform.hop_length,
+        )
+        audio_masks_float_conv = audio_masks[:, None, :].float()
+
+        z = self.quantizer.decode(indices) * mel_masks_float_conv
+        x = self.head(z) * audio_masks_float_conv
+
+        return x
+
+    def remove_parametrizations(self):
+        if hasattr(self.backbone, "remove_parametrizations"):
+            self.backbone.remove_parametrizations()
+
+        if hasattr(self.head, "remove_parametrizations"):
+            self.head.remove_parametrizations()
+
+    @property
+    def device(self):
+        return next(self.parameters()).device
+
+
 class FireflyBase(nn.Module):
 class FireflyBase(nn.Module):
     def __init__(self, ckpt_path: str = None, pretrained: bool = True):
     def __init__(self, ckpt_path: str = None, pretrained: bool = True):
         super().__init__()
         super().__init__()

+ 1 - 1
fish_speech/models/vqgan/modules/fsq.py

@@ -20,7 +20,7 @@ class DownsampleFiniteScalarQuantize(nn.Module):
     def __init__(
     def __init__(
         self,
         self,
         input_dim: int = 512,
         input_dim: int = 512,
-        n_codebooks: int = 9,
+        n_codebooks: int = 1,
         n_groups: int = 1,
         n_groups: int = 1,
         levels: tuple[int] = (8, 5, 5, 5),  # Approximate 2**10
         levels: tuple[int] = (8, 5, 5, 5),  # Approximate 2**10
         downsample_factor: tuple[int] = (2, 2),
         downsample_factor: tuple[int] = (2, 2),

+ 7 - 7
fish_speech/webui/manage.py

@@ -26,7 +26,7 @@ from fish_speech.i18n import i18n
 from fish_speech.webui.launch_utils import Seafoam, is_module_installed, versions_html
 from fish_speech.webui.launch_utils import Seafoam, is_module_installed, versions_html
 
 
 config_path = cur_work_dir / "fish_speech" / "configs"
 config_path = cur_work_dir / "fish_speech" / "configs"
-vqgan_yml_path = config_path / "vqgan_finetune.yaml"
+vqgan_yml_path = config_path / "firefly_gan_vq.yaml"
 llama_yml_path = config_path / "text2semantic_finetune.yaml"
 llama_yml_path = config_path / "text2semantic_finetune.yaml"
 vits_yml_path = config_path / "vits_decoder_finetune.yaml"
 vits_yml_path = config_path / "vits_decoder_finetune.yaml"
 
 
@@ -137,7 +137,7 @@ def change_decoder_config(decoder_model_path):
         choices = ["vits_decoder_finetune", "vits_decoder_pretrain"]
         choices = ["vits_decoder_finetune", "vits_decoder_pretrain"]
         return gr.Dropdown(choices=choices, value=choices[0])
         return gr.Dropdown(choices=choices, value=choices[0])
     elif "vqgan" in decoder_model_path or "vq-gan" in decoder_model_path:
     elif "vqgan" in decoder_model_path or "vq-gan" in decoder_model_path:
-        choices = ["vqgan_finetune", "vqgan_pretrain"]
+        choices = ["firefly_gan_vq", "firefly_gan_vq"]
         return gr.Dropdown(choices=choices, value=choices[0])
         return gr.Dropdown(choices=choices, value=choices[0])
     else:
     else:
         raise ValueError("Invalid decoder name")
         raise ValueError("Invalid decoder name")
@@ -517,7 +517,7 @@ def train_process(
             PYTHON,
             PYTHON,
             "fish_speech/train.py",
             "fish_speech/train.py",
             "--config-name",
             "--config-name",
-            "vqgan_finetune",
+            "firefly_gan_vq",
             f"project={project}",
             f"project={project}",
             f"trainer.strategy.process_group_backend={backend}",
             f"trainer.strategy.process_group_backend={backend}",
             f"model.optimizer.lr={vqgan_lr}",
             f"model.optimizer.lr={vqgan_lr}",
@@ -590,9 +590,9 @@ def train_process(
                 "--batch-size",
                 "--batch-size",
                 "16",
                 "16",
                 "--config-name",
                 "--config-name",
-                "vqgan_pretrain",
+                "firefly_gan_vq",
                 "--checkpoint-path",
                 "--checkpoint-path",
-                "checkpoints/vq-gan-group-fsq-2x1024.pth",
+                "checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
             ]
             ]
         )
         )
 
 
@@ -1292,8 +1292,8 @@ with gr.Blocks(
                                     choices=[
                                     choices=[
                                         "vits_decoder_finetune",
                                         "vits_decoder_finetune",
                                         "vits_decoder_pretrain",
                                         "vits_decoder_pretrain",
-                                        "vqgan_finetune",
-                                        "vqgan_pretrain",
+                                        "firefly_gan_vq",
+                                        "firefly_gan_vq",
                                     ],
                                     ],
                                     allow_custom_value=True,
                                     allow_custom_value=True,
                                 )
                                 )

+ 2 - 1
pyproject.toml

@@ -35,7 +35,8 @@ dependencies = [
     "vector_quantize_pytorch>=1.14.24",
     "vector_quantize_pytorch>=1.14.24",
     "samplerate>=0.2.1",
     "samplerate>=0.2.1",
     "resampy>=0.4.3",
     "resampy>=0.4.3",
-    "einx[torch]==0.2.2"
+    "einx[torch]==0.2.2",
+    "zstandard>=0.22.0"
 ]
 ]
 
 
 [project.optional-dependencies]
 [project.optional-dependencies]

+ 112 - 0
run.py

@@ -0,0 +1,112 @@
+import audioop
+import base64
+
+import numpy as np
+import soundfile as sf
+from fastapi import FastAPI, WebSocket
+from fastapi.responses import Response
+from loguru import logger
+
+from stream_service import FishAgentPipeline
+
+app = FastAPI()
+
+
+@app.post("/incoming")
+async def handle_incoming():
+    xml = """<Response>
+    <Connect>
+    <Stream url="wss://2427-24-4-31-213.ngrok-free.app/connection" />
+    </Connect>
+</Response>"""
+
+    logger.info("Incoming call received")
+    return Response(media_type="text/xml", content=xml)
+
+
+async def send_audio(ws, audio, stream_sid=""):
+    await ws.send_json(
+        {
+            "streamSid": stream_sid,
+            "event": "media",
+            "media": {
+                "payload": audio,
+            },
+        }
+    )
+
+
+def decode_mu_law(data):
+    samples = audioop.ulaw2lin(data, 2)
+    samples = np.frombuffer(samples, dtype=np.int16)
+    samples = samples.astype(np.float32) / 32768.0
+
+    return samples
+
+
+def encode_mu_law(data):
+    samples = np.clip(data, -1.0, 1.0)
+    samples = (samples * 32768).astype(np.int16)
+    samples = audioop.lin2ulaw(samples.tobytes(), 2)
+
+    return samples
+
+
+is_working = False
+
+
+@app.websocket("/connection")
+async def handle_connection(websocket: WebSocket):
+    global is_working
+
+    await websocket.accept()
+    logger.info("Connection established")
+    stream_sid = None
+    call_sid = None
+
+    if is_working:
+        logger.info("Already working, closing connection")
+        await websocket.close()
+        return
+
+    is_working = True
+    pipe.reset()
+
+    while True:
+        data = await websocket.receive_json()
+        if data["event"] == "connected":
+            logger.info("Connected message received")
+        elif data["event"] == "start":
+            stream_sid = data["start"]["streamSid"]
+            call_sid = data["start"]["callSid"]
+            logger.info(f"Start media streaming: {stream_sid} - {call_sid}")
+        elif data["event"] == "media":
+            payload = data["media"]["payload"]
+            chunk = base64.b64decode(payload)
+            samples = decode_mu_law(chunk)
+            for i in pipe.add_chunk(samples, sr=8000):
+                await send_audio(
+                    websocket, base64.b64encode(encode_mu_law(i)).decode(), stream_sid
+                )
+        elif data["event"] == "closed":
+            logger.info("Connection closed")
+            await websocket.close()
+            break
+        elif data["event"] == "stop":
+            logger.info("Stop media streaming")
+            await websocket.close()
+            break
+        else:
+            logger.info(f"Unknown event: {data}")
+
+    is_working = False
+
+
+if __name__ == "__main__":
+    import uvicorn
+
+    pipe = FishAgentPipeline()
+    pipe.warmup()
+
+    logger.info("Starting server")
+    uvicorn.run(app, host="localhost", port=5000)

+ 412 - 0
stream_service.py

@@ -0,0 +1,412 @@
+import time
+
+import librosa
+import numpy as np
+import torch
+import torchaudio
+from loguru import logger
+from torchaudio import functional as AF
+from transformers import (
+    AutoModelForSpeechSeq2Seq,
+    AutoProcessor,
+    AutoTokenizer,
+    pipeline,
+)
+
+from fish_speech.conversation import (
+    CODEBOOK_EOS_TOKEN_ID,
+    Conversation,
+    Message,
+    TokensPart,
+    encode_conversation,
+)
+from fish_speech.models.text2semantic.llama import DualARTransformer
+from tools.api import decode_vq_tokens, encode_reference
+from tools.llama.generate_test import convert_string
+from tools.llama.generate_test import generate as llama_generate
+from tools.llama.generate_test import load_model as load_llama_model
+from tools.vqgan.inference import load_model as load_decoder_model
+
+
+class FishStreamVAD:
+    def __init__(self) -> None:
+        # Args
+        self.sample_rate = 16000
+        self.threshold = 0.5
+        self.neg_threshold = self.threshold - 0.15
+        self.min_speech_duration_ms = 100
+        self.min_silence_ms = 500
+        self.speech_pad_ms = 30
+        self.chunk_size = 512
+
+        # Convert to samples
+        self.min_speech_duration_samples = (
+            self.min_speech_duration_ms * self.sample_rate // 1000
+        )
+        self.min_silence_samples = self.min_silence_ms * self.sample_rate // 1000
+        self.speech_pad_samples = self.speech_pad_ms * self.sample_rate // 1000
+
+        # Core buffers
+        self.reset()
+
+        # Load models
+        logger.info("Loading VAD model")
+        vad_model, vad_utils = torch.hub.load(
+            repo_or_dir="snakers4/silero-vad",
+            model="silero_vad",
+            force_reload=True,
+            onnx=True,
+        )
+
+        self.vad_model = vad_model
+        self.get_speech_timestamps = vad_utils[0]
+        logger.info("VAD model loaded")
+
+    def reset(self):
+        self.audio_chunks = None
+        self.vad_pointer = 0
+        self.speech_probs = []
+
+        self.triggered = False
+        self.start = self.end = self.temp_end = 0
+        self.last_seen_end = 0
+        self.speech_segments = []
+
+    def add_chunk(self, chunk, sr=None):
+        """
+        Add a chunk to the buffer
+        """
+
+        if isinstance(chunk, np.ndarray):
+            chunk = torch.from_numpy(chunk)
+
+        if sr is not None and sr != self.sample_rate:
+            chunk = AF.resample(chunk, sr, self.sample_rate)
+
+        # self.audio_chunks.append(chunk)
+        if self.audio_chunks is None:
+            self.audio_chunks = chunk
+        else:
+            self.audio_chunks = torch.cat([self.audio_chunks, chunk])
+
+        # Trigger VAD
+        yield from self.detect_speech()
+
+    def detect_speech(self):
+        """
+        Run the VAD model on the current buffer
+        """
+
+        speech_prob_start_idx = len(self.speech_probs)
+        while len(self.audio_chunks) - self.vad_pointer >= self.chunk_size:
+            chunk = self.audio_chunks[
+                self.vad_pointer : self.vad_pointer + self.chunk_size
+            ]
+            speech_prob = self.vad_model(chunk, self.sample_rate)
+            self.speech_probs.append(speech_prob)
+            self.vad_pointer += self.chunk_size
+
+        # Process speech probs
+        for i in range(speech_prob_start_idx, len(self.speech_probs)):
+            speech_prob = self.speech_probs[i]
+
+            if speech_prob >= self.threshold and self.temp_end:
+                self.temp_end = 0
+
+            if speech_prob >= self.threshold and self.triggered is False:
+                self.triggered = True
+                self.start = i * self.chunk_size
+                continue
+
+            if speech_prob < self.neg_threshold and self.triggered is True:
+                if self.temp_end == 0:
+                    self.temp_end = i * self.chunk_size
+
+                if i * self.chunk_size - self.temp_end < self.min_silence_samples:
+                    continue
+
+                self.end = self.temp_end
+                if self.end - self.start > self.min_speech_duration_samples:
+                    yield self.audio_chunks[
+                        self.start : self.end + self.speech_pad_samples
+                    ]
+
+                self.triggered = False
+                self.start = self.end = self.temp_end = 0
+
+
+class FishASR:
+    def __init__(self) -> None:
+        self.audio_chunks = None
+        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
+        torch_dtype = torch.bfloat16
+        model_id = "openai/whisper-medium.en"
+
+        logger.info("Loading ASR model")
+        model = AutoModelForSpeechSeq2Seq.from_pretrained(
+            model_id, torch_dtype=torch_dtype, use_safetensors=True
+        ).to(self.device)
+        processor = AutoProcessor.from_pretrained(model_id)
+        self.pipe = pipeline(
+            "automatic-speech-recognition",
+            model=model,
+            tokenizer=processor.tokenizer,
+            feature_extractor=processor.feature_extractor,
+            max_new_tokens=256,
+            torch_dtype=torch_dtype,
+            device=self.device,
+        )
+        logger.info("ASR model loaded")
+
+    @torch.inference_mode()
+    def run(self, chunk):
+        return self.pipe(chunk.numpy())
+
+
+class FishE2EAgent:
+    def __init__(self) -> None:
+        self.device = device = "cuda" if torch.cuda.is_available() else "cpu"
+        logger.info(f"Using device: {device}")
+
+        decoder_model = load_decoder_model(
+            config_name="firefly_gan_vq",
+            checkpoint_path="checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
+            device=device,
+        )
+        self.decoder_model = decoder_model
+        logger.info("Decoder model loaded")
+
+        llama_model, decode_one_token = load_llama_model(
+            config_name="dual_ar_2_codebook_1.3b",
+            checkpoint_path="checkpoints/step_000206000.ckpt",
+            device=device,
+            precision=torch.bfloat16,
+            max_length=2048,
+            compile=True,
+        )
+        self.llama_model: DualARTransformer = llama_model
+        self.decode_one_token = decode_one_token
+        logger.info("LLAMA model loaded")
+
+        self.tokenizer = AutoTokenizer.from_pretrained(
+            "checkpoints/fish-speech-agent-1"
+        )
+        self.semantic_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>")
+        self.im_end_id = self.tokenizer.convert_tokens_to_ids("<|im_end|>")
+        self.decoder_tokenizer = AutoTokenizer.from_pretrained(
+            "fishaudio/fish-speech-1"
+        )
+
+        # Control params
+        self.temperature = torch.tensor(0.7, device=device, dtype=torch.float)
+        self.top_p = torch.tensor(0.7, device=device, dtype=torch.float)
+        self.repetition_penalty = torch.tensor(1.2, device=device, dtype=torch.float)
+
+        # This is used to control the timbre of the generated audio
+        self.base_messages = [
+            # Message(
+            #     role="user",
+            #     parts=[np.load("example/q0.npy")],
+            # ),
+            # Message(
+            #     role="assistant",
+            #     parts=[
+            #         "Transcribed: Hi, can you briefly describe what is machine learning?\nResponse: Sure! Machine learning is the process of automating tasks that humans are capable of doing with a computer. It involves training computers to make decisions based on data.",
+            #         np.load("example/a0.npy"),
+            #     ],
+            # ),
+        ]
+        self.reference = encode_reference(
+            decoder_model=self.decoder_model,
+            reference_audio="example/a0.wav",
+            enable_reference_audio=True,
+        )
+        self.messages = self.base_messages.copy()
+
+    def reset(self):
+        self.messages = self.base_messages.copy()
+
+    @torch.inference_mode()
+    def vq_encode(self, audios, sr=None):
+        if isinstance(audios, np.ndarray):
+            audios = torch.from_numpy(audios)
+
+        if audios.ndim == 1:
+            audios = audios[None, None, :]
+
+        audios = audios.to(self.decoder_model.device)
+        if sr is not None and sr != self.decoder_model.sampling_rate:
+            audios = AF.resample(audios, sr, self.decoder_model.sampling_rate)
+
+        audio_lengths = torch.tensor(
+            [audios.shape[2]], device=self.decoder_model.device, dtype=torch.long
+        )
+
+        return self.decoder_model.encode(audios, audio_lengths)[0][0]
+
+    @torch.inference_mode()
+    def generate(self, audio_chunk, sr=None, text=None):
+        vq_output = self.vq_encode(audio_chunk, sr)
+        logger.info(f"VQ output: {vq_output.shape}")
+
+        # Encode conversation
+        self.messages.append(
+            Message(
+                role="user",
+                parts=[vq_output],
+            )
+        )
+
+        parts = []
+        if text is not None:
+            parts.append(f"Transcribed: {text}\nResponse:")
+
+        self.messages.append(
+            Message(
+                role="assistant",
+                parts=parts,
+            )
+        )
+        conversation = Conversation(self.messages)
+
+        # Encode the conversation
+        prompt, _ = encode_conversation(
+            conversation, self.tokenizer, self.llama_model.config.num_codebooks
+        )
+        prompt = prompt[:, :-1].to(dtype=torch.int, device=self.device)
+        prompt_length = prompt.shape[1]
+
+        # Generate
+        y = llama_generate(
+            model=self.llama_model,
+            prompt=prompt,
+            max_new_tokens=0,
+            eos_token_id=self.tokenizer.eos_token_id,
+            im_end_id=self.im_end_id,
+            decode_one_token=self.decode_one_token,
+            temperature=self.temperature,
+            top_p=self.top_p,
+            repetition_penalty=self.repetition_penalty,
+        )
+
+        tokens = self.tokenizer.decode(
+            y[0, prompt_length:].tolist(), skip_special_tokens=False
+        )
+        logger.info(f"Generated: {convert_string(tokens)}")
+
+        # Put the generated tokens
+        # since there is <im_end> and <eos> tokens, we remove last 2 tokens
+        code_mask = y[0, prompt_length:-2] == self.semantic_id
+        codes = y[1:, prompt_length:-2][:, code_mask].clone()
+
+        codes = codes - 2
+        assert (codes >= 0).all(), f"Negative code found"
+
+        decoded = y[:, prompt_length:-1].clone()
+        if decoded[0, -1] != self.im_end_id:  # <im_end>
+            val = [[self.im_end_id]] + [[CODEBOOK_EOS_TOKEN_ID]] * (decoded.size(0) - 1)
+            decoded = torch.cat(
+                (decoded, torch.tensor(val, device=self.device, dtype=torch.int)), dim=1
+            )
+
+        decoded = decoded.cpu()
+        self.messages[-1].parts.append(
+            TokensPart(
+                tokens=decoded[:1],
+                codes=decoded[1:],
+            )
+        )
+
+        # Less than 5 * 20 = 100ms
+        if codes.shape[1] <= 5:
+            return
+
+        # Generate audio
+        main_tokens = decoded[0]
+        text_tokens = main_tokens[main_tokens != self.semantic_id]
+        text = self.tokenizer.decode(text_tokens.tolist(), skip_special_tokens=True)
+        text_tokens = self.decoder_tokenizer.encode(text, return_tensors="pt").to(
+            self.device
+        )
+
+        audio = decode_vq_tokens(
+            decoder_model=self.decoder_model,
+            codes=codes,
+            text_tokens=text_tokens,
+            reference_embedding=self.reference,
+        )
+
+        if sr is not None and sr != self.decoder_model.sampling_rate:
+            audio = AF.resample(audio, self.decoder_model.sampling_rate, sr)
+
+        return audio.float()
+
+
+class FishAgentPipeline:
+    def __init__(self) -> None:
+        self.vad = FishStreamVAD()
+        # Currently use ASR model as intermediate
+        self.asr = FishASR()
+        self.agent = FishE2EAgent()
+
+        self.vad_segments = []
+        self.text_segments = []
+
+    def add_chunk(self, chunk, sr=None):
+        use_np = isinstance(chunk, np.ndarray)
+        if use_np:
+            chunk = torch.from_numpy(chunk)
+
+        if sr is not None and sr != 16000:
+            chunk = AF.resample(chunk, sr, 16000)
+
+        for vad_audio in self.vad.add_chunk(chunk, 16000):
+            self.vad_segments.append(vad_audio)
+            asr_text = self.asr.run(vad_audio)
+            self.text_segments.append(asr_text)
+            logger.info(f"ASR: {asr_text}")
+
+            # Actually should detect if intent is finished here
+            result = self.agent.generate(vad_audio, 16000, text=asr_text)
+            if result is None:
+                continue
+
+            if sr is not None and sr != 16000:
+                result = AF.resample(result, 16000, sr)
+
+            if use_np:
+                result = result.cpu().numpy()
+
+            yield result
+
+    def reset(self):
+        self.vad.reset()
+        self.agent.reset()
+        self.vad_segments = []
+        self.text_segments = []
+
+    def warmup(self):
+        logger.info("Warming up the pipeline")
+        audio, sr = librosa.load("example/q0.mp3", sr=16000)
+        for i in range(0, len(audio), 882):
+            for audio in self.add_chunk(audio[i : i + 882], sr):
+                pass
+        logger.info("Pipeline warmed up")
+        self.reset()
+
+
+if __name__ == "__main__":
+    import soundfile as sf
+
+    service = FishAgentPipeline()
+    service.warmup()
+    logger.info("Stream service started")
+
+    audio, sr = librosa.load("example/q1.mp3", sr=16000)
+    seg = []
+    for i in range(0, len(audio), 882):
+        for audio in service.add_chunk(audio[i : i + 882], sr):
+            seg.append(audio)
+
+    audio = np.concatenate(seg)
+    sf.write("output.wav", audio, 16000)

+ 183 - 0
test_echo.py

@@ -0,0 +1,183 @@
+import io
+import wave
+from typing import List
+
+import av
+import numpy as np
+from fastapi import FastAPI, WebSocket, WebSocketDisconnect
+from fastapi.responses import HTMLResponse
+
+app = FastAPI()
+
+html = """
+<!DOCTYPE html>
+<html>
+<head>
+    <title>Real-time Chat Room</title>
+</head>
+<body>
+    <h1>Real-time Chat Room</h1>
+    <button id="start">Start Streaming</button>
+    <button id="stop">Stop Streaming</button>
+    <script type="module">
+        import { MediaRecorder, register } from 'https://dev.jspm.io/npm:extendable-media-recorder';
+        import { connect } from 'https://dev.jspm.io/npm:extendable-media-recorder-wav-encoder';
+    
+        await register(await connect());
+
+        let socket;
+        let mediaRecorder;
+        let audioContext;
+
+        function startStreaming() {
+            initWebSocket();
+
+            audioContext = new (window.AudioContext || window.webkitAudioContext)();
+            navigator.mediaDevices.getUserMedia({ audio: {
+                channelCount: 1,  
+                sampleRate: 44100,
+                sampleSize: 16,
+                echoCancellation: true,
+                noiseSuppression: true
+            } })
+                .then(function (stream) {
+                    mediaRecorder = new MediaRecorder(stream, { mimeType: 'audio/webm;codecs=opus' });
+                    mediaRecorder.start(100);
+                    mediaRecorder.addEventListener("dataavailable", function (event) {
+                        socket.send(event.data);
+                    });
+                })
+                .catch(function (err) {
+                    console.error("Error accessing microphone:", err);
+                });
+
+                // Create a MediaSource
+                const mediaSource = new MediaSource();
+                const mediaStream = new MediaStream();
+
+                // Create an HTMLVideoElement and attach the MediaSource to it
+                const audioElement = document.createElement('audio');
+                audioElement.src = URL.createObjectURL(mediaSource);
+                audioElement.autoplay = true;
+                document.body.appendChild(audioElement);
+
+                mediaSource.addEventListener('sourceopen', function() {
+                    const sourceBuffer = mediaSource.addSourceBuffer('audio/webm; codecs=opus');
+
+                    socket.onmessage = function(event) {
+                        const arrayBuffer = event.data;
+
+                        sourceBuffer.appendBuffer(arrayBuffer);
+                    };
+                });
+        }
+
+        function stopStreaming() {
+            mediaRecorder.stop();
+        }
+
+        function initWebSocket() {
+            const is_wss = window.location.protocol === "https:";
+            socket = new WebSocket(`${is_wss ? "wss" : "ws"}://${window.location.host}/ws`);
+            socket.binaryType = 'arraybuffer';
+        }
+
+        document.getElementById("start").onclick = startStreaming;
+        document.getElementById("stop").onclick = stopStreaming;
+    </script>
+</body>
+</html>
+"""
+
+
+def encode_wav(data):
+    sample_rate = 44100
+    samples = np.frombuffer(data, dtype=np.int16)
+    buffer = io.BytesIO()
+
+    with wave.open(buffer, "wb") as wav_file:
+        wav_file.setnchannels(1)
+        wav_file.setsampwidth(2)
+        wav_file.setframerate(sample_rate)
+        wav_file.writeframes(samples.tobytes())
+
+    return buffer.getvalue()
+
+
+class ConnectionManager:
+    def __init__(self):
+        self.active_connections: List[WebSocket] = []
+
+    async def connect(self, websocket: WebSocket):
+        await websocket.accept()
+        self.active_connections.append(websocket)
+
+    def disconnect(self, websocket: WebSocket):
+        self.active_connections.remove(websocket)
+
+    async def broadcast(self, message: bytes, sender: WebSocket):
+        for connection in self.active_connections:
+            if connection == sender:
+                #     print("Sending message to client", connection)
+                await connection.send_bytes(message)
+
+
+manager = ConnectionManager()
+
+
+@app.get("/")
+async def get():
+    return HTMLResponse(html)
+
+
+@app.websocket("/ws")
+async def websocket_endpoint(websocket: WebSocket):
+    await manager.connect(websocket)
+    try:
+        buffer = io.BytesIO()
+        container = None
+        cur_pos = 0
+        total_size = 0
+
+        while True:
+            data = await websocket.receive_bytes()
+            # data = encode_wav(data)
+            # if len(data) == 1:
+            #     print(f"len(data): {len(data)}, data: {data}")
+            # if len(data) > 1:
+            #     data = b'\x1a' + data
+            #     with open("output.webm", "wb") as f:
+            #         f.write(data)
+            #     exit()
+            # print(f"len(data): {len(data)}")
+
+            # print("Received data:", data)
+            # Save as webm file and exit
+            # with open("output.wav", "wb") as f:
+            #     f.write(encode_wav(data))
+
+            buffer.write(data)
+            buffer.seek(cur_pos)
+            total_size += len(data)
+
+            if not container and total_size > 1000:
+                container = av.open(buffer, "r", format="webm")
+                print(container)
+            elif container:
+                for packet in container.decode(video=0):
+                    if packet.size == 0:
+                        continue
+
+                    cur_pos += packet.size
+                    for frame in packet.decode():
+                        print(frame.to_ndarray().shape)
+
+            await manager.broadcast(data, websocket)
+    except WebSocketDisconnect:
+        manager.disconnect(websocket)
+
+
+if __name__ == "__main__":
+    import uvicorn
+
+    uvicorn.run(app, host="0.0.0.0", port=8000)

+ 3 - 9
tools/api.py

@@ -400,21 +400,17 @@ def parse_args():
     parser.add_argument(
     parser.add_argument(
         "--llama-checkpoint-path",
         "--llama-checkpoint-path",
         type=str,
         type=str,
-        default="checkpoints/text2semantic-sft-medium-v1-4k.pth",
-    )
-    parser.add_argument(
-        "--llama-config-name", type=str, default="dual_ar_2_codebook_medium"
+        default="checkpoints/fish-speech-1.2",
     )
     )
     parser.add_argument(
     parser.add_argument(
         "--decoder-checkpoint-path",
         "--decoder-checkpoint-path",
         type=str,
         type=str,
-        default="checkpoints/vq-gan-group-fsq-2x1024.pth",
+        default="checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
     )
     )
-    parser.add_argument("--decoder-config-name", type=str, default="vqgan_pretrain")
+    parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
     parser.add_argument("--tokenizer", type=str, default="fishaudio/fish-speech-1")
     parser.add_argument("--tokenizer", type=str, default="fishaudio/fish-speech-1")
     parser.add_argument("--device", type=str, default="cuda")
     parser.add_argument("--device", type=str, default="cuda")
     parser.add_argument("--half", action="store_true")
     parser.add_argument("--half", action="store_true")
-    parser.add_argument("--max-length", type=int, default=2048)
     parser.add_argument("--compile", action="store_true")
     parser.add_argument("--compile", action="store_true")
     parser.add_argument("--max-text-length", type=int, default=0)
     parser.add_argument("--max-text-length", type=int, default=0)
     parser.add_argument("--listen", type=str, default="127.0.0.1:8000")
     parser.add_argument("--listen", type=str, default="127.0.0.1:8000")
@@ -450,11 +446,9 @@ if __name__ == "__main__":
 
 
     logger.info("Loading Llama model...")
     logger.info("Loading Llama model...")
     llama_queue = launch_thread_safe_queue(
     llama_queue = launch_thread_safe_queue(
-        config_name=args.llama_config_name,
         checkpoint_path=args.llama_checkpoint_path,
         checkpoint_path=args.llama_checkpoint_path,
         device=args.device,
         device=args.device,
         precision=args.precision,
         precision=args.precision,
-        max_length=args.max_length,
         compile=args.compile,
         compile=args.compile,
     )
     )
     llama_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
     llama_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)

+ 101 - 0
tools/llama/convert_hf_weights_to_llama.py

@@ -0,0 +1,101 @@
+import torch
+from transformers import LlamaForCausalLM
+
+from fish_speech.models.text2semantic.llama import BaseModelArgs, BaseTransformer
+
+# Load the HF model
+hf_model = LlamaForCausalLM.from_pretrained(
+    "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
+)
+
+model = BaseTransformer(
+    BaseModelArgs(
+        vocab_size=hf_model.config.vocab_size + 8,
+        n_layer=hf_model.config.num_hidden_layers,
+        n_head=hf_model.config.num_attention_heads,
+        n_local_heads=hf_model.config.num_key_value_heads,
+        dim=hf_model.config.hidden_size,
+        head_dim=hf_model.config.hidden_size // hf_model.config.num_attention_heads,
+        num_codebooks=2,
+        codebook_size=1032,
+    )
+)
+print(model.config)
+
+hf_state_dict = hf_model.state_dict()
+model_state_dict = model.state_dict()
+
+# print(hf_state_dict.keys())
+# print(model_state_dict.keys())
+
+new_state_dict = {}
+
+# Handle embeddings
+new_state_dict["embeddings.weight"] = model_state_dict.pop("embeddings.weight")
+hf_embed_tokens = hf_state_dict.pop("model.embed_tokens.weight")
+new_state_dict["embeddings.weight"][: hf_embed_tokens.shape[0]] = hf_embed_tokens
+
+# Restore layers
+for layer_idx in range(hf_model.config.num_hidden_layers):
+    # Handle attention
+    q_weight = hf_state_dict.pop(f"model.layers.{layer_idx}.self_attn.q_proj.weight")
+    k_weight = hf_state_dict.pop(f"model.layers.{layer_idx}.self_attn.k_proj.weight")
+    v_weight = hf_state_dict.pop(f"model.layers.{layer_idx}.self_attn.v_proj.weight")
+    qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0)
+    new_state_dict[f"layers.{layer_idx}.attention.wqkv.weight"] = qkv_weight
+    model_state_dict.pop(f"layers.{layer_idx}.attention.wqkv.weight")
+
+    o_weight = hf_state_dict.pop(f"model.layers.{layer_idx}.self_attn.o_proj.weight")
+    new_state_dict[f"layers.{layer_idx}.attention.wo.weight"] = o_weight
+    model_state_dict.pop(f"layers.{layer_idx}.attention.wo.weight")
+
+    # Handle feed forward
+    up_weight = hf_state_dict.pop(f"model.layers.{layer_idx}.mlp.up_proj.weight")
+    down_weight = hf_state_dict.pop(f"model.layers.{layer_idx}.mlp.down_proj.weight")
+    gate_weight = hf_state_dict.pop(f"model.layers.{layer_idx}.mlp.gate_proj.weight")
+
+    new_state_dict[f"layers.{layer_idx}.feed_forward.w1.weight"] = gate_weight
+    new_state_dict[f"layers.{layer_idx}.feed_forward.w2.weight"] = down_weight
+    new_state_dict[f"layers.{layer_idx}.feed_forward.w3.weight"] = up_weight
+
+    model_state_dict.pop(f"layers.{layer_idx}.feed_forward.w1.weight")
+    model_state_dict.pop(f"layers.{layer_idx}.feed_forward.w2.weight")
+    model_state_dict.pop(f"layers.{layer_idx}.feed_forward.w3.weight")
+
+    # Handle layer norms
+    input_layernorm_weight = hf_state_dict.pop(
+        f"model.layers.{layer_idx}.input_layernorm.weight"
+    )
+    post_attention_layernorm_weight = hf_state_dict.pop(
+        f"model.layers.{layer_idx}.post_attention_layernorm.weight"
+    )
+
+    new_state_dict[f"layers.{layer_idx}.ffn_norm.weight"] = (
+        post_attention_layernorm_weight
+    )
+    new_state_dict[f"layers.{layer_idx}.attention_norm.weight"] = input_layernorm_weight
+
+    model_state_dict.pop(f"layers.{layer_idx}.ffn_norm.weight")
+    model_state_dict.pop(f"layers.{layer_idx}.attention_norm.weight")
+
+# Handle final layer norm
+new_state_dict["norm.weight"] = hf_state_dict.pop("model.norm.weight")
+model_state_dict.pop("norm.weight")
+
+# Handle output layer
+w = hf_state_dict.pop("lm_head.weight")
+new_state_dict["output.weight"] = model_state_dict.pop("output.weight")
+new_state_dict["output.weight"][: w.shape[0]] = w
+
+print(hf_state_dict.keys(), len(hf_state_dict))
+print(model_state_dict.keys(), len(model_state_dict))
+
+print(model.load_state_dict(new_state_dict, strict=True))
+
+model = model.bfloat16()
+
+new_state_dict = {f"model.{k}": v for k, v in model.state_dict().items()}
+torch.save(
+    new_state_dict,
+    "checkpoints/fish-speech-agent-1/TinyLlama-1.1B-intermediate-step-1431k-3T.pth",
+)

+ 24 - 85
tools/llama/generate.py

@@ -19,7 +19,7 @@ from loguru import logger
 from tqdm import tqdm
 from tqdm import tqdm
 from transformers import AutoTokenizer
 from transformers import AutoTokenizer
 
 
-from fish_speech.datasets.text import CODEBOOK_EOS_TOKEN_ID, CODEBOOK_PAD_TOKEN_ID
+from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
 from fish_speech.text import clean_text, split_text
 from fish_speech.text import clean_text, split_text
 
 
 os.environ["TOKENIZERS_PARALLELISM"] = "false"
 os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -31,7 +31,11 @@ if hasattr(torch._inductor.config, "fx_graph_cache"):
     torch._inductor.config.fx_graph_cache = True
     torch._inductor.config.fx_graph_cache = True
 
 
 
 
-from fish_speech.models.text2semantic.llama import DualARTransformer, NaiveTransformer
+from fish_speech.models.text2semantic.llama import (
+    BaseTransformer,
+    DualARTransformer,
+    NaiveTransformer,
+)
 
 
 
 
 def multinomial_sample_one_no_sync(
 def multinomial_sample_one_no_sync(
@@ -161,7 +165,6 @@ def decode_n_tokens(
     cur_token: torch.Tensor,
     cur_token: torch.Tensor,
     input_pos: torch.Tensor,
     input_pos: torch.Tensor,
     num_new_tokens: int,
     num_new_tokens: int,
-    eos_token_id: int = 2,
     im_end_id: int = 4,
     im_end_id: int = 4,
     decode_one_token=decode_one_token_naive,
     decode_one_token=decode_one_token_naive,
     **sampling_kwargs,
     **sampling_kwargs,
@@ -197,11 +200,7 @@ def decode_n_tokens(
             model.config.num_codebooks + 1, -1
             model.config.num_codebooks + 1, -1
         )
         )
 
 
-        if (
-            cur_token[0, 0, -1] == eos_token_id
-            or cur_token[0, 0, -1] == im_end_id
-            or (cur_token[0, 1:, -1] == CODEBOOK_EOS_TOKEN_ID).any()
-        ):
+        if cur_token[0, 0, -1] == im_end_id:
             break
             break
 
 
     return previous_tokens[:, : i + 1]
     return previous_tokens[:, : i + 1]
@@ -214,7 +213,6 @@ def generate(
     model: NaiveTransformer,
     model: NaiveTransformer,
     prompt: torch.Tensor,
     prompt: torch.Tensor,
     max_new_tokens: int,
     max_new_tokens: int,
-    eos_token_id: int = 2,
     im_end_id: int = 4,
     im_end_id: int = 4,
     decode_one_token=decode_one_token_naive,
     decode_one_token=decode_one_token_naive,
     **sampling_kwargs,
     **sampling_kwargs,
@@ -255,6 +253,7 @@ def generate(
         if isinstance(model, NaiveTransformer)
         if isinstance(model, NaiveTransformer)
         else decode_one_token_ar
         else decode_one_token_ar
     )
     )
+
     next_token = prefill_decode(
     next_token = prefill_decode(
         model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
         model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
     )
     )
@@ -266,7 +265,6 @@ def generate(
         next_token.view(1, codebook_dim, -1),
         next_token.view(1, codebook_dim, -1),
         input_pos,
         input_pos,
         max_new_tokens - 1,
         max_new_tokens - 1,
-        eos_token_id=eos_token_id,
         im_end_id=im_end_id,
         im_end_id=im_end_id,
         decode_one_token=decode_one_token,
         decode_one_token=decode_one_token,
         **sampling_kwargs,
         **sampling_kwargs,
@@ -281,22 +279,12 @@ def generate(
 def encode_tokens(
 def encode_tokens(
     tokenizer,
     tokenizer,
     string,
     string,
-    bos=True,
     device="cuda",
     device="cuda",
     prompt_tokens=None,
     prompt_tokens=None,
-    speaker=None,
     num_codebooks=4,
     num_codebooks=4,
 ):
 ):
     string = clean_text(string)
     string = clean_text(string)
-
-    if speaker is None:
-        speaker = "assistant"
-
-    string = (
-        f"<|im_start|>user<|im_sep|>{string}<|im_end|><|im_start|>{speaker}<|im_sep|>"
-    )
-    if bos:
-        string = f"<|begin_of_sequence|>{string}"
+    string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n"
 
 
     new_tokens = tokenizer.encode(
     new_tokens = tokenizer.encode(
         string,
         string,
@@ -324,7 +312,7 @@ def encode_tokens(
         prompt_tokens = prompt_tokens[0]
         prompt_tokens = prompt_tokens[0]
 
 
     assert prompt_tokens.ndim == 2
     assert prompt_tokens.ndim == 2
-    data = prompt_tokens + 2
+    data = prompt_tokens + 1
 
 
     if prompt_tokens.shape[0] > num_codebooks:
     if prompt_tokens.shape[0] > num_codebooks:
         logger.warning(
         logger.warning(
@@ -332,13 +320,9 @@ def encode_tokens(
         )
         )
         data = data[:num_codebooks]
         data = data[:num_codebooks]
 
 
-    # Add eos token for each codebook
+    # Add pad token for each codebook
     data = torch.cat(
     data = torch.cat(
-        (
-            data,
-            torch.ones((data.size(0), 1), dtype=torch.int, device=device)
-            * CODEBOOK_EOS_TOKEN_ID,
-        ),
+        (data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)),
         dim=1,
         dim=1,
     )
     )
 
 
@@ -356,16 +340,10 @@ def encode_tokens(
     return prompt
     return prompt
 
 
 
 
-def load_model(
-    config_name, checkpoint_path, device, precision, max_length, compile=False
-):
-    hydra.core.global_hydra.GlobalHydra.instance().clear()
-    with initialize(version_base="1.3", config_path="../../fish_speech/configs/model"):
-        cfg = compose(
-            config_name=config_name, overrides=[f"config.max_seq_len={max_length}"]
-        )
-
-    model: Union[NaiveTransformer, DualARTransformer] = instantiate(cfg)
+def load_model(checkpoint_path, device, precision, compile=False):
+    model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
+        checkpoint_path, load_weights=True
+    )
 
 
     if "int8" in str(checkpoint_path):
     if "int8" in str(checkpoint_path):
         logger.info("Using int8 weight-only quantization!")
         logger.info("Using int8 weight-only quantization!")
@@ -384,21 +362,8 @@ def load_model(
         simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
         simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
         model = simple_quantizer.convert_for_runtime()
         model = simple_quantizer.convert_for_runtime()
 
 
-    checkpoint = torch.load(str(checkpoint_path), map_location="cpu")
-    if "state_dict" in checkpoint:
-        checkpoint = checkpoint["state_dict"]
-
-    if any(k.startswith("model.") for k in checkpoint):
-        checkpoint = {
-            k.replace("model.", ""): v
-            for k, v in checkpoint.items()
-            if k.startswith("model.")
-        }
-
-    model.load_state_dict(checkpoint, assign=True)
-
     model = model.to(device=device, dtype=precision)
     model = model.to(device=device, dtype=precision)
-    logger.info("Restored model from checkpoint")
+    logger.info(f"Restored model from checkpoint")
 
 
     if isinstance(model, DualARTransformer):
     if isinstance(model, DualARTransformer):
         decode_one_token = decode_one_token_ar
         decode_one_token = decode_one_token_ar
@@ -426,7 +391,6 @@ class GenerateResponse:
 def generate_long(
 def generate_long(
     *,
     *,
     model,
     model,
-    tokenizer: callable,
     device: str | torch.device,
     device: str | torch.device,
     decode_one_token: callable,
     decode_one_token: callable,
     text: str,
     text: str,
@@ -439,7 +403,6 @@ def generate_long(
     iterative_prompt: bool = True,
     iterative_prompt: bool = True,
     max_length: int = 2048,
     max_length: int = 2048,
     chunk_length: int = 150,
     chunk_length: int = 150,
-    speaker: Optional[str] = None,
     prompt_text: Optional[str | list[str]] = None,
     prompt_text: Optional[str | list[str]] = None,
     prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
     prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
 ):
 ):
@@ -457,6 +420,7 @@ def generate_long(
     ), "Prompt text and tokens must have the same length"
     ), "Prompt text and tokens must have the same length"
 
 
     model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
     model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
+    tokenizer = model.tokenizer
     im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
     im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
 
 
     encoded = []
     encoded = []
@@ -469,10 +433,8 @@ def generate_long(
                 encode_tokens(
                 encode_tokens(
                     tokenizer,
                     tokenizer,
                     string=t,
                     string=t,
-                    bos=idx == 0,
                     device=device,
                     device=device,
                     prompt_tokens=c,
                     prompt_tokens=c,
-                    speaker=speaker,
                     num_codebooks=model.config.num_codebooks,
                     num_codebooks=model.config.num_codebooks,
                 )
                 )
             )
             )
@@ -482,9 +444,7 @@ def generate_long(
             encode_tokens(
             encode_tokens(
                 tokenizer,
                 tokenizer,
                 string=text,
                 string=text,
-                bos=idx == 0 and not use_prompt,
                 device=device,
                 device=device,
-                speaker=speaker,
                 num_codebooks=model.config.num_codebooks,
                 num_codebooks=model.config.num_codebooks,
             )
             )
         )
         )
@@ -544,7 +504,6 @@ def generate_long(
                 model=model,
                 model=model,
                 prompt=cat_encoded,
                 prompt=cat_encoded,
                 max_new_tokens=max_new_tokens,
                 max_new_tokens=max_new_tokens,
-                eos_token_id=tokenizer.eos_token_id,
                 im_end_id=im_end_id,
                 im_end_id=im_end_id,
                 decode_one_token=decode_one_token,
                 decode_one_token=decode_one_token,
                 temperature=temperature,
                 temperature=temperature,
@@ -576,19 +535,13 @@ def generate_long(
 
 
             # Put the generated tokens
             # Put the generated tokens
             # since there is <im_end> and <eos> tokens, we remove last 2 tokens
             # since there is <im_end> and <eos> tokens, we remove last 2 tokens
-            codes = y[1:, prompt_length:-2].clone()
-
-            codes = codes - 2
+            codes = y[1:, prompt_length:-1].clone()
+            codes = codes - 1
             assert (codes >= 0).all(), f"Negative code found"
             assert (codes >= 0).all(), f"Negative code found"
 
 
             decoded = y[:, prompt_length:-1].clone()
             decoded = y[:, prompt_length:-1].clone()
-            if decoded[0, -1] != im_end_id:  # <im_end>
-                val = [[im_end_id]] + [[CODEBOOK_EOS_TOKEN_ID]] * (decoded.size(0) - 1)
-                decoded = torch.cat(
-                    (decoded, torch.tensor(val, device=device, dtype=torch.int)), dim=1
-                )
-
             # But for global encoding, we should keep the <im_end> token
             # But for global encoding, we should keep the <im_end> token
+
             global_encoded.append(decoded)
             global_encoded.append(decoded)
             assert (codes >= 0).all(), f"Negative code found: {codes}"
             assert (codes >= 0).all(), f"Negative code found: {codes}"
             yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
             yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
@@ -611,11 +564,9 @@ class GenerateRequest:
 
 
 
 
 def launch_thread_safe_queue(
 def launch_thread_safe_queue(
-    config_name,
     checkpoint_path,
     checkpoint_path,
     device,
     device,
     precision,
     precision,
-    max_length: int,
     compile: bool = False,
     compile: bool = False,
 ):
 ):
     input_queue = queue.Queue()
     input_queue = queue.Queue()
@@ -623,7 +574,7 @@ def launch_thread_safe_queue(
 
 
     def worker():
     def worker():
         model, decode_one_token = load_model(
         model, decode_one_token = load_model(
-            config_name, checkpoint_path, device, precision, max_length, compile=compile
+            checkpoint_path, device, precision, compile=compile
         )
         )
         init_event.set()
         init_event.set()
 
 
@@ -672,16 +623,12 @@ def launch_thread_safe_queue(
 @click.option(
 @click.option(
     "--checkpoint-path",
     "--checkpoint-path",
     type=click.Path(path_type=Path, exists=True),
     type=click.Path(path_type=Path, exists=True),
-    default="checkpoints/text2semantic-sft-medium-v1-4k.pth",
+    default="checkpoints/fish-speech-1.2",
 )
 )
-@click.option("--config-name", type=str, default="dual_ar_2_codebook_medium")
-@click.option("--tokenizer", type=str, default="fishaudio/fish-speech-1")
 @click.option("--compile/--no-compile", default=False)
 @click.option("--compile/--no-compile", default=False)
 @click.option("--seed", type=int, default=42)
 @click.option("--seed", type=int, default=42)
-@click.option("--speaker", type=str, default=None)
 @click.option("--half/--no-half", default=False)
 @click.option("--half/--no-half", default=False)
 @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
 @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
-@click.option("--max-length", type=int, default=2048)
 @click.option("--chunk-length", type=int, default=150)
 @click.option("--chunk-length", type=int, default=150)
 def main(
 def main(
     text: str,
     text: str,
@@ -693,14 +640,10 @@ def main(
     repetition_penalty: float,
     repetition_penalty: float,
     temperature: float,
     temperature: float,
     checkpoint_path: Path,
     checkpoint_path: Path,
-    config_name: str,
-    tokenizer: str,
     compile: bool,
     compile: bool,
     seed: int,
     seed: int,
-    speaker: Optional[str],
     half: bool,
     half: bool,
     iterative_prompt: bool,
     iterative_prompt: bool,
-    max_length: int,
     chunk_length: int,
     chunk_length: int,
 ) -> None:
 ) -> None:
     device = "cuda"
     device = "cuda"
@@ -715,7 +658,7 @@ def main(
     logger.info("Loading model ...")
     logger.info("Loading model ...")
     t0 = time.time()
     t0 = time.time()
     model, decode_one_token = load_model(
     model, decode_one_token = load_model(
-        config_name, checkpoint_path, device, precision, max_length, compile=compile
+        checkpoint_path, device, precision, compile=compile
     )
     )
 
 
     if torch.cuda.is_available():
     if torch.cuda.is_available():
@@ -726,7 +669,6 @@ def main(
     if prompt_tokens is not None:
     if prompt_tokens is not None:
         prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
         prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
 
 
-    tokenizer = AutoTokenizer.from_pretrained(tokenizer)
     torch.manual_seed(seed)
     torch.manual_seed(seed)
 
 
     if torch.cuda.is_available():
     if torch.cuda.is_available():
@@ -742,11 +684,8 @@ def main(
         top_p=top_p,
         top_p=top_p,
         repetition_penalty=repetition_penalty,
         repetition_penalty=repetition_penalty,
         temperature=temperature,
         temperature=temperature,
-        tokenizer=tokenizer,
         compile=compile,
         compile=compile,
-        speaker=speaker,
         iterative_prompt=iterative_prompt,
         iterative_prompt=iterative_prompt,
-        max_length=max_length,
         chunk_length=chunk_length,
         chunk_length=chunk_length,
         prompt_text=prompt_text,
         prompt_text=prompt_text,
         prompt_tokens=prompt_tokens,
         prompt_tokens=prompt_tokens,

+ 1 - 3
tools/llama/merge_lora.py

@@ -14,9 +14,7 @@ from fish_speech.models.text2semantic.lora_utils import (
 @click.command()
 @click.command()
 @click.option("--llama-config", type=str, default="dual_ar_2_codebook_medium")
 @click.option("--llama-config", type=str, default="dual_ar_2_codebook_medium")
 @click.option("--lora-config", type=str, default="r_8_alpha_16")
 @click.option("--lora-config", type=str, default="r_8_alpha_16")
-@click.option(
-    "--llama-weight", type=str, default="checkpoints/text2semantic-sft-medium-v1-4k.pth"
-)
+@click.option("--llama-weight", type=str, default="checkpoints/fish-speech-1.2")
 @click.option("--lora-weight", type=str, required=True)
 @click.option("--lora-weight", type=str, required=True)
 @click.option("--output", type=str, required=True)
 @click.option("--output", type=str, required=True)
 def merge(llama_config, lora_config, llama_weight, lora_weight, output):
 def merge(llama_config, lora_config, llama_weight, lora_weight, output):

+ 1 - 1
tools/llama/quantize.py

@@ -419,7 +419,7 @@ class WeightOnlyInt4Linear(torch.nn.Module):
 @click.option(
 @click.option(
     "--checkpoint-path",
     "--checkpoint-path",
     type=click.Path(path_type=Path, exists=True),
     type=click.Path(path_type=Path, exists=True),
-    default="checkpoints/text2semantic-sft-medium-v1-4k.pth",
+    default="checkpoints/fish-speech-1.2",
 )
 )
 @click.option("--config-name", type=str, default="dual_ar_2_codebook_medium")
 @click.option("--config-name", type=str, default="dual_ar_2_codebook_medium")
 @click.option(
 @click.option(

+ 1 - 1
tools/vits_decoder/inference.py

@@ -72,7 +72,7 @@ def load_model(config_name, checkpoint_path, device="cuda"):
 @click.option(
 @click.option(
     "--checkpoint-path",
     "--checkpoint-path",
     "-ckpt",
     "-ckpt",
-    default="checkpoints/vq-gan-group-fsq-2x1024.pth",
+    default="checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
 )
 )
 @click.option(
 @click.option(
     "--device",
     "--device",

+ 19 - 9
tools/vqgan/extract_vq.py

@@ -41,23 +41,31 @@ logger.add(sys.stderr, format=logger_format)
 
 
 @lru_cache(maxsize=1)
 @lru_cache(maxsize=1)
 def get_model(
 def get_model(
-    config_name: str = "vqgan_pretrain",
-    checkpoint_path: str = "checkpoints/vq-gan-group-fsq-2x1024.pth",
+    config_name: str = "firefly_gan_vq",
+    checkpoint_path: str = "checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
+    device: str | torch.device = "cuda",
 ):
 ):
     with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
     with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
         cfg = compose(config_name=config_name)
         cfg = compose(config_name=config_name)
 
 
-    model: LightningModule = instantiate(cfg.model)
+    model = instantiate(cfg)
     state_dict = torch.load(
     state_dict = torch.load(
         checkpoint_path,
         checkpoint_path,
-        map_location=model.device,
+        map_location=device,
     )
     )
     if "state_dict" in state_dict:
     if "state_dict" in state_dict:
         state_dict = state_dict["state_dict"]
         state_dict = state_dict["state_dict"]
 
 
+    if any("generator" in k for k in state_dict):
+        state_dict = {
+            k.replace("generator.", ""): v
+            for k, v in state_dict.items()
+            if "generator." in k
+        }
+
     model.load_state_dict(state_dict, strict=False)
     model.load_state_dict(state_dict, strict=False)
     model.eval()
     model.eval()
-    model.cuda()
+    model.to(device)
 
 
     logger.info(f"Loaded model")
     logger.info(f"Loaded model")
     return model
     return model
@@ -82,8 +90,10 @@ def process_batch(files: list[Path], model) -> float:
         if wav.shape[0] > 1:
         if wav.shape[0] > 1:
             wav = wav.mean(dim=0, keepdim=True)
             wav = wav.mean(dim=0, keepdim=True)
 
 
-        wav = torchaudio.functional.resample(wav.cuda(), sr, model.sampling_rate)[0]
-        total_time += len(wav) / model.sampling_rate
+        wav = torchaudio.functional.resample(
+            wav.cuda(), sr, model.spec_transform.sample_rate
+        )[0]
+        total_time += len(wav) / model.spec_transform.sample_rate
         max_length = max(max_length, len(wav))
         max_length = max(max_length, len(wav))
 
 
         wavs.append(wav)
         wavs.append(wav)
@@ -120,10 +130,10 @@ def process_batch(files: list[Path], model) -> float:
 @click.command()
 @click.command()
 @click.argument("folder")
 @click.argument("folder")
 @click.option("--num-workers", default=1)
 @click.option("--num-workers", default=1)
-@click.option("--config-name", default="vqgan_pretrain")
+@click.option("--config-name", default="firefly_gan_vq")
 @click.option(
 @click.option(
     "--checkpoint-path",
     "--checkpoint-path",
-    default="checkpoints/vq-gan-group-fsq-2x1024.pth",
+    default="checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
 )
 )
 @click.option("--batch-size", default=64)
 @click.option("--batch-size", default=64)
 @click.option("--filelist", default=None, type=Path)
 @click.option("--filelist", default=None, type=Path)

+ 24 - 22
tools/vqgan/inference.py

@@ -8,7 +8,6 @@ import torch
 import torchaudio
 import torchaudio
 from hydra import compose, initialize
 from hydra import compose, initialize
 from hydra.utils import instantiate
 from hydra.utils import instantiate
-from lightning import LightningModule
 from loguru import logger
 from loguru import logger
 from omegaconf import OmegaConf
 from omegaconf import OmegaConf
 
 
@@ -23,20 +22,26 @@ def load_model(config_name, checkpoint_path, device="cuda"):
     with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
     with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
         cfg = compose(config_name=config_name)
         cfg = compose(config_name=config_name)
 
 
-    model: LightningModule = instantiate(cfg.model)
+    model = instantiate(cfg)
     state_dict = torch.load(
     state_dict = torch.load(
         checkpoint_path,
         checkpoint_path,
-        map_location=model.device,
+        map_location=device,
     )
     )
-
     if "state_dict" in state_dict:
     if "state_dict" in state_dict:
         state_dict = state_dict["state_dict"]
         state_dict = state_dict["state_dict"]
 
 
-    model.load_state_dict(state_dict, strict=False)
+    if any("generator" in k for k in state_dict):
+        state_dict = {
+            k.replace("generator.", ""): v
+            for k, v in state_dict.items()
+            if "generator." in k
+        }
+
+    result = model.load_state_dict(state_dict, strict=False)
     model.eval()
     model.eval()
     model.to(device)
     model.to(device)
-    logger.info("Restored model from checkpoint")
 
 
+    logger.info(f"Loaded model: {result}")
     return model
     return model
 
 
 
 
@@ -51,11 +56,10 @@ def load_model(config_name, checkpoint_path, device="cuda"):
 @click.option(
 @click.option(
     "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
     "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
 )
 )
-@click.option("--config-name", "-cfg", default="vqgan_pretrain")
+@click.option("--config-name", default="firefly_gan_vq")
 @click.option(
 @click.option(
     "--checkpoint-path",
     "--checkpoint-path",
-    "-ckpt",
-    default="checkpoints/vq-gan-group-fsq-2x1024.pth",
+    default="checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
 )
 )
 @click.option(
 @click.option(
     "--device",
     "--device",
@@ -72,17 +76,17 @@ def main(input_path, output_path, config_name, checkpoint_path, device):
         audio, sr = torchaudio.load(str(input_path))
         audio, sr = torchaudio.load(str(input_path))
         if audio.shape[0] > 1:
         if audio.shape[0] > 1:
             audio = audio.mean(0, keepdim=True)
             audio = audio.mean(0, keepdim=True)
-        audio = torchaudio.functional.resample(audio, sr, model.sampling_rate)
+        audio = torchaudio.functional.resample(
+            audio, sr, model.spec_transform.sample_rate
+        )
 
 
-        audios = audio[None].to(model.device)
+        audios = audio[None].to(device)
         logger.info(
         logger.info(
-            f"Loaded audio with {audios.shape[2] / model.sampling_rate:.2f} seconds"
+            f"Loaded audio with {audios.shape[2] / model.spec_transform.sample_rate:.2f} seconds"
         )
         )
 
 
         # VQ Encoder
         # VQ Encoder
-        audio_lengths = torch.tensor(
-            [audios.shape[2]], device=model.device, dtype=torch.long
-        )
+        audio_lengths = torch.tensor([audios.shape[2]], device=device, dtype=torch.long)
         indices = model.encode(audios, audio_lengths)[0][0]
         indices = model.encode(audios, audio_lengths)[0][0]
 
 
         logger.info(f"Generated indices of shape {indices.shape}")
         logger.info(f"Generated indices of shape {indices.shape}")
@@ -92,17 +96,15 @@ def main(input_path, output_path, config_name, checkpoint_path, device):
     elif input_path.suffix == ".npy":
     elif input_path.suffix == ".npy":
         logger.info(f"Processing precomputed indices from {input_path}")
         logger.info(f"Processing precomputed indices from {input_path}")
         indices = np.load(input_path)
         indices = np.load(input_path)
-        indices = torch.from_numpy(indices).to(model.device).long()
+        indices = torch.from_numpy(indices).to(device).long()
         assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
         assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
     else:
     else:
         raise ValueError(f"Unknown input type: {input_path}")
         raise ValueError(f"Unknown input type: {input_path}")
 
 
     # Restore
     # Restore
-    feature_lengths = torch.tensor([indices.shape[1]], device=model.device)
-    fake_audios = model.decode(
-        indices=indices[None], feature_lengths=feature_lengths, return_audios=True
-    )
-    audio_time = fake_audios.shape[-1] / model.sampling_rate
+    feature_lengths = torch.tensor([indices.shape[1]], device=device)
+    fake_audios = model.decode(indices=indices[None], feature_lengths=feature_lengths)
+    audio_time = fake_audios.shape[-1] / model.spec_transform.sample_rate
 
 
     logger.info(
     logger.info(
         f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}"
         f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}"
@@ -110,7 +112,7 @@ def main(input_path, output_path, config_name, checkpoint_path, device):
 
 
     # Save audio
     # Save audio
     fake_audio = fake_audios[0, 0].float().cpu().numpy()
     fake_audio = fake_audios[0, 0].float().cpu().numpy()
-    sf.write(output_path, fake_audio, model.sampling_rate)
+    sf.write(output_path, fake_audio, model.spec_transform.sample_rate)
     logger.info(f"Saved audio to {output_path}")
     logger.info(f"Saved audio to {output_path}")
 
 
 
 

+ 3 - 9
tools/webui.py

@@ -443,21 +443,17 @@ def parse_args():
     parser.add_argument(
     parser.add_argument(
         "--llama-checkpoint-path",
         "--llama-checkpoint-path",
         type=Path,
         type=Path,
-        default="checkpoints/text2semantic-sft-large-v1.1-4k.pth",
-    )
-    parser.add_argument(
-        "--llama-config-name", type=str, default="dual_ar_2_codebook_large"
+        default="checkpoints/fish-speech-1.2",
     )
     )
     parser.add_argument(
     parser.add_argument(
         "--decoder-checkpoint-path",
         "--decoder-checkpoint-path",
         type=Path,
         type=Path,
-        default="checkpoints/vq-gan-group-fsq-2x1024.pth",
+        default="checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
     )
     )
-    parser.add_argument("--decoder-config-name", type=str, default="vqgan_pretrain")
+    parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
     parser.add_argument("--tokenizer", type=str, default="fishaudio/fish-speech-1")
     parser.add_argument("--tokenizer", type=str, default="fishaudio/fish-speech-1")
     parser.add_argument("--device", type=str, default="cuda")
     parser.add_argument("--device", type=str, default="cuda")
     parser.add_argument("--half", action="store_true")
     parser.add_argument("--half", action="store_true")
-    parser.add_argument("--max-length", type=int, default=2048)
     parser.add_argument("--compile", action="store_true")
     parser.add_argument("--compile", action="store_true")
     parser.add_argument("--max-gradio-length", type=int, default=0)
     parser.add_argument("--max-gradio-length", type=int, default=0)
 
 
@@ -470,11 +466,9 @@ if __name__ == "__main__":
 
 
     logger.info("Loading Llama model...")
     logger.info("Loading Llama model...")
     llama_queue = launch_thread_safe_queue(
     llama_queue = launch_thread_safe_queue(
-        config_name=args.llama_config_name,
         checkpoint_path=args.llama_checkpoint_path,
         checkpoint_path=args.llama_checkpoint_path,
         device=args.device,
         device=args.device,
         precision=args.precision,
         precision=args.precision,
-        max_length=args.max_length,
         compile=args.compile,
         compile=args.compile,
     )
     )
     llama_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
     llama_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)