Просмотр исходного кода

Optimize logger & add multilingual data

Lengyue 2 лет назад
Родитель
Сommit
a03b1b2767

+ 2 - 5
fish_speech/configs/base.yaml

@@ -17,7 +17,7 @@ trainer:
   devices: 8
   strategy:
     _target_: lightning.pytorch.strategies.DDPStrategy
-    static_graph: true
+
   precision: bf16-mixed
 
   # disable validation by epoch end
@@ -43,12 +43,9 @@ callbacks:
     auto_insert_metric_name: false
 
   model_summary:
-    _target_: lightning.pytorch.callbacks.RichModelSummary
+    _target_: lightning.pytorch.callbacks.ModelSummary
     max_depth: 2 # the maximum depth of layer nesting that the summary will include
 
-  rich_progress_bar:
-    _target_: lightning.pytorch.callbacks.RichProgressBar
-
   learning_rate_monitor:
     _target_: lightning.pytorch.callbacks.LearningRateMonitor
     logging_interval: step

+ 1 - 1
fish_speech/models/text2semantic/lit_module.py

@@ -47,7 +47,7 @@ class TextToSemantic(L.LightningModule):
         correct = indices.eq(batch["labels"].unsqueeze(-1)).sum()
         accuracy = correct / batch["labels"].numel()
         self.log(
-            f"{stage}/accuracy",
+            f"{stage}/top_5_accuracy",
             accuracy,
             on_step=True,
             on_epoch=False,

+ 41 - 17
tools/build_vq_text.py

@@ -1,32 +1,56 @@
+from functools import partial
 from pathlib import Path
 
+import numpy as np
 from datasets import Dataset
 
 
-def parse_data(wav_dir, item):
-    text_file = (wav_dir / item["item_name"]).with_suffix(".txt")
-    text = text_file.read_text().strip()
+def parse_data(phones, items):
+    results = []
 
-    semantic = item["semantic_audio"]
-    semantic = [f"<semantic_{x}>" for x in semantic.split(" ")]
-    semantic = " ".join(semantic)
+    for item_name, semantic_audio in zip(items["item_name"], items["semantic_audio"]):
+        wav_file = Path(item_name)
+        text_file = wav_file.with_suffix(".txt")
 
-    text = f"[INST] {text} [/INST] {semantic} </s>"
+        if not text_file.exists():
+            text_file = wav_file.with_suffix(".lab")
+
+        if not text_file.exists():
+            print(f"Missing {text_file}")
+            return None
+
+        text = text_file.read_text().strip()
+        semantic = [f"<semantic_{x}>" for x in semantic_audio.split(" ")]
+        semantic = " ".join(semantic)
+        results.append(f"[INST] {text} [/INST] {semantic} </s>")
+        results.append(f"[INST] {phones[item_name]} [/INST] {semantic} </s>")
 
     return {
-        "text": text,
+        "text": results,
     }
 
 
 if __name__ == "__main__":
-    # dataset = WenetVQDataset()
-    # dataset = list(dataset)
-    # print("Initialized dataset.")
-    dataset = Dataset.from_csv("data/cn-hubert-wenet-25hz-semantic.tsv", delimiter="\t")
+    phones = np.load("dump/phoneme_train.npy", allow_pickle=True).item()
+    phones1 = np.load(
+        "/home/fish/hubert-vq-vits/dump/phoneme_train.npy", allow_pickle=True
+    ).item()
+    phones.update(phones1)
+    print(len(phones))
+
+    dataset = Dataset.from_csv(
+        [
+            "dump/semantic_train.tsv",
+            "/home/fish/hubert-vq-vits/dump/semantic_train.tsv",
+        ],
+        delimiter="\t",
+        split="train",
+    )
     dataset = dataset.map(
-        lambda item: parse_data(Path("data/WenetSpeech"), item), num_proc=64
+        partial(parse_data, phones),
+        num_proc=32,
+        remove_columns=dataset.column_names,
+        batched=True,
     )
-    dataset = dataset.remove_columns(["item_name", "semantic_audio"])
-    dataset = dataset.train_test_split(test_size=0.01)
-    print(dataset["test"][0])
-    dataset.push_to_hub("fishaudio/wenet-vq", private=True)
+    print(len(dataset), dataset[0])
+    dataset.push_to_hub("fishaudio/cn-hubert-25hz-vq", private=True)