generate_data_set.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. import json
  2. import time
  3. import random
  4. import traceback
  5. from datetime import datetime
  6. from typing import Dict, List, Tuple
  7. from openai import OpenAI
  8. from tqdm import tqdm
  9. from pymysql.cursors import DictCursor
  10. from pqai_agent.database import MySQLManager
  11. from pqai_agent.agents.message_push_agent import MessagePushAgent
  12. from pqai_agent.logging_service import logger
  13. from pqai_agent import configs, logging_service
  14. from pqai_agent.mq_message import MessageType
  15. logging_service.setup_root_logger()
  16. config = {
  17. 'host': 'rm-bp13g3ra2f59q49xs.mysql.rds.aliyuncs.com',
  18. 'port': 3306,
  19. 'user': 'wqsd',
  20. 'password': 'wqsd@2025',
  21. 'database': 'ai_agent',
  22. 'charset': 'utf8mb4'
  23. }
  24. mysql_client = MySQLManager(config)
  25. def split_dialogue_history(dialogue_history_, timeout=30*60*1000):
  26. """
  27. :param dialogue_history_:
  28. :param timeout: 30 minutes
  29. :return:
  30. """
  31. messages_sorted = sorted(dialogue_history_, key=lambda x: x['timestamp'])
  32. dialogues = []
  33. current_dialogue = []
  34. for i, msg in enumerate(messages_sorted):
  35. if not current_dialogue:
  36. current_dialogue.append(msg)
  37. continue
  38. prev_msg = messages_sorted[i - 1]
  39. time_diff = msg["timestamp"] - prev_msg["timestamp"]
  40. # 判断是否为新对话
  41. is_new_dialogue = False
  42. if time_diff > timeout:
  43. is_new_dialogue = True
  44. if is_new_dialogue:
  45. dialogues.append(current_dialogue)
  46. current_dialogue = [msg]
  47. else:
  48. current_dialogue.append(msg)
  49. if current_dialogue:
  50. dialogues.append(current_dialogue)
  51. return dialogues
  52. def get_conversation_info():
  53. sql = f"""
  54. select roomid, count(id) as 'article_num'
  55. from qywx_chat_history where msg_type in (1,2,4) group by roomid
  56. having count(id) > 20;
  57. """
  58. return mysql_client.select(sql, cursor_type=DictCursor)
  59. def get_dialogue_history(room_id_):
  60. """
  61. 获取对话历史
  62. :param room_id_:
  63. :return:
  64. """
  65. sql = f"""
  66. select id, sender, receiver, sendtime, content
  67. from qywx_chat_history
  68. where roomid = %s and msg_type in %s order by sendtime;
  69. """
  70. return mysql_client.select(sql=sql, cursor_type=DictCursor, args=(room_id_, (1, 2, 4)))
  71. def get_dialogue_history_by_id(staff_id, dialogue_id_tuple):
  72. sql = f"""
  73. select sender, sendtime, content
  74. from qywx_chat_history
  75. where id in %s;
  76. """
  77. conversation_list = mysql_client.select(sql=sql, cursor_type=DictCursor, args=(dialogue_id_tuple,))
  78. history_conversation = [
  79. {
  80. "content": i['content'],
  81. "role": "assistant" if i['sender'] == staff_id else "user",
  82. "timestamp": i['sendtime']
  83. } for i in conversation_list
  84. ]
  85. return history_conversation
  86. def get_profile_info(user_id_, user_type):
  87. match user_type:
  88. case "user":
  89. sql = f"""
  90. select iconurl as 'avatar', profile_data_v1 as 'profile'
  91. from third_party_user where third_party_user_id = %s;
  92. """
  93. case "staff":
  94. sql = f"""
  95. select agent_profile as 'profile'
  96. from qywx_employee where third_party_user_id = %s;
  97. """
  98. case _:
  99. raise ValueError("user_type must be 'user' or 'staff'")
  100. return mysql_client.select(sql, cursor_type=DictCursor, args=(user_id_,))
  101. def generate_reply_dataset():
  102. conversation_info_list = get_conversation_info()
  103. data_set = []
  104. for conversation_info in tqdm(conversation_info_list):
  105. room_id = conversation_info["roomid"]
  106. staff_id = room_id.split(":")[1]
  107. if staff_id in ('1688854974625870', '1688856125791790', '1688856125791452'):
  108. user_id = room_id.split(":")[2]
  109. if staff_id and user_id:
  110. dialogue_history = get_dialogue_history(room_id)
  111. for idx, dialogue_info in enumerate(dialogue_history):
  112. if dialogue_info["sender"] == staff_id:
  113. conversation = dialogue_history[: idx]
  114. history_conversation = [
  115. {
  116. "id": i['id'],
  117. "content": i['content'],
  118. "role": "assistant" if i['sender'] == staff_id else "user",
  119. "timestamp": int(i['sendtime'] / 1000)
  120. } for i in conversation
  121. ]
  122. # filter history_conversation
  123. history_conversation = [i for i in history_conversation if i['timestamp'] > int(dialogue_info['sendtime'] / 1000) - 60 * 60 * 24 * 30]
  124. if len(history_conversation) > 20:
  125. history_conversation = history_conversation[-20:]
  126. eva_conversation = history_conversation[-10:]
  127. if history_conversation:
  128. user_activate_rate= len([i for i in eva_conversation if i['role'] == 'user']) / len(eva_conversation)
  129. reply_msg = dialogue_info['content']
  130. reply_time = int(dialogue_info['sendtime'] / 1000)
  131. if "早上好" in reply_msg:
  132. continue
  133. elif "早安" in reply_msg:
  134. continue
  135. elif "早" in reply_msg:
  136. continue
  137. elif "下午好" in reply_msg:
  138. continue
  139. elif "晚上好" in reply_msg:
  140. continue
  141. elif user_activate_rate < 0.3:
  142. continue
  143. else:
  144. # obj = {
  145. # "staff_id": staff_id,
  146. # "user_id": user_id,
  147. # "conversation": history_conversation,
  148. # "reply_msg": reply_msg,
  149. # "reply_time": reply_time,
  150. # "user_active_rate": user_activate_rate
  151. # }
  152. conversation_id_list = [i['id'] for i in history_conversation]
  153. insert_query = f"""
  154. insert into internal_conversation_data
  155. (dataset_id, staff_id, user_id, version_date, conversation, content, send_time, send_type, user_active_rate)
  156. values (%s, %s, %s, %s, %s, %s, %s, %s, %s);
  157. """
  158. mysql_client.execute(insert_query, args=(
  159. '1',
  160. staff_id,
  161. user_id,
  162. '2025-06-16',
  163. json.dumps(conversation_id_list, ensure_ascii=False),
  164. reply_msg,
  165. reply_time,
  166. 0,
  167. user_activate_rate
  168. ))
  169. # print(len(data_set))
  170. # with open("reply_data_set_filter_2.json", "w", encoding="utf-8") as f:
  171. # f.write(json.dumps(data_set, ensure_ascii=False, indent=4))
  172. def compose_dialogue(dialogue: List[Dict], timestamp_type: str='ms') -> str:
  173. role_map = {'user': '用户', 'assistant': '客服'}
  174. messages = []
  175. for msg in dialogue:
  176. if not msg['content']:
  177. continue
  178. if msg['role'] not in role_map:
  179. continue
  180. if timestamp_type == 'ms':
  181. format_dt = datetime.fromtimestamp(msg['timestamp'] / 1000).strftime('%Y-%m-%d %H:%M:%S')
  182. else:
  183. format_dt = datetime.fromtimestamp(msg['timestamp']).strftime('%Y-%m-%d %H:%M:%S')
  184. msg_type = msg.get('type', MessageType.TEXT).description
  185. messages.append('[{}][{}][{}]{}'.format(role_map[msg['role']], format_dt, msg_type, msg['content']))
  186. return '\n'.join(messages)
  187. def generate_push_dataset():
  188. fetch_query = f"""
  189. select staff_id, user_id, conversation, content, send_time, user_active_rate
  190. from internal_conversation_data
  191. where dataset_id = 1;
  192. """
  193. data_set = mysql_client.select(fetch_query, cursor_type=DictCursor)
  194. filter_conversation = [i for i in data_set if len(json.loads(i['conversation'])) >= 20]
  195. samples = random.sample(filter_conversation, 300)
  196. # init message push agent
  197. for sample in tqdm(samples):
  198. agent = MessagePushAgent()
  199. agent_profile = get_profile_info(sample["staff_id"], "staff")[0]['profile']
  200. agent_profile = json.loads(agent_profile)
  201. user_profile = get_profile_info(sample["user_id"], "user")[0]['profile']
  202. user_profile = json.loads(user_profile)
  203. conversation = get_dialogue_history_by_id(
  204. sample["staff_id"],
  205. tuple(json.loads(sample["conversation"]))
  206. )
  207. conversation.append(
  208. {
  209. "content": sample["content"],
  210. "role": "assistant",
  211. "timestamp": sample["send_time"] * 1000,
  212. # "type": 1
  213. }
  214. )
  215. conversation = sorted(conversation, key=lambda i: i['timestamp'], reverse=False)
  216. last_timestamp = int(conversation[-1]["timestamp"])
  217. push_time = int(last_timestamp / 1000) + 24 * 3600
  218. push_dt = datetime.fromtimestamp(push_time).strftime('%Y-%m-%d %H:%M:%S')
  219. try:
  220. push_message = agent.generate_message(
  221. context={
  222. "formatted_staff_profile": agent_profile,
  223. "nickname": user_profile.get('nickname'),
  224. "name": user_profile.get('name'),
  225. "preferred_nickname": user_profile.get('preferred_nickname'),
  226. "age": user_profile.get('age'),
  227. "region": user_profile.get('region'),
  228. "health_conditions": user_profile.get('health_conditions'),
  229. "gender": user_profile.get('gender'),
  230. "medications": user_profile.get('medications'),
  231. "interests": user_profile.get('interests'),
  232. "current_datetime": push_dt,
  233. "avatar": None
  234. },
  235. dialogue_history=conversation
  236. )
  237. if not push_message:
  238. print("push message error")
  239. continue
  240. else:
  241. print("push message success", push_message)
  242. insert_query = f"""
  243. insert into internal_conversation_data
  244. (dataset_id, staff_id, user_id, version_date, conversation, content, send_time, send_type, user_active_rate)
  245. values (%s, %s, %s, %s, %s, %s, %s, %s, %s);
  246. """
  247. mysql_client.execute(insert_query, args=(
  248. '2',
  249. sample["staff_id"],
  250. sample["user_id"],
  251. '2025-06-16',
  252. sample["conversation"],
  253. push_message,
  254. push_time,
  255. 1,
  256. sample["user_active_rate"]
  257. ))
  258. except Exception as e:
  259. print("error", e)
  260. print(traceback.format_exc())
  261. if __name__ == "__main__":
  262. generate_push_dataset()