init_model.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. from transformers import AutoTokenizer, LlamaConfig, LlamaModel
  2. # reuse the tokenizer from the llama
  3. model_type = "meta-llama/Llama-2-7b-hf"
  4. tokenizer = AutoTokenizer.from_pretrained(model_type)
  5. # new tokens
  6. new_tokens = [f"<semantic_{i}>" for i in range(4096)]
  7. tokenizer.add_tokens(new_tokens + ["<pad>"])
  8. # pad token
  9. tokenizer.pad_token = "<pad>"
  10. tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
  11. print(f"Vocab size: {len(tokenizer)}")
  12. hidden_size = 1024
  13. intermediate_size = hidden_size * (11 / 3)
  14. # then round to the nearest multiple of 8
  15. intermediate_size = round(intermediate_size / 8) * 8
  16. print(f"Hidden size: {hidden_size}")
  17. print(f"Intermediate size: {intermediate_size}")
  18. model = LlamaModel(
  19. LlamaConfig(
  20. vocab_size=tokenizer.vocab_size,
  21. hidden_size=hidden_size,
  22. intermediate_size=intermediate_size,
  23. num_hidden_layers=20,
  24. num_attention_heads=16,
  25. max_position_embeddings=4096,
  26. )
  27. )
  28. model = model.bfloat16()
  29. # Resize the token embeddings to include the new tokens
  30. # Make sure it's a multiple of 8 for faster training
  31. model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
  32. total_params = sum(p.numel() for p in model.parameters())
  33. print(f"Total parameters: {total_params / 1e6:.2f}M")
  34. # Try tokenizing a new sequence
  35. sequence = "Test <semantic_0> <semantic_1023> <pad>"
  36. encoded = tokenizer.encode(sequence)
  37. print("Test encoding....")
  38. print(f"\tSentence: {sequence}")
  39. print(f"\tEncoded: {encoded}")
  40. print(f"\tDecoded: {tokenizer.batch_decode(encoded)}")
  41. # model.save_pretrained("./checkpoints/speech-lm-300m-init")
  42. # tokenizer.save_pretrained("./checkpoints/speech-lm-300m-init")
  43. model.push_to_hub("fishaudio/speech-lm-300m", private=True, revision="init")
  44. tokenizer.push_to_hub("fishaudio/speech-lm-300m", private=True, revision="init")