from .base import BaseMySQLClient class Dataset(BaseMySQLClient): async def update_dataset_status(self, dataset_id, ori_status, new_status): query = """ UPDATE dataset SET status = %s WHERE id = %s AND status = %s; """ return await self.pool.async_save( query=query, params=(new_status, dataset_id, ori_status) ) async def select_dataset(self, status=1): query = """ SELECT * FROM dataset WHERE status = %s; """ return await self.pool.async_fetch(query=query, params=(status,)) async def add_dataset(self, name): query = """ INSERT INTO dataset (name) VALUES (%s); """ return await self.pool.async_save(query=query, params=(name,)) async def select_dataset_by_id(self, id_, status: int = 1): query = """ SELECT * FROM dataset WHERE id = %s AND status = %s; """ return await self.pool.async_fetch(query=query, params=(id_, status)) async def select_dataset_by_name(self, name, status: int = 1): query = """ SELECT * FROM dataset WHERE name = %s AND status = %s; """ return await self.pool.async_fetch(query=query, params=(name, status)) class ChatResult(BaseMySQLClient): async def insert_chat_result( self, query_text, dataset_ids, search_res, chat_res, score, has_answer, ai_answer, ai_source, ai_status, final_result, study_task_id, is_web=None, ): query = """ INSERT INTO chat_res (query, dataset_ids, search_res, chat_res, score, has_answer, ai_answer, ai_source, ai_status, is_web, final_result, study_task_id) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s); """ return await self.pool.async_save( query=query, params=( query_text, dataset_ids, search_res, chat_res, score, has_answer, ai_answer, ai_source, ai_status, is_web, final_result, study_task_id ), ) async def select_chat_results( self, page_num: int, page_size: int, order_by=None, is_web: int = 1 ): """ 分页查询 chat_res 表,并返回分页信息 :param page_num: 页码,从 1 开始 :param page_size: 每页数量 :param order_by: 排序条件,例如 {"id": "desc"} 或 {"created_at": "asc"} :param is_web: 是否为 Web 数据(默认 1) :return: dict,包含 entities、total_count、page、page_size、total_pages """ if order_by is None: order_by = {"id": "desc"} offset = (page_num - 1) * page_size # 动态拼接 where 条件 where_clauses = ["is_web = %s"] params = [is_web] where_sql = " AND ".join(where_clauses) # 动态拼接 order by order_field, order_direction = list(order_by.items())[0] order_sql = f"ORDER BY {order_field} {order_direction.upper()}" # 查询总数 count_query = f"SELECT COUNT(*) as total_count FROM chat_res WHERE {where_sql};" count_result = await self.pool.async_fetch( query=count_query, params=tuple(params) ) total_count = count_result[0]["total_count"] if count_result else 0 # 查询分页数据 query = f""" SELECT search_res, query,create_time, chat_res, ai_answer, final_result FROM chat_res WHERE {where_sql} {order_sql} LIMIT %s OFFSET %s; """ params.extend([page_size, offset]) entities = await self.pool.async_fetch(query=query, params=tuple(params)) total_pages = (total_count + page_size - 1) // page_size # 向上取整 return { "entities": entities, "total_count": total_count, "page": page_num, "page_size": page_size, "total_pages": total_pages, }