mapper.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. from .base import BaseMySQLClient
  2. class Dataset(BaseMySQLClient):
  3. async def update_dataset_status(self, dataset_id, ori_status, new_status):
  4. query = """
  5. UPDATE dataset SET status = %s WHERE id = %s AND status = %s;
  6. """
  7. return await self.pool.async_save(
  8. query=query, params=(new_status, dataset_id, ori_status)
  9. )
  10. async def select_dataset(self, status=1):
  11. query = """
  12. SELECT * FROM dataset WHERE status = %s;
  13. """
  14. return await self.pool.async_fetch(query=query, params=(status,))
  15. async def add_dataset(self, name):
  16. query = """
  17. INSERT INTO dataset (name) VALUES (%s);
  18. """
  19. return await self.pool.async_save(query=query, params=(name,))
  20. async def select_dataset_by_id(self, id_, status: int = 1):
  21. query = """
  22. SELECT * FROM dataset WHERE id = %s AND status = %s;
  23. """
  24. return await self.pool.async_fetch(query=query, params=(id_, status))
  25. async def select_dataset_by_name(self, name, status: int = 1):
  26. query = """
  27. SELECT * FROM dataset WHERE name = %s AND status = %s;
  28. """
  29. return await self.pool.async_fetch(query=query, params=(name, status))
  30. class ChatResult(BaseMySQLClient):
  31. async def insert_chat_result(
  32. self,
  33. query_text,
  34. dataset_ids,
  35. search_res,
  36. chat_res,
  37. score,
  38. has_answer,
  39. ai_answer,
  40. ai_source,
  41. ai_status,
  42. final_result,
  43. study_task_id,
  44. is_web=None,
  45. ):
  46. query = """
  47. INSERT INTO chat_res
  48. (query, dataset_ids, search_res, chat_res, score, has_answer, ai_answer, ai_source, ai_status, is_web, final_result, study_task_id)
  49. VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s);
  50. """
  51. return await self.pool.async_save(
  52. query=query,
  53. params=(
  54. query_text,
  55. dataset_ids,
  56. search_res,
  57. chat_res,
  58. score,
  59. has_answer,
  60. ai_answer,
  61. ai_source,
  62. ai_status,
  63. is_web,
  64. final_result,
  65. study_task_id
  66. ),
  67. )
  68. async def select_chat_results(
  69. self, page_num: int, page_size: int, order_by=None, is_web: int = 1
  70. ):
  71. """
  72. 分页查询 chat_res 表,并返回分页信息
  73. :param page_num: 页码,从 1 开始
  74. :param page_size: 每页数量
  75. :param order_by: 排序条件,例如 {"id": "desc"} 或 {"created_at": "asc"}
  76. :param is_web: 是否为 Web 数据(默认 1)
  77. :return: dict,包含 entities、total_count、page、page_size、total_pages
  78. """
  79. if order_by is None:
  80. order_by = {"id": "desc"}
  81. offset = (page_num - 1) * page_size
  82. # 动态拼接 where 条件
  83. where_clauses = ["is_web = %s"]
  84. params = [is_web]
  85. where_sql = " AND ".join(where_clauses)
  86. # 动态拼接 order by
  87. order_field, order_direction = list(order_by.items())[0]
  88. order_sql = f"ORDER BY {order_field} {order_direction.upper()}"
  89. # 查询总数
  90. count_query = f"SELECT COUNT(*) as total_count FROM chat_res WHERE {where_sql};"
  91. count_result = await self.pool.async_fetch(
  92. query=count_query, params=tuple(params)
  93. )
  94. total_count = count_result[0]["total_count"] if count_result else 0
  95. # 查询分页数据
  96. query = f"""
  97. SELECT search_res, query,create_time, chat_res, ai_answer, final_result FROM chat_res
  98. WHERE {where_sql}
  99. {order_sql}
  100. LIMIT %s OFFSET %s;
  101. """
  102. params.extend([page_size, offset])
  103. entities = await self.pool.async_fetch(query=query, params=tuple(params))
  104. total_pages = (total_count + page_size - 1) // page_size # 向上取整
  105. return {
  106. "entities": entities,
  107. "total_count": total_count,
  108. "page": page_num,
  109. "page_size": page_size,
  110. "total_pages": total_pages,
  111. }