Forráskód Böngészése

generate_data_set.py

luojunhui 5 napja
szülő
commit
e90441dccc
1 módosított fájl, 59 hozzáadás és 86 törlés
  1. 59 86
      generate_data_set.py

+ 59 - 86
generate_data_set.py

@@ -1,6 +1,7 @@
 import json
 import time
 import random
+import traceback
 
 from datetime import datetime
 
@@ -84,6 +85,23 @@ def get_dialogue_history(room_id_):
     return mysql_client.select(sql=sql, cursor_type=DictCursor, args=(room_id_, (1, 2, 4)))
 
 
+def get_dialogue_history_by_id(staff_id, dialogue_id_tuple):
+    sql = f"""
+        select sender, sendtime, content
+        from qywx_chat_history
+        where id in %s;
+    """
+    conversation_list = mysql_client.select(sql=sql, cursor_type=DictCursor, args=(dialogue_id_tuple,))
+    history_conversation = [
+        {
+            "content": i['content'],
+            "role": "assistant" if i['sender'] == staff_id else "user",
+            "timestamp": int(i['sendtime'] / 1000)
+        } for i in conversation_list
+    ]
+    return history_conversation
+
+
 def get_profile_info(user_id_, user_type):
     match user_type:
         case "user":
@@ -178,77 +196,17 @@ def generate_reply_dataset():
 
 
 def generate_push_dataset():
-    conversation_info_list = get_conversation_info()
-    data_set = []
-    for conversation_info in conversation_info_list:
-        room_id = conversation_info["roomid"]
-        staff_id = room_id.split(":")[1]
-        # if staff_id in ('1688854974625870', '1688856125791790', '1688856125791452'):
-        user_id = room_id.split(":")[2]
-        if staff_id and user_id:
-            history_push_message = []  # 所有 push 消息历史(这个变量会保留到每次循环)
-            dialogue_history = get_dialogue_history(room_id)
-
-            for idx, dialogue_info in enumerate(dialogue_history):
-                if idx == 0:
-                    continue  # 防止访问 idx - 1
-
-                sender = dialogue_info["sender"]
-                send_timestamp = int(dialogue_info["sendtime"] / 1000)
-
-                before_message = dialogue_history[idx - 1]
-                before_send_timestamp = int(before_message["sendtime"] / 1000)
-
-                if sender == staff_id and (send_timestamp - before_send_timestamp) >= 86400:
-                    push_msg = dialogue_info['content']
-                    if push_msg == '早安,新的一天,愿你拥有最好的心情去迎接一切美好!爆款视频抢先观看,点击下方精彩不断~':
-                        continue
-                    else:
-                        print( datetime.fromtimestamp(send_timestamp), push_msg)
-                        conversation = [
-                            i for i in dialogue_history[: idx] if i['content'] != '早安,新的一天,愿你拥有最好的心情去迎接一切美好!爆款视频抢先观看,点击下方精彩不断~'
-                                        ]
 
-                        history_conversation = [
-                            {
-                                "content": i['content'],
-                                "role": "assistant" if i['sender'] == staff_id else "user",
-                                "timestamp": int(i['sendtime'] / 1000)
-                            }
-                            for i in conversation
-                            if int(i['sendtime'] / 1000) > send_timestamp - 86400 * 5
-                        ]
-
-                        if len(history_conversation) > 50:
-                            history_conversation = history_conversation[-50:]
+    fetch_query = f"""
+       select staff_id, user_id, conversation, content, send_time, user_active_rate
+       from internal_conversation_data
+       where dataset_id = 1;
+    """
+    data_set = mysql_client.select(fetch_query, cursor_type=DictCursor)
+    filter_conversation = [i for i in data_set if len(json.loads(i['conversation'])) >= 20]
+    print(len(filter_conversation))
 
-                        if history_conversation:
-                            push_msg = dialogue_info['content']
-                            obj = {
-                                "staff_id": staff_id,
-                                "user_id": user_id,
-                                "conversation": history_conversation,
-                                "push_msg": push_msg,
-                                "push_time": datetime.fromtimestamp(send_timestamp).strftime("%Y-%m-%d %H:%M:%S"),
-                                "history_push_messages": history_push_message.copy()
-                            }
-                            data_set.append(obj)
-                            history_push_message.append(push_msg)
-
-    with open("push_message_dataset_v4.json", "w", encoding="utf-8") as f:
-        f.write(json.dumps(data_set, ensure_ascii=False, indent=4))
-
-
-def generate_push_dataset_new():
-    import json
-    with open("reply_data_set_filter_2.json", "r", encoding="utf-8") as f:
-        data_set = json.loads(f.read())
-    #
-    filter_conversation = [i for i in data_set if len(i['conversation']) >= 20]
-
-    samples =random.sample(filter_conversation, 50)
-    # with open("push_dataset_new_0613_24h.json", encoding="utf-8") as f:
-    #     samples = json.load(f)
+    samples =random.sample(filter_conversation, 100)
 
     # init message push agent
     agent = MessagePushAgent()
@@ -257,14 +215,23 @@ def generate_push_dataset_new():
         agent_profile = json.loads(agent_profile)
         user_profile = get_profile_info(sample["user_id"], "user")[0]['profile']
         user_profile = json.loads(user_profile)
-        # agent_profile = sample["agent_profile"]
-        # user_profile = sample["user_profile"]
-        conversation = sorted(sample["conversation"], key=lambda i: i['timestamp'], reverse=False)
+        conversation = get_dialogue_history_by_id(
+            sample["staff_id"],
+            tuple(sample["conversation"])
+        )
+        conversation.append(
+            {
+                "content": sample["content"],
+                "role": "assistant",
+                "timestamp": sample["send_time"]
+            }
+        )
+        conversation = sorted(conversation, key=lambda i: i['timestamp'], reverse=False)
         last_timestamp = int(conversation[-1]["timestamp"])
         push_time = last_timestamp + 48 * 3600
         push_dt =  datetime.fromtimestamp(push_time).strftime('%Y-%m-%d %H:%M:%S')
         try:
-            response = agent.generate_message(
+            push_message = agent.generate_message(
                 context={
                     "formatted_staff_profile": agent_profile,
                     "nickname": user_profile.get('nickname'),
@@ -279,26 +246,32 @@ def generate_push_dataset_new():
                     "current_datetime": push_dt,
                     "avatar": None
                 },
-                dialogue_history=sample["conversation"],
+                dialogue_history=conversation,
                 timestamp_type="s"
             )
-            print("---------push消息----------", response)
-            print("\n")
-            sample['push_msg'] = response
-            sample['user_profile'] = user_profile
-            sample['agent_profile'] = agent_profile
-            sample['push_time'] = push_time
-            data_set.append(sample)
+            insert_query = f"""
+                insert into internal_conversation_data
+                (dataset_id, staff_id, user_id, version_date, conversation, content, send_time, send_type, user_active_rate)
+                values (%s, %s, %s, %s, %s, %s, %s, %s, %s);
+            """
+            mysql_client.execute(insert_query, args=(
+                '3',
+                sample["staff_id"],
+                sample["user_id"],
+                '2025-06-16',
+                sample["conversation"],
+                push_message,
+                push_time,
+                1,
+                sample["user_active_rate"]
+            ))
         except Exception as e:
             print("error", e)
-
-    with open("push_dataset_new_0613_24h_v2.json", "w", encoding="utf-8") as f:
-        f.write(json.dumps(data_set, ensure_ascii=False, indent=4))
+            print(traceback.format_exc())
 
 
 if __name__ == "__main__":
-    generate_reply_dataset()
-
+    generate_push_dataset()