generate_data_set.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. import json
  2. import time
  3. import random
  4. import traceback
  5. from datetime import datetime
  6. from openai import OpenAI
  7. from tqdm import tqdm
  8. from pymysql.cursors import DictCursor
  9. from pqai_agent.database import MySQLManager
  10. from pqai_agent.agents.message_push_agent import MessagePushAgent
  11. from pqai_agent.logging_service import logger
  12. from pqai_agent import configs, logging_service
  13. logging_service.setup_root_logger()
  14. config = {
  15. 'host': 'rm-bp13g3ra2f59q49xs.mysql.rds.aliyuncs.com',
  16. 'port': 3306,
  17. 'user': 'wqsd',
  18. 'password': 'wqsd@2025',
  19. 'database': 'ai_agent',
  20. 'charset': 'utf8mb4'
  21. }
  22. mysql_client = MySQLManager(config)
  23. def split_dialogue_history(dialogue_history_, timeout=30*60*1000):
  24. """
  25. :param dialogue_history_:
  26. :param timeout: 30 minutes
  27. :return:
  28. """
  29. messages_sorted = sorted(dialogue_history_, key=lambda x: x['timestamp'])
  30. dialogues = []
  31. current_dialogue = []
  32. for i, msg in enumerate(messages_sorted):
  33. if not current_dialogue:
  34. current_dialogue.append(msg)
  35. continue
  36. prev_msg = messages_sorted[i - 1]
  37. time_diff = msg["timestamp"] - prev_msg["timestamp"]
  38. # 判断是否为新对话
  39. is_new_dialogue = False
  40. if time_diff > timeout:
  41. is_new_dialogue = True
  42. if is_new_dialogue:
  43. dialogues.append(current_dialogue)
  44. current_dialogue = [msg]
  45. else:
  46. current_dialogue.append(msg)
  47. if current_dialogue:
  48. dialogues.append(current_dialogue)
  49. return dialogues
  50. def get_conversation_info():
  51. sql = f"""
  52. select roomid, count(id) as 'article_num'
  53. from qywx_chat_history where msg_type in (1,2,4) group by roomid
  54. having count(id) > 20;
  55. """
  56. return mysql_client.select(sql, cursor_type=DictCursor)
  57. def get_dialogue_history(room_id_):
  58. """
  59. 获取对话历史
  60. :param room_id_:
  61. :return:
  62. """
  63. sql = f"""
  64. select id, sender, receiver, sendtime, content
  65. from qywx_chat_history
  66. where roomid = %s and msg_type in %s order by sendtime;
  67. """
  68. return mysql_client.select(sql=sql, cursor_type=DictCursor, args=(room_id_, (1, 2, 4)))
  69. def get_dialogue_history_by_id(staff_id, dialogue_id_tuple):
  70. sql = f"""
  71. select sender, sendtime, content
  72. from qywx_chat_history
  73. where id in %s;
  74. """
  75. conversation_list = mysql_client.select(sql=sql, cursor_type=DictCursor, args=(dialogue_id_tuple,))
  76. history_conversation = [
  77. {
  78. "content": i['content'],
  79. "role": "assistant" if i['sender'] == staff_id else "user",
  80. "timestamp": int(i['sendtime'] / 1000)
  81. } for i in conversation_list
  82. ]
  83. return history_conversation
  84. def get_profile_info(user_id_, user_type):
  85. match user_type:
  86. case "user":
  87. sql = f"""
  88. select iconurl as 'avatar', profile_data_v1 as 'profile'
  89. from third_party_user where third_party_user_id = %s;
  90. """
  91. case "staff":
  92. sql = f"""
  93. select agent_profile as 'profile'
  94. from qywx_employee where third_party_user_id = %s;
  95. """
  96. case _:
  97. raise ValueError("user_type must be 'user' or 'staff'")
  98. return mysql_client.select(sql, cursor_type=DictCursor, args=(user_id_,))
  99. def generate_reply_dataset():
  100. conversation_info_list = get_conversation_info()
  101. data_set = []
  102. for conversation_info in tqdm(conversation_info_list):
  103. room_id = conversation_info["roomid"]
  104. staff_id = room_id.split(":")[1]
  105. if staff_id in ('1688854974625870', '1688856125791790', '1688856125791452'):
  106. user_id = room_id.split(":")[2]
  107. if staff_id and user_id:
  108. dialogue_history = get_dialogue_history(room_id)
  109. for idx, dialogue_info in enumerate(dialogue_history):
  110. if dialogue_info["sender"] == staff_id:
  111. conversation = dialogue_history[: idx]
  112. history_conversation = [
  113. {
  114. "id": i['id'],
  115. "content": i['content'],
  116. "role": "assistant" if i['sender'] == staff_id else "user",
  117. "timestamp": int(i['sendtime'] / 1000)
  118. } for i in conversation
  119. ]
  120. # filter history_conversation
  121. history_conversation = [i for i in history_conversation if i['timestamp'] > int(dialogue_info['sendtime'] / 1000) - 60 * 60 * 24 * 30]
  122. if len(history_conversation) > 20:
  123. history_conversation = history_conversation[-20:]
  124. eva_conversation = history_conversation[-10:]
  125. if history_conversation:
  126. user_activate_rate= len([i for i in eva_conversation if i['role'] == 'user']) / len(eva_conversation)
  127. reply_msg = dialogue_info['content']
  128. reply_time = int(dialogue_info['sendtime'] / 1000)
  129. if "早上好" in reply_msg:
  130. continue
  131. elif "早安" in reply_msg:
  132. continue
  133. elif "早" in reply_msg:
  134. continue
  135. elif "下午好" in reply_msg:
  136. continue
  137. elif "晚上好" in reply_msg:
  138. continue
  139. elif user_activate_rate < 0.3:
  140. continue
  141. else:
  142. # obj = {
  143. # "staff_id": staff_id,
  144. # "user_id": user_id,
  145. # "conversation": history_conversation,
  146. # "reply_msg": reply_msg,
  147. # "reply_time": reply_time,
  148. # "user_active_rate": user_activate_rate
  149. # }
  150. conversation_id_list = [i['id'] for i in history_conversation]
  151. insert_query = f"""
  152. insert into internal_conversation_data
  153. (dataset_id, staff_id, user_id, version_date, conversation, content, send_time, send_type, user_active_rate)
  154. values (%s, %s, %s, %s, %s, %s, %s, %s, %s);
  155. """
  156. mysql_client.execute(insert_query, args=(
  157. '1',
  158. staff_id,
  159. user_id,
  160. '2025-06-16',
  161. json.dumps(conversation_id_list, ensure_ascii=False),
  162. reply_msg,
  163. reply_time,
  164. 0,
  165. user_activate_rate
  166. ))
  167. # print(len(data_set))
  168. # with open("reply_data_set_filter_2.json", "w", encoding="utf-8") as f:
  169. # f.write(json.dumps(data_set, ensure_ascii=False, indent=4))
  170. def generate_push_dataset():
  171. fetch_query = f"""
  172. select staff_id, user_id, conversation, content, send_time, user_active_rate
  173. from internal_conversation_data
  174. where dataset_id = 1;
  175. """
  176. data_set = mysql_client.select(fetch_query, cursor_type=DictCursor)
  177. filter_conversation = [i for i in data_set if len(json.loads(i['conversation'])) >= 20]
  178. print(len(filter_conversation))
  179. samples =random.sample(filter_conversation, 100)
  180. # init message push agent
  181. agent = MessagePushAgent()
  182. for sample in tqdm(samples):
  183. agent_profile = get_profile_info(sample["staff_id"], "staff")[0]['profile']
  184. agent_profile = json.loads(agent_profile)
  185. user_profile = get_profile_info(sample["user_id"], "user")[0]['profile']
  186. user_profile = json.loads(user_profile)
  187. conversation = get_dialogue_history_by_id(
  188. sample["staff_id"],
  189. tuple(sample["conversation"])
  190. )
  191. conversation.append(
  192. {
  193. "content": sample["content"],
  194. "role": "assistant",
  195. "timestamp": sample["send_time"]
  196. }
  197. )
  198. conversation = sorted(conversation, key=lambda i: i['timestamp'], reverse=False)
  199. last_timestamp = int(conversation[-1]["timestamp"])
  200. push_time = last_timestamp + 48 * 3600
  201. push_dt = datetime.fromtimestamp(push_time).strftime('%Y-%m-%d %H:%M:%S')
  202. try:
  203. push_message = agent.generate_message(
  204. context={
  205. "formatted_staff_profile": agent_profile,
  206. "nickname": user_profile.get('nickname'),
  207. "name": user_profile.get('name'),
  208. "preferred_nickname": user_profile.get('preferred_nickname'),
  209. "age": user_profile.get('age'),
  210. "region": user_profile.get('region'),
  211. "health_conditions": user_profile.get('health_conditions'),
  212. "gender": user_profile.get('gender'),
  213. "medications": user_profile.get('medications'),
  214. "interests": user_profile.get('interests'),
  215. "current_datetime": push_dt,
  216. "avatar": None
  217. },
  218. dialogue_history=conversation,
  219. timestamp_type="s"
  220. )
  221. insert_query = f"""
  222. insert into internal_conversation_data
  223. (dataset_id, staff_id, user_id, version_date, conversation, content, send_time, send_type, user_active_rate)
  224. values (%s, %s, %s, %s, %s, %s, %s, %s, %s);
  225. """
  226. mysql_client.execute(insert_query, args=(
  227. '3',
  228. sample["staff_id"],
  229. sample["user_id"],
  230. '2025-06-16',
  231. sample["conversation"],
  232. push_message,
  233. push_time,
  234. 1,
  235. sample["user_active_rate"]
  236. ))
  237. except Exception as e:
  238. print("error", e)
  239. print(traceback.format_exc())
  240. if __name__ == "__main__":
  241. generate_push_dataset()