#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ FastAPI服务主文件 提供搜索API端点 """ import asyncio import logging from typing import List, Dict, Any from concurrent.futures import ThreadPoolExecutor from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field from api.data_converter import ( convert_api_input_to_pipeline_format, convert_pipeline_output_to_api_response ) from api.pipeline_wrapper import PipelineWrapper from api.request_context import RequestContextManager logger = logging.getLogger(__name__) # 创建FastAPI应用 app = FastAPI( title="特征搜索API服务", description="复用阶段3-7的搜索和评估服务", version="1.0.0" ) # 并发控制 REQUEST_SEMAPHORE = None MAX_CONCURRENT_REQUESTS = 5 EXECUTOR = None # 线程池执行器,用于运行阻塞的pipeline代码 # 请求模型 class PersonaFeature(BaseModel): """人设特征模型""" persona_feature_name: str = Field(..., description="人设特征名称") class SearchRequest(BaseModel): """搜索请求模型""" original_target: str = Field(..., description="原始目标名称") persona_features: List[PersonaFeature] = Field(..., description="人设特征列表", min_items=1) candidate_words: List[str] = Field(..., description="候选词列表", min_items=1) # 响应模型 class MatchedFeature(BaseModel): """匹配的解构特征模型""" feature_name: str = Field(..., description="特征名称") dimension: str = Field(..., description="特征维度(如:灵感点-全新内容)") dimension_detail: str = Field(..., description="维度详情(如:实质、形式、意图)") weight: float = Field(..., description="特征权重") similarity_score: float = Field(..., description="相似度得分") class MatchedNote(BaseModel): """匹配的帖子模型""" note_id: str note_title: str evaluation_score: float max_similarity: float contribution: float # 评估详情 evaluation_reasoning: str = Field(..., description="综合得分的评分说明") key_matching_points: List[str] = Field(..., description="关键匹配点列表") query_relevance: str = Field(..., description="Query相关性(相关/不相关)") query_relevance_explanation: str = Field(..., description="Query相关性说明") matched_features: List[MatchedFeature] = Field(..., description="匹配的解构特征列表") note_data: Dict[str, Any] # 完整的搜索结果信息 class ComprehensiveScoreDetail(BaseModel): """综合得分详情模型""" N: int = 0 M: int = 0 total_contribution: float = 0.0 P: float = 0.0 class SearchResult(BaseModel): """搜索结果模型""" search_word: str source_words: List[str] = Field(..., description="query来源组合,数组格式,包含生成该query所使用的所有来源词") comprehensive_score: float comprehensive_score_detail: Dict[str, Any] matched_notes: List[MatchedNote] class SearchResponse(BaseModel): """搜索响应模型""" original_target: str search_results: List[SearchResult] @app.on_event("startup") async def startup_event(): """应用启动时初始化并发限流器和线程池""" global REQUEST_SEMAPHORE, EXECUTOR REQUEST_SEMAPHORE = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS) EXECUTOR = ThreadPoolExecutor(max_workers=MAX_CONCURRENT_REQUESTS) logger.info(f"API服务已启动,最大并发请求数: {MAX_CONCURRENT_REQUESTS}") @app.on_event("shutdown") async def shutdown_event(): """应用关闭时清理资源""" global EXECUTOR if EXECUTOR: EXECUTOR.shutdown(wait=True) logger.info("线程池已关闭") logger.info("API服务关闭") @app.post("/what/search", response_model=SearchResponse) async def search(request: SearchRequest): """ 执行搜索和评估(支持并发) 并发控制:最多5个请求同时处理,超出的请求会等待 Args: request: 搜索请求 Returns: 搜索结果响应 Raises: HTTPException: 当请求参数无效或处理失败时 """ # 步骤1:获取并发许可(限流) async with REQUEST_SEMAPHORE: # 步骤2:创建请求专用上下文 async with RequestContextManager(base_dir="temp_requests") as ctx: try: logger.info(f"[{ctx.request_id}] 收到搜索请求: " f"original_target={request.original_target}, " f"persona_features={len(request.persona_features)}, " f"candidate_words={len(request.candidate_words)}") # 验证输入参数 if not request.original_target or not request.original_target.strip(): raise HTTPException(status_code=400, detail="original_target不能为空") if not request.persona_features or len(request.persona_features) == 0: raise HTTPException(status_code=400, detail="persona_features不能为空") if not request.candidate_words or len(request.candidate_words) == 0: raise HTTPException(status_code=400, detail="candidate_words不能为空") # 验证persona_features中的persona_feature_name for idx, pf in enumerate(request.persona_features): if not pf.persona_feature_name or not pf.persona_feature_name.strip(): raise HTTPException( status_code=400, detail=f"persona_features[{idx}].persona_feature_name不能为空" ) # 步骤3:创建请求专用的Pipeline实例 pipeline_wrapper = PipelineWrapper( output_dir=str(ctx.work_dir), request_id=ctx.request_id ) # 步骤4:转换API输入格式 logger.info(f"[{ctx.request_id}] 转换API输入格式...") try: features_data = convert_api_input_to_pipeline_format( original_target=request.original_target, persona_features=[pf.dict() for pf in request.persona_features], candidate_words=request.candidate_words ) except Exception as e: logger.error(f"[{ctx.request_id}] API输入格式转换失败: {e}", exc_info=True) raise HTTPException(status_code=400, detail=f"输入格式转换失败: {str(e)}") if not features_data: raise HTTPException(status_code=400, detail="无法构建有效的特征数据") # 步骤5:执行阶段3-7(在线程池中运行以避免阻塞) logger.info(f"[{ctx.request_id}] 执行阶段3-7...") try: # 获取当前事件循环 loop = asyncio.get_event_loop() # 在线程池中运行阻塞的pipeline代码 pipeline_output = await loop.run_in_executor( EXECUTOR, pipeline_wrapper.run_stages_3_to_7_sync, features_data ) except Exception as e: logger.error(f"[{ctx.request_id}] 阶段3-7执行失败: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Pipeline执行失败: {str(e)}") # 验证pipeline输出 if not pipeline_output or 'evaluation_results' not in pipeline_output: logger.error(f"[{ctx.request_id}] Pipeline输出格式不正确") raise HTTPException(status_code=500, detail="Pipeline输出格式不正确") # 步骤6:转换API输出格式 logger.info(f"[{ctx.request_id}] 转换API输出格式...") try: response = convert_pipeline_output_to_api_response( pipeline_results=pipeline_output['evaluation_results'], original_target=request.original_target, similarity_results=pipeline_output.get('similarity_results') ) except Exception as e: logger.error(f"[{ctx.request_id}] API输出格式转换失败: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"输出格式转换失败: {str(e)}") logger.info(f"[{ctx.request_id}] 搜索完成: " f"找到 {len(response['search_results'])} 个有效结果") # 步骤7:返回结果(工作目录会在退出上下文时自动清理) return response except HTTPException: raise except Exception as e: logger.error(f"[{ctx.request_id}] 请求失败: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"内部服务器错误: {str(e)}") @app.get("/health") async def health_check(): """健康检查端点""" return { "status": "healthy", "max_concurrent_requests": MAX_CONCURRENT_REQUESTS, "semaphore_initialized": REQUEST_SEMAPHORE is not None }