123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127 |
- from .base import BaseMySQLClient
- class Dataset(BaseMySQLClient):
- async def update_dataset_status(self, dataset_id, ori_status, new_status):
- query = """
- UPDATE dataset SET status = %s WHERE id = %s AND status = %s;
- """
- return await self.pool.async_save(
- query=query, params=(new_status, dataset_id, ori_status)
- )
- async def select_dataset(self, status=1):
- query = """
- SELECT * FROM dataset WHERE status = %s;
- """
- return await self.pool.async_fetch(query=query, params=(status,))
- async def add_dataset(self, name):
- query = """
- INSERT INTO dataset (name) VALUES (%s);
- """
- return await self.pool.async_save(query=query, params=(name,))
- async def select_dataset_by_id(self, id_, status: int = 1):
- query = """
- SELECT * FROM dataset WHERE id = %s AND status = %s;
- """
- return await self.pool.async_fetch(query=query, params=(id_, status))
- async def select_dataset_by_name(self, name, status: int = 1):
- query = """
- SELECT * FROM dataset WHERE name = %s AND status = %s;
- """
- return await self.pool.async_fetch(query=query, params=(name, status))
- class ChatResult(BaseMySQLClient):
- async def insert_chat_result(
- self,
- query_text,
- dataset_ids,
- search_res,
- chat_res,
- score,
- has_answer,
- ai_answer,
- ai_source,
- ai_status,
- final_result,
- study_task_id=None,
- is_web=None,
- ):
- query = """
- INSERT INTO chat_res
- (query, dataset_ids, search_res, chat_res, score, has_answer, ai_answer, ai_source, ai_status, is_web, final_result, study_task_id)
- VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s);
- """
- return await self.pool.async_save(
- query=query,
- params=(
- query_text,
- dataset_ids,
- search_res,
- chat_res,
- score,
- has_answer,
- ai_answer,
- ai_source,
- ai_status,
- is_web,
- final_result,
- study_task_id
- ),
- )
- async def select_chat_results(
- self, page_num: int, page_size: int, order_by=None, is_web: int = 1
- ):
- """
- 分页查询 chat_res 表,并返回分页信息
- :param page_num: 页码,从 1 开始
- :param page_size: 每页数量
- :param order_by: 排序条件,例如 {"id": "desc"} 或 {"created_at": "asc"}
- :param is_web: 是否为 Web 数据(默认 1)
- :return: dict,包含 entities、total_count、page、page_size、total_pages
- """
- if order_by is None:
- order_by = {"id": "desc"}
- offset = (page_num - 1) * page_size
- # 动态拼接 where 条件
- where_clauses = ["is_web = %s"]
- params = [is_web]
- where_sql = " AND ".join(where_clauses)
- # 动态拼接 order by
- order_field, order_direction = list(order_by.items())[0]
- order_sql = f"ORDER BY {order_field} {order_direction.upper()}"
- # 查询总数
- count_query = f"SELECT COUNT(*) as total_count FROM chat_res WHERE {where_sql};"
- count_result = await self.pool.async_fetch(
- query=count_query, params=tuple(params)
- )
- total_count = count_result[0]["total_count"] if count_result else 0
- # 查询分页数据
- query = f"""
- SELECT search_res, query,create_time, chat_res, ai_answer, final_result FROM chat_res
- WHERE {where_sql}
- {order_sql}
- LIMIT %s OFFSET %s;
- """
- params.extend([page_size, offset])
- entities = await self.pool.async_fetch(query=query, params=tuple(params))
- total_pages = (total_count + page_size - 1) // page_size # 向上取整
- return {
- "entities": entities,
- "total_count": total_count,
- "page": page_num,
- "page_size": page_size,
- "total_pages": total_pages,
- }
|