Explorar el Código

增加api接口

xueyiming hace 2 semanas
padre
commit
e25c748218

+ 7 - 0
applications/utils/chat/__init__.py

@@ -0,0 +1,7 @@
+from applications.utils.chat.chat_classifier import ChatClassifier
+
+
+__all__ = [
+    "ChatClassifier"
+]
+

+ 57 - 0
applications/utils/chat/chat_classifier.py

@@ -0,0 +1,57 @@
+from typing import List
+
+from applications.config import Chunk
+from applications.api import fetch_deepseek_completion
+
+
+class ChatClassifier:
+    @staticmethod
+    def generate_summary_prompt(query, search_results):
+        """
+        生成总结的prompt。
+
+        :param query: 问题
+        :param search_results: 搜索结果列表,每个元素包含 'content', 'contentSummary', 'score'
+        :return: 生成的总结prompt
+        """
+
+        # 为了让AI更好地理解,我们将使用以下格式构建prompt:
+        prompt = f"问题: {query}\n\n请结合以下搜索结果,生成一个总结:\n"
+
+        # 先生成基于相似度加权的summary
+        weighted_summaries = []
+        weighted_contents = []
+
+        for result in search_results:
+            content = result['content']
+            content_summary = result['contentSummary']
+            score = result['score']
+
+            # 加权内容摘要和内容
+            weighted_summaries.append((content_summary, score))
+            weighted_contents.append((content, score))
+
+        # 为了生成更准确的总结,基于相似度加权内容和摘要
+        weighted_summaries.sort(key=lambda x: x[1], reverse=True)  # 按相似度降序排列
+        weighted_contents.sort(key=lambda x: x[1], reverse=True)  # 按相似度降序排列
+
+        # 将加权的摘要和内容加入到prompt中
+        prompt += "\n-- 加权内容摘要 --\n"
+        for summary, score in weighted_summaries:
+            prompt += f"摘要: {summary} | 相似度: {score:.2f}\n"
+
+        prompt += "\n-- 加权内容 --\n"
+        for content, score in weighted_contents:
+            prompt += f"内容: {content} | 相似度: {score:.2f}\n"
+
+        # 最后请求AI进行总结
+        prompt += "\n基于上述内容,请帮我生成一个简洁的总结。"
+
+        return prompt
+
+    async def chat_with_deepseek(self, query, search_results):
+        prompt = self.generate_summary_prompt(query, search_results)
+        response = await fetch_deepseek_completion(
+            model="DeepSeek-V3", prompt=prompt
+        )
+        return response

+ 111 - 0
applications/utils/mysql/mapper.py

@@ -1,4 +1,6 @@
 import json
 import json
+from datetime import datetime
+
 from applications.config import Chunk
 from applications.config import Chunk
 
 
 
 
@@ -24,6 +26,33 @@ class Dataset(BaseMySQLClient):
             params=(new_status, dataset_id, ori_status),
             params=(new_status, dataset_id, ori_status),
         )
         )
 
 
+    async def select_dataset(self, status=1):
+        query = """
+            select * from dataset where status = %s;
+        """
+        return await self.pool.async_fetch(
+            query=query,
+            params=(status,)
+        )
+
+    async def add_dataset(self, name):
+        query = """
+            insert into dataset (name, created_at, updated_at, status) values (%s, %s, %s, %s);
+        """
+        return await self.pool.async_save(
+            query=query,
+            params=(name, datetime.now(), datetime.now(), 1)
+        )
+
+    async def select_dataset_by_id(self, id, status=1):
+        query = """
+            select * from dataset where id = %s and status = %s;
+        """
+        return await self.pool.async_fetch(
+            query=query,
+            params=(id, status)
+        )
+
 
 
 class Contents(BaseMySQLClient):
 class Contents(BaseMySQLClient):
     async def insert_content(self, doc_id, text, text_type, title, dataset_id):
     async def insert_content(self, doc_id, text, text_type, title, dataset_id):
@@ -72,6 +101,80 @@ class Contents(BaseMySQLClient):
             query=query, params=(new_status, doc_id, ori_status)
             query=query, params=(new_status, doc_id, ori_status)
         )
         )
 
 
+    async def select_count(self, dataset_id, doc_status=1):
+        query = """
+            select count(*) as count from contents where dataset_id = %s and doc_status = %s;
+        """
+        rows = await self.pool.async_fetch(query=query, params=(dataset_id, doc_status))
+        return rows[0]["count"] if rows else 0
+
+    async def select_content_by_doc_id(self, doc_id):
+        query = """
+            select * from contents where doc_id = %s;
+        """
+        return await self.pool.async_fetch(
+            query=query,
+            params=(doc_id,)
+        )
+
+    async def select_contents(
+            self,
+            page_num: int,
+            page_size: int,
+            order_by: dict = {"id": "desc"},
+            dataset_id: int = None,
+            doc_status: int = 1,
+    ):
+        """
+        分页查询 contents 表,并返回分页信息
+        :param page_num: 页码,从 1 开始
+        :param page_size: 每页数量
+        :param order_by: 排序条件,例如 {"id": "desc"} 或 {"created_at": "asc"}
+        :param dataset_id: 数据集 ID
+        :param doc_status: 文档状态(默认 1)
+        :return: dict,包含 entities、total_count、page、page_size、total_pages
+        """
+        offset = (page_num - 1) * page_size
+
+        # 动态拼接 where 条件
+        where_clauses = ["doc_status = %s"]
+        params = [doc_status]
+
+        if dataset_id:
+            where_clauses.append("dataset_id = %s")
+            params.append(dataset_id)
+
+        where_sql = " AND ".join(where_clauses)
+
+        # 动态拼接 order by
+        order_field, order_direction = list(order_by.items())[0]
+        order_sql = f"ORDER BY {order_field} {order_direction.upper()}"
+
+        # 查询总数
+        count_query = f"SELECT COUNT(*) as total_count FROM contents WHERE {where_sql};"
+        count_result = await self.pool.async_fetch(query=count_query, params=tuple(params))
+        total_count = count_result[0]["total_count"] if count_result else 0
+
+        # 查询分页数据
+        query = f"""
+            SELECT * FROM contents
+            WHERE {where_sql}
+            {order_sql}
+            LIMIT %s OFFSET %s;
+        """
+        params.extend([page_size, offset])
+        entities = await self.pool.async_fetch(query=query, params=tuple(params))
+
+        total_pages = (total_count + page_size - 1) // page_size  # 向上取整
+
+        return {
+            "entities": entities,
+            "total_count": total_count,
+            "page": page_num,
+            "page_size": page_size,
+            "total_pages": total_pages
+        }
+
 
 
 class ContentChunks(BaseMySQLClient):
 class ContentChunks(BaseMySQLClient):
     async def insert_chunk(self, chunk: Chunk) -> int:
     async def insert_chunk(self, chunk: Chunk) -> int:
@@ -173,3 +276,11 @@ class ContentChunks(BaseMySQLClient):
         return await self.pool.async_save(
         return await self.pool.async_save(
             query=query, params=(new_status, dataset_id, ori_status)
             query=query, params=(new_status, dataset_id, ori_status)
         )
         )
+
+    async def select_chunk_content(self, doc_id, chunk_id, status=1):
+        query = """
+            select * from content_chunks where doc_id = %s and chunk_id = %s and status = %s;
+        """
+        return await self.pool.async_fetch(
+            query=query, params=(doc_id, chunk_id, status)
+        )

+ 267 - 1
routes/buleprint.py

@@ -1,8 +1,10 @@
+import asyncio
 import traceback
 import traceback
 import uuid
 import uuid
 from typing import Dict, Any
 from typing import Dict, Any
 
 
 from quart import Blueprint, jsonify, request
 from quart import Blueprint, jsonify, request
+from quart_cors import cors
 
 
 from applications.config import (
 from applications.config import (
     DEFAULT_MODEL,
     DEFAULT_MODEL,
@@ -15,9 +17,11 @@ from applications.api import get_basic_embedding
 from applications.api import get_img_embedding
 from applications.api import get_img_embedding
 from applications.async_task import ChunkEmbeddingTask, DeleteTask
 from applications.async_task import ChunkEmbeddingTask, DeleteTask
 from applications.search import HybridSearch
 from applications.search import HybridSearch
-
+from applications.utils.chat import ChatClassifier
+from applications.utils.mysql import Dataset, Contents, ContentChunks
 
 
 server_bp = Blueprint("api", __name__, url_prefix="/api")
 server_bp = Blueprint("api", __name__, url_prefix="/api")
+server_bp = cors(server_bp, allow_origin="*")
 
 
 
 
 @server_bp.route("/embed", methods=["POST"])
 @server_bp.route("/embed", methods=["POST"])
@@ -138,3 +142,265 @@ async def search():
                 return jsonify({"error": "error  search_type"}), 200
                 return jsonify({"error": "error  search_type"}), 200
     except Exception as e:
     except Exception as e:
         return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500
         return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500
+
+
+@server_bp.route("/dataset/list", methods=["GET"])
+async def dataset_list():
+    resource = get_resource_manager()
+    datasets = await Dataset(resource.mysql_client).select_dataset()
+
+    # 创建所有任务
+    tasks = [
+        Contents(resource.mysql_client).select_count(dataset["id"])
+        for dataset in datasets
+    ]
+    counts = await asyncio.gather(*tasks)
+
+    # 组装数据
+    data_list = [
+        {
+            "dataset_id": dataset["id"],
+            "name": dataset["name"],
+            "count": count,
+            "created_at": dataset["created_at"].strftime("%Y-%m-%d"),
+        }
+        for dataset, count in zip(datasets, counts)
+    ]
+
+    return jsonify({
+        "status_code": 200,
+        "detail": "success",
+        "data": data_list
+    })
+
+
+@server_bp.route("/dataset/add", methods=["POST"])
+async def add_dataset():
+    resource = get_resource_manager()
+    dataset = Dataset(resource.mysql_client)
+    # 从请求体里取参数
+    body = await request.get_json()
+    name = body.get("name")
+    if not name:
+        return jsonify({
+            "status_code": 400,
+            "detail": "name is required"
+        })
+    # 执行新增
+    await dataset.add_dataset(name)
+    return jsonify({
+        "status_code": 200,
+        "detail": "success"
+    })
+
+
+@server_bp.route("/content/get", methods=["GET"])
+async def get_content():
+    resource = get_resource_manager()
+    contents = Contents(resource.mysql_client)
+
+    # 获取请求参数
+    doc_id = request.args.get("docId")
+    if not doc_id:
+        return jsonify({
+            "status_code": 400,
+            "detail": "doc_id is required",
+            "data": {}
+        })
+
+    # 查询内容
+    rows = await contents.select_content_by_doc_id(doc_id)
+
+    if not rows:
+        return jsonify({
+            "status_code": 404,
+            "detail": "content not found",
+            "data": {}
+        })
+
+    row = rows[0]
+
+    return jsonify({
+        "status_code": 200,
+        "detail": "success",
+        "data": {
+            "title": row.get("title", ""),
+            "text": row.get("text", ""),
+            "doc_id": row.get("doc_id", "")
+        }
+    })
+
+
+@server_bp.route("/content/list", methods=["GET"])
+async def content_list():
+    resource = get_resource_manager()
+    contents = Contents(resource.mysql_client)
+
+    # 从 URL 查询参数获取分页和过滤参数
+    page_num = int(request.args.get("page", 1))
+    page_size = int(request.args.get("pageSize", 10))
+    dataset_id = request.args.get("datasetId")
+    doc_status = int(request.args.get("doc_status", 1))
+
+    # order_by 可以用 JSON 字符串传递
+    import json
+    order_by_str = request.args.get("order_by", '{"id":"desc"}')
+    try:
+        order_by = json.loads(order_by_str)
+    except Exception:
+        order_by = {"id": "desc"}
+
+    # 调用 select_contents,获取分页字典
+    result = await contents.select_contents(
+        page_num=page_num,
+        page_size=page_size,
+        dataset_id=dataset_id,
+        doc_status=doc_status,
+        order_by=order_by,
+    )
+
+    # 格式化 entities,只保留必要字段
+    entities = [
+        {
+            "doc_id": row["doc_id"],
+            "title": row.get("title") or "",
+            "text": row.get("text") or "",
+        }
+        for row in result["entities"]
+    ]
+
+    return jsonify({
+        "status_code": 200,
+        "detail": "success",
+        "data": {
+            "entities": entities,
+            "total_count": result["total_count"],
+            "page": result["page"],
+            "page_size": result["page_size"],
+            "total_pages": result["total_pages"]
+        }
+    })
+
+
+async def query_search(query_text, filters=None, search_type='', anns_field='vector_text',
+                       search_params=BASE_MILVUS_SEARCH_PARAMS, _source=False, es_size=10000, sort_by=None,
+                       milvus_size=20, limit=10):
+    if filters is None:
+        filters = {}
+    query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
+    resource = get_resource_manager()
+    search_engine = HybridSearch(
+        milvus_pool=resource.milvus_client, es_pool=resource.es_client
+    )
+    try:
+        match search_type:
+            case "base":
+                response = await search_engine.base_vector_search(
+                    query_vec=query_vector,
+                    anns_field=anns_field,
+                    search_params=search_params,
+                    limit=limit,
+                )
+                return response
+            case "hybrid":
+                response = await search_engine.hybrid_search(
+                    filters=filters,
+                    query_vec=query_vector,
+                    anns_field=anns_field,
+                    search_params=search_params,
+                    es_size=es_size,
+                    sort_by=sort_by,
+                    milvus_size=milvus_size,
+                )
+                return response
+            case "strategy":
+                return None
+            case _:
+                return None
+    except Exception as e:
+        return None
+
+
+@server_bp.route("/query", methods=["GET"])
+async def query():
+    query_text = request.args.get("query")
+    dataset_ids = request.args.get("datasetIds").split(",")
+    search_type = request.args.get("search_type", "hybrid")
+    query_results = await query_search(query_text=query_text, filters={"dataset_id": dataset_ids},
+                                       search_type=search_type)
+    resource = get_resource_manager()
+    content_chunk_mapper = ContentChunks(resource.mysql_client)
+    dataset_mapper = Dataset(resource.mysql_client)
+    res = []
+    for result in query_results['results']:
+        content_chunks = await content_chunk_mapper.select_chunk_content(doc_id=result['doc_id'],
+                                                                         chunk_id=result['chunk_id'])
+        if not content_chunks:
+            return jsonify({
+                "status_code": 500,
+                "detail": "content_chunk not found",
+                "data": {}
+            })
+        content_chunk = content_chunks[0]
+        datasets = await dataset_mapper.select_dataset_by_id(content_chunk['dataset_id'])
+        if not datasets:
+            return jsonify({
+                "status_code": 500,
+                "detail": "dataset not found",
+                "data": {}
+            })
+        dataset = datasets[0]
+        dataset_name = None
+        if dataset:
+            dataset_name = dataset['name']
+        res.append(
+            {'docId': content_chunk['doc_id'], 'content': content_chunk['text'],
+             'contentSummary': content_chunk['summary'], 'score': result['score'], 'datasetName': dataset_name})
+    data = {'results': res}
+    return jsonify({'status_code': 200,
+                    'detail': "success",
+                    'data': data})
+
+
+@server_bp.route("/chat", methods=["GET"])
+async def chat():
+    query_text = request.args.get("query")
+    dataset_ids = request.args.get("datasetIds").split(",")
+    search_type = request.args.get("search_type", "hybrid")
+    query_results = await query_search(query_text=query_text, filters={"dataset_id": dataset_ids},
+                                       search_type=search_type)
+    resource = get_resource_manager()
+    content_chunk_mapper = ContentChunks(resource.mysql_client)
+    dataset_mapper = Dataset(resource.mysql_client)
+    res = []
+    for result in query_results['results']:
+        content_chunks = await content_chunk_mapper.select_chunk_content(doc_id=result['doc_id'],
+                                                                         chunk_id=result['chunk_id'])
+        if not content_chunks:
+            return jsonify({
+                "status_code": 500,
+                "detail": "content_chunk not found",
+                "data": {}
+            })
+        content_chunk = content_chunks[0]
+        datasets = await dataset_mapper.select_dataset_by_id(content_chunk['dataset_id'])
+        if not datasets:
+            return jsonify({
+                "status_code": 500,
+                "detail": "dataset not found",
+                "data": {}
+            })
+        dataset = datasets[0]
+        dataset_name = None
+        if dataset:
+            dataset_name = dataset['name']
+        res.append(
+            {'docId': content_chunk['doc_id'], 'content': content_chunk['text'],
+             'contentSummary': content_chunk['summary'], 'score': result['score'], 'datasetName': dataset_name})
+
+    chat_classifier = ChatClassifier()
+    chat_res = await chat_classifier.chat_with_deepseek(query_text, res)
+    data = {'results': res, 'chat_res': chat_res}
+    return jsonify({'status_code': 200,
+                    'detail': "success",
+                    'data': data})