mysql_session_manager.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. import abc
  2. import pymysql.cursors
  3. from typing import Dict, List, Optional
  4. from pqai_agent.database import MySQLManager
  5. class SessionManager(abc.ABC):
  6. @abc.abstractmethod
  7. def get_staff_sessions(
  8. self,
  9. staff_id: str,
  10. page_id: int = 1,
  11. page_size: int = 10,
  12. session_type: str = "default",
  13. ) -> List[Dict]:
  14. pass
  15. @abc.abstractmethod
  16. def get_staff_sessions_summary(
  17. self,
  18. staff_id: str,
  19. page_id: int,
  20. page_size: int,
  21. status: int,
  22. ) -> Dict:
  23. pass
  24. @abc.abstractmethod
  25. def get_staff_session_list(
  26. self, staff_id: str, page_id: int, page_size: int
  27. ) -> Dict:
  28. pass
  29. @abc.abstractmethod
  30. def get_conversation_list(
  31. self, staff_id: str, user_id: str, page: Optional[int], page_size: int
  32. ) -> Dict:
  33. pass
  34. class MySQLSessionManager(SessionManager):
  35. def __init__(self, db_config, staff_table, user_table, agent_state_table, chat_history_table):
  36. self.db = MySQLManager(db_config)
  37. self.staff_table = staff_table
  38. self.user_table = user_table
  39. self.agent_state_table = agent_state_table
  40. self.chat_history_table = chat_history_table
  41. def get_staff_sessions(
  42. self,
  43. staff_id,
  44. page_id: int = 1,
  45. page_size: int = 10,
  46. session_type: str = "default",
  47. ) -> List[Dict]:
  48. """
  49. :param page_size:
  50. :param page_id:
  51. :param session_type:
  52. :param staff_id:
  53. :return:
  54. """
  55. match session_type:
  56. case "active":
  57. sql = f"""
  58. select staff_id, current_state, user_id
  59. from {self.agent_state_table}
  60. where staff_id = %s and update_timestamp >= DATE_SUB(NOW(), INTERVAL 2 HOUR)
  61. order by update_timestamp desc;
  62. """
  63. case "human_intervention":
  64. sql = f"""
  65. select staff_id, current_state, user_id
  66. from {self.agent_state_table}
  67. where staff_id = %s and current_state = 5 order by update_timestamp desc;
  68. """
  69. case _:
  70. sql = f"""
  71. select t1.staff_id, t1.current_state, t1.user_id, t2.name, t2.iconurl
  72. from {self.agent_state_table} t1 join {self.user_table} t2
  73. on t1.user_id = t2.third_party_user_id
  74. where t1.staff_id = %s
  75. order by
  76. IF(t1.current_state = 5, 0, 1),
  77. t1.update_timestamp desc
  78. limit {page_size + 1} offset {page_size * (page_id - 1)};
  79. """
  80. staff_sessions = self.db.select(
  81. sql, cursor_type=pymysql.cursors.DictCursor, args=(staff_id,)
  82. )
  83. return staff_sessions
  84. def get_staff_sessions_summary(
  85. self, staff_id, page_id: int, page_size: int, status: int
  86. ) -> Dict:
  87. """
  88. :param status: staff status(0: unemployed, 1: employed)
  89. :param staff_id: staff
  90. :param page_id: page id
  91. :param page_size: page size
  92. :return:
  93. :todo: 未使用 Mysql 连接池,每次查询均需要与 MySQL 建立连接,性能较低,需要优化
  94. """
  95. if not staff_id:
  96. get_staff_query = f"""
  97. select third_party_user_id, name from {self.staff_table} where status = %s
  98. limit %s offset %s;
  99. """
  100. staff_id_list = self.db.select(
  101. sql=get_staff_query,
  102. cursor_type=pymysql.cursors.DictCursor,
  103. args=(status, page_size + 1, (page_id - 1) * page_size),
  104. )
  105. if not staff_id_list:
  106. return {}
  107. if len(staff_id_list) > page_size:
  108. has_next_page = True
  109. next_page_id = page_id + 1
  110. staff_id_list = staff_id_list[:page_size]
  111. else:
  112. has_next_page = False
  113. next_page_id = None
  114. else:
  115. get_staff_query = f"""
  116. select third_party_user_id, name from {self.staff_table}
  117. where status = %s and third_party_user_id = %s;
  118. """
  119. staff_id_list = self.db.select(
  120. sql=get_staff_query,
  121. cursor_type=pymysql.cursors.DictCursor,
  122. args=(status, staff_id),
  123. )
  124. if not staff_id_list:
  125. return {}
  126. has_next_page = False
  127. next_page_id = None
  128. response_data = [
  129. {
  130. "staff_id": staff["third_party_user_id"],
  131. "staff_name": staff["name"],
  132. "active_sessions": len(
  133. self.get_staff_sessions(
  134. staff["third_party_user_id"], session_type="active"
  135. )
  136. ),
  137. "human_intervention_sessions": len(
  138. self.get_staff_sessions(
  139. staff["third_party_user_id"], session_type="human_intervention"
  140. )
  141. ),
  142. }
  143. for staff in staff_id_list
  144. ]
  145. return {
  146. "has_next_page": has_next_page,
  147. "next_page_id": next_page_id,
  148. "data": response_data,
  149. }
  150. def get_staff_session_list(self, staff_id, page_id: int, page_size: int) -> Dict:
  151. """
  152. :param page_size:
  153. :param page_id:
  154. :param staff_id:
  155. :return:
  156. """
  157. session_list = self.get_staff_sessions(staff_id, page_id, page_size)
  158. if len(session_list) > page_size:
  159. has_next_page = True
  160. next_page_id = page_id + 1
  161. session_list = session_list[:page_size]
  162. else:
  163. has_next_page = False
  164. next_page_id = None
  165. response_data = []
  166. for session in session_list:
  167. temp_obj = {}
  168. user_id = session["user_id"]
  169. room_id = ":".join(["private", staff_id, user_id])
  170. select_query = f"""
  171. select content, sendtime as max_timestamp, msg_type
  172. from {self.chat_history_table}
  173. where roomid = %s
  174. order by sendtime desc limit %s;
  175. """
  176. last_message = self.db.select(
  177. sql=select_query,
  178. cursor_type=pymysql.cursors.DictCursor,
  179. args=(room_id, 1),
  180. )
  181. if not last_message:
  182. temp_obj["message"] = None
  183. temp_obj["timestamp"] = 0
  184. temp_obj["msg_type"] = None
  185. else:
  186. temp_obj["message"] = last_message[0]["content"]
  187. temp_obj["timestamp"] = last_message[0]["max_timestamp"]
  188. temp_obj["msg_type"] = last_message[0]["msg_type"]
  189. temp_obj["user_id"] = user_id
  190. temp_obj["user_name"] = session["name"]
  191. temp_obj["avatar"] = session["iconurl"]
  192. temp_obj["current_state"] = session["current_state"]
  193. response_data.append(temp_obj)
  194. return {
  195. "staff_id": staff_id,
  196. "has_next_page": has_next_page,
  197. "next_page_id": next_page_id,
  198. "data": response_data,
  199. }
  200. def get_conversation_list(
  201. self, staff_id: str, user_id: str, page: Optional[int], page_size: int
  202. ):
  203. """
  204. :param page_size:
  205. :param staff_id:
  206. :param user_id:
  207. :param page: timestamp
  208. :return:
  209. """
  210. room_id = ":".join(["private", staff_id, user_id])
  211. if not page:
  212. fetch_query = f"""
  213. select t1.sender, t2.name, t1.sendtime, t1.content, t2.iconurl, t1.msg_type
  214. from {self.chat_history_table} t1
  215. join {self.user_table} t2 on t1.sender = t2.third_party_user_id
  216. where roomid = %s
  217. order by sendtime desc
  218. limit %s;
  219. """
  220. messages = self.db.select(
  221. sql=fetch_query,
  222. cursor_type=pymysql.cursors.DictCursor,
  223. args=(room_id, page_size + 1),
  224. )
  225. else:
  226. fetch_query = f"""
  227. select t1.sender, t2.name, t1.sendtime, t1.content, t2.iconurl, t1.msg_type
  228. from {self.chat_history_table} t1
  229. join {self.user_table} t2 on t1.sender = t2.third_party_user_id
  230. where t1.roomid = %s and t1.sendtime <= %s
  231. order by sendtime desc
  232. limit %s;
  233. """
  234. messages = self.db.select(
  235. sql=fetch_query,
  236. cursor_type=pymysql.cursors.DictCursor,
  237. args=(room_id, page, page_size + 1),
  238. )
  239. if messages:
  240. if len(messages) > page_size:
  241. has_next_page = True
  242. next_page = messages[-1]["sendtime"]
  243. else:
  244. has_next_page = False
  245. next_page = None
  246. response_data = [
  247. {
  248. "sender_id": message["sender"],
  249. "sender_name": message["name"],
  250. "avatar": message["iconurl"],
  251. "content": message["content"],
  252. "timestamp": message["sendtime"],
  253. "msg_type": message["msg_type"],
  254. "role": "user" if message["sender"] == user_id else "staff",
  255. }
  256. for message in messages
  257. ]
  258. return {
  259. "staff_id": staff_id,
  260. "user_id": user_id,
  261. "has_next_page": has_next_page,
  262. "next_page": next_page,
  263. "data": response_data,
  264. }
  265. else:
  266. has_next_page = False
  267. next_page = None
  268. return {
  269. "staff_id": staff_id,
  270. "user_id": user_id,
  271. "has_next_page": has_next_page,
  272. "next_page": next_page,
  273. "data": [],
  274. }