mapper.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. from .base import BaseMySQLClient
  2. class Dataset(BaseMySQLClient):
  3. async def update_dataset_status(self, dataset_id, ori_status, new_status):
  4. query = """
  5. UPDATE dataset SET status = %s WHERE id = %s AND status = %s;
  6. """
  7. return await self.pool.async_save(
  8. query=query, params=(new_status, dataset_id, ori_status)
  9. )
  10. async def select_dataset(self, status=1):
  11. query = """
  12. SELECT * FROM dataset WHERE status = %s;
  13. """
  14. return await self.pool.async_fetch(query=query, params=(status,))
  15. async def add_dataset(self, name):
  16. query = """
  17. INSERT INTO dataset (name) VALUES (%s);
  18. """
  19. return await self.pool.async_save(query=query, params=(name,))
  20. async def select_dataset_by_id(self, id_, status: int = 1):
  21. query = """
  22. SELECT * FROM dataset WHERE id = %s AND status = %s;
  23. """
  24. return await self.pool.async_fetch(query=query, params=(id_, status))
  25. async def select_dataset_by_name(self, name, status: int = 1):
  26. query = """
  27. SELECT * FROM dataset WHERE name = %s AND status = %s;
  28. """
  29. return await self.pool.async_fetch(query=query, params=(name, status))
  30. class ChatResult(BaseMySQLClient):
  31. async def insert_chat_result(
  32. self,
  33. query_text,
  34. dataset_ids,
  35. search_res,
  36. chat_res,
  37. score,
  38. has_answer,
  39. ai_answer,
  40. ai_source,
  41. ai_status,
  42. final_result,
  43. is_web=None,
  44. ):
  45. query = """
  46. INSERT INTO chat_res
  47. (query, dataset_ids, search_res, chat_res, score, has_answer, ai_answer, ai_source, ai_status, is_web, final_result)
  48. VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s);
  49. """
  50. return await self.pool.async_save(
  51. query=query,
  52. params=(
  53. query_text,
  54. dataset_ids,
  55. search_res,
  56. chat_res,
  57. score,
  58. has_answer,
  59. ai_answer,
  60. ai_source,
  61. ai_status,
  62. is_web,
  63. final_result,
  64. ),
  65. )
  66. async def select_chat_results(
  67. self, page_num: int, page_size: int, order_by=None, is_web: int = 1
  68. ):
  69. """
  70. 分页查询 chat_res 表,并返回分页信息
  71. :param page_num: 页码,从 1 开始
  72. :param page_size: 每页数量
  73. :param order_by: 排序条件,例如 {"id": "desc"} 或 {"created_at": "asc"}
  74. :param is_web: 是否为 Web 数据(默认 1)
  75. :return: dict,包含 entities、total_count、page、page_size、total_pages
  76. """
  77. if order_by is None:
  78. order_by = {"id": "desc"}
  79. offset = (page_num - 1) * page_size
  80. # 动态拼接 where 条件
  81. where_clauses = ["is_web = %s"]
  82. params = [is_web]
  83. where_sql = " AND ".join(where_clauses)
  84. # 动态拼接 order by
  85. order_field, order_direction = list(order_by.items())[0]
  86. order_sql = f"ORDER BY {order_field} {order_direction.upper()}"
  87. # 查询总数
  88. count_query = f"SELECT COUNT(*) as total_count FROM chat_res WHERE {where_sql};"
  89. count_result = await self.pool.async_fetch(
  90. query=count_query, params=tuple(params)
  91. )
  92. total_count = count_result[0]["total_count"] if count_result else 0
  93. # 查询分页数据
  94. query = f"""
  95. SELECT search_res, query,create_time, chat_res, ai_answer, final_result FROM chat_res
  96. WHERE {where_sql}
  97. {order_sql}
  98. LIMIT %s OFFSET %s;
  99. """
  100. params.extend([page_size, offset])
  101. entities = await self.pool.async_fetch(query=query, params=tuple(params))
  102. total_pages = (total_count + page_size - 1) // page_size # 向上取整
  103. return {
  104. "entities": entities,
  105. "total_count": total_count,
  106. "page": page_num,
  107. "page_size": page_size,
  108. "total_pages": total_pages,
  109. }