|
@@ -38,8 +38,14 @@ class TrainingArguments(_TrainingArguments):
|
|
|
use_lora: bool = field(default=False)
|
|
use_lora: bool = field(default=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
-def dataset_transform(batch, tokenizer: AutoTokenizer=None):
|
|
|
|
|
- outputs = tokenizer(batch["prompt"], padding="longest", truncation=True, max_length=512, return_tensors="pt")
|
|
|
|
|
|
|
+def dataset_transform(batch, tokenizer: AutoTokenizer = None):
|
|
|
|
|
+ outputs = tokenizer(
|
|
|
|
|
+ batch["prompt"],
|
|
|
|
|
+ padding="longest",
|
|
|
|
|
+ truncation=True,
|
|
|
|
|
+ max_length=512,
|
|
|
|
|
+ return_tensors="pt",
|
|
|
|
|
+ )
|
|
|
labels = outputs.input_ids.clone()
|
|
labels = outputs.input_ids.clone()
|
|
|
|
|
|
|
|
# Set the labels to -100 so that the logits are not affected by loss
|
|
# Set the labels to -100 so that the logits are not affected by loss
|
|
@@ -51,6 +57,7 @@ def dataset_transform(batch, tokenizer: AutoTokenizer=None):
|
|
|
"labels": labels,
|
|
"labels": labels,
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+
|
|
|
def train():
|
|
def train():
|
|
|
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
|
|
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
|
|
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
|
@@ -87,11 +94,11 @@ def train():
|
|
|
|
|
|
|
|
try:
|
|
try:
|
|
|
dataset = load_from_disk(data_args.data_path)
|
|
dataset = load_from_disk(data_args.data_path)
|
|
|
- if 'train' in dataset:
|
|
|
|
|
- dataset = dataset['train']
|
|
|
|
|
|
|
+ if "train" in dataset:
|
|
|
|
|
+ dataset = dataset["train"]
|
|
|
except:
|
|
except:
|
|
|
dataset = load_dataset(data_args.data_path, split="train")
|
|
dataset = load_dataset(data_args.data_path, split="train")
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
dataset.set_transform(partial(dataset_transform, tokenizer=tokenizer))
|
|
dataset.set_transform(partial(dataset_transform, tokenizer=tokenizer))
|
|
|
dataset = dataset.train_test_split(test_size=1000, seed=42)
|
|
dataset = dataset.train_test_split(test_size=1000, seed=42)
|
|
|
|
|
|