Selaa lähdekoodia

增加api接口

xueyiming 2 viikkoa sitten
vanhempi
commit
e25c748218

+ 7 - 0
applications/utils/chat/__init__.py

@@ -0,0 +1,7 @@
+from applications.utils.chat.chat_classifier import ChatClassifier
+
+
+__all__ = [
+    "ChatClassifier"
+]
+

+ 57 - 0
applications/utils/chat/chat_classifier.py

@@ -0,0 +1,57 @@
+from typing import List
+
+from applications.config import Chunk
+from applications.api import fetch_deepseek_completion
+
+
+class ChatClassifier:
+    @staticmethod
+    def generate_summary_prompt(query, search_results):
+        """
+        生成总结的prompt。
+
+        :param query: 问题
+        :param search_results: 搜索结果列表,每个元素包含 'content', 'contentSummary', 'score'
+        :return: 生成的总结prompt
+        """
+
+        # 为了让AI更好地理解,我们将使用以下格式构建prompt:
+        prompt = f"问题: {query}\n\n请结合以下搜索结果,生成一个总结:\n"
+
+        # 先生成基于相似度加权的summary
+        weighted_summaries = []
+        weighted_contents = []
+
+        for result in search_results:
+            content = result['content']
+            content_summary = result['contentSummary']
+            score = result['score']
+
+            # 加权内容摘要和内容
+            weighted_summaries.append((content_summary, score))
+            weighted_contents.append((content, score))
+
+        # 为了生成更准确的总结,基于相似度加权内容和摘要
+        weighted_summaries.sort(key=lambda x: x[1], reverse=True)  # 按相似度降序排列
+        weighted_contents.sort(key=lambda x: x[1], reverse=True)  # 按相似度降序排列
+
+        # 将加权的摘要和内容加入到prompt中
+        prompt += "\n-- 加权内容摘要 --\n"
+        for summary, score in weighted_summaries:
+            prompt += f"摘要: {summary} | 相似度: {score:.2f}\n"
+
+        prompt += "\n-- 加权内容 --\n"
+        for content, score in weighted_contents:
+            prompt += f"内容: {content} | 相似度: {score:.2f}\n"
+
+        # 最后请求AI进行总结
+        prompt += "\n基于上述内容,请帮我生成一个简洁的总结。"
+
+        return prompt
+
+    async def chat_with_deepseek(self, query, search_results):
+        prompt = self.generate_summary_prompt(query, search_results)
+        response = await fetch_deepseek_completion(
+            model="DeepSeek-V3", prompt=prompt
+        )
+        return response

+ 111 - 0
applications/utils/mysql/mapper.py

@@ -1,4 +1,6 @@
 import json
+from datetime import datetime
+
 from applications.config import Chunk
 
 
@@ -24,6 +26,33 @@ class Dataset(BaseMySQLClient):
             params=(new_status, dataset_id, ori_status),
         )
 
+    async def select_dataset(self, status=1):
+        query = """
+            select * from dataset where status = %s;
+        """
+        return await self.pool.async_fetch(
+            query=query,
+            params=(status,)
+        )
+
+    async def add_dataset(self, name):
+        query = """
+            insert into dataset (name, created_at, updated_at, status) values (%s, %s, %s, %s);
+        """
+        return await self.pool.async_save(
+            query=query,
+            params=(name, datetime.now(), datetime.now(), 1)
+        )
+
+    async def select_dataset_by_id(self, id, status=1):
+        query = """
+            select * from dataset where id = %s and status = %s;
+        """
+        return await self.pool.async_fetch(
+            query=query,
+            params=(id, status)
+        )
+
 
 class Contents(BaseMySQLClient):
     async def insert_content(self, doc_id, text, text_type, title, dataset_id):
@@ -72,6 +101,80 @@ class Contents(BaseMySQLClient):
             query=query, params=(new_status, doc_id, ori_status)
         )
 
+    async def select_count(self, dataset_id, doc_status=1):
+        query = """
+            select count(*) as count from contents where dataset_id = %s and doc_status = %s;
+        """
+        rows = await self.pool.async_fetch(query=query, params=(dataset_id, doc_status))
+        return rows[0]["count"] if rows else 0
+
+    async def select_content_by_doc_id(self, doc_id):
+        query = """
+            select * from contents where doc_id = %s;
+        """
+        return await self.pool.async_fetch(
+            query=query,
+            params=(doc_id,)
+        )
+
+    async def select_contents(
+            self,
+            page_num: int,
+            page_size: int,
+            order_by: dict = {"id": "desc"},
+            dataset_id: int = None,
+            doc_status: int = 1,
+    ):
+        """
+        分页查询 contents 表,并返回分页信息
+        :param page_num: 页码,从 1 开始
+        :param page_size: 每页数量
+        :param order_by: 排序条件,例如 {"id": "desc"} 或 {"created_at": "asc"}
+        :param dataset_id: 数据集 ID
+        :param doc_status: 文档状态(默认 1)
+        :return: dict,包含 entities、total_count、page、page_size、total_pages
+        """
+        offset = (page_num - 1) * page_size
+
+        # 动态拼接 where 条件
+        where_clauses = ["doc_status = %s"]
+        params = [doc_status]
+
+        if dataset_id:
+            where_clauses.append("dataset_id = %s")
+            params.append(dataset_id)
+
+        where_sql = " AND ".join(where_clauses)
+
+        # 动态拼接 order by
+        order_field, order_direction = list(order_by.items())[0]
+        order_sql = f"ORDER BY {order_field} {order_direction.upper()}"
+
+        # 查询总数
+        count_query = f"SELECT COUNT(*) as total_count FROM contents WHERE {where_sql};"
+        count_result = await self.pool.async_fetch(query=count_query, params=tuple(params))
+        total_count = count_result[0]["total_count"] if count_result else 0
+
+        # 查询分页数据
+        query = f"""
+            SELECT * FROM contents
+            WHERE {where_sql}
+            {order_sql}
+            LIMIT %s OFFSET %s;
+        """
+        params.extend([page_size, offset])
+        entities = await self.pool.async_fetch(query=query, params=tuple(params))
+
+        total_pages = (total_count + page_size - 1) // page_size  # 向上取整
+
+        return {
+            "entities": entities,
+            "total_count": total_count,
+            "page": page_num,
+            "page_size": page_size,
+            "total_pages": total_pages
+        }
+
 
 class ContentChunks(BaseMySQLClient):
     async def insert_chunk(self, chunk: Chunk) -> int:
@@ -173,3 +276,11 @@ class ContentChunks(BaseMySQLClient):
         return await self.pool.async_save(
             query=query, params=(new_status, dataset_id, ori_status)
         )
+
+    async def select_chunk_content(self, doc_id, chunk_id, status=1):
+        query = """
+            select * from content_chunks where doc_id = %s and chunk_id = %s and status = %s;
+        """
+        return await self.pool.async_fetch(
+            query=query, params=(doc_id, chunk_id, status)
+        )

+ 267 - 1
routes/buleprint.py

@@ -1,8 +1,10 @@
+import asyncio
 import traceback
 import uuid
 from typing import Dict, Any
 
 from quart import Blueprint, jsonify, request
+from quart_cors import cors
 
 from applications.config import (
     DEFAULT_MODEL,
@@ -15,9 +17,11 @@ from applications.api import get_basic_embedding
 from applications.api import get_img_embedding
 from applications.async_task import ChunkEmbeddingTask, DeleteTask
 from applications.search import HybridSearch
-
+from applications.utils.chat import ChatClassifier
+from applications.utils.mysql import Dataset, Contents, ContentChunks
 
 server_bp = Blueprint("api", __name__, url_prefix="/api")
+server_bp = cors(server_bp, allow_origin="*")
 
 
 @server_bp.route("/embed", methods=["POST"])
@@ -138,3 +142,265 @@ async def search():
                 return jsonify({"error": "error  search_type"}), 200
     except Exception as e:
         return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500
+
+
+@server_bp.route("/dataset/list", methods=["GET"])
+async def dataset_list():
+    resource = get_resource_manager()
+    datasets = await Dataset(resource.mysql_client).select_dataset()
+
+    # 创建所有任务
+    tasks = [
+        Contents(resource.mysql_client).select_count(dataset["id"])
+        for dataset in datasets
+    ]
+    counts = await asyncio.gather(*tasks)
+
+    # 组装数据
+    data_list = [
+        {
+            "dataset_id": dataset["id"],
+            "name": dataset["name"],
+            "count": count,
+            "created_at": dataset["created_at"].strftime("%Y-%m-%d"),
+        }
+        for dataset, count in zip(datasets, counts)
+    ]
+
+    return jsonify({
+        "status_code": 200,
+        "detail": "success",
+        "data": data_list
+    })
+
+
+@server_bp.route("/dataset/add", methods=["POST"])
+async def add_dataset():
+    resource = get_resource_manager()
+    dataset = Dataset(resource.mysql_client)
+    # 从请求体里取参数
+    body = await request.get_json()
+    name = body.get("name")
+    if not name:
+        return jsonify({
+            "status_code": 400,
+            "detail": "name is required"
+        })
+    # 执行新增
+    await dataset.add_dataset(name)
+    return jsonify({
+        "status_code": 200,
+        "detail": "success"
+    })
+
+
+@server_bp.route("/content/get", methods=["GET"])
+async def get_content():
+    resource = get_resource_manager()
+    contents = Contents(resource.mysql_client)
+
+    # 获取请求参数
+    doc_id = request.args.get("docId")
+    if not doc_id:
+        return jsonify({
+            "status_code": 400,
+            "detail": "doc_id is required",
+            "data": {}
+        })
+
+    # 查询内容
+    rows = await contents.select_content_by_doc_id(doc_id)
+
+    if not rows:
+        return jsonify({
+            "status_code": 404,
+            "detail": "content not found",
+            "data": {}
+        })
+
+    row = rows[0]
+
+    return jsonify({
+        "status_code": 200,
+        "detail": "success",
+        "data": {
+            "title": row.get("title", ""),
+            "text": row.get("text", ""),
+            "doc_id": row.get("doc_id", "")
+        }
+    })
+
+
+@server_bp.route("/content/list", methods=["GET"])
+async def content_list():
+    resource = get_resource_manager()
+    contents = Contents(resource.mysql_client)
+
+    # 从 URL 查询参数获取分页和过滤参数
+    page_num = int(request.args.get("page", 1))
+    page_size = int(request.args.get("pageSize", 10))
+    dataset_id = request.args.get("datasetId")
+    doc_status = int(request.args.get("doc_status", 1))
+
+    # order_by 可以用 JSON 字符串传递
+    import json
+    order_by_str = request.args.get("order_by", '{"id":"desc"}')
+    try:
+        order_by = json.loads(order_by_str)
+    except Exception:
+        order_by = {"id": "desc"}
+
+    # 调用 select_contents,获取分页字典
+    result = await contents.select_contents(
+        page_num=page_num,
+        page_size=page_size,
+        dataset_id=dataset_id,
+        doc_status=doc_status,
+        order_by=order_by,
+    )
+
+    # 格式化 entities,只保留必要字段
+    entities = [
+        {
+            "doc_id": row["doc_id"],
+            "title": row.get("title") or "",
+            "text": row.get("text") or "",
+        }
+        for row in result["entities"]
+    ]
+
+    return jsonify({
+        "status_code": 200,
+        "detail": "success",
+        "data": {
+            "entities": entities,
+            "total_count": result["total_count"],
+            "page": result["page"],
+            "page_size": result["page_size"],
+            "total_pages": result["total_pages"]
+        }
+    })
+
+
+async def query_search(query_text, filters=None, search_type='', anns_field='vector_text',
+                       search_params=BASE_MILVUS_SEARCH_PARAMS, _source=False, es_size=10000, sort_by=None,
+                       milvus_size=20, limit=10):
+    if filters is None:
+        filters = {}
+    query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
+    resource = get_resource_manager()
+    search_engine = HybridSearch(
+        milvus_pool=resource.milvus_client, es_pool=resource.es_client
+    )
+    try:
+        match search_type:
+            case "base":
+                response = await search_engine.base_vector_search(
+                    query_vec=query_vector,
+                    anns_field=anns_field,
+                    search_params=search_params,
+                    limit=limit,
+                )
+                return response
+            case "hybrid":
+                response = await search_engine.hybrid_search(
+                    filters=filters,
+                    query_vec=query_vector,
+                    anns_field=anns_field,
+                    search_params=search_params,
+                    es_size=es_size,
+                    sort_by=sort_by,
+                    milvus_size=milvus_size,
+                )
+                return response
+            case "strategy":
+                return None
+            case _:
+                return None
+    except Exception as e:
+        return None
+
+
+@server_bp.route("/query", methods=["GET"])
+async def query():
+    query_text = request.args.get("query")
+    dataset_ids = request.args.get("datasetIds").split(",")
+    search_type = request.args.get("search_type", "hybrid")
+    query_results = await query_search(query_text=query_text, filters={"dataset_id": dataset_ids},
+                                       search_type=search_type)
+    resource = get_resource_manager()
+    content_chunk_mapper = ContentChunks(resource.mysql_client)
+    dataset_mapper = Dataset(resource.mysql_client)
+    res = []
+    for result in query_results['results']:
+        content_chunks = await content_chunk_mapper.select_chunk_content(doc_id=result['doc_id'],
+                                                                         chunk_id=result['chunk_id'])
+        if not content_chunks:
+            return jsonify({
+                "status_code": 500,
+                "detail": "content_chunk not found",
+                "data": {}
+            })
+        content_chunk = content_chunks[0]
+        datasets = await dataset_mapper.select_dataset_by_id(content_chunk['dataset_id'])
+        if not datasets:
+            return jsonify({
+                "status_code": 500,
+                "detail": "dataset not found",
+                "data": {}
+            })
+        dataset = datasets[0]
+        dataset_name = None
+        if dataset:
+            dataset_name = dataset['name']
+        res.append(
+            {'docId': content_chunk['doc_id'], 'content': content_chunk['text'],
+             'contentSummary': content_chunk['summary'], 'score': result['score'], 'datasetName': dataset_name})
+    data = {'results': res}
+    return jsonify({'status_code': 200,
+                    'detail': "success",
+                    'data': data})
+
+
+@server_bp.route("/chat", methods=["GET"])
+async def chat():
+    query_text = request.args.get("query")
+    dataset_ids = request.args.get("datasetIds").split(",")
+    search_type = request.args.get("search_type", "hybrid")
+    query_results = await query_search(query_text=query_text, filters={"dataset_id": dataset_ids},
+                                       search_type=search_type)
+    resource = get_resource_manager()
+    content_chunk_mapper = ContentChunks(resource.mysql_client)
+    dataset_mapper = Dataset(resource.mysql_client)
+    res = []
+    for result in query_results['results']:
+        content_chunks = await content_chunk_mapper.select_chunk_content(doc_id=result['doc_id'],
+                                                                         chunk_id=result['chunk_id'])
+        if not content_chunks:
+            return jsonify({
+                "status_code": 500,
+                "detail": "content_chunk not found",
+                "data": {}
+            })
+        content_chunk = content_chunks[0]
+        datasets = await dataset_mapper.select_dataset_by_id(content_chunk['dataset_id'])
+        if not datasets:
+            return jsonify({
+                "status_code": 500,
+                "detail": "dataset not found",
+                "data": {}
+            })
+        dataset = datasets[0]
+        dataset_name = None
+        if dataset:
+            dataset_name = dataset['name']
+        res.append(
+            {'docId': content_chunk['doc_id'], 'content': content_chunk['text'],
+             'contentSummary': content_chunk['summary'], 'score': result['score'], 'datasetName': dataset_name})
+
+    chat_classifier = ChatClassifier()
+    chat_res = await chat_classifier.chat_with_deepseek(query_text, res)
+    data = {'results': res, 'chat_res': chat_res}
+    return jsonify({'status_code': 200,
+                    'detail': "success",
+                    'data': data})