浏览代码

merge master

luojunhui 3 天之前
父节点
当前提交
8792f46829
共有 2 个文件被更改,包括 100 次插入43 次删除
  1. 99 42
      evaluate_agent_v2.py
  2. 1 1
      generate_data_set.py

+ 99 - 42
evaluate_agent_v2.py

@@ -1,6 +1,7 @@
 import concurrent
 import datetime
 import json
+import random
 import time
 
 from tqdm import tqdm
@@ -47,6 +48,24 @@ def fetch_deepseek_completion(prompt, output_type="text"):
     return response
 
 
+
+def compose_dialogue(dialogue: List[Dict], timestamp_type: str='ms') -> str:
+    role_map = {'user': '用户', 'assistant': '客服'}
+    messages = []
+    for msg in dialogue:
+        if not msg['content']:
+            continue
+        if msg['role'] not in role_map:
+            continue
+        if timestamp_type == 'ms':
+            format_dt = datetime.datetime.fromtimestamp(msg['timestamp'] / 1000).strftime('%Y-%m-%d %H:%M:%S')
+        else:
+            format_dt = datetime.datetime.fromtimestamp(msg['timestamp']).strftime('%Y-%m-%d %H:%M:%S')
+        msg_type = "文本"
+        messages.append('[{}][{}][{}]{}'.format(role_map[msg['role']], format_dt, msg_type, msg['content']))
+    return '\n'.join(messages)
+
+
 class AgentEvaluator:
 
     def __init__(self) -> None:
@@ -90,7 +109,34 @@ class AgentEvaluator:
 
 class PushMessageEvaluator(AgentEvaluator):
 
-    def generate_prompt(self, dialogue_history: List[Dict], message: str,
+    def get_push_dataset(self):
+        sql = f"""
+            select staff_id, user_id, conversation, content, send_time
+            from internal_conversation_data
+            where dataset_id = 2;
+        """
+        return self.mysql_client.select(sql, cursor_type=DictCursor)
+
+
+    def get_dialogue_history_by_id(self, staff_id, dialogue_id_tuple):
+        sql = f"""
+            select sender, sendtime, content
+            from qywx_chat_history
+            where id in %s;
+        """
+
+        conversation_list = self.mysql_client.select(sql=sql, cursor_type=DictCursor, args=(dialogue_id_tuple,))
+        history_conversation = [
+            {
+                "content": i['content'],
+                "role": "assistant" if i['sender'] == staff_id else "user",
+                "timestamp": i['sendtime']
+            } for i in conversation_list
+        ]
+        return history_conversation
+
+
+    def generate_prompt(self, dialogue_history: str, message: str,
         send_time: str, user_profile: Dict, agent_profile: Dict) -> str:
         """
         生成评估prompt
@@ -221,55 +267,66 @@ class PushMessageEvaluator(AgentEvaluator):
         return prompt
 
     def evaluate_task(self, line):
-        conversation_length = len(line["conversation"])
-        if conversation_length > 5:
-            push_time = line["conversation"][-1]["timestamp"] + 48 * 3600
-            evaluator_prompt = self.generate_prompt(
-                dialogue_history=line["conversation"],
-                message=line["push_msg"],
-                send_time=push_time,
-                agent_profile=line["agent_profile"],
-                user_profile=line["user_profile"],
-            )
-            print(evaluator_prompt)
-            response = fetch_deepseek_completion(evaluator_prompt, output_type='json')
-            return {
-                "user_profile": line["user_profile"],
-                "agent_profile": line["agent_profile"],
-                "dialogue_history": line["conversation"],
-                "push_message": line["push_msg"],
-                "push_time": push_time,
-                "evaluation_result": response
-            }
-        return None
+        staff_id = line['staff_id']
+        user_id = line['user_id']
+        conversation_id_list = json.loads(line['conversation'])
+        push_message = line['content']
+        send_time = line['send_time']
+        send_date_str = datetime.datetime.fromtimestamp(send_time).strftime('%Y-%m-%d %H:%M:%S')
+        dialogue_list = self.get_dialogue_history_by_id(staff_id, tuple(conversation_id_list))
+        format_dialogue = compose_dialogue(dialogue_list)
+        agent_profile = self.get_profile_info(staff_id, "staff")[0]['profile']
+        agent_profile = json.loads(agent_profile)
+        user_profile = self.get_profile_info(user_id, "user")[0]['profile']
+        user_profile = json.loads(user_profile)
+        evaluator_prompt = self.generate_prompt(
+            dialogue_history=format_dialogue,
+            message=push_message,
+            send_time=send_date_str,
+            agent_profile=agent_profile,
+            user_profile=user_profile,
+        )
+        print(evaluator_prompt)
+        response = fetch_deepseek_completion(evaluator_prompt, output_type='json')
+        return {
+            "user_profile": user_profile,
+            "agent_profile": agent_profile,
+            "dialogue_history": format_dialogue,
+            "push_message": push_message,
+            "push_time": send_date_str,
+            "evaluation_result": response
+        }
+
 
     def evaluate(self):
+        data = self.get_push_dataset()
 
+        samples = random.sample(data, 48)
 
-        # data = data[:8]
+        from concurrent.futures import ThreadPoolExecutor
+        from tqdm import tqdm
+        # # 多线程处理主逻辑
+        L = []
+        with ThreadPoolExecutor(max_workers=8) as executor:  # 可根据CPU核心数调整worker数量
+            futures = []
+            for line in samples:
+                futures.append(executor.submit(self.evaluate_task, line))
 
-        # from concurrent.futures import ThreadPoolExecutor
-        # from tqdm import tqdm
-        # # # 多线程处理主逻辑
-        # L = []
-        # with ThreadPoolExecutor(max_workers=8) as executor:  # 可根据CPU核心数调整worker数量
-        #     futures = []
-        #     for line in data:
-        #         futures.append(executor.submit(self.evaluate_task, line))
-        #
-        #     # 使用tqdm显示进度
-        #     for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
-        #         result = future.result()
-        #         if result:
-        #             L.append(result)
-        for line in tqdm(data):
-            response = self.evaluate_task(line)
+            # 使用tqdm显示进度
+            for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
+                result = future.result()
+                if result:
+                    L.append(result)
+        # for line in tqdm(samples):
+        #     response = self.evaluate_task(line)
+        #     print("\n")
+        #     print(json.dumps(response, ensure_ascii=False, indent=4))
         #     if response:
         #         L.append(response)
         #
-        # # 保存结果(与原代码相同)
-        # with open("push_message_evaluation_result_0613_24_v2.json", "w", encoding="utf-8") as f:
-        #     json.dump(L, f, ensure_ascii=False, indent=4)
+        # 保存结果(与原代码相同)
+        with open("push_message_0617_eva.json", "w", encoding="utf-8") as f:
+            json.dump(L, f, ensure_ascii=False, indent=4)
 
 if __name__ == "__main__":
     PushMessageEvaluator().evaluate()

+ 1 - 1
generate_data_set.py

@@ -225,7 +225,7 @@ def generate_push_dataset():
     data_set = mysql_client.select(fetch_query, cursor_type=DictCursor)
     filter_conversation = [i for i in data_set if len(json.loads(i['conversation'])) >= 20]
 
-    samples =random.sample(filter_conversation, 300)
+    samples = random.sample(filter_conversation, 300)
 
     # init message push agent
     for sample in tqdm(samples):