|
@@ -1,8 +1,10 @@
|
|
|
|
+import asyncio
|
|
import traceback
|
|
import traceback
|
|
import uuid
|
|
import uuid
|
|
from typing import Dict, Any
|
|
from typing import Dict, Any
|
|
|
|
|
|
from quart import Blueprint, jsonify, request
|
|
from quart import Blueprint, jsonify, request
|
|
|
|
+from quart_cors import cors
|
|
|
|
|
|
from applications.config import (
|
|
from applications.config import (
|
|
DEFAULT_MODEL,
|
|
DEFAULT_MODEL,
|
|
@@ -15,9 +17,11 @@ from applications.api import get_basic_embedding
|
|
from applications.api import get_img_embedding
|
|
from applications.api import get_img_embedding
|
|
from applications.async_task import ChunkEmbeddingTask, DeleteTask
|
|
from applications.async_task import ChunkEmbeddingTask, DeleteTask
|
|
from applications.search import HybridSearch
|
|
from applications.search import HybridSearch
|
|
-
|
|
|
|
|
|
+from applications.utils.chat import ChatClassifier
|
|
|
|
+from applications.utils.mysql import Dataset, Contents, ContentChunks
|
|
|
|
|
|
server_bp = Blueprint("api", __name__, url_prefix="/api")
|
|
server_bp = Blueprint("api", __name__, url_prefix="/api")
|
|
|
|
+server_bp = cors(server_bp, allow_origin="*")
|
|
|
|
|
|
|
|
|
|
@server_bp.route("/embed", methods=["POST"])
|
|
@server_bp.route("/embed", methods=["POST"])
|
|
@@ -138,3 +142,265 @@ async def search():
|
|
return jsonify({"error": "error search_type"}), 200
|
|
return jsonify({"error": "error search_type"}), 200
|
|
except Exception as e:
|
|
except Exception as e:
|
|
return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500
|
|
return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@server_bp.route("/dataset/list", methods=["GET"])
|
|
|
|
+async def dataset_list():
|
|
|
|
+ resource = get_resource_manager()
|
|
|
|
+ datasets = await Dataset(resource.mysql_client).select_dataset()
|
|
|
|
+
|
|
|
|
+ # 创建所有任务
|
|
|
|
+ tasks = [
|
|
|
|
+ Contents(resource.mysql_client).select_count(dataset["id"])
|
|
|
|
+ for dataset in datasets
|
|
|
|
+ ]
|
|
|
|
+ counts = await asyncio.gather(*tasks)
|
|
|
|
+
|
|
|
|
+ # 组装数据
|
|
|
|
+ data_list = [
|
|
|
|
+ {
|
|
|
|
+ "dataset_id": dataset["id"],
|
|
|
|
+ "name": dataset["name"],
|
|
|
|
+ "count": count,
|
|
|
|
+ "created_at": dataset["created_at"].strftime("%Y-%m-%d"),
|
|
|
|
+ }
|
|
|
|
+ for dataset, count in zip(datasets, counts)
|
|
|
|
+ ]
|
|
|
|
+
|
|
|
|
+ return jsonify({
|
|
|
|
+ "status_code": 200,
|
|
|
|
+ "detail": "success",
|
|
|
|
+ "data": data_list
|
|
|
|
+ })
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@server_bp.route("/dataset/add", methods=["POST"])
|
|
|
|
+async def add_dataset():
|
|
|
|
+ resource = get_resource_manager()
|
|
|
|
+ dataset = Dataset(resource.mysql_client)
|
|
|
|
+ # 从请求体里取参数
|
|
|
|
+ body = await request.get_json()
|
|
|
|
+ name = body.get("name")
|
|
|
|
+ if not name:
|
|
|
|
+ return jsonify({
|
|
|
|
+ "status_code": 400,
|
|
|
|
+ "detail": "name is required"
|
|
|
|
+ })
|
|
|
|
+ # 执行新增
|
|
|
|
+ await dataset.add_dataset(name)
|
|
|
|
+ return jsonify({
|
|
|
|
+ "status_code": 200,
|
|
|
|
+ "detail": "success"
|
|
|
|
+ })
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@server_bp.route("/content/get", methods=["GET"])
|
|
|
|
+async def get_content():
|
|
|
|
+ resource = get_resource_manager()
|
|
|
|
+ contents = Contents(resource.mysql_client)
|
|
|
|
+
|
|
|
|
+ # 获取请求参数
|
|
|
|
+ doc_id = request.args.get("docId")
|
|
|
|
+ if not doc_id:
|
|
|
|
+ return jsonify({
|
|
|
|
+ "status_code": 400,
|
|
|
|
+ "detail": "doc_id is required",
|
|
|
|
+ "data": {}
|
|
|
|
+ })
|
|
|
|
+
|
|
|
|
+ # 查询内容
|
|
|
|
+ rows = await contents.select_content_by_doc_id(doc_id)
|
|
|
|
+
|
|
|
|
+ if not rows:
|
|
|
|
+ return jsonify({
|
|
|
|
+ "status_code": 404,
|
|
|
|
+ "detail": "content not found",
|
|
|
|
+ "data": {}
|
|
|
|
+ })
|
|
|
|
+
|
|
|
|
+ row = rows[0]
|
|
|
|
+
|
|
|
|
+ return jsonify({
|
|
|
|
+ "status_code": 200,
|
|
|
|
+ "detail": "success",
|
|
|
|
+ "data": {
|
|
|
|
+ "title": row.get("title", ""),
|
|
|
|
+ "text": row.get("text", ""),
|
|
|
|
+ "doc_id": row.get("doc_id", "")
|
|
|
|
+ }
|
|
|
|
+ })
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@server_bp.route("/content/list", methods=["GET"])
|
|
|
|
+async def content_list():
|
|
|
|
+ resource = get_resource_manager()
|
|
|
|
+ contents = Contents(resource.mysql_client)
|
|
|
|
+
|
|
|
|
+ # 从 URL 查询参数获取分页和过滤参数
|
|
|
|
+ page_num = int(request.args.get("page", 1))
|
|
|
|
+ page_size = int(request.args.get("pageSize", 10))
|
|
|
|
+ dataset_id = request.args.get("datasetId")
|
|
|
|
+ doc_status = int(request.args.get("doc_status", 1))
|
|
|
|
+
|
|
|
|
+ # order_by 可以用 JSON 字符串传递
|
|
|
|
+ import json
|
|
|
|
+ order_by_str = request.args.get("order_by", '{"id":"desc"}')
|
|
|
|
+ try:
|
|
|
|
+ order_by = json.loads(order_by_str)
|
|
|
|
+ except Exception:
|
|
|
|
+ order_by = {"id": "desc"}
|
|
|
|
+
|
|
|
|
+ # 调用 select_contents,获取分页字典
|
|
|
|
+ result = await contents.select_contents(
|
|
|
|
+ page_num=page_num,
|
|
|
|
+ page_size=page_size,
|
|
|
|
+ dataset_id=dataset_id,
|
|
|
|
+ doc_status=doc_status,
|
|
|
|
+ order_by=order_by,
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ # 格式化 entities,只保留必要字段
|
|
|
|
+ entities = [
|
|
|
|
+ {
|
|
|
|
+ "doc_id": row["doc_id"],
|
|
|
|
+ "title": row.get("title") or "",
|
|
|
|
+ "text": row.get("text") or "",
|
|
|
|
+ }
|
|
|
|
+ for row in result["entities"]
|
|
|
|
+ ]
|
|
|
|
+
|
|
|
|
+ return jsonify({
|
|
|
|
+ "status_code": 200,
|
|
|
|
+ "detail": "success",
|
|
|
|
+ "data": {
|
|
|
|
+ "entities": entities,
|
|
|
|
+ "total_count": result["total_count"],
|
|
|
|
+ "page": result["page"],
|
|
|
|
+ "page_size": result["page_size"],
|
|
|
|
+ "total_pages": result["total_pages"]
|
|
|
|
+ }
|
|
|
|
+ })
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+async def query_search(query_text, filters=None, search_type='', anns_field='vector_text',
|
|
|
|
+ search_params=BASE_MILVUS_SEARCH_PARAMS, _source=False, es_size=10000, sort_by=None,
|
|
|
|
+ milvus_size=20, limit=10):
|
|
|
|
+ if filters is None:
|
|
|
|
+ filters = {}
|
|
|
|
+ query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
|
|
|
|
+ resource = get_resource_manager()
|
|
|
|
+ search_engine = HybridSearch(
|
|
|
|
+ milvus_pool=resource.milvus_client, es_pool=resource.es_client
|
|
|
|
+ )
|
|
|
|
+ try:
|
|
|
|
+ match search_type:
|
|
|
|
+ case "base":
|
|
|
|
+ response = await search_engine.base_vector_search(
|
|
|
|
+ query_vec=query_vector,
|
|
|
|
+ anns_field=anns_field,
|
|
|
|
+ search_params=search_params,
|
|
|
|
+ limit=limit,
|
|
|
|
+ )
|
|
|
|
+ return response
|
|
|
|
+ case "hybrid":
|
|
|
|
+ response = await search_engine.hybrid_search(
|
|
|
|
+ filters=filters,
|
|
|
|
+ query_vec=query_vector,
|
|
|
|
+ anns_field=anns_field,
|
|
|
|
+ search_params=search_params,
|
|
|
|
+ es_size=es_size,
|
|
|
|
+ sort_by=sort_by,
|
|
|
|
+ milvus_size=milvus_size,
|
|
|
|
+ )
|
|
|
|
+ return response
|
|
|
|
+ case "strategy":
|
|
|
|
+ return None
|
|
|
|
+ case _:
|
|
|
|
+ return None
|
|
|
|
+ except Exception as e:
|
|
|
|
+ return None
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@server_bp.route("/query", methods=["GET"])
|
|
|
|
+async def query():
|
|
|
|
+ query_text = request.args.get("query")
|
|
|
|
+ dataset_ids = request.args.get("datasetIds").split(",")
|
|
|
|
+ search_type = request.args.get("search_type", "hybrid")
|
|
|
|
+ query_results = await query_search(query_text=query_text, filters={"dataset_id": dataset_ids},
|
|
|
|
+ search_type=search_type)
|
|
|
|
+ resource = get_resource_manager()
|
|
|
|
+ content_chunk_mapper = ContentChunks(resource.mysql_client)
|
|
|
|
+ dataset_mapper = Dataset(resource.mysql_client)
|
|
|
|
+ res = []
|
|
|
|
+ for result in query_results['results']:
|
|
|
|
+ content_chunks = await content_chunk_mapper.select_chunk_content(doc_id=result['doc_id'],
|
|
|
|
+ chunk_id=result['chunk_id'])
|
|
|
|
+ if not content_chunks:
|
|
|
|
+ return jsonify({
|
|
|
|
+ "status_code": 500,
|
|
|
|
+ "detail": "content_chunk not found",
|
|
|
|
+ "data": {}
|
|
|
|
+ })
|
|
|
|
+ content_chunk = content_chunks[0]
|
|
|
|
+ datasets = await dataset_mapper.select_dataset_by_id(content_chunk['dataset_id'])
|
|
|
|
+ if not datasets:
|
|
|
|
+ return jsonify({
|
|
|
|
+ "status_code": 500,
|
|
|
|
+ "detail": "dataset not found",
|
|
|
|
+ "data": {}
|
|
|
|
+ })
|
|
|
|
+ dataset = datasets[0]
|
|
|
|
+ dataset_name = None
|
|
|
|
+ if dataset:
|
|
|
|
+ dataset_name = dataset['name']
|
|
|
|
+ res.append(
|
|
|
|
+ {'docId': content_chunk['doc_id'], 'content': content_chunk['text'],
|
|
|
|
+ 'contentSummary': content_chunk['summary'], 'score': result['score'], 'datasetName': dataset_name})
|
|
|
|
+ data = {'results': res}
|
|
|
|
+ return jsonify({'status_code': 200,
|
|
|
|
+ 'detail': "success",
|
|
|
|
+ 'data': data})
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@server_bp.route("/chat", methods=["GET"])
|
|
|
|
+async def chat():
|
|
|
|
+ query_text = request.args.get("query")
|
|
|
|
+ dataset_ids = request.args.get("datasetIds").split(",")
|
|
|
|
+ search_type = request.args.get("search_type", "hybrid")
|
|
|
|
+ query_results = await query_search(query_text=query_text, filters={"dataset_id": dataset_ids},
|
|
|
|
+ search_type=search_type)
|
|
|
|
+ resource = get_resource_manager()
|
|
|
|
+ content_chunk_mapper = ContentChunks(resource.mysql_client)
|
|
|
|
+ dataset_mapper = Dataset(resource.mysql_client)
|
|
|
|
+ res = []
|
|
|
|
+ for result in query_results['results']:
|
|
|
|
+ content_chunks = await content_chunk_mapper.select_chunk_content(doc_id=result['doc_id'],
|
|
|
|
+ chunk_id=result['chunk_id'])
|
|
|
|
+ if not content_chunks:
|
|
|
|
+ return jsonify({
|
|
|
|
+ "status_code": 500,
|
|
|
|
+ "detail": "content_chunk not found",
|
|
|
|
+ "data": {}
|
|
|
|
+ })
|
|
|
|
+ content_chunk = content_chunks[0]
|
|
|
|
+ datasets = await dataset_mapper.select_dataset_by_id(content_chunk['dataset_id'])
|
|
|
|
+ if not datasets:
|
|
|
|
+ return jsonify({
|
|
|
|
+ "status_code": 500,
|
|
|
|
+ "detail": "dataset not found",
|
|
|
|
+ "data": {}
|
|
|
|
+ })
|
|
|
|
+ dataset = datasets[0]
|
|
|
|
+ dataset_name = None
|
|
|
|
+ if dataset:
|
|
|
|
+ dataset_name = dataset['name']
|
|
|
|
+ res.append(
|
|
|
|
+ {'docId': content_chunk['doc_id'], 'content': content_chunk['text'],
|
|
|
|
+ 'contentSummary': content_chunk['summary'], 'score': result['score'], 'datasetName': dataset_name})
|
|
|
|
+
|
|
|
|
+ chat_classifier = ChatClassifier()
|
|
|
|
+ chat_res = await chat_classifier.chat_with_deepseek(query_text, res)
|
|
|
|
+ data = {'results': res, 'chat_res': chat_res}
|
|
|
|
+ return jsonify({'status_code': 200,
|
|
|
|
+ 'detail': "success",
|
|
|
|
+ 'data': data})
|