|
|
@@ -0,0 +1,112 @@
|
|
|
+from dataclasses import dataclass, field
|
|
|
+from functools import partial
|
|
|
+from typing import Optional
|
|
|
+
|
|
|
+from datasets import load_dataset, load_from_disk
|
|
|
+from transformers import (
|
|
|
+ AutoModelForCausalLM,
|
|
|
+ AutoTokenizer,
|
|
|
+ DataCollatorWithPadding,
|
|
|
+ HfArgumentParser,
|
|
|
+ Trainer,
|
|
|
+)
|
|
|
+from transformers import TrainingArguments as _TrainingArguments
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class ModelArguments:
|
|
|
+ model_name_or_path: Optional[str] = field(default="baichuan-inc/Baichuan2-7B-Base")
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class DataArguments:
|
|
|
+ data_path: str = field(
|
|
|
+ default=None, metadata={"help": "Path to the training data."}
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class TrainingArguments(_TrainingArguments):
|
|
|
+ cache_dir: Optional[str] = field(default=None)
|
|
|
+ optim: str = field(default="adamw_torch")
|
|
|
+ model_max_length: int = field(
|
|
|
+ default=512,
|
|
|
+ metadata={
|
|
|
+ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
|
|
|
+ },
|
|
|
+ )
|
|
|
+ use_lora: bool = field(default=False)
|
|
|
+
|
|
|
+
|
|
|
+def dataset_transform(batch, tokenizer: AutoTokenizer=None):
|
|
|
+ outputs = tokenizer(batch["prompt"], padding="longest", truncation=True, max_length=512, return_tensors="pt")
|
|
|
+ labels = outputs.input_ids.clone()
|
|
|
+
|
|
|
+ # Set the labels to -100 so that the logits are not affected by loss
|
|
|
+ labels[outputs.attention_mask == 0] = -100
|
|
|
+
|
|
|
+ return {
|
|
|
+ "input_ids": outputs.input_ids,
|
|
|
+ "attention_mask": outputs.attention_mask,
|
|
|
+ "labels": labels,
|
|
|
+ }
|
|
|
+
|
|
|
+def train():
|
|
|
+ parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
|
|
|
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
|
|
+
|
|
|
+ model = AutoModelForCausalLM.from_pretrained(
|
|
|
+ model_args.model_name_or_path,
|
|
|
+ trust_remote_code=True,
|
|
|
+ cache_dir=training_args.cache_dir,
|
|
|
+ )
|
|
|
+
|
|
|
+ tokenizer = AutoTokenizer.from_pretrained(
|
|
|
+ model_args.model_name_or_path,
|
|
|
+ use_fast=False,
|
|
|
+ trust_remote_code=True,
|
|
|
+ model_max_length=training_args.model_max_length,
|
|
|
+ cache_dir=training_args.cache_dir,
|
|
|
+ )
|
|
|
+ tokenizer.pad_token_id = tokenizer.eos_token_id
|
|
|
+
|
|
|
+ if training_args.use_lora:
|
|
|
+ from peft import LoraConfig, TaskType, get_peft_model
|
|
|
+
|
|
|
+ peft_config = LoraConfig(
|
|
|
+ task_type=TaskType.CAUSAL_LM,
|
|
|
+ target_modules=["W_pack"],
|
|
|
+ inference_mode=False,
|
|
|
+ r=16,
|
|
|
+ lora_alpha=64,
|
|
|
+ lora_dropout=0.1,
|
|
|
+ )
|
|
|
+ model.enable_input_require_grads()
|
|
|
+ model = get_peft_model(model, peft_config)
|
|
|
+ model.print_trainable_parameters()
|
|
|
+
|
|
|
+ try:
|
|
|
+ dataset = load_from_disk(data_args.data_path)
|
|
|
+ if 'train' in dataset:
|
|
|
+ dataset = dataset['train']
|
|
|
+ except:
|
|
|
+ dataset = load_dataset(data_args.data_path, split="train")
|
|
|
+
|
|
|
+ dataset.set_transform(partial(dataset_transform, tokenizer=tokenizer))
|
|
|
+ dataset = dataset.train_test_split(test_size=1000, seed=42)
|
|
|
+
|
|
|
+ trainer = Trainer(
|
|
|
+ model=model,
|
|
|
+ args=training_args,
|
|
|
+ train_dataset=dataset["train"],
|
|
|
+ eval_dataset=dataset["test"],
|
|
|
+ tokenizer=tokenizer,
|
|
|
+ data_collator=DataCollatorWithPadding(tokenizer),
|
|
|
+ )
|
|
|
+ trainer.train()
|
|
|
+ trainer.save_state()
|
|
|
+ trainer.save_model(output_dir=training_args.output_dir)
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ train()
|