|
@@ -1,5 +1,6 @@
|
|
|
import asyncio
|
|
import asyncio
|
|
|
import json
|
|
import json
|
|
|
|
|
+import os
|
|
|
import traceback
|
|
import traceback
|
|
|
import uuid
|
|
import uuid
|
|
|
from typing import Dict, Any
|
|
from typing import Dict, Any
|
|
@@ -19,9 +20,14 @@ from applications.config import (
|
|
|
from applications.resource import get_resource_manager
|
|
from applications.resource import get_resource_manager
|
|
|
from applications.search import HybridSearch
|
|
from applications.search import HybridSearch
|
|
|
from applications.utils.chat import RAGChatAgent
|
|
from applications.utils.chat import RAGChatAgent
|
|
|
-from applications.utils.mysql import Dataset, Contents, ContentChunks, ChatResult
|
|
|
|
|
|
|
+from applications.utils.mysql import Dataset, Contents, ContentChunks, ChatResult, Books
|
|
|
from applications.api.qwen import QwenClient
|
|
from applications.api.qwen import QwenClient
|
|
|
-from applications.utils.spider.study import study
|
|
|
|
|
|
|
+from applications.utils.oss.oss_client import OSSClient
|
|
|
|
|
+from applications.utils.task.async_task import (
|
|
|
|
|
+ handle_books,
|
|
|
|
|
+ process_question,
|
|
|
|
|
+ query_search,
|
|
|
|
|
+)
|
|
|
|
|
|
|
|
server_bp = Blueprint("api", __name__, url_prefix="/api")
|
|
server_bp = Blueprint("api", __name__, url_prefix="/api")
|
|
|
server_bp = cors(server_bp, allow_origin="*")
|
|
server_bp = cors(server_bp, allow_origin="*")
|
|
@@ -312,76 +318,6 @@ async def content_list():
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
-async def query_search(
|
|
|
|
|
- query_text,
|
|
|
|
|
- filters=None,
|
|
|
|
|
- search_type="",
|
|
|
|
|
- anns_field="vector_text",
|
|
|
|
|
- search_params=BASE_MILVUS_SEARCH_PARAMS,
|
|
|
|
|
- _source=False,
|
|
|
|
|
- es_size=10000,
|
|
|
|
|
- sort_by=None,
|
|
|
|
|
- milvus_size=20,
|
|
|
|
|
- limit=10,
|
|
|
|
|
-):
|
|
|
|
|
- if filters is None:
|
|
|
|
|
- filters = {}
|
|
|
|
|
- query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
|
|
|
|
|
- resource = get_resource_manager()
|
|
|
|
|
- search_engine = HybridSearch(
|
|
|
|
|
- milvus_pool=resource.milvus_client,
|
|
|
|
|
- es_pool=resource.es_client,
|
|
|
|
|
- graph_pool=resource.graph_client,
|
|
|
|
|
- )
|
|
|
|
|
- try:
|
|
|
|
|
- match search_type:
|
|
|
|
|
- case "base":
|
|
|
|
|
- response = await search_engine.base_vector_search(
|
|
|
|
|
- query_vec=query_vector,
|
|
|
|
|
- anns_field=anns_field,
|
|
|
|
|
- search_params=search_params,
|
|
|
|
|
- limit=limit,
|
|
|
|
|
- )
|
|
|
|
|
- return response
|
|
|
|
|
- case "hybrid":
|
|
|
|
|
- response = await search_engine.hybrid_search(
|
|
|
|
|
- filters=filters,
|
|
|
|
|
- query_vec=query_vector,
|
|
|
|
|
- anns_field=anns_field,
|
|
|
|
|
- search_params=search_params,
|
|
|
|
|
- es_size=es_size,
|
|
|
|
|
- sort_by=sort_by,
|
|
|
|
|
- milvus_size=milvus_size,
|
|
|
|
|
- )
|
|
|
|
|
- case "strategy":
|
|
|
|
|
- return None
|
|
|
|
|
- case _:
|
|
|
|
|
- return None
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- return None
|
|
|
|
|
- if response is None:
|
|
|
|
|
- return None
|
|
|
|
|
- resource = get_resource_manager()
|
|
|
|
|
- content_chunk_mapper = ContentChunks(resource.mysql_client)
|
|
|
|
|
- res = []
|
|
|
|
|
- for result in response["results"]:
|
|
|
|
|
- content_chunks = await content_chunk_mapper.select_chunk_content(
|
|
|
|
|
- doc_id=result["doc_id"], chunk_id=result["chunk_id"]
|
|
|
|
|
- )
|
|
|
|
|
- if content_chunks:
|
|
|
|
|
- content_chunk = content_chunks[0]
|
|
|
|
|
- res.append(
|
|
|
|
|
- {
|
|
|
|
|
- "docId": content_chunk["doc_id"],
|
|
|
|
|
- "content": content_chunk["text"],
|
|
|
|
|
- "contentSummary": content_chunk["summary"],
|
|
|
|
|
- "score": result["score"],
|
|
|
|
|
- "datasetId": content_chunk["dataset_id"],
|
|
|
|
|
- }
|
|
|
|
|
- )
|
|
|
|
|
- return res[:limit]
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
@server_bp.route("/query", methods=["GET"])
|
|
@server_bp.route("/query", methods=["GET"])
|
|
|
async def query():
|
|
async def query():
|
|
|
query_text = request.args.get("query")
|
|
query_text = request.args.get("query")
|
|
@@ -549,67 +485,6 @@ async def rag_search():
|
|
|
return jsonify({"status_code": 200, "detail": "success", "data": data_list})
|
|
return jsonify({"status_code": 200, "detail": "success", "data": data_list})
|
|
|
|
|
|
|
|
|
|
|
|
|
-async def process_question(question, query_text, rag_chat_agent):
|
|
|
|
|
- try:
|
|
|
|
|
- dataset_id_strs = "11,12"
|
|
|
|
|
- dataset_ids = dataset_id_strs.split(",")
|
|
|
|
|
- search_type = "hybrid"
|
|
|
|
|
-
|
|
|
|
|
- # 执行查询任务
|
|
|
|
|
- query_results = await query_search(
|
|
|
|
|
- query_text=question,
|
|
|
|
|
- filters={"dataset_id": dataset_ids},
|
|
|
|
|
- search_type=search_type,
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- resource = get_resource_manager()
|
|
|
|
|
- chat_result_mapper = ChatResult(resource.mysql_client)
|
|
|
|
|
-
|
|
|
|
|
- # 异步执行 chat 与 deepseek 的对话
|
|
|
|
|
- chat_result = await rag_chat_agent.chat_with_deepseek(question, query_results)
|
|
|
|
|
-
|
|
|
|
|
- # # 判断是否需要执行 study
|
|
|
|
|
- study_task_id = None
|
|
|
|
|
- if chat_result["status"] == 0:
|
|
|
|
|
- study_task_id = study(question)["task_id"]
|
|
|
|
|
-
|
|
|
|
|
- qwen_client = QwenClient()
|
|
|
|
|
- llm_search = qwen_client.search_and_chat(
|
|
|
|
|
- user_prompt=question
|
|
|
|
|
- )
|
|
|
|
|
- decision = await rag_chat_agent.make_decision(
|
|
|
|
|
- question, chat_result, llm_search
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- # 构建返回的数据
|
|
|
|
|
- data = {
|
|
|
|
|
- "query": question,
|
|
|
|
|
- "result": decision["result"],
|
|
|
|
|
- "status": decision["status"],
|
|
|
|
|
- "relevance_score": decision["relevance_score"],
|
|
|
|
|
- # "used_tools": decision["used_tools"],
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- # 插入数据库
|
|
|
|
|
- await chat_result_mapper.insert_chat_result(
|
|
|
|
|
- question,
|
|
|
|
|
- dataset_id_strs,
|
|
|
|
|
- json.dumps(query_results, ensure_ascii=False),
|
|
|
|
|
- chat_result["summary"],
|
|
|
|
|
- chat_result["relevance_score"],
|
|
|
|
|
- chat_result["status"],
|
|
|
|
|
- llm_search["content"],
|
|
|
|
|
- json.dumps(llm_search["search_results"], ensure_ascii=False),
|
|
|
|
|
- 1,
|
|
|
|
|
- decision["result"],
|
|
|
|
|
- study_task_id,
|
|
|
|
|
- )
|
|
|
|
|
- return data
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- print(f"Error processing question: {question}. Error: {str(e)}")
|
|
|
|
|
- return {"query": question, "error": str(e)}
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
@server_bp.route("/chat/history", methods=["GET"])
|
|
@server_bp.route("/chat/history", methods=["GET"])
|
|
|
async def chat_history():
|
|
async def chat_history():
|
|
|
page_num = int(request.args.get("page", 1))
|
|
page_num = int(request.args.get("page", 1))
|
|
@@ -630,3 +505,63 @@ async def chat_history():
|
|
|
},
|
|
},
|
|
|
}
|
|
}
|
|
|
)
|
|
)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@server_bp.route("/upload/file", methods=["POST"])
|
|
|
|
|
+async def upload_pdf():
|
|
|
|
|
+ # 获取前端上传的文件
|
|
|
|
|
+ # 先等待 request.files 属性来确保文件已加载
|
|
|
|
|
+ files = await request.files
|
|
|
|
|
+
|
|
|
|
|
+ # 获取文件对象
|
|
|
|
|
+ file = files.get("file")
|
|
|
|
|
+
|
|
|
|
|
+ if file:
|
|
|
|
|
+ # 检查文件扩展名是否是 .pdf
|
|
|
|
|
+ if not file.filename.lower().endswith(".pdf"):
|
|
|
|
|
+ return jsonify(
|
|
|
|
|
+ {
|
|
|
|
|
+ "status": "error",
|
|
|
|
|
+ "message": "Invalid file format. Only PDF files are allowed.",
|
|
|
|
|
+ }
|
|
|
|
|
+ ), 400
|
|
|
|
|
+
|
|
|
|
|
+ # 获取文件名
|
|
|
|
|
+ filename = file.filename
|
|
|
|
|
+
|
|
|
|
|
+ book_id = f"book-{uuid.uuid4()}"
|
|
|
|
|
+ # 检查文件的 MIME 类型是否是 application/pdf
|
|
|
|
|
+ if file.content_type != "application/pdf":
|
|
|
|
|
+ return jsonify(
|
|
|
|
|
+ {
|
|
|
|
|
+ "status": "error",
|
|
|
|
|
+ "message": "Invalid MIME type. Only PDF files are allowed.",
|
|
|
|
|
+ }
|
|
|
|
|
+ ), 400
|
|
|
|
|
+
|
|
|
|
|
+ # 保存到本地(可选,视需要)
|
|
|
|
|
+ file_path = os.path.join("/tmp", book_id) # 临时存储路径
|
|
|
|
|
+ await file.save(file_path)
|
|
|
|
|
+ resource = get_resource_manager()
|
|
|
|
|
+ books = Books(resource.mysql_client)
|
|
|
|
|
+ # 上传到 OSS
|
|
|
|
|
+ try:
|
|
|
|
|
+ oss_client = OSSClient()
|
|
|
|
|
+ # 上传文件到 OSS
|
|
|
|
|
+ oss_path = f"rag/pdfs/{book_id}"
|
|
|
|
|
+ oss_client.upload_file(file_path, oss_path)
|
|
|
|
|
+ await books.insert_book(book_id, filename, oss_path)
|
|
|
|
|
+ return jsonify({"status_code": 200, "detail": "success"})
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ return jsonify({"status_code": 500, "detail": str(e)})
|
|
|
|
|
+
|
|
|
|
|
+ else:
|
|
|
|
|
+ return jsonify({"status_code": 400, "detail": "No file uploaded."})
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@server_bp.route("/process/book", methods=["GET"])
|
|
|
|
|
+async def process_book():
|
|
|
|
|
+ # 创建异步任务来后台处理书籍
|
|
|
|
|
+ asyncio.create_task(handle_books())
|
|
|
|
|
+ # 返回立即响应
|
|
|
|
|
+ return jsonify({"status": "success", "message": "任务已提交后台处理"}), 200
|