123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586 |
- import asyncio
- import json
- import traceback
- import uuid
- from typing import Dict, Any
- from quart import Blueprint, jsonify, request
- from quart_cors import cors
- from applications.api import get_basic_embedding
- from applications.api import get_img_embedding
- from applications.async_task import AutoRechunkTask, BuildGraph
- from applications.async_task import ChunkEmbeddingTask, DeleteTask
- from applications.config import (
- DEFAULT_MODEL,
- LOCAL_MODEL_CONFIG,
- BASE_MILVUS_SEARCH_PARAMS,
- )
- from applications.resource import get_resource_manager
- from applications.search import HybridSearch
- from applications.utils.chat import RAGChatAgent
- from applications.utils.mysql import Dataset, Contents, ContentChunks, ChatResult
- from applications.utils.spider.study import study
- server_bp = Blueprint("api", __name__, url_prefix="/api")
- server_bp = cors(server_bp, allow_origin="*")
- @server_bp.route("/embed", methods=["POST"])
- async def embed():
- body = await request.get_json()
- text = body.get("text")
- model_name = body.get("model", DEFAULT_MODEL)
- if not LOCAL_MODEL_CONFIG.get(model_name):
- return jsonify({"error": "error model"})
- embedding = await get_basic_embedding(text, model_name)
- return jsonify({"embedding": embedding})
- @server_bp.route("/img_embed", methods=["POST"])
- async def img_embed():
- body = await request.get_json()
- url_list = body.get("url_list")
- if not url_list:
- return jsonify({"error": "error url_list"})
- embedding = await get_img_embedding(url_list)
- return jsonify(embedding)
- @server_bp.route("/delete", methods=["POST"])
- async def delete():
- body = await request.get_json()
- level = body.get("level")
- params = body.get("params")
- if not level or not params:
- return jsonify({"error": "error level or params"})
- resource = get_resource_manager()
- del_task = DeleteTask(resource)
- response = await del_task.deal(level, params)
- return jsonify(response)
- @server_bp.route("/chunk", methods=["POST"])
- async def chunk():
- body = await request.get_json()
- text = body.get("text", "")
- ori_doc_id = body.get("doc_id")
- text = text.strip()
- if not text:
- return jsonify({"error": "error text"})
- resource = get_resource_manager()
- # generate doc id
- if ori_doc_id:
- body["re_chunk"] = True
- doc_id = ori_doc_id
- else:
- doc_id = f"doc-{uuid.uuid4()}"
- chunk_task = ChunkEmbeddingTask(doc_id=doc_id, resource=resource)
- doc_id = await chunk_task.deal(body)
- return jsonify({"doc_id": doc_id})
- @server_bp.route("/search", methods=["POST"])
- async def search():
- """
- filters: Dict[str, Any], # 条件过滤
- query_vec: List[float], # query 的向量
- anns_field: str = "vector_text", # query指定的向量空间
- search_params: Optional[Dict[str, Any]] = None, # 向量距离方式
- query_text: str = None, #是否通过 topic 倒排
- _source=False, # 是否返回元数据
- es_size: int = 10000, #es 第一层过滤数量
- sort_by: str = None, # 排序
- milvus_size: int = 10 # milvus粗排返回数量
- :return:
- """
- body = await request.get_json()
- # 解析数据
- search_type: str = body.get("search_type")
- filters: Dict[str, Any] = body.get("filters", {})
- anns_field: str = body.get("anns_field", "vector_text")
- search_params: Dict[str, Any] = body.get("search_params", BASE_MILVUS_SEARCH_PARAMS)
- query_text: str = body.get("query_text")
- _source: bool = body.get("_source", False)
- es_size: int = body.get("es_size", 10000)
- sort_by: str = body.get("sort_by")
- milvus_size: int = body.get("milvus", 20)
- limit: int = body.get("limit", 10)
- path_between_chunks: dict = body.get("path_between_chunks", {})
- if not query_text:
- return jsonify({"error": "error query_text"})
- query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
- resource = get_resource_manager()
- search_engine = HybridSearch(
- milvus_pool=resource.milvus_client,
- es_pool=resource.es_client,
- graph_pool=resource.graph_client,
- )
- try:
- match search_type:
- case "base":
- response = await search_engine.base_vector_search(
- query_vec=query_vector,
- anns_field=anns_field,
- search_params=search_params,
- limit=limit,
- )
- return jsonify(response), 200
- case "hybrid":
- response = await search_engine.hybrid_search(
- filters=filters,
- query_vec=query_vector,
- anns_field=anns_field,
- search_params=search_params,
- es_size=es_size,
- sort_by=sort_by,
- milvus_size=milvus_size,
- )
- return jsonify(response), 200
- case "hybrid2":
- co_fields = {"Entity": filters["entities"][0]}
- response = await search_engine.hybrid_search_with_graph(
- filters=filters,
- query_vec=query_vector,
- anns_field=anns_field,
- search_params=search_params,
- es_size=es_size,
- sort_by=sort_by,
- milvus_size=milvus_size,
- co_occurrence_fields=co_fields,
- shortest_path_fields=path_between_chunks,
- )
- return jsonify(response), 200
- case "strategy":
- return jsonify({"error": "strategy not implemented"}), 405
- case _:
- return jsonify({"error": "error search_type"}), 200
- except Exception as e:
- return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500
- @server_bp.route("/dataset/list", methods=["GET"])
- async def dataset_list():
- resource = get_resource_manager()
- datasets = await Dataset(resource.mysql_client).select_dataset()
- # 创建所有任务
- tasks = [
- Contents(resource.mysql_client).select_count(dataset["id"])
- for dataset in datasets
- ]
- counts = await asyncio.gather(*tasks)
- # 组装数据
- data_list = [
- {
- "dataset_id": dataset["id"],
- "name": dataset["name"],
- "count": count,
- "created_at": dataset["created_at"].strftime("%Y-%m-%d"),
- }
- for dataset, count in zip(datasets, counts)
- ]
- return jsonify({"status_code": 200, "detail": "success", "data": data_list})
- @server_bp.route("/dataset/add", methods=["POST"])
- async def add_dataset():
- resource = get_resource_manager()
- dataset_mapper = Dataset(resource.mysql_client)
- # 从请求体里取参数
- body = await request.get_json()
- name = body.get("name")
- if not name:
- return jsonify({"status_code": 400, "detail": "name is required"})
- # 执行新增
- dataset = await dataset_mapper.select_dataset_by_name(name)
- if dataset:
- return jsonify({"status_code": 400, "detail": "name is exist"})
- await dataset_mapper.add_dataset(name)
- new_dataset = await dataset_mapper.select_dataset_by_name(name)
- return jsonify(
- {
- "status_code": 200,
- "detail": "success",
- "data": {"datasetId": new_dataset[0]["id"]},
- }
- )
- @server_bp.route("/content/get", methods=["GET"])
- async def get_content():
- resource = get_resource_manager()
- contents = Contents(resource.mysql_client)
- # 获取请求参数
- doc_id = request.args.get("docId")
- if not doc_id:
- return jsonify({"status_code": 400, "detail": "doc_id is required", "data": {}})
- # 查询内容
- rows = await contents.select_content_by_doc_id(doc_id)
- if not rows:
- return jsonify({"status_code": 404, "detail": "content not found", "data": {}})
- row = rows[0]
- return jsonify(
- {
- "status_code": 200,
- "detail": "success",
- "data": {
- "title": row.get("title", ""),
- "text": row.get("text", ""),
- "doc_id": row.get("doc_id", ""),
- },
- }
- )
- @server_bp.route("/content/list", methods=["GET"])
- async def content_list():
- resource = get_resource_manager()
- contents = Contents(resource.mysql_client)
- # 从 URL 查询参数获取分页和过滤参数
- page_num = int(request.args.get("page", 1))
- page_size = int(request.args.get("pageSize", 10))
- dataset_id = request.args.get("datasetId")
- doc_status = int(request.args.get("doc_status", 1))
- # order_by 可以用 JSON 字符串传递
- import json
- order_by_str = request.args.get("order_by", '{"id":"desc"}')
- try:
- order_by = json.loads(order_by_str)
- except Exception:
- order_by = {"id": "desc"}
- # 调用 select_contents,获取分页字典
- result = await contents.select_contents(
- page_num=page_num,
- page_size=page_size,
- dataset_id=dataset_id,
- doc_status=doc_status,
- order_by=order_by,
- )
- # 格式化 entities,只保留必要字段
- entities = [
- {
- "doc_id": row["doc_id"],
- "title": row.get("title") or "",
- "text": row.get("text") or "",
- "statusDesc": "可用" if row.get("status") == 2 else "不可用",
- }
- for row in result["entities"]
- ]
- return jsonify(
- {
- "status_code": 200,
- "detail": "success",
- "data": {
- "entities": entities,
- "total_count": result["total_count"],
- "page": result["page"],
- "page_size": result["page_size"],
- "total_pages": result["total_pages"],
- },
- }
- )
- async def query_search(
- query_text,
- filters=None,
- search_type="",
- anns_field="vector_text",
- search_params=BASE_MILVUS_SEARCH_PARAMS,
- _source=False,
- es_size=10000,
- sort_by=None,
- milvus_size=20,
- limit=10,
- ):
- if filters is None:
- filters = {}
- query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
- resource = get_resource_manager()
- search_engine = HybridSearch(
- milvus_pool=resource.milvus_client,
- es_pool=resource.es_client,
- graph_pool=resource.graph_client,
- )
- try:
- match search_type:
- case "base":
- response = await search_engine.base_vector_search(
- query_vec=query_vector,
- anns_field=anns_field,
- search_params=search_params,
- limit=limit,
- )
- return response
- case "hybrid":
- response = await search_engine.hybrid_search(
- filters=filters,
- query_vec=query_vector,
- anns_field=anns_field,
- search_params=search_params,
- es_size=es_size,
- sort_by=sort_by,
- milvus_size=milvus_size,
- )
- case "strategy":
- return None
- case _:
- return None
- except Exception as e:
- return None
- if response is None:
- return None
- resource = get_resource_manager()
- content_chunk_mapper = ContentChunks(resource.mysql_client)
- res = []
- for result in response["results"]:
- content_chunks = await content_chunk_mapper.select_chunk_content(
- doc_id=result["doc_id"], chunk_id=result["chunk_id"]
- )
- if content_chunks:
- content_chunk = content_chunks[0]
- res.append(
- {
- "docId": content_chunk["doc_id"],
- "content": content_chunk["text"],
- "contentSummary": content_chunk["summary"],
- "score": result["score"],
- "datasetId": content_chunk["dataset_id"],
- }
- )
- return res[:limit]
- @server_bp.route("/query", methods=["GET"])
- async def query():
- query_text = request.args.get("query")
- dataset_ids = request.args.get("datasetIds").split(",")
- search_type = request.args.get("search_type", "hybrid")
- query_results = await query_search(
- query_text=query_text,
- filters={"dataset_id": dataset_ids},
- search_type=search_type,
- )
- resource = get_resource_manager()
- dataset_mapper = Dataset(resource.mysql_client)
- for result in query_results:
- datasets = await dataset_mapper.select_dataset_by_id(result["datasetId"])
- if datasets:
- dataset_name = datasets[0]["name"]
- result["datasetName"] = dataset_name
- data = {"results": query_results}
- return jsonify({"status_code": 200, "detail": "success", "data": data})
- @server_bp.route("/chat", methods=["GET"])
- async def chat():
- query_text = request.args.get("query")
- dataset_id_strs = request.args.get("datasetIds")
- dataset_ids = dataset_id_strs.split(",")
- search_type = request.args.get("search_type", "hybrid")
- query_results = await query_search(
- query_text=query_text,
- filters={"dataset_id": dataset_ids},
- search_type=search_type,
- )
- resource = get_resource_manager()
- chat_result_mapper = ChatResult(resource.mysql_client)
- dataset_mapper = Dataset(resource.mysql_client)
- for result in query_results:
- datasets = await dataset_mapper.select_dataset_by_id(result["datasetId"])
- if datasets:
- dataset_name = datasets[0]["name"]
- result["datasetName"] = dataset_name
- rag_chat_agent = RAGChatAgent()
- chat_result = await rag_chat_agent.chat_with_deepseek(query_text, query_results)
- study_task_id = None
- if chat_result["status"] == 0:
- study_task_id = study(query_text)['task_id']
- llm_search = await rag_chat_agent.llm_search(query_text)
- decision = await rag_chat_agent.make_decision(chat_result, llm_search)
- data = {
- "results": query_results,
- "chat_res": decision["result"],
- "rag_summary": chat_result["summary"],
- "llm_summary": llm_search["answer"],
- }
- await chat_result_mapper.insert_chat_result(
- query_text,
- dataset_id_strs,
- json.dumps(query_results, ensure_ascii=False),
- chat_result["summary"],
- chat_result["relevance_score"],
- chat_result["status"],
- llm_search["answer"],
- llm_search["source"],
- llm_search["status"],
- decision["result"],
- study_task_id,
- is_web=1,
- )
- return jsonify({"status_code": 200, "detail": "success", "data": data})
- @server_bp.route("/chunk/list", methods=["GET"])
- async def chunk_list():
- resource = get_resource_manager()
- content_chunk = ContentChunks(resource.mysql_client)
- # 从 URL 查询参数获取分页和过滤参数
- page_num = int(request.args.get("page", 1))
- page_size = int(request.args.get("pageSize", 10))
- doc_id = request.args.get("docId")
- if not doc_id:
- return jsonify({"status_code": 500, "detail": "docId not found", "data": {}})
- # 调用 select_contents,获取分页字典
- result = await content_chunk.select_chunk_contents(
- page_num=page_num, page_size=page_size, doc_id=doc_id
- )
- if not result:
- return jsonify({"status_code": 500, "detail": "chunk is empty", "data": {}})
- # 格式化 entities,只保留必要字段
- entities = [
- {
- "id": row["id"],
- "chunk_id": row["chunk_id"],
- "doc_id": row["doc_id"],
- "summary": row.get("summary") or "",
- "text": row.get("text") or "",
- "statusDesc": "可用" if row.get("chunk_status") == 2 else "不可用",
- }
- for row in result["entities"]
- ]
- return jsonify(
- {
- "status_code": 200,
- "detail": "success",
- "data": {
- "entities": entities,
- "total_count": result["total_count"],
- "page": result["page"],
- "page_size": result["page_size"],
- "total_pages": result["total_pages"],
- },
- }
- )
- @server_bp.route("/auto_rechunk", methods=["GET"])
- async def auto_rechunk():
- resource = get_resource_manager()
- auto_rechunk_task = AutoRechunkTask(mysql_client=resource.mysql_client)
- process_cnt = await auto_rechunk_task.deal()
- return jsonify({"status_code": 200, "detail": "success", "cnt": process_cnt})
- @server_bp.route("/build_graph", methods=["POST"])
- async def delete_task():
- body = await request.get_json()
- doc_id: str = body.get("doc_id")
- if not doc_id:
- return jsonify({"status_code": 500, "detail": "docId not found", "data": {}})
- dataset_id: str = body.get("dataset_id", 12)
- batch: bool = body.get("batch_process", False)
- resource = get_resource_manager()
- build_graph_task = BuildGraph(
- neo4j=resource.graph_client,
- es_client=resource.es_client,
- mysql_client=resource.mysql_client,
- )
- if batch:
- await build_graph_task.deal_batch(dataset_id)
- else:
- await build_graph_task.deal(doc_id)
- return jsonify({"status_code": 200, "detail": "success", "data": {}})
- @server_bp.route("/rag/search", methods=["POST"])
- async def rag_search():
- body = await request.get_json()
- query_text = body.get("queryText")
- dataset_id_strs = "11,12"
- dataset_ids = dataset_id_strs.split(",")
- search_type = "hybrid"
- query_results = await query_search(
- query_text=query_text,
- filters={"dataset_id": dataset_ids},
- search_type=search_type,
- limit=5,
- )
- resource = get_resource_manager()
- chat_result_mapper = ChatResult(resource.mysql_client)
- rag_chat_agent = RAGChatAgent()
- chat_result = await rag_chat_agent.chat_with_deepseek(query_text, query_results)
- study_task_id = None
- if chat_result["status"] == 0:
- study_task_id = study(query_text)['task_id']
- llm_search = await rag_chat_agent.llm_search(query_text)
- decision = await rag_chat_agent.make_decision(chat_result, llm_search)
- data = {
- "result": decision["result"],
- "status": decision["status"],
- "relevance_score": decision["relevance_score"],
- }
- await chat_result_mapper.insert_chat_result(
- query_text,
- dataset_id_strs,
- json.dumps(query_results, ensure_ascii=False),
- chat_result["summary"],
- chat_result["relevance_score"],
- chat_result["status"],
- llm_search["answer"],
- llm_search["source"],
- llm_search["status"],
- decision["result"],
- study_task_id
- )
- return jsonify({"status_code": 200, "detail": "success", "data": data})
- @server_bp.route("/chat/history", methods=["GET"])
- async def chat_history():
- page_num = int(request.args.get("page", 1))
- page_size = int(request.args.get("pageSize", 10))
- resource = get_resource_manager()
- chat_result_mapper = ChatResult(resource.mysql_client)
- result = await chat_result_mapper.select_chat_results(page_num, page_size)
- return jsonify(
- {
- "status_code": 200,
- "detail": "success",
- "data": {
- "entities": result["entities"],
- "total_count": result["total_count"],
- "page": result["page"],
- "page_size": result["page_size"],
- "total_pages": result["total_pages"],
- },
- }
- )
|