dataset_service.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. import json
  2. from cgitb import reset
  3. from typing import List
  4. from sqlalchemy import func
  5. from pqai_agent.data_models.dataset_model import DatasetModule
  6. from pqai_agent.data_models.datasets import Datasets
  7. from pqai_agent.data_models.internal_conversation_data import InternalConversationData
  8. from pqai_agent.data_models.qywx_chat_history import QywxChatHistory
  9. from pqai_agent.data_models.qywx_employee import QywxEmployee
  10. from pqai_agent_server.const.type_enum import get_dataset_type_desc
  11. from pqai_agent_server.utils.odps_utils import ODPSUtils
  12. class DatasetService:
  13. def __init__(self, session_maker):
  14. self.session_maker = session_maker
  15. odps_utils = ODPSUtils()
  16. self.odps_utils = odps_utils
  17. def get_user_profile_data(self, third_party_user_id: str, date_version: str):
  18. sql = f"""
  19. SELECT * FROM third_party_user_date_version
  20. WHERE dt >= '20250612' and dt < {date_version} -- 添加分区条件
  21. and third_party_user_id = {third_party_user_id}
  22. and profile_data_v1 is not null
  23. order by dt desc
  24. limit 1
  25. """
  26. result_df = self.odps_utils.execute_sql(sql)
  27. if not result_df.empty:
  28. return result_df.iloc[0].to_dict() # 获取第一行
  29. return None
  30. def get_dataset_module_list_by_module(self, module_id: int):
  31. with self.session_maker() as session:
  32. return session.query(DatasetModule).filter(DatasetModule.module_id == module_id).filter(
  33. DatasetModule.is_delete == 0).all()
  34. def get_conversation_data_list_by_dataset(self, dataset_id: int):
  35. with self.session_maker() as session:
  36. return session.query(InternalConversationData).filter(
  37. InternalConversationData.dataset_id == dataset_id).filter(
  38. InternalConversationData.is_delete == 0).order_by(
  39. InternalConversationData.id.asc()
  40. ).all()
  41. def get_conversation_data_by_id(self, conversation_data_id: int):
  42. with self.session_maker() as session:
  43. return session.query(InternalConversationData).filter(
  44. InternalConversationData.id == conversation_data_id).one()
  45. def get_staff_profile_data(self, third_party_user_id: str):
  46. with self.session_maker() as session:
  47. return session.query(QywxEmployee).filter(
  48. QywxEmployee.third_party_user_id == third_party_user_id).one()
  49. def get_conversation_list_by_ids(self, conversation_ids: List[int]):
  50. with self.session_maker() as session:
  51. conversations = session.query(QywxChatHistory).filter(QywxChatHistory.id.in_(conversation_ids)).order_by(
  52. QywxChatHistory.id.asc()).all()
  53. result = []
  54. for conversation in conversations:
  55. data = {}
  56. data["id"] = conversation.id
  57. data["sender"] = conversation.sender
  58. data["receiver"] = conversation.receiver
  59. data["roomid"] = conversation.roomid
  60. data["sendtime"] = conversation.sendtime / 1000
  61. data["msg_type"] = conversation.msg_type
  62. data["content"] = conversation.content
  63. result.append(data)
  64. return result
  65. def get_chat_conversation_list_by_ids(self, conversation_ids: List[int], staff_id):
  66. result = self.get_conversation_list_by_ids(conversation_ids)
  67. conversations = [
  68. {
  69. "content": conversation['content'],
  70. "role": "assistant" if conversation['sender'] == staff_id else "user",
  71. "timestamp": conversation['sendtime']
  72. } for conversation in result
  73. ]
  74. return conversations
  75. def get_dataset_list(self, page_num: int, page_size: int):
  76. with self.session_maker() as session:
  77. # 计算偏移量
  78. offset = (page_num - 1) * page_size
  79. # 查询分页数据
  80. result = (session.query(Datasets)
  81. .filter(Datasets.is_delete == 0)
  82. .limit(page_size).offset(offset).all())
  83. # 查询总记录数
  84. total = session.query(func.count(Datasets.id)).filter(Datasets.is_delete == 0).scalar()
  85. total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
  86. total_page = 1 if total_page <= 0 else total_page
  87. response_data = [
  88. {
  89. "id": dataset.id,
  90. "name": dataset.name,
  91. "type": get_dataset_type_desc(dataset.type),
  92. "description": dataset.description,
  93. "createTime": dataset.create_time.strftime("%Y-%m-%d %H:%M:%S"),
  94. "updateTime": dataset.update_time.strftime("%Y-%m-%d %H:%M:%S")
  95. }
  96. for dataset in result
  97. ]
  98. return {
  99. "currentPage": page_num,
  100. "pageSize": page_size,
  101. "totalSize": total_page,
  102. "total": total,
  103. "list": response_data,
  104. }
  105. def get_conversation_data_list(self, dataset_id: int, page_num: int, page_size: int):
  106. with self.session_maker() as session:
  107. # 计算偏移量
  108. offset = (page_num - 1) * page_size
  109. # 查询分页数据
  110. result = (session.query(InternalConversationData)
  111. .filter(InternalConversationData.dataset_id == dataset_id)
  112. .filter(InternalConversationData.is_delete == 0)
  113. .limit(page_size).offset(offset).all())
  114. # 查询总记录数
  115. total = session.query(func.count(InternalConversationData.id)).filter(
  116. InternalConversationData.is_delete == 0).scalar()
  117. total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
  118. total_page = 1 if total_page <= 0 else total_page
  119. response_data = []
  120. for conversation_data in result:
  121. data = {}
  122. data["id"] = conversation_data.id
  123. data["datasetId"] = conversation_data.dataset_id
  124. data["staff"] = self.get_staff_profile_data(conversation_data.staff_id).agent_profile
  125. data["user"] = self.get_user_profile_data(conversation_data.user_id,
  126. conversation_data.version_date.replace("-", ""))[
  127. 'profile_data_v1']
  128. data["conversation"] = self.get_conversation_list_by_ids(json.loads(conversation_data.conversation))
  129. data["content"] = conversation_data.content
  130. data["sendTime"] = conversation_data.send_time
  131. data["sendType"] = conversation_data.send_type
  132. data["userActiveRate"] = conversation_data.user_active_rate
  133. data["createTime"]: conversation_data.create_time.strftime("%Y-%m-%d %H:%M:%S")
  134. data["updateTime"]: conversation_data.update_time.strftime("%Y-%m-%d %H:%M:%S")
  135. response_data.append(data)
  136. return {
  137. "currentPage": page_num,
  138. "pageSize": page_size,
  139. "totalSize": total_page,
  140. "total": total,
  141. "list": response_data,
  142. }