Browse Source

Merge branch 'dev-xym-relocation' of Server/llm_vector_server into master

luojunhui 2 tuần trước cách đây
mục cha
commit
5c22088cd2

+ 19 - 10
applications/utils/chat/chat_classifier.py

@@ -15,10 +15,8 @@ class ChatClassifier:
         :return: 生成的总结prompt
         """
 
-        # 为了让AI更好地理解,我们将使用以下格式构建prompt:
         prompt = f"问题: {query}\n\n请结合以下搜索结果,生成一个总结:\n"
 
-        # 先生成基于相似度加权的summary
         weighted_summaries = []
         weighted_contents = []
 
@@ -27,15 +25,12 @@ class ChatClassifier:
             content_summary = result["contentSummary"]
             score = result["score"]
 
-            # 加权内容摘要和内容
             weighted_summaries.append((content_summary, score))
             weighted_contents.append((content, score))
 
-        # 为了生成更准确的总结,基于相似度加权内容和摘要
-        weighted_summaries.sort(key=lambda x: x[1], reverse=True)  # 按相似度降序排列
-        weighted_contents.sort(key=lambda x: x[1], reverse=True)  # 按相似度降序排列
+        weighted_summaries.sort(key=lambda x: x[1], reverse=True)
+        weighted_contents.sort(key=lambda x: x[1], reverse=True)
 
-        # 将加权的摘要和内容加入到prompt中
         prompt += "\n-- 加权内容摘要 --\n"
         for summary, score in weighted_summaries:
             prompt += f"摘要: {summary} | 相似度: {score:.2f}\n"
@@ -44,12 +39,26 @@ class ChatClassifier:
         for content, score in weighted_contents:
             prompt += f"内容: {content} | 相似度: {score:.2f}\n"
 
-        # 最后请求AI进行总结
-        prompt += "\n基于上述内容,请帮我生成一个简洁的总结。"
+        # 约束 AI 输出 JSON
+        prompt += """
+    请基于上述内容生成一个总结,并返回 JSON 格式,结构如下:
+
+    {
+      "query": "<原始问题>",
+      "summary": "<简洁总结>",
+      "relevance_score": <0到1之间的小数,表示总结与问题的相关度>
+    }
+
+    注意:
+    - 只输出 JSON,不要额外解释。
+    - relevance_score 数字越大,表示总结和问题越相关。
+    """
 
         return prompt
 
     async def chat_with_deepseek(self, query, search_results):
         prompt = self.generate_summary_prompt(query, search_results)
-        response = await fetch_deepseek_completion(model="DeepSeek-V3", prompt=prompt)
+        response = await fetch_deepseek_completion(
+            model="DeepSeek-V3", prompt=prompt, output_type="json"
+        )
         return response

+ 2 - 2
applications/utils/mysql/__init__.py

@@ -1,5 +1,5 @@
 from .pool import DatabaseManager
-from .mapper import Contents, ContentChunks, Dataset
+from .mapper import Contents, ContentChunks, Dataset, ChatResult
 
 
-__all__ = ["Contents", "ContentChunks", "DatabaseManager", "Dataset"]
+__all__ = ["Contents", "ContentChunks", "DatabaseManager", "Dataset", "ChatResult"]

+ 71 - 0
applications/utils/mysql/mapper.py

@@ -284,3 +284,74 @@ class ContentChunks(BaseMySQLClient):
             SELECT * FROM content_chunks WHERE doc_id = %s AND chunk_id = %s;
         """
         return await self.pool.async_fetch(query=query, params=(doc_id, chunk_id))
+
+    async def select_chunk_contents(
+        self,
+        page_num: int,
+        page_size: int,
+        order_by: dict = {"chunk_id": "asc"},
+        doc_id: str = None,
+        doc_status: int = None,
+    ):
+        offset = (page_num - 1) * page_size
+
+        # 动态拼接 where 条件
+        where_clauses = []
+        params = []
+
+        if doc_id:
+            where_clauses.append("doc_id = %s")
+            params.append(doc_id)
+
+        if doc_status:
+            where_clauses.append("doc_status = %s")
+            params.append(doc_status)
+
+        where_sql = " AND ".join(where_clauses)
+
+        # 动态拼接 order by
+        order_field, order_direction = list(order_by.items())[0]
+        order_sql = f"ORDER BY {order_field} {order_direction.upper()}"
+
+        # 查询总数
+        count_query = (
+            f"SELECT COUNT(*) as total_count FROM content_chunks WHERE {where_sql};"
+        )
+        count_result = await self.pool.async_fetch(
+            query=count_query, params=tuple(params)
+        )
+        total_count = count_result[0]["total_count"] if count_result else 0
+
+        # 查询分页数据
+        query = f"""
+            SELECT * FROM content_chunks
+            WHERE {where_sql}
+            {order_sql}
+            LIMIT %s OFFSET %s;
+        """
+        params.extend([page_size, offset])
+        entities = await self.pool.async_fetch(query=query, params=tuple(params))
+
+        total_pages = (total_count + page_size - 1) // page_size  # 向上取整
+        print(total_pages)
+        return {
+            "entities": entities,
+            "total_count": total_count,
+            "page": page_num,
+            "page_size": page_size,
+            "total_pages": total_pages,
+        }
+
+
+class ChatResult(BaseMySQLClient):
+    async def insert_chat_result(
+        self, query_text, dataset_ids, search_res, chat_res, score
+    ):
+        query = """
+                    INSERT INTO chat_res
+                        (query, dataset_ids, search_res, chat_res, score) 
+                        VALUES (%s, %s, %s, %s, %s);
+                """
+        return await self.pool.async_save(
+            query=query, params=(query_text, dataset_ids, search_res, chat_res, score)
+        )

+ 62 - 6
routes/buleprint.py

@@ -1,4 +1,5 @@
 import asyncio
+import json
 import traceback
 import uuid
 from typing import Dict, Any
@@ -6,18 +7,18 @@ from typing import Dict, Any
 from quart import Blueprint, jsonify, request
 from quart_cors import cors
 
+from applications.api import get_basic_embedding
+from applications.api import get_img_embedding
+from applications.async_task import ChunkEmbeddingTask, DeleteTask
 from applications.config import (
     DEFAULT_MODEL,
     LOCAL_MODEL_CONFIG,
     BASE_MILVUS_SEARCH_PARAMS,
 )
 from applications.resource import get_resource_manager
-from applications.api import get_basic_embedding
-from applications.api import get_img_embedding
-from applications.async_task import ChunkEmbeddingTask, DeleteTask
 from applications.search import HybridSearch
 from applications.utils.chat import ChatClassifier
-from applications.utils.mysql import Dataset, Contents, ContentChunks
+from applications.utils.mysql import Dataset, Contents, ContentChunks, ChatResult
 
 server_bp = Blueprint("api", __name__, url_prefix="/api")
 server_bp = cors(server_bp, allow_origin="*")
@@ -368,7 +369,8 @@ async def query():
 @server_bp.route("/chat", methods=["GET"])
 async def chat():
     query_text = request.args.get("query")
-    dataset_ids = request.args.get("datasetIds").split(",")
+    dataset_id_strs = request.args.get("datasetIds")
+    dataset_ids = dataset_id_strs.split(",")
     search_type = request.args.get("search_type", "hybrid")
     query_results = await query_search(
         query_text=query_text,
@@ -378,6 +380,7 @@ async def chat():
     resource = get_resource_manager()
     content_chunk_mapper = ContentChunks(resource.mysql_client)
     dataset_mapper = Dataset(resource.mysql_client)
+    chat_result_mapper = ChatResult(resource.mysql_client)
     res = []
     for result in query_results["results"]:
         content_chunks = await content_chunk_mapper.select_chunk_content(
@@ -411,5 +414,58 @@ async def chat():
 
     chat_classifier = ChatClassifier()
     chat_res = await chat_classifier.chat_with_deepseek(query_text, res)
-    data = {"results": res, "chat_res": chat_res}
+    data = {"results": res, "chat_res": chat_res["summary"]}
+    await chat_result_mapper.insert_chat_result(
+        query_text,
+        dataset_id_strs,
+        json.dumps(data, ensure_ascii=False),
+        chat_res["summary"],
+        chat_res["relevance_score"],
+    )
     return jsonify({"status_code": 200, "detail": "success", "data": data})
+
+
+@server_bp.route("/chunk/list", methods=["GET"])
+async def chunk_list():
+    resource = get_resource_manager()
+    content_chunk = ContentChunks(resource.mysql_client)
+
+    # 从 URL 查询参数获取分页和过滤参数
+    page_num = int(request.args.get("page", 1))
+    page_size = int(request.args.get("pageSize", 10))
+    doc_id = request.args.get("docId")
+    if not doc_id:
+        return jsonify({"status_code": 500, "detail": "docId not found", "data": {}})
+
+    # 调用 select_contents,获取分页字典
+    result = await content_chunk.select_chunk_contents(
+        page_num=page_num, page_size=page_size, doc_id=doc_id
+    )
+
+    if not result:
+        return jsonify({"status_code": 500, "detail": "chunk is empty", "data": {}})
+    # 格式化 entities,只保留必要字段
+    entities = [
+        {
+            "id": row["id"],
+            "chunk_id": row["chunk_id"],
+            "doc_id": row["doc_id"],
+            "summary": row.get("summary") or "",
+            "text": row.get("text") or "",
+        }
+        for row in result["entities"]
+    ]
+
+    return jsonify(
+        {
+            "status_code": 200,
+            "detail": "success",
+            "data": {
+                "entities": entities,
+                "total_count": result["total_count"],
+                "page": result["page"],
+                "page_size": result["page_size"],
+                "total_pages": result["total_pages"],
+            },
+        }
+    )