mapper.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. from .base import BaseMySQLClient
  2. class Dataset(BaseMySQLClient):
  3. async def update_dataset_status(self, dataset_id, ori_status, new_status):
  4. query = """
  5. UPDATE dataset SET status = %s WHERE id = %s AND status = %s;
  6. """
  7. return await self.pool.async_save(
  8. query=query, params=(new_status, dataset_id, ori_status)
  9. )
  10. async def select_dataset(self, status=1):
  11. query = """
  12. SELECT * FROM dataset WHERE status = %s;
  13. """
  14. return await self.pool.async_fetch(query=query, params=(status,))
  15. async def add_dataset(self, name):
  16. query = """
  17. INSERT INTO dataset (name) VALUES (%s);
  18. """
  19. return await self.pool.async_save(query=query, params=(name,))
  20. async def select_dataset_by_id(self, id_, status: int = 1):
  21. query = """
  22. SELECT * FROM dataset WHERE id = %s AND status = %s;
  23. """
  24. return await self.pool.async_fetch(query=query, params=(id_, status))
  25. async def select_dataset_by_name(self, name, status: int = 1):
  26. query = """
  27. SELECT * FROM dataset WHERE name = %s AND status = %s;
  28. """
  29. return await self.pool.async_fetch(query=query, params=(name, status))
  30. class ChatResult(BaseMySQLClient):
  31. async def insert_chat_result(
  32. self,
  33. query_text,
  34. dataset_ids,
  35. search_res,
  36. chat_res,
  37. score,
  38. has_answer,
  39. ai_answer,
  40. ai_source,
  41. ai_status,
  42. is_web=None,
  43. ):
  44. query = """
  45. INSERT INTO chat_res
  46. (query, dataset_ids, search_res, chat_res, score, has_answer, ai_answer, ai_source, ai_status, is_web)
  47. VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s);
  48. """
  49. return await self.pool.async_save(
  50. query=query,
  51. params=(
  52. query_text,
  53. dataset_ids,
  54. search_res,
  55. chat_res,
  56. score,
  57. has_answer,
  58. ai_answer,
  59. ai_source,
  60. ai_status,
  61. is_web,
  62. ),
  63. )
  64. async def select_chat_results(
  65. self, page_num: int, page_size: int, order_by=None, is_web: int = 1
  66. ):
  67. """
  68. 分页查询 chat_res 表,并返回分页信息
  69. :param page_num: 页码,从 1 开始
  70. :param page_size: 每页数量
  71. :param order_by: 排序条件,例如 {"id": "desc"} 或 {"created_at": "asc"}
  72. :param is_web: 是否为 Web 数据(默认 1)
  73. :return: dict,包含 entities、total_count、page、page_size、total_pages
  74. """
  75. if order_by is None:
  76. order_by = {"id": "desc"}
  77. offset = (page_num - 1) * page_size
  78. # 动态拼接 where 条件
  79. where_clauses = ["is_web = %s"]
  80. params = [is_web]
  81. where_sql = " AND ".join(where_clauses)
  82. # 动态拼接 order by
  83. order_field, order_direction = list(order_by.items())[0]
  84. order_sql = f"ORDER BY {order_field} {order_direction.upper()}"
  85. # 查询总数
  86. count_query = f"SELECT COUNT(*) as total_count FROM chat_res WHERE {where_sql};"
  87. count_result = await self.pool.async_fetch(
  88. query=count_query, params=tuple(params)
  89. )
  90. total_count = count_result[0]["total_count"] if count_result else 0
  91. # 查询分页数据
  92. query = f"""
  93. SELECT search_res, query,create_time, chat_res, ai_answer FROM chat_res
  94. WHERE {where_sql}
  95. {order_sql}
  96. LIMIT %s OFFSET %s;
  97. """
  98. params.extend([page_size, offset])
  99. entities = await self.pool.async_fetch(query=query, params=tuple(params))
  100. total_pages = (total_count + page_size - 1) // page_size # 向上取整
  101. return {
  102. "entities": entities,
  103. "total_count": total_count,
  104. "page": page_num,
  105. "page_size": page_size,
  106. "total_pages": total_pages,
  107. }