import asyncio import traceback import uuid from typing import Dict, Any from quart import Blueprint, jsonify, request from quart_cors import cors from applications.config import ( DEFAULT_MODEL, LOCAL_MODEL_CONFIG, ChunkerConfig, BASE_MILVUS_SEARCH_PARAMS, ) from applications.resource import get_resource_manager from applications.api import get_basic_embedding from applications.api import get_img_embedding from applications.async_task import ChunkEmbeddingTask, DeleteTask from applications.search import HybridSearch from applications.utils.chat import ChatClassifier from applications.utils.mysql import Dataset, Contents, ContentChunks server_bp = Blueprint("api", __name__, url_prefix="/api") server_bp = cors(server_bp, allow_origin="*") @server_bp.route("/embed", methods=["POST"]) async def embed(): body = await request.get_json() text = body.get("text") model_name = body.get("model", DEFAULT_MODEL) if not LOCAL_MODEL_CONFIG.get(model_name): return jsonify({"error": "error model"}) embedding = await get_basic_embedding(text, model_name) return jsonify({"embedding": embedding}) @server_bp.route("/img_embed", methods=["POST"]) async def img_embed(): body = await request.get_json() url_list = body.get("url_list") if not url_list: return jsonify({"error": "error url_list"}) embedding = await get_img_embedding(url_list) return jsonify(embedding) @server_bp.route("/delete", methods=["POST"]) async def delete(): body = await request.get_json() level = body.get("level") params = body.get("params") if not level or not params: return jsonify({"error": "error level or params"}) resource = get_resource_manager() delete_task = DeleteTask(resource) response = await delete_task.deal(level, params) return jsonify(response) @server_bp.route("/chunk", methods=["POST"]) async def chunk(): body = await request.get_json() text = body.get("text", "") text = text.strip() if not text: return jsonify({"error": "error text"}) resource = get_resource_manager() doc_id = f"doc-{uuid.uuid4()}" chunk_task = ChunkEmbeddingTask( resource.mysql_client, resource.milvus_client, cfg=ChunkerConfig(), doc_id=doc_id, es_pool=resource.es_client, ) doc_id = await chunk_task.deal(body) return jsonify({"doc_id": doc_id}) @server_bp.route("/search", methods=["POST"]) async def search(): """ filters: Dict[str, Any], # 条件过滤 query_vec: List[float], # query 的向量 anns_field: str = "vector_text", # query指定的向量空间 search_params: Optional[Dict[str, Any]] = None, # 向量距离方式 query_text: str = None, #是否通过 topic 倒排 _source=False, # 是否返回元数据 es_size: int = 10000, #es 第一层过滤数量 sort_by: str = None, # 排序 milvus_size: int = 10 # milvus粗排返回数量 :return: """ body = await request.get_json() # 解析数据 search_type: str = body.get("search_type") filters: Dict[str, Any] = body.get("filters", {}) anns_field: str = body.get("anns_field", "vector_text") search_params: Dict[str, Any] = body.get("search_params", BASE_MILVUS_SEARCH_PARAMS) query_text: str = body.get("query_text") _source: bool = body.get("_source", False) es_size: int = body.get("es_size", 10000) sort_by: str = body.get("sort_by") milvus_size: int = body.get("milvus", 20) limit: int = body.get("limit", 10) if not query_text: return jsonify({"error": "error query_text"}) query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL) resource = get_resource_manager() search_engine = HybridSearch( milvus_pool=resource.milvus_client, es_pool=resource.es_client ) try: match search_type: case "base": response = await search_engine.base_vector_search( query_vec=query_vector, anns_field=anns_field, search_params=search_params, limit=limit, ) return jsonify(response), 200 case "hybrid": response = await search_engine.hybrid_search( filters=filters, query_vec=query_vector, anns_field=anns_field, search_params=search_params, es_size=es_size, sort_by=sort_by, milvus_size=milvus_size, ) return jsonify(response), 200 case "strategy": return jsonify({"error": "strategy not implemented"}), 405 case _: return jsonify({"error": "error search_type"}), 200 except Exception as e: return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500 @server_bp.route("/dataset/list", methods=["GET"]) async def dataset_list(): resource = get_resource_manager() datasets = await Dataset(resource.mysql_client).select_dataset() # 创建所有任务 tasks = [ Contents(resource.mysql_client).select_count(dataset["id"]) for dataset in datasets ] counts = await asyncio.gather(*tasks) # 组装数据 data_list = [ { "dataset_id": dataset["id"], "name": dataset["name"], "count": count, "created_at": dataset["created_at"].strftime("%Y-%m-%d"), } for dataset, count in zip(datasets, counts) ] return jsonify({ "status_code": 200, "detail": "success", "data": data_list }) @server_bp.route("/dataset/add", methods=["POST"]) async def add_dataset(): resource = get_resource_manager() dataset = Dataset(resource.mysql_client) # 从请求体里取参数 body = await request.get_json() name = body.get("name") if not name: return jsonify({ "status_code": 400, "detail": "name is required" }) # 执行新增 await dataset.add_dataset(name) return jsonify({ "status_code": 200, "detail": "success" }) @server_bp.route("/content/get", methods=["GET"]) async def get_content(): resource = get_resource_manager() contents = Contents(resource.mysql_client) # 获取请求参数 doc_id = request.args.get("docId") if not doc_id: return jsonify({ "status_code": 400, "detail": "doc_id is required", "data": {} }) # 查询内容 rows = await contents.select_content_by_doc_id(doc_id) if not rows: return jsonify({ "status_code": 404, "detail": "content not found", "data": {} }) row = rows[0] return jsonify({ "status_code": 200, "detail": "success", "data": { "title": row.get("title", ""), "text": row.get("text", ""), "doc_id": row.get("doc_id", "") } }) @server_bp.route("/content/list", methods=["GET"]) async def content_list(): resource = get_resource_manager() contents = Contents(resource.mysql_client) # 从 URL 查询参数获取分页和过滤参数 page_num = int(request.args.get("page", 1)) page_size = int(request.args.get("pageSize", 10)) dataset_id = request.args.get("datasetId") doc_status = int(request.args.get("doc_status", 1)) # order_by 可以用 JSON 字符串传递 import json order_by_str = request.args.get("order_by", '{"id":"desc"}') try: order_by = json.loads(order_by_str) except Exception: order_by = {"id": "desc"} # 调用 select_contents,获取分页字典 result = await contents.select_contents( page_num=page_num, page_size=page_size, dataset_id=dataset_id, doc_status=doc_status, order_by=order_by, ) # 格式化 entities,只保留必要字段 entities = [ { "doc_id": row["doc_id"], "title": row.get("title") or "", "text": row.get("text") or "", } for row in result["entities"] ] return jsonify({ "status_code": 200, "detail": "success", "data": { "entities": entities, "total_count": result["total_count"], "page": result["page"], "page_size": result["page_size"], "total_pages": result["total_pages"] } }) async def query_search(query_text, filters=None, search_type='', anns_field='vector_text', search_params=BASE_MILVUS_SEARCH_PARAMS, _source=False, es_size=10000, sort_by=None, milvus_size=20, limit=10): if filters is None: filters = {} query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL) resource = get_resource_manager() search_engine = HybridSearch( milvus_pool=resource.milvus_client, es_pool=resource.es_client ) try: match search_type: case "base": response = await search_engine.base_vector_search( query_vec=query_vector, anns_field=anns_field, search_params=search_params, limit=limit, ) return response case "hybrid": response = await search_engine.hybrid_search( filters=filters, query_vec=query_vector, anns_field=anns_field, search_params=search_params, es_size=es_size, sort_by=sort_by, milvus_size=milvus_size, ) return response case "strategy": return None case _: return None except Exception as e: return None @server_bp.route("/query", methods=["GET"]) async def query(): query_text = request.args.get("query") dataset_ids = request.args.get("datasetIds").split(",") search_type = request.args.get("search_type", "hybrid") query_results = await query_search(query_text=query_text, filters={"dataset_id": dataset_ids}, search_type=search_type) resource = get_resource_manager() content_chunk_mapper = ContentChunks(resource.mysql_client) dataset_mapper = Dataset(resource.mysql_client) res = [] for result in query_results['results']: content_chunks = await content_chunk_mapper.select_chunk_content(doc_id=result['doc_id'], chunk_id=result['chunk_id']) if not content_chunks: return jsonify({ "status_code": 500, "detail": "content_chunk not found", "data": {} }) content_chunk = content_chunks[0] datasets = await dataset_mapper.select_dataset_by_id(content_chunk['dataset_id']) if not datasets: return jsonify({ "status_code": 500, "detail": "dataset not found", "data": {} }) dataset = datasets[0] dataset_name = None if dataset: dataset_name = dataset['name'] res.append( {'docId': content_chunk['doc_id'], 'content': content_chunk['text'], 'contentSummary': content_chunk['summary'], 'score': result['score'], 'datasetName': dataset_name}) data = {'results': res} return jsonify({'status_code': 200, 'detail': "success", 'data': data}) @server_bp.route("/chat", methods=["GET"]) async def chat(): query_text = request.args.get("query") dataset_ids = request.args.get("datasetIds").split(",") search_type = request.args.get("search_type", "hybrid") query_results = await query_search(query_text=query_text, filters={"dataset_id": dataset_ids}, search_type=search_type) resource = get_resource_manager() content_chunk_mapper = ContentChunks(resource.mysql_client) dataset_mapper = Dataset(resource.mysql_client) res = [] for result in query_results['results']: content_chunks = await content_chunk_mapper.select_chunk_content(doc_id=result['doc_id'], chunk_id=result['chunk_id']) if not content_chunks: return jsonify({ "status_code": 500, "detail": "content_chunk not found", "data": {} }) content_chunk = content_chunks[0] datasets = await dataset_mapper.select_dataset_by_id(content_chunk['dataset_id']) if not datasets: return jsonify({ "status_code": 500, "detail": "dataset not found", "data": {} }) dataset = datasets[0] dataset_name = None if dataset: dataset_name = dataset['name'] res.append( {'docId': content_chunk['doc_id'], 'content': content_chunk['text'], 'contentSummary': content_chunk['summary'], 'score': result['score'], 'datasetName': dataset_name}) chat_classifier = ChatClassifier() chat_res = await chat_classifier.chat_with_deepseek(query_text, res) data = {'results': res, 'chat_res': chat_res} return jsonify({'status_code': 200, 'detail': "success", 'data': data})