rebuild_tokenizer.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. from transformers import AutoModelForCausalLM, AutoTokenizer
  2. from fish_speech.text.symbols import en_symbols, jp_symbols, zh_symbols
  3. # reuse the tokenizer from the llama
  4. model_type = "meta-llama/Llama-2-7b-hf"
  5. tokenizer = AutoTokenizer.from_pretrained(model_type)
  6. # new tokens
  7. new_tokens = [f"<semantic_{i}>" for i in range(4096)] + list(
  8. set(zh_symbols + jp_symbols + en_symbols)
  9. )
  10. tokenizer.add_tokens(new_tokens)
  11. # pad token
  12. tokenizer.pad_token = tokenizer.eos_token
  13. tokenizer.pad_token_id = tokenizer.eos_token_id
  14. print(f"Vocab size: {len(tokenizer)}")
  15. model = AutoModelForCausalLM.from_pretrained(
  16. "fishaudio/speech-lm-300m", revision="text-pretrain-10k"
  17. )
  18. # Resize the token embeddings to include the new tokens
  19. # Make sure it's a multiple of 8 for faster training
  20. model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
  21. total_params = sum(p.numel() for p in model.parameters())
  22. print(f"Total parameters: {total_params / 1e6:.2f}M")
  23. # Try tokenizing a new sequence
  24. sequence = "Test <semantic_0> <semantic_1023> </s> uang1 iang5 AA an"
  25. encoded = tokenizer.encode(sequence)
  26. print("Test encoding....")
  27. print(f"\tSentence: {sequence}")
  28. print(f"\tEncoded: {encoded}")
  29. print(f"\tDecoded: {tokenizer.batch_decode(encoded)}")
  30. model.push_to_hub(
  31. "fishaudio/speech-lm-300m", private=True, revision="text-pretrain-10k-phones"
  32. )
  33. tokenizer.push_to_hub(
  34. "fishaudio/speech-lm-300m", private=True, revision="text-pretrain-10k-phones"
  35. )