Ver código fonte

增加搜索结果存储和相关度判断

xueyiming 2 semanas atrás
pai
commit
7a0ec79bf6

+ 17 - 10
applications/utils/chat/chat_classifier.py

@@ -15,10 +15,8 @@ class ChatClassifier:
         :return: 生成的总结prompt
         """
 
-        # 为了让AI更好地理解,我们将使用以下格式构建prompt:
         prompt = f"问题: {query}\n\n请结合以下搜索结果,生成一个总结:\n"
 
-        # 先生成基于相似度加权的summary
         weighted_summaries = []
         weighted_contents = []
 
@@ -27,15 +25,12 @@ class ChatClassifier:
             content_summary = result['contentSummary']
             score = result['score']
 
-            # 加权内容摘要和内容
             weighted_summaries.append((content_summary, score))
             weighted_contents.append((content, score))
 
-        # 为了生成更准确的总结,基于相似度加权内容和摘要
-        weighted_summaries.sort(key=lambda x: x[1], reverse=True)  # 按相似度降序排列
-        weighted_contents.sort(key=lambda x: x[1], reverse=True)  # 按相似度降序排列
+        weighted_summaries.sort(key=lambda x: x[1], reverse=True)
+        weighted_contents.sort(key=lambda x: x[1], reverse=True)
 
-        # 将加权的摘要和内容加入到prompt中
         prompt += "\n-- 加权内容摘要 --\n"
         for summary, score in weighted_summaries:
             prompt += f"摘要: {summary} | 相似度: {score:.2f}\n"
@@ -44,14 +39,26 @@ class ChatClassifier:
         for content, score in weighted_contents:
             prompt += f"内容: {content} | 相似度: {score:.2f}\n"
 
-        # 最后请求AI进行总结
-        prompt += "\n基于上述内容,请帮我生成一个简洁的总结。"
+        # 约束 AI 输出 JSON
+        prompt += """
+    请基于上述内容生成一个总结,并返回 JSON 格式,结构如下:
+
+    {
+      "query": "<原始问题>",
+      "summary": "<简洁总结>",
+      "relevance_score": <0到1之间的小数,表示总结与问题的相关度>
+    }
+
+    注意:
+    - 只输出 JSON,不要额外解释。
+    - relevance_score 数字越大,表示总结和问题越相关。
+    """
 
         return prompt
 
     async def chat_with_deepseek(self, query, search_results):
         prompt = self.generate_summary_prompt(query, search_results)
         response = await fetch_deepseek_completion(
-            model="DeepSeek-V3", prompt=prompt
+            model="DeepSeek-V3", prompt=prompt, output_type='json'
         )
         return response

+ 73 - 0
applications/utils/mysql/mapper.py

@@ -287,3 +287,76 @@ class ContentChunks(BaseMySQLClient):
         return await self.pool.async_fetch(
             query=query, params=(doc_id, chunk_id, status)
         )
+
+    async def select_chunk_contents(
+            self,
+            page_num: int,
+            page_size: int,
+            order_by: dict = {"chunk_id": "asc"},
+            doc_id: str = None,
+            doc_status: int = None
+    ):
+
+        offset = (page_num - 1) * page_size
+
+        # 动态拼接 where 条件
+        where_clauses = []
+        params = []
+
+        if doc_id:
+            where_clauses.append("doc_id = %s")
+            params.append(doc_id)
+
+        if doc_status:
+            where_clauses.append("doc_status = %s")
+            params.append(doc_status)
+
+        where_sql = " AND ".join(where_clauses)
+
+        # 动态拼接 order by
+        order_field, order_direction = list(order_by.items())[0]
+        order_sql = f"ORDER BY {order_field} {order_direction.upper()}"
+
+        # 查询总数
+        count_query = f"SELECT COUNT(*) as total_count FROM content_chunks WHERE {where_sql};"
+        count_result = await self.pool.async_fetch(query=count_query, params=tuple(params))
+        total_count = count_result[0]["total_count"] if count_result else 0
+
+        # 查询分页数据
+        query = f"""
+            SELECT * FROM content_chunks
+            WHERE {where_sql}
+            {order_sql}
+            LIMIT %s OFFSET %s;
+        """
+        params.extend([page_size, offset])
+        entities = await self.pool.async_fetch(query=query, params=tuple(params))
+
+        total_pages = (total_count + page_size - 1) // page_size  # 向上取整
+        print(total_pages)
+        return {
+            "entities": entities,
+            "total_count": total_count,
+            "page": page_num,
+            "page_size": page_size,
+            "total_pages": total_pages
+        }
+
+
+class ChatRes(BaseMySQLClient):
+    async def insert_chat_res(self, query_text, dataset_ids, search_res, chat_res, score):
+        query = """
+                    INSERT INTO chat_res
+                        (query, dataset_ids, search_res, chat_res, score) 
+                        VALUES (%s, %s, %s, %s, %s);
+                """
+        return await self.pool.async_save(
+            query=query,
+            params=(
+                query_text,
+                dataset_ids,
+                search_res,
+                chat_res,
+                score
+            )
+        )

+ 58 - 2
routes/buleprint.py

@@ -1,4 +1,5 @@
 import asyncio
+import json
 import traceback
 import uuid
 from typing import Dict, Any
@@ -19,6 +20,7 @@ from applications.async_task import ChunkEmbeddingTask, DeleteTask
 from applications.search import HybridSearch
 from applications.utils.chat import ChatClassifier
 from applications.utils.mysql import Dataset, Contents, ContentChunks
+from applications.utils.mysql.mapper import ChatRes
 
 server_bp = Blueprint("api", __name__, url_prefix="/api")
 server_bp = cors(server_bp, allow_origin="*")
@@ -359,13 +361,15 @@ async def query():
 @server_bp.route("/chat", methods=["GET"])
 async def chat():
     query_text = request.args.get("query")
-    dataset_ids = request.args.get("datasetIds").split(",")
+    dataset_id_strs = request.args.get("datasetIds")
+    dataset_ids = dataset_id_strs.split(",")
     search_type = request.args.get("search_type", "hybrid")
     query_results = await query_search(query_text=query_text, filters={"dataset_id": dataset_ids},
                                        search_type=search_type)
     resource = get_resource_manager()
     content_chunk_mapper = ContentChunks(resource.mysql_client)
     dataset_mapper = Dataset(resource.mysql_client)
+    chat_res_mapper = ChatRes(resource.mysql_client)
     res = []
     for result in query_results['results']:
         content_chunks = await content_chunk_mapper.select_chunk_content(doc_id=result['doc_id'],
@@ -394,7 +398,59 @@ async def chat():
 
     chat_classifier = ChatClassifier()
     chat_res = await chat_classifier.chat_with_deepseek(query_text, res)
-    data = {'results': res, 'chat_res': chat_res}
+    data = {'results': res, 'chat_res': chat_res['summary']}
+    await chat_res_mapper.insert_chat_res(query_text, dataset_id_strs, json.dumps(data, ensure_ascii=False),
+                                          chat_res['summary'], chat_res['relevance_score'])
     return jsonify({'status_code': 200,
                     'detail': "success",
                     'data': data})
+
+
+@server_bp.route("/chunk/list", methods=["GET"])
+async def chunk_list():
+    resource = get_resource_manager()
+    content_chunk = ContentChunks(resource.mysql_client)
+
+    # 从 URL 查询参数获取分页和过滤参数
+    page_num = int(request.args.get("page", 1))
+    page_size = int(request.args.get("pageSize", 10))
+    doc_id = request.args.get("docId")
+    if not doc_id:
+        return jsonify({
+            "status_code": 500,
+            "detail": "docId not found",
+            "data": {}
+        })
+
+    # 调用 select_contents,获取分页字典
+    result = await content_chunk.select_chunk_contents(page_num=page_num, page_size=page_size, doc_id=doc_id)
+
+    if not result:
+        return jsonify({
+            "status_code": 500,
+            "detail": "chunk is empty",
+            "data": {}
+        })
+    # 格式化 entities,只保留必要字段
+    entities = [
+        {
+            "id": row['id'],
+            "chunk_id": row['chunk_id'],
+            "doc_id": row["doc_id"],
+            "summary": row.get("summary") or "",
+            "text": row.get("text") or "",
+        }
+        for row in result["entities"]
+    ]
+
+    return jsonify({
+        "status_code": 200,
+        "detail": "success",
+        "data": {
+            "entities": entities,
+            "total_count": result["total_count"],
+            "page": result["page"],
+            "page_size": result["page_size"],
+            "total_pages": result["total_pages"]
+        }
+    })