dataset_service.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import json
  2. from cgitb import reset
  3. from typing import List
  4. from sqlalchemy import func
  5. from pqai_agent.data_models.dataset_model import DatasetModule
  6. from pqai_agent.data_models.datasets import Datasets
  7. from pqai_agent.data_models.internal_conversation_data import InternalConversationData
  8. from pqai_agent.data_models.qywx_chat_history import QywxChatHistory
  9. from pqai_agent.data_models.qywx_employee import QywxEmployee
  10. from pqai_agent_server.const.type_enum import get_dataset_type_desc
  11. from pqai_agent_server.utils.odps_utils import ODPSUtils
  12. class DatasetService:
  13. def __init__(self, session_maker):
  14. self.session_maker = session_maker
  15. odps_utils = ODPSUtils()
  16. self.odps_utils = odps_utils
  17. def get_user_profile_data(self, third_party_user_id: str, date_version: str):
  18. sql = f"""
  19. SELECT * FROM third_party_user_date_version
  20. WHERE dt between '20250612' and {date_version} -- 添加分区条件
  21. and third_party_user_id = {third_party_user_id}
  22. and profile_data_v1 is not null
  23. order by dt desc
  24. limit 1
  25. """
  26. result_df = self.odps_utils.execute_sql(sql)
  27. if not result_df.empty:
  28. return result_df.iloc[0].to_dict() # 获取第一行
  29. return None
  30. def get_dataset_list_by_module(self, module_id: int):
  31. with self.session_maker() as session:
  32. return session.query(DatasetModule).filter(DatasetModule.module_id == module_id).filter(
  33. DatasetModule.is_delete == 0).all()
  34. def get_conversation_data_list_by_dataset(self, dataset_id: int):
  35. with self.session_maker() as session:
  36. return session.query(InternalConversationData).filter(
  37. InternalConversationData.dataset_id == dataset_id).filter(
  38. InternalConversationData.is_delete == 0).all()
  39. def get_conversation_data_by_id(self, conversation_data_id: int):
  40. with self.session_maker() as session:
  41. return session.query(InternalConversationData).filter(
  42. InternalConversationData.id == conversation_data_id).one()
  43. def get_staff_profile_data(self, third_party_user_id: str):
  44. with self.session_maker() as session:
  45. return session.query(QywxEmployee).filter(
  46. QywxEmployee.third_party_user_id == third_party_user_id).one()
  47. def get_conversation_list_by_ids(self, conversation_ids: List[int]):
  48. with self.session_maker() as session:
  49. conversations = session.query(QywxChatHistory).filter(QywxChatHistory.id.in_(conversation_ids)).all()
  50. result = []
  51. for conversation in conversations:
  52. data = {}
  53. data["id"] = conversation.id
  54. data["sender"] = conversation.sender
  55. data["receiver"] = conversation.receiver
  56. data["roomid"] = conversation.roomid
  57. data["sendtime"] = conversation.sendtime / 1000
  58. data["msg_type"] = conversation.msg_type
  59. data["content"] = conversation.content
  60. result.append(data)
  61. return result
  62. def get_dataset_list(self, page_num: int, page_size: int):
  63. with self.session_maker() as session:
  64. # 计算偏移量
  65. offset = (page_num - 1) * page_size
  66. # 查询分页数据
  67. result = (session.query(Datasets)
  68. .filter(Datasets.is_delete == 0)
  69. .limit(page_size).offset(offset).all())
  70. # 查询总记录数
  71. total = session.query(func.count(Datasets.id)).filter(Datasets.is_delete == 0).scalar()
  72. total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
  73. total_page = 1 if total_page <= 0 else total_page
  74. response_data = [
  75. {
  76. "id": dataset.id,
  77. "name": dataset.name,
  78. "type": get_dataset_type_desc(dataset.type),
  79. "description": dataset.description,
  80. "createTime": dataset.create_time.strftime("%Y-%m-%d %H:%M:%S"),
  81. "updateTime": dataset.update_time.strftime("%Y-%m-%d %H:%M:%S")
  82. }
  83. for dataset in result
  84. ]
  85. return {
  86. "currentPage": page_num,
  87. "pageSize": page_size,
  88. "totalSize": total_page,
  89. "total": total,
  90. "list": response_data,
  91. }
  92. def get_conversation_data_list(self, dataset_id: int, page_num: int, page_size: int):
  93. with self.session_maker() as session:
  94. # 计算偏移量
  95. offset = (page_num - 1) * page_size
  96. # 查询分页数据
  97. result = (session.query(InternalConversationData)
  98. .filter(InternalConversationData.dataset_id == dataset_id)
  99. .filter(InternalConversationData.is_delete == 0)
  100. .limit(page_size).offset(offset).all())
  101. # 查询总记录数
  102. total = session.query(func.count(InternalConversationData.id)).filter(
  103. InternalConversationData.is_delete == 0).scalar()
  104. total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
  105. total_page = 1 if total_page <= 0 else total_page
  106. response_data = []
  107. for conversation_data in result:
  108. data = {}
  109. data["id"] = conversation_data.id
  110. data["datasetId"] = conversation_data.dataset_id
  111. data["staff"] = self.get_staff_profile_data(conversation_data.staff_id).agent_profile
  112. data["user"] = self.get_user_profile_data(conversation_data.user_id,
  113. conversation_data.version_date.replace("-", ""))['profile_data_v1']
  114. data["conversation"] = self.get_conversation_list_by_ids(json.loads(conversation_data.conversation))
  115. data["content"] = conversation_data.content
  116. data["sendTime"] = conversation_data.send_time
  117. data["sendType"] = conversation_data.send_type
  118. data["userActiveRate"] = conversation_data.user_active_rate
  119. data["createTime"]: conversation_data.create_time.strftime("%Y-%m-%d %H:%M:%S")
  120. data["updateTime"]: conversation_data.update_time.strftime("%Y-%m-%d %H:%M:%S")
  121. response_data.append(data)
  122. return {
  123. "currentPage": page_num,
  124. "pageSize": page_size,
  125. "totalSize": total_page,
  126. "total": total,
  127. "list": response_data,
  128. }