|
@@ -9,14 +9,14 @@ import json
|
|
|
import sys
|
|
|
import os
|
|
|
import time
|
|
|
+import threading
|
|
|
+import asyncio
|
|
|
+import concurrent.futures
|
|
|
+import fcntl
|
|
|
+import errno
|
|
|
from typing import Any, Dict, List, Optional, TypedDict, Annotated
|
|
|
from contextlib import asynccontextmanager
|
|
|
-import asyncio
|
|
|
from utils.mysql_db import MysqlHelper
|
|
|
-
|
|
|
-# 保证可以导入本项目模块
|
|
|
-sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
|
-
|
|
|
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
|
|
from fastapi.responses import JSONResponse
|
|
|
from pydantic import BaseModel, Field
|
|
@@ -35,6 +35,9 @@ except ImportError:
|
|
|
from utils.logging_config import get_logger
|
|
|
from tools.agent_tools import QueryDataTool, IdentifyTool, UpdateDataTool, StructureTool
|
|
|
|
|
|
+# 保证可以导入本项目模块
|
|
|
+sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
|
+
|
|
|
# 创建 logger
|
|
|
logger = get_logger('Agent')
|
|
|
|
|
@@ -69,6 +72,8 @@ class ExtractRequest(BaseModel):
|
|
|
|
|
|
# 全局变量
|
|
|
identify_tool = None
|
|
|
+# 全局线程池
|
|
|
+THREAD_POOL = concurrent.futures.ThreadPoolExecutor(max_workers=20)
|
|
|
|
|
|
def update_request_status(request_id: str, status: int):
|
|
|
"""
|
|
@@ -113,53 +118,80 @@ async def lifespan(app: FastAPI):
|
|
|
identify_tool = IdentifyTool()
|
|
|
|
|
|
# 启动后恢复中断的流程
|
|
|
- # 异步恢复中断流程,避免阻塞启动
|
|
|
- app.state.restore_task = asyncio.create_task(restore_interrupted_processes())
|
|
|
+ # 使用线程池恢复中断流程,避免阻塞启动
|
|
|
+ thread = threading.Thread(target=restore_interrupted_processes)
|
|
|
+ thread.daemon = True
|
|
|
+ thread.start()
|
|
|
+ app.state.restore_thread = thread
|
|
|
|
|
|
yield
|
|
|
|
|
|
# 关闭时执行
|
|
|
logger.info("🛑 关闭 Knowledge Agent 服务...")
|
|
|
- # 优雅取消后台恢复任务
|
|
|
- restore_task = getattr(app.state, 'restore_task', None)
|
|
|
- if restore_task and not restore_task.done():
|
|
|
- restore_task.cancel()
|
|
|
- try:
|
|
|
- await restore_task
|
|
|
- except asyncio.CancelledError:
|
|
|
- logger.info("✅ 已取消后台恢复任务")
|
|
|
+ # 关闭线程池
|
|
|
+ THREAD_POOL.shutdown(wait=False)
|
|
|
+ logger.info("✅ 已关闭线程池")
|
|
|
|
|
|
-async def restore_interrupted_processes():
|
|
|
+def restore_interrupted_processes():
|
|
|
"""
|
|
|
启动后恢复中断的流程
|
|
|
1. 找到knowledge_request表中parsing_status=1的request_id,去请求 /parse/async
|
|
|
2. 找到knowledge_request表中extraction_status=1的request_id和query,去请求 /extract
|
|
|
3. 找到knowledge_request表中expansion_status=1的request_id和query,去请求 /expand
|
|
|
+
|
|
|
+ 使用文件锁确保只有一个进程执行恢复操作
|
|
|
"""
|
|
|
+ # 定义锁文件路径
|
|
|
+ lock_file_path = "/tmp/knowledge_agent_restore.lock"
|
|
|
+
|
|
|
try:
|
|
|
- logger.info("🔄 开始恢复中断的流程...")
|
|
|
-
|
|
|
- # 等待服务完全启动
|
|
|
- await asyncio.sleep(3)
|
|
|
+ # 创建或打开锁文件
|
|
|
+ lock_file = open(lock_file_path, 'w')
|
|
|
|
|
|
- # 1. 恢复解析中断的流程
|
|
|
- await restore_parsing_processes()
|
|
|
-
|
|
|
- # 2. 恢复提取中断的流程
|
|
|
- await restore_extraction_processes()
|
|
|
+ try:
|
|
|
+ # 尝试获取文件锁(非阻塞模式)
|
|
|
+ fcntl.flock(lock_file, fcntl.LOCK_EX | fcntl.LOCK_NB)
|
|
|
+ logger.info("🔄 获取恢复锁成功,开始恢复中断的流程...")
|
|
|
+
|
|
|
+ # 等待服务完全启动
|
|
|
+ time.sleep(3)
|
|
|
+
|
|
|
+ # 1. 恢复解析中断的流程
|
|
|
+ restore_parsing_processes()
|
|
|
+
|
|
|
+ # 2. 恢复提取中断的流程
|
|
|
+ restore_extraction_processes()
|
|
|
|
|
|
- # 3. 恢复扩展中断的流程
|
|
|
- await restore_expansion_processes()
|
|
|
-
|
|
|
- logger.info("✅ 流程恢复完成")
|
|
|
-
|
|
|
+ # 3. 恢复扩展中断的流程
|
|
|
+ restore_expansion_processes()
|
|
|
+
|
|
|
+ logger.info("✅ 流程恢复完成")
|
|
|
+
|
|
|
+ # 释放锁
|
|
|
+ fcntl.flock(lock_file, fcntl.LOCK_UN)
|
|
|
+
|
|
|
+ except IOError as e:
|
|
|
+ # 如果错误是因为无法获取锁(资源暂时不可用),说明已有其他进程在执行恢复
|
|
|
+ if e.errno == errno.EAGAIN:
|
|
|
+ logger.info("⏩ 另一个进程正在执行恢复操作,跳过本次恢复")
|
|
|
+ else:
|
|
|
+ logger.error(f"❌ 获取恢复锁时发生错误: {e}")
|
|
|
+ finally:
|
|
|
+ # 关闭锁文件
|
|
|
+ lock_file.close()
|
|
|
+
|
|
|
except Exception as e:
|
|
|
logger.error(f"❌ 流程恢复失败: {e}")
|
|
|
+ # 尝试清理锁文件
|
|
|
+ try:
|
|
|
+ if os.path.exists(lock_file_path):
|
|
|
+ os.remove(lock_file_path)
|
|
|
+ except:
|
|
|
+ pass
|
|
|
|
|
|
-async def restore_parsing_processes():
|
|
|
+def restore_parsing_processes():
|
|
|
"""恢复解析中断的流程"""
|
|
|
try:
|
|
|
- from utils.mysql_db import MysqlHelper
|
|
|
|
|
|
# 查询parsing_status=1的请求
|
|
|
sql = "SELECT request_id FROM knowledge_request WHERE parsing_status = 1"
|
|
@@ -175,7 +207,7 @@ async def restore_parsing_processes():
|
|
|
request_id = row[0]
|
|
|
try:
|
|
|
# 调用 /parse/async 接口,带重试机制
|
|
|
- await call_parse_async_with_retry(request_id)
|
|
|
+ call_parse_async_with_retry(request_id)
|
|
|
logger.info(f"✅ 恢复解析流程成功: request_id={request_id}")
|
|
|
except Exception as e:
|
|
|
logger.error(f"❌ 恢复解析流程失败: request_id={request_id}, error={e}")
|
|
@@ -183,10 +215,9 @@ async def restore_parsing_processes():
|
|
|
except Exception as e:
|
|
|
logger.error(f"❌ 恢复解析流程时发生错误: {e}")
|
|
|
|
|
|
-async def restore_extraction_processes():
|
|
|
+def restore_extraction_processes():
|
|
|
"""恢复提取中断的流程"""
|
|
|
try:
|
|
|
- from utils.mysql_db import MysqlHelper
|
|
|
|
|
|
# 查询extraction_status=1的请求和query
|
|
|
sql = "SELECT request_id, query FROM knowledge_request WHERE extraction_status = 1"
|
|
@@ -203,7 +234,7 @@ async def restore_extraction_processes():
|
|
|
query = row[1] if len(row) > 1 else ""
|
|
|
try:
|
|
|
# 直接调用提取函数,带重试机制(函数内部会处理状态更新)
|
|
|
- await call_extract_with_retry(request_id, query)
|
|
|
+ call_extract_with_retry(request_id, query)
|
|
|
logger.info(f"✅ 恢复提取流程成功: request_id={request_id}")
|
|
|
except Exception as e:
|
|
|
logger.error(f"❌ 恢复提取流程失败: request_id={request_id}, error={e}")
|
|
@@ -211,10 +242,9 @@ async def restore_extraction_processes():
|
|
|
except Exception as e:
|
|
|
logger.error(f"❌ 恢复提取流程时发生错误: {e}")
|
|
|
|
|
|
-async def restore_expansion_processes():
|
|
|
+def restore_expansion_processes():
|
|
|
"""恢复扩展中断的流程"""
|
|
|
try:
|
|
|
- from utils.mysql_db import MysqlHelper
|
|
|
|
|
|
# 查询expansion_status=1的请求和query
|
|
|
sql = "SELECT request_id, query FROM knowledge_request WHERE expansion_status = 1"
|
|
@@ -231,7 +261,7 @@ async def restore_expansion_processes():
|
|
|
query = row[1] if len(row) > 1 else ""
|
|
|
try:
|
|
|
# 直接调用扩展函数,带重试机制(函数内部会处理状态更新)
|
|
|
- await call_expand_with_retry(request_id, query)
|
|
|
+ call_expand_with_retry(request_id, query)
|
|
|
logger.info(f"✅ 恢复扩展流程成功: request_id={request_id}")
|
|
|
except Exception as e:
|
|
|
logger.error(f"❌ 恢复扩展流程失败: request_id={request_id}, error={e}")
|
|
@@ -239,12 +269,13 @@ async def restore_expansion_processes():
|
|
|
except Exception as e:
|
|
|
logger.error(f"❌ 恢复扩展流程时发生错误: {e}")
|
|
|
|
|
|
-async def call_parse_async_with_retry(request_id: str, max_retries: int = 3):
|
|
|
+def call_parse_async_with_retry(request_id: str, max_retries: int = 3):
|
|
|
"""直接调用解析函数,带重试机制"""
|
|
|
for attempt in range(max_retries):
|
|
|
try:
|
|
|
- # 直接调用后台处理函数
|
|
|
- await process_request_background(request_id)
|
|
|
+ # 直接调用后台处理函数,使用线程池
|
|
|
+ future = THREAD_POOL.submit(process_request_background_sync, request_id)
|
|
|
+ result = future.result()
|
|
|
logger.info(f"直接调用解析函数成功: request_id={request_id}")
|
|
|
return
|
|
|
|
|
@@ -253,11 +284,11 @@ async def call_parse_async_with_retry(request_id: str, max_retries: int = 3):
|
|
|
|
|
|
# 如果不是最后一次尝试,等待后重试
|
|
|
if attempt < max_retries - 1:
|
|
|
- await asyncio.sleep(2 ** attempt) # 指数退避
|
|
|
+ time.sleep(2 ** attempt) # 指数退避
|
|
|
|
|
|
logger.error(f"直接调用解析函数最终失败: request_id={request_id}, 已重试{max_retries}次")
|
|
|
|
|
|
-async def call_extract_with_retry(request_id: str, query: str, max_retries: int = 3):
|
|
|
+def call_extract_with_retry(request_id: str, query: str, max_retries: int = 3):
|
|
|
"""直接调用提取函数,带重试机制"""
|
|
|
for attempt in range(max_retries):
|
|
|
try:
|
|
@@ -265,17 +296,13 @@ async def call_extract_with_retry(request_id: str, query: str, max_retries: int
|
|
|
update_extract_status(request_id, 1)
|
|
|
|
|
|
# 直接调用提取函数(同步函数,在线程池中执行)
|
|
|
- from agents.clean_agent.agent import execute_agent_with_api
|
|
|
- import concurrent.futures
|
|
|
|
|
|
- # 在线程池中执行同步函数
|
|
|
- loop = asyncio.get_event_loop()
|
|
|
- with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
|
- result = await loop.run_in_executor(
|
|
|
- executor,
|
|
|
- execute_agent_with_api,
|
|
|
- json.dumps({"query_word": query, "request_id": request_id})
|
|
|
- )
|
|
|
+ # 在全局线程池中执行同步函数
|
|
|
+ future = THREAD_POOL.submit(
|
|
|
+ execute_agent_with_api,
|
|
|
+ json.dumps({"query_word": query, "request_id": request_id})
|
|
|
+ )
|
|
|
+ result = future.result()
|
|
|
|
|
|
# 更新状态为处理完成
|
|
|
update_extract_status(request_id, 2)
|
|
@@ -289,19 +316,19 @@ async def call_extract_with_retry(request_id: str, query: str, max_retries: int
|
|
|
|
|
|
# 如果不是最后一次尝试,等待后重试
|
|
|
if attempt < max_retries - 1:
|
|
|
- await asyncio.sleep(2 ** attempt) # 指数退避
|
|
|
+ time.sleep(2 ** attempt) # 指数退避
|
|
|
|
|
|
logger.error(f"直接调用提取函数最终失败: request_id={request_id}, 已重试{max_retries}次")
|
|
|
|
|
|
-async def call_expand_with_retry(request_id: str, query: str, max_retries: int = 3):
|
|
|
+def call_expand_with_retry(request_id: str, query: str, max_retries: int = 3):
|
|
|
"""直接调用扩展函数,带重试机制"""
|
|
|
for attempt in range(max_retries):
|
|
|
try:
|
|
|
# 直接调用扩展函数
|
|
|
- from agents.expand_agent.agent import execute_expand_agent_with_api
|
|
|
|
|
|
- # 直接调用同步函数,不使用线程池
|
|
|
- result = execute_expand_agent_with_api(request_id, query)
|
|
|
+ # 在全局线程池中执行同步函数
|
|
|
+ future = THREAD_POOL.submit(execute_expand_agent_with_api, request_id, query)
|
|
|
+ result = future.result()
|
|
|
|
|
|
logger.info(f"直接调用扩展函数成功: request_id={request_id}")
|
|
|
return
|
|
@@ -311,7 +338,7 @@ async def call_expand_with_retry(request_id: str, query: str, max_retries: int =
|
|
|
|
|
|
# 如果不是最后一次尝试,等待后重试
|
|
|
if attempt < max_retries - 1:
|
|
|
- await asyncio.sleep(2 ** attempt) # 指数退避
|
|
|
+ time.sleep(2 ** attempt) # 指数退避
|
|
|
|
|
|
logger.error(f"直接调用扩展函数最终失败: request_id={request_id}, 已重试{max_retries}次")
|
|
|
|
|
@@ -605,15 +632,16 @@ async def parse_processing_async(request: TriggerRequest, background_tasks: Back
|
|
|
}
|
|
|
RUNNING_REQUESTS.add(request.requestId)
|
|
|
|
|
|
- async def _background_wrapper(rid: str):
|
|
|
+ def _background_wrapper_sync(rid: str):
|
|
|
try:
|
|
|
- await process_request_background(rid)
|
|
|
+ process_request_background_sync(rid)
|
|
|
finally:
|
|
|
- async with RUNNING_LOCK:
|
|
|
+ # 使用线程安全的方式移除请求ID
|
|
|
+ with threading.Lock():
|
|
|
RUNNING_REQUESTS.discard(rid)
|
|
|
|
|
|
- # 直接使用 asyncio 创建后台任务(不阻塞当前请求返回)
|
|
|
- asyncio.create_task(_background_wrapper(request.requestId))
|
|
|
+ # 使用全局线程池提交后台任务
|
|
|
+ THREAD_POOL.submit(_background_wrapper_sync, request.requestId)
|
|
|
|
|
|
# 立即返回(不阻塞)
|
|
|
return {
|
|
@@ -627,8 +655,8 @@ async def parse_processing_async(request: TriggerRequest, background_tasks: Back
|
|
|
logger.error(f"提交异步任务失败: {e}")
|
|
|
raise HTTPException(status_code=500, detail=f"提交任务失败: {str(e)}")
|
|
|
|
|
|
-async def process_request_background(request_id: str):
|
|
|
- """后台处理请求"""
|
|
|
+def process_request_background_sync(request_id: str):
|
|
|
+ """后台处理请求(同步版本)"""
|
|
|
try:
|
|
|
logger.info(f"开始后台处理: requestId={request_id}")
|
|
|
|
|
@@ -663,6 +691,18 @@ async def process_request_background(request_id: str):
|
|
|
# 处理失败,更新状态为3
|
|
|
update_request_status(request_id, 3)
|
|
|
|
|
|
+async def process_request_background(request_id: str):
|
|
|
+ """后台处理请求(异步版本,为了兼容性保留)"""
|
|
|
+ try:
|
|
|
+ # 在线程池中执行同步版本
|
|
|
+ loop = asyncio.get_event_loop()
|
|
|
+ with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
|
+ await loop.run_in_executor(executor, process_request_background_sync, request_id)
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"后台处理失败: requestId={request_id}, error={e}")
|
|
|
+ # 处理失败,更新状态为3
|
|
|
+ update_request_status(request_id, 3)
|
|
|
+
|
|
|
|
|
|
extraction_requests: set = set()
|
|
|
|
|
@@ -689,8 +729,8 @@ async def extract(request: ExtractRequest):
|
|
|
# 更新状态为处理中
|
|
|
update_extract_status(requestId, 1)
|
|
|
|
|
|
- # 创建异步任务执行Agent
|
|
|
- async def _execute_extract_async():
|
|
|
+ # 创建线程池任务执行Agent
|
|
|
+ def _execute_extract_sync():
|
|
|
try:
|
|
|
result = execute_agent_with_api(json.dumps({"query_word": query, "request_id": requestId}))
|
|
|
# 更新状态为处理完成
|
|
@@ -701,12 +741,12 @@ async def extract(request: ExtractRequest):
|
|
|
logger.error(f"异步提取任务失败: requestId={requestId}, error={e}")
|
|
|
# 更新状态为处理失败
|
|
|
update_extract_status(requestId, 3)
|
|
|
- raise
|
|
|
finally:
|
|
|
+ # 移除请求ID
|
|
|
extraction_requests.discard(requestId)
|
|
|
|
|
|
- # 创建异步任务但不等待完成
|
|
|
- asyncio.create_task(_execute_extract_async())
|
|
|
+ # 使用全局线程池提交任务
|
|
|
+ THREAD_POOL.submit(_execute_extract_sync)
|
|
|
|
|
|
# 立即返回状态
|
|
|
return {"status": 1, "requestId": requestId, "message": "提取任务已启动并在后台处理"}
|
|
@@ -736,7 +776,8 @@ async def expand(request: ExpandRequest):
|
|
|
|
|
|
# 并发防抖:同一 requestId 只允许一个在运行
|
|
|
expansion_requests = getattr(app.state, 'expansion_requests', set())
|
|
|
- async with RUNNING_LOCK:
|
|
|
+ # 使用线程锁而不是asyncio锁
|
|
|
+ with threading.Lock():
|
|
|
if requestId in expansion_requests:
|
|
|
return {"status": 1, "requestId": requestId, "message": "扩展查询已在处理中"}
|
|
|
# 如果集合不存在,创建它
|
|
@@ -747,10 +788,10 @@ async def expand(request: ExpandRequest):
|
|
|
# 立即更新状态为处理中
|
|
|
_update_expansion_status(requestId, 1)
|
|
|
|
|
|
- # 创建异步任务执行扩展Agent
|
|
|
- async def _execute_expand_async():
|
|
|
+ # 创建线程池任务执行扩展Agent
|
|
|
+ def _execute_expand_sync():
|
|
|
try:
|
|
|
- # 直接调用同步函数,不使用线程池
|
|
|
+ # 直接调用同步函数,使用线程池
|
|
|
from agents.expand_agent.agent import execute_expand_agent_with_api
|
|
|
result = execute_expand_agent_with_api(requestId, query)
|
|
|
# 更新状态为处理完成
|
|
@@ -761,15 +802,14 @@ async def expand(request: ExpandRequest):
|
|
|
logger.error(f"异步扩展查询任务失败: requestId={requestId}, error={e}")
|
|
|
# 更新状态为处理失败
|
|
|
_update_expansion_status(requestId, 3)
|
|
|
- raise
|
|
|
finally:
|
|
|
# 无论成功失败,都从运行集合中移除
|
|
|
- async with RUNNING_LOCK:
|
|
|
+ with threading.Lock():
|
|
|
if hasattr(app.state, 'expansion_requests'):
|
|
|
app.state.expansion_requests.discard(requestId)
|
|
|
|
|
|
- # 创建异步任务但不等待完成
|
|
|
- asyncio.create_task(_execute_expand_async())
|
|
|
+ # 使用全局线程池提交任务
|
|
|
+ THREAD_POOL.submit(_execute_expand_sync)
|
|
|
|
|
|
# 立即返回状态
|
|
|
return {"status": 1, "requestId": requestId, "message": "扩展查询任务已启动并在后台处理"}
|