luojunhui hai 2 semanas
pai
achega
2dcbbddd6e

+ 1 - 4
applications/utils/chat/__init__.py

@@ -1,7 +1,4 @@
 from applications.utils.chat.chat_classifier import ChatClassifier
 
 
-__all__ = [
-    "ChatClassifier"
-]
-
+__all__ = ["ChatClassifier"]

+ 4 - 6
applications/utils/chat/chat_classifier.py

@@ -23,9 +23,9 @@ class ChatClassifier:
         weighted_contents = []
 
         for result in search_results:
-            content = result['content']
-            content_summary = result['contentSummary']
-            score = result['score']
+            content = result["content"]
+            content_summary = result["contentSummary"]
+            score = result["score"]
 
             # 加权内容摘要和内容
             weighted_summaries.append((content_summary, score))
@@ -51,7 +51,5 @@ class ChatClassifier:
 
     async def chat_with_deepseek(self, query, search_results):
         prompt = self.generate_summary_prompt(query, search_results)
-        response = await fetch_deepseek_completion(
-            model="DeepSeek-V3", prompt=prompt
-        )
+        response = await fetch_deepseek_completion(model="DeepSeek-V3", prompt=prompt)
         return response

+ 16 - 9
applications/utils/chunks/topic_aware_chunking.py

@@ -46,8 +46,13 @@ class TopicAwareChunker(BoundaryDetector, SplitTextIntoSentences):
 
 
 class TopicAwarePackerV1(TopicAwareChunker):
-
-    def _pack_v1(self, sentence_list: List[str], boundaries: List[int], text_type: int, dataset_id: int) -> List[Chunk]:
+    def _pack_v1(
+        self,
+        sentence_list: List[str],
+        boundaries: List[int],
+        text_type: int,
+        dataset_id: int,
+    ) -> List[Chunk]:
         boundary_set = set(boundaries)
         chunks: List[Chunk] = []
         start = 0
@@ -99,22 +104,25 @@ class TopicAwarePackerV1(TopicAwareChunker):
 
 
 class TopicAwarePackerV2(TopicAwareChunker):
-
     def _pack_v2(
-        self, sentence_list: List[str], boundaries: List[int], embeddings: np.ndarray, text_type: int, dataset_id: int
+        self,
+        sentence_list: List[str],
+        boundaries: List[int],
+        embeddings: np.ndarray,
+        text_type: int,
+        dataset_id: int,
     ) -> List[Chunk]:
         segments = []
         seg_embs = []
         last_idx = 0
         for b in boundaries + [len(sentence_list) - 1]:
-            seg = sentence_list[last_idx:b + 1]
-            seg_emb = np.mean(embeddings[last_idx:b + 1], axis=0)
+            seg = sentence_list[last_idx : b + 1]
+            seg_emb = np.mean(embeddings[last_idx : b + 1], axis=0)
             if seg:
                 segments.append(seg)
                 seg_embs.append(seg_emb)
             last_idx = b + 1
 
-
         final_segments = []
         for seg in segments:
             tokens = num_tokens("".join(seg))
@@ -139,7 +147,7 @@ class TopicAwarePackerV2(TopicAwareChunker):
                     chunk_id=index,
                     tokens=num_tokens(text),
                     text_type=text_type,
-                    status=status
+                    status=status,
                 )
             )
 
@@ -157,4 +165,3 @@ class TopicAwarePackerV2(TopicAwareChunker):
             text_type=text_type,
             dataset_id=dataset_id,
         )
-

+ 16 - 23
applications/utils/mysql/mapper.py

@@ -32,28 +32,21 @@ class Dataset(BaseMySQLClient):
         query = """
             select * from dataset where status = %s;
         """
-        return await self.pool.async_fetch(
-            query=query,
-            params=(status,)
-        )
+        return await self.pool.async_fetch(query=query, params=(status,))
 
     async def add_dataset(self, name):
         query = """
             insert into dataset (name, created_at, updated_at, status) values (%s, %s, %s, %s);
         """
         return await self.pool.async_save(
-            query=query,
-            params=(name, datetime.now(), datetime.now(), 1)
+            query=query, params=(name, datetime.now(), datetime.now(), 1)
         )
 
     async def select_dataset_by_id(self, id, status=1):
         query = """
             select * from dataset where id = %s and status = %s;
         """
-        return await self.pool.async_fetch(
-            query=query,
-            params=(id, status)
-        )
+        return await self.pool.async_fetch(query=query, params=(id, status))
 
 
 class Contents(BaseMySQLClient):
@@ -114,18 +107,15 @@ class Contents(BaseMySQLClient):
         query = """
             select * from contents where doc_id = %s;
         """
-        return await self.pool.async_fetch(
-            query=query,
-            params=(doc_id,)
-        )
+        return await self.pool.async_fetch(query=query, params=(doc_id,))
 
     async def select_contents(
-            self,
-            page_num: int,
-            page_size: int,
-            order_by: dict = {"id": "desc"},
-            dataset_id: int = None,
-            doc_status: int = 1,
+        self,
+        page_num: int,
+        page_size: int,
+        order_by: dict = {"id": "desc"},
+        dataset_id: int = None,
+        doc_status: int = 1,
     ):
         """
         分页查询 contents 表,并返回分页信息
@@ -154,7 +144,9 @@ class Contents(BaseMySQLClient):
 
         # 查询总数
         count_query = f"SELECT COUNT(*) as total_count FROM contents WHERE {where_sql};"
-        count_result = await self.pool.async_fetch(query=count_query, params=tuple(params))
+        count_result = await self.pool.async_fetch(
+            query=count_query, params=tuple(params)
+        )
         total_count = count_result[0]["total_count"] if count_result else 0
 
         # 查询分页数据
@@ -174,7 +166,7 @@ class Contents(BaseMySQLClient):
             "total_count": total_count,
             "page": page_num,
             "page_size": page_size,
-            "total_pages": total_pages
+            "total_pages": total_pages,
         }
 
 
@@ -206,7 +198,8 @@ class ContentChunks(BaseMySQLClient):
             WHERE doc_id = %s AND chunk_id = %s AND chunk_status = %s and status = %s;
         """
         return await self.pool.async_save(
-            query=query, params=(new_status, doc_id, chunk_id, ori_status, self.CHUNK_USEFUL_STATUS)
+            query=query,
+            params=(new_status, doc_id, chunk_id, ori_status, self.CHUNK_USEFUL_STATUS),
         )
 
     async def update_embedding_status(self, doc_id, chunk_id, ori_status, new_status):

+ 98 - 90
routes/buleprint.py

@@ -161,11 +161,7 @@ async def dataset_list():
         for dataset, count in zip(datasets, counts)
     ]
 
-    return jsonify({
-        "status_code": 200,
-        "detail": "success",
-        "data": data_list
-    })
+    return jsonify({"status_code": 200, "detail": "success", "data": data_list})
 
 
 @server_bp.route("/dataset/add", methods=["POST"])
@@ -176,16 +172,10 @@ async def add_dataset():
     body = await request.get_json()
     name = body.get("name")
     if not name:
-        return jsonify({
-            "status_code": 400,
-            "detail": "name is required"
-        })
+        return jsonify({"status_code": 400, "detail": "name is required"})
     # 执行新增
     await dataset.add_dataset(name)
-    return jsonify({
-        "status_code": 200,
-        "detail": "success"
-    })
+    return jsonify({"status_code": 200, "detail": "success"})
 
 
 @server_bp.route("/content/get", methods=["GET"])
@@ -196,33 +186,27 @@ async def get_content():
     # 获取请求参数
     doc_id = request.args.get("docId")
     if not doc_id:
-        return jsonify({
-            "status_code": 400,
-            "detail": "doc_id is required",
-            "data": {}
-        })
+        return jsonify({"status_code": 400, "detail": "doc_id is required", "data": {}})
 
     # 查询内容
     rows = await contents.select_content_by_doc_id(doc_id)
 
     if not rows:
-        return jsonify({
-            "status_code": 404,
-            "detail": "content not found",
-            "data": {}
-        })
+        return jsonify({"status_code": 404, "detail": "content not found", "data": {}})
 
     row = rows[0]
 
-    return jsonify({
-        "status_code": 200,
-        "detail": "success",
-        "data": {
-            "title": row.get("title", ""),
-            "text": row.get("text", ""),
-            "doc_id": row.get("doc_id", "")
+    return jsonify(
+        {
+            "status_code": 200,
+            "detail": "success",
+            "data": {
+                "title": row.get("title", ""),
+                "text": row.get("text", ""),
+                "doc_id": row.get("doc_id", ""),
+            },
         }
-    })
+    )
 
 
 @server_bp.route("/content/list", methods=["GET"])
@@ -238,6 +222,7 @@ async def content_list():
 
     # order_by 可以用 JSON 字符串传递
     import json
+
     order_by_str = request.args.get("order_by", '{"id":"desc"}')
     try:
         order_by = json.loads(order_by_str)
@@ -263,22 +248,33 @@ async def content_list():
         for row in result["entities"]
     ]
 
-    return jsonify({
-        "status_code": 200,
-        "detail": "success",
-        "data": {
-            "entities": entities,
-            "total_count": result["total_count"],
-            "page": result["page"],
-            "page_size": result["page_size"],
-            "total_pages": result["total_pages"]
+    return jsonify(
+        {
+            "status_code": 200,
+            "detail": "success",
+            "data": {
+                "entities": entities,
+                "total_count": result["total_count"],
+                "page": result["page"],
+                "page_size": result["page_size"],
+                "total_pages": result["total_pages"],
+            },
         }
-    })
+    )
 
 
-async def query_search(query_text, filters=None, search_type='', anns_field='vector_text',
-                       search_params=BASE_MILVUS_SEARCH_PARAMS, _source=False, es_size=10000, sort_by=None,
-                       milvus_size=20, limit=10):
+async def query_search(
+    query_text,
+    filters=None,
+    search_type="",
+    anns_field="vector_text",
+    search_params=BASE_MILVUS_SEARCH_PARAMS,
+    _source=False,
+    es_size=10000,
+    sort_by=None,
+    milvus_size=20,
+    limit=10,
+):
     if filters is None:
         filters = {}
     query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
@@ -320,40 +316,46 @@ async def query():
     query_text = request.args.get("query")
     dataset_ids = request.args.get("datasetIds").split(",")
     search_type = request.args.get("search_type", "hybrid")
-    query_results = await query_search(query_text=query_text, filters={"dataset_id": dataset_ids},
-                                       search_type=search_type)
+    query_results = await query_search(
+        query_text=query_text,
+        filters={"dataset_id": dataset_ids},
+        search_type=search_type,
+    )
     resource = get_resource_manager()
     content_chunk_mapper = ContentChunks(resource.mysql_client)
     dataset_mapper = Dataset(resource.mysql_client)
     res = []
-    for result in query_results['results']:
-        content_chunks = await content_chunk_mapper.select_chunk_content(doc_id=result['doc_id'],
-                                                                         chunk_id=result['chunk_id'])
+    for result in query_results["results"]:
+        content_chunks = await content_chunk_mapper.select_chunk_content(
+            doc_id=result["doc_id"], chunk_id=result["chunk_id"]
+        )
         if not content_chunks:
-            return jsonify({
-                "status_code": 500,
-                "detail": "content_chunk not found",
-                "data": {}
-            })
+            return jsonify(
+                {"status_code": 500, "detail": "content_chunk not found", "data": {}}
+            )
         content_chunk = content_chunks[0]
-        datasets = await dataset_mapper.select_dataset_by_id(content_chunk['dataset_id'])
+        datasets = await dataset_mapper.select_dataset_by_id(
+            content_chunk["dataset_id"]
+        )
         if not datasets:
-            return jsonify({
-                "status_code": 500,
-                "detail": "dataset not found",
-                "data": {}
-            })
+            return jsonify(
+                {"status_code": 500, "detail": "dataset not found", "data": {}}
+            )
         dataset = datasets[0]
         dataset_name = None
         if dataset:
-            dataset_name = dataset['name']
+            dataset_name = dataset["name"]
         res.append(
-            {'docId': content_chunk['doc_id'], 'content': content_chunk['text'],
-             'contentSummary': content_chunk['summary'], 'score': result['score'], 'datasetName': dataset_name})
-    data = {'results': res}
-    return jsonify({'status_code': 200,
-                    'detail': "success",
-                    'data': data})
+            {
+                "docId": content_chunk["doc_id"],
+                "content": content_chunk["text"],
+                "contentSummary": content_chunk["summary"],
+                "score": result["score"],
+                "datasetName": dataset_name,
+            }
+        )
+    data = {"results": res}
+    return jsonify({"status_code": 200, "detail": "success", "data": data})
 
 
 @server_bp.route("/chat", methods=["GET"])
@@ -361,40 +363,46 @@ async def chat():
     query_text = request.args.get("query")
     dataset_ids = request.args.get("datasetIds").split(",")
     search_type = request.args.get("search_type", "hybrid")
-    query_results = await query_search(query_text=query_text, filters={"dataset_id": dataset_ids},
-                                       search_type=search_type)
+    query_results = await query_search(
+        query_text=query_text,
+        filters={"dataset_id": dataset_ids},
+        search_type=search_type,
+    )
     resource = get_resource_manager()
     content_chunk_mapper = ContentChunks(resource.mysql_client)
     dataset_mapper = Dataset(resource.mysql_client)
     res = []
-    for result in query_results['results']:
-        content_chunks = await content_chunk_mapper.select_chunk_content(doc_id=result['doc_id'],
-                                                                         chunk_id=result['chunk_id'])
+    for result in query_results["results"]:
+        content_chunks = await content_chunk_mapper.select_chunk_content(
+            doc_id=result["doc_id"], chunk_id=result["chunk_id"]
+        )
         if not content_chunks:
-            return jsonify({
-                "status_code": 500,
-                "detail": "content_chunk not found",
-                "data": {}
-            })
+            return jsonify(
+                {"status_code": 500, "detail": "content_chunk not found", "data": {}}
+            )
         content_chunk = content_chunks[0]
-        datasets = await dataset_mapper.select_dataset_by_id(content_chunk['dataset_id'])
+        datasets = await dataset_mapper.select_dataset_by_id(
+            content_chunk["dataset_id"]
+        )
         if not datasets:
-            return jsonify({
-                "status_code": 500,
-                "detail": "dataset not found",
-                "data": {}
-            })
+            return jsonify(
+                {"status_code": 500, "detail": "dataset not found", "data": {}}
+            )
         dataset = datasets[0]
         dataset_name = None
         if dataset:
-            dataset_name = dataset['name']
+            dataset_name = dataset["name"]
         res.append(
-            {'docId': content_chunk['doc_id'], 'content': content_chunk['text'],
-             'contentSummary': content_chunk['summary'], 'score': result['score'], 'datasetName': dataset_name})
+            {
+                "docId": content_chunk["doc_id"],
+                "content": content_chunk["text"],
+                "contentSummary": content_chunk["summary"],
+                "score": result["score"],
+                "datasetName": dataset_name,
+            }
+        )
 
     chat_classifier = ChatClassifier()
     chat_res = await chat_classifier.chat_with_deepseek(query_text, res)
-    data = {'results': res, 'chat_res': chat_res}
-    return jsonify({'status_code': 200,
-                    'detail': "success",
-                    'data': data})
+    data = {"results": res, "chat_res": chat_res}
+    return jsonify({"status_code": 200, "detail": "success", "data": data})