|
@@ -5,29 +5,40 @@ import argparse
|
|
|
import json
|
|
|
|
|
|
import pandas as pd
|
|
|
+from sklearn.utils import shuffle
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
image_tag = "<image>"
|
|
|
|
|
|
-prompt = "根据用户在社交平台的昵称和头像,判断用户的性别倾向;当头像不可用时,仅使用昵称进行判断;直接返回[男性|女性|未知],不要返回多余的信息;昵称:%s,头像:%s"
|
|
|
+system = "根据用户在社交平台的昵称和头像,判断用户的性别倾向;当头像不可用时,仅使用昵称进行判断;直接返回[男性|女性|未知],不要返回多余的信息。"
|
|
|
+prompt = "昵称:%s,头像:%s"
|
|
|
|
|
|
|
|
|
# LLaMA-Factory 训练模版
|
|
|
# https://github.com/hiyouga/LLaMA-Factory/blob/main/examples/train_lora/qwen2_5vl_lora_sft.yaml
|
|
|
+# [OpenAI格式]https://llamafactory.readthedocs.io/zh-cn/latest/getting_started/data_preparation.html
|
|
|
+
|
|
|
+def get_system_message():
|
|
|
+ return {
|
|
|
+ "role": "system",
|
|
|
+ "content": system}
|
|
|
+
|
|
|
+
|
|
|
def get_user_message(content):
|
|
|
return {
|
|
|
- "content": prompt % (content, image_tag),
|
|
|
- "role": "user"}
|
|
|
+ "role": "user",
|
|
|
+ "content": prompt % (content, image_tag)}
|
|
|
|
|
|
|
|
|
def get_assistant_message(content):
|
|
|
return {
|
|
|
- "content": content,
|
|
|
- "role": "assistant"}
|
|
|
+ "role": "assistant",
|
|
|
+ "content": content}
|
|
|
|
|
|
|
|
|
def get_conversation(row):
|
|
|
messages = [
|
|
|
+ get_system_message(),
|
|
|
get_user_message(row["nick_name"]),
|
|
|
get_assistant_message(row["gender"])
|
|
|
]
|
|
@@ -40,6 +51,7 @@ def get_conversation(row):
|
|
|
def process(input_file, output_file):
|
|
|
conversation_list = []
|
|
|
df = pd.read_csv(input_file)
|
|
|
+ df = shuffle(df)
|
|
|
for index, row in tqdm(df.iterrows(), total=len(df)):
|
|
|
conversation_list.append(get_conversation(row))
|
|
|
if conversation_list:
|