| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119 |
- from dataclasses import dataclass, field
- from functools import partial
- from typing import Optional
- from datasets import load_dataset, load_from_disk
- from transformers import (
- AutoModelForCausalLM,
- AutoTokenizer,
- DataCollatorWithPadding,
- HfArgumentParser,
- Trainer,
- )
- from transformers import TrainingArguments as _TrainingArguments
- @dataclass
- class ModelArguments:
- model_name_or_path: Optional[str] = field(default="baichuan-inc/Baichuan2-7B-Base")
- @dataclass
- class DataArguments:
- data_path: str = field(
- default=None, metadata={"help": "Path to the training data."}
- )
- @dataclass
- class TrainingArguments(_TrainingArguments):
- cache_dir: Optional[str] = field(default=None)
- optim: str = field(default="adamw_torch")
- model_max_length: int = field(
- default=512,
- metadata={
- "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
- },
- )
- use_lora: bool = field(default=False)
- def dataset_transform(batch, tokenizer: AutoTokenizer = None):
- outputs = tokenizer(
- batch["prompt"],
- padding="longest",
- truncation=True,
- max_length=512,
- return_tensors="pt",
- )
- labels = outputs.input_ids.clone()
- # Set the labels to -100 so that the logits are not affected by loss
- labels[outputs.attention_mask == 0] = -100
- return {
- "input_ids": outputs.input_ids,
- "attention_mask": outputs.attention_mask,
- "labels": labels,
- }
- def train():
- parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
- model_args, data_args, training_args = parser.parse_args_into_dataclasses()
- model = AutoModelForCausalLM.from_pretrained(
- model_args.model_name_or_path,
- trust_remote_code=True,
- cache_dir=training_args.cache_dir,
- )
- tokenizer = AutoTokenizer.from_pretrained(
- model_args.model_name_or_path,
- use_fast=False,
- trust_remote_code=True,
- model_max_length=training_args.model_max_length,
- cache_dir=training_args.cache_dir,
- )
- tokenizer.pad_token_id = tokenizer.eos_token_id
- if training_args.use_lora:
- from peft import LoraConfig, TaskType, get_peft_model
- peft_config = LoraConfig(
- task_type=TaskType.CAUSAL_LM,
- target_modules=["W_pack"],
- inference_mode=False,
- r=16,
- lora_alpha=64,
- lora_dropout=0.1,
- )
- model.enable_input_require_grads()
- model = get_peft_model(model, peft_config)
- model.print_trainable_parameters()
- try:
- dataset = load_from_disk(data_args.data_path)
- if "train" in dataset:
- dataset = dataset["train"]
- except:
- dataset = load_dataset(data_args.data_path, split="train")
- dataset.set_transform(partial(dataset_transform, tokenizer=tokenizer))
- dataset = dataset.train_test_split(test_size=1000, seed=42)
- trainer = Trainer(
- model=model,
- args=training_args,
- train_dataset=dataset["train"],
- eval_dataset=dataset["test"],
- tokenizer=tokenizer,
- data_collator=DataCollatorWithPadding(tokenizer),
- )
- trainer.train()
- trainer.save_state()
- trainer.save_model(output_dir=training_args.output_dir)
- if __name__ == "__main__":
- train()
|