fine-tune.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. from dataclasses import dataclass, field
  2. from functools import partial
  3. from typing import Optional
  4. from datasets import load_dataset, load_from_disk
  5. from transformers import (
  6. AutoModelForCausalLM,
  7. AutoTokenizer,
  8. DataCollatorWithPadding,
  9. HfArgumentParser,
  10. Trainer,
  11. )
  12. from transformers import TrainingArguments as _TrainingArguments
  13. @dataclass
  14. class ModelArguments:
  15. model_name_or_path: Optional[str] = field(default="baichuan-inc/Baichuan2-7B-Base")
  16. @dataclass
  17. class DataArguments:
  18. data_path: str = field(
  19. default=None, metadata={"help": "Path to the training data."}
  20. )
  21. @dataclass
  22. class TrainingArguments(_TrainingArguments):
  23. cache_dir: Optional[str] = field(default=None)
  24. optim: str = field(default="adamw_torch")
  25. model_max_length: int = field(
  26. default=512,
  27. metadata={
  28. "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
  29. },
  30. )
  31. use_lora: bool = field(default=False)
  32. def dataset_transform(batch, tokenizer: AutoTokenizer = None):
  33. outputs = tokenizer(
  34. batch["prompt"],
  35. padding="longest",
  36. truncation=True,
  37. max_length=512,
  38. return_tensors="pt",
  39. )
  40. labels = outputs.input_ids.clone()
  41. # Set the labels to -100 so that the logits are not affected by loss
  42. labels[outputs.attention_mask == 0] = -100
  43. return {
  44. "input_ids": outputs.input_ids,
  45. "attention_mask": outputs.attention_mask,
  46. "labels": labels,
  47. }
  48. def train():
  49. parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
  50. model_args, data_args, training_args = parser.parse_args_into_dataclasses()
  51. model = AutoModelForCausalLM.from_pretrained(
  52. model_args.model_name_or_path,
  53. trust_remote_code=True,
  54. cache_dir=training_args.cache_dir,
  55. )
  56. tokenizer = AutoTokenizer.from_pretrained(
  57. model_args.model_name_or_path,
  58. use_fast=False,
  59. trust_remote_code=True,
  60. model_max_length=training_args.model_max_length,
  61. cache_dir=training_args.cache_dir,
  62. )
  63. tokenizer.pad_token_id = tokenizer.eos_token_id
  64. if training_args.use_lora:
  65. from peft import LoraConfig, TaskType, get_peft_model
  66. peft_config = LoraConfig(
  67. task_type=TaskType.CAUSAL_LM,
  68. target_modules=["W_pack"],
  69. inference_mode=False,
  70. r=16,
  71. lora_alpha=64,
  72. lora_dropout=0.1,
  73. )
  74. model.enable_input_require_grads()
  75. model = get_peft_model(model, peft_config)
  76. model.print_trainable_parameters()
  77. try:
  78. dataset = load_from_disk(data_args.data_path)
  79. if "train" in dataset:
  80. dataset = dataset["train"]
  81. except:
  82. dataset = load_dataset(data_args.data_path, split="train")
  83. dataset.set_transform(partial(dataset_transform, tokenizer=tokenizer))
  84. dataset = dataset.train_test_split(test_size=1000, seed=42)
  85. trainer = Trainer(
  86. model=model,
  87. args=training_args,
  88. train_dataset=dataset["train"],
  89. eval_dataset=dataset["test"],
  90. tokenizer=tokenizer,
  91. data_collator=DataCollatorWithPadding(tokenizer),
  92. )
  93. trainer.train()
  94. trainer.save_state()
  95. trainer.save_model(output_dir=training_args.output_dir)
  96. if __name__ == "__main__":
  97. train()