|
@@ -15,6 +15,7 @@ import concurrent.futures
|
|
|
import fcntl
|
|
|
import errno
|
|
|
import multiprocessing
|
|
|
+import signal
|
|
|
from typing import Any, Dict, List, Optional, TypedDict, Annotated
|
|
|
from contextlib import asynccontextmanager
|
|
|
|
|
@@ -78,6 +79,41 @@ class ExtractRequest(BaseModel):
|
|
|
identify_tool = None
|
|
|
# 全局线程池
|
|
|
THREAD_POOL = concurrent.futures.ThreadPoolExecutor(max_workers=20)
|
|
|
+# 活跃的进程池列表
|
|
|
+ACTIVE_POOLS = []
|
|
|
+POOLS_LOCK = threading.Lock()
|
|
|
+
|
|
|
+def cleanup_all_pools():
|
|
|
+ """清理所有活跃的进程池"""
|
|
|
+ global ACTIVE_POOLS, POOLS_LOCK
|
|
|
+ with POOLS_LOCK:
|
|
|
+ logger.info(f"开始清理 {len(ACTIVE_POOLS)} 个活跃进程池...")
|
|
|
+ for pool in ACTIVE_POOLS:
|
|
|
+ try:
|
|
|
+ logger.info("正在终止进程池...")
|
|
|
+ pool.terminate()
|
|
|
+ pool.join(timeout=5) # 等待5秒
|
|
|
+ if pool._state != 'CLOSED':
|
|
|
+ logger.warning("进程池未正常关闭,强制终止")
|
|
|
+ pool.kill()
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"清理进程池时出错: {e}")
|
|
|
+ ACTIVE_POOLS.clear()
|
|
|
+ logger.info("所有进程池已清理")
|
|
|
+
|
|
|
+def signal_handler(signum, frame):
|
|
|
+ """信号处理器"""
|
|
|
+ logger.info(f"收到信号 {signum},开始清理...")
|
|
|
+ cleanup_all_pools()
|
|
|
+ # 关闭线程池
|
|
|
+ THREAD_POOL.shutdown(wait=False)
|
|
|
+ logger.info("清理完成,退出程序")
|
|
|
+ sys.exit(0)
|
|
|
+
|
|
|
+def register_signal_handlers():
|
|
|
+ """注册信号处理器"""
|
|
|
+ signal.signal(signal.SIGTERM, signal_handler)
|
|
|
+ signal.signal(signal.SIGINT, signal_handler)
|
|
|
|
|
|
def get_identify_tool():
|
|
|
"""惰性初始化 IdentifyTool,确保在子进程中可用"""
|
|
@@ -124,6 +160,9 @@ async def lifespan(app: FastAPI):
|
|
|
# 启动时执行
|
|
|
logger.info("🚀 启动 Knowledge Agent 服务...")
|
|
|
|
|
|
+ # 注册信号处理器
|
|
|
+ register_signal_handlers()
|
|
|
+
|
|
|
# 初始化全局工具
|
|
|
global identify_tool
|
|
|
identify_tool = IdentifyTool()
|
|
@@ -139,6 +178,8 @@ async def lifespan(app: FastAPI):
|
|
|
|
|
|
# 关闭时执行
|
|
|
logger.info("🛑 关闭 Knowledge Agent 服务...")
|
|
|
+ # 清理所有进程池
|
|
|
+ cleanup_all_pools()
|
|
|
# 关闭线程池
|
|
|
THREAD_POOL.shutdown(wait=False)
|
|
|
logger.info("✅ 已关闭线程池")
|
|
@@ -542,9 +583,26 @@ def create_langgraph_workflow():
|
|
|
except RuntimeError:
|
|
|
pass # 如果已经设置过,忽略错误
|
|
|
|
|
|
- with multiprocessing.Pool(processes=7) as pool:
|
|
|
+ pool = None
|
|
|
+ try:
|
|
|
+ pool = multiprocessing.Pool(processes=7)
|
|
|
+ with POOLS_LOCK:
|
|
|
+ ACTIVE_POOLS.append(pool)
|
|
|
+
|
|
|
logger.info(f"开始多进程处理: 数量={len(process_args)}, 使用7个进程")
|
|
|
results = pool.map(process_single_item, process_args)
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"多进程处理异常: {e}")
|
|
|
+ results = []
|
|
|
+ finally:
|
|
|
+ if pool is not None:
|
|
|
+ logger.info("正在关闭多进程池...")
|
|
|
+ pool.close()
|
|
|
+ pool.join()
|
|
|
+ with POOLS_LOCK:
|
|
|
+ if pool in ACTIVE_POOLS:
|
|
|
+ ACTIVE_POOLS.remove(pool)
|
|
|
+ logger.info("多进程池已关闭")
|
|
|
|
|
|
# 恢复原始启动方法
|
|
|
try:
|