generate_data_set.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. import json
  2. import time
  3. import random
  4. from datetime import datetime
  5. from openai import OpenAI
  6. from tqdm import tqdm
  7. from pymysql.cursors import DictCursor
  8. from pqai_agent.database import MySQLManager
  9. from pqai_agent.agents.message_push_agent import MessagePushAgent
  10. from pqai_agent.logging_service import logger
  11. from pqai_agent import configs, logging_service
  12. logging_service.setup_root_logger()
  13. config = {
  14. 'host': 'rm-bp13g3ra2f59q49xs.mysql.rds.aliyuncs.com',
  15. 'port': 3306,
  16. 'user': 'wqsd',
  17. 'password': 'wqsd@2025',
  18. 'database': 'ai_agent',
  19. 'charset': 'utf8mb4'
  20. }
  21. mysql_client = MySQLManager(config)
  22. def split_dialogue_history(dialogue_history_, timeout=30*60*1000):
  23. """
  24. :param dialogue_history_:
  25. :param timeout: 30 minutes
  26. :return:
  27. """
  28. messages_sorted = sorted(dialogue_history_, key=lambda x: x['timestamp'])
  29. dialogues = []
  30. current_dialogue = []
  31. for i, msg in enumerate(messages_sorted):
  32. if not current_dialogue:
  33. current_dialogue.append(msg)
  34. continue
  35. prev_msg = messages_sorted[i - 1]
  36. time_diff = msg["timestamp"] - prev_msg["timestamp"]
  37. # 判断是否为新对话
  38. is_new_dialogue = False
  39. if time_diff > timeout:
  40. is_new_dialogue = True
  41. if is_new_dialogue:
  42. dialogues.append(current_dialogue)
  43. current_dialogue = [msg]
  44. else:
  45. current_dialogue.append(msg)
  46. if current_dialogue:
  47. dialogues.append(current_dialogue)
  48. return dialogues
  49. def get_conversation_info():
  50. sql = f"""
  51. select roomid, count(id) as 'article_num'
  52. from qywx_chat_history where msg_type in (1,2,4) group by roomid
  53. having count(id) > 20;
  54. """
  55. return mysql_client.select(sql, cursor_type=DictCursor)
  56. def get_dialogue_history(room_id_):
  57. """
  58. 获取对话历史
  59. :param room_id_:
  60. :return:
  61. """
  62. sql = f"""
  63. select id, sender, receiver, sendtime, content
  64. from qywx_chat_history
  65. where roomid = %s and msg_type in %s order by sendtime;
  66. """
  67. return mysql_client.select(sql=sql, cursor_type=DictCursor, args=(room_id_, (1, 2, 4)))
  68. def get_profile_info(user_id_, user_type):
  69. match user_type:
  70. case "user":
  71. sql = f"""
  72. select iconurl as 'avatar', profile_data_v1 as 'profile'
  73. from third_party_user where third_party_user_id = %s;
  74. """
  75. case "staff":
  76. sql = f"""
  77. select agent_profile as 'profile'
  78. from qywx_employee where third_party_user_id = %s;
  79. """
  80. case _:
  81. raise ValueError("user_type must be 'user' or 'staff'")
  82. return mysql_client.select(sql, cursor_type=DictCursor, args=(user_id_,))
  83. def generate_reply_dataset():
  84. conversation_info_list = get_conversation_info()
  85. data_set = []
  86. for conversation_info in tqdm(conversation_info_list):
  87. room_id = conversation_info["roomid"]
  88. staff_id = room_id.split(":")[1]
  89. if staff_id in ('1688854974625870', '1688856125791790', '1688856125791452'):
  90. user_id = room_id.split(":")[2]
  91. if staff_id and user_id:
  92. dialogue_history = get_dialogue_history(room_id)
  93. for idx, dialogue_info in enumerate(dialogue_history):
  94. if dialogue_info["sender"] == staff_id:
  95. conversation = dialogue_history[: idx]
  96. history_conversation = [
  97. {
  98. "id": i['id'],
  99. "content": i['content'],
  100. "role": "assistant" if i['sender'] == staff_id else "user",
  101. "timestamp": int(i['sendtime'] / 1000)
  102. } for i in conversation
  103. ]
  104. # filter history_conversation
  105. history_conversation = [i for i in history_conversation if i['timestamp'] > int(dialogue_info['sendtime'] / 1000) - 60 * 60 * 24 * 30]
  106. if len(history_conversation) > 20:
  107. history_conversation = history_conversation[-20:]
  108. eva_conversation = history_conversation[-10:]
  109. if history_conversation:
  110. user_activate_rate= len([i for i in eva_conversation if i['role'] == 'user']) / len(eva_conversation)
  111. reply_msg = dialogue_info['content']
  112. reply_time = int(dialogue_info['sendtime'] / 1000)
  113. if "早上好" in reply_msg:
  114. continue
  115. elif "早安" in reply_msg:
  116. continue
  117. elif "早" in reply_msg:
  118. continue
  119. elif "下午好" in reply_msg:
  120. continue
  121. elif "晚上好" in reply_msg:
  122. continue
  123. elif user_activate_rate < 0.3:
  124. continue
  125. else:
  126. # obj = {
  127. # "staff_id": staff_id,
  128. # "user_id": user_id,
  129. # "conversation": history_conversation,
  130. # "reply_msg": reply_msg,
  131. # "reply_time": reply_time,
  132. # "user_active_rate": user_activate_rate
  133. # }
  134. conversation_id_list = [i['id'] for i in history_conversation]
  135. insert_query = f"""
  136. insert into internal_conversation_data
  137. (dataset_id, staff_id, user_id, version_date, conversation, content, send_time, send_type, user_active_rate)
  138. values (%s, %s, %s, %s, %s, %s, %s, %s, %s);
  139. """
  140. mysql_client.execute(insert_query, args=(
  141. '1',
  142. staff_id,
  143. user_id,
  144. '2025-06-16',
  145. json.dumps(conversation_id_list, ensure_ascii=False),
  146. reply_msg,
  147. reply_time,
  148. 0,
  149. user_activate_rate
  150. ))
  151. # print(len(data_set))
  152. # with open("reply_data_set_filter_2.json", "w", encoding="utf-8") as f:
  153. # f.write(json.dumps(data_set, ensure_ascii=False, indent=4))
  154. def generate_push_dataset():
  155. conversation_info_list = get_conversation_info()
  156. data_set = []
  157. for conversation_info in conversation_info_list:
  158. room_id = conversation_info["roomid"]
  159. staff_id = room_id.split(":")[1]
  160. # if staff_id in ('1688854974625870', '1688856125791790', '1688856125791452'):
  161. user_id = room_id.split(":")[2]
  162. if staff_id and user_id:
  163. history_push_message = [] # 所有 push 消息历史(这个变量会保留到每次循环)
  164. dialogue_history = get_dialogue_history(room_id)
  165. for idx, dialogue_info in enumerate(dialogue_history):
  166. if idx == 0:
  167. continue # 防止访问 idx - 1
  168. sender = dialogue_info["sender"]
  169. send_timestamp = int(dialogue_info["sendtime"] / 1000)
  170. before_message = dialogue_history[idx - 1]
  171. before_send_timestamp = int(before_message["sendtime"] / 1000)
  172. if sender == staff_id and (send_timestamp - before_send_timestamp) >= 86400:
  173. push_msg = dialogue_info['content']
  174. if push_msg == '早安,新的一天,愿你拥有最好的心情去迎接一切美好!爆款视频抢先观看,点击下方精彩不断~':
  175. continue
  176. else:
  177. print( datetime.fromtimestamp(send_timestamp), push_msg)
  178. conversation = [
  179. i for i in dialogue_history[: idx] if i['content'] != '早安,新的一天,愿你拥有最好的心情去迎接一切美好!爆款视频抢先观看,点击下方精彩不断~'
  180. ]
  181. history_conversation = [
  182. {
  183. "content": i['content'],
  184. "role": "assistant" if i['sender'] == staff_id else "user",
  185. "timestamp": int(i['sendtime'] / 1000)
  186. }
  187. for i in conversation
  188. if int(i['sendtime'] / 1000) > send_timestamp - 86400 * 5
  189. ]
  190. if len(history_conversation) > 50:
  191. history_conversation = history_conversation[-50:]
  192. if history_conversation:
  193. push_msg = dialogue_info['content']
  194. obj = {
  195. "staff_id": staff_id,
  196. "user_id": user_id,
  197. "conversation": history_conversation,
  198. "push_msg": push_msg,
  199. "push_time": datetime.fromtimestamp(send_timestamp).strftime("%Y-%m-%d %H:%M:%S"),
  200. "history_push_messages": history_push_message.copy()
  201. }
  202. data_set.append(obj)
  203. history_push_message.append(push_msg)
  204. with open("push_message_dataset_v4.json", "w", encoding="utf-8") as f:
  205. f.write(json.dumps(data_set, ensure_ascii=False, indent=4))
  206. def generate_push_dataset_new():
  207. import json
  208. with open("reply_data_set_filter_2.json", "r", encoding="utf-8") as f:
  209. data_set = json.loads(f.read())
  210. #
  211. filter_conversation = [i for i in data_set if len(i['conversation']) >= 20]
  212. samples =random.sample(filter_conversation, 50)
  213. # with open("push_dataset_new_0613_24h.json", encoding="utf-8") as f:
  214. # samples = json.load(f)
  215. # init message push agent
  216. agent = MessagePushAgent()
  217. for sample in tqdm(samples):
  218. agent_profile = get_profile_info(sample["staff_id"], "staff")[0]['profile']
  219. agent_profile = json.loads(agent_profile)
  220. user_profile = get_profile_info(sample["user_id"], "user")[0]['profile']
  221. user_profile = json.loads(user_profile)
  222. # agent_profile = sample["agent_profile"]
  223. # user_profile = sample["user_profile"]
  224. conversation = sorted(sample["conversation"], key=lambda i: i['timestamp'], reverse=False)
  225. last_timestamp = int(conversation[-1]["timestamp"])
  226. push_time = last_timestamp + 48 * 3600
  227. push_dt = datetime.fromtimestamp(push_time).strftime('%Y-%m-%d %H:%M:%S')
  228. try:
  229. response = agent.generate_message(
  230. context={
  231. "formatted_staff_profile": agent_profile,
  232. "nickname": user_profile.get('nickname'),
  233. "name": user_profile.get('name'),
  234. "preferred_nickname": user_profile.get('preferred_nickname'),
  235. "age": user_profile.get('age'),
  236. "region": user_profile.get('region'),
  237. "health_conditions": user_profile.get('health_conditions'),
  238. "gender": user_profile.get('gender'),
  239. "medications": user_profile.get('medications'),
  240. "interests": user_profile.get('interests'),
  241. "current_datetime": push_dt,
  242. "avatar": None
  243. },
  244. dialogue_history=sample["conversation"],
  245. timestamp_type="s"
  246. )
  247. print("---------push消息----------", response)
  248. print("\n")
  249. sample['push_msg'] = response
  250. sample['user_profile'] = user_profile
  251. sample['agent_profile'] = agent_profile
  252. sample['push_time'] = push_time
  253. data_set.append(sample)
  254. except Exception as e:
  255. print("error", e)
  256. with open("push_dataset_new_0613_24h_v2.json", "w", encoding="utf-8") as f:
  257. f.write(json.dumps(data_set, ensure_ascii=False, indent=4))
  258. if __name__ == "__main__":
  259. generate_reply_dataset()