Bläddra i källkod

增加搜索结果存储和相关度判断

xueyiming 2 veckor sedan
förälder
incheckning
7a0ec79bf6
3 ändrade filer med 148 tillägg och 12 borttagningar
  1. 17 10
      applications/utils/chat/chat_classifier.py
  2. 73 0
      applications/utils/mysql/mapper.py
  3. 58 2
      routes/buleprint.py

+ 17 - 10
applications/utils/chat/chat_classifier.py

@@ -15,10 +15,8 @@ class ChatClassifier:
         :return: 生成的总结prompt
         :return: 生成的总结prompt
         """
         """
 
 
-        # 为了让AI更好地理解,我们将使用以下格式构建prompt:
         prompt = f"问题: {query}\n\n请结合以下搜索结果,生成一个总结:\n"
         prompt = f"问题: {query}\n\n请结合以下搜索结果,生成一个总结:\n"
 
 
-        # 先生成基于相似度加权的summary
         weighted_summaries = []
         weighted_summaries = []
         weighted_contents = []
         weighted_contents = []
 
 
@@ -27,15 +25,12 @@ class ChatClassifier:
             content_summary = result['contentSummary']
             content_summary = result['contentSummary']
             score = result['score']
             score = result['score']
 
 
-            # 加权内容摘要和内容
             weighted_summaries.append((content_summary, score))
             weighted_summaries.append((content_summary, score))
             weighted_contents.append((content, score))
             weighted_contents.append((content, score))
 
 
-        # 为了生成更准确的总结,基于相似度加权内容和摘要
-        weighted_summaries.sort(key=lambda x: x[1], reverse=True)  # 按相似度降序排列
-        weighted_contents.sort(key=lambda x: x[1], reverse=True)  # 按相似度降序排列
+        weighted_summaries.sort(key=lambda x: x[1], reverse=True)
+        weighted_contents.sort(key=lambda x: x[1], reverse=True)
 
 
-        # 将加权的摘要和内容加入到prompt中
         prompt += "\n-- 加权内容摘要 --\n"
         prompt += "\n-- 加权内容摘要 --\n"
         for summary, score in weighted_summaries:
         for summary, score in weighted_summaries:
             prompt += f"摘要: {summary} | 相似度: {score:.2f}\n"
             prompt += f"摘要: {summary} | 相似度: {score:.2f}\n"
@@ -44,14 +39,26 @@ class ChatClassifier:
         for content, score in weighted_contents:
         for content, score in weighted_contents:
             prompt += f"内容: {content} | 相似度: {score:.2f}\n"
             prompt += f"内容: {content} | 相似度: {score:.2f}\n"
 
 
-        # 最后请求AI进行总结
-        prompt += "\n基于上述内容,请帮我生成一个简洁的总结。"
+        # 约束 AI 输出 JSON
+        prompt += """
+    请基于上述内容生成一个总结,并返回 JSON 格式,结构如下:
+
+    {
+      "query": "<原始问题>",
+      "summary": "<简洁总结>",
+      "relevance_score": <0到1之间的小数,表示总结与问题的相关度>
+    }
+
+    注意:
+    - 只输出 JSON,不要额外解释。
+    - relevance_score 数字越大,表示总结和问题越相关。
+    """
 
 
         return prompt
         return prompt
 
 
     async def chat_with_deepseek(self, query, search_results):
     async def chat_with_deepseek(self, query, search_results):
         prompt = self.generate_summary_prompt(query, search_results)
         prompt = self.generate_summary_prompt(query, search_results)
         response = await fetch_deepseek_completion(
         response = await fetch_deepseek_completion(
-            model="DeepSeek-V3", prompt=prompt
+            model="DeepSeek-V3", prompt=prompt, output_type='json'
         )
         )
         return response
         return response

+ 73 - 0
applications/utils/mysql/mapper.py

@@ -287,3 +287,76 @@ class ContentChunks(BaseMySQLClient):
         return await self.pool.async_fetch(
         return await self.pool.async_fetch(
             query=query, params=(doc_id, chunk_id, status)
             query=query, params=(doc_id, chunk_id, status)
         )
         )
+
+    async def select_chunk_contents(
+            self,
+            page_num: int,
+            page_size: int,
+            order_by: dict = {"chunk_id": "asc"},
+            doc_id: str = None,
+            doc_status: int = None
+    ):
+
+        offset = (page_num - 1) * page_size
+
+        # 动态拼接 where 条件
+        where_clauses = []
+        params = []
+
+        if doc_id:
+            where_clauses.append("doc_id = %s")
+            params.append(doc_id)
+
+        if doc_status:
+            where_clauses.append("doc_status = %s")
+            params.append(doc_status)
+
+        where_sql = " AND ".join(where_clauses)
+
+        # 动态拼接 order by
+        order_field, order_direction = list(order_by.items())[0]
+        order_sql = f"ORDER BY {order_field} {order_direction.upper()}"
+
+        # 查询总数
+        count_query = f"SELECT COUNT(*) as total_count FROM content_chunks WHERE {where_sql};"
+        count_result = await self.pool.async_fetch(query=count_query, params=tuple(params))
+        total_count = count_result[0]["total_count"] if count_result else 0
+
+        # 查询分页数据
+        query = f"""
+            SELECT * FROM content_chunks
+            WHERE {where_sql}
+            {order_sql}
+            LIMIT %s OFFSET %s;
+        """
+        params.extend([page_size, offset])
+        entities = await self.pool.async_fetch(query=query, params=tuple(params))
+
+        total_pages = (total_count + page_size - 1) // page_size  # 向上取整
+        print(total_pages)
+        return {
+            "entities": entities,
+            "total_count": total_count,
+            "page": page_num,
+            "page_size": page_size,
+            "total_pages": total_pages
+        }
+
+
+class ChatRes(BaseMySQLClient):
+    async def insert_chat_res(self, query_text, dataset_ids, search_res, chat_res, score):
+        query = """
+                    INSERT INTO chat_res
+                        (query, dataset_ids, search_res, chat_res, score) 
+                        VALUES (%s, %s, %s, %s, %s);
+                """
+        return await self.pool.async_save(
+            query=query,
+            params=(
+                query_text,
+                dataset_ids,
+                search_res,
+                chat_res,
+                score
+            )
+        )

+ 58 - 2
routes/buleprint.py

@@ -1,4 +1,5 @@
 import asyncio
 import asyncio
+import json
 import traceback
 import traceback
 import uuid
 import uuid
 from typing import Dict, Any
 from typing import Dict, Any
@@ -19,6 +20,7 @@ from applications.async_task import ChunkEmbeddingTask, DeleteTask
 from applications.search import HybridSearch
 from applications.search import HybridSearch
 from applications.utils.chat import ChatClassifier
 from applications.utils.chat import ChatClassifier
 from applications.utils.mysql import Dataset, Contents, ContentChunks
 from applications.utils.mysql import Dataset, Contents, ContentChunks
+from applications.utils.mysql.mapper import ChatRes
 
 
 server_bp = Blueprint("api", __name__, url_prefix="/api")
 server_bp = Blueprint("api", __name__, url_prefix="/api")
 server_bp = cors(server_bp, allow_origin="*")
 server_bp = cors(server_bp, allow_origin="*")
@@ -359,13 +361,15 @@ async def query():
 @server_bp.route("/chat", methods=["GET"])
 @server_bp.route("/chat", methods=["GET"])
 async def chat():
 async def chat():
     query_text = request.args.get("query")
     query_text = request.args.get("query")
-    dataset_ids = request.args.get("datasetIds").split(",")
+    dataset_id_strs = request.args.get("datasetIds")
+    dataset_ids = dataset_id_strs.split(",")
     search_type = request.args.get("search_type", "hybrid")
     search_type = request.args.get("search_type", "hybrid")
     query_results = await query_search(query_text=query_text, filters={"dataset_id": dataset_ids},
     query_results = await query_search(query_text=query_text, filters={"dataset_id": dataset_ids},
                                        search_type=search_type)
                                        search_type=search_type)
     resource = get_resource_manager()
     resource = get_resource_manager()
     content_chunk_mapper = ContentChunks(resource.mysql_client)
     content_chunk_mapper = ContentChunks(resource.mysql_client)
     dataset_mapper = Dataset(resource.mysql_client)
     dataset_mapper = Dataset(resource.mysql_client)
+    chat_res_mapper = ChatRes(resource.mysql_client)
     res = []
     res = []
     for result in query_results['results']:
     for result in query_results['results']:
         content_chunks = await content_chunk_mapper.select_chunk_content(doc_id=result['doc_id'],
         content_chunks = await content_chunk_mapper.select_chunk_content(doc_id=result['doc_id'],
@@ -394,7 +398,59 @@ async def chat():
 
 
     chat_classifier = ChatClassifier()
     chat_classifier = ChatClassifier()
     chat_res = await chat_classifier.chat_with_deepseek(query_text, res)
     chat_res = await chat_classifier.chat_with_deepseek(query_text, res)
-    data = {'results': res, 'chat_res': chat_res}
+    data = {'results': res, 'chat_res': chat_res['summary']}
+    await chat_res_mapper.insert_chat_res(query_text, dataset_id_strs, json.dumps(data, ensure_ascii=False),
+                                          chat_res['summary'], chat_res['relevance_score'])
     return jsonify({'status_code': 200,
     return jsonify({'status_code': 200,
                     'detail': "success",
                     'detail': "success",
                     'data': data})
                     'data': data})
+
+
+@server_bp.route("/chunk/list", methods=["GET"])
+async def chunk_list():
+    resource = get_resource_manager()
+    content_chunk = ContentChunks(resource.mysql_client)
+
+    # 从 URL 查询参数获取分页和过滤参数
+    page_num = int(request.args.get("page", 1))
+    page_size = int(request.args.get("pageSize", 10))
+    doc_id = request.args.get("docId")
+    if not doc_id:
+        return jsonify({
+            "status_code": 500,
+            "detail": "docId not found",
+            "data": {}
+        })
+
+    # 调用 select_contents,获取分页字典
+    result = await content_chunk.select_chunk_contents(page_num=page_num, page_size=page_size, doc_id=doc_id)
+
+    if not result:
+        return jsonify({
+            "status_code": 500,
+            "detail": "chunk is empty",
+            "data": {}
+        })
+    # 格式化 entities,只保留必要字段
+    entities = [
+        {
+            "id": row['id'],
+            "chunk_id": row['chunk_id'],
+            "doc_id": row["doc_id"],
+            "summary": row.get("summary") or "",
+            "text": row.get("text") or "",
+        }
+        for row in result["entities"]
+    ]
+
+    return jsonify({
+        "status_code": 200,
+        "detail": "success",
+        "data": {
+            "entities": entities,
+            "total_count": result["total_count"],
+            "page": result["page"],
+            "page_size": result["page_size"],
+            "total_pages": result["total_pages"]
+        }
+    })