Lengyue 1 год назад
Родитель
Сommit
c9209e94d9

+ 2 - 1
fish_speech/configs/lora/r_8_alpha_16.yaml

@@ -1,3 +1,4 @@
-_target_: fish_speech.models.text2semantic.lora_utils.LoraConfig
+_target_: fish_speech.models.text2semantic.lora.LoraConfig
 r: 8
 lora_alpha: 16
+lora_dropout: 0.01

+ 0 - 66
fish_speech/configs/text2semantic_agent.yaml

@@ -1,66 +0,0 @@
-defaults:
-  - base
-  - model@model.model: dual_ar_2_codebook_1.3b
-  - _self_
-
-project: text2semantic_agent_dual_ar_debug
-max_length: 2048
-ckpt_path: checkpoints/fish-speech-agent-1/step_000013000.ckpt
-resume_weights_only: true
-
-# Lightning Trainer
-trainer:
-  accumulate_grad_batches: 1
-  gradient_clip_val: 1.0
-  gradient_clip_algorithm: 'norm'
-  max_steps: 1_000_000
-  precision: bf16-true
-  log_every_n_steps: 10
-  limit_val_batches: 10
-  val_check_interval: 1000
-
-# Dataset Configuration
-tokenizer:
-  _target_: transformers.AutoTokenizer.from_pretrained
-  pretrained_model_name_or_path: checkpoints/fish-speech-agent-1
-
-# Dataset Configuration
-train_dataset: {}
-val_dataset: {}
-
-data:
-  _target_: fish_speech.datasets.text.TextDataModule
-  train_dataset: ${train_dataset}
-  val_dataset: ${val_dataset}
-  num_workers: 4
-  batch_size: 8
-  tokenizer: ${tokenizer}
-  max_length: ${max_length}
-
-# Model Configuration
-model:
-  _target_: fish_speech.models.text2semantic.TextToSemantic
-  model: {}
-
-  optimizer:
-    _target_: torch.optim.AdamW
-    _partial_: true
-    lr: 3e-4
-    weight_decay: 0.01
-    betas: [0.9, 0.95]
-    eps: 1e-5
-
-  lr_scheduler:
-    _target_: torch.optim.lr_scheduler.LambdaLR
-    _partial_: true
-    lr_lambda:
-      _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
-      _partial_: true
-      num_warmup_steps: 1000
-      num_training_steps: ${trainer.max_steps}
-      final_lr_ratio: 0.1
-
-# Callbacks
-callbacks:
-  model_checkpoint:
-    every_n_train_steps: ${trainer.val_check_interval}

+ 22 - 19
fish_speech/configs/text2semantic_finetune.yaml

@@ -1,12 +1,10 @@
 defaults:
   - base
-  - model@model.model: dual_ar_2_codebook_small
   - _self_
 
 project: text2semantic_finetune_dual_ar
-max_length: 2048
-ckpt_path: checkpoints/text2semantic-sft-medium-v1.1-4k.pth
-resume_weights_only: true
+max_length: 4096
+pretrained_ckpt_path: checkpoints/fish-speech-1.2
 
 # Lightning Trainer
 trainer:
@@ -21,31 +19,31 @@ trainer:
 # Dataset Configuration
 tokenizer:
   _target_: transformers.AutoTokenizer.from_pretrained
-  pretrained_model_name_or_path: fishaudio/fish-speech-1
+  pretrained_model_name_or_path: ${pretrained_ckpt_path}
 
 # Dataset Configuration
 train_dataset:
-  _target_: fish_speech.datasets.text.AutoAugTextDataset
+  _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
   proto_files:
     - data/protos
   tokenizer: ${tokenizer}
+  causal: true
   max_length: ${max_length}
-  num_codebooks: ${model.model.config.num_codebooks}
-  use_speaker: 0.5
+  use_speaker: false
   interactive_prob: 0.7
 
 val_dataset:
-  _target_: fish_speech.datasets.text.AutoAugTextDataset
+  _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
   proto_files:
     - data/protos
   tokenizer: ${tokenizer}
+  causal: true
   max_length: ${max_length}
-  num_codebooks: ${model.model.config.num_codebooks}
-  use_speaker: 0.5
+  use_speaker: false
   interactive_prob: 0.7
 
 data:
-  _target_: fish_speech.datasets.text.TextDataModule
+  _target_: fish_speech.datasets.semantic.SemanticDataModule
   train_dataset: ${train_dataset}
   val_dataset: ${val_dataset}
   num_workers: 4
@@ -55,13 +53,18 @@ data:
 
 # Model Configuration
 model:
-  _target_: fish_speech.models.text2semantic.TextToSemantic
-  model: {}
+  _target_: fish_speech.models.text2semantic.lit_module.TextToSemantic
+  model: 
+    _target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained
+    path: ${pretrained_ckpt_path}
+    load_weights: true
+    max_length: ${max_length}
+    lora_config: null
 
   optimizer:
     _target_: torch.optim.AdamW
     _partial_: true
-    lr: 1e-5
+    lr: 1e-4
     weight_decay: 0
     betas: [0.9, 0.95]
     eps: 1e-5
@@ -70,12 +73,12 @@ model:
     _target_: torch.optim.lr_scheduler.LambdaLR
     _partial_: true
     lr_lambda:
-      _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
+      _target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda
       _partial_: true
-      num_warmup_steps: 0.1
-      num_training_steps: ${trainer.max_steps}
+      num_warmup_steps: 10
 
 # Callbacks
 callbacks:
   model_checkpoint:
-    every_n_train_steps: ${trainer.val_check_interval}
+    every_n_train_steps: 10
+    # ${trainer.val_check_interval}

+ 0 - 76
fish_speech/configs/text2semantic_pretrain.yaml

@@ -1,76 +0,0 @@
-defaults:
-  - base
-  - model@model.model: dual_ar_2_codebook_small
-  - _self_
-
-project: text2semantic_pretrain_dual_ar_debug
-max_length: 2048
-
-# Lightning Trainer
-trainer:
-  accumulate_grad_batches: 1
-  gradient_clip_val: 1.0
-  gradient_clip_algorithm: 'norm'
-  max_steps: 1_000_000
-  precision: bf16-true
-  limit_val_batches: 10
-
-# Dataset Configuration
-tokenizer:
-  _target_: transformers.AutoTokenizer.from_pretrained
-  pretrained_model_name_or_path: fishaudio/fish-speech-1
-
-# Dataset Configuration
-train_dataset:
-  _target_: fish_speech.datasets.text.AutoAugTextDataset
-  proto_files:
-    - data/protos/train
-  tokenizer: ${tokenizer}
-  max_length: ${max_length}
-  num_codebooks: ${model.model.config.num_codebooks}
-  use_speaker: false
-  interactive_prob: 0.5
-  skip_text_prob: 0.1
-
-val_dataset:
-  _target_: fish_speech.datasets.text.AutoAugTextDataset
-  proto_files:
-    - data/protos/test
-  tokenizer: ${tokenizer}
-  max_length: ${max_length}
-  num_codebooks: ${model.model.config.num_codebooks}
-  use_speaker: false
-  interactive_prob: 0.5
-  skip_text_prob: 0.1
-
-data:
-  _target_: fish_speech.datasets.text.TextDataModule
-  train_dataset: ${train_dataset}
-  val_dataset: ${val_dataset}
-  num_workers: 4
-  batch_size: 8
-  tokenizer: ${tokenizer}
-  max_length: ${max_length}
-
-# Model Configuration
-model:
-  _target_: fish_speech.models.text2semantic.TextToSemantic
-  model: {}
-
-  optimizer:
-    _target_: torch.optim.AdamW
-    _partial_: true
-    lr: 3e-4
-    weight_decay: 0.01
-    betas: [0.9, 0.95]
-    eps: 1e-5
-
-  lr_scheduler:
-    _target_: torch.optim.lr_scheduler.LambdaLR
-    _partial_: true
-    lr_lambda:
-      _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
-      _partial_: true
-      num_warmup_steps: 2000
-      num_training_steps: ${trainer.max_steps}
-      final_lr_ratio: 0.1

+ 0 - 82
fish_speech/configs/text2semantic_sft.yaml

@@ -1,82 +0,0 @@
-defaults:
-  - base
-  - model@model.model: dual_ar_2_codebook_small
-  - _self_
-
-project: text2semantic_sft_dual_ar
-max_length: 4096
-ckpt_path: checkpoints/text2semantic-medium-v1-2k.pth
-resume_weights_only: true
-
-# Lightning Trainer
-trainer:
-  accumulate_grad_batches: 1
-  gradient_clip_val: 1.0
-  gradient_clip_algorithm: 'norm'
-  max_steps: 10_000
-  precision: bf16-true
-  limit_val_batches: 10
-  val_check_interval: 500
-
-# Dataset Configuration
-tokenizer:
-  _target_: transformers.AutoTokenizer.from_pretrained
-  pretrained_model_name_or_path: fishaudio/fish-speech-1
-
-# Dataset Configuration
-train_dataset:
-  _target_: fish_speech.datasets.text.AutoAugTextDataset
-  proto_files:
-    - data/protos/sft
-  tokenizer: ${tokenizer}
-  max_length: ${max_length}
-  num_codebooks: ${model.model.config.num_codebooks}
-  use_speaker: 0.5
-  interactive_prob: 0.7
-
-val_dataset:
-  _target_: fish_speech.datasets.text.AutoAugTextDataset
-  proto_files:
-    - data/protos/sft
-  tokenizer: ${tokenizer}
-  max_length: ${max_length}
-  num_codebooks: ${model.model.config.num_codebooks}
-  use_speaker: 0.5
-  interactive_prob: 0.7
-
-data:
-  _target_: fish_speech.datasets.text.TextDataModule
-  train_dataset: ${train_dataset}
-  val_dataset: ${val_dataset}
-  num_workers: 4
-  batch_size: 8
-  tokenizer: ${tokenizer}
-  max_length: ${max_length}
-
-# Model Configuration
-model:
-  _target_: fish_speech.models.text2semantic.TextToSemantic
-  model: {}
-
-  optimizer:
-    _target_: torch.optim.AdamW
-    _partial_: true
-    lr: 4e-5
-    weight_decay: 0
-    betas: [0.9, 0.95]
-    eps: 1e-5
-
-  lr_scheduler:
-    _target_: torch.optim.lr_scheduler.LambdaLR
-    _partial_: true
-    lr_lambda:
-      _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
-      _partial_: true
-      num_warmup_steps: 100
-      num_training_steps: ${trainer.max_steps}
-      final_lr_ratio: 0
-
-callbacks:
-  model_checkpoint:
-    every_n_train_steps: 1000
-    save_top_k: 10

+ 0 - 381
fish_speech/datasets/prompts.py

@@ -1,381 +0,0 @@
-# "Transcribe the following audio into text."
-# "Transcribe what you will hear."
-
-asr_instructions = [
-    "Transcribe:",
-    "Transcribe the following audio into text.",
-    "Convert the audio you're about to hear into written text.",
-    "Please write down what you hear in the audio file.",
-    "Listen to the audio and type out its contents.",
-    "Your task is to write the audio's content in text form.",
-    "Transcribe the content of the audio into text.",
-    "Transform the given audio into a textual format.",
-    "Listen to the following sound clip and transcribe it.",
-    "The audio provided should be converted into written words.",
-    "Document the audio in text.",
-    "Put the audio's dialogue into written form.",
-    "Capture the audio's message in text.",
-    "Turn the sound file's speech into text.",
-    "Render the audio into a text version.",
-    "Translate the audio recording to text.",
-    "Write out the dialogue from the audio.",
-    "Listen and transcribe the audio into words.",
-    "Change the audio into a written transcript.",
-    "Your job is to transcribe the audio to text.",
-    "Please transcribe the spoken words into text.",
-    "The task is to convert audio speech into written text.",
-    "Make a text transcript of the following audio.",
-    "Decode the audio into a written document.",
-    "Write down the transcription of the audio.",
-    "Please provide a text version of this audio.",
-    "The objective is to transcribe the audio into readable text.",
-    "Listen carefully and type out the audio.",
-    "Transform this audio clip into a text document.",
-    "Your assignment is to transcribe this audio.",
-    "Transcribe this sound recording into text format.",
-    "The goal is to turn the audio into text.",
-    "Your duty is to document the audio in written form.",
-    "Listen to this audio piece and write down its contents.",
-    "The task is converting the audio into text.",
-    "Please create a textual transcription of the audio.",
-    "Capture in writing what is said in the audio.",
-    "Transcribe the audible content into a text format.",
-    "The mission is to transcribe the audio into text.",
-    "Your task: convert the audio to text.",
-    "Write the contents of the audio as text.",
-    "Listen to the clip and transcribe its audio to text.",
-    "Transcribe the given audio track into written words.",
-    "The assignment is to write out the audio in text.",
-    "Convert the spoken words into text.",
-    "Transcribe the voice recording into text.",
-    "Your task is to make a written record of the audio.",
-    "Listen to the audio and reproduce it in text.",
-    "Transcribe the following sound into written text.",
-    "Your challenge is to transcribe the audio into written form.",
-    "Make a written version of the audio.",
-    "Take the audio and transcribe it to text.",
-    "Write down everything you hear in the audio.",
-    "Please put the audio into text format.",
-    "Your role is to transcribe the following audio into text.",
-    "Convert the audio message into written text.",
-    "Provide a written transcription of the audio.",
-    "Listen and convert the audio to text.",
-    "The requirement is to transcribe the audio into text form.",
-    "Document in text what the audio says.",
-    "Transcribe into text what you hear in the audio.",
-    "Translate the audio file's contents into text.",
-    "The task is to create a text transcript of the audio.",
-    "Your assignment: Translate the audio into written words.",
-    "Write a textual representation of the audio.",
-    "Capture the essence of the audio in text.",
-    "Your job: Listen to the audio and transcribe it.",
-    "Turn the audio content into a text transcript.",
-    "The task at hand is to transcribe the audio to text.",
-    "Reproduce the audio in text form.",
-    "Your mission: Convert the audio into a textual format.",
-    "Transcribe what is spoken in the audio into text.",
-    "Create a written version of what's in the audio.",
-    "Transform the spoken audio into text.",
-    "Document the spoken words in the audio as text.",
-    "The objective is to write down the audio in text.",
-    "Your goal: Transcribe the audio into text.",
-    "Please convert the audio file into text.",
-    "Transcribe the audio clip into written text.",
-    "Listen to the audio and transcribe the speech into text.",
-    "Transform the voice from the audio into written words.",
-    "The task is to write the audio's speech in text form.",
-    "Your duty: Write down what the audio says.",
-    "Turn the given audio into a written format.",
-    "Write in text form what is said in the audio.",
-    "Your task: Document the audio in text.",
-    "Provide a text transcription of the audio.",
-    "Provide a text transcription of the audio.",
-    "Write down the audio you listen to.",
-    "Type out the spoken words you hear.",
-    "Document the audio content verbatim.",
-    "Transcribe the spoken content accurately.",
-    "Convert the audio you hear into text.",
-    "Record in writing what is said in the audio.",
-    "Capture the spoken words in written form.",
-    "Translate the audio into written text.",
-    "Jot down the words you hear in the audio.",
-    "Put into writing the spoken words you hear.",
-    "Transcribe the auditory information verbatim.",
-    "Note down the dialogue from the audio.",
-    "Write out the spoken words from the audio.",
-    "Transcribe the oral presentation into text.",
-    "Render the spoken audio into written form.",
-    "Reproduce the spoken words in text form.",
-    "Document what is being said in the audio.",
-    "Translate the spoken word into written form.",
-    "Write verbatim what you hear in the audio.",
-    "Capture in writing the contents of the audio.",
-    "Transcribe verbatim the spoken words.",
-    "Write down verbatim what is spoken.",
-    "Transcribe the sounds into words on paper.",
-    "Translate the sounds you hear into words.",
-    "Write the spoken words in text form.",
-    "Reproduce the audio content in writing.",
-    "Note verbatim what is said in the audio.",
-    "Put the audio content into written words.",
-    "Record the spoken words into text format.",
-    "Transcribe the audio into a written document.",
-    "Write down exactly what you hear.",
-    "Type out the content of the audio.",
-    "Document the words spoken in the audio.",
-    "Translate the verbal content into text.",
-    "Convert what you hear into written words.",
-    "Capture the essence of the audio in writing.",
-    "Reproduce the spoken content in written form.",
-    "Jot down exactly what is said in the audio.",
-    "Document every word you hear in the audio.",
-    "Record the audio content by writing it down.",
-    "Capture the audio's spoken words in text.",
-    "Turn the spoken audio into a written transcript.",
-    "Write down the contents of the audio verbatim.",
-    "Transcribe the voice you hear into text.",
-    "Convert the spoken audio into text format.",
-    "Type what is being spoken in the audio.",
-    "Translate the audio speech into written words.",
-    "Write the audio's dialogue in written form.",
-    "Record the verbal content as written text.",
-    "Transcribe the spoken parts of the audio.",
-    "Note down everything you hear in the audio.",
-    "Capture every word from the audio in text.",
-    "Put the spoken audio into text form.",
-    "Transcribe the audible content into words.",
-    "Translate the oral content into written text.",
-    "Type out everything heard in the audio.",
-    "Write down the spoken parts verbatim.",
-    "Document the spoken audio in text form.",
-    "Capture the verbal exchanges in written text.",
-    "Transcribe each word you hear accurately.",
-    "Turn the audio into a textual document.",
-    "Transcribe the sound into written words.",
-    "Write the audio transcript in your own words.",
-    "Document in text what you hear in the audio.",
-    "Record in text the spoken parts of the audio.",
-    "Transcribe the narrative you hear into text.",
-    "Capture the spoken narrative in written form.",
-    "Convert the verbal audio into written script.",
-    "Note down the spoken words in the audio.",
-    "Write in text form what is spoken in the audio.",
-    "Record the audio's spoken words verbatim.",
-    "Jot down the audio's dialogue accurately.",
-    "Transcribe the verbal parts into written words.",
-    "Translate the audio's spoken content into text.",
-    "Document the audio dialogue in written form.",
-    "Type out the words spoken in the audio verbatim.",
-    "Write down word for word what is said in the audio.",
-    "Transcribe the entire audio content into text.",
-    "Note down precisely what is said in the audio.",
-    "Capture in text the spoken content of the audio.",
-    "Record the spoken audio into written language.",
-    "Write the essence of the audio in text form.",
-    "Transcribe the words you hear in the audio.",
-    "Translate every spoken word into written text.",
-    "Convert the oral speech into a written format.",
-    "Jot down the words spoken in the audio.",
-    "Record every word from the audio in writing.",
-    "Document the entire audio in written form.",
-    "Transcribe the spoken language into text.",
-    "Write down the audio's words exactly as spoken.",
-    "Capture the spoken word in written format.",
-    "Type out verbatim the spoken audio content.",
-    "Write precisely what you hear from the audio.",
-]
-
-# "Read the following text with emotion."
-# "Read the following text."
-
-tts_instructions = [
-    "Speak:",
-    "Expressively read the text that follows.",
-    "Convey the upcoming text with emotion.",
-    "Deliver the following passage with heartfelt expression.",
-    "Evoke emotion while reading the text below.",
-    "With feeling, please read the text that comes next.",
-    "Infuse the upcoming words with emotional depth as you read.",
-    "Let your emotions guide you as you read the following lines.",
-    "Channel emotion into your reading of the next passage.",
-    "Read the text below with a sense of emotion.",
-    "Bring the following words to life with emotional expression.",
-    "Engage emotionally with the text as you read it aloud.",
-    "Imbue the subsequent text with feeling as you read.",
-    "Read the following content with genuine emotion.",
-    "Allow your feelings to resonate through the upcoming text.",
-    "Emotionally interpret the text that follows.",
-    "Read the ensuing passage with deep feeling.",
-    "Convey the text below with genuine emotional depth.",
-    "Read the text that comes next, letting your emotions flow.",
-    "With emotion, present the following words.",
-    "Let your emotional expression enhance the next text.",
-    "Embrace emotion as you read the following passage.",
-    "Read aloud the text below with emotive expression.",
-    "Infuse the upcoming lines with emotional intensity.",
-    "With sincerity, read the following text with emotion.",
-    "Project emotion as you deliver the text that follows.",
-    "Let the next words be read with a wealth of emotion.",
-    "Give the upcoming text an emotional rendition.",
-    "With emotion, read the text that is presented next.",
-    "Convey the essence of the following text with heartfelt emotion.",
-    "Inject emotional depth into your reading of the next passage.",
-    "Bring out the emotional undertones in the following text.",
-    "Embody the emotions as you read the text below.",
-    "Express the following narrative with emotional depth.",
-    "Let emotion permeate your reading of the upcoming passage.",
-    "Interpret the following text with a rich emotional tone.",
-    "Elicit emotion through your reading of the next content.",
-    "Read the subsequent text with a deep emotional connection.",
-    "Emote the essence of the text that follows in your reading.",
-    "Render the following lines with emotional expression.",
-    "Expressively interpret the upcoming text.",
-    "Immerse in emotion as you read the following passage.",
-    "Engage with the text below on an emotional level as you read.",
-    "With emotional clarity, read the next text.",
-    "Let an emotional depth inform your reading of the following words.",
-    "Express the following content with deep emotional resonance.",
-    "Deliver the upcoming text with a range of emotions.",
-    "Narrate the following lines with emotional expressiveness.",
-    "Convey emotional texture as you read the text below.",
-    "Instill the next passage with emotive power as you read.",
-    "Read the ensuing text with a palette of emotions.",
-    "With a depth of feeling, present the next text.",
-    "Inflect the upcoming words with emotional vibrancy.",
-    "Emotionally engage with the text that follows in your reading.",
-    "Lend emotional expression to the passage below.",
-    "Evoke a spectrum of emotions as you read the next lines.",
-    "Channel a rich emotional tone into the following text.",
-    "With feeling, convey the essence of the upcoming passage.",
-    "Read the text that comes next with emotional fervor.",
-    "Render the following words with emotional authenticity.",
-    "Give the upcoming passage an emotive interpretation.",
-    "Allow your reading of the text below to be emotionally driven.",
-    "Imbue the next lines with a sense of emotion.",
-    "Emotionally animate the following text as you read.",
-    "Bring emotional depth to the passage that follows.",
-    "Articulate the text below with emotional nuance.",
-    "Project a range of emotions as you read the upcoming text.",
-    "With emotion, breathe life into the following words.",
-    "Narrate the ensuing text with heartfelt emotion.",
-    "Convey the text that follows with emotional richness.",
-    "Read aloud the next passage with a depth of emotion.",
-    "Emphasize emotional expression in your reading of the text below.",
-    "Let your reading of the following lines be emotionally charged.",
-    "With a heartfelt approach, read the upcoming text.",
-    "Express the essence of emotion as you deliver the next passage.",
-    "Read the following text, infused with emotional energy.",
-    "Allow the text that comes next to be expressed with emotion.",
-    "Convey the following passage with an emotional depth.",
-    "Emotionally render the text that follows.",
-    "With an emotional undertone, read the upcoming words.",
-    "Read the text below, letting emotion guide your expression.",
-    "Elicit an emotional response through your reading of the next passage.",
-    "Give the following lines an emotive delivery.",
-    "Read the upcoming text with emotional sincerity.",
-    "Narrate the text that follows with an emotional touch.",
-    "Deliver the following words with an emotive clarity.",
-    "Express the next passage with a range of emotional tones.",
-    "Immerse yourself emotionally in the text below as you read.",
-    "Let the ensuing text be conveyed with profound emotion.",
-    "Infuse the following lines with a sense of heartfelt emotion.",
-    "Emotionally engage with the upcoming text in your reading.",
-    "Convey deep emotion as you read the text that follows.",
-    "Let your reading of the next passage be rich in emotion.",
-    "With emotional depth, narrate the following text.",
-    "Read the text below, capturing its emotional essence.",
-    "Emote through your reading of the upcoming lines.",
-    "Please read the text that follows aloud.",
-    "Proceed to vocalize the upcoming text.",
-    "Kindly articulate the subsequent text.",
-    "Go ahead and pronounce the text below.",
-    "Could you recite the forthcoming passage?",
-    "Start reading the text below out loud.",
-    "Announce the following text audibly.",
-    "Voice the text that comes next.",
-    "Read through the following lines aloud.",
-    "Narrate the text presented below.",
-    "Elevate your voice for the upcoming script.",
-    "Broadcast the text that follows.",
-    "Project the subsequent lines audibly.",
-    "Give voice to the text underneath.",
-    "Unfold the following text with your voice.",
-    "Engage in reading the next piece of text aloud.",
-    "Orate the following series of words.",
-    "Enunciate the text appearing next.",
-    "Verbally present the upcoming text.",
-    "Articulate the passage that follows.",
-    "Read aloud the text that's coming up.",
-    "Proclaim the subsequent words.",
-    "Vocalize the narrative below.",
-    "Bring the following text to life by reading it aloud.",
-    "Express the next text with your voice.",
-    "Render the following text audibly.",
-    "Voice out the lines that follow.",
-    "Orally deliver the upcoming text.",
-    "Loudly read out the text below.",
-    "Share the next text by reading it out loud.",
-    "Speak the following passage aloud.",
-    "Let your voice carry the upcoming words.",
-    "Annunciate the text that follows.",
-    "Sound out the subsequent text.",
-    "Aurally present the text below.",
-    "Elocute the forthcoming lines.",
-    "Recite the text below with clarity.",
-    "Make the next text heard by reading aloud.",
-    "Bring forth your voice for the following script.",
-    "Read the text that ensues out loud.",
-    "Deliver the following lines vocally.",
-    "Voice the ensuing text.",
-    "Publicly read the text that follows.",
-    "Loudly narrate the subsequent text.",
-    "Express the following text through your voice.",
-    "Verbally articulate the next passage.",
-    "Read the forthcoming text clearly.",
-    "Announce the next set of words aloud.",
-    "Broadcast the following narrative.",
-    "Articulate the text coming up next.",
-    "Enunciate the passage that follows clearly.",
-    "Recite the subsequent text audibly.",
-    "Speak out the text below.",
-    "Project your voice with the following words.",
-    "Read the next lines aloud.",
-    "Vocalize the text that is to follow.",
-    "Narrate aloud the text below.",
-    "Orate the forthcoming script.",
-    "Pronounce the next passage.",
-    "Read out the subsequent text.",
-    "Let the following words be heard by reading them aloud.",
-    "Express the text that follows with your voice.",
-    "Give audible life to the text below.",
-    "Speak the ensuing text clearly.",
-    "Make the forthcoming text audible.",
-    "Project the next series of words audibly.",
-    "Voice out the following narrative.",
-    "Elevate the subsequent text with your voice.",
-    "Bring the next passage to audible life.",
-    "Read the lines that come next out loud.",
-    "Announce the text below with clarity.",
-    "Vocalize the script that follows.",
-    "Narrate the following text with emphasis.",
-    "Deliver the upcoming words with your voice.",
-    "Articulate the next set of lines.",
-    "Verbally convey the following text.",
-    "Present the subsequent text vocally.",
-    "Enunciate the upcoming passage loudly.",
-    "Orally render the text that follows.",
-    "Speak out the subsequent narrative.",
-    "Proclaim the next text audibly.",
-    "Elocute the following lines with clarity.",
-    "Give voice to the upcoming script.",
-    "Let your voice express the text below.",
-    "Annunciate the following words clearly.",
-    "Sound out the text that is next.",
-    "Aurally convey the subsequent passage.",
-    "Read the text up next aloud.",
-]
-
-prompt_dict = {
-    "asr": asr_instructions,
-    "tts": tts_instructions,
-}

+ 580 - 0
fish_speech/datasets/semantic.py

@@ -0,0 +1,580 @@
+import random
+from dataclasses import dataclass
+from itertools import chain
+from pathlib import Path
+from random import Random
+from typing import Optional, Union
+
+import numpy as np
+import pyarrow.parquet as pq
+import torch
+import torch.nn.functional as F
+from datasets.download.streaming_download_manager import xopen
+from huggingface_hub import HfApi
+from lightning import LightningDataModule
+from torch.distributed import get_rank, get_world_size, is_initialized
+from torch.utils.data import DataLoader, IterableDataset, get_worker_info
+from transformers import AutoTokenizer
+
+from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
+from fish_speech.datasets.protos.text_data_pb2 import SampledData
+from fish_speech.datasets.protos.text_data_stream import read_pb_stream
+from fish_speech.text.clean import clean_text
+from fish_speech.utils import RankedLogger
+from fish_speech.utils.braceexpand import braceexpand
+
+log = RankedLogger(__name__, rank_zero_only=True)
+
+
+def split_by_rank_worker(files):
+    # We need to know the total number of devices
+    # to split the data properly
+
+    total_devices = 1
+    if is_initialized():
+        total_devices = get_world_size()
+
+    worker_info = get_worker_info()
+    if worker_info is not None:
+        total_devices *= worker_info.num_workers
+
+    if len(files) < total_devices:
+        # Repeat the files N times to match the number of devices
+        files = files * (total_devices // len(files) + 1)
+
+    # DDP
+    if is_initialized():
+        files = files[get_rank() :: get_world_size()]
+
+    # Split by worker
+    if worker_info is not None:
+        files = files[worker_info.id :: worker_info.num_workers]
+
+    return files
+
+
+class AutoTextSemanticInstructionDataset(IterableDataset):
+    """
+    Auto Augment Dataset by Speaker
+
+    1. Random concatenate multiple sentences from the same speaker to form a longer sentence
+    2. Automatically normalize the text
+
+    For interactive mode, we use the following format (multiple sequences):
+    <s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
+
+    For non-interactive mode, we use the following format (one long sequence):
+    <s> [INST] text [/INST] ... </s>
+    """
+
+    def __init__(
+        self,
+        proto_files: list[str],
+        seed: int = 42,
+        interactive_prob: float = 0.5,
+        max_length: int = 1024,
+        tokenizer: AutoTokenizer = None,
+        use_speaker: bool | float = True,
+        causal: bool = True,
+        use_negative_samples: bool = False,
+        num_codebooks: Optional[int] = None,
+        skip_text_prob: float = 0.0,
+    ):
+        """
+        Args:
+            proto_files: proto buf files if using local data
+            seed: random seed
+            interactive_prob: probability to use interactive mode
+            max_length: max length of the text
+            tokenizer: tokenizer
+            use_speaker: include speaker information in the prompt
+            causal: use causal sampling when using local data, disable will lead to random sampling
+            use_negative_samples: generate negative samples
+            num_codebooks: number of codebooks, if None, it will be automatically detected
+            skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
+        """
+
+        super().__init__()
+
+        assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
+
+        self.seed = seed
+        self.max_length = max_length
+        self.tokenizer = tokenizer
+        self.interactive_prob = interactive_prob
+        self.use_speaker = use_speaker
+        self.proto_files = proto_files
+        self.causal = causal
+        self.use_negative_samples = use_negative_samples
+        self.num_codebooks = num_codebooks
+        self.skip_text_prob = skip_text_prob
+
+        self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>")
+        self.groups = None
+
+    def init_mock_data_server(self):
+        if self.groups is not None:
+            return
+
+        # Expand the proto files
+        expanded_proto_files = []
+        for filename in self.proto_files:
+            for i in braceexpand(filename):
+                i = Path(i)
+                if i.is_file():
+                    expanded_proto_files.append(i)
+                elif i.is_dir():
+                    expanded_proto_files.extend(i.rglob("*.proto"))
+                    expanded_proto_files.extend(i.rglob("*.protos"))
+                else:
+                    raise ValueError(f"{i} is not a file or directory")
+
+        expanded_proto_files = sorted(expanded_proto_files)
+        Random(self.seed).shuffle(expanded_proto_files)
+
+        self.groups = []
+        shard_proto_files = split_by_rank_worker(expanded_proto_files)
+        log.info(
+            f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
+        )
+
+        count = 0
+        for filename in shard_proto_files:
+            with open(filename, "rb") as f:
+                for text_data in read_pb_stream(f):
+                    self.groups.append(text_data)
+                    count += 1
+
+        log.info(f"Read total {count} groups of data")
+
+        # Shuffle the lines
+        Random(self.seed).shuffle(self.groups)
+        self.group_weights = [len(i.sentences) for i in self.groups]
+
+    def __iter__(self):
+        while True:
+            yield self.augment()
+
+    def tokenize_sentence(self, sentence: str):
+        sentence = clean_text(sentence)
+        tokens = self.tokenizer.encode(
+            f"{sentence}",
+            max_length=10**6,
+            add_special_tokens=False,
+            truncation=False,
+        )
+        return sentence, len(tokens)
+
+    def sample_data(self):
+        if self.groups is None:
+            self.init_mock_data_server()
+
+        # Shuffle unique lines, estimate that each sample is at least 20 tokens
+        num_samples = self.max_length // 20
+
+        # choice group based on their number of samples
+        group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
+
+        if self.causal:
+            # Sample in order
+            if num_samples >= len(group.sentences):
+                samples = group.sentences
+            else:
+                begin = random.randint(0, len(group.sentences) - num_samples)
+                samples = group.sentences[begin : begin + num_samples]
+        else:
+            samples = random.choices(
+                group.sentences, k=min(num_samples, len(group.sentences))
+            )
+
+        return SampledData(
+            source=group.source,
+            name=group.name,
+            samples=samples,
+        )
+
+    def augment(self):
+        final_text, final_semantic = [], []
+        response = self.sample_data()
+        if len(response.samples) == 0:
+            # Invalid group
+            return None
+
+        samples = list(response.samples)
+        idx = 0
+        use_interactive = random.random() < self.interactive_prob
+
+        if use_interactive is False:
+            # Random sample based on speaker using a truncated normal distribution
+            a = torch.tensor([0], dtype=torch.float32)
+            torch.nn.init.trunc_normal_(
+                a,
+                mean=self.max_length // 2,
+                std=self.max_length // 4,
+                a=10,
+                b=self.max_length,
+            )
+            remaining_tokens = a.long().item() - 4
+        else:
+            remaining_tokens = self.max_length
+
+        # Use speaker
+        if isinstance(self.use_speaker, float):
+            use_speaker = random.random() < self.use_speaker
+        else:
+            use_speaker = self.use_speaker
+
+        all_tokens, all_labels = [], []
+        while remaining_tokens > 0 and len(samples) > 0:
+            sentence = samples.pop(0)
+
+            text = random.choice(sentence.texts)
+            text, length = self.tokenize_sentence(text)
+            remaining_tokens -= length + len(sentence.semantics[0].values)
+
+            if use_interactive is False:
+                final_text.append(text)
+                final_semantic.append(sentence.semantics)
+            else:
+                # For interactive mode, we only apply speaker for the first sentence
+                # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
+                tokens, labels = self.pack_sentences(
+                    sentences=[text],
+                    semantics=[sentence.semantics],
+                    speaker=response.name if use_speaker else None,
+                    add_bos=idx == 0,
+                    skip_text=random.random() < self.skip_text_prob,
+                )
+
+                all_tokens.append(tokens)
+                all_labels.append(labels)
+
+            idx += 1
+
+        if use_interactive is False:
+            tokens, labels = self.pack_sentences(
+                final_text,
+                semantics=final_semantic,
+                speaker=response.name if use_speaker else None,
+                add_bos=True,
+            )
+            all_tokens.append(tokens)
+            all_labels.append(labels)
+
+        tokens = torch.cat(all_tokens, dim=1)
+        labels = torch.cat(all_labels, dim=1)
+
+        # Verify that the length is correct
+        assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
+
+        # Verify bos token
+        assert tokens[0, 0] == self.tokenizer.bos_token_id
+
+        data = {"tokens": tokens, "labels": labels}
+
+        if self.use_negative_samples:
+            negative_samples = self.generate_negative_samples(all_tokens, all_labels)
+            data.update(negative_samples)
+
+        return data
+
+    def generate_negative_samples(self, all_tokens, all_labels):
+        new_tokens, new_labels = [], []
+
+        for tokens, labels in zip(all_tokens, all_labels):
+            # If all codebooks are not -100, we find where it starts
+            start = torch.where(labels[1:].sum(0) != -100 * (labels.size(0) - 1))[0][0]
+            assert (labels[1:, start:] != -100).all()  # This shouldn't happen
+
+            mode = random.choice(["repeat", "lost", "noise"])
+            begin = random.randint(start, labels.size(1) - 1)
+            end = random.randint(begin, labels.size(1) - 1)
+
+            if mode == "repeat":
+                tokens = torch.cat(
+                    [
+                        tokens[:, :begin],
+                        tokens[:, begin:end],
+                        tokens[:, begin:end],
+                        tokens[:, end:],
+                    ],
+                    dim=1,
+                )
+                labels = torch.cat(
+                    [
+                        labels[:, :begin],
+                        labels[:, begin:end],
+                        labels[:, begin:end],
+                        labels[:, end:],
+                    ],
+                    dim=1,
+                )
+            elif mode == "lost":
+                tokens = torch.cat([tokens[:, :begin], tokens[:, end:]], dim=1)
+                labels = torch.cat([labels[:, :begin], labels[:, end:]], dim=1)
+            elif mode == "noise":
+                middle_tokens, middle_labels = (
+                    tokens[:, begin:end],
+                    labels[:, begin:end],
+                )
+                random_order0 = torch.randperm(middle_tokens.size(1))
+                random_order1 = torch.randperm(middle_tokens.size(1))
+                middle_tokens = middle_tokens[:, random_order0]
+                middle_labels = middle_labels[:, random_order1]
+                tokens = torch.cat(
+                    [tokens[:, :begin], middle_tokens, tokens[:, end:]], dim=1
+                )
+                labels = torch.cat(
+                    [labels[:, :begin], middle_labels, labels[:, end:]], dim=1
+                )
+
+            new_tokens.append(tokens)
+            new_labels.append(labels)
+
+        tokens = torch.cat(new_tokens, dim=1)
+        labels = torch.cat(new_labels, dim=1)
+
+        # Verify that the length is correct
+        assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
+
+        return {"negative_tokens": tokens, "negative_labels": labels}
+
+    def pack_sentences(
+        self,
+        sentences: list[str],
+        semantics: list,
+        speaker: Optional[str] = None,
+        add_bos: bool = True,
+        skip_text: bool = False,
+    ):
+        if speaker is None:
+            speaker = "assistant"
+
+        cated_sentences = " ".join(sentences)
+        if skip_text:
+            cated_sentences = "<|skip_text|>"
+
+        final_text = "<|im_start|>user\n" + cated_sentences + "<|im_end|>"
+        final_text = final_text + f"<|im_start|>{speaker}\n"
+
+        encoded = self.tokenizer.encode(
+            final_text,
+            add_special_tokens=False,
+            truncation=False,
+            max_length=10**6,
+        )
+        semantic_length = sum([len(i[0].values) for i in semantics])
+        prompt_length = len(encoded)
+        num_codebooks = (
+            len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
+        )
+
+        bos_bias = 1 if add_bos else 0
+
+        # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
+        tokens = (
+            encoded
+            + [self.semantic_token_id] * semantic_length
+            + self.tokenizer.convert_tokens_to_ids(["<|im_end|>"])
+        )
+
+        if add_bos:
+            tokens = [self.tokenizer.bos_token_id] + tokens
+
+        # Codebook bos/padding: 0, eos: 1
+        codes = [
+            [CODEBOOK_PAD_TOKEN_ID] * (prompt_length + bos_bias)
+            for _ in range(num_codebooks)
+        ]
+        for segment in semantics:
+            for book_idx, book in zip(range(num_codebooks), segment):
+                for j in book.values:
+                    codes[book_idx].append(int(j) + 1)
+
+        for book in codes:
+            book.extend([CODEBOOK_PAD_TOKEN_ID] * 1)
+
+        tokens = [tokens] + codes
+
+        tokens = torch.tensor(tokens, dtype=torch.long)
+        labels = tokens.clone()
+
+        if skip_text:
+            # If text is not provided, the sentence is used for condition only, all labels are -100
+            torch.fill_(labels, -100)
+            return tokens, labels
+
+        # Mask out the <s> tokens for semantic, predict semantic tokens only
+        # Since we don't mask out the input tokens, the language modeling still works
+        labels[1:, : (prompt_length + bos_bias)] = -100
+
+        tokens = tokens[:, :-1]
+        labels = labels[:, 1:]
+
+        # Verify the padding is correct, and the last token is eos
+        assert add_bos is False or tokens[0, 0] == self.tokenizer.bos_token_id
+        assert (tokens[1:, : prompt_length + bos_bias] == CODEBOOK_PAD_TOKEN_ID).all()
+        assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all()
+
+        return tokens, labels
+
+
+@dataclass
+class TextDataCollator:
+    tokenizer: AutoTokenizer
+    max_length: int = 1024
+
+    def __call__(self, examples):
+        if "negative_tokens" in examples:
+            positive_examples = []
+            negative_examples = []
+
+            for i in examples:
+                positive_examples.append(
+                    {
+                        "tokens": i["tokens"],
+                        "labels": i["labels"],
+                    }
+                )
+                negative_examples.append(
+                    {
+                        "tokens": i["negative_tokens"],
+                        "labels": i["negative_labels"],
+                    }
+                )
+
+            examples = positive_examples + negative_examples
+
+        return self.batchify(examples)
+
+    def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
+        tokens, attention_masks, labels = [], [], []
+
+        # Calculate the max length
+        max_tokens_length = 0
+        for example in examples:
+            max_tokens_length = max(max_tokens_length, example[tokens_key].size(1))
+        max_tokens_length = min(max_tokens_length, self.max_length)
+
+        for example in examples:
+            _tokens = example[tokens_key][:, :max_tokens_length]
+            _labels = example[labels_key][:, :max_tokens_length]
+            _attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool)
+            tokens_length = _tokens.size(1)
+            _attention_mask[:tokens_length] = False
+
+            assert tokens_length == _labels.size(
+                1
+            ), f"{tokens_length} != {_labels.size(1)}"
+
+            if tokens_length < max_tokens_length:
+                _tokens = F.pad(
+                    _tokens,
+                    (0, max_tokens_length - tokens_length),
+                    value=self.tokenizer.eos_token_id,
+                )
+                _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
+                _labels = F.pad(
+                    _labels, (0, max_tokens_length - _labels.size(1)), value=-100
+                )
+
+            tokens.append(_tokens)
+            attention_masks.append(_attention_mask)
+            labels.append(_labels)
+
+        tokens = torch.stack(tokens, dim=0)
+        attention_masks = torch.stack(attention_masks, dim=0)
+        labels = torch.stack(labels, dim=0)
+
+        return {
+            "inputs": tokens,
+            "attention_masks": attention_masks,
+            "labels": labels,
+        }
+
+
+class InterleaveDataset(IterableDataset):
+    def __init__(
+        self,
+        datasets: list[IterableDataset],
+        probabilities: list[float],
+        seed: int = 42,
+    ):
+        super().__init__()
+
+        self.datasets = datasets
+        self.probabilities = probabilities
+        self.seed = seed
+
+    def __iter__(self):
+        rng = np.random.default_rng(self.seed)
+        dataset_iterators = [iter(dataset) for dataset in self.datasets]
+
+        while True:
+            # Random choice one
+            dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
+            dataset_iterator = dataset_iterators[dataset_idx]
+
+            try:
+                yield next(dataset_iterator)
+            except StopIteration:
+                # Exhausted, create a new iterator
+                dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
+                yield next(dataset_iterators[dataset_idx])
+
+
+class SemanticDataModule(LightningDataModule):
+    def __init__(
+        self,
+        train_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
+        val_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
+        batch_size: int = 32,
+        tokenizer: AutoTokenizer = None,
+        max_length: int = 1024,
+        num_workers: int = 4,
+    ):
+        super().__init__()
+
+        self.train_dataset = train_dataset
+        self.val_dataset = val_dataset
+        self.batch_size = batch_size
+        self.tokenizer = tokenizer
+        self.max_length = max_length
+        self.num_workers = num_workers
+
+    def train_dataloader(self):
+        return DataLoader(
+            self.train_dataset,
+            batch_size=self.batch_size,
+            collate_fn=TextDataCollator(self.tokenizer, self.max_length),
+            num_workers=self.num_workers,
+            persistent_workers=True,
+        )
+
+    def val_dataloader(self):
+        return DataLoader(
+            self.val_dataset,
+            batch_size=self.batch_size,
+            collate_fn=TextDataCollator(self.tokenizer, self.max_length),
+            num_workers=self.num_workers,
+            persistent_workers=True,
+        )
+
+
+if __name__ == "__main__":
+    from tqdm import tqdm
+
+    ds = AutoAugTextDataset(
+        ["data/protos"],
+        tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"),
+        use_speaker=False,
+        interactive_prob=1.0,
+        use_negative_samples=False,
+        skip_text_prob=0.5,
+    )
+
+    for i in ds:
+        print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
+        # i["labels"][0][i["labels"][0] == -100] = 0
+        # print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False))
+        break

+ 0 - 652
fish_speech/datasets/text.py

@@ -1,652 +0,0 @@
-import gzip
-import io
-import json
-import random
-from dataclasses import dataclass
-from pathlib import Path
-from random import Random
-from typing import Optional, Union
-
-import numpy as np
-import torch
-import torch.nn.functional as F
-import zstandard as zstd
-from lightning import LightningDataModule
-from torch.distributed import get_rank, get_world_size, is_initialized
-from torch.utils.data import DataLoader, IterableDataset, get_worker_info
-from transformers import AutoTokenizer
-
-from fish_speech.conversation import (
-    CODEBOOK_PAD_TOKEN_ID,
-    SKIP_TEXT_STRING,
-    Conversation,
-    Message,
-    encode_conversation,
-)
-from fish_speech.datasets.prompts import asr_instructions, tts_instructions
-from fish_speech.datasets.protos.text_data_pb2 import SampledData
-from fish_speech.datasets.protos.text_data_stream import read_pb_stream
-from fish_speech.text.clean import clean_text
-from fish_speech.utils import RankedLogger
-from fish_speech.utils.braceexpand import braceexpand
-
-log = RankedLogger(__name__, rank_zero_only=True)
-
-DCTX = zstd.ZstdDecompressor(max_window_size=2**31)
-
-
-def split_by_rank_worker(files):
-    # We need to know the total number of devices
-    # to split the data properly
-
-    total_devices = 1
-    if is_initialized():
-        total_devices = get_world_size()
-
-    worker_info = get_worker_info()
-    if worker_info is not None:
-        total_devices *= worker_info.num_workers
-
-    if len(files) < total_devices:
-        # Repeat the files N times to match the number of devices
-        files = files * (total_devices // len(files) + 1)
-
-    # DDP
-    if is_initialized():
-        files = files[get_rank() :: get_world_size()]
-
-    # Split by worker
-    if worker_info is not None:
-        files = files[worker_info.id :: worker_info.num_workers]
-
-    return files
-
-
-def expand_split_proto_files(proto_files, seed: int = 42):
-    # Expand the proto files
-    expanded_proto_files = []
-    for filename in proto_files:
-        for i in braceexpand(filename):
-            i = Path(i)
-            if i.is_file():
-                expanded_proto_files.append(i)
-            elif i.is_dir():
-                expanded_proto_files.extend(i.rglob("*.proto"))
-                expanded_proto_files.extend(i.rglob("*.protos"))
-            else:
-                raise ValueError(f"{i} is not a file or directory")
-
-    expanded_proto_files = sorted(expanded_proto_files)
-    Random(seed).shuffle(expanded_proto_files)
-    return split_by_rank_worker(expanded_proto_files)
-
-
-class TextPretrainDataset(IterableDataset):
-    def __init__(
-        self,
-        source: str,
-        seed: int = 42,
-        max_length: int = 1024,
-        tokenizer: AutoTokenizer = None,
-        num_codebooks: int = 2,
-    ):
-        super().__init__()
-
-        self.source = Path(source)
-        self.seed = seed
-        self.max_length = max_length
-        self.tokenizer = tokenizer
-        self.num_codebooks = num_codebooks
-
-        if self.source.is_file():
-            with open(self.source, "r") as f:
-                files = f.read().strip().split("\n")
-            self.root = self.source.parent
-        else:
-            files = [
-                str(i.relative_to(self.source)) for i in self.source.rglob("*.jsonl")
-            ]
-            self.root = self.source
-
-        # Get sharded files
-        self.files = sorted(files)
-
-        Random(seed).shuffle(self.files)
-
-    def __iter__(self):
-        files = split_by_rank_worker(self.files)
-        random.shuffle(files)
-
-        for filename in files:
-            try:
-                yield from self.parse_data(filename)
-            except Exception as e:
-                log.exception(f"Failed to parse {filename}: {e}")
-
-    def read_jsonl(self, filename: str):
-        with open(self.root / filename, "rb") as f:
-            if filename.endswith(".zst"):
-                stream_reader = DCTX.stream_reader(f)
-            elif filename.endswith(".gz"):
-                stream_reader = gzip.open(f, "rb")
-            elif filename.endswith(".jsonl"):
-                stream_reader = f
-            else:
-                raise ValueError(f"Unknown file type: {filename}")
-
-            stream = io.TextIOWrapper(stream_reader, encoding="utf-8")
-
-            # Parse jsonl
-            for line in stream:
-                line = json.loads(line)
-                yield line
-
-    def parse_data(self, filename: str):
-        for line in self.read_jsonl(filename):
-            # encode
-            tokens = self.tokenizer.encode(
-                line["text"],
-                add_special_tokens=False,
-                truncation=False,
-                max_length=10**6,
-            )
-
-            tokens = (
-                [self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
-            )
-
-            if len(tokens) > self.max_length:
-                tokens = tokens[: self.max_length]
-
-            tokens = self.pad_codebooks(tokens)
-            labels = tokens.clone()
-            tokens = tokens[:, :-1]
-            labels = labels[:, 1:]
-            labels[1:] = -100  # no loss on codebook
-
-            yield {"tokens": tokens, "labels": labels}
-
-    def pad_codebooks(self, tokens):
-        placeholder_multi_codebook = (
-            torch.zeros((self.num_codebooks, len(tokens)), dtype=torch.long)
-            + CODEBOOK_PAD_TOKEN_ID
-        )
-        return torch.concat(
-            [
-                torch.tensor([tokens], dtype=torch.long),
-                placeholder_multi_codebook,
-            ],
-            dim=0,
-        )
-
-
-class TextInstructionDataset(TextPretrainDataset):
-    def parse_data(self, filename: str):
-        for line in self.read_jsonl(filename):
-            messages = []
-            for conversation in line["conversations"]:
-                role = {
-                    "human": "user",
-                    "gpt": "assistant",
-                    "system": "system",
-                }[conversation["from"]]
-
-                message = Message(
-                    role=role,
-                    parts=[conversation["value"]],
-                )
-                messages.append(message)
-
-            conversation = Conversation(messages=messages)
-            tokens, labels = encode_conversation(
-                conversation,
-                self.tokenizer,
-                num_codebooks=self.num_codebooks,
-            )
-
-            yield {"tokens": tokens, "labels": labels}
-
-
-def semantic_to_tensor(semantics):
-    num_codebooks = len(semantics)
-    codes = [[] for _ in range(num_codebooks)]
-
-    for book_idx, book in zip(range(num_codebooks), semantics):
-        for j in book.values:
-            codes[book_idx].append(int(j))
-
-    return torch.tensor(codes, dtype=torch.int)
-
-
-class AutoTextSemanticInstructionDataset(IterableDataset):
-    def __init__(
-        self,
-        proto_files: list[str],
-        seed: int = 42,
-        max_length: int = 1024,
-        tokenizer: AutoTokenizer = None,
-        causual: Union[bool, float] = True,
-        num_codebooks: Optional[int] = None,
-        skip_text_prob: float = 0.0,
-        asr_prob: float = 0.0,
-    ):
-        """
-        Args:
-            proto_files: proto buf files if using local data
-            seed: random seed
-            max_length: max length of the text
-            tokenizer: tokenizer
-            causual: use causual sampling when using local data, disable will lead to random sampling
-            num_codebooks: number of codebooks, if None, it will be automatically detected
-            skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
-            asr_prob: probability to use ASR
-        """
-
-        super().__init__()
-
-        assert 0 <= skip_text_prob <= 1, "skip_text_prob must be in [0, 1]"
-        assert 0 <= asr_prob <= 1, "asr_prob must be in [0, 1]"
-
-        self.seed = seed
-        self.max_length = max_length
-        self.tokenizer = tokenizer
-        self.proto_files = proto_files
-        self.causual = causual
-        self.num_codebooks = num_codebooks
-        self.skip_text_prob = skip_text_prob
-        self.asr_prob = asr_prob
-        self.groups = None
-
-    def init_mock_data_server(self):
-        if self.groups is not None:
-            return
-
-        self.groups = []
-        shard_proto_files = expand_split_proto_files(self.proto_files, seed=self.seed)
-        log.info(f"Reading {len(shard_proto_files)} files")
-
-        count = 0
-        for filename in shard_proto_files:
-            with open(filename, "rb") as f:
-                for text_data in read_pb_stream(f):
-                    self.groups.append(text_data)
-                    count += 1
-
-        log.info(f"Read total {count} groups of data")
-
-        # Shuffle the lines
-        Random(self.seed).shuffle(self.groups)
-        self.group_weights = [len(i.sentences) for i in self.groups]
-
-    def __iter__(self):
-        while True:
-            yield self.augment()
-
-    def tokenize_sentence(self, sentence: str):
-        sentence = clean_text(sentence)
-        tokens = self.tokenizer.encode(
-            f"{sentence}",
-            max_length=10**6,
-            add_special_tokens=False,
-            truncation=False,
-        )
-        return sentence, len(tokens)
-
-    def sample_data(self):
-        if self.groups is None:
-            self.init_mock_data_server()
-
-        # Shuffle unique lines, estimate that each sample is at least 20 tokens
-        num_samples = self.max_length // 20
-
-        # choice group based on their number of samples
-        group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
-
-        causual = self.causual
-        if isinstance(self.causual, float):
-            causual = random.random() < self.causual
-
-        if causual:
-            # Sample in order
-            if num_samples >= len(group.sentences):
-                samples = group.sentences
-            else:
-                begin = random.randint(0, len(group.sentences) - num_samples)
-                samples = group.sentences[begin : begin + num_samples]
-        else:
-            samples = random.choices(
-                group.sentences, k=min(num_samples, len(group.sentences))
-            )
-
-        return SampledData(
-            source=group.source,
-            name=group.name,
-            samples=samples,
-        )
-
-    def augment(self):
-        response = self.sample_data()
-        if len(response.samples) == 0:
-            # Invalid group
-            return None
-
-        samples = list(response.samples)
-        idx = 0
-        remaining_tokens = self.max_length
-
-        all_messages = []
-        while remaining_tokens > 0 and len(samples) > 0:
-            sentence = samples.pop(0)
-
-            text = random.choice(sentence.texts)
-            text, length = self.tokenize_sentence(text)
-            remaining_tokens -= length + len(sentence.semantics[0].values)
-
-            # For interactive mode, we only apply speaker for the first sentence
-            # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
-
-            if random.random() < self.asr_prob:
-                all_messages.append(
-                    Message(
-                        role="user",
-                        parts=[
-                            random.choice(asr_instructions),
-                            semantic_to_tensor(sentence.semantics),
-                        ],
-                    )
-                )
-                all_messages.append(
-                    Message(
-                        role="assistant",
-                        parts=[text],
-                    )
-                )
-            else:
-                skip_text = random.random() < self.skip_text_prob
-                if skip_text:
-                    text = SKIP_TEXT_STRING
-
-                all_messages.append(
-                    Message(
-                        role="user",
-                        parts=[random.choice(tts_instructions) + text],
-                        mask_labels=skip_text,
-                    )
-                )
-                all_messages.append(
-                    Message(
-                        role="assistant",
-                        parts=[semantic_to_tensor(sentence.semantics)],
-                        mask_labels=skip_text,
-                    )
-                )
-
-            idx += 1
-
-        tokens, labels = encode_conversation(
-            Conversation(messages=all_messages),
-            self.tokenizer,
-            num_codebooks=self.num_codebooks,
-        )
-
-        # Verify that the length is correct
-        assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
-
-        # Verify bos token
-        assert tokens[0, 0] == self.tokenizer.bos_token_id
-
-        return {"tokens": tokens, "labels": labels}
-
-
-class SemanticInstructionDataset(IterableDataset):
-    def __init__(
-        self,
-        proto_files: list[str],
-        seed: int = 42,
-        max_length: int = 1024,
-        tokenizer: AutoTokenizer = None,
-        num_codebooks: Optional[int] = None,
-    ):
-        super().__init__()
-
-        self.seed = seed
-        self.max_length = max_length
-        self.tokenizer = tokenizer
-        self.proto_files = proto_files
-        self.num_codebooks = num_codebooks
-
-    def get_data_generator(self):
-        shard_proto_files = expand_split_proto_files(self.proto_files, seed=self.seed)
-        random.shuffle(shard_proto_files)
-        log.info(f"Fetched {len(shard_proto_files)} files")
-
-        for filename in shard_proto_files:
-            with open(filename, "rb") as f:
-                for group in read_pb_stream(f):
-                    yield group
-
-    def pack_one_group(self, group):
-        sentences = group.sentences
-
-        messages = []
-        for idx, sentence in enumerate(sentences):
-            role = "user" if idx % 2 == 0 else "assistant"
-            semantic = semantic_to_tensor(sentence.semantics)
-            text = random.choice(sentence.texts)
-            parts = [semantic]
-            if role == "assistant":
-                # Let model to predict the text first
-                prev_text = random.choice(sentences[idx - 1].texts)
-                # parts.insert(0, f"Q: {prev_text}\nA: {text}")
-            messages.append(
-                Message(
-                    role=role,
-                    parts=parts,
-                )
-            )
-
-        conversation = Conversation(messages=messages)
-        tokens, labels = encode_conversation(
-            conversation,
-            self.tokenizer,
-            num_codebooks=self.num_codebooks,
-        )
-
-        return {"tokens": tokens, "labels": labels}
-
-    def __iter__(self):
-        for group in self.get_data_generator():
-            try:
-                yield self.pack_one_group(group)
-            except Exception as e:
-                log.exception(f"Failed to parse {group}: {e}")
-
-
-@dataclass
-class TextDataCollator:
-    tokenizer: AutoTokenizer
-    max_length: int = 1024
-
-    def __call__(self, examples):
-        if "negative_tokens" in examples:
-            positive_examples = []
-            negative_examples = []
-
-            for i in examples:
-                positive_examples.append(
-                    {
-                        "tokens": i["tokens"],
-                        "labels": i["labels"],
-                    }
-                )
-                negative_examples.append(
-                    {
-                        "tokens": i["negative_tokens"],
-                        "labels": i["negative_labels"],
-                    }
-                )
-
-            examples = positive_examples + negative_examples
-
-        return self.batchify(examples)
-
-    def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
-        tokens, attention_masks, labels = [], [], []
-
-        # Calculate the max length
-        max_tokens_length = 0
-        for example in examples:
-            max_tokens_length = max(max_tokens_length, example[tokens_key].size(1))
-        max_tokens_length = min(max_tokens_length, self.max_length)
-
-        for example in examples:
-            _tokens = example[tokens_key][:, :max_tokens_length]
-            _labels = example[labels_key][:, :max_tokens_length]
-            _attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool)
-            tokens_length = _tokens.size(1)
-            _attention_mask[:tokens_length] = False
-
-            assert tokens_length == _labels.size(
-                1
-            ), f"{tokens_length} != {_labels.size(1)}"
-
-            if tokens_length < max_tokens_length:
-                _tokens = F.pad(
-                    _tokens,
-                    (0, max_tokens_length - tokens_length),
-                    value=self.tokenizer.eos_token_id,
-                )
-                _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
-                _labels = F.pad(
-                    _labels, (0, max_tokens_length - _labels.size(1)), value=-100
-                )
-
-            tokens.append(_tokens)
-            attention_masks.append(_attention_mask)
-            labels.append(_labels)
-
-        tokens = torch.stack(tokens, dim=0)
-        attention_masks = torch.stack(attention_masks, dim=0)
-        labels = torch.stack(labels, dim=0)
-
-        return {
-            "inputs": tokens,
-            "attention_masks": attention_masks,
-            "labels": labels,
-        }
-
-
-class InterleaveDataset(IterableDataset):
-    def __init__(
-        self,
-        datasets: list[IterableDataset],
-        probabilities: list[float],
-        seed: int = 42,
-    ):
-        super().__init__()
-
-        self.datasets = datasets
-        self.probabilities = probabilities
-        self.seed = seed
-
-    def __iter__(self):
-        rng = np.random.default_rng(self.seed)
-        dataset_iterators = [iter(dataset) for dataset in self.datasets]
-
-        while True:
-            # Random choice one
-            dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
-            dataset_iterator = dataset_iterators[dataset_idx]
-
-            try:
-                yield next(dataset_iterator)
-            except StopIteration:
-                # Exhausted, create a new iterator
-                dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
-                yield next(dataset_iterators[dataset_idx])
-
-
-class TextDataModule(LightningDataModule):
-    def __init__(
-        self,
-        train_dataset: Union[
-            AutoTextSemanticInstructionDataset,
-            TextPretrainDataset,
-            TextInstructionDataset,
-            InterleaveDataset,
-        ],
-        val_dataset: Union[
-            AutoTextSemanticInstructionDataset,
-            TextPretrainDataset,
-            TextInstructionDataset,
-            InterleaveDataset,
-        ],
-        batch_size: int = 32,
-        tokenizer: AutoTokenizer = None,
-        max_length: int = 1024,
-        num_workers: int = 4,
-    ):
-        super().__init__()
-
-        self.train_dataset = train_dataset
-        self.val_dataset = val_dataset
-        self.batch_size = batch_size
-        self.tokenizer = tokenizer
-        self.max_length = max_length
-        self.num_workers = num_workers
-
-    def train_dataloader(self):
-        return DataLoader(
-            self.train_dataset,
-            batch_size=self.batch_size,
-            collate_fn=TextDataCollator(self.tokenizer, self.max_length),
-            num_workers=self.num_workers,
-            persistent_workers=True,
-        )
-
-    def val_dataloader(self):
-        return DataLoader(
-            self.val_dataset,
-            batch_size=self.batch_size,
-            collate_fn=TextDataCollator(self.tokenizer, self.max_length),
-            num_workers=self.num_workers,
-            persistent_workers=True,
-        )
-
-
-if __name__ == "__main__":
-    from tqdm import tqdm
-
-    # ds = AutoTextSemanticInstructionDataset(
-    #     ["data/protos/sft/val/11labs"],
-    #     tokenizer=AutoTokenizer.from_pretrained("checkpoints/fish-speech-agent-1"),
-    #     skip_text_prob=1.0,
-    #     asr_prob=0.0,
-    #     num_codebooks=2,
-    # )
-    # ds = TextInstructionDataset(
-    #     source="data/openhermes2_5",
-    #     tokenizer=AutoTokenizer.from_pretrained("checkpoints/fish-speech-agent-1"),
-    # )
-
-    ds = SemanticInstructionDataset(
-        proto_files=["data/protos/sft/val/ultrachat_200k_spoken_openai"],
-        tokenizer=AutoTokenizer.from_pretrained("checkpoints/fish-speech-agent-1"),
-        num_codebooks=2,
-    )
-
-    for i in ds:
-        # print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
-        # i["labels"][0][i["labels"][0] == -100] = 0
-        # print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False))
-
-        length = i["tokens"].size(1)
-        print(i["tokens"].size(), i["tokens"].dtype)
-        for j in range(length):
-            print(
-                ds.tokenizer.decode(i["tokens"][0, j]),
-                i["tokens"][:, j],
-                i["labels"][:, j],
-            )
-            input()
-        break

+ 0 - 195
fish_speech/datasets/vits.py

@@ -1,195 +0,0 @@
-import random
-from dataclasses import dataclass
-from pathlib import Path
-from typing import Optional
-
-import librosa
-import numpy as np
-import torch
-import torch.distributed as dist
-from lightning import LightningDataModule
-from torch.utils.data import DataLoader, Dataset
-from torch.utils.data.distributed import DistributedSampler
-from transformers import AutoTokenizer
-
-from fish_speech.utils import RankedLogger
-
-logger = RankedLogger(__name__, rank_zero_only=False)
-
-
-class VITSDataset(Dataset):
-    def __init__(
-        self,
-        filelist: str,
-        tokenizer: AutoTokenizer,
-        sample_rate: int = 44100,
-        hop_length: int = 512,
-        min_duration: float = 1.5,
-        max_duration: float = 30.0,
-        suffix: str = ".lab",
-        sentence_mask_ratio: float = 0.0,
-    ):
-        super().__init__()
-
-        filelist = Path(filelist)
-        root = filelist.parent
-
-        self.files = []
-        for line in filelist.read_text(encoding="utf-8").splitlines():
-            path = root / line
-            self.files.append(path)
-
-        self.sample_rate = sample_rate
-        self.hop_length = hop_length
-        self.min_duration = min_duration
-        self.max_duration = max_duration
-        self.tokenizer = tokenizer
-        self.suffix = suffix
-        self.sentence_mask_ratio = sentence_mask_ratio
-
-    def __len__(self):
-        return len(self.files)
-
-    def get_item(self, idx):
-        audio_file = self.files[idx]
-        text_file = audio_file.with_suffix(self.suffix)
-
-        if text_file.exists() is False or audio_file.exists() is False:
-            return None
-
-        audio, _ = librosa.load(audio_file, sr=self.sample_rate, mono=True)
-        duration = len(audio) / self.sample_rate
-
-        # Pad to minimum duration
-        if duration < self.min_duration:
-            pad_duration = self.min_duration - duration
-            pad_samples = int(pad_duration * self.sample_rate)
-            audio = np.pad(audio, (0, pad_samples))
-
-        # Truncate to maximum duration
-        if duration > self.max_duration:
-            random_start = random.randint(
-                0, len(audio) - int(self.max_duration * self.sample_rate) - 1
-            )
-            audio = audio[
-                random_start : random_start + int(self.max_duration * self.sample_rate)
-            ]
-
-        max_value = np.abs(audio).max()
-        if max_value > 1.0:
-            audio = audio / max_value
-
-        if random.random() < self.sentence_mask_ratio:
-            text = "-"
-        else:
-            text = text_file.read_text(encoding="utf-8")
-
-        input_ids = self.tokenizer(text, return_tensors="pt").input_ids.squeeze(0)
-
-        return {
-            "audio": torch.from_numpy(audio),
-            "text": input_ids,
-        }
-
-    def __getitem__(self, idx):
-        try:
-            return self.get_item(idx)
-        except Exception as e:
-            import traceback
-
-            traceback.print_exc()
-            logger.error(f"Error loading {self.files[idx]}: {e}")
-            return None
-
-
-@dataclass
-class VITSCollator:
-    tokenizer: AutoTokenizer
-
-    def __call__(self, batch):
-        batch = [x for x in batch if x is not None]
-
-        audio_lengths = torch.tensor([len(x["audio"]) for x in batch])
-        audio_maxlen = audio_lengths.max()
-
-        text_lengths = torch.tensor([len(x["text"]) for x in batch])
-        text_maxlen = text_lengths.max()
-
-        # Rounds up to nearest multiple of 2 (audio_lengths)
-        audios = []
-        texts = []
-        for x in batch:
-            audios.append(
-                torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"])))
-            )
-
-            texts.append(
-                torch.nn.functional.pad(
-                    x["text"],
-                    (0, text_maxlen - len(x["text"])),
-                    value=self.tokenizer.eos_token_id,
-                )
-            )
-
-        return {
-            "audios": torch.stack(audios),
-            "audio_lengths": audio_lengths,
-            "texts": torch.stack(texts),
-            "text_lengths": text_lengths,
-        }
-
-
-class VITSDataModule(LightningDataModule):
-    def __init__(
-        self,
-        train_dataset: VITSDataset,
-        val_dataset: VITSDataset,
-        tokenizer: AutoTokenizer,
-        batch_size: int = 32,
-        num_workers: int = 4,
-        val_batch_size: Optional[int] = None,
-    ):
-        super().__init__()
-
-        self.train_dataset = train_dataset
-        self.val_dataset = val_dataset
-        self.batch_size = batch_size
-        self.val_batch_size = val_batch_size or batch_size
-        self.num_workers = num_workers
-        self.tokenizer = tokenizer
-
-    def train_dataloader(self):
-        return DataLoader(
-            self.train_dataset,
-            batch_size=self.batch_size,
-            collate_fn=VITSCollator(self.tokenizer),
-            num_workers=self.num_workers,
-            shuffle=False,
-            persistent_workers=True,
-        )
-
-    def val_dataloader(self):
-        return DataLoader(
-            self.val_dataset,
-            batch_size=self.val_batch_size,
-            collate_fn=VITSCollator(self.tokenizer),
-            num_workers=self.num_workers,
-            persistent_workers=True,
-        )
-
-
-if __name__ == "__main__":
-    tokenizer = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1")
-    dataset = VITSDataset(
-        "data/source/Genshin/filelist.train.txt", tokenizer=tokenizer, suffix=".lab"
-    )
-    dataloader = DataLoader(
-        dataset, batch_size=4, shuffle=False, collate_fn=VITSCollator(tokenizer)
-    )
-
-    for batch in dataloader:
-        print(batch["audios"].shape)
-        print(batch["audio_lengths"])
-        print(batch["texts"].shape)
-        print(batch["text_lengths"])
-        break

+ 6 - 104
fish_speech/models/text2semantic/lit_module.py

@@ -18,31 +18,23 @@ class TextToSemantic(L.LightningModule):
         model: NaiveTransformer,
         optimizer: Any,
         lr_scheduler: Any,
-        lora_config: Optional[LoraConfig] = None,
-        use_dpo: bool = False,
-        dpo_beta: float = 0.2,
     ):
         super().__init__()
 
         self.model = model
         self.optimizer_builder = optimizer
         self.lr_scheduler_builder = lr_scheduler
-        self.lora_config = lora_config
-        self.use_dpo = use_dpo  # We don't support reference model yet
-        self.dpo_beta = dpo_beta
-
-        if self.lora_config is not None:
-            setup_lora(self.model, self.lora_config)
 
     def forward(self, x):
         return self.model(x)
 
     def on_save_checkpoint(self, checkpoint):
-        if self.lora_config is None:
-            return
-
         # Save only LoRA parameters
         state_dict = checkpoint["state_dict"]
+        use_lora = any("lora" in name for name in state_dict.keys())
+        if not use_lora:
+            return
+
         for name in list(state_dict.keys()):
             if "lora" not in name:
                 state_dict.pop(name)
@@ -130,21 +122,15 @@ class TextToSemantic(L.LightningModule):
         token_logits = outputs.token_logits
         codebook_logits = outputs.codebook_logits
 
-        if self.use_dpo:
-            # Firtst half is positive, second half is negative
-            token_logits, negative_token_logits = token_logits.chunk(2)
-            codebook_logits, negative_codebook_logits = codebook_logits.chunk(2)
-            labels, negative_labels = labels.chunk(2)
-
         # Generate labels
-        base_loss = fast_cross_entropy_loss(
+        base_loss = F.cross_entropy(
             token_logits.view(-1, token_logits.size(-1)),
             labels[:, 0].reshape(-1),
             ignore_index=-100,
         )
 
         codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
-        semantic_loss = fast_cross_entropy_loss(
+        semantic_loss = F.cross_entropy(
             codebook_logits.view(-1, codebook_logits.size(-1)),
             codebook_labels.reshape(-1),
             ignore_index=-100,
@@ -152,74 +138,6 @@ class TextToSemantic(L.LightningModule):
 
         loss = base_loss + semantic_loss
 
-        # If we use dpo
-        if self.use_dpo:
-            negative_codebook_labels = negative_labels[
-                :, 1 : 1 + self.model.config.num_codebooks
-            ].mT
-
-            positive_codebook_logps = self.get_batch_logps(
-                codebook_logits, codebook_labels
-            )
-            negative_codebook_logps = self.get_batch_logps(
-                negative_codebook_logits, negative_codebook_labels
-            )
-
-            # TODO: implement the reference model, avoid screwing up the gradients
-            dpo_loss = -F.logsigmoid(
-                (positive_codebook_logps - negative_codebook_logps) * self.dpo_beta
-            ).mean()
-
-            chosen_rewards = self.dpo_beta * positive_codebook_logps.detach()
-            rejected_rewards = self.dpo_beta * negative_codebook_logps.detach()
-            reward_accuracy = (chosen_rewards > rejected_rewards).float().mean()
-            chosen_rewards, rejected_rewards = (
-                chosen_rewards.mean(),
-                rejected_rewards.mean(),
-            )
-
-            loss = loss + dpo_loss
-
-            self.log(
-                f"{stage}/dpo_loss",
-                dpo_loss,
-                on_step=is_train,
-                on_epoch=not is_train,
-                prog_bar=False,
-                logger=True,
-                sync_dist=not is_train,
-            )
-
-            self.log(
-                f"{stage}/chosen_rewards",
-                chosen_rewards,
-                on_step=is_train,
-                on_epoch=not is_train,
-                prog_bar=False,
-                logger=True,
-                sync_dist=not is_train,
-            )
-
-            self.log(
-                f"{stage}/rejected_rewards",
-                rejected_rewards,
-                on_step=is_train,
-                on_epoch=not is_train,
-                prog_bar=False,
-                logger=True,
-                sync_dist=not is_train,
-            )
-
-            self.log(
-                f"{stage}/reward_accuracy",
-                reward_accuracy,
-                on_step=is_train,
-                on_epoch=not is_train,
-                prog_bar=False,
-                logger=True,
-                sync_dist=not is_train,
-            )
-
         self.log(
             f"{stage}/loss",
             loss,
@@ -262,22 +180,6 @@ class TextToSemantic(L.LightningModule):
             sync_dist=not is_train,
         )
 
-        if self.model.config.num_codebooks != self.model.config.num_in_codebooks:
-            accuracy = self.get_accuracy(
-                codebook_logits[:, :, : self.model.config.num_in_codebooks],
-                codebook_labels[:, :, : self.model.config.num_in_codebooks],
-            )
-
-            self.log(
-                f"{stage}/top_5_accuracy_in",
-                accuracy,
-                on_step=is_train,
-                on_epoch=not is_train,
-                prog_bar=True,
-                logger=True,
-                sync_dist=not is_train,
-            )
-
         return loss
 
     def get_accuracy(self, logits, labels):

+ 2 - 2
fish_speech/models/text2semantic/llama.py

@@ -477,7 +477,7 @@ class DualARTransformer(BaseTransformer):
 
         # Drop the last token and rotate left
         codebooks = inp[:, 1:-1, 1:]
-        codebooks = F.pad(codebooks, (0, 1), value=self.config.codebook_padding_idx)
+        codebooks = F.pad(codebooks, (0, 1), value=0)
         codebook_embeddings = self.fast_embeddings(codebooks)
         x = torch.cat([x[:, None], codebook_embeddings], dim=1)
         b, s = x.size(0), x.size(2)
@@ -485,7 +485,7 @@ class DualARTransformer(BaseTransformer):
 
         # Remove padded part
         codebooks = rearrange(codebooks, "b n s -> (b s) n")
-        codebook_mask = (codebooks == self.config.codebook_padding_idx).all(dim=-1)
+        codebook_mask = (codebooks == 0).all(dim=-1)
 
         if torch.all(codebook_mask):
             # If all codebooks are padded, we keep first 8 to make sure the model runs

+ 15 - 0
fish_speech/scheduler.py

@@ -23,3 +23,18 @@ def get_cosine_schedule_with_warmup_lr_lambda(
         final_lr_ratio,
         0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
     )
+
+
+def get_constant_schedule_with_warmup_lr_lambda(
+    current_step: int,
+    *,
+    num_warmup_steps: int | float,
+    num_training_steps: int | None = None,
+):
+    if 0 < num_warmup_steps < 1:  # float mode
+        num_warmup_steps = int(num_warmup_steps * num_training_steps)
+
+    if current_step < num_warmup_steps:
+        return float(current_step) / float(max(1, num_warmup_steps))
+
+    return 1.0

+ 1 - 1
tools/llama/eval_in_context.py

@@ -9,7 +9,7 @@ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
 
 from torch.utils.data import DataLoader
 
-from fish_speech.datasets.text import AutoAugTextDataset, TextDataCollator
+from fish_speech.datasets.semantic import AutoAugTextDataset, TextDataCollator
 from tools.llama.generate import load_model
 
 

+ 39 - 23
tools/llama/merge_lora.py

@@ -1,3 +1,7 @@
+import shutil
+from copy import deepcopy
+from pathlib import Path
+
 import click
 import hydra
 import torch
@@ -5,42 +9,37 @@ from hydra import compose, initialize
 from hydra.utils import instantiate
 from loguru import logger
 
-from fish_speech.models.text2semantic.lora_utils import (
-    get_merged_state_dict,
-    setup_lora,
-)
+from fish_speech.models.text2semantic.llama import BaseTransformer
+from fish_speech.models.text2semantic.lora import get_merged_state_dict
 
 
 @click.command()
-@click.option("--llama-config", type=str, default="dual_ar_2_codebook_medium")
 @click.option("--lora-config", type=str, default="r_8_alpha_16")
-@click.option("--llama-weight", type=str, default="checkpoints/fish-speech-1.2")
+@click.option("--base-weight", type=str, default="checkpoints/fish-speech-1.2")
 @click.option("--lora-weight", type=str, required=True)
 @click.option("--output", type=str, required=True)
-def merge(llama_config, lora_config, llama_weight, lora_weight, output):
+def merge(lora_config, base_weight, lora_weight, output):
+    output = Path(output)
     logger.info(
-        f"Merging {llama_weight} and {lora_weight} into {output} with configs {llama_config} and {lora_config}"
+        f"Merging {base_weight} and {lora_weight} into {output} with {lora_config}"
     )
 
-    hydra.core.global_hydra.GlobalHydra.instance().clear()
-    with initialize(version_base="1.3", config_path="../../fish_speech/configs/model"):
-        # The max_seq_len here doesn't matter.
-        cfg = compose(config_name=llama_config, overrides=[f"config.max_seq_len=2048"])
-
-    llama_model = instantiate(cfg)
-    logger.info(f"Loaded llama model with config {llama_config}")
-
-    hydra.core.global_hydra.GlobalHydra.instance().clear()
     with initialize(version_base="1.3", config_path="../../fish_speech/configs/lora"):
         cfg = compose(config_name=lora_config)
 
     lora_config = instantiate(cfg)
     logger.info(f"Loaded lora model with config {lora_config}")
 
-    setup_lora(llama_model, lora_config)
-    logger.info(f"Merged model setup complete")
+    llama_model = BaseTransformer.from_pretrained(
+        path=base_weight,
+        load_weights=True,
+        lora_config=lora_config,
+    )
+    logger.info(f"Loaded llama model")
 
-    llama_state_dict = torch.load(llama_weight, map_location="cpu")
+    llama_state_dict = llama_model.state_dict()
+    llama_state_dict = {k: v for k, v in llama_state_dict.items() if "lora" not in k}
+    llama_state_dict_copy = deepcopy(llama_state_dict)
     lora_state_dict = torch.load(lora_weight, map_location="cpu")
 
     if "state_dict" in llama_state_dict:
@@ -70,9 +69,26 @@ def merge(llama_config, lora_config, llama_weight, lora_weight, output):
     llama_model.load_state_dict(merged_state_dict, strict=True)
     logger.info(f"Merged model loaded")
 
-    state_dict = get_merged_state_dict(llama_model)
-    torch.save(state_dict, output)
-    logger.info(f"Merged model saved to {output}")
+    # Trigger eval mode to merge lora
+    llama_model.eval()
+    llama_model.save_pretrained(output, drop_lora=True)
+    logger.info(f"Saved merged model to {output}, validating")
+
+    new_state_dict = torch.load(output / "model.pth", map_location="cpu")
+    original_keys = set(llama_state_dict_copy.keys())
+    merged_keys = set(new_state_dict.keys())
+
+    assert original_keys == merged_keys, "Keys should be same"
+
+    for key in original_keys:
+        diff_l1 = (new_state_dict[key] - llama_state_dict_copy[key]).abs().sum().item()
+        if diff_l1 != 0:
+            break
+    else:
+        logger.error("Merged model is same as the original model")
+        exit(1)
+
+    logger.info("Merged model is different from the original model, check passed")
 
 
 if __name__ == "__main__":