rebuild_tokenizer.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. from transformers import AutoModelForCausalLM, AutoTokenizer
  2. from fish_speech.text.symbols import en_symbols, jp_symbols, zh_symbols
  3. # reuse the tokenizer from the llama
  4. model_type = "meta-llama/Llama-2-7b-hf"
  5. tokenizer = AutoTokenizer.from_pretrained(model_type)
  6. # new tokens
  7. new_tokens = list(set(zh_symbols + jp_symbols + en_symbols))
  8. tokenizer.add_tokens(new_tokens)
  9. # pad token
  10. tokenizer.pad_token = tokenizer.eos_token
  11. tokenizer.pad_token_id = tokenizer.eos_token_id
  12. tokenizer.padding_side = "right"
  13. tokenizer.truncation_side = "right"
  14. print(f"Vocab size: {len(tokenizer)}")
  15. # model = AutoModelForCausalLM.from_pretrained(
  16. # "fishaudio/speech-lm-300m", revision="mqtts-proto"
  17. # )
  18. # Resize the token embeddings to include the new tokens
  19. # Make sure it's a multiple of 8 for faster training
  20. # model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
  21. # total_params = sum(p.numel() for p in model.parameters())
  22. # print(f"Total parameters: {total_params / 1e6:.2f}M")
  23. # Try tokenizing a new sequence
  24. sequence = "Test <semantic_0> <semantic_1023> </s> uang1 iang5 AA an"
  25. encoded = tokenizer.encode(sequence)
  26. print("Test encoding....")
  27. print(f"\tSentence: {sequence}")
  28. print(f"\tEncoded: {encoded}")
  29. print(f"\tDecoded: {tokenizer.batch_decode(encoded)}")
  30. # model.push_to_hub(
  31. # "fishaudio/speech-lm-300m", private=True, revision="text-pretrain-10k-phones"
  32. # )
  33. tokenizer.push_to_hub("fishaudio/speech-lm-300m", private=True, revision="mqtts-phones")