generate_data_set.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. import json
  2. import time
  3. import random
  4. from tqdm import tqdm
  5. from pymysql.cursors import DictCursor
  6. from pqai_agent.database import MySQLManager
  7. config = {
  8. 'host': 'rm-bp13g3ra2f59q49xs.mysql.rds.aliyuncs.com',
  9. 'port': 3306,
  10. 'user': 'wqsd',
  11. 'password': 'wqsd@2025',
  12. 'database': 'ai_agent',
  13. 'charset': 'utf8mb4'
  14. }
  15. mysql_client = MySQLManager(config)
  16. def split_dialogue_history(dialogue_history_, timeout=30*60*1000):
  17. """
  18. :param dialogue_history_:
  19. :param timeout: 30 minutes
  20. :return:
  21. """
  22. messages_sorted = sorted(dialogue_history_, key=lambda x: x['timestamp'])
  23. dialogues = []
  24. current_dialogue = []
  25. for i, msg in enumerate(messages_sorted):
  26. if not current_dialogue:
  27. current_dialogue.append(msg)
  28. continue
  29. prev_msg = messages_sorted[i - 1]
  30. time_diff = msg["timestamp"] - prev_msg["timestamp"]
  31. # 判断是否为新对话
  32. is_new_dialogue = False
  33. if time_diff > timeout:
  34. is_new_dialogue = True
  35. if is_new_dialogue:
  36. dialogues.append(current_dialogue)
  37. current_dialogue = [msg]
  38. else:
  39. current_dialogue.append(msg)
  40. if current_dialogue:
  41. dialogues.append(current_dialogue)
  42. return dialogues
  43. def get_conversation_info():
  44. sql = f"""
  45. select roomid, count(id) as 'article_num'
  46. from qywx_chat_history where msg_type = 1 group by roomid
  47. having count(id) > 20;
  48. """
  49. return mysql_client.select(sql, cursor_type=DictCursor)
  50. def get_dialogue_history(room_id_):
  51. """
  52. 获取对话历史
  53. :param room_id_:
  54. :return:
  55. """
  56. sql = f"""
  57. select sender, receiver, sendtime, content
  58. from qywx_chat_history
  59. where roomid = %s and msg_type = %s;
  60. """
  61. return mysql_client.select(sql=sql, cursor_type=DictCursor, args=(room_id_, 1))
  62. def get_profile_info(user_id_, user_type):
  63. match user_type:
  64. case "user":
  65. sql = f"""
  66. select iconurl as 'avatar', profile_data_v1 as 'profile'
  67. from third_party_user where third_party_user_id = %s;
  68. """
  69. case "staff":
  70. sql = f"""
  71. select agent_profile as 'profile'
  72. from qywx_employee where third_party_user_id = %s;
  73. """
  74. case _:
  75. raise ValueError("user_type must be 'user' or 'staff'")
  76. return mysql_client.select(sql, cursor_type=DictCursor, args=(user_id_,))
  77. if __name__ == "__main__":
  78. conversation_info_list = get_conversation_info()
  79. data_set = []
  80. for conversation_info in tqdm(conversation_info_list):
  81. room_id = conversation_info["roomid"]
  82. staff_id = room_id.split(":")[1]
  83. user_id = room_id.split(":")[2]
  84. if staff_id and user_id:
  85. dialogue_history = get_dialogue_history(room_id)
  86. for idx, dialogue_info in enumerate(dialogue_history):
  87. if dialogue_info["sender"] == staff_id:
  88. conversation = dialogue_history[: idx]
  89. history_conversation = [
  90. {
  91. "content": i['content'],
  92. "role": "assistant" if i['sender'] == staff_id else "user",
  93. "timestamp": int(i['sendtime'] / 1000)
  94. } for i in conversation]
  95. # filter history_conversation
  96. history_conversation = [i for i in history_conversation if i['timestamp'] > int(dialogue_info['sendtime'] / 1000) - 60 * 60 * 24 * 30]
  97. if len(history_conversation) > 100:
  98. history_conversation = history_conversation[-100:]
  99. reply_msg = dialogue_info['content']
  100. obj = {
  101. "staff_id": staff_id,
  102. "user_id": user_id,
  103. "conversation": history_conversation,
  104. "reply_msg": reply_msg
  105. }
  106. data_set.append(obj)
  107. with open("reply_data_set_filter.json", "w", encoding="utf-8") as f:
  108. f.write(json.dumps(data_set, ensure_ascii=False, indent=4))