Lengyue 2 лет назад
Родитель
Сommit
05c342704b

+ 3 - 0
fish_speech/datasets/text.py

@@ -21,6 +21,7 @@ from transformers import AutoTokenizer
 
 from fish_speech.datasets.protos.text_data_pb2 import SampleDataRequest
 from fish_speech.datasets.protos.text_data_pb2_grpc import DataServiceStub
+from fish_speech.text.parser import clean_text
 from fish_speech.text.symbols import pad as pad_symbol
 from fish_speech.text.symbols import pu_symbols
 from fish_speech.utils import RankedLogger
@@ -189,6 +190,8 @@ class AutoAugTextDataset(IterableDataset):
                     for i in phones
                 ]
             )
+        else:
+            sentence = clean_text(sentence)
 
         tokens = self.tokenizer.encode(
             f"{sentence}",

+ 3 - 0
fish_speech/models/vqgan/lit_module.py

@@ -169,6 +169,9 @@ class VQGAN(L.LightningModule):
             text_features, size=gt_mels.shape[2], mode="nearest"
         )
 
+        if loss_vq.ndim > 1:
+            loss_vq = loss_vq.mean()
+
         # Sample mels
         speaker_features = (
             self.speaker_encoder(gt_mels, mel_masks)

+ 3 - 2
fish_speech/models/vqgan/modules/encoders.py

@@ -275,17 +275,18 @@ class VQEncoder(nn.Module):
         codebook_size: int = 2048,
         downsample: int = 1,
         codebook_groups: int = 1,
+        codebook_layers: int = 1,
     ):
         super().__init__()
 
-        if codebook_groups > 1:
+        if codebook_groups > 1 or codebook_layers > 1:
             self.vq = GroupedResidualVQ(
                 dim=vq_channels,
                 codebook_size=codebook_size,
                 threshold_ema_dead_code=2,
                 kmeans_init=False,
                 groups=codebook_groups,
-                num_quantizers=1,
+                num_quantizers=codebook_layers,
             )
         else:
             self.vq = VectorQuantize(

+ 6 - 0
fish_speech/text/parser.py

@@ -98,9 +98,13 @@ REMOVE_UNKNOWN_SYMBOL_REGEX = re.compile(
 def clean_text(text):
     # Clean the text
     text = text.strip()
+    # Replace <p:(.*?)> with <PPP(.*?)PPP>
+    text = re.sub(r"<p:(.*?)>", r"<PPP\1PPP>", text)
     # Replace all chinese symbols with their english counterparts
     text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
     text = REMOVE_UNKNOWN_SYMBOL_REGEX.sub("", text)
+    # Replace <PPP(.*?)PPP> with <p:(.*?)>
+    text = re.sub(r"<PPP(.*?)PPP>", r"<p:\1>", text)
 
     return text
 
@@ -231,3 +235,5 @@ if __name__ == "__main__":
         "测试一下 Hugging face, BGM声音很大吗?那我改一下. 世界、こんにちは。"  # noqa: E501
     )
     print(segments)
+
+    print(clean_text("测试一下 Hugging face, BGM声音很大吗?那我改一下. 世界、こんにちは。<p:123> <p:aH>"))

+ 5 - 1
tools/llama/generate.py

@@ -14,6 +14,8 @@ from loguru import logger
 from tqdm import tqdm
 from transformers import AutoTokenizer
 
+from fish_speech.text.parser import clean_text
+
 os.environ["TOKENIZERS_PARALLELISM"] = "false"
 torch._inductor.config.coordinate_descent_tuning = True
 torch._inductor.config.triton.unique_kernel_names = True
@@ -266,12 +268,14 @@ def encode_tokens(
         string = prompt_string + " " + string
 
     if use_g2p:
-        prompt = g2p(prompt)
+        prompt = g2p(string)
         prompt = [
             (f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
             for _, i in prompt
         ]
         string = " ".join(prompt)
+    else:
+        string = clean_text(string)
 
     string = f"[INST] {string} [/INST]"