dataset_service.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. import json
  2. from cgitb import reset
  3. from typing import List
  4. from sqlalchemy import func
  5. from pqai_agent.data_models.dataset_model import DatasetModule
  6. from pqai_agent.data_models.datasets import Datasets
  7. from pqai_agent.data_models.internal_conversation_data import InternalConversationData
  8. from pqai_agent.data_models.qywx_chat_history import QywxChatHistory
  9. from pqai_agent.data_models.qywx_employee import QywxEmployee
  10. from pqai_agent_server.const.type_enum import get_dataset_type_desc
  11. from pqai_agent_server.utils.odps_utils import ODPSUtils
  12. class DatasetService:
  13. def __init__(self, session_maker):
  14. self.session_maker = session_maker
  15. odps_utils = ODPSUtils()
  16. self.odps_utils = odps_utils
  17. def get_user_profile_data(self, third_party_user_id: str, date_version: str):
  18. sql = f"""
  19. SELECT * FROM third_party_user_date_version
  20. WHERE dt between '20250612' and {date_version} -- 添加分区条件
  21. and third_party_user_id = {third_party_user_id}
  22. and profile_data_v1 is not null
  23. order by dt desc
  24. limit 1
  25. """
  26. result_df = self.odps_utils.execute_sql(sql)
  27. if not result_df.empty:
  28. return result_df.iloc[0].to_dict() # 获取第一行
  29. return None
  30. def get_dataset_list_by_module(self, module_id: int):
  31. with self.session_maker() as session:
  32. return session.query(DatasetModule).filter(DatasetModule.module_id == module_id).filter(
  33. DatasetModule.is_delete == 0).all()
  34. def get_conversation_data_list_by_dataset(self, dataset_id: int):
  35. with self.session_maker() as session:
  36. return session.query(InternalConversationData).filter(
  37. InternalConversationData.dataset_id == dataset_id).filter(
  38. InternalConversationData.is_delete == 0).order_by(
  39. InternalConversationData.id.asc()
  40. ).all()
  41. def get_conversation_data_by_id(self, conversation_data_id: int):
  42. with self.session_maker() as session:
  43. return session.query(InternalConversationData).filter(
  44. InternalConversationData.id == conversation_data_id).one()
  45. def get_staff_profile_data(self, third_party_user_id: str):
  46. with self.session_maker() as session:
  47. return session.query(QywxEmployee).filter(
  48. QywxEmployee.third_party_user_id == third_party_user_id).one()
  49. def get_conversation_list_by_ids(self, conversation_ids: List[int]):
  50. with self.session_maker() as session:
  51. conversations = session.query(QywxChatHistory).filter(QywxChatHistory.id.in_(conversation_ids)).all()
  52. result = []
  53. for conversation in conversations:
  54. data = {}
  55. data["id"] = conversation.id
  56. data["sender"] = conversation.sender
  57. data["receiver"] = conversation.receiver
  58. data["roomid"] = conversation.roomid
  59. data["sendtime"] = conversation.sendtime / 1000
  60. data["msg_type"] = conversation.msg_type
  61. data["content"] = conversation.content
  62. result.append(data)
  63. return result
  64. def get_dataset_list(self, page_num: int, page_size: int):
  65. with self.session_maker() as session:
  66. # 计算偏移量
  67. offset = (page_num - 1) * page_size
  68. # 查询分页数据
  69. result = (session.query(Datasets)
  70. .filter(Datasets.is_delete == 0)
  71. .limit(page_size).offset(offset).all())
  72. # 查询总记录数
  73. total = session.query(func.count(Datasets.id)).filter(Datasets.is_delete == 0).scalar()
  74. total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
  75. total_page = 1 if total_page <= 0 else total_page
  76. response_data = [
  77. {
  78. "id": dataset.id,
  79. "name": dataset.name,
  80. "type": get_dataset_type_desc(dataset.type),
  81. "description": dataset.description,
  82. "createTime": dataset.create_time.strftime("%Y-%m-%d %H:%M:%S"),
  83. "updateTime": dataset.update_time.strftime("%Y-%m-%d %H:%M:%S")
  84. }
  85. for dataset in result
  86. ]
  87. return {
  88. "currentPage": page_num,
  89. "pageSize": page_size,
  90. "totalSize": total_page,
  91. "total": total,
  92. "list": response_data,
  93. }
  94. def get_conversation_data_list(self, dataset_id: int, page_num: int, page_size: int):
  95. with self.session_maker() as session:
  96. # 计算偏移量
  97. offset = (page_num - 1) * page_size
  98. # 查询分页数据
  99. result = (session.query(InternalConversationData)
  100. .filter(InternalConversationData.dataset_id == dataset_id)
  101. .filter(InternalConversationData.is_delete == 0)
  102. .limit(page_size).offset(offset).all())
  103. # 查询总记录数
  104. total = session.query(func.count(InternalConversationData.id)).filter(
  105. InternalConversationData.is_delete == 0).scalar()
  106. total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
  107. total_page = 1 if total_page <= 0 else total_page
  108. response_data = []
  109. for conversation_data in result:
  110. data = {}
  111. data["id"] = conversation_data.id
  112. data["datasetId"] = conversation_data.dataset_id
  113. data["staff"] = self.get_staff_profile_data(conversation_data.staff_id).agent_profile
  114. data["user"] = self.get_user_profile_data(conversation_data.user_id,
  115. conversation_data.version_date.replace("-", ""))[
  116. 'profile_data_v1']
  117. data["conversation"] = self.get_conversation_list_by_ids(json.loads(conversation_data.conversation))
  118. data["content"] = conversation_data.content
  119. data["sendTime"] = conversation_data.send_time
  120. data["sendType"] = conversation_data.send_type
  121. data["userActiveRate"] = conversation_data.user_active_rate
  122. data["createTime"]: conversation_data.create_time.strftime("%Y-%m-%d %H:%M:%S")
  123. data["updateTime"]: conversation_data.update_time.strftime("%Y-%m-%d %H:%M:%S")
  124. response_data.append(data)
  125. return {
  126. "currentPage": page_num,
  127. "pageSize": page_size,
  128. "totalSize": total_page,
  129. "total": total,
  130. "list": response_data,
  131. }