xueyiming пре 2 недеља
родитељ
комит
01207b3e68

+ 1 - 4
applications/utils/chat/__init__.py

@@ -1,7 +1,4 @@
 from applications.utils.chat.chat_classifier import ChatClassifier
 
 
-__all__ = [
-    "ChatClassifier"
-]
-
+__all__ = ["ChatClassifier"]

+ 4 - 4
applications/utils/chat/chat_classifier.py

@@ -21,9 +21,9 @@ class ChatClassifier:
         weighted_contents = []
 
         for result in search_results:
-            content = result['content']
-            content_summary = result['contentSummary']
-            score = result['score']
+            content = result["content"]
+            content_summary = result["contentSummary"]
+            score = result["score"]
 
             weighted_summaries.append((content_summary, score))
             weighted_contents.append((content, score))
@@ -59,6 +59,6 @@ class ChatClassifier:
     async def chat_with_deepseek(self, query, search_results):
         prompt = self.generate_summary_prompt(query, search_results)
         response = await fetch_deepseek_completion(
-            model="DeepSeek-V3", prompt=prompt, output_type='json'
+            model="DeepSeek-V3", prompt=prompt, output_type="json"
         )
         return response

+ 16 - 9
applications/utils/chunks/topic_aware_chunking.py

@@ -46,8 +46,13 @@ class TopicAwareChunker(BoundaryDetector, SplitTextIntoSentences):
 
 
 class TopicAwarePackerV1(TopicAwareChunker):
-
-    def _pack_v1(self, sentence_list: List[str], boundaries: List[int], text_type: int, dataset_id: int) -> List[Chunk]:
+    def _pack_v1(
+        self,
+        sentence_list: List[str],
+        boundaries: List[int],
+        text_type: int,
+        dataset_id: int,
+    ) -> List[Chunk]:
         boundary_set = set(boundaries)
         chunks: List[Chunk] = []
         start = 0
@@ -99,22 +104,25 @@ class TopicAwarePackerV1(TopicAwareChunker):
 
 
 class TopicAwarePackerV2(TopicAwareChunker):
-
     def _pack_v2(
-        self, sentence_list: List[str], boundaries: List[int], embeddings: np.ndarray, text_type: int, dataset_id: int
+        self,
+        sentence_list: List[str],
+        boundaries: List[int],
+        embeddings: np.ndarray,
+        text_type: int,
+        dataset_id: int,
     ) -> List[Chunk]:
         segments = []
         seg_embs = []
         last_idx = 0
         for b in boundaries + [len(sentence_list) - 1]:
-            seg = sentence_list[last_idx:b + 1]
-            seg_emb = np.mean(embeddings[last_idx:b + 1], axis=0)
+            seg = sentence_list[last_idx : b + 1]
+            seg_emb = np.mean(embeddings[last_idx : b + 1], axis=0)
             if seg:
                 segments.append(seg)
                 seg_embs.append(seg_emb)
             last_idx = b + 1
 
-
         final_segments = []
         for seg in segments:
             tokens = num_tokens("".join(seg))
@@ -139,7 +147,7 @@ class TopicAwarePackerV2(TopicAwareChunker):
                     chunk_id=index,
                     tokens=num_tokens(text),
                     text_type=text_type,
-                    status=status
+                    status=status,
                 )
             )
 
@@ -157,4 +165,3 @@ class TopicAwarePackerV2(TopicAwareChunker):
             text_type=text_type,
             dataset_id=dataset_id,
         )
-

+ 33 - 42
applications/utils/mysql/mapper.py

@@ -32,28 +32,21 @@ class Dataset(BaseMySQLClient):
         query = """
             select * from dataset where status = %s;
         """
-        return await self.pool.async_fetch(
-            query=query,
-            params=(status,)
-        )
+        return await self.pool.async_fetch(query=query, params=(status,))
 
     async def add_dataset(self, name):
         query = """
             insert into dataset (name, created_at, updated_at, status) values (%s, %s, %s, %s);
         """
         return await self.pool.async_save(
-            query=query,
-            params=(name, datetime.now(), datetime.now(), 1)
+            query=query, params=(name, datetime.now(), datetime.now(), 1)
         )
 
     async def select_dataset_by_id(self, id, status=1):
         query = """
             select * from dataset where id = %s and status = %s;
         """
-        return await self.pool.async_fetch(
-            query=query,
-            params=(id, status)
-        )
+        return await self.pool.async_fetch(query=query, params=(id, status))
 
 
 class Contents(BaseMySQLClient):
@@ -114,18 +107,15 @@ class Contents(BaseMySQLClient):
         query = """
             select * from contents where doc_id = %s;
         """
-        return await self.pool.async_fetch(
-            query=query,
-            params=(doc_id,)
-        )
+        return await self.pool.async_fetch(query=query, params=(doc_id,))
 
     async def select_contents(
-            self,
-            page_num: int,
-            page_size: int,
-            order_by: dict = {"id": "desc"},
-            dataset_id: int = None,
-            doc_status: int = 1,
+        self,
+        page_num: int,
+        page_size: int,
+        order_by: dict = {"id": "desc"},
+        dataset_id: int = None,
+        doc_status: int = 1,
     ):
         """
         分页查询 contents 表,并返回分页信息
@@ -154,7 +144,9 @@ class Contents(BaseMySQLClient):
 
         # 查询总数
         count_query = f"SELECT COUNT(*) as total_count FROM contents WHERE {where_sql};"
-        count_result = await self.pool.async_fetch(query=count_query, params=tuple(params))
+        count_result = await self.pool.async_fetch(
+            query=count_query, params=tuple(params)
+        )
         total_count = count_result[0]["total_count"] if count_result else 0
 
         # 查询分页数据
@@ -174,7 +166,7 @@ class Contents(BaseMySQLClient):
             "total_count": total_count,
             "page": page_num,
             "page_size": page_size,
-            "total_pages": total_pages
+            "total_pages": total_pages,
         }
 
 
@@ -206,7 +198,8 @@ class ContentChunks(BaseMySQLClient):
             WHERE doc_id = %s AND chunk_id = %s AND chunk_status = %s and status = %s;
         """
         return await self.pool.async_save(
-            query=query, params=(new_status, doc_id, chunk_id, ori_status, self.CHUNK_USEFUL_STATUS)
+            query=query,
+            params=(new_status, doc_id, chunk_id, ori_status, self.CHUNK_USEFUL_STATUS),
         )
 
     async def update_embedding_status(self, doc_id, chunk_id, ori_status, new_status):
@@ -289,14 +282,13 @@ class ContentChunks(BaseMySQLClient):
         )
 
     async def select_chunk_contents(
-            self,
-            page_num: int,
-            page_size: int,
-            order_by: dict = {"chunk_id": "asc"},
-            doc_id: str = None,
-            doc_status: int = None
+        self,
+        page_num: int,
+        page_size: int,
+        order_by: dict = {"chunk_id": "asc"},
+        doc_id: str = None,
+        doc_status: int = None,
     ):
-
         offset = (page_num - 1) * page_size
 
         # 动态拼接 where 条件
@@ -318,8 +310,12 @@ class ContentChunks(BaseMySQLClient):
         order_sql = f"ORDER BY {order_field} {order_direction.upper()}"
 
         # 查询总数
-        count_query = f"SELECT COUNT(*) as total_count FROM content_chunks WHERE {where_sql};"
-        count_result = await self.pool.async_fetch(query=count_query, params=tuple(params))
+        count_query = (
+            f"SELECT COUNT(*) as total_count FROM content_chunks WHERE {where_sql};"
+        )
+        count_result = await self.pool.async_fetch(
+            query=count_query, params=tuple(params)
+        )
         total_count = count_result[0]["total_count"] if count_result else 0
 
         # 查询分页数据
@@ -339,24 +335,19 @@ class ContentChunks(BaseMySQLClient):
             "total_count": total_count,
             "page": page_num,
             "page_size": page_size,
-            "total_pages": total_pages
+            "total_pages": total_pages,
         }
 
 
 class ChatRes(BaseMySQLClient):
-    async def insert_chat_res(self, query_text, dataset_ids, search_res, chat_res, score):
+    async def insert_chat_res(
+        self, query_text, dataset_ids, search_res, chat_res, score
+    ):
         query = """
                     INSERT INTO chat_res
                         (query, dataset_ids, search_res, chat_res, score) 
                         VALUES (%s, %s, %s, %s, %s);
                 """
         return await self.pool.async_save(
-            query=query,
-            params=(
-                query_text,
-                dataset_ids,
-                search_res,
-                chat_res,
-                score
-            )
+            query=query, params=(query_text, dataset_ids, search_res, chat_res, score)
         )

+ 124 - 115
routes/buleprint.py

@@ -163,11 +163,7 @@ async def dataset_list():
         for dataset, count in zip(datasets, counts)
     ]
 
-    return jsonify({
-        "status_code": 200,
-        "detail": "success",
-        "data": data_list
-    })
+    return jsonify({"status_code": 200, "detail": "success", "data": data_list})
 
 
 @server_bp.route("/dataset/add", methods=["POST"])
@@ -178,16 +174,10 @@ async def add_dataset():
     body = await request.get_json()
     name = body.get("name")
     if not name:
-        return jsonify({
-            "status_code": 400,
-            "detail": "name is required"
-        })
+        return jsonify({"status_code": 400, "detail": "name is required"})
     # 执行新增
     await dataset.add_dataset(name)
-    return jsonify({
-        "status_code": 200,
-        "detail": "success"
-    })
+    return jsonify({"status_code": 200, "detail": "success"})
 
 
 @server_bp.route("/content/get", methods=["GET"])
@@ -198,33 +188,27 @@ async def get_content():
     # 获取请求参数
     doc_id = request.args.get("docId")
     if not doc_id:
-        return jsonify({
-            "status_code": 400,
-            "detail": "doc_id is required",
-            "data": {}
-        })
+        return jsonify({"status_code": 400, "detail": "doc_id is required", "data": {}})
 
     # 查询内容
     rows = await contents.select_content_by_doc_id(doc_id)
 
     if not rows:
-        return jsonify({
-            "status_code": 404,
-            "detail": "content not found",
-            "data": {}
-        })
+        return jsonify({"status_code": 404, "detail": "content not found", "data": {}})
 
     row = rows[0]
 
-    return jsonify({
-        "status_code": 200,
-        "detail": "success",
-        "data": {
-            "title": row.get("title", ""),
-            "text": row.get("text", ""),
-            "doc_id": row.get("doc_id", "")
+    return jsonify(
+        {
+            "status_code": 200,
+            "detail": "success",
+            "data": {
+                "title": row.get("title", ""),
+                "text": row.get("text", ""),
+                "doc_id": row.get("doc_id", ""),
+            },
         }
-    })
+    )
 
 
 @server_bp.route("/content/list", methods=["GET"])
@@ -240,6 +224,7 @@ async def content_list():
 
     # order_by 可以用 JSON 字符串传递
     import json
+
     order_by_str = request.args.get("order_by", '{"id":"desc"}')
     try:
         order_by = json.loads(order_by_str)
@@ -265,22 +250,33 @@ async def content_list():
         for row in result["entities"]
     ]
 
-    return jsonify({
-        "status_code": 200,
-        "detail": "success",
-        "data": {
-            "entities": entities,
-            "total_count": result["total_count"],
-            "page": result["page"],
-            "page_size": result["page_size"],
-            "total_pages": result["total_pages"]
+    return jsonify(
+        {
+            "status_code": 200,
+            "detail": "success",
+            "data": {
+                "entities": entities,
+                "total_count": result["total_count"],
+                "page": result["page"],
+                "page_size": result["page_size"],
+                "total_pages": result["total_pages"],
+            },
         }
-    })
+    )
 
 
-async def query_search(query_text, filters=None, search_type='', anns_field='vector_text',
-                       search_params=BASE_MILVUS_SEARCH_PARAMS, _source=False, es_size=10000, sort_by=None,
-                       milvus_size=20, limit=10):
+async def query_search(
+    query_text,
+    filters=None,
+    search_type="",
+    anns_field="vector_text",
+    search_params=BASE_MILVUS_SEARCH_PARAMS,
+    _source=False,
+    es_size=10000,
+    sort_by=None,
+    milvus_size=20,
+    limit=10,
+):
     if filters is None:
         filters = {}
     query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
@@ -322,40 +318,46 @@ async def query():
     query_text = request.args.get("query")
     dataset_ids = request.args.get("datasetIds").split(",")
     search_type = request.args.get("search_type", "hybrid")
-    query_results = await query_search(query_text=query_text, filters={"dataset_id": dataset_ids},
-                                       search_type=search_type)
+    query_results = await query_search(
+        query_text=query_text,
+        filters={"dataset_id": dataset_ids},
+        search_type=search_type,
+    )
     resource = get_resource_manager()
     content_chunk_mapper = ContentChunks(resource.mysql_client)
     dataset_mapper = Dataset(resource.mysql_client)
     res = []
-    for result in query_results['results']:
-        content_chunks = await content_chunk_mapper.select_chunk_content(doc_id=result['doc_id'],
-                                                                         chunk_id=result['chunk_id'])
+    for result in query_results["results"]:
+        content_chunks = await content_chunk_mapper.select_chunk_content(
+            doc_id=result["doc_id"], chunk_id=result["chunk_id"]
+        )
         if not content_chunks:
-            return jsonify({
-                "status_code": 500,
-                "detail": "content_chunk not found",
-                "data": {}
-            })
+            return jsonify(
+                {"status_code": 500, "detail": "content_chunk not found", "data": {}}
+            )
         content_chunk = content_chunks[0]
-        datasets = await dataset_mapper.select_dataset_by_id(content_chunk['dataset_id'])
+        datasets = await dataset_mapper.select_dataset_by_id(
+            content_chunk["dataset_id"]
+        )
         if not datasets:
-            return jsonify({
-                "status_code": 500,
-                "detail": "dataset not found",
-                "data": {}
-            })
+            return jsonify(
+                {"status_code": 500, "detail": "dataset not found", "data": {}}
+            )
         dataset = datasets[0]
         dataset_name = None
         if dataset:
-            dataset_name = dataset['name']
+            dataset_name = dataset["name"]
         res.append(
-            {'docId': content_chunk['doc_id'], 'content': content_chunk['text'],
-             'contentSummary': content_chunk['summary'], 'score': result['score'], 'datasetName': dataset_name})
-    data = {'results': res}
-    return jsonify({'status_code': 200,
-                    'detail': "success",
-                    'data': data})
+            {
+                "docId": content_chunk["doc_id"],
+                "content": content_chunk["text"],
+                "contentSummary": content_chunk["summary"],
+                "score": result["score"],
+                "datasetName": dataset_name,
+            }
+        )
+    data = {"results": res}
+    return jsonify({"status_code": 200, "detail": "success", "data": data})
 
 
 @server_bp.route("/chat", methods=["GET"])
@@ -364,46 +366,57 @@ async def chat():
     dataset_id_strs = request.args.get("datasetIds")
     dataset_ids = dataset_id_strs.split(",")
     search_type = request.args.get("search_type", "hybrid")
-    query_results = await query_search(query_text=query_text, filters={"dataset_id": dataset_ids},
-                                       search_type=search_type)
+    query_results = await query_search(
+        query_text=query_text,
+        filters={"dataset_id": dataset_ids},
+        search_type=search_type,
+    )
     resource = get_resource_manager()
     content_chunk_mapper = ContentChunks(resource.mysql_client)
     dataset_mapper = Dataset(resource.mysql_client)
     chat_res_mapper = ChatRes(resource.mysql_client)
     res = []
-    for result in query_results['results']:
-        content_chunks = await content_chunk_mapper.select_chunk_content(doc_id=result['doc_id'],
-                                                                         chunk_id=result['chunk_id'])
+    for result in query_results["results"]:
+        content_chunks = await content_chunk_mapper.select_chunk_content(
+            doc_id=result["doc_id"], chunk_id=result["chunk_id"]
+        )
         if not content_chunks:
-            return jsonify({
-                "status_code": 500,
-                "detail": "content_chunk not found",
-                "data": {}
-            })
+            return jsonify(
+                {"status_code": 500, "detail": "content_chunk not found", "data": {}}
+            )
         content_chunk = content_chunks[0]
-        datasets = await dataset_mapper.select_dataset_by_id(content_chunk['dataset_id'])
+        datasets = await dataset_mapper.select_dataset_by_id(
+            content_chunk["dataset_id"]
+        )
         if not datasets:
-            return jsonify({
-                "status_code": 500,
-                "detail": "dataset not found",
-                "data": {}
-            })
+            return jsonify(
+                {"status_code": 500, "detail": "dataset not found", "data": {}}
+            )
         dataset = datasets[0]
         dataset_name = None
         if dataset:
-            dataset_name = dataset['name']
+            dataset_name = dataset["name"]
         res.append(
-            {'docId': content_chunk['doc_id'], 'content': content_chunk['text'],
-             'contentSummary': content_chunk['summary'], 'score': result['score'], 'datasetName': dataset_name})
+            {
+                "docId": content_chunk["doc_id"],
+                "content": content_chunk["text"],
+                "contentSummary": content_chunk["summary"],
+                "score": result["score"],
+                "datasetName": dataset_name,
+            }
+        )
 
     chat_classifier = ChatClassifier()
     chat_res = await chat_classifier.chat_with_deepseek(query_text, res)
-    data = {'results': res, 'chat_res': chat_res['summary']}
-    await chat_res_mapper.insert_chat_res(query_text, dataset_id_strs, json.dumps(data, ensure_ascii=False),
-                                          chat_res['summary'], chat_res['relevance_score'])
-    return jsonify({'status_code': 200,
-                    'detail': "success",
-                    'data': data})
+    data = {"results": res, "chat_res": chat_res["summary"]}
+    await chat_res_mapper.insert_chat_res(
+        query_text,
+        dataset_id_strs,
+        json.dumps(data, ensure_ascii=False),
+        chat_res["summary"],
+        chat_res["relevance_score"],
+    )
+    return jsonify({"status_code": 200, "detail": "success", "data": data})
 
 
 @server_bp.route("/chunk/list", methods=["GET"])
@@ -416,26 +429,20 @@ async def chunk_list():
     page_size = int(request.args.get("pageSize", 10))
     doc_id = request.args.get("docId")
     if not doc_id:
-        return jsonify({
-            "status_code": 500,
-            "detail": "docId not found",
-            "data": {}
-        })
+        return jsonify({"status_code": 500, "detail": "docId not found", "data": {}})
 
     # 调用 select_contents,获取分页字典
-    result = await content_chunk.select_chunk_contents(page_num=page_num, page_size=page_size, doc_id=doc_id)
+    result = await content_chunk.select_chunk_contents(
+        page_num=page_num, page_size=page_size, doc_id=doc_id
+    )
 
     if not result:
-        return jsonify({
-            "status_code": 500,
-            "detail": "chunk is empty",
-            "data": {}
-        })
+        return jsonify({"status_code": 500, "detail": "chunk is empty", "data": {}})
     # 格式化 entities,只保留必要字段
     entities = [
         {
-            "id": row['id'],
-            "chunk_id": row['chunk_id'],
+            "id": row["id"],
+            "chunk_id": row["chunk_id"],
             "doc_id": row["doc_id"],
             "summary": row.get("summary") or "",
             "text": row.get("text") or "",
@@ -443,14 +450,16 @@ async def chunk_list():
         for row in result["entities"]
     ]
 
-    return jsonify({
-        "status_code": 200,
-        "detail": "success",
-        "data": {
-            "entities": entities,
-            "total_count": result["total_count"],
-            "page": result["page"],
-            "page_size": result["page_size"],
-            "total_pages": result["total_pages"]
+    return jsonify(
+        {
+            "status_code": 200,
+            "detail": "success",
+            "data": {
+                "entities": entities,
+                "total_count": result["total_count"],
+                "page": result["page"],
+                "page_size": result["page_size"],
+                "total_pages": result["total_pages"],
+            },
         }
-    })
+    )