فهرست منبع

Finetune support of OpenAudio-S1 (#1115)

* fix voice-cloning links.

* rebuild scheduler.py

* move back finetune files.

* fix spec_transform bugs.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* bring back train.py

* fix lit module import and rebuild semantic's dataset class.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* move back callbacks.

* fix tokenizer size.

* fix lora bugs.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* seems to be the last problem to fix.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update ft docs.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update several docs.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update main doc and mkdocs.yml

* update other docs.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix torchaudio version.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
PoTaTo 5 ماه پیش
والد
کامیت
781bf1cd7a

+ 1 - 5
README.md

@@ -56,6 +56,7 @@
 Here are the official documents for Fish Speech, follow the instructions to get started easily.
 
 - [Installation](https://speech.fish.audio/install/)
+- [Finetune](https://speech.fish.audio/finetune/)
 - [Inference](https://speech.fish.audio/inference/)
 - [Samples](https://speech.fish.audio/examples)
 
@@ -172,11 +173,6 @@ Both S1 and S1-mini incorporate online Reinforcement Learning from Human Feedbac
     <img src="docs/assets/Thumbnail.jpg" alt="OpenAudio S1 Video" style="width: 50%;" />
 </a>
 
-### **Audio Samples**
-<div style="margin: 20px 0;">
-    <em> High-quality audio samples will be available soon, demonstrating our multilingual TTS capabilities across different languages and emotions.</em>
-</div>
-
 </div>
 
 ---

+ 45 - 53
docs/README.ar.md

@@ -45,146 +45,138 @@
 
 > [!IMPORTANT]
 > **إشعار الترخيص**  
-> يتم إصدار قاعدة الكود هذه تحت **رخصة Apache** ويتم إصدار جميع أوزان النماذج تحت **رخصة CC-BY-NC-SA-4.0**. يرجى الرجوع إلى [LICENSE](../LICENSE) لمزيد من التفاصيل.
+> تم إصدار قاعدة الكود هذه بموجب **ترخيص Apache** وتم إصدار جميع أوزان النموذج بموجب **ترخيص CC-BY-NC-SA-4.0**. يرجى الرجوع إلى [LICENSE](LICENSE) لمزيد من التفاصيل.
 
 > [!WARNING]
-> **إخلاء المسؤولية القانونية**  
-> نحن لا نتحمل أي مسؤولية عن أي استخدام غير قانوني لقاعدة الكود. يرجى الرجوع إلى القوانين المحلية حول DMCA والقوانين الأخرى ذات الصلة.
+> **إخلاء المسؤولية القانوني**  
+> نحن لا نتحمل أي مسؤولية عن أي استخدام غير قانوني لقاعدة الكود. يرجى الرجوع إلى القوانين المحلية الخاصة بك فيما يتعلق بقانون الألفية الجديدة لحقوق طبع ونشر المواد الرقمية والقوانين الأخرى ذات الصلة.
 
-## ابدأ من هنا
+## ابدأ هنا
 
-هنا هي الوثائق الرسمية لـ Fish Speech، اتبع التعليمات للبدء بسهولة.
+فيما يلي المستندات الرسمية لـ Fish Speech، اتبع التعليمات للبدء بسهولة.
 
-- [التثبيت](https://speech.fish.audio/ar/install/)
-- [الاستنتاج](https://speech.fish.audio/ar/inference/)
+- [التثبيت](https://speech.fish.audio/install/)
+- [الضبط الدقيق](https://speech.fish.audio/finetune/)
+- [الاستدلال](https://speech.fish.audio/inference/)
 - [العينات](https://speech.fish.audio/examples)
 
-## 🎉 الإعلان
+## 🎉 إعلان
 
-نحن متحمسون للإعلان عن إعادة تسمية علامتنا التجارية إلى **OpenAudio** — تقديم سلسلة جديدة ثورية من نماذج تحويل النص إلى كلام المتقدمة التي تبني على أساس Fish-Speech.
+يسعدنا أن نعلن أننا قمنا بإعادة تسمية العلامة التجارية إلى **OpenAudio** — تقديم سلسلة جديدة ثورية من نماذج تحويل النص إلى كلام المتقدمة التي تبني على أساس Fish-Speech.
 
-نحن فخورون بإطلاق **OpenAudio-S1** كأول نموذج في هذه السلسلة، يقدم تحسينات كبيرة في الجودة والأداء والقدرات.
+نحن فخورون بإصدار **OpenAudio-S1** كنموذج أول في هذه السلسلة، حيث يوفر تحسينات كبيرة في الجودة والأداء والقدرات.
 
-يأتي OpenAudio-S1 في إصدارين: **OpenAudio-S1** و **OpenAudio-S1-mini**. كلا النموذجين متاحان الآن على [Fish Audio Playground](https://fish.audio) (لـ **OpenAudio-S1**) و [Hugging Face](https://huggingface.co/fishaudio/openaudio-s1-mini) (لـ **OpenAudio-S1-mini**).
+يأتي OpenAudio-S1 في نسختين: **OpenAudio-S1** و **OpenAudio-S1-mini**. كلا النموذجين متاحان الآن على [Fish Audio Playground](https://fish.audio) (لـ **OpenAudio-S1**) و [Hugging Face](https://huggingface.co/fishaudio/openaudio-s1-mini) (لـ **OpenAudio-S1-mini**).
 
 قم بزيارة [موقع OpenAudio](https://openaudio.com/blogs/s1) للمدونة والتقرير التقني.
 
-## النقاط البارزة
+## أبرز المميزات
 
 ### **جودة TTS ممتازة**
 
-نستخدم مقاييس تقييم Seed TTS لتقييم أداء النموذج، وتظهر النتائج أن OpenAudio S1 يحقق **0.008 WER** و **0.004 CER** على النص الإنجليزي، وهو أفضل بكثير من النماذج السابقة. (الإنجليزية، التقييم التلقائي، بناءً على OpenAI gpt-4o-transcribe، مسافة المتحدث باستخدام Revai/pyannote-wespeaker-voxceleb-resnet34-LM)
+نستخدم مقاييس تقييم Seed TTS لتقييم أداء النموذج، وتظهر النتائج أن OpenAudio S1 يحقق **0.008 WER** و **0.004 CER** على النص الإنجليزي، وهو أفضل بشكل ملحوظ من النماذج السابقة. (الإنجليزية، التقييم التلقائي، بناءً على OpenAI gpt-4o-transcribe، مسافة المتحدث باستخدام Revai/pyannote-wespeaker-voxceleb-resnet34-LM)
 
-| النموذج | معدل خطأ الكلمات (WER) | معدل خطأ الأحرف (CER) | مسافة المتحدث |
+| النموذج | معدل الخطأ في الكلمات (WER) | معدل الخطأ في الأحرف (CER) | مسافة المتحدث |
 |-------|----------------------|---------------------------|------------------|
 | **S1** | **0.008**  | **0.004**  | **0.332** |
 | **S1-mini** | **0.011** | **0.005** | **0.380** |
 
 ### **أفضل نموذج في TTS-Arena2** 🏆
 
-حقق OpenAudio S1 **المرتبة الأولى** في [TTS-Arena2](https://arena.speechcolab.org/)، المعيار لتقييم تحويل النص إلى كلام:
+حقق OpenAudio S1 **المركز الأول** على [TTS-Arena2](https://arena.speechcolab.org/)، المعيار لتقييم تحويل النص إلى كلام:
 
 <div align="center">
-    <img src="assets/Elo.jpg" alt="TTS-Arena2 Ranking" style="width: 75%;" />
+    <img src="../docs/assets/Elo.jpg" alt="TTS-Arena2 Ranking" style="width: 75%;" />
 </div>
 
 ### **التحكم في الكلام**
-يدعم OpenAudio S1 **مجموعة متنوعة من العلامات العاطفية والنبرة والخاصة** لتعزيز تركيب الكلام:
 
-- **المشاعر الأساسية**:
+يدعم OpenAudio S1 **مجموعة متنوعة من العلامات العاطفية والنبرة والعلامات الخاصة** لتعزيز تخليق الكلام:
+
+- **العواطف الأساسية**:
 ```
 (غاضب) (حزين) (متحمس) (مندهش) (راضي) (مسرور) 
-(خائف) (قلق) (منزعج) (عصبي) (محبط) (مكتئب)
-(متعاطف) (محرج) (مشمئز) (متأثر) (فخور) (مسترخي)
+(خائف) (قلق) (منزعج) (متوتر) (محبط) (مكتئب)
+(متعاطف) (محرج) (مشمئز) (متحرك) (فخور) (مرتاح)
 (ممتن) (واثق) (مهتم) (فضولي) (مرتبك) (مبتهج)
 ```
 
-- **المشاعر المتقدمة**:
+- **العواطف المتقدمة**:
 ```
 (محتقر) (غير سعيد) (قلق) (هستيري) (غير مبال) 
-(نافد الصبر) (مذنب) (ازدرائي) (مذعور) (غاضب) (مترد)
-(متحمس) (غير موافق) (سلبي) (منكر) (مندهش) (جدي)
-(ساخر) (مصالح) (مواسي) (صادق) (ساخر)
+(غير صبور) (مذنب) (ساخر) (ذعر) (غاضب) (متردد)
+(متحمس) (غير موافق) (سلبي) (نافي) (مندهش) (جاد)
+(ساخر) (مصالح) (مريح) (صادق) (ساخر)
 (متردد) (مستسلم) (مؤلم) (محرج) (مسلي)
 ```
 
 - **علامات النبرة**:
 ```
-(بنبرة مستعجلة) (صراخ) (صراخ) (همس) (نبرة ناعمة)
+(بنبرة مستعجلة) (يصرخ) (يصرخ) (يهمهم) (بنبرة ناعمة)
 ```
 
 - **تأثيرات صوتية خاصة**:
 ```
-(ضحك) (قهقهة) (نشيج) (بكاء بصوت عالٍ) (تنهد) (لهاث)
-(أنين) (ضحك الجمهور) (ضحك الخلفية) (ضحك الجمهور)
+(يضحك) (يقهقه) (ينتحب) (يبكي بصوت عال) (يتنهد) (يلهث)
+(يئن) (ضحك الجمهور) (ضحك في الخلفية) (ضحك الجمهور)
 ```
 
-يمكنك أيضًا استخدام ها،ها،ها للتحكم، هناك العديد من الحالات الأخرى في انتظار استكشافك بنفسك.
+يمكنك أيضًا استخدام Ha,ha,ha للتحكم، وهناك العديد من الحالات الأخرى التي تنتظر استكشافها بنفسك.
 
-(الدعم للإنجليزية والصينية واليابانية الآن، والمزيد من اللغات قادم قريبًا!)
+(الدعم متاح للإنجليزية والصينية واليابانية الآن، والمزيد من اللغات قريبًا!)
 
 ### **نوعان من النماذج**
 
 | النموذج | الحجم | التوفر | الميزات |
 |-------|------|--------------|----------|
-| **S1** | 4 مليار معامل | متاح على [fish.audio](https://fish.audio) | النموذج الرئيسي كامل الميزات |
-| **S1-mini** | 0.5 مليار معامل | متاح على Hugging Face [hf space](https://huggingface.co/spaces/fishaudio/openaudio-s1-mini) | إصدار مقطر بالقدرات الأساسية |
+| **S1** | 4B معامل | متوفر على [fish.audio](https://fish.audio/) | النموذج الرئيسي كامل الميزات |
+| **S1-mini** | 0.5B معامل | متوفر على huggingface [hf space](https://huggingface.co/spaces/fishaudio/openaudio-s1-mini) | نسخة مقطرة بالقدرات الأساسية |
 
-كل من S1 و S1-mini يدمجان التعلم المعزز عبر الإنترنت من ردود الفعل البشرية (RLHF).
+كلا النموذجين S1 و S1-mini يتضمنان التعلم المعزز من التغذية الراجعة البشرية (RLHF) عبر الإنترنت.
 
 ## **الميزات**
 
-1. **TTS بدون عينات وبعينات قليلة:** أدخل عينة صوتية من 10 إلى 30 ثانية لإنتاج مخرجات TTS عالية الجودة. **للإرشادات التفصيلية، راجع [أفضل ممارسات استنساخ الصوت](https://docs.fish.audio/text-to-speech/voice-clone-best-practices).**
+1. **TTS بدون عينات وقليل العينات:** أدخل عينة صوتية مدتها 10 إلى 30 ثانية لتوليد مخرجات TTS عالية الجودة. **للحصول على إرشادات مفصلة، راجع [أفضل ممارسات استنساخ الصوت](https://docs.fish.audio/resources/best-practices/voice-cloning).**
 
-2. **الدعم متعدد اللغات وعبر اللغات:** ببساطة انسخ والصق النص متعدد اللغات في مربع الإدخال—لا حاجة للقلق بشأن اللغة. يدعم حاليًا الإنجليزية واليابانية والكورية والصينية والفرنسية والألمانية والعربية والإسبانية.
+2. **الدعم متعدد اللغات وعبر اللغات:** ما عليك سوى نسخ ولصق النص متعدد اللغات في مربع الإدخال — لا داعي للقلق بشأن اللغة. يدعم حاليًا الإنجليزية واليابانية والكورية والصينية والفرنسية والألمانية والعربية والإسبانية.
 
-3. **لا يعتمد على الصوتيات:** النموذج لديه قدرات تعميم قوية ولا يعتمد على الصوتيات لـ TTS. يمكنه التعامل مع النص في أي نص لغوي.
+3. **لا يعتمد على الفونيمات:** يتمتع النموذج بقدرات تعميم قوية ولا يعتمد على الفونيمات لـ TTS. يمكنه التعامل مع النص بأي لغة نصية.
 
-4. **دقيق للغاية:** يحقق معدل خطأ أحرف منخفض (CER) حوالي 0.4% ومعدل خطأ كلمات (WER) حوالي 0.8% لـ Seed-TTS Eval.
+4. **دقيق للغاية:** يحقق معدل خطأ في الأحرف (CER) حوالي 0.4٪ ومعدل خطأ في الكلمات (WER) حوالي 0.8٪ لـ Seed-TTS Eval.
 
-5. **سريع:** مع تسريع fish-tech، عامل الوقت الحقيقي حوالي 1:5 على كمبيوتر محمول Nvidia RTX 4060 و 1:15 على Nvidia RTX 4090.
+5. **سريع:** مع التسريع بواسطة torch compile، فإن عامل الوقت الحقيقي هو حوالي 1:7 على بطاقة Nvidia RTX 4090 GPU.
 
-6. **استدلال WebUI:** يتميز بواجهة مستخدم ويب سهلة الاستخدام تعتمد على Gradio ومتوافقة مع متصفحات Chrome و Firefox و Edge وغيرها.
+6. **استدلال WebUI:** يتميز بواجهة ويب سهلة الاستخدام تعتمد على Gradio متوافقة مع Chrome و Firefox و Edge والمتصفحات الأخرى.
 
 7. **سهولة النشر:** يمكنك إعداد خادم استدلال بسهولة مع دعم أصلي لأنظمة Linux و Windows (دعم macOS قريبًا)، مما يقلل من فقدان الأداء.
 
-## المجتمع والدعم
+## **وسائل الإعلام والعروض التوضيحية**
 
 <div align="center">
 
 ### **وسائل التواصل الاجتماعي**
 <a href="https://x.com/FishAudio/status/1929915992299450398" target="_blank">
-    <img src="https://img.shields.io/badge/𝕏-Latest_Demo-black?style=for-the-badge&logo=x&logoColor=white" alt="أحدث عرض توضيحي على X" />
+    <img src="https://img.shields.io/badge/𝕏-أحدث_عرض_توضيحي-black?style=for-the-badge&logo=x&logoColor=white" alt="أحدث عرض توضيحي على X" />
 </a>
 
 ### **العروض التوضيحية التفاعلية**
 <a href="https://fish.audio" target="_blank">
-    <img src="https://img.shields.io/badge/Fish_Audio-Try_OpenAudio_S1-blue?style=for-the-badge" alt="جرب OpenAudio S1" />
+    <img src="https://img.shields.io/badge/Fish_Audio-جرب_OpenAudio_S1-blue?style=for-the-badge" alt="جرب OpenAudio S1" />
 </a>
 <a href="https://huggingface.co/spaces/fishaudio/openaudio-s1-mini" target="_blank">
-    <img src="https://img.shields.io/badge/Hugging_Face-Try_S1_Mini-yellow?style=for-the-badge" alt="جرب S1 Mini" />
+    <img src="https://img.shields.io/badge/Hugging_Face-جرب_S1_Mini-yellow?style=for-the-badge" alt="جرب S1 Mini" />
 </a>
 
 ### **عروض الفيديو**
 
 <a href="https://www.youtube.com/watch?v=SYuPvd7m06A" target="_blank">
-    <img src="../docs/assets/Thumbnail.jpg" alt="OpenAudio S1 Video" style="width: 50%;" />
+    <img src="../docs/assets/Thumbnail.jpg" alt="فيديو OpenAudio S1" style="width: 50%;" />
 </a>
 
-### **عينات الصوت**
-<div style="margin: 20px 0;">
-    <em>ستتوفر عينات صوتية عالية الجودة قريبًا، تُظهر قدراتنا في TTS متعدد اللغات عبر لغات ومشاعر مختلفة.</em>
-</div>
-
 </div>
 
 ---
 
-## الوثائق
-
-- [بناء البيئة](ar/install.md)
-- [الاستنتاج](ar/inference.md)
-
 ## الاعتمادات
 
 - [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)

+ 8 - 16
docs/README.ja.md

@@ -1,7 +1,7 @@
 <div align="center">
 <h1>Fish Speech</h1>
 
-[English](../README.md) | [简体中文](README.zh.md) | [Portuguese](README.pt-BR.md) | **日本語** | [한국어](README.ko.md) <br>
+[English](../README.md) | [简体中文](README.zh.md) | [Portuguese](README.pt-BR.md) | **日本語** | [한국어](README.ko.md) | [العربية](README.ar.md) <br>
 
 <a href="https://www.producthunt.com/posts/fish-speech-1-4?embed=true&utm_source=badge-featured&utm_medium=badge&utm_souce=badge-fish&#0045;speech&#0045;1&#0045;4" target="_blank">
     <img src="https://api.producthunt.com/widgets/embed-image/v1/featured.svg?post_id=488440&theme=light" alt="Fish&#0032;Speech&#0032;1&#0046;4 - Open&#0045;Source&#0032;Multilingual&#0032;Text&#0045;to&#0045;Speech&#0032;with&#0032;Voice&#0032;Cloning | Product Hunt" style="width: 250px; height: 54px;" width="250" height="54" />
@@ -56,6 +56,7 @@
 こちらは Fish Speech の公式ドキュメントです。手順に従って簡単に始めることができます。
 
 - [インストール](https://speech.fish.audio/ja/install/)
+- [ファインチューニング](https://speech.fish.audio/ja/finetune/)
 - [推論](https://speech.fish.audio/ja/inference/)
 - [サンプル](https://speech.fish.audio/examples)
 
@@ -89,6 +90,7 @@ OpenAudio S1は、テキスト音声変換評価のベンチマークである[T
 </div>
 
 ### **音声制御**
+
 OpenAudio S1は**音声合成を強化するための様々な感情、トーン、特別なマーカーをサポート**しています:
 
 - **基本感情**:
@@ -126,15 +128,15 @@ OpenAudio S1は**音声合成を強化するための様々な感情、トーン
 ### **2種類のモデル**
 
 | モデル | サイズ | 利用可能性 | 機能 |
-|-------|--------|------------|------|
-| **S1** | 4Bパラメータ | [fish.audio](fish.audio)で利用可能 | フル機能のフラッグシップモデル |
+|-------|------|--------------|----------|
+| **S1** | 4Bパラメータ | [fish.audio](https://fish.audio/)で利用可能 | フル機能のフラッグシップモデル |
 | **S1-mini** | 0.5Bパラメータ | huggingface [hf space](https://huggingface.co/spaces/fishaudio/openaudio-s1-mini)で利用可能 | コア機能を持つ蒸留版 |
 
 S1とS1-miniの両方がオンライン人間フィードバック強化学習(RLHF)を組み込んでいます。
 
 ## **機能**
 
-1. **ゼロショット・少数ショットTTS:** 10〜30秒の音声サンプルを入力して高品質のTTS出力を生成します。**詳細なガイドラインについては、[Voice Cloning Best Practices](https://docs.fish.audio/text-to-speech/voice-clone-best-practices)をご覧ください。**
+1. **ゼロショット・少数ショットTTS:** 10〜30秒の音声サンプルを入力して高品質のTTS出力を生成します。**詳細なガイドラインについては、[Voice Cloning Best Practices](https://docs.fish.audio/resources/best-practices/voice-cloning)をご覧ください。**
 
 2. **多言語・言語横断サポート:** 多言語テキストを入力ボックスにコピー&ペーストするだけで、言語を気にする必要はありません。現在、英語、日本語、韓国語、中国語、フランス語、ドイツ語、アラビア語、スペイン語をサポートしています。
 
@@ -142,7 +144,7 @@ S1とS1-miniの両方がオンライン人間フィードバック強化学習
 
 4. **高精度:** Seed-TTS Evalで約0.4%の低いCER(文字誤り率)と約0.8%のWER(単語誤り率)を達成します。
 
-5. **高速:** fish-tech加速により、Nvidia RTX 4060ラップトップで約1:5、Nvidia RTX 4090で約1:15のリアルタイム係数を実現します。
+5. **高速:** torch compileによる加速により、Nvidia RTX 4090 GPUで約1:7のリアルタイム係数を実現します。
 
 6. **WebUI推論:** 使いやすいGradioベースのWeb UIを搭載し、Chrome、Firefox、Edgeなどのブラウザと互換性があります。
 
@@ -168,23 +170,13 @@ S1とS1-miniの両方がオンライン人間フィードバック強化学習
 ### **ビデオショーケース**
 
 <a href="https://www.youtube.com/watch?v=SYuPvd7m06A" target="_blank">
-    <img src="../docs/assets/Thumbnail.jpg" alt="OpenAudio S1 Video" style="width: 50%;" />
+    <img src="assets/Thumbnail.jpg" alt="OpenAudio S1 Video" style="width: 50%;" />
 </a>
 
-### **音声サンプル**
-<div style="margin: 20px 0;">
-    <em>高品質の音声サンプルは間もなく公開予定で、異なる言語と感情における私たちの多言語TTS機能を実演します。</em>
-</div>
-
 </div>
 
 ---
 
-## ドキュメント
-
-- [環境構築](ja/install.md)
-- [推論](ja/inference.md)
-
 ## クレジット
 
 - [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)

+ 5 - 14
docs/README.ko.md

@@ -1,7 +1,7 @@
 <div align="center">
 <h1>Fish Speech</h1>
 
-[English](../README.md) | [简体中文](README.zh.md) | [Portuguese](README.pt-BR.md) | [日本語](README.ja.md) | **한국어** <br>
+[English](../README.md) | [简体中文](README.zh.md) | [Portuguese](README.pt-BR.md) | [日本語](README.ja.md) | **한국어** | [العربية](README.ar.md) <br>
 
 <a href="https://www.producthunt.com/posts/fish-speech-1-4?embed=true&utm_source=badge-featured&utm_medium=badge&utm_souce=badge-fish&#0045;speech&#0045;1&#0045;4" target="_blank">
     <img src="https://api.producthunt.com/widgets/embed-image/v1/featured.svg?post_id=488440&theme=light" alt="Fish&#0032;Speech&#0032;1&#0046;4 - Open&#0045;Source&#0032;Multilingual&#0032;Text&#0045;to&#0045;Speech&#0032;with&#0032;Voice&#0032;Cloning | Product Hunt" style="width: 250px; height: 54px;" width="250" height="54" />
@@ -56,6 +56,7 @@
 여기는 Fish Speech의 공식 문서입니다. 지침을 따라 쉽게 시작하세요.
 
 - [설치](https://speech.fish.audio/ko/install/)
+- [파인튜닝](https://speech.fish.audio/ko/finetune/)
 - [추론](https://speech.fish.audio/ko/inference/)
 - [샘플](https://speech.fish.audio/examples)
 
@@ -85,7 +86,7 @@ OpenAudio-S1은 두 가지 버전으로 제공됩니다: **OpenAudio-S1**과 **O
 OpenAudio S1은 텍스트 음성 변환 평가의 벤치마크인 [TTS-Arena2](https://arena.speechcolab.org/)에서 **1위**를 달성했습니다:
 
 <div align="center">
-    <img src="assets/Elo.jpg" alt="TTS-Arena2 순위" style="width: 75%;" />
+    <img src="../docs/assets/Elo.jpg" alt="TTS-Arena2 순위" style="width: 75%;" />
 </div>
 
 ### **음성 제어**
@@ -134,7 +135,7 @@ S1과 S1-mini 모두 온라인 인간 피드백 강화학습(RLHF)을 통합하
 
 ## **기능**
 
-1. **제로샷 및 퓨샷 TTS:** 10~30초의 음성 샘플을 입력하여 고품질 TTS 출력을 생성합니다. **자세한 가이드라인은 [음성 복제 모범 사례](https://docs.fish.audio/text-to-speech/voice-clone-best-practices)를 참조하세요.**
+1. **제로샷 및 퓨샷 TTS:** 10~30초의 음성 샘플을 입력하여 고품질 TTS 출력을 생성합니다. **자세한 가이드라인은 [음성 복제 모범 사례](https://docs.fish.audio/resources/best-practices/voice-cloning)를 참조하세요.**
 
 2. **다국어 및 교차 언어 지원:** 다국어 텍스트를 입력 상자에 복사하여 붙여넣기만 하면 됩니다. 언어를 걱정할 필요가 없습니다. 현재 영어, 일본어, 한국어, 중국어, 프랑스어, 독일어, 아랍어, 스페인어를 지원합니다.
 
@@ -142,7 +143,7 @@ S1과 S1-mini 모두 온라인 인간 피드백 강화학습(RLHF)을 통합하
 
 4. **높은 정확도:** Seed-TTS Eval에서 약 0.4%의 낮은 CER(문자 오류율)과 약 0.8%의 WER(단어 오류율)을 달성합니다.
 
-5. **빠른 속도:** fish-tech 가속화로 Nvidia RTX 4060 노트북에서 실시간 팩터가 약 1:5, Nvidia RTX 4090에서 1:15입니다.
+5. **빠른 속도:** torch compile로 가속화되어 Nvidia RTX 4090 GPU에서 실시간 팩터가 약 1:7입니다.
 
 6. **WebUI 추론:** 사용하기 쉬운 Gradio 기반 웹 UI를 제공하며 Chrome, Firefox, Edge 등 다른 브라우저와 호환됩니다.
 
@@ -171,20 +172,10 @@ S1과 S1-mini 모두 온라인 인간 피드백 강화학습(RLHF)을 통합하
     <img src="../docs/assets/Thumbnail.jpg" alt="OpenAudio S1 Video" style="width: 50%;" />
 </a>
 
-### **오디오 샘플**
-<div style="margin: 20px 0;">
-    <em> 다양한 언어와 감정에 걸친 다국어 TTS 기능을 보여주는 고품질 오디오 샘플이 곧 제공될 예정입니다.</em>
-</div>
-
 </div>
 
 ---
 
-## 문서
-
-- [환경 구축](ko/install.md)
-- [추론](ko/inference.md)
-
 ## 크레딧
 
 - [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)

+ 11 - 19
docs/README.pt-BR.md

@@ -1,7 +1,7 @@
 <div align="center">
 <h1>Fish Speech</h1>
 
-[English](../README.md) | [简体中文](README.zh.md) | **Portuguese** | [日本語](README.ja.md) | [한국어](README.ko.md) <br>
+[English](../README.md) | [简体中文](README.zh.md) | **Portuguese** | [日本語](README.ja.md) | [한국어](README.ko.md) | [العربية](README.ar.md) <br>
 
 <a href="https://www.producthunt.com/posts/fish-speech-1-4?embed=true&utm_source=badge-featured&utm_medium=badge&utm_souce=badge-fish&#0045;speech&#0045;1&#0045;4" target="_blank">
     <img src="https://api.producthunt.com/widgets/embed-image/v1/featured.svg?post_id=488440&theme=light" alt="Fish&#0032;Speech&#0032;1&#0046;4 - Open&#0045;Source&#0032;Multilingual&#0032;Text&#0045;to&#0045;Speech&#0032;with&#0032;Voice&#0032;Cloning | Product Hunt" style="width: 250px; height: 54px;" width="250" height="54" />
@@ -55,8 +55,9 @@
 
 Aqui estão os documentos oficiais do Fish Speech, siga as instruções para começar facilmente.
 
-- [Instalação](https://speech.fish.audio/pt/install/)
-- [Inferência](https://speech.fish.audio/pt/inference/)
+- [Instalação](https://speech.fish.audio/install/)
+- [Fine-tune](https://speech.fish.audio/finetune/)
+- [Inferência](https://speech.fish.audio/inference/)
 - [Amostras](https://speech.fish.audio/examples)
 
 ## 🎉 Anúncio
@@ -89,6 +90,7 @@ O OpenAudio S1 alcançou a **classificação #1** no [TTS-Arena2](https://arena.
 </div>
 
 ### **Controle de Fala**
+
 O OpenAudio S1 **suporta uma variedade de marcadores emocionais, de tom e especiais** para aprimorar a síntese de fala:
 
 - **Emoções básicas**:
@@ -127,14 +129,14 @@ Você também pode usar Ha,ha,ha para controlar, há muitos outros casos esperan
 
 | Modelo | Tamanho | Disponibilidade | Recursos |
 |-------|------|--------------|----------|
-| **S1** | 4B parâmetros | Disponível em [fish.audio](https://fish.audio) | Modelo flagship com recursos completos |
+| **S1** | 4B parâmetros | Disponível em [fish.audio](https://fish.audio/) | Modelo flagship com recursos completos |
 | **S1-mini** | 0.5B parâmetros | Disponível no Hugging Face [hf space](https://huggingface.co/spaces/fishaudio/openaudio-s1-mini) | Versão destilada com capacidades principais |
 
 Tanto S1 quanto S1-mini incorporam Aprendizado por Reforço online com Feedback Humano (RLHF).
-   
-   ## **Recursos**
 
-1. **TTS Zero-shot e Few-shot:** Insira uma amostra vocal de 10 a 30 segundos para gerar saída TTS de alta qualidade. **Para diretrizes detalhadas, veja [Melhores Práticas de Clonagem de Voz](https://docs.fish.audio/text-to-speech/voice-clone-best-practices).**
+## **Recursos**
+
+1. **TTS Zero-shot e Few-shot:** Insira uma amostra vocal de 10 a 30 segundos para gerar saída TTS de alta qualidade. **Para diretrizes detalhadas, veja [Melhores Práticas de Clonagem de Voz](https://docs.fish.audio/resources/best-practices/voice-cloning).**
 
 2. **Suporte Multilíngue e Cross-lingual:** Simplesmente copie e cole texto multilíngue na caixa de entrada—não precisa se preocupar com o idioma. Atualmente suporta inglês, japonês, coreano, chinês, francês, alemão, árabe e espanhol.
 
@@ -142,7 +144,7 @@ Tanto S1 quanto S1-mini incorporam Aprendizado por Reforço online com Feedback
 
 4. **Altamente Preciso:** Alcança um baixo CER (Taxa de Erro de Caractere) de cerca de 0.4% e WER (Taxa de Erro de Palavra) de cerca de 0.8% para Seed-TTS Eval.
 
-5. **Rápido:** Com aceleração fish-tech, o fator de tempo real é aproximadamente 1:5 em um laptop Nvidia RTX 4060 e 1:15 em um Nvidia RTX 4090.
+5. **Rápido:** Com aceleração por torch compile, o fator de tempo real é aproximadamente 1:7 em uma GPU Nvidia RTX 4090.
 
 6. **Inferência via WebUI:** Apresenta uma interface de usuário baseada em Gradio, fácil de usar e compatível com Chrome, Firefox, Edge e outros navegadores.
 
@@ -168,23 +170,13 @@ Tanto S1 quanto S1-mini incorporam Aprendizado por Reforço online com Feedback
 ### **Vitrines de Vídeo**
 
 <a href="https://www.youtube.com/watch?v=SYuPvd7m06A" target="_blank">
-    <img src="../docs/assets/Thumbnail.jpg" alt="OpenAudio S1 Video" style="width: 50%;" />
+    <img src="assets/Thumbnail.jpg" alt="OpenAudio S1 Video" style="width: 50%;" />
 </a>
 
-### **Amostras de Áudio**
-<div style="margin: 20px 0;">
-    <em> Amostras de áudio de alta qualidade estarão disponíveis em breve, demonstrando nossas capacidades TTS multilíngues em diferentes idiomas e emoções.</em>
-</div>
-
 </div>
 
 ---
 
-## Documentos
-
-- [Construir Ambiente](pt/install.md)
-- [Inferência](pt/inference.md)
-
 ## Créditos
 
 - [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)

+ 11 - 19
docs/README.zh.md

@@ -1,7 +1,7 @@
 <div align="center">
 <h1>Fish Speech</h1>
 
-[English](../README.md) | **简体中文** | [Portuguese](README.pt-BR.md) | [日本語](README.ja.md) | [한국어](README.ko.md) <br>
+[English](../README.md) | **简体中文** | [Portuguese](README.pt-BR.md) | [日本語](README.ja.md) | [한국어](README.ko.md) | [العربية](README.ar.md) <br>
 
 <a href="https://www.producthunt.com/posts/fish-speech-1-4?embed=true&utm_source=badge-featured&utm_medium=badge&utm_souce=badge-fish&#0045;speech&#0045;1&#0045;4" target="_blank">
     <img src="https://api.producthunt.com/widgets/embed-image/v1/featured.svg?post_id=488440&theme=light" alt="Fish&#0032;Speech&#0032;1&#0046;4 - Open&#0045;Source&#0032;Multilingual&#0032;Text&#0045;to&#0045;Speech&#0032;with&#0032;Voice&#0032;Cloning | Product Hunt" style="width: 250px; height: 54px;" width="250" height="54" />
@@ -56,6 +56,7 @@
 这里是 Fish Speech 的官方文档,按照说明轻松开始使用。
 
 - [安装](https://speech.fish.audio/zh/install/)
+- [微调](https://speech.fish.audio/zh/finetune/)
 - [推理](https://speech.fish.audio/zh/inference/)
 - [示例](https://speech.fish.audio/examples)
 
@@ -85,10 +86,11 @@ OpenAudio-S1 提供两个版本:**OpenAudio-S1** 和 **OpenAudio-S1-mini**。
 OpenAudio S1 在 [TTS-Arena2](https://arena.speechcolab.org/) 上取得了 **第一名**,这是文本转语音评估的基准:
 
 <div align="center">
-    <img src="assets/Elo.jpg" alt="TTS-Arena2 排名" style="width: 75%;" />
+    <img src="../docs/assets/Elo.jpg" alt="TTS-Arena2 排名" style="width: 75%;" />
 </div>
 
 ### **语音控制**
+
 OpenAudio S1 **支持多种情感、语调和特殊标记** 来增强语音合成:
 
 - **基础情感**:
@@ -131,10 +133,10 @@ OpenAudio S1 **支持多种情感、语调和特殊标记** 来增强语音合
 | **S1-mini** | 0.5B 参数 | 在 Hugging Face [hf space](https://huggingface.co/spaces/fishaudio/openaudio-s1-mini) 上可用 | 具有核心功能的精简版本 |
 
 S1 和 S1-mini 都集成了在线人类反馈强化学习(RLHF)。
-   
-   ## **功能**
 
-1. **零样本和少样本 TTS:** 输入 10 到 30 秒的语音样本以生成高质量的 TTS 输出。**详细指南请参见 [语音克隆最佳实践](https://docs.fish.audio/text-to-speech/voice-clone-best-practices)。**
+## **功能**
+
+1. **零样本和少样本 TTS:** 输入 10 到 30 秒的语音样本以生成高质量的 TTS 输出。**详细指南请参见 [语音克隆最佳实践](https://docs.fish.audio/resources/best-practices/voice-cloning)。**
 
 2. **多语言和跨语言支持:** 只需将多语言文本复制并粘贴到输入框中——无需担心语言问题。目前支持英语、日语、韩语、中文、法语、德语、阿拉伯语和西班牙语。
 
@@ -142,7 +144,7 @@ S1 和 S1-mini 都集成了在线人类反馈强化学习(RLHF)。
 
 4. **高准确性:** 在 Seed-TTS Eval 上实现约 0.4% 的低 CER(字符错误率)和约 0.8% 的 WER(词错误率)。
 
-5. **快速:** 通过 fish-tech 加速,在 Nvidia RTX 4060 笔记本电脑上实时因子约为 1:5,在 Nvidia RTX 4090 上为 1:15
+5. **快速:** 通过 torch compile 加速,在 Nvidia RTX 4090 GPU 上的实时因子约为 1:7
 
 6. **WebUI 推理:** 提供简单易用的、基于 Gradio 的 Web UI,兼容 Chrome、Firefox、Edge 等浏览器。
 
@@ -154,15 +156,15 @@ S1 和 S1-mini 都集成了在线人类反馈强化学习(RLHF)。
 
 ### **社交媒体**
 <a href="https://x.com/FishAudio/status/1929915992299450398" target="_blank">
-    <img src="https://img.shields.io/badge/𝕏-Latest_Demo-black?style=for-the-badge&logo=x&logoColor=white" alt="X 上的最新演示" />
+    <img src="https://img.shields.io/badge/𝕏-最新演示-black?style=for-the-badge&logo=x&logoColor=white" alt="X 上的最新演示" />
 </a>
 
 ### **交互式演示**
 <a href="https://fish.audio" target="_blank">
-    <img src="https://img.shields.io/badge/Fish_Audio-Try_OpenAudio_S1-blue?style=for-the-badge" alt="试用 OpenAudio S1" />
+    <img src="https://img.shields.io/badge/Fish_Audio-试用_OpenAudio_S1-blue?style=for-the-badge" alt="试用 OpenAudio S1" />
 </a>
 <a href="https://huggingface.co/spaces/fishaudio/openaudio-s1-mini" target="_blank">
-    <img src="https://img.shields.io/badge/Hugging_Face-Try_S1_Mini-yellow?style=for-the-badge" alt="试用 S1 Mini" />
+    <img src="https://img.shields.io/badge/Hugging_Face-试用_S1_Mini-yellow?style=for-the-badge" alt="试用 S1 Mini" />
 </a>
 
 ### **视频展示**
@@ -171,20 +173,10 @@ S1 和 S1-mini 都集成了在线人类反馈强化学习(RLHF)。
     <img src="../docs/assets/Thumbnail.jpg" alt="OpenAudio S1 Video" style="width: 50%;" />
 </a>
 
-### **音频样本**
-<div style="margin: 20px 0;">
-    <em> 展示我们跨不同语言和情感的多语言 TTS 功能的高质量音频样本即将推出。</em>
-</div>
-
 </div>
 
 ---
 
-## 文档
-
-- [构建环境](zh/install.md)
-- [推理](zh/inference.md)
-
 ## 致谢
 
 - [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)

+ 125 - 0
docs/ar/finetune.md

@@ -0,0 +1,125 @@
+# الضبط الدقيق (Fine-tuning)
+
+من الواضح أنك عندما فتحت هذه الصفحة، لم تكن راضيًا عن أداء النموذج المدرب مسبقًا في وضع zero-shot. أنت ترغب في إجراء ضبط دقيق لنموذج لتحسين أدائه على مجموعة البيانات الخاصة بك.
+
+في الإصدار الحالي، ما عليك سوى إجراء الضبط الدقيق لجزء 'LLAMA'.
+
+## الضبط الدقيق لـ LLAMA
+### 1. إعداد مجموعة البيانات
+
+```
+.
+├── SPK1
+│   ├── 21.15-26.44.lab
+│   ├── 21.15-26.44.mp3
+│   ├── 27.51-29.98.lab
+│   ├── 27.51-29.98.mp3
+│   ├── 30.1-32.71.lab
+│   └── 30.1-32.71.mp3
+└── SPK2
+    ├── 38.79-40.85.lab
+    └── 38.79-40.85.mp3
+```
+
+تحتاج إلى تحويل مجموعة البيانات الخاصة بك إلى التنسيق أعلاه ووضعها تحت مجلد `data`. يمكن أن يكون للملف الصوتي الامتدادات `.mp3`، `.wav`، أو `.flac`، ويجب أن يكون لملف التعليقات التوضيحية الامتداد `.lab`.
+
+!!! info "تنسيق مجموعة البيانات"
+    يحتاج ملف التعليقات التوضيحية `.lab` فقط إلى احتواء النص المكتوب للمقطع الصوتي، دون الحاجة إلى تنسيق خاص. على سبيل المثال، إذا كان محتوى `hi.mp3` هو "مرحبًا، وداعًا"، فسيحتوي ملف `hi.lab` على سطر واحد من النص: "مرحبًا، وداعًا".
+
+!!! warning "تحذير"
+    يوصى بتطبيق تسوية جهارة الصوت (loudness normalization) على مجموعة البيانات. يمكنك استخدام [fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess) للقيام بذلك.
+    ```bash
+    fap loudness-norm data-raw data --clean
+    ```
+
+### 2. الاستخراج الدفعي للرموز الدلالية (semantic tokens)
+
+تأكد من أنك قمت بتنزيل أوزان VQGAN. إذا لم تكن قد فعلت، قم بتشغيل الأمر التالي:
+
+```bash
+huggingface-cli download fishaudio/openaudio-s1-mini --local-dir checkpoints/openaudio-s1-mini
+```
+
+يمكنك بعد ذلك تشغيل الأمر التالي لاستخراج الرموز الدلالية:
+
+```bash
+python tools/vqgan/extract_vq.py data \
+    --num-workers 1 --batch-size 16 \
+    --config-name "modded_dac_vq" \
+    --checkpoint-path "checkpoints/openaudio-s1-mini/codec.pth"
+```
+
+!!! note "ملاحظة"
+    يمكنك ضبط `--num-workers` و `--batch-size` لزيادة سرعة الاستخراج، ولكن يرجى التأكد من عدم تجاوز حد ذاكرة وحدة معالجة الرسومات (GPU) الخاصة بك.
+
+سيقوم هذا الأمر بإنشاء ملفات `.npy` في مجلد `data`، كما هو موضح أدناه:
+
+```
+.
+├── SPK1
+│   ├── 21.15-26.44.lab
+│   ├── 21.15-26.44.mp3
+│   ├── 21.15-26.44.npy
+│   ├── 27.51-29.98.lab
+│   ├── 27.51-29.98.mp3
+│   ├── 27.51-29.98.npy
+│   ├── 30.1-32.71.lab
+│   ├── 30.1-32.71.mp3
+│   └── 30.1-32.71.npy
+└── SPK2
+    ├── 38.79-40.85.lab
+    ├── 38.79-40.85.mp3
+    └── 38.79-40.85.npy
+```
+
+### 3. حزم مجموعة البيانات في protobuf
+
+```bash
+python tools/llama/build_dataset.py \
+    --input "data" \
+    --output "data/protos" \
+    --text-extension .lab \
+    --num-workers 16
+```
+
+بعد انتهاء تنفيذ الأمر، يجب أن ترى ملف `protos` في مجلد `data`.
+
+### 4. أخيرًا، الضبط الدقيق باستخدام LoRA
+
+بالمثل، تأكد من أنك قمت بتنزيل أوزان `LLAMA`. إذا لم تكن قد فعلت، قم بتشغيل الأمر التالي:
+
+```bash
+huggingface-cli download fishaudio/openaudio-s1-mini --local-dir checkpoints/openaudio-s1-mini
+```
+
+أخيرًا، يمكنك بدء الضبط الدقيق عن طريق تشغيل الأمر التالي:
+
+```bash
+python fish_speech/train.py --config-name text2semantic_finetune \
+    project=$project \
+    +lora@model.model.lora_config=r_8_alpha_16
+```
+
+!!! note "ملاحظة"
+    يمكنك تعديل معلمات التدريب مثل `batch_size`، `gradient_accumulation_steps`، وما إلى ذلك لتناسب ذاكرة وحدة معالجة الرسومات الخاصة بك عن طريق تعديل `fish_speech/configs/text2semantic_finetune.yaml`.
+
+!!! note "ملاحظة"
+    لمستخدمي Windows، يمكنك استخدام `trainer.strategy.process_group_backend=gloo` لتجنب مشكلات `nccl`.
+
+بعد اكتمال التدريب، يمكنك الرجوع إلى قسم [الاستدلال (inference)](inference.md) لاختبار نموذجك.
+
+!!! info "معلومات"
+    بشكل افتراضي، سيتعلم النموذج فقط أنماط كلام المتحدث وليس جرس الصوت (timbre). لا تزال بحاجة إلى استخدام التلقينات (prompts) لضمان استقرار جرس الصوت.
+    إذا كنت ترغب في تعلم جرس الصوت، يمكنك زيادة عدد خطوات التدريب، ولكن هذا قد يؤدي إلى الإفراط في التخصيص (overfitting).
+
+بعد التدريب، تحتاج إلى تحويل أوزان LoRA إلى أوزان عادية قبل إجراء الاستدلال.
+
+```bash
+python tools/llama/merge_lora.py \
+	--lora-config r_8_alpha_16 \
+	--base-weight checkpoints/openaudio-s1-mini \
+	--lora-weight results/$project/checkpoints/step_000000010.ckpt \
+	--output checkpoints/openaudio-s1-mini-yth-lora/
+```
+!!! note "ملاحظة"
+    يمكنك أيضًا تجربة نقاط تحقق (checkpoints) أخرى. نقترح استخدام أقدم نقطة تحقق تلبي متطلباتك، حيث إنها غالبًا ما تؤدي أداءً أفضل على البيانات خارج التوزيع (OOD).

+ 127 - 0
docs/en/finetune.md

@@ -0,0 +1,127 @@
+# Fine-tuning
+
+Obviously, when you opened this page, you were not satisfied with the performance of the zero-shot pre-trained model. You want to fine-tune a model to improve its performance on your dataset.
+
+In the current version, you only need to finetune the 'LLAMA' part.
+
+## Fine-tuning LLAMA
+### 1. Prepare the dataset
+
+```
+.
+├── SPK1
+│   ├── 21.15-26.44.lab
+│   ├── 21.15-26.44.mp3
+│   ├── 27.51-29.98.lab
+│   ├── 27.51-29.98.mp3
+│   ├── 30.1-32.71.lab
+│   └── 30.1-32.71.mp3
+└── SPK2
+    ├── 38.79-40.85.lab
+    └── 38.79-40.85.mp3
+```
+
+You need to convert your dataset into the above format and place it under `data`. The audio file can have the extensions `.mp3`, `.wav`, or `.flac`, and the annotation file should have the extension `.lab`.
+
+!!! info
+    The `.lab` annotation file only needs to contain the transcription of the audio, with no special formatting required. For example, if `hi.mp3` says "Hello, goodbye," then the `hi.lab` file would contain a single line of text: "Hello, goodbye."
+
+!!! warning
+    It's recommended to apply loudness normalization to the dataset. You can use [fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess) to do this.
+
+    ```bash
+    fap loudness-norm data-raw data --clean
+    ```
+
+
+### 2. Batch extraction of semantic tokens
+
+Make sure you have downloaded the VQGAN weights. If not, run the following command:
+
+```bash
+huggingface-cli download fishaudio/openaudio-s1-mini --local-dir checkpoints/openaudio-s1-mini
+```
+
+You can then run the following command to extract semantic tokens:
+
+```bash
+python tools/vqgan/extract_vq.py data \
+    --num-workers 1 --batch-size 16 \
+    --config-name "modded_dac_vq" \
+    --checkpoint-path "checkpoints/openaudio-s1-mini/codec.pth"
+```
+
+!!! note
+    You can adjust `--num-workers` and `--batch-size` to increase extraction speed, but please make sure not to exceed your GPU memory limit.
+
+This command will create `.npy` files in the `data` directory, as shown below:
+
+```
+.
+├── SPK1
+│   ├── 21.15-26.44.lab
+│   ├── 21.15-26.44.mp3
+│   ├── 21.15-26.44.npy
+│   ├── 27.51-29.98.lab
+│   ├── 27.51-29.98.mp3
+│   ├── 27.51-29.98.npy
+│   ├── 30.1-32.71.lab
+│   ├── 30.1-32.71.mp3
+│   └── 30.1-32.71.npy
+└── SPK2
+    ├── 38.79-40.85.lab
+    ├── 38.79-40.85.mp3
+    └── 38.79-40.85.npy
+```
+
+### 3. Pack the dataset into protobuf
+
+```bash
+python tools/llama/build_dataset.py \
+    --input "data" \
+    --output "data/protos" \
+    --text-extension .lab \
+    --num-workers 16
+```
+
+After the command finishes executing, you should see the `protos` file in the `data` directory.
+
+### 4. Finally, fine-tuning with LoRA
+
+Similarly, make sure you have downloaded the `LLAMA` weights. If not, run the following command:
+
+```bash
+huggingface-cli download fishaudio/openaudio-s1-mini --local-dir checkpoints/openaudio-s1-mini
+```
+
+Finally, you can start the fine-tuning by running the following command:
+
+```bash
+python fish_speech/train.py --config-name text2semantic_finetune \
+    project=$project \
+    +lora@model.model.lora_config=r_8_alpha_16
+```
+
+!!! note
+    You can modify the training parameters such as `batch_size`, `gradient_accumulation_steps`, etc. to fit your GPU memory by modifying `fish_speech/configs/text2semantic_finetune.yaml`.
+
+!!! note
+    For Windows users, you can use `trainer.strategy.process_group_backend=gloo` to avoid `nccl` issues.
+
+After training is complete, you can refer to the [inference](inference.md) section to test your model.
+
+!!! info
+    By default, the model will only learn the speaker's speech patterns and not the timbre. You still need to use prompts to ensure timbre stability.
+    If you want to learn the timbre, you can increase the number of training steps, but this may lead to overfitting.
+
+After training, you need to convert the LoRA weights to regular weights before performing inference.
+
+```bash
+python tools/llama/merge_lora.py \
+	--lora-config r_8_alpha_16 \
+	--base-weight checkpoints/openaudio-s1-mini \
+	--lora-weight results/$project/checkpoints/step_000000010.ckpt \
+	--output checkpoints/openaudio-s1-mini-yth-lora/
+```
+!!! note
+    You may also try other checkpoints. We suggest using the earliest checkpoint that meets your requirements, as they often perform better on out-of-distribution (OOD) data.

+ 125 - 0
docs/ja/finetune.md

@@ -0,0 +1,125 @@
+# ファインチューニング
+
+このページを開いたということは、明らかに、事前学習済みモデルのゼロショット性能に満足していないということでしょう。データセットでより良い性能を発揮するようにモデルをファインチューニングしたいとお考えのはずです。
+
+現在のバージョンでは、「LLAMA」部分のみをファインチューニングする必要があります。
+
+## LLAMA のファインチューニング
+### 1. データセットの準備
+
+```
+.
+├── SPK1
+│   ├── 21.15-26.44.lab
+│   ├── 21.15-26.44.mp3
+│   ├── 27.51-29.98.lab
+│   ├── 27.51-29.98.mp3
+│   ├── 30.1-32.71.lab
+│   └── 30.1-32.71.mp3
+└── SPK2
+    ├── 38.79-40.85.lab
+    └── 38.79-40.85.mp3
+```
+
+データセットを上記の形式に変換し、`data` ディレクトリに配置する必要があります。音声ファイルの拡張子は `.mp3`、`.wav`、または `.flac` が使用でき、注釈ファイルの拡張子は `.lab` にすることを推奨します。
+
+!!! info
+    `.lab` 注釈ファイルには、音声の書き起こしテキストのみを含める必要があり、特別なフォーマット要件はありません。たとえば、`hi.mp3` の内容が「こんにちは、さようなら。」である場合、`hi.lab` ファイルには「こんにちは、さようなら。」という一行のテキストのみが含まれます。
+
+!!! warning
+    データセットにラウドネス正規化を適用することをお勧めします。これには [fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess) を使用できます。
+    ```bash
+    fap loudness-norm data-raw data --clean
+    ```
+
+### 2. セマンティックトークンの一括抽出
+
+VQGANの重みをダウンロードしていることを確認してください。まだの場合は、次のコマンドを実行してください。
+
+```bash
+huggingface-cli download fishaudio/openaudio-s1-mini --local-dir checkpoints/openaudio-s1-mini
+```
+
+その後、次のコマンドを実行してセマンティックトークンを抽出できます。
+
+```bash
+python tools/vqgan/extract_vq.py data \
+    --num-workers 1 --batch-size 16 \
+    --config-name "modded_dac_vq" \
+    --checkpoint-path "checkpoints/openaudio-s1-mini/codec.pth"
+```
+
+!!! note
+    `--num-workers` と `--batch-size` を調整して抽出速度を向上させることができますが、GPUメモリの制限を超えないように注意してください。
+
+このコマンドは `data` ディレクトリに `.npy` ファイルを作成します。以下のようになります。
+
+```
+.
+├── SPK1
+│   ├── 21.15-26.44.lab
+│   ├── 21.15-26.44.mp3
+│   ├── 21.15-26.44.npy
+│   ├── 27.51-29.98.lab
+│   ├── 27.51-29.98.mp3
+│   ├── 27.51-29.98.npy
+│   ├── 30.1-32.71.lab
+│   ├── 30.1-32.71.mp3
+│   └── 30.1-32.71.npy
+└── SPK2
+    ├── 38.79-40.85.lab
+    ├── 38.79-40.85.mp3
+    └── 38.79-40.85.npy```
+
+### 3. データセットを protobuf にパックする
+
+```bash
+python tools/llama/build_dataset.py \
+    --input "data" \
+    --output "data/protos" \
+    --text-extension .lab \
+    --num-workers 16
+```
+
+コマンドの実行が完了すると、`data` ディレクトリに `protos` ファイルが表示されるはずです。
+
+### 4. 最後に LoRA でファインチューニング
+
+同様に、`LLAMA` の重みをダウンロードしていることを確認してください。まだの場合は、次のコマンドを実行してください。
+
+```bash
+huggingface-cli download fishaudio/openaudio-s1-mini --local-dir checkpoints/openaudio-s1-mini
+```
+
+最後に、次のコマンドを実行してファインチューニングを開始できます。
+
+```bash
+python fish_speech/train.py --config-name text2semantic_finetune \
+    project=$project \
+    +lora@model.model.lora_config=r_8_alpha_16
+```
+
+!!! note
+    `fish_speech/configs/text2semantic_finetune.yaml` を変更することで、`batch_size` や `gradient_accumulation_steps` などのトレーニングパラメータをGPUメモリに合わせて変更できます。
+
+!!! note
+    Windows ユーザーの場合、`trainer.strategy.process_group_backend=gloo` を使用して `nccl` の問題を回避できます。
+
+トレーニングが完了したら、[推論](inference.md) のセクションを参照してモデルをテストできます。
+
+!!! info
+    デフォルト設定では、モデルは話者の発音方法のみを学習し、音色は学習しません。音色の安定性を確保するためには、依然としてプロンプトを使用する必要があります。
+    音色を学習させたい場合は、トレーニングステップ数を増やしてください。ただし、これにより過学習が発生する可能性があります。
+
+トレーニング後、推論を行う前に LoRA の重みを通常の重みに変換する必要があります。
+
+```bash
+python tools/llama/merge_lora.py \
+	--lora-config r_8_alpha_16 \
+	--base-weight checkpoints/openaudio-s1-mini \
+	--lora-weight results/$project/checkpoints/step_000000010.ckpt \
+	--output checkpoints/openaudio-s1-mini-yth-lora/
+```
+
+!!! note
+    他のチェックポイントを試すこともできます。要件を満たす最も早いチェックポイントを使用することをお勧めします。これらは通常、OOD(分布外)データに対してより良いパフォーマンスを発揮します。

+ 126 - 0
docs/ko/finetune.md

@@ -0,0 +1,126 @@
+# 미세 조정 (Fine-tuning)
+
+이 페이지를 열었다는 것은, 사전 훈련된 모델의 제로샷(zero-shot) 성능에 만족하지 못했다는 의미일 것입니다. 여러분의 데이터셋에서 더 나은 성능을 내도록 모델을 미세 조정하고 싶으실 겁니다.
+
+현재 버전에서는 'LLAMA' 부분만 미세 조정하면 됩니다.
+
+## LLAMA 미세 조정
+### 1. 데이터셋 준비
+
+```
+.
+├── SPK1
+│   ├── 21.15-26.44.lab
+│   ├── 21.15-26.44.mp3
+│   ├── 27.51-29.98.lab
+│   ├── 27.51-29.98.mp3
+│   ├── 30.1-32.71.lab
+│   └── 30.1-32.71.mp3
+└── SPK2
+    ├── 38.79-40.85.lab
+    └── 38.79-40.85.mp3
+```
+
+데이터셋을 위 형식으로 변환하여 `data` 폴더 아래에 배치해야 합니다. 오디오 파일 확장자는 `.mp3`, `.wav` 또는 `.flac`일 수 있으며, 주석 파일 확장자는 `.lab`을 권장합니다.
+
+!!! info
+    `.lab` 주석 파일에는 오디오의 전사 텍스트만 포함하면 되며, 특별한 형식 요구사항은 없습니다. 예를 들어 `hi.mp3`의 내용이 "안녕하세요, 안녕히 가세요."라면, `hi.lab` 파일에는 "안녕하세요, 안녕히 가세요."라는 한 줄의 텍스트만 포함하면 됩니다.
+
+!!! warning
+    데이터셋에 음량 정규화를 적용하는 것이 좋습니다. 이를 위해 [fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess)를 사용할 수 있습니다.
+    ```bash
+    fap loudness-norm data-raw data --clean
+    ```
+
+### 2. 시맨틱 토큰 일괄 추출
+
+VQGAN 가중치를 다운로드했는지 확인하세요. 그렇지 않은 경우 다음 명령을 실행하세요.
+
+```bash
+huggingface-cli download fishaudio/openaudio-s1-mini --local-dir checkpoints/openaudio-s1-mini
+```
+
+그런 다음 다음 명령을 실행하여 시맨틱 토큰을 추출할 수 있습니다.
+
+```bash
+python tools/vqgan/extract_vq.py data \
+    --num-workers 1 --batch-size 16 \
+    --config-name "modded_dac_vq" \
+    --checkpoint-path "checkpoints/openaudio-s1-mini/codec.pth"
+```
+
+!!! note
+    `--num-workers`와 `--batch-size`를 조정하여 추출 속도를 높일 수 있지만, GPU 메모리 한도를 초과하지 않도록 주의하세요.
+
+이 명령은 `data` 디렉토리에 `.npy` 파일을 생성합니다. 결과는 다음과 같습니다.
+
+```
+.
+├── SPK1
+│   ├── 21.15-26.44.lab
+│   ├── 21.15-26.44.mp3
+│   ├── 21.15-26.44.npy
+│   ├── 27.51-29.98.lab
+│   ├── 27.51-29.98.mp3
+│   ├── 27.51-29.98.npy
+│   ├── 30.1-32.71.lab
+│   ├── 30.1-32.71.mp3
+│   └── 30.1-32.71.npy
+└── SPK2
+    ├── 38.79-40.85.lab
+    ├── 38.79-40.85.mp3
+    └── 38.79-40.85.npy
+```
+
+### 3. 데이터셋을 protobuf로 패킹하기
+
+```bash
+python tools/llama/build_dataset.py \
+    --input "data" \
+    --output "data/protos" \
+    --text-extension .lab \
+    --num-workers 16
+```
+
+명령 실행이 완료되면 `data` 디렉토리에서 `protos` 파일을 볼 수 있어야 합니다.
+
+### 4. 마지막으로, LoRA로 미세 조정하기
+
+마찬가지로, `LLAMA` 가중치를 다운로드했는지 확인하세요. 그렇지 않은 경우 다음 명령을 실행하세요.
+
+```bash
+huggingface-cli download fishaudio/openaudio-s1-mini --local-dir checkpoints/openaudio-s1-mini
+```
+
+마지막으로, 다음 명령을 실행하여 미세 조정을 시작할 수 있습니다.
+
+```bash
+python fish_speech/train.py --config-name text2semantic_finetune \
+    project=$project \
+    +lora@model.model.lora_config=r_8_alpha_16
+```
+
+!!! note
+    `fish_speech/configs/text2semantic_finetune.yaml` 파일을 수정하여 `batch_size`, `gradient_accumulation_steps` 등 훈련 매개변수를 GPU 메모리에 맞게 조정할 수 있습니다.
+
+!!! note
+    Windows 사용자의 경우, `trainer.strategy.process_group_backend=gloo`를 사용하여 `nccl` 관련 문제를 피할 수 있습니다.
+
+훈련이 완료되면 [추론](inference.md) 섹션을 참조하여 모델을 테스트할 수 있습니다.
+
+!!! info
+    기본 설정에서는 모델이 화자의 발음 방식만 학습하고 음색은 학습하지 않습니다. 음색 안정성을 보장하려면 여전히 프롬프트를 사용해야 합니다.
+    음색을 학습시키고 싶다면 훈련 스텝 수를 늘리되, 이는 과적합(overfitting)으로 이어질 수 있습니다.
+
+훈련 후, 추론을 수행하기 전에 LoRA 가중치를 일반 가중치로 변환해야 합니다.
+
+```bash
+python tools/llama/merge_lora.py \
+	--lora-config r_8_alpha_16 \
+	--base-weight checkpoints/openaudio-s1-mini \
+	--lora-weight results/$project/checkpoints/step_000000010.ckpt \
+	--output checkpoints/openaudio-s1-mini-yth-lora/
+```
+
+!!! note
+    다른 체크포인트를 시도해 볼 수도 있습니다. 요구 사항을 충족하는 가장 이른 체크포인트를 사용하는 것이 좋습니다. 이러한 체크포인트는 보통 OOD(분포 외) 데이터에서 더 나은 성능을 보입니다.

+ 124 - 0
docs/pt/finetune.md

@@ -0,0 +1,124 @@
+# Ajuste Fino (Fine-tuning)
+
+Obviamente, ao abrir esta página, você não estava satisfeito com o desempenho do modelo pré-treinado em modo zero-shot. Você deseja fazer um ajuste fino em um modelo para melhorar seu desempenho em seu conjunto de dados.
+
+Na versão atual, você só precisa fazer o ajuste fino da parte 'LLAMA'.
+
+## Ajuste Fino do LLAMA
+### 1. Prepare o conjunto de dados
+
+```
+.
+├── SPK1
+│   ├── 21.15-26.44.lab
+│   ├── 21.15-26.44.mp3
+│   ├── 27.51-29.98.lab
+│   ├── 27.51-29.98.mp3
+│   ├── 30.1-32.71.lab
+│   └── 30.1-32.71.mp3
+└── SPK2
+    ├── 38.79-40.85.lab
+    └── 38.79-40.85.mp3
+```
+
+Você precisa converter seu conjunto de dados para o formato acima e colocá-lo no diretório `data`. O arquivo de áudio pode ter as extensões `.mp3`, `.wav` ou `.flac`, e o arquivo de anotação deve ter a extensão `.lab`.
+
+!!! info
+    O arquivo de anotação `.lab` precisa conter apenas a transcrição do áudio, sem necessidade de formatação especial. Por exemplo, se `hi.mp3` contiver "Olá, adeus.", então o arquivo `hi.lab` conterá uma única linha de texto: "Olá, adeus.".
+
+!!! warning
+    Recomenda-se aplicar a normalização de volume (loudness) ao conjunto de dados. Você pode usar o [fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess) para fazer isso.
+    ```bash
+    fap loudness-norm data-raw data --clean
+    ```
+
+### 2. Extração em lote de tokens semânticos
+
+Certifique-se de que você baixou os pesos do VQGAN. Se não, execute o seguinte comando:
+
+```bash
+huggingface-cli download fishaudio/openaudio-s1-mini --local-dir checkpoints/openaudio-s1-mini
+```
+
+Em seguida, você pode executar o seguinte comando para extrair os tokens semânticos:
+
+```bash
+python tools/vqgan/extract_vq.py data \
+    --num-workers 1 --batch-size 16 \
+    --config-name "modded_dac_vq" \
+    --checkpoint-path "checkpoints/openaudio-s1-mini/codec.pth"
+```
+
+!!! note
+    Você pode ajustar `--num-workers` e `--batch-size` para aumentar a velocidade de extração, mas certifique-se de não exceder o limite de memória da sua GPU.
+
+Este comando criará arquivos `.npy` no diretório `data`, como mostrado abaixo:
+
+```
+.
+├── SPK1
+│   ├── 21.15-26.44.lab
+│   ├── 21.15-26.44.mp3
+│   ├── 21.15-26.44.npy
+│   ├── 27.51-29.98.lab
+│   ├── 27.51-29.98.mp3
+│   ├── 27.51-29.98.npy
+│   ├── 30.1-32.71.lab
+│   ├── 30.1-32.71.mp3
+│   └── 30.1-32.71.npy
+└── SPK2
+    ├── 38.79-40.85.lab
+    ├── 38.79-40.85.mp3
+    └── 38.79-40.85.npy
+```
+
+### 3. Empacote o conjunto de dados em protobuf
+
+```bash
+python tools/llama/build_dataset.py \
+    --input "data" \
+    --output "data/protos" \
+    --text-extension .lab \
+    --num-workers 16
+```
+
+Após a conclusão da execução do comando, você deverá ver o arquivo `protos` no diretório `data`.
+
+### 4. Finalmente, ajuste fino com LoRA
+
+Da mesma forma, certifique-se de que você baixou os pesos do `LLAMA`. Se não, execute o seguinte comando:
+
+```bash
+huggingface-cli download fishaudio/openaudio-s1-mini --local-dir checkpoints/openaudio-s1-mini
+```
+
+Finalmente, você pode iniciar o ajuste fino executando o seguinte comando:
+
+```bash
+python fish_speech/train.py --config-name text2semantic_finetune \
+    project=$project \
+    +lora@model.model.lora_config=r_8_alpha_16
+```
+
+!!! note
+    Você pode modificar os parâmetros de treinamento, como `batch_size`, `gradient_accumulation_steps`, etc., para se adequar à memória da sua GPU, modificando `fish_speech/configs/text2semantic_finetune.yaml`.
+
+!!! note
+    Para usuários do Windows, você pode usar `trainer.strategy.process_group_backend=gloo` para evitar problemas com `nccl`.
+
+Após o treinamento ser concluído, você pode consultar a seção de [inferência](inference.md) para testar seu modelo.
+
+!!! info
+    Por padrão, o modelo aprenderá apenas os padrões de fala do locutor e não o timbre. Você ainda precisará usar prompts para garantir a estabilidade do timbre.
+    Se você quiser aprender o timbre, pode aumentar o número de passos de treinamento, mas isso pode levar a um sobreajuste (overfitting).
+
+Após o treinamento, você precisa converter os pesos do LoRA para pesos regulares antes de realizar a inferência.
+
+```bash
+python tools/llama/merge_lora.py \
+	--lora-config r_8_alpha_16 \
+	--base-weight checkpoints/openaudio-s1-mini \
+	--lora-weight results/$project/checkpoints/step_000000010.ckpt \
+	--output checkpoints/openaudio-s1-mini-yth-lora/```
+!!! note
+    Você também pode tentar outros checkpoints. Sugerimos usar o checkpoint mais antigo que atenda aos seus requisitos, pois eles geralmente têm um desempenho melhor em dados fora de distribuição (OOD).

+ 139 - 0
docs/zh/finetune.md

@@ -0,0 +1,139 @@
+# 微调
+
+显然, 当你打开这个页面的时候, 你已经对预训练模型 zero-shot 的效果不算满意. 你想要微调一个模型, 使得它在你的数据集上表现更好.  
+
+在目前版本,你只需要微调'LLAMA'部分即可.
+
+## LLAMA 微调
+### 1. 准备数据集
+
+```
+.
+├── SPK1
+│   ├── 21.15-26.44.lab
+│   ├── 21.15-26.44.mp3
+│   ├── 27.51-29.98.lab
+│   ├── 27.51-29.98.mp3
+│   ├── 30.1-32.71.lab
+│   └── 30.1-32.71.mp3
+└── SPK2
+    ├── 38.79-40.85.lab
+    └── 38.79-40.85.mp3
+```
+
+你需要将数据集转为以上格式, 并放到 `data` 下, 音频后缀可以为 `.mp3`, `.wav` 或 `.flac`, 标注文件后缀建议为 `.lab`.
+
+!!! info
+    标注文件 `.lab` 仅需包含音频的转写文本,无需遵循特殊格式要求。例如,如果 `hi.mp3` 中的内容是“你好,再见。”,那么 `hi.lab` 文件中只需包含一行文本:“你好,再见”。    
+
+!!! warning
+    建议先对数据集进行响度匹配, 你可以使用 [fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess) 来完成这一步骤. 
+    ```bash
+    fap loudness-norm data-raw data --clean
+    ```
+
+### 2. 批量提取语义 token
+
+确保你已经下载了 vqgan 权重, 如果没有, 请运行以下命令:
+
+```bash
+huggingface-cli download fishaudio/openaudio-s1-mini --local-dir checkpoints/openaudio-s1-mini
+```
+
+对于中国大陆用户, 可使用 mirror 下载.
+
+```bash
+HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/openaudio-s1-mini --local-dir checkpoints/openaudio-s1-mini
+```
+
+随后可运行以下命令来提取语义 token:
+
+```bash
+python tools/vqgan/extract_vq.py data \
+    --num-workers 1 --batch-size 16 \
+    --config-name "modded_dac_vq" \
+    --checkpoint-path "checkpoints/openaudio-s1-mini/codec.pth"
+```
+
+!!! note
+    你可以调整 `--num-workers` 和 `--batch-size` 来提高提取速度, 但是请注意不要超过你的显存限制.  
+
+该命令会在 `data` 目录下创建 `.npy` 文件, 如下所示:
+
+```
+.
+├── SPK1
+│   ├── 21.15-26.44.lab
+│   ├── 21.15-26.44.mp3
+│   ├── 21.15-26.44.npy
+│   ├── 27.51-29.98.lab
+│   ├── 27.51-29.98.mp3
+│   ├── 27.51-29.98.npy
+│   ├── 30.1-32.71.lab
+│   ├── 30.1-32.71.mp3
+│   └── 30.1-32.71.npy
+└── SPK2
+    ├── 38.79-40.85.lab
+    ├── 38.79-40.85.mp3
+    └── 38.79-40.85.npy
+```
+
+### 3. 打包数据集为 protobuf
+
+```bash
+python tools/llama/build_dataset.py \
+    --input "data" \
+    --output "data/protos" \
+    --text-extension .lab \
+    --num-workers 16
+```
+
+命令执行完毕后, 你应该能在 `data` 目录下看到 `protos` 文件.
+
+
+### 4. 最后, 使用 LoRA 进行微调
+
+同样的, 请确保你已经下载了 `LLAMA` 权重, 如果没有, 请运行以下命令:
+
+```bash
+huggingface-cli download fishaudio/openaudio-s1-mini --local-dir checkpoints/openaudio-s1-mini
+```
+
+对于中国大陆用户, 可使用 mirror 下载.
+
+```bash
+HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/openaudio-s1-mini --local-dir checkpoints/openaudio-s1-mini
+```
+
+最后, 你可以运行以下命令来启动微调:
+
+```bash
+python fish_speech/train.py --config-name text2semantic_finetune \
+    project=$project \
+    +lora@model.model.lora_config=r_8_alpha_16
+```
+
+!!! note
+    你可以通过修改 `fish_speech/configs/text2semantic_finetune.yaml` 来修改训练参数如 `batch_size`, `gradient_accumulation_steps` 等, 来适应你的显存.
+
+!!! note
+    对于 Windows 用户, 你可以使用 `trainer.strategy.process_group_backend=gloo` 来避免 `nccl` 的问题.
+
+训练结束后, 你可以参考 [推理](inference.md) 部分来测试你的模型.
+
+!!! info
+    默认配置下, 基本只会学到说话人的发音方式, 而不包含音色, 你依然需要使用 prompt 来保证音色的稳定性.  
+    如果你想要学到音色, 请将训练步数调大, 但这有可能会导致过拟合. 
+
+训练完成后, 你需要先将 loRA 的权重转为普通权重, 然后再进行推理.
+
+```bash
+python tools/llama/merge_lora.py \
+	--lora-config r_8_alpha_16 \
+	--base-weight checkpoints/openaudio-s1-mini \
+	--lora-weight results/$project/checkpoints/step_000000010.ckpt \
+	--output checkpoints/openaudio-s1-mini-yth-lora/
+```
+
+!!! note
+    你也可以尝试其他的 checkpoint, 我们建议你使用最早的满足你要求的 checkpoint, 他们通常在 OOD 上表现更好.

+ 3 - 0
fish_speech/callbacks/__init__.py

@@ -0,0 +1,3 @@
+from .grad_norm import GradNormMonitor
+
+__all__ = ["GradNormMonitor"]

+ 113 - 0
fish_speech/callbacks/grad_norm.py

@@ -0,0 +1,113 @@
+from typing import Optional, Union
+
+import lightning.pytorch as pl
+import torch
+from lightning import LightningModule, Trainer
+from lightning.pytorch.callbacks import Callback
+from torch import Tensor, nn
+from torch.utils._foreach_utils import (
+    _group_tensors_by_device_and_dtype,
+    _has_foreach_support,
+)
+
+
+@torch.no_grad()
+def grad_norm(
+    parameters: Union[Tensor, list[Tensor]],
+    norm_type: float = 2.0,
+) -> float:
+    """
+    Returns the norm of the gradients of the given parameters.
+
+    Args:
+        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
+            single Tensor that will have gradients normalized
+        norm_type (float): type of the used p-norm.
+
+    Returns:
+        Total norm of the parameter gradients (viewed as a single vector).
+    """  # noqa: E501
+
+    if isinstance(parameters, Tensor):
+        parameters = [parameters]
+
+    grads = [p.grad for p in parameters if p.grad is not None]
+    if len(grads) == 0:
+        return None
+
+    first_device = grads[0].device
+    grouped_grads: dict[
+        tuple[torch.device, torch.dtype], list[list[Tensor]]
+    ] = _group_tensors_by_device_and_dtype(
+        [[g.detach() for g in grads]]
+    )  # type: ignore[assignment]
+
+    norms = []
+    for (device, _), ([grads], _) in grouped_grads.items():
+        if _has_foreach_support(grads, device=device):
+            norms.extend(torch._foreach_norm(grads, norm_type))
+        else:
+            norms.extend([torch.norm(g, norm_type) for g in grads])
+
+    return torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)
+
+
+class GradNormMonitor(Callback):
+    """
+    Callback that computes the gradient norm of the model parameters.
+    """
+
+    def __init__(
+        self,
+        norm_type: float = 2.0,
+        logging_interval: str = "step",
+        sub_module: Optional[Union[str, list[str]]] = None,
+    ) -> None:
+        """
+        Args:
+            norm_type (float): type of the used p-norm.
+            logging_interval (str): "step" or "epoch".
+        """
+        super().__init__()
+
+        self.norm_type = norm_type
+        self.logging_interval = logging_interval
+        self.sub_module = sub_module
+
+    def on_after_backward(self, trainer: Trainer, model: LightningModule) -> None:
+        """
+        Computes the gradient norm of the model parameters and logs it to the logger.
+
+        Args:
+            trainer (Trainer): The trainer object
+            model (LightningModule): The current lightningModule
+        """
+
+        lightning_model = model
+
+        if self.sub_module is None:
+            return self.log_sub_module_grad_norm(lightning_model, model, "")
+
+        sub_modules = self.sub_module
+        if isinstance(sub_modules, str):
+            sub_modules = [sub_modules]
+
+        for sub_module in sub_modules:
+            self.log_sub_module_grad_norm(
+                lightning_model, getattr(model, sub_module), f"/{sub_module}"
+            )
+
+    def log_sub_module_grad_norm(
+        self, lightning_model: LightningModule, model: nn.Module, path: str
+    ) -> None:
+        grad_norm_val = grad_norm(model.parameters(), self.norm_type)
+        if grad_norm_val is None:
+            return
+
+        on_step = self.logging_interval == "step"
+        lightning_model.log(
+            f"train{path}/grad_norm",
+            grad_norm_val,
+            on_step=on_step,
+            on_epoch=not on_step,
+        )

+ 53 - 0
fish_speech/datasets/concat_repeat.py

@@ -0,0 +1,53 @@
+import bisect
+import random
+from typing import Iterable
+
+from torch.utils.data import Dataset, IterableDataset
+
+
+class ConcatRepeatDataset(Dataset):
+    datasets: list[Dataset]
+    cumulative_sizes: list[int]
+    repeats: list[int]
+
+    @staticmethod
+    def cumsum(sequence, repeats):
+        r, s = [], 0
+        for dataset, repeat in zip(sequence, repeats):
+            l = len(dataset) * repeat
+            r.append(l + s)
+            s += l
+        return r
+
+    def __init__(self, datasets: Iterable[Dataset], repeats: list[int]):
+        super().__init__()
+
+        self.datasets = list(datasets)
+        self.repeats = repeats
+
+        assert len(self.datasets) > 0, "datasets should not be an empty iterable"
+        assert len(self.datasets) == len(
+            repeats
+        ), "datasets and repeats should have the same length"
+
+        for d in self.datasets:
+            assert not isinstance(
+                d, IterableDataset
+            ), "ConcatRepeatDataset does not support IterableDataset"
+
+        self.cumulative_sizes = self.cumsum(self.datasets, self.repeats)
+
+    def __len__(self):
+        return self.cumulative_sizes[-1]
+
+    def __getitem__(self, idx):
+        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
+
+        if dataset_idx == 0:
+            sample_idx = idx
+        else:
+            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
+
+        dataset = self.datasets[dataset_idx]
+
+        return dataset[sample_idx % len(dataset)]

+ 24 - 0
fish_speech/datasets/protos/text-data.proto

@@ -0,0 +1,24 @@
+syntax = "proto3";
+
+package text_data;
+
+message Semantics {
+    repeated uint32 values = 1;
+}
+
+message Sentence {
+    repeated string texts = 1;
+    repeated Semantics semantics = 3;
+}
+
+message TextData {
+    string source = 1;
+    string name = 2;
+    repeated Sentence sentences = 4;
+}
+
+message SampledData {
+    string source = 1;
+    string name = 2;
+    repeated Sentence samples = 3;
+}

+ 33 - 0
fish_speech/datasets/protos/text_data_pb2.py

@@ -0,0 +1,33 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler.  DO NOT EDIT!
+# source: text-data.proto
+# Protobuf Python Version: 4.25.1
+"""Generated protocol buffer code."""
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import symbol_database as _symbol_database
+from google.protobuf.internal import builder as _builder
+
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
+    b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"B\n\x08Sentence\x12\r\n\x05texts\x18\x01 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"P\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"Q\n\x0bSampledData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12$\n\x07samples\x18\x03 \x03(\x0b\x32\x13.text_data.Sentenceb\x06proto3'
+)
+
+_globals = globals()
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "text_data_pb2", _globals)
+if _descriptor._USE_C_DESCRIPTORS == False:
+    DESCRIPTOR._options = None
+    _globals["_SEMANTICS"]._serialized_start = 30
+    _globals["_SEMANTICS"]._serialized_end = 57
+    _globals["_SENTENCE"]._serialized_start = 59
+    _globals["_SENTENCE"]._serialized_end = 125
+    _globals["_TEXTDATA"]._serialized_start = 127
+    _globals["_TEXTDATA"]._serialized_end = 207
+    _globals["_SAMPLEDDATA"]._serialized_start = 209
+    _globals["_SAMPLEDDATA"]._serialized_end = 290
+# @@protoc_insertion_point(module_scope)

+ 36 - 0
fish_speech/datasets/protos/text_data_stream.py

@@ -0,0 +1,36 @@
+import struct
+
+from .text_data_pb2 import TextData
+
+
+def read_pb_stream(f):
+    while True:
+        buf = f.read(4)
+        if len(buf) == 0:
+            break
+        size = struct.unpack("I", buf)[0]
+        buf = f.read(size)
+        text_data = TextData()
+        text_data.ParseFromString(buf)
+        yield text_data
+
+
+def write_pb_stream(f, text_data):
+    buf = text_data.SerializeToString()
+    f.write(struct.pack("I", len(buf)))
+    f.write(buf)
+
+
+def pack_pb_stream(text_data):
+    buf = text_data.SerializeToString()
+    return struct.pack("I", len(buf)) + buf
+
+
+def split_pb_stream(f):
+    while True:
+        head = f.read(4)
+        if len(head) == 0:
+            break
+        size = struct.unpack("I", head)[0]
+        buf = f.read(size)
+        yield head + buf

+ 627 - 0
fish_speech/datasets/semantic.py

@@ -0,0 +1,627 @@
+import random
+from dataclasses import dataclass
+from itertools import chain
+from pathlib import Path
+from random import Random
+from typing import Optional, Union
+
+import numpy as np
+import pyarrow.parquet as pq
+import torch
+import torch.nn.functional as F
+from datasets.download.streaming_download_manager import xopen
+from huggingface_hub import HfApi
+from lightning import LightningDataModule
+from torch.distributed import get_rank, get_world_size, is_initialized
+from torch.utils.data import DataLoader, Dataset, IterableDataset, get_worker_info
+
+from fish_speech.content_sequence import ContentSequence, TextPart, VQPart
+
+CODEBOOK_PAD_TOKEN_ID = 0
+
+from fish_speech.datasets.protos.text_data_pb2 import SampledData
+from fish_speech.datasets.protos.text_data_stream import read_pb_stream
+from fish_speech.text.clean import clean_text
+from fish_speech.tokenizer import FishTokenizer
+from fish_speech.utils import RankedLogger
+from fish_speech.utils.braceexpand import braceexpand
+
+log = RankedLogger(__name__, rank_zero_only=True)
+
+
+def split_by_rank_worker(files):
+    # We need to know the total number of devices
+    # to split the data properly
+
+    total_devices = 1
+    if is_initialized():
+        total_devices = get_world_size()
+
+    worker_info = get_worker_info()
+    if worker_info is not None:
+        total_devices *= worker_info.num_workers
+
+    if len(files) < total_devices:
+        # Repeat the files N times to match the number of devices
+        files = files * (total_devices // len(files) + 1)
+
+    # DDP
+    if is_initialized():
+        files = files[get_rank() :: get_world_size()]
+
+    # Split by worker
+    if worker_info is not None:
+        files = files[worker_info.id :: worker_info.num_workers]
+
+    return files
+
+
+class AutoTextSemanticInstructionIterableDataset(IterableDataset):
+    """
+    Auto Augment Dataset by Speaker
+
+    1. Random concatenate multiple sentences from the same speaker to form a longer sentence
+    2. Automatically normalize the text
+
+    For interactive mode, we use the following format (multiple sequences):
+    <s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
+
+    For non-interactive mode, we use the following format (one long sequence):
+    <s> [INST] text [/INST] ... </s>
+    """
+
+    def __init__(
+        self,
+        proto_files: list[str],
+        seed: int = 42,
+        interactive_prob: float = 0.5,
+        max_length: int = 1024,
+        tokenizer: FishTokenizer = None,
+        use_speaker: bool | float = True,
+        causal: bool = True,
+        num_codebooks: Optional[int] = None,
+        skip_text_prob: float = 0.0,
+    ):
+        """
+        Args:
+            proto_files: proto buf files if using local data
+            seed: random seed
+            interactive_prob: probability to use interactive mode
+            max_length: max length of the text
+            tokenizer: tokenizer
+            use_speaker: include speaker information in the prompt
+            causal: use causal sampling when using local data, disable will lead to random sampling
+            num_codebooks: number of codebooks, if None, it will be automatically detected
+            skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
+        """
+
+        super().__init__()
+
+        assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
+
+        self.seed = seed
+        self.max_length = max_length
+        self.tokenizer = tokenizer
+        self.interactive_prob = interactive_prob
+        self.use_speaker = use_speaker
+        self.proto_files = proto_files
+        self.causal = causal
+        self.num_codebooks = num_codebooks
+        self.skip_text_prob = skip_text_prob
+
+        self.groups = None
+
+    def __iter__(self):
+        while True:
+            yield self.augment()
+
+    def init_mock_data_server(self):
+        if self.groups is not None:
+            return
+
+        # Expand the proto files
+        expanded_proto_files = []
+        for filename in self.proto_files:
+            for i in braceexpand(filename):
+                i = Path(i)
+                if i.is_file():
+                    expanded_proto_files.append(i)
+                elif i.is_dir():
+                    expanded_proto_files.extend(i.rglob("*.proto"))
+                    expanded_proto_files.extend(i.rglob("*.protos"))
+                else:
+                    raise ValueError(f"{i} is not a file or directory")
+
+        expanded_proto_files = sorted(expanded_proto_files)
+        Random(self.seed).shuffle(expanded_proto_files)
+
+        self.groups = []
+        shard_proto_files = split_by_rank_worker(expanded_proto_files)
+        log.info(
+            f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
+        )
+
+        count = 0
+        for filename in shard_proto_files:
+            with open(filename, "rb") as f:
+                for text_data in read_pb_stream(f):
+                    self.groups.append(text_data)
+                    count += 1
+
+        log.info(f"Read total {count} groups of data")
+
+        # Shuffle the lines
+        Random(self.seed).shuffle(self.groups)
+        self.group_weights = [len(i.sentences) for i in self.groups]
+
+    def sample_data(self):
+        if self.groups is None:
+            self.init_mock_data_server()
+
+        # Shuffle unique lines, estimate that each sample is at least 20 tokens
+        num_samples = self.max_length // 20
+
+        # choice group based on their number of samples
+        group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
+
+        if self.causal:
+            # Sample in order
+            if num_samples >= len(group.sentences):
+                samples = group.sentences
+            else:
+                begin = random.randint(0, len(group.sentences) - num_samples)
+                samples = group.sentences[begin : begin + num_samples]
+        else:
+            samples = random.choices(
+                group.sentences, k=min(num_samples, len(group.sentences))
+            )
+
+        return SampledData(
+            source=group.source,
+            name=group.name,
+            samples=samples,
+        )
+
+    def pack_sentences(
+        self,
+        sentences: list[str],
+        semantics: list,
+        # speaker: Optional[str] = None, # speaker is now handled by tokens
+        skip_text: bool = False,
+    ):
+
+        seq = ContentSequence()
+
+        seq.append(TextPart(text="Speak out the provided text."))
+
+        # User's turn
+        cated_sentences = " ".join(sentences)
+        if skip_text:
+            cated_sentences = "<|skip_text|>"
+
+        seq.append(
+            TextPart(text=f"<|speaker:user|> {cated_sentences}"),
+            add_end=True,
+        )
+
+        # Assistant's turn
+        vq_codes = [x.values for x in semantics[0]]
+        vq_codes_tensor = torch.tensor(vq_codes).to(torch.int32)
+
+        # 将 cal_loss=True 直接关联到 VQPart 上,这比之前更精确
+        vq_part = VQPart(codes=vq_codes_tensor, cal_loss=True)
+
+        # 将多个 parts 一起添加,最后也加上 <|im_end|>
+        seq.append(
+            [TextPart(text="<|speaker:assistant|> <|voice|>"), vq_part],
+            add_end=True,
+        )
+
+        encoded = seq.encode(
+            tokenizer=self.tokenizer,
+        )
+
+        num_codebooks = (
+            len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
+        )
+
+        tokens_raw = encoded.tokens
+        tokens = torch.zeros((num_codebooks + 1, len(tokens_raw)), dtype=torch.int)
+        tokens[0] = tokens_raw
+
+        vq_parts = encoded.vq_parts
+        vq_parts = [part.to(tokens.device) for part in vq_parts]
+        vq_parts = torch.cat(vq_parts, dim=1)
+        tokens[1:, encoded.vq_mask_tokens] = vq_parts
+
+        labels_raw = encoded.labels
+        labels = torch.full((num_codebooks + 1, len(labels_raw)), -100, dtype=torch.int)
+        labels[0, :] = labels_raw
+        labels[1:, encoded.vq_mask_labels] = vq_parts
+        labels[1:, -1:] = CODEBOOK_PAD_TOKEN_ID
+
+        tokens = tokens.long()
+        labels = labels.long()
+
+        # Verify the padding is correct, and the last token is eos
+        assert (tokens[1:, ~(encoded.vq_mask_tokens)] == CODEBOOK_PAD_TOKEN_ID).all()
+        assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all()
+
+        return tokens, labels
+
+    def augment(self):
+        response = self.sample_data()
+        if len(response.samples) == 0:
+            # Invalid group
+            return None
+
+        samples = list(response.samples)
+        all_tokens, all_labels = [], []
+
+        while len(samples) > 0:
+            sentence = samples.pop(0)
+            text = clean_text(random.choice(sentence.texts))
+
+            tokens, labels = self.pack_sentences(
+                sentences=[text],
+                semantics=[sentence.semantics],
+                # speaker=response.name if use_speaker else None,
+                skip_text=random.random() < self.skip_text_prob,
+            )
+
+            all_tokens.append(tokens)
+            all_labels.append(labels)
+
+        tokens = torch.cat(all_tokens, dim=1)
+        labels = torch.cat(all_labels, dim=1)
+
+        # Verify that the length is correct
+        assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
+
+        data = {"tokens": tokens, "labels": labels}
+
+        return data
+
+
+class AutoTextSemanticInstructionDataset(Dataset):
+    """
+    Auto Augment Dataset by Speaker
+
+    1. Random concatenate multiple sentences from the same speaker to form a longer sentence
+    2. Automatically normalize the text
+
+    For interactive mode, we use the following format (multiple sequences):
+    <s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
+
+    For non-interactive mode, we use the following format (one long sequence):
+    <s> [INST] text [/INST] ... </s>
+    """
+
+    def __init__(
+        self,
+        proto_files: list[str],
+        seed: int = 42,
+        interactive_prob: float = 0.5,
+        max_length: int = 1024,
+        tokenizer: FishTokenizer = None,
+        use_speaker: bool | float = True,
+        causal: bool = True,
+        num_codebooks: Optional[int] = None,
+        skip_text_prob: float = 0.0,
+    ):
+        """
+        Args:
+            proto_files: proto buf files if using local data
+            seed: random seed
+            interactive_prob: probability to use interactive mode
+            max_length: max length of the text
+            tokenizer: tokenizer
+            use_speaker: include speaker information in the prompt
+            causal: use causal sampling when using local data, disable will lead to random sampling
+            num_codebooks: number of codebooks, if None, it will be automatically detected
+            skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
+        """
+        super().__init__()
+
+        assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
+
+        self.seed = seed
+        self.max_length = max_length
+        self.tokenizer = tokenizer
+        self.interactive_prob = interactive_prob
+        self.use_speaker = use_speaker
+        self.proto_files = proto_files
+        self.causal = causal
+        self.num_codebooks = num_codebooks
+        self.skip_text_prob = skip_text_prob
+
+        self.data = []
+        self._init_data()
+
+    def _init_data(self):
+        expanded_proto_files = []
+        for filename in self.proto_files:
+            for i in braceexpand(filename):
+                i = Path(i)
+                if i.is_file():
+                    expanded_proto_files.append(i)
+                elif i.is_dir():
+                    expanded_proto_files.extend(i.rglob("*.proto"))
+                    expanded_proto_files.extend(i.rglob("*.protos"))
+                else:
+                    raise ValueError(f"{i} is not a file or directory")
+
+        expanded_proto_files = sorted(expanded_proto_files)
+        Random(self.seed).shuffle(expanded_proto_files)
+
+        groups = []
+        shard_proto_files = split_by_rank_worker(expanded_proto_files)
+        log.info(
+            f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
+        )
+
+        count = 0
+        for filename in shard_proto_files:
+            with open(filename, "rb") as f:
+                for text_data in read_pb_stream(f):
+                    groups.append(text_data)
+                    count += 1
+
+        log.info(f"Read total {count} groups of data")
+
+        for group in groups:
+            if len(group.sentences) == 0:
+                continue
+
+            samples = list(group.sentences)
+            for sentence in samples:
+                text = clean_text(random.choice(sentence.texts))
+
+                tokens, labels = self.pack_sentences(
+                    sentences=[text],
+                    semantics=[sentence.semantics],
+                    skip_text=random.random() < self.skip_text_prob,
+                )
+
+                self.data.append({"tokens": tokens, "labels": labels})
+
+        random.Random(self.seed).shuffle(self.data)
+
+    def __len__(self):
+        return len(self.data)
+
+    def __getitem__(self, idx):
+        return self.data[idx]
+
+    def pack_sentences(
+        self,
+        sentences: list[str],
+        semantics: list,
+        skip_text: bool = False,
+    ):
+        messages = [
+            Message(
+                role="system",
+                parts=[TextPart(text="Speak out the provided text.")],
+            )
+        ]
+
+        cated_sentences = " ".join(sentences)
+        if skip_text:
+            cated_sentences = "<|skip_text|>"
+
+        messages.append(
+            Message(
+                role="user",
+                parts=[TextPart(text=cated_sentences)],
+            )
+        )
+
+        vq_codes = [x.values for x in semantics[0]]
+        vq_codes_tensor = torch.tensor(vq_codes).to(torch.int32)
+        vqpart = VQPart(codes=vq_codes_tensor)
+        messages.append(
+            Message(
+                role="assistant",
+                parts=[TextPart(text="<|voice|>"), vqpart],
+                cal_loss=True,
+            )
+        )
+
+        num_codebooks = (
+            len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
+        )
+
+        conversation = Conversation(messages=messages)
+        encoded = conversation.encode(
+            tokenizer=self.tokenizer,
+        )
+
+        tokens_raw = encoded.tokens
+        tokens = torch.zeros((num_codebooks + 1, len(tokens_raw)), dtype=torch.int)
+        tokens[0] = tokens_raw
+
+        vq_parts = encoded.vq_parts
+        vq_parts = [part.to(tokens.device) for part in vq_parts]
+        vq_parts = torch.cat(vq_parts, dim=1)
+        tokens[1:, encoded.vq_mask_tokens] = vq_parts
+
+        labels_raw = encoded.labels
+        labels = torch.full((num_codebooks + 1, len(labels_raw)), -100, dtype=torch.int)
+        labels[0, :] = labels_raw
+        labels[1:, encoded.vq_mask_labels] = vq_parts
+        labels[1:, -1:] = CODEBOOK_PAD_TOKEN_ID
+
+        tokens = tokens.long()
+        labels = labels.long()
+
+        assert (tokens[1:, ~(encoded.vq_mask_tokens)] == CODEBOOK_PAD_TOKEN_ID).all()
+        assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all()
+
+        return tokens, labels
+
+
+class InterleaveDataset(IterableDataset):
+    def __init__(
+        self,
+        datasets: list[IterableDataset],
+        probabilities: list[float],
+        seed: int = 42,
+    ):
+        super().__init__()
+
+        self.datasets = datasets
+        self.probabilities = probabilities
+        self.seed = seed
+
+    def __iter__(self):
+        rng = np.random.default_rng(self.seed)
+        dataset_iterators = [iter(dataset) for dataset in self.datasets]
+
+        while True:
+            # Random choice one
+            dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
+            dataset_iterator = dataset_iterators[dataset_idx]
+
+            try:
+                yield next(dataset_iterator)
+            except StopIteration:
+                # Exhausted, create a new iterator
+                dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
+                yield next(dataset_iterators[dataset_idx])
+
+
+@dataclass
+class TextDataCollator:
+    tokenizer: FishTokenizer
+    max_length: int = 1024
+
+    def __call__(self, examples):
+        if "negative_tokens" in examples:
+            positive_examples = []
+            negative_examples = []
+
+            for i in examples:
+                positive_examples.append(
+                    {
+                        "tokens": i["tokens"],
+                        "labels": i["labels"],
+                    }
+                )
+                negative_examples.append(
+                    {
+                        "tokens": i["negative_tokens"],
+                        "labels": i["negative_labels"],
+                    }
+                )
+
+            examples = positive_examples + negative_examples
+
+        return self.batchify(examples)
+
+    def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
+        tokens, attention_masks, labels = [], [], []
+
+        # Calculate the max length
+        max_tokens_length = 0
+        for example in examples:
+            max_tokens_length = max(max_tokens_length, example[tokens_key].size(1))
+        max_tokens_length = min(max_tokens_length, self.max_length)
+
+        for example in examples:
+            _tokens = example[tokens_key][:, :max_tokens_length]
+            _labels = example[labels_key][:, :max_tokens_length]
+            _attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool)
+            tokens_length = _tokens.size(1)
+            _attention_mask[:tokens_length] = False
+
+            assert tokens_length == _labels.size(
+                1
+            ), f"{tokens_length} != {_labels.size(1)}"
+
+            if tokens_length < max_tokens_length:
+                _tokens = F.pad(
+                    _tokens,
+                    (0, max_tokens_length - tokens_length),
+                    value=self.tokenizer.get_token_id("<|end_of_text|>"),
+                )
+                _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
+                _labels = F.pad(
+                    _labels, (0, max_tokens_length - _labels.size(1)), value=-100
+                )
+
+            tokens.append(_tokens)
+            attention_masks.append(_attention_mask)
+            labels.append(_labels)
+
+        tokens = torch.stack(tokens, dim=0)
+        attention_masks = torch.stack(attention_masks, dim=0)
+        labels = torch.stack(labels, dim=0)
+
+        return {
+            "inputs": tokens,
+            "attention_masks": attention_masks,
+            "labels": labels,
+        }
+
+
+class SemanticDataModule(LightningDataModule):
+    def __init__(
+        self,
+        train_dataset: Union[
+            AutoTextSemanticInstructionDataset,
+            AutoTextSemanticInstructionIterableDataset,
+            InterleaveDataset,
+        ],
+        val_dataset: Union[
+            AutoTextSemanticInstructionDataset,
+            AutoTextSemanticInstructionIterableDataset,
+            InterleaveDataset,
+        ],
+        batch_size: int = 32,
+        tokenizer: FishTokenizer = None,
+        max_length: int = 1024,
+        num_workers: int = 4,
+    ):
+        super().__init__()
+
+        self.train_dataset = train_dataset
+        self.val_dataset = val_dataset
+        self.batch_size = batch_size
+        self.tokenizer = tokenizer
+        self.max_length = max_length
+        self.num_workers = num_workers
+
+    def train_dataloader(self):
+        return DataLoader(
+            self.train_dataset,
+            batch_size=self.batch_size,
+            collate_fn=TextDataCollator(self.tokenizer, self.max_length),
+            num_workers=self.num_workers,
+            persistent_workers=True,
+        )
+
+    def val_dataloader(self):
+        return DataLoader(
+            self.val_dataset,
+            batch_size=self.batch_size,
+            collate_fn=TextDataCollator(self.tokenizer, self.max_length),
+            num_workers=self.num_workers,
+            persistent_workers=True,
+        )
+
+
+if __name__ == "__main__":
+    from tqdm import tqdm
+
+    ds = AutoTextSemanticInstructionDataset(
+        ["data/protos"],
+        tokenizer=FishTokenizer("checkpoints/fish-speech-1.5/tokenizer.tiktoken"),
+        use_speaker=False,
+        interactive_prob=1.0,
+        skip_text_prob=0.5,
+    )
+
+    for i in range(100):
+        # Please uncomment line 235 to visualize the tokenized message
+        print(ds[i])

+ 147 - 0
fish_speech/datasets/vqgan.py

@@ -0,0 +1,147 @@
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Optional
+
+import librosa
+import numpy as np
+import torch
+from lightning import LightningDataModule
+from torch.utils.data import DataLoader, Dataset
+
+from fish_speech.utils import RankedLogger
+
+logger = RankedLogger(__name__, rank_zero_only=False)
+
+
+class VQGANDataset(Dataset):
+    def __init__(
+        self,
+        filelist: str,
+        sample_rate: int = 32000,
+        hop_length: int = 640,
+        slice_frames: Optional[int] = None,
+    ):
+        super().__init__()
+
+        filelist = Path(filelist)
+        root = filelist.parent
+
+        self.files = [
+            root / line.strip()
+            for line in filelist.read_text(encoding="utf-8").splitlines()
+            if line.strip()
+        ]
+        self.sample_rate = sample_rate
+        self.hop_length = hop_length
+        self.slice_frames = slice_frames
+
+    def __len__(self):
+        return len(self.files)
+
+    def get_item(self, idx):
+        file = self.files[idx]
+
+        audio, _ = librosa.load(file, sr=self.sample_rate, mono=True)
+
+        # Slice audio and features
+        if (
+            self.slice_frames is not None
+            and audio.shape[0] > self.slice_frames * self.hop_length
+        ):
+            start = np.random.randint(
+                0, audio.shape[0] - self.slice_frames * self.hop_length
+            )
+            audio = audio[start : start + self.slice_frames * self.hop_length]
+
+        if len(audio) == 0:
+            return None
+
+        max_value = np.abs(audio).max()
+        if max_value > 1.0:
+            audio = audio / max_value
+
+        return {
+            "audio": torch.from_numpy(audio),
+        }
+
+    def __getitem__(self, idx):
+        try:
+            return self.get_item(idx)
+        except Exception as e:
+            import traceback
+
+            traceback.print_exc()
+            logger.error(f"Error loading {self.files[idx]}: {e}")
+            return None
+
+
+@dataclass
+class VQGANCollator:
+    def __call__(self, batch):
+        batch = [x for x in batch if x is not None]
+
+        audio_lengths = torch.tensor([len(x["audio"]) for x in batch])
+        audio_maxlen = audio_lengths.max()
+
+        # Rounds up to nearest multiple of 2 (audio_lengths)
+        audios = []
+        for x in batch:
+            audios.append(
+                torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"])))
+            )
+
+        return {
+            "audios": torch.stack(audios),
+            "audio_lengths": audio_lengths,
+        }
+
+
+class VQGANDataModule(LightningDataModule):
+    def __init__(
+        self,
+        train_dataset: VQGANDataset,
+        val_dataset: VQGANDataset,
+        batch_size: int = 32,
+        num_workers: int = 4,
+        val_batch_size: Optional[int] = None,
+    ):
+        super().__init__()
+
+        self.train_dataset = train_dataset
+        self.val_dataset = val_dataset
+        self.batch_size = batch_size
+        self.val_batch_size = val_batch_size or batch_size
+        self.num_workers = num_workers
+
+    def train_dataloader(self):
+        return DataLoader(
+            self.train_dataset,
+            batch_size=self.batch_size,
+            collate_fn=VQGANCollator(),
+            num_workers=self.num_workers,
+            shuffle=True,
+            persistent_workers=True,
+        )
+
+    def val_dataloader(self):
+        return DataLoader(
+            self.val_dataset,
+            batch_size=self.val_batch_size,
+            collate_fn=VQGANCollator(),
+            num_workers=self.num_workers,
+            persistent_workers=True,
+        )
+
+
+if __name__ == "__main__":
+    dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt")
+    dataloader = DataLoader(
+        dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator()
+    )
+
+    for batch in dataloader:
+        print(batch["audios"].shape)
+        print(batch["features"].shape)
+        print(batch["audio_lengths"])
+        print(batch["feature_lengths"])
+        break

+ 210 - 0
fish_speech/models/text2semantic/lit_module.py

@@ -0,0 +1,210 @@
+from typing import Any, Optional
+
+import lightning as L
+import torch
+import torch.nn.functional as F
+from lightning.pytorch.utilities.types import OptimizerLRScheduler
+
+import fish_speech.utils as utils
+
+CODEBOOK_PAD_TOKEN_ID = 0
+from fish_speech.models.text2semantic.llama import NaiveTransformer
+
+log = utils.RankedLogger(__name__, rank_zero_only=True)
+
+
+class TextToSemantic(L.LightningModule):
+    def __init__(
+        self,
+        model: NaiveTransformer,
+        optimizer: Any,
+        lr_scheduler: Any,
+    ):
+        super().__init__()
+
+        self.model = model
+        self.optimizer_builder = optimizer
+        self.lr_scheduler_builder = lr_scheduler
+
+    def forward(self, x):
+        return self.model(x)
+
+    def on_save_checkpoint(self, checkpoint):
+        # Save only LoRA parameters
+        state_dict = checkpoint["state_dict"]
+        use_lora = any("lora" in name for name in state_dict.keys())
+        if not use_lora:
+            return
+
+        for name in list(state_dict.keys()):
+            if "lora" not in name:
+                state_dict.pop(name)
+
+    def configure_optimizers(self) -> OptimizerLRScheduler:
+        # Get weight decay parameters
+        weight_decay_parameters, other_parameters = [], []
+        for name, param in self.named_parameters():
+            if ".bias" in name or "norm.weight" in name or ".embeddings." in name:
+                other_parameters.append(param)
+            else:
+                weight_decay_parameters.append(param)
+
+        optimizer = self.optimizer_builder(
+            [
+                {"params": weight_decay_parameters},
+                {"params": other_parameters, "weight_decay": 0.0},
+            ]
+        )
+
+        # Print the parameters and their weight decay
+        for i in optimizer.param_groups:
+            log.info(
+                f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
+            )
+
+        lr_scheduler = self.lr_scheduler_builder(optimizer)
+
+        return {
+            "optimizer": optimizer,
+            "lr_scheduler": {
+                "scheduler": lr_scheduler,
+                "interval": "step",
+            },
+        }
+
+    # Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
+    def get_batch_logps(
+        self,
+        logits: torch.FloatTensor,
+        labels: torch.LongTensor,
+        average_log_prob: bool = False,
+    ) -> torch.FloatTensor:
+        """Compute the log probabilities of the given labels under the given logits.
+
+        Args:
+            logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
+            labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
+            average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
+
+        Returns:
+            A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
+        """
+        assert logits.shape[:-1] == labels.shape
+
+        labels = labels.clone()
+        loss_mask = labels != -100
+
+        # dummy token; we'll ignore the losses on these tokens later
+        labels[labels == -100] = 0
+
+        per_token_logps = torch.gather(
+            logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
+        ).squeeze(-1)
+
+        if average_log_prob:
+            return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
+        else:
+            return (per_token_logps * loss_mask).sum(-1)
+
+    def _step(self, batch, batch_idx, stage: str):
+        is_train = stage == "train"
+
+        if is_train:
+            # Key part to make lora work
+            # Otherwise the parameters are merged, which lead to incorrect gradients
+            self.model.train()
+
+        # Do positive and negative samples in the same batch to speed up training
+        labels = batch["labels"]
+        outputs = self.model(
+            inp=batch["inputs"],
+            key_padding_mask=batch["attention_masks"],
+            labels=batch["labels"],
+        )
+        token_logits = outputs.token_logits
+        codebook_logits = outputs.codebook_logits
+
+        # Generate labels
+        base_loss = F.cross_entropy(
+            token_logits.view(-1, token_logits.size(-1)),
+            labels[:, 0].reshape(-1),
+            ignore_index=-100,
+        )
+
+        token_ids = labels[:, 0]
+        semantic_mask = (token_ids >= self.model.tokenizer.semantic_begin_id) & (
+            token_ids <= self.model.tokenizer.semantic_end_id
+        )
+        all_codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks]
+        all_codebook_labels_permuted = all_codebook_labels.permute(0, 2, 1)
+        filtered_codebook_labels = all_codebook_labels_permuted[semantic_mask]
+        semantic_loss = F.cross_entropy(
+            codebook_logits.reshape(-1, codebook_logits.size(-1)),
+            filtered_codebook_labels.reshape(-1),
+            ignore_index=-100,
+        )
+
+        loss = base_loss + semantic_loss
+
+        self.log(
+            f"{stage}/loss",
+            loss,
+            on_step=is_train,
+            on_epoch=not is_train,
+            prog_bar=True,
+            logger=True,
+            sync_dist=not is_train,
+        )
+
+        self.log(
+            f"{stage}/base_loss",
+            base_loss,
+            on_step=is_train,
+            on_epoch=not is_train,
+            prog_bar=False,
+            logger=True,
+            sync_dist=not is_train,
+        )
+
+        self.log(
+            f"{stage}/semantic_loss",
+            semantic_loss,
+            on_step=is_train,
+            on_epoch=not is_train,
+            prog_bar=False,
+            logger=True,
+            sync_dist=not is_train,
+        )
+
+        # Top-5 accuracy
+        accuracy = self.get_accuracy(codebook_logits, filtered_codebook_labels)
+        self.log(
+            f"{stage}/top_5_accuracy",
+            accuracy,
+            on_step=is_train,
+            on_epoch=not is_train,
+            prog_bar=True,
+            logger=True,
+            sync_dist=not is_train,
+        )
+
+        return loss
+
+    def get_accuracy(self, logits, labels):
+        mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID)
+        if mask.sum() == 0:
+            return torch.tensor(0.0, device=logits.device)
+
+        _, indices = logits.topk(5, dim=-1)
+        correct = indices.eq(labels.unsqueeze(-1))
+        correct[~mask] = 0
+        correct = correct.sum()
+        accuracy = correct / mask.sum()
+
+        return accuracy
+
+    def training_step(self, batch, batch_idx):
+        return self._step(batch, batch_idx, "train")
+
+    def validation_step(self, batch, batch_idx):
+        return self._step(batch, batch_idx, "val")

+ 12 - 12
fish_speech/models/text2semantic/llama.py

@@ -432,10 +432,6 @@ class BaseTransformer(nn.Module):
         logger.info(f"Loading model from {path}, config: {config}")
         model = model_cls(config, tokenizer=tokenizer)
 
-        if lora_config is not None:
-            setup_lora(model, lora_config)
-            logger.info(f"LoRA setup: {lora_config}")
-
         if load_weights is False:
             logger.info("Randomly initialized model")
         else:
@@ -497,6 +493,10 @@ class BaseTransformer(nn.Module):
             err = model.load_state_dict(weights, strict=False, assign=True)
             logger.info(f"Loaded weights with error: {err}")
 
+        if lora_config is not None:
+            setup_lora(model, lora_config)
+            logger.info(f"LoRA setup: {lora_config}")
+
         return model
 
     def save_pretrained(self, path: str, drop_lora: bool = False):
@@ -642,10 +642,6 @@ class DualARTransformer(BaseTransformer):
         parent_result = super().forward(
             inp=inp,
             key_padding_mask=key_padding_mask,
-            vq_parts=vq_parts,
-            vq_masks=vq_masks,
-            mel_parts=mel_parts,
-            mel_masks=mel_masks,
         )
         token_logits = parent_result.logits
         x = parent_result.hidden_states
@@ -658,7 +654,10 @@ class DualARTransformer(BaseTransformer):
         fast_freqs_cis = self.fast_freqs_cis[:fast_seq_len]
 
         # Extract corresponding parts with labels
-        codebook_mask = labels == self.semantic_token_id
+        token_labels = labels[:, 0]
+        codebook_mask = (token_labels >= self.tokenizer.semantic_begin_id) & (
+            token_labels <= self.tokenizer.semantic_end_id
+        )
         # This gives where input token is <|semantic|>
         x = x[codebook_mask]
 
@@ -675,9 +674,10 @@ class DualARTransformer(BaseTransformer):
                 dtype=torch.int,
             )
         else:
-            codebooks = vq_parts[..., :-1][vq_require_losses][
-                vq_masks[vq_require_losses]
-            ]
+            all_codebooks = labels[:, 1:, :]
+            all_codebooks_permuted = all_codebooks.permute(0, 2, 1)
+            semantic_codebooks = all_codebooks_permuted[codebook_mask]
+            codebooks = semantic_codebooks[:, :-1]
 
         x = self.fast_project_in(x)
         codebook_embeddings = self.fast_embeddings(codebooks)

+ 40 - 0
fish_speech/scheduler.py

@@ -0,0 +1,40 @@
+import math
+
+
+def get_cosine_schedule_with_warmup_lr_lambda(
+    current_step: int,
+    *,
+    num_warmup_steps: int | float,
+    num_training_steps: int,
+    num_cycles: float = 0.5,
+    final_lr_ratio: float = 0.0,
+):
+    if 0 < num_warmup_steps < 1:  # float mode
+        num_warmup_steps = int(num_warmup_steps * num_training_steps)
+
+    if current_step < num_warmup_steps:
+        return float(current_step) / float(max(1, num_warmup_steps))
+
+    progress = float(current_step - num_warmup_steps) / float(
+        max(1, num_training_steps - num_warmup_steps)
+    )
+
+    return max(
+        final_lr_ratio,
+        0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
+    )
+
+
+def get_constant_schedule_with_warmup_lr_lambda(
+    current_step: int,
+    *,
+    num_warmup_steps: int | float,
+    num_training_steps: int | None = None,
+):
+    if 0 < num_warmup_steps < 1:  # float mode
+        num_warmup_steps = int(num_warmup_steps * num_training_steps)
+
+    if current_step < num_warmup_steps:
+        return float(current_step) / float(max(1, num_warmup_steps))
+
+    return 1.0

+ 1 - 1
fish_speech/tokenizer.py

@@ -46,7 +46,7 @@ MODALITY_TOKENS = {
 }
 
 SEMANTIC_TOKEN_TEMPLATE = "<|semantic:{i}|>"
-SEMANTIC_TOKENS = [SEMANTIC_TOKEN_TEMPLATE.format(i=i) for i in range(1024)]
+SEMANTIC_TOKENS = [SEMANTIC_TOKEN_TEMPLATE.format(i=i) for i in range(4096)]
 
 # Warning: when you add a new special token, you should only add it to the end of the list.
 ALL_SPECIAL_TOKENS = [

+ 141 - 0
fish_speech/train.py

@@ -0,0 +1,141 @@
+import os
+
+os.environ["USE_LIBUV"] = "0"
+import sys
+from typing import Optional
+
+import hydra
+import lightning as L
+import pyrootutils
+import torch
+from lightning import Callback, LightningDataModule, LightningModule, Trainer
+from lightning.pytorch.loggers import Logger
+from lightning.pytorch.strategies import DDPStrategy
+from omegaconf import DictConfig, OmegaConf
+
+os.environ.pop("SLURM_NTASKS", None)
+os.environ.pop("SLURM_JOB_NAME", None)
+os.environ.pop("SLURM_NTASKS_PER_NODE", None)
+
+# register eval resolver and root
+pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+
+# Allow TF32 on Ampere GPUs
+torch.set_float32_matmul_precision("high")
+torch.backends.cudnn.allow_tf32 = True
+
+# register eval resolver
+OmegaConf.register_new_resolver("eval", eval)
+
+import fish_speech.utils as utils
+
+log = utils.RankedLogger(__name__, rank_zero_only=True)
+
+
+@utils.task_wrapper
+def train(cfg: DictConfig) -> tuple[dict, dict]:
+    """Trains the model. Can additionally evaluate on a testset, using best weights obtained during
+    training.
+    This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
+    failure. Useful for multiruns, saving info about the crash, etc.
+    Args:
+        cfg (DictConfig): Configuration composed by Hydra.
+    Returns:
+        Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
+    """  # noqa: E501
+
+    # set seed for random number generators in pytorch, numpy and python.random
+    if cfg.get("seed"):
+        L.seed_everything(cfg.seed, workers=False)
+
+    if cfg.get("deterministic"):
+        torch.use_deterministic_algorithms(True)
+
+    log.info(f"Instantiating datamodule <{cfg.data._target_}>")
+    datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
+
+    log.info(f"Instantiating model <{cfg.model._target_}>")
+    model: LightningModule = hydra.utils.instantiate(cfg.model)
+
+    log.info("Instantiating callbacks...")
+    callbacks: list[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))
+
+    log.info("Instantiating loggers...")
+    logger: list[Logger] = utils.instantiate_loggers(cfg.get("logger"))
+
+    log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
+    trainer: Trainer = hydra.utils.instantiate(
+        cfg.trainer,
+        callbacks=callbacks,
+        logger=logger,
+    )
+
+    object_dict = {
+        "cfg": cfg,
+        "datamodule": datamodule,
+        "model": model,
+        "callbacks": callbacks,
+        "logger": logger,
+        "trainer": trainer,
+    }
+
+    if logger:
+        log.info("Logging hyperparameters!")
+        utils.log_hyperparameters(object_dict)
+
+    if cfg.get("train"):
+        log.info("Starting training!")
+
+        ckpt_path = cfg.get("ckpt_path")
+        auto_resume = False
+
+        resume_ckpt_path = utils.get_latest_checkpoint(cfg.paths.ckpt_dir)
+        if resume_ckpt_path is not None:
+            ckpt_path = resume_ckpt_path
+            auto_resume = True
+
+        if ckpt_path is not None:
+            log.info(f"Resuming from checkpoint: {ckpt_path}")
+
+        # resume weights only is disabled for auto-resume
+        if cfg.get("resume_weights_only") and auto_resume is False:
+            log.info("Resuming weights only!")
+            ckpt = torch.load(ckpt_path, map_location=model.device)
+            if "state_dict" in ckpt:
+                ckpt = ckpt["state_dict"]
+            err = model.load_state_dict(ckpt, strict=False)
+            log.info(f"Error loading state dict: {err}")
+            ckpt_path = None
+
+        trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
+
+    train_metrics = trainer.callback_metrics
+
+    if cfg.get("test"):
+        log.info("Starting testing!")
+        ckpt_path = trainer.checkpoint_callback.best_model_path
+        if ckpt_path == "":
+            log.warning("Best ckpt not found! Using current weights for testing...")
+            ckpt_path = cfg.get("ckpt_path")
+
+        trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
+        log.info(f"Best ckpt path: {ckpt_path}")
+
+    test_metrics = trainer.callback_metrics
+
+    # merge train and test metrics
+    metric_dict = {**train_metrics, **test_metrics}
+
+    return metric_dict, object_dict
+
+
+@hydra.main(
+    version_base="1.3", config_path="./configs", config_name="llama_pretrain.yaml"
+)
+def main(cfg: DictConfig) -> Optional[float]:
+    # train the model
+    train(cfg)
+
+
+if __name__ == "__main__":
+    main()

+ 6 - 0
mkdocs.yml

@@ -57,6 +57,7 @@ theme:
 nav:
   - Introduction: en/index.md
   - Installation: en/install.md
+  - Finetune: en/finetune.md
   - Inference: en/inference.md
   - Samples: en/samples.md
 
@@ -84,6 +85,7 @@ plugins:
           nav:
             - 介绍: zh/index.md
             - 安装: zh/install.md
+            - 微调: zh/finetune.md
             - 推理: zh/inference.md
             - 示例: zh/samples.md
         - locale: ja
@@ -92,6 +94,7 @@ plugins:
           nav:
             - はじめに: ja/index.md
             - インストール: ja/install.md
+            - ファインチューニング: ja/finetune.md
             - 推論: ja/inference.md
             - サンプル: ja/samples.md
         - locale: pt
@@ -100,6 +103,7 @@ plugins:
           nav:
             - Introdução: pt/index.md
             - Instalação: pt/install.md
+            - Ajuste Fino: pt/finetune.md
             - Inferência: pt/inference.md
             - Amostras: pt/samples.md
         - locale: ko
@@ -108,6 +112,7 @@ plugins:
           nav:
             - 소개: ko/index.md
             - 설치: ko/install.md
+            - 파인튜닝: ko/finetune.md
             - 추론: ko/inference.md
             - 샘플: ko/samples.md
         - locale: ar
@@ -116,6 +121,7 @@ plugins:
           nav:
             - مقدمة: ar/index.md
             - التثبيت: ar/install.md
+            - الضبط الدقيق: ar/finetune.md
             - الاستنتاج: ar/inference.md
             - العينات: ar/samples.md
 

+ 1 - 1
pyproject.toml

@@ -49,7 +49,7 @@ dependencies = [
 
 [project.optional-dependencies]
 stable = [
-    "torch>=2.5.1",
+    "torch<2.9.0",
     "torchaudio",
 ]
 cpu = [

+ 169 - 0
tools/llama/build_dataset.py

@@ -0,0 +1,169 @@
+import itertools
+import os
+import re
+from collections import defaultdict
+from functools import partial
+from multiprocessing import Pool
+from pathlib import Path
+
+import click
+import numpy as np
+from loguru import logger
+from tqdm import tqdm
+
+from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
+from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
+from fish_speech.utils.file import load_filelist
+
+# To avoid CPU overload
+os.environ["MKL_NUM_THREADS"] = "1"
+os.environ["OMP_NUM_THREADS"] = "1"
+
+
+def task_generator_folder(root: Path, text_extension: str):
+    files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}"))
+    files = sorted(files)
+
+    grouped_files = defaultdict(list)
+    for file in tqdm(files, desc=f"Grouping {root}"):
+        p = str(file.parent)
+        speaker = file.parent.name
+
+        try:
+            if isinstance(text_extension, str):
+                texts = [file.with_suffix(text_extension).read_text(encoding="utf-8")]
+            else:
+                texts = [
+                    file.with_suffix(ext).read_text(encoding="utf-8")
+                    for ext in text_extension
+                ]
+        except Exception as e:
+            logger.error(f"Failed to read text {file}: {e}")
+            continue
+
+        grouped_files[p].append((speaker, file, texts))
+
+    logger.info(
+        f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..."
+    )
+
+    for i in grouped_files.values():
+        subset = [(f, t) for _, f, t in i]
+        yield i[0][0], subset, "folder"
+
+
+def task_generator_filelist(filelist):
+    grouped_files = defaultdict(list)
+    for filename, speaker, _, text in load_filelist(filelist):
+        grouped_files[speaker].append((Path(filename), [text]))
+
+    logger.info(f"Found {len(grouped_files)} groups in {filelist}")
+    for speaker, values in grouped_files.items():
+        yield speaker, values, "filelist"
+
+
+def run_task(task):
+    name, subset, source = task
+
+    # Parse the files
+    sentences = []
+    for file, texts in subset:
+        np_file = file.with_suffix(".npy")
+        if np_file.exists() is False:
+            logger.warning(f"Can't find {np_file}")
+            continue
+
+        new_texts = []
+
+        for text in texts:
+            # Simple cleaning: replace { xxx } and < xxx > with space
+            text = re.sub(r"\{.*?\}", " ", text)
+            text = re.sub(r"<.*?>", " ", text)
+            text = re.sub(r"\s+", " ", text)
+            new_texts.append(text)
+
+        try:
+            semantics = np.load(np_file)
+        except Exception as e:
+            logger.error(f"Failed to parse {file}: {e}")
+            continue
+
+        if isinstance(semantics, np.ndarray):
+            semantics = semantics.tolist()
+
+        sentences.append(
+            Sentence(
+                texts=new_texts,
+                semantics=[Semantics(values=s) for s in semantics],
+            )
+        )
+
+    # Pack the sentences
+    return pack_pb_stream(
+        TextData(
+            source=source,
+            name=name,
+            sentences=sentences,
+        )
+    )
+
+
+@click.command()
+@click.option(
+    "--input",
+    type=click.Path(path_type=Path),
+    required=True,
+    help="A folder containing the dataset or a filelist",
+    multiple=True,
+)
+@click.option(
+    "--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft"
+)
+@click.option("--num-workers", type=int, default=16)
+@click.option("--text-extension", type=str, default=[".txt"], multiple=True)
+@click.option(
+    "--shard-size", type=int, default=10, help="The maximum size of each shard in mb"
+)
+def main(input, output, num_workers, text_extension, shard_size):
+    generator_fns = []
+
+    for f in input:
+        assert f.exists(), f"{f} not found"
+
+        if f.is_dir():
+            generator_fn = task_generator_folder(f, text_extension)
+        else:
+            generator_fn = task_generator_filelist(f)
+
+        generator_fns.append(generator_fn)
+
+    generator_fn = itertools.chain(*generator_fns)
+    output.mkdir(parents=True, exist_ok=True)
+
+    dataset_fp = None
+    tar_idx = 0
+    written_size = 0
+
+    with Pool(num_workers) as p:
+        for result in tqdm(p.imap_unordered(run_task, generator_fn)):
+            if dataset_fp is None:
+                dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb")
+
+            dataset_fp.write(result)
+            written_size += len(result)
+
+            if written_size > shard_size * 1024 * 1024:
+                logger.info(f"Finished writing {tar_idx} shards to {output}")
+                dataset_fp.close()
+                dataset_fp = None
+                written_size = 0
+                tar_idx += 1
+
+    if dataset_fp is not None:
+        dataset_fp.close()
+
+    logger.info(f"Finished writing {tar_idx + 1} shards to {output}")
+
+
+if __name__ == "__main__":
+    main()

+ 171 - 0
tools/llama/eval_in_context.py

@@ -0,0 +1,171 @@
+import pyrootutils
+import torch
+import torch.nn.functional as F
+from matplotlib import pyplot as plt
+from transformers import AutoTokenizer
+
+# register eval resolver and root
+pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+
+from torch.utils.data import DataLoader
+
+from fish_speech.datasets.semantic import AutoAugTextDataset, TextDataCollator
+from fish_speech.models.text2semantic.inference import load_model
+
+
+def smooth(
+    scalars: list[float], weight: float
+) -> list[float]:  # Weight between 0 and 1
+    last = scalars[0]  # First value in the plot (first timestep)
+    smoothed = list()
+    for point in scalars:
+        smoothed_val = last * weight + (1 - weight) * point  # Calculate smoothed value
+        smoothed.append(smoothed_val)  # Save it
+        last = smoothed_val  # Anchor the last smoothed value
+
+    return smoothed
+
+
+@torch.inference_mode()
+def analyze_one_model(loader, config, weight, max_length):
+    device = "cuda" if torch.cuda.is_available() else "cpu"
+    model = load_model(
+        config,
+        weight,
+        device,
+        torch.bfloat16,
+        max_length,
+        compile=False,
+    )[0]
+
+    current_step = 0
+    model.eval()
+
+    semantic_loss_sum = torch.zeros(
+        max_length,
+        dtype=torch.float32,
+        device=device,
+    )
+    counter = torch.zeros(
+        max_length,
+        dtype=torch.long,
+        device=device,
+    )
+
+    for batch in loader:
+        batch = {k: v.to(device) for k, v in batch.items()}
+
+        labels = batch["labels"]
+        outputs = model(
+            inp=batch["inputs"],
+            key_padding_mask=batch["attention_masks"],
+        )
+
+        token_logits = outputs.token_logits
+        codebook_logits = outputs.codebook_logits
+
+        # Generate labels
+        base_loss = F.cross_entropy(
+            token_logits.reshape(-1, token_logits.size(-1)),
+            labels[:, 0].reshape(-1),
+            ignore_index=-100,
+            reduction="none",
+        )
+
+        codebook_labels = labels[:, 1 : 1 + model.config.num_codebooks].mT
+        semantic_loss = F.cross_entropy(
+            codebook_logits.reshape(-1, codebook_logits.size(-1)),
+            codebook_labels.reshape(-1),
+            ignore_index=-100,
+            reduction="none",
+        )
+
+        base_loss = base_loss.reshape(labels[:, 0].shape)
+        semantic_loss = semantic_loss.reshape(codebook_labels.shape)
+
+        semantic_loss_frame = semantic_loss.mean(-1)
+        pad_pos = codebook_labels.sum(-1) == -100 * model.config.num_codebooks
+
+        for loss_sample, pad in zip(semantic_loss_frame, pad_pos):
+            semantic_loss_sum[~pad] += loss_sample[~pad]
+            counter[~pad] += 1
+
+        current_step += 1
+        if current_step == 10:
+            break
+
+    semantic_loss = semantic_loss.cpu()
+    counter = counter.cpu()
+    xs, ys = [], []
+
+    for i, (loss, count) in enumerate(zip(semantic_loss_sum, counter)):
+        if count > 0:
+            xs.append(i)
+            ys.append((loss / count).item())  # for better loss visualization
+
+    smoothed_ys = smooth(ys, 0.95)
+
+    # Unload model
+    del model
+    torch.cuda.empty_cache()
+
+    return xs, ys, smoothed_ys
+
+
+def main():
+    tokenizer = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1")
+    max_length = 4096
+
+    ds = AutoAugTextDataset(
+        ["data/protos/sft/云天河"],
+        tokenizer=tokenizer,
+        use_speaker=False,
+        interactive_prob=1.0,
+        max_length=max_length,
+    )
+
+    loader = DataLoader(
+        ds,
+        batch_size=8,
+        collate_fn=TextDataCollator(tokenizer, max_length=max_length),
+        num_workers=0,
+        shuffle=False,
+    )
+
+    plt.figure(figsize=(10, 5), dpi=200)
+
+    plt.xlabel("Frame")
+    plt.ylabel("Loss")
+    plt.yscale("log")
+    plt.title("Semantic Loss")
+    plt.grid(which="both", axis="both")
+    plt.xlim(0, max_length)
+
+    tests = [
+        (
+            "pertrain-medium",
+            "dual_ar_2_codebook_medium",
+            "checkpoints/text2semantic-pretrain-medium-2k-v1.pth",
+        ),
+        (
+            "sft-medium",
+            "dual_ar_2_codebook_medium",
+            "checkpoints/text2semantic-sft-medium-v1.1-4k.pth",
+        ),
+        (
+            "sft-large",
+            "dual_ar_2_codebook_large",
+            "checkpoints/text2semantic-sft-large-v1.1-4k.pth",
+        ),
+    ]
+
+    for name, config, weight in tests:
+        xs, _, smoothed_ys = analyze_one_model(loader, config, weight, max_length)
+        plt.plot(xs, smoothed_ys, label=name)
+
+    plt.legend()
+    plt.savefig("semantic_loss.png")
+
+
+if __name__ == "__main__":
+    main()

+ 96 - 0
tools/llama/merge_lora.py

@@ -0,0 +1,96 @@
+import shutil
+from copy import deepcopy
+from pathlib import Path
+
+import click
+import hydra
+import torch
+from hydra import compose, initialize
+from hydra.utils import instantiate
+from loguru import logger
+
+from fish_speech.models.text2semantic.llama import BaseTransformer
+from fish_speech.models.text2semantic.lora import get_merged_state_dict
+
+
+@click.command()
+@click.option("--lora-config", type=str, default="r_8_alpha_16")
+@click.option("--base-weight", type=str, default="checkpoints/fish-speech-1.4")
+@click.option("--lora-weight", type=str, required=True)
+@click.option("--output", type=str, required=True)
+def merge(lora_config, base_weight, lora_weight, output):
+    output = Path(output)
+    logger.info(
+        f"Merging {base_weight} and {lora_weight} into {output} with {lora_config}"
+    )
+
+    with initialize(version_base="1.3", config_path="../../fish_speech/configs/lora"):
+        cfg = compose(config_name=lora_config)
+
+    lora_config = instantiate(cfg)
+    logger.info(f"Loaded lora model with config {lora_config}")
+
+    llama_model = BaseTransformer.from_pretrained(
+        path=base_weight,
+        load_weights=True,
+        lora_config=lora_config,
+    )
+    logger.info(f"Loaded llama model")
+
+    llama_state_dict = llama_model.state_dict()
+    llama_state_dict = {k: v for k, v in llama_state_dict.items() if "lora" not in k}
+    llama_state_dict_copy = deepcopy(llama_state_dict)
+    lora_state_dict = torch.load(lora_weight, map_location="cpu", weights_only=False)
+
+    if "state_dict" in llama_state_dict:
+        llama_state_dict = llama_state_dict["state_dict"]
+
+    if "state_dict" in lora_state_dict:
+        lora_state_dict = lora_state_dict["state_dict"]
+
+    # remove prefix model.
+    if any(k.startswith("model.") for k in llama_state_dict.keys()):
+        llama_state_dict = {
+            k.replace("model.", ""): v
+            for k, v in llama_state_dict.items()
+            if k.startswith("model.")
+        }
+    if any(k.startswith("model.") for k in lora_state_dict.keys()):
+        lora_state_dict = {
+            k.replace("model.", ""): v
+            for k, v in lora_state_dict.items()
+            if k.startswith("model.")
+        }
+
+    logger.info(f"Found {len(llama_state_dict)} keys in llama model")
+    logger.info(f"Found {len(lora_state_dict)} keys in lora model")
+
+    merged_state_dict = llama_state_dict | lora_state_dict
+    llama_model.load_state_dict(merged_state_dict, strict=True)
+    logger.info(f"Merged model loaded")
+
+    # Trigger eval mode to merge lora
+    llama_model.eval()
+    llama_model.save_pretrained(output, drop_lora=True)
+    logger.info(f"Saved merged model to {output}, validating")
+
+    new_state_dict = torch.load(output / "model.pth", map_location="cpu")
+    original_keys = set(llama_state_dict_copy.keys())
+
+    tolerance = 1e-5
+    for key in original_keys:
+        diff_l1 = (new_state_dict[key] - llama_state_dict_copy[key]).abs().sum().item()
+        if diff_l1 > tolerance:
+            logger.info(f"Significant difference found in key: {key}")
+            break
+
+    if diff_l1 <= tolerance:
+        logger.warning(
+            "Merged model seems identical to the original model. Further validation might be needed."
+        )
+    else:
+        logger.info("Merged model is different from the original model, check passed")
+
+
+if __name__ == "__main__":
+    merge()

+ 2 - 4
tools/vqgan/extract_vq.py

@@ -95,10 +95,8 @@ def process_batch(files: list[Path], model) -> float:
         if wav.shape[0] > 1:
             wav = wav.mean(dim=0, keepdim=True)
 
-        wav = torchaudio.functional.resample(
-            wav.cuda(), sr, model.spec_transform.sample_rate
-        )[0]
-        total_time += len(wav) / model.spec_transform.sample_rate
+        wav = torchaudio.functional.resample(wav.cuda(), sr, model.sample_rate)[0]
+        total_time += len(wav) / model.sample_rate
         max_length = max(max_length, len(wav))
 
         wavs.append(wav)