瀏覽代碼

格式化代码

xueyiming 1 天之前
父節點
當前提交
770b61b07d
共有 2 個文件被更改,包括 34 次插入32 次删除
  1. 25 20
      applications/utils/task/async_task.py
  2. 9 12
      routes/blueprint.py

+ 25 - 20
applications/utils/task/async_task.py

@@ -13,6 +13,7 @@ from applications.utils.oss.oss_client import OSSClient
 from applications.utils.pdf.book_extract import book_extract
 from applications.utils.spider.study import study
 
+
 async def handle_books():
     try:
         # 获取资源管理器和客户端
@@ -26,7 +27,9 @@ async def handle_books():
         for book in books:
             book_id = book.get("book_id")
             # 获取提取状态
-            extract_status = (await books_mapper.select_book_extract_status(book_id))[0].get("extract_status")
+            extract_status = (await books_mapper.select_book_extract_status(book_id))[
+                0
+            ].get("extract_status")
 
             if extract_status == 0:
                 # 更新提取状态为处理中
@@ -45,10 +48,16 @@ async def handle_books():
                     # 提取书籍内容
                     res = await book_extract(book_path, book_id)
                     if res:
-                        content_list = res.get("results", {}).get(book_id, {}).get("content_list", [])
+                        content_list = (
+                            res.get("results", {})
+                            .get(book_id, {})
+                            .get("content_list", [])
+                        )
                         if content_list:
                             # 更新提取结果
-                            await books_mapper.update_book_extract_result(book_id, content_list)
+                            await books_mapper.update_book_extract_result(
+                                book_id, content_list
+                            )
 
                         # 创建文档 ID
                         doc_id = f"doc-{uuid.uuid4()}"
@@ -91,12 +100,8 @@ async def process_question(question, query_text, rag_chat_agent):
             study_task_id = study(question)["task_id"]
 
         qwen_client = QwenClient()
-        llm_search = qwen_client.search_and_chat(
-            user_prompt=question
-        )
-        decision = await rag_chat_agent.make_decision(
-            question, chat_result, llm_search
-        )
+        llm_search = qwen_client.search_and_chat(user_prompt=question)
+        decision = await rag_chat_agent.make_decision(question, chat_result, llm_search)
 
         # 构建返回的数据
         data = {
@@ -128,16 +133,16 @@ async def process_question(question, query_text, rag_chat_agent):
 
 
 async def query_search(
-        query_text,
-        filters=None,
-        search_type="",
-        anns_field="vector_text",
-        search_params=BASE_MILVUS_SEARCH_PARAMS,
-        _source=False,
-        es_size=10000,
-        sort_by=None,
-        milvus_size=20,
-        limit=10,
+    query_text,
+    filters=None,
+    search_type="",
+    anns_field="vector_text",
+    search_params=BASE_MILVUS_SEARCH_PARAMS,
+    _source=False,
+    es_size=10000,
+    sort_by=None,
+    milvus_size=20,
+    limit=10,
 ):
     if filters is None:
         filters = {}
@@ -194,4 +199,4 @@ async def query_search(
                     "datasetId": content_chunk["dataset_id"],
                 }
             )
-    return res[:limit]
+    return res[:limit]

+ 9 - 12
routes/blueprint.py

@@ -23,7 +23,11 @@ from applications.utils.chat import RAGChatAgent
 from applications.utils.mysql import Dataset, Contents, ContentChunks, ChatResult, Books
 from applications.api.qwen import QwenClient
 from applications.utils.oss.oss_client import OSSClient
-from applications.utils.task.async_task import handle_books, process_question, query_search
+from applications.utils.task.async_task import (
+    handle_books,
+    process_question,
+    query_search,
+)
 
 server_bp = Blueprint("api", __name__, url_prefix="/api")
 server_bp = cors(server_bp, allow_origin="*")
@@ -481,7 +485,6 @@ async def rag_search():
     return jsonify({"status_code": 200, "detail": "success", "data": data_list})
 
 
-
 @server_bp.route("/chat/history", methods=["GET"])
 async def chat_history():
     page_num = int(request.args.get("page", 1))
@@ -525,7 +528,6 @@ async def upload_pdf():
 
         # 获取文件名
         filename = file.filename
-        print(filename)
 
         book_id = f"book-{uuid.uuid4()}"
         # 检查文件的 MIME 类型是否是 application/pdf
@@ -549,17 +551,12 @@ async def upload_pdf():
             oss_path = f"rag/pdfs/{book_id}"
             oss_client.upload_file(file_path, oss_path)
             await books.insert_book(book_id, filename, oss_path)
-            # os.remove(file_path)
-            return jsonify(
-                {
-                    "status": "success",
-                    "message": f"File {filename} uploaded successfully to OSS!",
-                }
-            ), 200
+            return jsonify({"status_code": 200, "detail": "success"})
         except Exception as e:
-            return jsonify({"status": "error", "message": str(e)}), 500
+            return jsonify({"status_code": 500, "detail": str(e)})
+
     else:
-        return jsonify({"status": "error", "message": "No file uploaded."}), 400
+        return jsonify({"status_code": 400, "detail": "No file uploaded."})
 
 
 @server_bp.route("/process/book", methods=["GET"])