train.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. from dataclasses import dataclass, field
  2. from functools import partial
  3. from typing import Optional
  4. from .speech_lm.dataset import build_dataset
  5. from datasets import load_dataset, load_from_disk
  6. from transformers import (
  7. AutoModelForCausalLM,
  8. AutoTokenizer,
  9. DataCollatorWithPadding,
  10. HfArgumentParser,
  11. Trainer,
  12. )
  13. from transformers import TrainingArguments as _TrainingArguments
  14. @dataclass
  15. class ModelArguments:
  16. model_name_or_path: Optional[str] = field(default="fishaudio/speech-lm-300m")
  17. model_revision: Optional[str] = field(default="main")
  18. @dataclass
  19. class DataArguments:
  20. pass
  21. @dataclass
  22. class TrainingArguments(_TrainingArguments):
  23. cache_dir: Optional[str] = field(default=None)
  24. optim: str = field(default="adamw_torch")
  25. model_max_length: int = field(
  26. default=512,
  27. metadata={
  28. "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
  29. },
  30. )
  31. use_lora: bool = field(default=False)
  32. def train():
  33. parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
  34. model_args, data_args, training_args = parser.parse_args_into_dataclasses()
  35. model = AutoModelForCausalLM.from_pretrained(
  36. model_args.model_name_or_path,
  37. trust_remote_code=True,
  38. cache_dir=training_args.cache_dir,
  39. revision=model_args.model_revision,
  40. )
  41. tokenizer = AutoTokenizer.from_pretrained(
  42. model_args.model_name_or_path,
  43. use_fast=False,
  44. trust_remote_code=True,
  45. model_max_length=training_args.model_max_length,
  46. cache_dir=training_args.cache_dir,
  47. revision=model_args.model_revision,
  48. )
  49. tokenizer.pad_token_id = tokenizer.eos_token_id
  50. if training_args.use_lora:
  51. from peft import LoraConfig, TaskType, get_peft_model
  52. peft_config = LoraConfig(
  53. task_type=TaskType.CAUSAL_LM,
  54. target_modules=["W_pack"],
  55. inference_mode=False,
  56. r=16,
  57. lora_alpha=64,
  58. lora_dropout=0.1,
  59. )
  60. model.enable_input_require_grads()
  61. model = get_peft_model(model, peft_config)
  62. model.print_trainable_parameters()
  63. try:
  64. dataset = load_from_disk(data_args.data_path)
  65. if "train" in dataset:
  66. dataset = dataset["train"]
  67. except:
  68. dataset = load_dataset(data_args.data_path, split="train")
  69. dataset.set_transform(partial(dataset_transform, tokenizer=tokenizer))
  70. dataset = dataset.train_test_split(test_size=1000, seed=42)
  71. trainer = Trainer(
  72. model=model,
  73. args=training_args,
  74. train_dataset=dataset["train"],
  75. eval_dataset=dataset["test"],
  76. tokenizer=tokenizer,
  77. data_collator=DataCollatorWithPadding(tokenizer),
  78. )
  79. trainer.train()
  80. trainer.save_state()
  81. trainer.save_model(output_dir=training_args.output_dir)
  82. if __name__ == "__main__":
  83. train()