Ver Fonte

generate_data_set.py

luojunhui há 2 meses atrás
pai
commit
d5d686c19c
1 ficheiros alterados com 198 adições e 33 exclusões
  1. 198 33
      generate_data_set.py

+ 198 - 33
generate_data_set.py

@@ -2,9 +2,17 @@ import json
 import time
 import random
 
+from datetime import datetime
+
+from openai import OpenAI
 from tqdm import tqdm
 from pymysql.cursors import DictCursor
 from pqai_agent.database import MySQLManager
+from pqai_agent.agents.message_push_agent import MessagePushAgent
+from pqai_agent.logging_service import logger
+from pqai_agent import configs, logging_service
+
+logging_service.setup_root_logger()
 
 
 config = {
@@ -56,8 +64,8 @@ def split_dialogue_history(dialogue_history_, timeout=30*60*1000):
 def get_conversation_info():
     sql = f"""
         select roomid, count(id) as 'article_num'
-        from qywx_chat_history where msg_type = 1 group by roomid
-        having count(id) > 50;
+        from qywx_chat_history where msg_type in (1,2,4) group by roomid
+        having count(id) > 20;
     """
     return mysql_client.select(sql, cursor_type=DictCursor)
 
@@ -69,11 +77,11 @@ def get_dialogue_history(room_id_):
     :return:
     """
     sql = f"""
-        select sender, receiver, sendtime, content
+        select id, sender, receiver, sendtime, content
         from qywx_chat_history
-        where roomid = %s and msg_type = %s;
+        where roomid = %s and msg_type in %s order by sendtime;
     """
-    return mysql_client.select(sql=sql, cursor_type=DictCursor, args=(room_id_, 1))
+    return mysql_client.select(sql=sql, cursor_type=DictCursor, args=(room_id_, (1, 2, 4)))
 
 
 def get_profile_info(user_id_, user_type):
@@ -94,46 +102,203 @@ def get_profile_info(user_id_, user_type):
     return mysql_client.select(sql, cursor_type=DictCursor, args=(user_id_,))
 
 
-if __name__ == "__main__":
+def generate_reply_dataset():
     conversation_info_list = get_conversation_info()
     data_set = []
     for conversation_info in tqdm(conversation_info_list):
         room_id = conversation_info["roomid"]
         staff_id = room_id.split(":")[1]
+        if staff_id in ('1688854974625870', '1688856125791790', '1688856125791452'):
+            user_id = room_id.split(":")[2]
+            if staff_id and user_id:
+                dialogue_history = get_dialogue_history(room_id)
+                for idx, dialogue_info in enumerate(dialogue_history):
+                    if dialogue_info["sender"] == staff_id:
+                        conversation = dialogue_history[: idx]
+                        history_conversation = [
+                            {
+                                "id": i['id'],
+                                "content": i['content'],
+                                "role": "assistant" if i['sender'] == staff_id else "user",
+                                "timestamp": int(i['sendtime'] / 1000)
+                            } for i in conversation
+                        ]
+                        # filter history_conversation
+                        history_conversation = [i for i in history_conversation if i['timestamp'] > int(dialogue_info['sendtime'] / 1000) - 60 * 60 * 24 * 30]
+
+                        if len(history_conversation) > 20:
+                            history_conversation = history_conversation[-20:]
+
+                        eva_conversation = history_conversation[-10:]
+                        if history_conversation:
+                            user_activate_rate= len([i for i in eva_conversation if i['role'] == 'user']) / len(eva_conversation)
+                            reply_msg = dialogue_info['content']
+                            reply_time = int(dialogue_info['sendtime'] / 1000)
+                            if "早上好" in reply_msg:
+                                continue
+                            elif "早安" in reply_msg:
+                                continue
+                            elif "早" in reply_msg:
+                                continue
+                            elif "下午好" in reply_msg:
+                                continue
+                            elif "晚上好" in reply_msg:
+                                continue
+                            elif user_activate_rate < 0.3:
+                                continue
+                            else:
+                                # obj = {
+                                #     "staff_id": staff_id,
+                                #     "user_id": user_id,
+                                #     "conversation": history_conversation,
+                                #     "reply_msg": reply_msg,
+                                #     "reply_time": reply_time,
+                                #     "user_active_rate": user_activate_rate
+                                # }
+                                conversation_id_list = [i['id'] for i in history_conversation]
+                                insert_query = f"""
+                                    insert into internal_conversation_data
+                                    (dataset_id, staff_id, user_id, version_date, conversation, content, send_time, send_type, user_active_rate)
+                                    values (%s, %s, %s, %s, %s, %s, %s, %s, %s);
+                                """
+                                mysql_client.execute(insert_query, args=(
+                                    '1',
+                                    staff_id,
+                                    user_id,
+                                    '2025-06-16',
+                                    json.dumps(conversation_id_list, ensure_ascii=False),
+                                    reply_msg,
+                                    reply_time,
+                                    0,
+                                    user_activate_rate
+                                ))
+    # print(len(data_set))
+    # with open("reply_data_set_filter_2.json", "w", encoding="utf-8") as f:
+    #     f.write(json.dumps(data_set, ensure_ascii=False, indent=4))
+
+
+def generate_push_dataset():
+    conversation_info_list = get_conversation_info()
+    data_set = []
+    for conversation_info in conversation_info_list:
+        room_id = conversation_info["roomid"]
+        staff_id = room_id.split(":")[1]
+        # if staff_id in ('1688854974625870', '1688856125791790', '1688856125791452'):
         user_id = room_id.split(":")[2]
         if staff_id and user_id:
+            history_push_message = []  # 所有 push 消息历史(这个变量会保留到每次循环)
             dialogue_history = get_dialogue_history(room_id)
+
             for idx, dialogue_info in enumerate(dialogue_history):
-                if dialogue_info["sender"] == staff_id:
-                    conversation = dialogue_history[: idx]
-                    history_conversation = [
-                        {
-                            "content": i['content'],
-                            "role": "assistant" if i['sender'] == staff_id else "user",
-                            "timestamp": int(i['sendtime'] / 1000)
-                        } for i in conversation]
-                    # filter history_conversation
-                    history_conversation = [i for i in history_conversation if i['timestamp'] > int(dialogue_info['sendtime'] / 1000) - 60 * 60 * 24 * 30]
-
-                    if len(history_conversation) > 100:
-                        history_conversation = history_conversation[-100:]
-
-                    reply_msg = dialogue_info['content']
-                    reply_time = int(dialogue_info['sendtime'] / 1000)
-                    obj = {
-                        "staff_id": staff_id,
-                        "user_id": user_id,
-                        "conversation": history_conversation,
-                        "reply_msg": reply_msg,
-                        "reply_time": reply_time,
-                    }
-                    data_set.append(obj)
-
-    print(len(data_set))
-    with open("reply_data_set_filter.json", "w", encoding="utf-8") as f:
+                if idx == 0:
+                    continue  # 防止访问 idx - 1
+
+                sender = dialogue_info["sender"]
+                send_timestamp = int(dialogue_info["sendtime"] / 1000)
+
+                before_message = dialogue_history[idx - 1]
+                before_send_timestamp = int(before_message["sendtime"] / 1000)
+
+                if sender == staff_id and (send_timestamp - before_send_timestamp) >= 86400:
+                    push_msg = dialogue_info['content']
+                    if push_msg == '早安,新的一天,愿你拥有最好的心情去迎接一切美好!爆款视频抢先观看,点击下方精彩不断~':
+                        continue
+                    else:
+                        print( datetime.fromtimestamp(send_timestamp), push_msg)
+                        conversation = [
+                            i for i in dialogue_history[: idx] if i['content'] != '早安,新的一天,愿你拥有最好的心情去迎接一切美好!爆款视频抢先观看,点击下方精彩不断~'
+                                        ]
+
+                        history_conversation = [
+                            {
+                                "content": i['content'],
+                                "role": "assistant" if i['sender'] == staff_id else "user",
+                                "timestamp": int(i['sendtime'] / 1000)
+                            }
+                            for i in conversation
+                            if int(i['sendtime'] / 1000) > send_timestamp - 86400 * 5
+                        ]
+
+                        if len(history_conversation) > 50:
+                            history_conversation = history_conversation[-50:]
+
+                        if history_conversation:
+                            push_msg = dialogue_info['content']
+                            obj = {
+                                "staff_id": staff_id,
+                                "user_id": user_id,
+                                "conversation": history_conversation,
+                                "push_msg": push_msg,
+                                "push_time": datetime.fromtimestamp(send_timestamp).strftime("%Y-%m-%d %H:%M:%S"),
+                                "history_push_messages": history_push_message.copy()
+                            }
+                            data_set.append(obj)
+                            history_push_message.append(push_msg)
+
+    with open("push_message_dataset_v4.json", "w", encoding="utf-8") as f:
         f.write(json.dumps(data_set, ensure_ascii=False, indent=4))
 
 
+def generate_push_dataset_new():
+    import json
+    with open("reply_data_set_filter_2.json", "r", encoding="utf-8") as f:
+        data_set = json.loads(f.read())
+    #
+    filter_conversation = [i for i in data_set if len(i['conversation']) >= 20]
+
+    samples =random.sample(filter_conversation, 50)
+    # with open("push_dataset_new_0613_24h.json", encoding="utf-8") as f:
+    #     samples = json.load(f)
+
+    # init message push agent
+    agent = MessagePushAgent()
+    for sample in tqdm(samples):
+        agent_profile = get_profile_info(sample["staff_id"], "staff")[0]['profile']
+        agent_profile = json.loads(agent_profile)
+        user_profile = get_profile_info(sample["user_id"], "user")[0]['profile']
+        user_profile = json.loads(user_profile)
+        # agent_profile = sample["agent_profile"]
+        # user_profile = sample["user_profile"]
+        conversation = sorted(sample["conversation"], key=lambda i: i['timestamp'], reverse=False)
+        last_timestamp = int(conversation[-1]["timestamp"])
+        push_time = last_timestamp + 48 * 3600
+        push_dt =  datetime.fromtimestamp(push_time).strftime('%Y-%m-%d %H:%M:%S')
+        try:
+            response = agent.generate_message(
+                context={
+                    "formatted_staff_profile": agent_profile,
+                    "nickname": user_profile.get('nickname'),
+                    "name": user_profile.get('name'),
+                    "preferred_nickname": user_profile.get('preferred_nickname'),
+                    "age": user_profile.get('age'),
+                    "region": user_profile.get('region'),
+                    "health_conditions": user_profile.get('health_conditions'),
+                    "gender": user_profile.get('gender'),
+                    "medications": user_profile.get('medications'),
+                    "interests": user_profile.get('interests'),
+                    "current_datetime": push_dt,
+                    "avatar": None
+                },
+                dialogue_history=sample["conversation"],
+                timestamp_type="s"
+            )
+            print("---------push消息----------", response)
+            print("\n")
+            sample['push_msg'] = response
+            sample['user_profile'] = user_profile
+            sample['agent_profile'] = agent_profile
+            sample['push_time'] = push_time
+            data_set.append(sample)
+        except Exception as e:
+            print("error", e)
+
+    with open("push_dataset_new_0613_24h_v2.json", "w", encoding="utf-8") as f:
+        f.write(json.dumps(data_set, ensure_ascii=False, indent=4))
+
+
+if __name__ == "__main__":
+    generate_reply_dataset()
+