fine-tune.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. from dataclasses import dataclass, field
  2. from functools import partial
  3. from typing import Optional
  4. from datasets import load_dataset, load_from_disk
  5. from transformers import (
  6. AutoModelForCausalLM,
  7. AutoTokenizer,
  8. DataCollatorWithPadding,
  9. HfArgumentParser,
  10. Trainer,
  11. )
  12. from transformers import TrainingArguments as _TrainingArguments
  13. @dataclass
  14. class ModelArguments:
  15. model_name_or_path: Optional[str] = field(default="baichuan-inc/Baichuan2-7B-Base")
  16. @dataclass
  17. class DataArguments:
  18. data_path: str = field(
  19. default=None, metadata={"help": "Path to the training data."}
  20. )
  21. @dataclass
  22. class TrainingArguments(_TrainingArguments):
  23. cache_dir: Optional[str] = field(default=None)
  24. optim: str = field(default="adamw_torch")
  25. model_max_length: int = field(
  26. default=512,
  27. metadata={
  28. "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
  29. },
  30. )
  31. use_lora: bool = field(default=False)
  32. def dataset_transform(batch, tokenizer: AutoTokenizer=None):
  33. outputs = tokenizer(batch["prompt"], padding="longest", truncation=True, max_length=512, return_tensors="pt")
  34. labels = outputs.input_ids.clone()
  35. # Set the labels to -100 so that the logits are not affected by loss
  36. labels[outputs.attention_mask == 0] = -100
  37. return {
  38. "input_ids": outputs.input_ids,
  39. "attention_mask": outputs.attention_mask,
  40. "labels": labels,
  41. }
  42. def train():
  43. parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
  44. model_args, data_args, training_args = parser.parse_args_into_dataclasses()
  45. model = AutoModelForCausalLM.from_pretrained(
  46. model_args.model_name_or_path,
  47. trust_remote_code=True,
  48. cache_dir=training_args.cache_dir,
  49. )
  50. tokenizer = AutoTokenizer.from_pretrained(
  51. model_args.model_name_or_path,
  52. use_fast=False,
  53. trust_remote_code=True,
  54. model_max_length=training_args.model_max_length,
  55. cache_dir=training_args.cache_dir,
  56. )
  57. tokenizer.pad_token_id = tokenizer.eos_token_id
  58. if training_args.use_lora:
  59. from peft import LoraConfig, TaskType, get_peft_model
  60. peft_config = LoraConfig(
  61. task_type=TaskType.CAUSAL_LM,
  62. target_modules=["W_pack"],
  63. inference_mode=False,
  64. r=16,
  65. lora_alpha=64,
  66. lora_dropout=0.1,
  67. )
  68. model.enable_input_require_grads()
  69. model = get_peft_model(model, peft_config)
  70. model.print_trainable_parameters()
  71. try:
  72. dataset = load_from_disk(data_args.data_path)
  73. if 'train' in dataset:
  74. dataset = dataset['train']
  75. except:
  76. dataset = load_dataset(data_args.data_path, split="train")
  77. dataset.set_transform(partial(dataset_transform, tokenizer=tokenizer))
  78. dataset = dataset.train_test_split(test_size=1000, seed=42)
  79. trainer = Trainer(
  80. model=model,
  81. args=training_args,
  82. train_dataset=dataset["train"],
  83. eval_dataset=dataset["test"],
  84. tokenizer=tokenizer,
  85. data_collator=DataCollatorWithPadding(tokenizer),
  86. )
  87. trainer.train()
  88. trainer.save_state()
  89. trainer.save_model(output_dir=training_args.output_dir)
  90. if __name__ == "__main__":
  91. train()