Browse Source

clean_agent

丁云鹏 16 hours ago
parent
commit
cd264aa848
1 changed files with 120 additions and 80 deletions
  1. 120 80
      agent.py

+ 120 - 80
agent.py

@@ -9,14 +9,14 @@ import json
 import sys
 import os
 import time
+import threading
+import asyncio
+import concurrent.futures
+import fcntl
+import errno
 from typing import Any, Dict, List, Optional, TypedDict, Annotated
 from contextlib import asynccontextmanager
-import asyncio
 from utils.mysql_db import MysqlHelper
-
-# 保证可以导入本项目模块
-sys.path.append(os.path.dirname(os.path.abspath(__file__)))
-
 from fastapi import FastAPI, HTTPException, BackgroundTasks
 from fastapi.responses import JSONResponse
 from pydantic import BaseModel, Field
@@ -35,6 +35,9 @@ except ImportError:
 from utils.logging_config import get_logger
 from tools.agent_tools import QueryDataTool, IdentifyTool, UpdateDataTool, StructureTool
 
+# 保证可以导入本项目模块
+sys.path.append(os.path.dirname(os.path.abspath(__file__)))
+
 # 创建 logger
 logger = get_logger('Agent')
 
@@ -69,6 +72,8 @@ class ExtractRequest(BaseModel):
 
 # 全局变量
 identify_tool = None
+# 全局线程池
+THREAD_POOL = concurrent.futures.ThreadPoolExecutor(max_workers=20)
 
 def update_request_status(request_id: str, status: int):
     """
@@ -113,53 +118,80 @@ async def lifespan(app: FastAPI):
     identify_tool = IdentifyTool()
     
     # 启动后恢复中断的流程
-    # 异步恢复中断流程,避免阻塞启动
-    app.state.restore_task = asyncio.create_task(restore_interrupted_processes())
+    # 使用线程池恢复中断流程,避免阻塞启动
+    thread = threading.Thread(target=restore_interrupted_processes)
+    thread.daemon = True
+    thread.start()
+    app.state.restore_thread = thread
     
     yield
     
     # 关闭时执行
     logger.info("🛑 关闭 Knowledge Agent 服务...")
-    # 优雅取消后台恢复任务
-    restore_task = getattr(app.state, 'restore_task', None)
-    if restore_task and not restore_task.done():
-        restore_task.cancel()
-        try:
-            await restore_task
-        except asyncio.CancelledError:
-            logger.info("✅ 已取消后台恢复任务")
+    # 关闭线程池
+    THREAD_POOL.shutdown(wait=False)
+    logger.info("✅ 已关闭线程池")
 
-async def restore_interrupted_processes():
+def restore_interrupted_processes():
     """
     启动后恢复中断的流程
     1. 找到knowledge_request表中parsing_status=1的request_id,去请求 /parse/async
     2. 找到knowledge_request表中extraction_status=1的request_id和query,去请求 /extract
     3. 找到knowledge_request表中expansion_status=1的request_id和query,去请求 /expand
+    
+    使用文件锁确保只有一个进程执行恢复操作
     """
+    # 定义锁文件路径
+    lock_file_path = "/tmp/knowledge_agent_restore.lock"
+    
     try:
-        logger.info("🔄 开始恢复中断的流程...")
-        
-        # 等待服务完全启动
-        await asyncio.sleep(3)
+        # 创建或打开锁文件
+        lock_file = open(lock_file_path, 'w')
         
-        # 1. 恢复解析中断的流程
-        await restore_parsing_processes()
-        
-        # 2. 恢复提取中断的流程
-        await restore_extraction_processes()
+        try:
+            # 尝试获取文件锁(非阻塞模式)
+            fcntl.flock(lock_file, fcntl.LOCK_EX | fcntl.LOCK_NB)
+            logger.info("🔄 获取恢复锁成功,开始恢复中断的流程...")
+            
+            # 等待服务完全启动
+            time.sleep(3)
+            
+            # 1. 恢复解析中断的流程
+            restore_parsing_processes()
+            
+            # 2. 恢复提取中断的流程
+            restore_extraction_processes()
 
-        # 3. 恢复扩展中断的流程
-        await restore_expansion_processes()
-        
-        logger.info("✅ 流程恢复完成")
-        
+            # 3. 恢复扩展中断的流程
+            restore_expansion_processes()
+            
+            logger.info("✅ 流程恢复完成")
+            
+            # 释放锁
+            fcntl.flock(lock_file, fcntl.LOCK_UN)
+            
+        except IOError as e:
+            # 如果错误是因为无法获取锁(资源暂时不可用),说明已有其他进程在执行恢复
+            if e.errno == errno.EAGAIN:
+                logger.info("⏩ 另一个进程正在执行恢复操作,跳过本次恢复")
+            else:
+                logger.error(f"❌ 获取恢复锁时发生错误: {e}")
+        finally:
+            # 关闭锁文件
+            lock_file.close()
+            
     except Exception as e:
         logger.error(f"❌ 流程恢复失败: {e}")
+        # 尝试清理锁文件
+        try:
+            if os.path.exists(lock_file_path):
+                os.remove(lock_file_path)
+        except:
+            pass
 
-async def restore_parsing_processes():
+def restore_parsing_processes():
     """恢复解析中断的流程"""
     try:
-        from utils.mysql_db import MysqlHelper
         
         # 查询parsing_status=1的请求
         sql = "SELECT request_id FROM knowledge_request WHERE parsing_status = 1"
@@ -175,7 +207,7 @@ async def restore_parsing_processes():
             request_id = row[0]
             try:
                 # 调用 /parse/async 接口,带重试机制
-                await call_parse_async_with_retry(request_id)
+                call_parse_async_with_retry(request_id)
                 logger.info(f"✅ 恢复解析流程成功: request_id={request_id}")
             except Exception as e:
                 logger.error(f"❌ 恢复解析流程失败: request_id={request_id}, error={e}")
@@ -183,10 +215,9 @@ async def restore_parsing_processes():
     except Exception as e:
         logger.error(f"❌ 恢复解析流程时发生错误: {e}")
 
-async def restore_extraction_processes():
+def restore_extraction_processes():
     """恢复提取中断的流程"""
     try:
-        from utils.mysql_db import MysqlHelper
         
         # 查询extraction_status=1的请求和query
         sql = "SELECT request_id, query FROM knowledge_request WHERE extraction_status = 1"
@@ -203,7 +234,7 @@ async def restore_extraction_processes():
             query = row[1] if len(row) > 1 else ""
             try:
                 # 直接调用提取函数,带重试机制(函数内部会处理状态更新)
-                await call_extract_with_retry(request_id, query)
+                call_extract_with_retry(request_id, query)
                 logger.info(f"✅ 恢复提取流程成功: request_id={request_id}")
             except Exception as e:
                 logger.error(f"❌ 恢复提取流程失败: request_id={request_id}, error={e}")
@@ -211,10 +242,9 @@ async def restore_extraction_processes():
     except Exception as e:
         logger.error(f"❌ 恢复提取流程时发生错误: {e}")
 
-async def restore_expansion_processes():
+def restore_expansion_processes():
     """恢复扩展中断的流程"""
     try:
-        from utils.mysql_db import MysqlHelper
         
         # 查询expansion_status=1的请求和query
         sql = "SELECT request_id, query FROM knowledge_request WHERE expansion_status = 1"
@@ -231,7 +261,7 @@ async def restore_expansion_processes():
             query = row[1] if len(row) > 1 else ""
             try:
                 # 直接调用扩展函数,带重试机制(函数内部会处理状态更新)
-                await call_expand_with_retry(request_id, query)
+                call_expand_with_retry(request_id, query)
                 logger.info(f"✅ 恢复扩展流程成功: request_id={request_id}")
             except Exception as e:
                 logger.error(f"❌ 恢复扩展流程失败: request_id={request_id}, error={e}")
@@ -239,12 +269,13 @@ async def restore_expansion_processes():
     except Exception as e:
         logger.error(f"❌ 恢复扩展流程时发生错误: {e}")
 
-async def call_parse_async_with_retry(request_id: str, max_retries: int = 3):
+def call_parse_async_with_retry(request_id: str, max_retries: int = 3):
     """直接调用解析函数,带重试机制"""
     for attempt in range(max_retries):
         try:
-            # 直接调用后台处理函数
-            await process_request_background(request_id)
+            # 直接调用后台处理函数,使用线程池
+            future = THREAD_POOL.submit(process_request_background_sync, request_id)
+            result = future.result()
             logger.info(f"直接调用解析函数成功: request_id={request_id}")
             return
                     
@@ -253,11 +284,11 @@ async def call_parse_async_with_retry(request_id: str, max_retries: int = 3):
         
         # 如果不是最后一次尝试,等待后重试
         if attempt < max_retries - 1:
-            await asyncio.sleep(2 ** attempt)  # 指数退避
+            time.sleep(2 ** attempt)  # 指数退避
     
     logger.error(f"直接调用解析函数最终失败: request_id={request_id}, 已重试{max_retries}次")
 
-async def call_extract_with_retry(request_id: str, query: str, max_retries: int = 3):
+def call_extract_with_retry(request_id: str, query: str, max_retries: int = 3):
     """直接调用提取函数,带重试机制"""
     for attempt in range(max_retries):
         try:
@@ -265,17 +296,13 @@ async def call_extract_with_retry(request_id: str, query: str, max_retries: int
             update_extract_status(request_id, 1)
             
             # 直接调用提取函数(同步函数,在线程池中执行)
-            from agents.clean_agent.agent import execute_agent_with_api
-            import concurrent.futures
             
-            # 在线程池中执行同步函数
-            loop = asyncio.get_event_loop()
-            with concurrent.futures.ThreadPoolExecutor() as executor:
-                result = await loop.run_in_executor(
-                    executor, 
-                    execute_agent_with_api, 
-                    json.dumps({"query_word": query, "request_id": request_id})
-                )
+            # 在全局线程池中执行同步函数
+            future = THREAD_POOL.submit(
+                execute_agent_with_api, 
+                json.dumps({"query_word": query, "request_id": request_id})
+            )
+            result = future.result()
             
             # 更新状态为处理完成
             update_extract_status(request_id, 2)
@@ -289,19 +316,19 @@ async def call_extract_with_retry(request_id: str, query: str, max_retries: int
         
         # 如果不是最后一次尝试,等待后重试
         if attempt < max_retries - 1:
-            await asyncio.sleep(2 ** attempt)  # 指数退避
+            time.sleep(2 ** attempt)  # 指数退避
     
     logger.error(f"直接调用提取函数最终失败: request_id={request_id}, 已重试{max_retries}次")
 
-async def call_expand_with_retry(request_id: str, query: str, max_retries: int = 3):
+def call_expand_with_retry(request_id: str, query: str, max_retries: int = 3):
     """直接调用扩展函数,带重试机制"""
     for attempt in range(max_retries):
         try:
             # 直接调用扩展函数
-            from agents.expand_agent.agent import execute_expand_agent_with_api
             
-            # 直接调用同步函数,不使用线程池
-            result = execute_expand_agent_with_api(request_id, query)
+            # 在全局线程池中执行同步函数
+            future = THREAD_POOL.submit(execute_expand_agent_with_api, request_id, query)
+            result = future.result()
             
             logger.info(f"直接调用扩展函数成功: request_id={request_id}")
             return
@@ -311,7 +338,7 @@ async def call_expand_with_retry(request_id: str, query: str, max_retries: int =
         
         # 如果不是最后一次尝试,等待后重试
         if attempt < max_retries - 1:
-            await asyncio.sleep(2 ** attempt)  # 指数退避
+            time.sleep(2 ** attempt)  # 指数退避
     
     logger.error(f"直接调用扩展函数最终失败: request_id={request_id}, 已重试{max_retries}次")
 
@@ -605,15 +632,16 @@ async def parse_processing_async(request: TriggerRequest, background_tasks: Back
                 }
             RUNNING_REQUESTS.add(request.requestId)
         
-        async def _background_wrapper(rid: str):
+        def _background_wrapper_sync(rid: str):
             try:
-                await process_request_background(rid)
+                process_request_background_sync(rid)
             finally:
-                async with RUNNING_LOCK:
+                # 使用线程安全的方式移除请求ID
+                with threading.Lock():
                     RUNNING_REQUESTS.discard(rid)
         
-        # 直接使用 asyncio 创建后台任务(不阻塞当前请求返回)
-        asyncio.create_task(_background_wrapper(request.requestId))
+        # 使用全局线程池提交后台任务
+        THREAD_POOL.submit(_background_wrapper_sync, request.requestId)
         
         # 立即返回(不阻塞)
         return {
@@ -627,8 +655,8 @@ async def parse_processing_async(request: TriggerRequest, background_tasks: Back
         logger.error(f"提交异步任务失败: {e}")
         raise HTTPException(status_code=500, detail=f"提交任务失败: {str(e)}")
 
-async def process_request_background(request_id: str):
-    """后台处理请求"""
+def process_request_background_sync(request_id: str):
+    """后台处理请求(同步版本)"""
     try:
         logger.info(f"开始后台处理: requestId={request_id}")
         
@@ -663,6 +691,18 @@ async def process_request_background(request_id: str):
         # 处理失败,更新状态为3
         update_request_status(request_id, 3)
 
+async def process_request_background(request_id: str):
+    """后台处理请求(异步版本,为了兼容性保留)"""
+    try:
+        # 在线程池中执行同步版本
+        loop = asyncio.get_event_loop()
+        with concurrent.futures.ThreadPoolExecutor() as executor:
+            await loop.run_in_executor(executor, process_request_background_sync, request_id)
+    except Exception as e:
+        logger.error(f"后台处理失败: requestId={request_id}, error={e}")
+        # 处理失败,更新状态为3
+        update_request_status(request_id, 3)
+
 
 extraction_requests: set = set()
 
@@ -689,8 +729,8 @@ async def extract(request: ExtractRequest):
         # 更新状态为处理中
         update_extract_status(requestId, 1)
         
-        # 创建异步任务执行Agent
-        async def _execute_extract_async():
+        # 创建线程池任务执行Agent
+        def _execute_extract_sync():
             try:
                 result = execute_agent_with_api(json.dumps({"query_word": query, "request_id": requestId}))
                 # 更新状态为处理完成
@@ -701,12 +741,12 @@ async def extract(request: ExtractRequest):
                 logger.error(f"异步提取任务失败: requestId={requestId}, error={e}")
                 # 更新状态为处理失败
                 update_extract_status(requestId, 3)
-                raise
             finally:
+                # 移除请求ID
                 extraction_requests.discard(requestId)
         
-        # 创建异步任务但不等待完成
-        asyncio.create_task(_execute_extract_async())
+        # 使用全局线程池提交任务
+        THREAD_POOL.submit(_execute_extract_sync)
         
         # 立即返回状态
         return {"status": 1, "requestId": requestId, "message": "提取任务已启动并在后台处理"}
@@ -736,7 +776,8 @@ async def expand(request: ExpandRequest):
         
         # 并发防抖:同一 requestId 只允许一个在运行
         expansion_requests = getattr(app.state, 'expansion_requests', set())
-        async with RUNNING_LOCK:
+        # 使用线程锁而不是asyncio锁
+        with threading.Lock():
             if requestId in expansion_requests:
                 return {"status": 1, "requestId": requestId, "message": "扩展查询已在处理中"}
             # 如果集合不存在,创建它
@@ -747,10 +788,10 @@ async def expand(request: ExpandRequest):
         # 立即更新状态为处理中
         _update_expansion_status(requestId, 1)
         
-        # 创建异步任务执行扩展Agent
-        async def _execute_expand_async():
+        # 创建线程池任务执行扩展Agent
+        def _execute_expand_sync():
             try:
-                # 直接调用同步函数,使用线程池
+                # 直接调用同步函数,使用线程池
                 from agents.expand_agent.agent import execute_expand_agent_with_api
                 result = execute_expand_agent_with_api(requestId, query)
                 # 更新状态为处理完成
@@ -761,15 +802,14 @@ async def expand(request: ExpandRequest):
                 logger.error(f"异步扩展查询任务失败: requestId={requestId}, error={e}")
                 # 更新状态为处理失败
                 _update_expansion_status(requestId, 3)
-                raise
             finally:
                 # 无论成功失败,都从运行集合中移除
-                async with RUNNING_LOCK:
+                with threading.Lock():
                     if hasattr(app.state, 'expansion_requests'):
                         app.state.expansion_requests.discard(requestId)
         
-        # 创建异步任务但不等待完成
-        asyncio.create_task(_execute_expand_async())
+        # 使用全局线程池提交任务
+        THREAD_POOL.submit(_execute_expand_sync)
         
         # 立即返回状态
         return {"status": 1, "requestId": requestId, "message": "扩展查询任务已启动并在后台处理"}