Parcourir la source

Fix error handling & optimize dataset

Lengyue il y a 2 ans
Parent
commit
467353b23a

+ 1 - 1
fish_speech/configs/data/libri-light.yaml

@@ -1,7 +1,7 @@
 datasets:
   - root: /***REMOVED***/workspace/eva-gan/data/libri-light
     source: LibriLight
-    languages: [ZH, EN]
+    languages: [EN]
     extension: .txt
     # This controls the grouping of the dataset (i.e. speaker)
     # 1 means we use the parent folder of the file as the group name

+ 1 - 2
fish_speech/datasets/text.py

@@ -308,11 +308,10 @@ class AutoAugTextDataset(IterableDataset):
             tokens, labels = self.pack_sentences(
                 final_text,
                 semantics=final_semantic,
-                speaker=None if self.use_speaker else sentence.speaker,
+                speaker=None if self.use_speaker else response.name,
                 add_bos=True,
             )
         else:
-            print(all_tokens[0].shape)
             tokens = torch.cat(all_tokens, dim=1)
             labels = torch.cat(all_labels, dim=1)
 

+ 5 - 1
fish_speech/text/parser.py

@@ -3,6 +3,8 @@ import re
 import string
 from typing import Optional
 
+from loguru import logger
+
 from fish_speech.text.chinese import g2p as g2p_chinese
 from fish_speech.text.english import g2p as g2p_english
 from fish_speech.text.japanese import g2p as g2p_japanese
@@ -214,7 +216,9 @@ def segments_to_phones(
                 phones.append(q1)
                 ids.append(symbols_to_id[q1])
             else:
-                raise ValueError(f"Unknown phone: {segment.language} - {phone} -")
+                logger.warning(
+                    f"Unknown phone: {segment.language}: `{phone}`, ignored."
+                )
 
     return phones, ids
 

+ 2 - 3
tools/llama/build_dataset.py

@@ -1,6 +1,5 @@
-import glob
 import re
-from collections import Counter, defaultdict
+from collections import defaultdict
 from multiprocessing import Pool
 from pathlib import Path
 
@@ -13,7 +12,7 @@ from tqdm import tqdm
 from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
 from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
 from fish_speech.text import g2p
-from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
+from fish_speech.utils.file import load_filelist
 
 
 def task_generator_yaml(config):

+ 2 - 2
tools/llama/generate.py

@@ -298,7 +298,7 @@ def encode_tokens(
     tokens = torch.tensor([tokens], dtype=torch.int, device=device)
 
     # Codebooks
-    zeros = torch.zeros((4, tokens.size(1)), dtype=torch.int, device=device)
+    zeros = torch.zeros((8, tokens.size(1)), dtype=torch.int, device=device)
     prompt = torch.cat((tokens, zeros), dim=0)
 
     if prompt_tokens is None:
@@ -368,7 +368,7 @@ def load_model(config_name, checkpoint_path, device, precision):
     "--prompt-tokens", type=click.Path(path_type=Path, exists=True), default=None
 )
 @click.option("--num-samples", type=int, default=1)
-@click.option("--max_new_tokens", type=int, default=0)
+@click.option("--max-new-tokens", type=int, default=0)
 @click.option("--top-k", type=int, default=None)
 @click.option("--top-p", type=float, default=0.5)
 @click.option("--repetition-penalty", type=float, default=1.5)