浏览代码

add system

jch 1 周之前
父节点
当前提交
2997c09d1b
共有 2 个文件被更改,包括 18 次插入5 次删除
  1. 1 0
      README.md
  2. 17 5
      src/preprocess/generate_qw2_5_lora_sft_json.py

+ 1 - 0
README.md

@@ -25,6 +25,7 @@
 - [qwen2_5vl_lora_sft](https://github.com/hiyouga/LLaMA-Factory/blob/main/examples/README_zh.md#%E5%A4%9A%E6%A8%A1%E6%80%81%E6%8C%87%E4%BB%A4%E7%9B%91%E7%9D%A3%E5%BE%AE%E8%B0%83)
 - [qwen2_5vl_lora_dpo](https://github.com/hiyouga/LLaMA-Factory/blob/main/examples/README_zh.md#%E5%A4%9A%E6%A8%A1%E6%80%81-dpoorposimpo-%E8%AE%AD%E7%BB%83)
 - [qwen2_5vl_full_sft](https://github.com/hiyouga/LLaMA-Factory/blob/main/examples/README_zh.md#%E5%A4%9A%E6%A8%A1%E6%80%81%E6%8C%87%E4%BB%A4%E7%9B%91%E7%9D%A3%E5%BE%AE%E8%B0%83-1)
+- llamafactory-cli train examples/train_lora/qwen2_5vl_lora_sft.yaml <font face="仿宋" color=red size=4>[修改数据集]</font>
 
 ## 1.9 llamafactory部署
 - [llamafactory仓库](https://github.com/hiyouga/LLaMA-Factory)

+ 17 - 5
src/preprocess/generate_qw2_5_lora_sft_json.py

@@ -5,29 +5,40 @@ import argparse
 import json
 
 import pandas as pd
+from sklearn.utils import shuffle
 from tqdm import tqdm
 
 image_tag = "<image>"
 
-prompt = "根据用户在社交平台的昵称和头像,判断用户的性别倾向;当头像不可用时,仅使用昵称进行判断;直接返回[男性|女性|未知],不要返回多余的信息;昵称:%s,头像:%s"
+system = "根据用户在社交平台的昵称和头像,判断用户的性别倾向;当头像不可用时,仅使用昵称进行判断;直接返回[男性|女性|未知],不要返回多余的信息。"
+prompt = "昵称:%s,头像:%s"
 
 
 # LLaMA-Factory 训练模版
 # https://github.com/hiyouga/LLaMA-Factory/blob/main/examples/train_lora/qwen2_5vl_lora_sft.yaml
+# [OpenAI格式]https://llamafactory.readthedocs.io/zh-cn/latest/getting_started/data_preparation.html
+
+def get_system_message():
+    return {
+        "role": "system",
+        "content": system}
+
+
 def get_user_message(content):
     return {
-        "content": prompt % (content, image_tag),
-        "role": "user"}
+        "role": "user",
+        "content": prompt % (content, image_tag)}
 
 
 def get_assistant_message(content):
     return {
-        "content": content,
-        "role": "assistant"}
+        "role": "assistant",
+        "content": content}
 
 
 def get_conversation(row):
     messages = [
+        get_system_message(),
         get_user_message(row["nick_name"]),
         get_assistant_message(row["gender"])
     ]
@@ -40,6 +51,7 @@ def get_conversation(row):
 def process(input_file, output_file):
     conversation_list = []
     df = pd.read_csv(input_file)
+    df = shuffle(df)
     for index, row in tqdm(df.iterrows(), total=len(df)):
         conversation_list.append(get_conversation(row))
     if conversation_list: