Explorar o código

Merge branch 'feature/xueyiming/2025-10-09-update-query' of Server/rag_server into master

xueyiming hai 1 mes
pai
achega
7c0dc4eb33

+ 37 - 0
applications/utils/mysql/books.py

@@ -23,3 +23,40 @@ class Books(BaseMySQLClient):
         return await self.pool.async_save(
             query=query, params=(new_status, book_id, ori_status)
         )
+
+    async def insert_book(self, book_id, book_name, book_oss_path):
+        query = """
+            INSERT INTO books (book_id, book_name, book_oss_path)
+             VALUES (%s, %s, %s);
+        """
+        return await self.pool.async_save(
+            query=query, params=(book_id, book_name, book_oss_path)
+        )
+
+    async def select_init_books(self):
+        query = """
+            SELECT book_id, book_name, book_oss_path, extract_status
+            FROM books
+            WHERE extract_status = 0;
+        """
+        return await self.pool.async_fetch(query=query)
+
+    async def select_book_extract_status(self, book_id):
+        query = """
+            SELECT book_id, extract_status
+            FROM books
+            WHERE book_id = %s;
+        """
+        return await self.pool.async_fetch(query=query, params=(book_id,))
+
+    async def update_book_extract_status(self, book_id, status):
+        query = """
+            UPDATE books SET extract_status = %s WHERE book_id = %s;
+            """
+        return await self.pool.async_save(query=query, params=(status, book_id))
+
+    async def update_book_extract_result(self, book_id, extract_result):
+        query = """
+            UPDATE books SET extract_result = %s, extract_status = 2 WHERE book_id = %s;
+            """
+        return await self.pool.async_save(query=query, params=(extract_result, book_id))

+ 0 - 0
applications/utils/oss/__init__.py


+ 90 - 0
applications/utils/oss/oss_client.py

@@ -0,0 +1,90 @@
+import oss2
+import os
+
+
+class OSSClient:
+    # 配置默认的 endpoint 地址
+    DEFAULT_ENDPOINT = "oss-cn-hangzhou.aliyuncs.com"  # 默认华东1区(杭州)
+
+    def __init__(self, access_key_id=None, access_key_secret=None, bucket_name=None):
+        """
+        初始化 OSS 客户端
+        :param bucket_name: Bucket 名称
+        """
+        # 从环境变量中获取 Access Key 和 Secret
+        if access_key_id is None or access_key_secret is None:
+            access_key_id = "LTAIP6x1l3DXfSxm"
+            access_key_secret = "KbTaM9ars4OX3PMS6Xm7rtxGr1FLon"
+        if bucket_name is None:
+            bucket_name = "art-pubbucket"
+
+        # 检查是否有凭证
+        if not access_key_id or not access_key_secret:
+            raise ValueError(
+                "ACCESS_KEY_ID and ACCESS_KEY_SECRET must be set in the environment variables."
+            )
+
+        # 使用默认的 endpoint 地址
+        self.auth = oss2.Auth(access_key_id, access_key_secret)
+        self.bucket = oss2.Bucket(self.auth, self.DEFAULT_ENDPOINT, bucket_name)
+
+    def upload_file(self, local_file_path, oss_file_path):
+        """
+        上传文件到 OSS
+        :param local_file_path: 本地文件路径
+        :param oss_file_path: OSS 存储的文件路径(例如:pdfs/myfile.pdf)
+        :return: 上传结果,成功返回文件信息,失败抛出异常
+        """
+        if not os.path.exists(local_file_path):
+            raise FileNotFoundError(f"Local file {local_file_path} does not exist.")
+
+        try:
+            self.bucket.put_object_from_file(oss_file_path, local_file_path)
+            return {"status": "success", "message": f"File uploaded to {oss_file_path}"}
+        except Exception as e:
+            raise Exception(f"Error uploading file to OSS: {str(e)}")
+
+    def download_file(self, oss_file_path, local_file_path):
+        """
+        从 OSS 下载文件
+        :param oss_file_path: OSS 文件路径(例如:pdfs/myfile.pdf)
+        :param local_file_path: 本地保存路径
+        :return: 下载结果,成功返回下载文件的路径,失败抛出异常
+        """
+        try:
+            self.bucket.get_object_to_file(oss_file_path, local_file_path)
+            return {
+                "status": "success",
+                "message": f"File downloaded to {local_file_path}",
+            }
+        except Exception as e:
+            raise Exception(f"Error downloading file from OSS: {str(e)}")
+
+    def delete_file(self, oss_file_path):
+        """
+        从 OSS 删除文件
+        :param oss_file_path: OSS 文件路径(例如:pdfs/myfile.pdf)
+        :return: 删除结果,成功返回消息,失败抛出异常
+        """
+        try:
+            self.bucket.delete_object(oss_file_path)
+            return {
+                "status": "success",
+                "message": f"File {oss_file_path} deleted from OSS",
+            }
+        except Exception as e:
+            raise Exception(f"Error deleting file from OSS: {str(e)}")
+
+    def file_exists(self, oss_file_path):
+        """
+        检查文件是否存在于 OSS 中
+        :param oss_file_path: OSS 文件路径(例如:pdfs/myfile.pdf)
+        :return: 布尔值,文件存在返回 True,文件不存在返回 False
+        """
+        try:
+            self.bucket.get_object(oss_file_path)
+            return True
+        except oss2.exceptions.NoSuchKey:
+            return False
+        except Exception as e:
+            raise Exception(f"Error checking file existence on OSS: {str(e)}")

+ 0 - 0
applications/utils/pdf/__init__.py


+ 29 - 0
applications/utils/pdf/book_extract.py

@@ -0,0 +1,29 @@
+import requests
+# -*- coding: utf-8 -*-
+
+
+async def book_extract(book_path, book_id):
+    with open(book_path, "rb") as f:
+        files = {"files": (book_id, f, "application/pdf")}
+        response = requests.post(
+            "http://192.168.100.31:8003/file_parse",
+            headers={"accept": "application/json"},
+            data={
+                "return_model_output": "false",
+                "return_md": "false",
+                "return_images": "false",
+                "end_page_id": "99999",
+                "parse_method": "auto",
+                "start_page_id": "0",
+                "lang_list": "ch",
+                "output_dir": "./output",
+                "server_url": "string",
+                "return_content_list": "true",
+                "backend": "pipeline",
+                "table_enable": "true",
+                "response_format_zip": "false",
+                "formula_enable": "true",
+            },
+            files=files,
+        )
+    return response.json()

+ 1 - 0
applications/utils/spider/study.py

@@ -1,4 +1,5 @@
 import json
+import time
 
 import requests
 

+ 0 - 0
applications/utils/task/__init__.py


+ 203 - 0
applications/utils/task/async_task.py

@@ -0,0 +1,203 @@
+import json
+import os
+import uuid
+
+from applications.api import get_basic_embedding
+from applications.api.qwen import QwenClient
+from applications.async_task import ChunkBooksTask
+from applications.config import BASE_MILVUS_SEARCH_PARAMS, DEFAULT_MODEL
+from applications.resource import get_resource_manager
+from applications.search import HybridSearch
+from applications.utils.mysql import Books, ChatResult, ContentChunks
+from applications.utils.oss.oss_client import OSSClient
+from applications.utils.pdf.book_extract import book_extract
+from applications.utils.spider.study import study
+
+
+async def handle_books():
+    try:
+        # 获取资源管理器和客户端
+        resource = get_resource_manager()
+        books_mapper = Books(resource.mysql_client)
+        oss_client = OSSClient()
+
+        # 获取待处理的书籍列表
+        books = await books_mapper.select_init_books()
+
+        for book in books:
+            book_id = book.get("book_id")
+            # 获取提取状态
+            extract_status = (await books_mapper.select_book_extract_status(book_id))[
+                0
+            ].get("extract_status")
+
+            if extract_status == 0:
+                # 更新提取状态为处理中
+                await books_mapper.update_book_extract_status(book_id, 1)
+                book_path = os.path.join("/tmp", book_id)
+
+                if not os.path.exists(book_path):
+                    oss_path = f"rag/pdfs/{book_id}"
+                    try:
+                        # 下载书籍文件
+                        await oss_client.download_file(oss_path, book_path)
+                    except Exception as e:
+                        continue  # 如果下载失败,跳过该书籍
+
+                try:
+                    # 提取书籍内容
+                    res = await book_extract(book_path, book_id)
+                    if res:
+                        content_list = (
+                            res.get("results", {})
+                            .get(book_id, {})
+                            .get("content_list", [])
+                        )
+                        if content_list:
+                            # 更新提取结果
+                            await books_mapper.update_book_extract_result(
+                                book_id, content_list
+                            )
+
+                except Exception as e:
+                    await books_mapper.update_book_extract_status(book_id, 99)
+                    continue  # 如果提取过程失败,跳过该书籍
+
+                # 创建文档 ID
+                doc_id = f"doc-{uuid.uuid4()}"
+                chunk_task = ChunkBooksTask(doc_id=doc_id, resource=resource)
+
+                # 处理分片任务
+                body = {"book_id": book_id}
+                await chunk_task.deal(body)  # 异步执行分片任务
+
+    except Exception as e:
+        # 捕获整体异常
+        print(f"处理请求失败,错误: {e}")
+
+
+async def process_question(question, query_text, rag_chat_agent):
+    try:
+        dataset_id_strs = "11,12"
+        dataset_ids = dataset_id_strs.split(",")
+        search_type = "hybrid"
+
+        # 执行查询任务
+        query_results = await query_search(
+            query_text=question,
+            filters={"dataset_id": dataset_ids},
+            search_type=search_type,
+        )
+
+        resource = get_resource_manager()
+        chat_result_mapper = ChatResult(resource.mysql_client)
+
+        # 异步执行 chat 与 deepseek 的对话
+        chat_result = await rag_chat_agent.chat_with_deepseek(question, query_results)
+
+        # # 判断是否需要执行 study
+        study_task_id = None
+        if chat_result["status"] == 0:
+            study_task_id = study(question)["task_id"]
+
+        qwen_client = QwenClient()
+        llm_search = qwen_client.search_and_chat(user_prompt=question)
+        decision = await rag_chat_agent.make_decision(question, chat_result, llm_search)
+
+        # 构建返回的数据
+        data = {
+            "query": question,
+            "result": decision["result"],
+            "status": decision["status"],
+            "relevance_score": decision["relevance_score"],
+            # "used_tools": decision["used_tools"],
+        }
+
+        # 插入数据库
+        await chat_result_mapper.insert_chat_result(
+            question,
+            dataset_id_strs,
+            json.dumps(query_results, ensure_ascii=False),
+            chat_result["summary"],
+            chat_result["relevance_score"],
+            chat_result["status"],
+            llm_search["content"],
+            json.dumps(llm_search["search_results"], ensure_ascii=False),
+            1,
+            decision["result"],
+            study_task_id,
+        )
+        return data
+    except Exception as e:
+        print(f"Error processing question: {question}. Error: {str(e)}")
+        return {"query": question, "error": str(e)}
+
+
+async def query_search(
+    query_text,
+    filters=None,
+    search_type="",
+    anns_field="vector_text",
+    search_params=BASE_MILVUS_SEARCH_PARAMS,
+    _source=False,
+    es_size=10000,
+    sort_by=None,
+    milvus_size=20,
+    limit=10,
+):
+    if filters is None:
+        filters = {}
+    query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
+    resource = get_resource_manager()
+    search_engine = HybridSearch(
+        milvus_pool=resource.milvus_client,
+        es_pool=resource.es_client,
+        graph_pool=resource.graph_client,
+    )
+    try:
+        match search_type:
+            case "base":
+                response = await search_engine.base_vector_search(
+                    query_vec=query_vector,
+                    anns_field=anns_field,
+                    search_params=search_params,
+                    limit=limit,
+                )
+                return response
+            case "hybrid":
+                response = await search_engine.hybrid_search(
+                    filters=filters,
+                    query_vec=query_vector,
+                    anns_field=anns_field,
+                    search_params=search_params,
+                    es_size=es_size,
+                    sort_by=sort_by,
+                    milvus_size=milvus_size,
+                )
+            case "strategy":
+                return None
+            case _:
+                return None
+    except Exception as e:
+        return None
+    if response is None:
+        return None
+    resource = get_resource_manager()
+    content_chunk_mapper = ContentChunks(resource.mysql_client)
+    res = []
+    for result in response["results"]:
+        content_chunks = await content_chunk_mapper.select_chunk_content(
+            doc_id=result["doc_id"], chunk_id=result["chunk_id"]
+        )
+        if content_chunks:
+            content_chunk = content_chunks[0]
+            res.append(
+                {
+                    "docId": content_chunk["doc_id"],
+                    "content": content_chunk["text"],
+                    "contentSummary": content_chunk["summary"],
+                    "score": result["score"],
+                    "datasetId": content_chunk["dataset_id"],
+                }
+            )
+    return res[:limit]

+ 1 - 1
mcp_server/server.py

@@ -10,7 +10,7 @@ from applications.utils.chat import RAGChatAgent
 from applications.utils.mysql import ChatResult
 from applications.api.qwen import QwenClient
 from applications.utils.spider.study import study
-from routes.blueprint import query_search
+from applications.utils.task.async_task import query_search
 
 
 def create_mcp_server() -> Server:

+ 1 - 0
requirements.txt

@@ -26,4 +26,5 @@ langchain==0.3.27
 langchain-core==0.3.76
 langchain-text-splitters==0.3.11
 mcp==1.14.1
+oss2==2.19.1
 dashscope==1.24.6

+ 68 - 133
routes/blueprint.py

@@ -1,5 +1,6 @@
 import asyncio
 import json
+import os
 import traceback
 import uuid
 from typing import Dict, Any
@@ -19,9 +20,14 @@ from applications.config import (
 from applications.resource import get_resource_manager
 from applications.search import HybridSearch
 from applications.utils.chat import RAGChatAgent
-from applications.utils.mysql import Dataset, Contents, ContentChunks, ChatResult
+from applications.utils.mysql import Dataset, Contents, ContentChunks, ChatResult, Books
 from applications.api.qwen import QwenClient
-from applications.utils.spider.study import study
+from applications.utils.oss.oss_client import OSSClient
+from applications.utils.task.async_task import (
+    handle_books,
+    process_question,
+    query_search,
+)
 
 server_bp = Blueprint("api", __name__, url_prefix="/api")
 server_bp = cors(server_bp, allow_origin="*")
@@ -312,76 +318,6 @@ async def content_list():
     )
 
 
-async def query_search(
-    query_text,
-    filters=None,
-    search_type="",
-    anns_field="vector_text",
-    search_params=BASE_MILVUS_SEARCH_PARAMS,
-    _source=False,
-    es_size=10000,
-    sort_by=None,
-    milvus_size=20,
-    limit=10,
-):
-    if filters is None:
-        filters = {}
-    query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
-    resource = get_resource_manager()
-    search_engine = HybridSearch(
-        milvus_pool=resource.milvus_client,
-        es_pool=resource.es_client,
-        graph_pool=resource.graph_client,
-    )
-    try:
-        match search_type:
-            case "base":
-                response = await search_engine.base_vector_search(
-                    query_vec=query_vector,
-                    anns_field=anns_field,
-                    search_params=search_params,
-                    limit=limit,
-                )
-                return response
-            case "hybrid":
-                response = await search_engine.hybrid_search(
-                    filters=filters,
-                    query_vec=query_vector,
-                    anns_field=anns_field,
-                    search_params=search_params,
-                    es_size=es_size,
-                    sort_by=sort_by,
-                    milvus_size=milvus_size,
-                )
-            case "strategy":
-                return None
-            case _:
-                return None
-    except Exception as e:
-        return None
-    if response is None:
-        return None
-    resource = get_resource_manager()
-    content_chunk_mapper = ContentChunks(resource.mysql_client)
-    res = []
-    for result in response["results"]:
-        content_chunks = await content_chunk_mapper.select_chunk_content(
-            doc_id=result["doc_id"], chunk_id=result["chunk_id"]
-        )
-        if content_chunks:
-            content_chunk = content_chunks[0]
-            res.append(
-                {
-                    "docId": content_chunk["doc_id"],
-                    "content": content_chunk["text"],
-                    "contentSummary": content_chunk["summary"],
-                    "score": result["score"],
-                    "datasetId": content_chunk["dataset_id"],
-                }
-            )
-    return res[:limit]
-
-
 @server_bp.route("/query", methods=["GET"])
 async def query():
     query_text = request.args.get("query")
@@ -549,67 +485,6 @@ async def rag_search():
     return jsonify({"status_code": 200, "detail": "success", "data": data_list})
 
 
-async def process_question(question, query_text, rag_chat_agent):
-    try:
-        dataset_id_strs = "11,12"
-        dataset_ids = dataset_id_strs.split(",")
-        search_type = "hybrid"
-
-        # 执行查询任务
-        query_results = await query_search(
-            query_text=question,
-            filters={"dataset_id": dataset_ids},
-            search_type=search_type,
-        )
-
-        resource = get_resource_manager()
-        chat_result_mapper = ChatResult(resource.mysql_client)
-
-        # 异步执行 chat 与 deepseek 的对话
-        chat_result = await rag_chat_agent.chat_with_deepseek(question, query_results)
-
-        # # 判断是否需要执行 study
-        study_task_id = None
-        if chat_result["status"] == 0:
-            study_task_id = study(question)["task_id"]
-
-        qwen_client = QwenClient()
-        llm_search = qwen_client.search_and_chat(
-            user_prompt=question
-        )
-        decision = await rag_chat_agent.make_decision(
-            question, chat_result, llm_search
-        )
-
-        # 构建返回的数据
-        data = {
-            "query": question,
-            "result": decision["result"],
-            "status": decision["status"],
-            "relevance_score": decision["relevance_score"],
-            # "used_tools": decision["used_tools"],
-        }
-
-        # 插入数据库
-        await chat_result_mapper.insert_chat_result(
-            question,
-            dataset_id_strs,
-            json.dumps(query_results, ensure_ascii=False),
-            chat_result["summary"],
-            chat_result["relevance_score"],
-            chat_result["status"],
-            llm_search["content"],
-            json.dumps(llm_search["search_results"], ensure_ascii=False),
-            1,
-            decision["result"],
-            study_task_id,
-        )
-        return data
-    except Exception as e:
-        print(f"Error processing question: {question}. Error: {str(e)}")
-        return {"query": question, "error": str(e)}
-
-
 @server_bp.route("/chat/history", methods=["GET"])
 async def chat_history():
     page_num = int(request.args.get("page", 1))
@@ -630,3 +505,63 @@ async def chat_history():
             },
         }
     )
+
+
+@server_bp.route("/upload/file", methods=["POST"])
+async def upload_pdf():
+    # 获取前端上传的文件
+    # 先等待 request.files 属性来确保文件已加载
+    files = await request.files
+
+    # 获取文件对象
+    file = files.get("file")
+
+    if file:
+        # 检查文件扩展名是否是 .pdf
+        if not file.filename.lower().endswith(".pdf"):
+            return jsonify(
+                {
+                    "status": "error",
+                    "message": "Invalid file format. Only PDF files are allowed.",
+                }
+            ), 400
+
+        # 获取文件名
+        filename = file.filename
+
+        book_id = f"book-{uuid.uuid4()}"
+        # 检查文件的 MIME 类型是否是 application/pdf
+        if file.content_type != "application/pdf":
+            return jsonify(
+                {
+                    "status": "error",
+                    "message": "Invalid MIME type. Only PDF files are allowed.",
+                }
+            ), 400
+
+        # 保存到本地(可选,视需要)
+        file_path = os.path.join("/tmp", book_id)  # 临时存储路径
+        await file.save(file_path)
+        resource = get_resource_manager()
+        books = Books(resource.mysql_client)
+        # 上传到 OSS
+        try:
+            oss_client = OSSClient()
+            # 上传文件到 OSS
+            oss_path = f"rag/pdfs/{book_id}"
+            oss_client.upload_file(file_path, oss_path)
+            await books.insert_book(book_id, filename, oss_path)
+            return jsonify({"status_code": 200, "detail": "success"})
+        except Exception as e:
+            return jsonify({"status_code": 500, "detail": str(e)})
+
+    else:
+        return jsonify({"status_code": 400, "detail": "No file uploaded."})
+
+
+@server_bp.route("/process/book", methods=["GET"])
+async def process_book():
+    # 创建异步任务来后台处理书籍
+    asyncio.create_task(handle_books())
+    # 返回立即响应
+    return jsonify({"status": "success", "message": "任务已提交后台处理"}), 200