Explorar el Código

Optimize text loader & add phones to tokenizer

Lengyue hace 2 años
padre
commit
78a02fe9e1
Se han modificado 5 ficheros con 52 adiciones y 2 borrados
  1. 5 0
      README.md
  2. 1 1
      fish_speech/text/parser.py
  3. 1 1
      pyproject.toml
  4. 0 0
      tools/llama/init_model.py
  5. 45 0
      tools/llama/rebuild_tokenizer.py

+ 5 - 0
README.md

@@ -15,3 +15,8 @@ pip3 install ninja && MAX_JOBS=4 pip3 install flash-attn --no-build-isolation
 # Install fish-speech
 pip3 install -e .
 ```
+
+## Credits
+- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
+- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
+- [GPT VITS](https://github.com/innnky/gpt-vits)

+ 1 - 1
fish_speech/text/parser.py

@@ -213,7 +213,7 @@ def segments_to_phones(
 
 def g2p(text, order=None):
     segments = parse_text_to_segments(text, order=order)
-    _, phones = segments_to_phones(segments)
+    phones, _ = segments_to_phones(segments)
     return phones
 
 

+ 1 - 1
pyproject.toml

@@ -23,7 +23,7 @@ dependencies = [
     "natsort>=8.4.0",
     "einops>=0.7.0",
     "librosa>=0.10.1",
-    "vector-quantize-pytorch>=1.9.18",
+    "vector-quantize-pytorch>=1.10.0",
     "rich>=13.5.3",
     "gradio>=4.0.0",
     "cn2an",

+ 0 - 0
tools/init_llama_model.py → tools/llama/init_model.py


+ 45 - 0
tools/llama/rebuild_tokenizer.py

@@ -0,0 +1,45 @@
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from fish_speech.text.symbols import en_symbols, jp_symbols, zh_symbols
+
+# reuse the tokenizer from the llama
+model_type = "meta-llama/Llama-2-7b-hf"
+tokenizer = AutoTokenizer.from_pretrained(model_type)
+
+# new tokens
+new_tokens = [f"<semantic_{i}>" for i in range(4096)] + list(
+    set(zh_symbols + jp_symbols + en_symbols)
+)
+tokenizer.add_tokens(new_tokens)
+
+# pad token
+tokenizer.pad_token = tokenizer.eos_token
+tokenizer.pad_token_id = tokenizer.eos_token_id
+
+print(f"Vocab size: {len(tokenizer)}")
+
+model = AutoModelForCausalLM.from_pretrained(
+    "fishaudio/speech-lm-300m", revision="text-pretrain-10k"
+)
+
+# Resize the token embeddings to include the new tokens
+# Make sure it's a multiple of 8 for faster training
+model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
+
+total_params = sum(p.numel() for p in model.parameters())
+print(f"Total parameters: {total_params / 1e6:.2f}M")
+
+# Try tokenizing a new sequence
+sequence = "Test <semantic_0> <semantic_1023> </s> uang1 iang5 AA an"
+encoded = tokenizer.encode(sequence)
+print("Test encoding....")
+print(f"\tSentence: {sequence}")
+print(f"\tEncoded: {encoded}")
+print(f"\tDecoded: {tokenizer.batch_decode(encoded)}")
+
+model.push_to_hub(
+    "fishaudio/speech-lm-300m", private=True, revision="text-pretrain-10k-phones"
+)
+tokenizer.push_to_hub(
+    "fishaudio/speech-lm-300m", private=True, revision="text-pretrain-10k-phones"
+)