|
@@ -5,8 +5,10 @@ FastAPI服务主文件
|
|
|
提供搜索API端点
|
|
提供搜索API端点
|
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
|
|
+import asyncio
|
|
|
import logging
|
|
import logging
|
|
|
from typing import List, Dict, Any
|
|
from typing import List, Dict, Any
|
|
|
|
|
+from concurrent.futures import ThreadPoolExecutor
|
|
|
from fastapi import FastAPI, HTTPException
|
|
from fastapi import FastAPI, HTTPException
|
|
|
from pydantic import BaseModel, Field
|
|
from pydantic import BaseModel, Field
|
|
|
|
|
|
|
@@ -15,6 +17,7 @@ from api.data_converter import (
|
|
|
convert_pipeline_output_to_api_response
|
|
convert_pipeline_output_to_api_response
|
|
|
)
|
|
)
|
|
|
from api.pipeline_wrapper import PipelineWrapper
|
|
from api.pipeline_wrapper import PipelineWrapper
|
|
|
|
|
+from api.request_context import RequestContextManager
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
@@ -25,8 +28,10 @@ app = FastAPI(
|
|
|
version="1.0.0"
|
|
version="1.0.0"
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
-# 全局Pipeline包装器实例
|
|
|
|
|
-pipeline_wrapper: PipelineWrapper = None
|
|
|
|
|
|
|
+# 并发控制
|
|
|
|
|
+REQUEST_SEMAPHORE = None
|
|
|
|
|
+MAX_CONCURRENT_REQUESTS = 5
|
|
|
|
|
+EXECUTOR = None # 线程池执行器,用于运行阻塞的pipeline代码
|
|
|
|
|
|
|
|
|
|
|
|
|
# 请求模型
|
|
# 请求模型
|
|
@@ -78,120 +83,132 @@ class SearchResponse(BaseModel):
|
|
|
|
|
|
|
|
@app.on_event("startup")
|
|
@app.on_event("startup")
|
|
|
async def startup_event():
|
|
async def startup_event():
|
|
|
- """应用启动时初始化Pipeline包装器"""
|
|
|
|
|
- global pipeline_wrapper
|
|
|
|
|
- try:
|
|
|
|
|
- pipeline_wrapper = PipelineWrapper()
|
|
|
|
|
- logger.info("Pipeline包装器初始化成功")
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- logger.error(f"Pipeline包装器初始化失败: {e}", exc_info=True)
|
|
|
|
|
- raise
|
|
|
|
|
|
|
+ """应用启动时初始化并发限流器和线程池"""
|
|
|
|
|
+ global REQUEST_SEMAPHORE, EXECUTOR
|
|
|
|
|
+ REQUEST_SEMAPHORE = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
|
|
|
|
|
+ EXECUTOR = ThreadPoolExecutor(max_workers=MAX_CONCURRENT_REQUESTS)
|
|
|
|
|
+ logger.info(f"API服务已启动,最大并发请求数: {MAX_CONCURRENT_REQUESTS}")
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.on_event("shutdown")
|
|
@app.on_event("shutdown")
|
|
|
async def shutdown_event():
|
|
async def shutdown_event():
|
|
|
"""应用关闭时清理资源"""
|
|
"""应用关闭时清理资源"""
|
|
|
- global pipeline_wrapper
|
|
|
|
|
- if pipeline_wrapper:
|
|
|
|
|
- try:
|
|
|
|
|
- pipeline_wrapper.cleanup()
|
|
|
|
|
- logger.info("Pipeline包装器清理完成")
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- logger.warning(f"Pipeline包装器清理失败: {e}")
|
|
|
|
|
|
|
+ global EXECUTOR
|
|
|
|
|
+ if EXECUTOR:
|
|
|
|
|
+ EXECUTOR.shutdown(wait=True)
|
|
|
|
|
+ logger.info("线程池已关闭")
|
|
|
|
|
+ logger.info("API服务关闭")
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/what/search", response_model=SearchResponse)
|
|
@app.post("/what/search", response_model=SearchResponse)
|
|
|
async def search(request: SearchRequest):
|
|
async def search(request: SearchRequest):
|
|
|
"""
|
|
"""
|
|
|
- 执行搜索和评估
|
|
|
|
|
-
|
|
|
|
|
|
|
+ 执行搜索和评估(支持并发)
|
|
|
|
|
+
|
|
|
|
|
+ 并发控制:最多5个请求同时处理,超出的请求会等待
|
|
|
|
|
+
|
|
|
Args:
|
|
Args:
|
|
|
request: 搜索请求
|
|
request: 搜索请求
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
Returns:
|
|
Returns:
|
|
|
搜索结果响应
|
|
搜索结果响应
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
Raises:
|
|
Raises:
|
|
|
HTTPException: 当请求参数无效或处理失败时
|
|
HTTPException: 当请求参数无效或处理失败时
|
|
|
"""
|
|
"""
|
|
|
- try:
|
|
|
|
|
- logger.info(f"收到搜索请求: original_target={request.original_target}, "
|
|
|
|
|
- f"persona_features数量={len(request.persona_features)}, "
|
|
|
|
|
- f"candidate_words数量={len(request.candidate_words)}")
|
|
|
|
|
-
|
|
|
|
|
- # 验证Pipeline包装器是否已初始化
|
|
|
|
|
- if pipeline_wrapper is None:
|
|
|
|
|
- logger.error("Pipeline包装器未初始化")
|
|
|
|
|
- raise HTTPException(status_code=503, detail="Pipeline包装器未初始化,请稍后重试")
|
|
|
|
|
-
|
|
|
|
|
- # 验证输入参数
|
|
|
|
|
- if not request.original_target or not request.original_target.strip():
|
|
|
|
|
- raise HTTPException(status_code=400, detail="original_target不能为空")
|
|
|
|
|
-
|
|
|
|
|
- if not request.persona_features or len(request.persona_features) == 0:
|
|
|
|
|
- raise HTTPException(status_code=400, detail="persona_features不能为空")
|
|
|
|
|
-
|
|
|
|
|
- if not request.candidate_words or len(request.candidate_words) == 0:
|
|
|
|
|
- raise HTTPException(status_code=400, detail="candidate_words不能为空")
|
|
|
|
|
-
|
|
|
|
|
- # 验证persona_features中的persona_feature_name
|
|
|
|
|
- for idx, pf in enumerate(request.persona_features):
|
|
|
|
|
- if not pf.persona_feature_name or not pf.persona_feature_name.strip():
|
|
|
|
|
- raise HTTPException(
|
|
|
|
|
- status_code=400,
|
|
|
|
|
- detail=f"persona_features[{idx}].persona_feature_name不能为空"
|
|
|
|
|
|
|
+ # 步骤1:获取并发许可(限流)
|
|
|
|
|
+ async with REQUEST_SEMAPHORE:
|
|
|
|
|
+ # 步骤2:创建请求专用上下文
|
|
|
|
|
+ async with RequestContextManager(base_dir="temp_requests") as ctx:
|
|
|
|
|
+ try:
|
|
|
|
|
+ logger.info(f"[{ctx.request_id}] 收到搜索请求: "
|
|
|
|
|
+ f"original_target={request.original_target}, "
|
|
|
|
|
+ f"persona_features={len(request.persona_features)}, "
|
|
|
|
|
+ f"candidate_words={len(request.candidate_words)}")
|
|
|
|
|
+
|
|
|
|
|
+ # 验证输入参数
|
|
|
|
|
+ if not request.original_target or not request.original_target.strip():
|
|
|
|
|
+ raise HTTPException(status_code=400, detail="original_target不能为空")
|
|
|
|
|
+
|
|
|
|
|
+ if not request.persona_features or len(request.persona_features) == 0:
|
|
|
|
|
+ raise HTTPException(status_code=400, detail="persona_features不能为空")
|
|
|
|
|
+
|
|
|
|
|
+ if not request.candidate_words or len(request.candidate_words) == 0:
|
|
|
|
|
+ raise HTTPException(status_code=400, detail="candidate_words不能为空")
|
|
|
|
|
+
|
|
|
|
|
+ # 验证persona_features中的persona_feature_name
|
|
|
|
|
+ for idx, pf in enumerate(request.persona_features):
|
|
|
|
|
+ if not pf.persona_feature_name or not pf.persona_feature_name.strip():
|
|
|
|
|
+ raise HTTPException(
|
|
|
|
|
+ status_code=400,
|
|
|
|
|
+ detail=f"persona_features[{idx}].persona_feature_name不能为空"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 步骤3:创建请求专用的Pipeline实例
|
|
|
|
|
+ pipeline_wrapper = PipelineWrapper(
|
|
|
|
|
+ output_dir=str(ctx.work_dir),
|
|
|
|
|
+ request_id=ctx.request_id
|
|
|
)
|
|
)
|
|
|
-
|
|
|
|
|
- # 步骤1:将API输入转换为pipeline格式
|
|
|
|
|
- logger.info("步骤1:转换API输入格式...")
|
|
|
|
|
- try:
|
|
|
|
|
- features_data = convert_api_input_to_pipeline_format(
|
|
|
|
|
- original_target=request.original_target,
|
|
|
|
|
- persona_features=[pf.dict() for pf in request.persona_features],
|
|
|
|
|
- candidate_words=request.candidate_words
|
|
|
|
|
- )
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- logger.error(f"API输入格式转换失败: {e}", exc_info=True)
|
|
|
|
|
- raise HTTPException(status_code=400, detail=f"输入格式转换失败: {str(e)}")
|
|
|
|
|
-
|
|
|
|
|
- if not features_data:
|
|
|
|
|
- raise HTTPException(status_code=400, detail="无法构建有效的特征数据,请检查输入参数")
|
|
|
|
|
-
|
|
|
|
|
- # 步骤2:执行阶段3-7
|
|
|
|
|
- logger.info("步骤2:执行阶段3-7...")
|
|
|
|
|
- try:
|
|
|
|
|
- pipeline_output = await pipeline_wrapper.run_stages_3_to_7(features_data)
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- logger.error(f"阶段3-7执行失败: {e}", exc_info=True)
|
|
|
|
|
- raise HTTPException(status_code=500, detail=f"Pipeline执行失败: {str(e)}")
|
|
|
|
|
-
|
|
|
|
|
- # 验证pipeline输出
|
|
|
|
|
- if not pipeline_output or 'evaluation_results' not in pipeline_output:
|
|
|
|
|
- logger.error("Pipeline输出格式不正确")
|
|
|
|
|
- raise HTTPException(status_code=500, detail="Pipeline输出格式不正确")
|
|
|
|
|
-
|
|
|
|
|
- # 步骤3:将pipeline输出转换为API响应格式
|
|
|
|
|
- logger.info("步骤3:转换API输出格式...")
|
|
|
|
|
- try:
|
|
|
|
|
- response = convert_pipeline_output_to_api_response(
|
|
|
|
|
- pipeline_results=pipeline_output['evaluation_results'],
|
|
|
|
|
- original_target=request.original_target,
|
|
|
|
|
- similarity_results=pipeline_output.get('similarity_results')
|
|
|
|
|
- )
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- logger.error(f"API输出格式转换失败: {e}", exc_info=True)
|
|
|
|
|
- raise HTTPException(status_code=500, detail=f"输出格式转换失败: {str(e)}")
|
|
|
|
|
-
|
|
|
|
|
- logger.info(f"搜索完成: 找到 {len(response['search_results'])} 个有效结果 "
|
|
|
|
|
- f"(综合得分P > 0)")
|
|
|
|
|
-
|
|
|
|
|
- return response
|
|
|
|
|
-
|
|
|
|
|
- except HTTPException:
|
|
|
|
|
- raise
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- logger.error(f"搜索请求处理失败: {e}", exc_info=True)
|
|
|
|
|
- raise HTTPException(status_code=500, detail=f"内部服务器错误: {str(e)}")
|
|
|
|
|
|
|
+
|
|
|
|
|
+ # 步骤4:转换API输入格式
|
|
|
|
|
+ logger.info(f"[{ctx.request_id}] 转换API输入格式...")
|
|
|
|
|
+ try:
|
|
|
|
|
+ features_data = convert_api_input_to_pipeline_format(
|
|
|
|
|
+ original_target=request.original_target,
|
|
|
|
|
+ persona_features=[pf.dict() for pf in request.persona_features],
|
|
|
|
|
+ candidate_words=request.candidate_words
|
|
|
|
|
+ )
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"[{ctx.request_id}] API输入格式转换失败: {e}", exc_info=True)
|
|
|
|
|
+ raise HTTPException(status_code=400, detail=f"输入格式转换失败: {str(e)}")
|
|
|
|
|
+
|
|
|
|
|
+ if not features_data:
|
|
|
|
|
+ raise HTTPException(status_code=400, detail="无法构建有效的特征数据")
|
|
|
|
|
+
|
|
|
|
|
+ # 步骤5:执行阶段3-7(在线程池中运行以避免阻塞)
|
|
|
|
|
+ logger.info(f"[{ctx.request_id}] 执行阶段3-7...")
|
|
|
|
|
+ try:
|
|
|
|
|
+ # 获取当前事件循环
|
|
|
|
|
+ loop = asyncio.get_event_loop()
|
|
|
|
|
+
|
|
|
|
|
+ # 在线程池中运行阻塞的pipeline代码
|
|
|
|
|
+ pipeline_output = await loop.run_in_executor(
|
|
|
|
|
+ EXECUTOR,
|
|
|
|
|
+ pipeline_wrapper.run_stages_3_to_7_sync,
|
|
|
|
|
+ features_data
|
|
|
|
|
+ )
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"[{ctx.request_id}] 阶段3-7执行失败: {e}", exc_info=True)
|
|
|
|
|
+ raise HTTPException(status_code=500, detail=f"Pipeline执行失败: {str(e)}")
|
|
|
|
|
+
|
|
|
|
|
+ # 验证pipeline输出
|
|
|
|
|
+ if not pipeline_output or 'evaluation_results' not in pipeline_output:
|
|
|
|
|
+ logger.error(f"[{ctx.request_id}] Pipeline输出格式不正确")
|
|
|
|
|
+ raise HTTPException(status_code=500, detail="Pipeline输出格式不正确")
|
|
|
|
|
+
|
|
|
|
|
+ # 步骤6:转换API输出格式
|
|
|
|
|
+ logger.info(f"[{ctx.request_id}] 转换API输出格式...")
|
|
|
|
|
+ try:
|
|
|
|
|
+ response = convert_pipeline_output_to_api_response(
|
|
|
|
|
+ pipeline_results=pipeline_output['evaluation_results'],
|
|
|
|
|
+ original_target=request.original_target,
|
|
|
|
|
+ similarity_results=pipeline_output.get('similarity_results')
|
|
|
|
|
+ )
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"[{ctx.request_id}] API输出格式转换失败: {e}", exc_info=True)
|
|
|
|
|
+ raise HTTPException(status_code=500, detail=f"输出格式转换失败: {str(e)}")
|
|
|
|
|
+
|
|
|
|
|
+ logger.info(f"[{ctx.request_id}] 搜索完成: "
|
|
|
|
|
+ f"找到 {len(response['search_results'])} 个有效结果")
|
|
|
|
|
+
|
|
|
|
|
+ # 步骤7:返回结果(工作目录会在退出上下文时自动清理)
|
|
|
|
|
+ return response
|
|
|
|
|
+
|
|
|
|
|
+ except HTTPException:
|
|
|
|
|
+ raise
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"[{ctx.request_id}] 请求失败: {e}", exc_info=True)
|
|
|
|
|
+ raise HTTPException(status_code=500, detail=f"内部服务器错误: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/health")
|
|
@app.get("/health")
|
|
@@ -199,6 +216,7 @@ async def health_check():
|
|
|
"""健康检查端点"""
|
|
"""健康检查端点"""
|
|
|
return {
|
|
return {
|
|
|
"status": "healthy",
|
|
"status": "healthy",
|
|
|
- "pipeline_initialized": pipeline_wrapper is not None
|
|
|
|
|
|
|
+ "max_concurrent_requests": MAX_CONCURRENT_REQUESTS,
|
|
|
|
|
+ "semaphore_initialized": REQUEST_SEMAPHORE is not None
|
|
|
}
|
|
}
|
|
|
|
|
|