| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- FastAPI服务主文件
- 提供搜索API端点
- """
- import asyncio
- import logging
- from typing import List, Dict, Any
- from concurrent.futures import ThreadPoolExecutor
- from fastapi import FastAPI, HTTPException
- from pydantic import BaseModel, Field
- from api.data_converter import (
- convert_api_input_to_pipeline_format,
- convert_pipeline_output_to_api_response
- )
- from api.pipeline_wrapper import PipelineWrapper
- from api.request_context import RequestContextManager
- logger = logging.getLogger(__name__)
- # 创建FastAPI应用
- app = FastAPI(
- title="特征搜索API服务",
- description="复用阶段3-7的搜索和评估服务",
- version="1.0.0"
- )
- # 并发控制
- REQUEST_SEMAPHORE = None
- MAX_CONCURRENT_REQUESTS = 5
- EXECUTOR = None # 线程池执行器,用于运行阻塞的pipeline代码
- # 请求模型
- class PersonaFeature(BaseModel):
- """人设特征模型"""
- persona_feature_name: str = Field(..., description="人设特征名称")
- class SearchRequest(BaseModel):
- """搜索请求模型"""
- original_target: str = Field(..., description="原始目标名称")
- persona_features: List[PersonaFeature] = Field(..., description="人设特征列表", min_items=1)
- candidate_words: List[str] = Field(..., description="候选词列表", min_items=1)
- # 响应模型
- class MatchedFeature(BaseModel):
- """匹配的解构特征模型"""
- feature_name: str = Field(..., description="特征名称")
- dimension: str = Field(..., description="特征维度(如:灵感点-全新内容)")
- dimension_detail: str = Field(..., description="维度详情(如:实质、形式、意图)")
- weight: float = Field(..., description="特征权重")
- similarity_score: float = Field(..., description="相似度得分")
- class MatchedNote(BaseModel):
- """匹配的帖子模型"""
- note_id: str
- note_title: str
- evaluation_score: float
- max_similarity: float
- contribution: float
- # 评估详情
- evaluation_reasoning: str = Field(..., description="综合得分的评分说明")
- key_matching_points: List[str] = Field(..., description="关键匹配点列表")
- query_relevance: str = Field(..., description="Query相关性(相关/不相关)")
- query_relevance_explanation: str = Field(..., description="Query相关性说明")
- matched_features: List[MatchedFeature] = Field(..., description="匹配的解构特征列表")
- note_data: Dict[str, Any] # 完整的搜索结果信息
- class ComprehensiveScoreDetail(BaseModel):
- """综合得分详情模型"""
- N: int = 0
- M: int = 0
- total_contribution: float = 0.0
- P: float = 0.0
- class SearchResult(BaseModel):
- """搜索结果模型"""
- search_word: str
- source_words: List[str] = Field(..., description="query来源组合,数组格式,包含生成该query所使用的所有来源词")
- comprehensive_score: float
- comprehensive_score_detail: Dict[str, Any]
- matched_notes: List[MatchedNote]
- class SearchResponse(BaseModel):
- """搜索响应模型"""
- original_target: str
- search_results: List[SearchResult]
- @app.on_event("startup")
- async def startup_event():
- """应用启动时初始化并发限流器和线程池"""
- global REQUEST_SEMAPHORE, EXECUTOR
- REQUEST_SEMAPHORE = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
- EXECUTOR = ThreadPoolExecutor(max_workers=MAX_CONCURRENT_REQUESTS)
- logger.info(f"API服务已启动,最大并发请求数: {MAX_CONCURRENT_REQUESTS}")
- @app.on_event("shutdown")
- async def shutdown_event():
- """应用关闭时清理资源"""
- global EXECUTOR
- if EXECUTOR:
- EXECUTOR.shutdown(wait=True)
- logger.info("线程池已关闭")
- logger.info("API服务关闭")
- @app.post("/what/search", response_model=SearchResponse)
- async def search(request: SearchRequest):
- """
- 执行搜索和评估(支持并发)
- 并发控制:最多5个请求同时处理,超出的请求会等待
- Args:
- request: 搜索请求
- Returns:
- 搜索结果响应
- Raises:
- HTTPException: 当请求参数无效或处理失败时
- """
- # 步骤1:获取并发许可(限流)
- async with REQUEST_SEMAPHORE:
- # 步骤2:创建请求专用上下文
- async with RequestContextManager(base_dir="temp_requests") as ctx:
- try:
- logger.info(f"[{ctx.request_id}] 收到搜索请求: "
- f"original_target={request.original_target}, "
- f"persona_features={len(request.persona_features)}, "
- f"candidate_words={len(request.candidate_words)}")
- # 验证输入参数
- if not request.original_target or not request.original_target.strip():
- raise HTTPException(status_code=400, detail="original_target不能为空")
- if not request.persona_features or len(request.persona_features) == 0:
- raise HTTPException(status_code=400, detail="persona_features不能为空")
- if not request.candidate_words or len(request.candidate_words) == 0:
- raise HTTPException(status_code=400, detail="candidate_words不能为空")
- # 验证persona_features中的persona_feature_name
- for idx, pf in enumerate(request.persona_features):
- if not pf.persona_feature_name or not pf.persona_feature_name.strip():
- raise HTTPException(
- status_code=400,
- detail=f"persona_features[{idx}].persona_feature_name不能为空"
- )
- # 步骤3:创建请求专用的Pipeline实例
- pipeline_wrapper = PipelineWrapper(
- output_dir=str(ctx.work_dir),
- request_id=ctx.request_id
- )
- # 步骤4:转换API输入格式
- logger.info(f"[{ctx.request_id}] 转换API输入格式...")
- try:
- features_data = convert_api_input_to_pipeline_format(
- original_target=request.original_target,
- persona_features=[pf.dict() for pf in request.persona_features],
- candidate_words=request.candidate_words
- )
- except Exception as e:
- logger.error(f"[{ctx.request_id}] API输入格式转换失败: {e}", exc_info=True)
- raise HTTPException(status_code=400, detail=f"输入格式转换失败: {str(e)}")
- if not features_data:
- raise HTTPException(status_code=400, detail="无法构建有效的特征数据")
- # 步骤5:执行阶段3-7(在线程池中运行以避免阻塞)
- logger.info(f"[{ctx.request_id}] 执行阶段3-7...")
- try:
- # 获取当前事件循环
- loop = asyncio.get_event_loop()
- # 在线程池中运行阻塞的pipeline代码
- pipeline_output = await loop.run_in_executor(
- EXECUTOR,
- pipeline_wrapper.run_stages_3_to_7_sync,
- features_data
- )
- except Exception as e:
- logger.error(f"[{ctx.request_id}] 阶段3-7执行失败: {e}", exc_info=True)
- raise HTTPException(status_code=500, detail=f"Pipeline执行失败: {str(e)}")
- # 验证pipeline输出
- if not pipeline_output or 'evaluation_results' not in pipeline_output:
- logger.error(f"[{ctx.request_id}] Pipeline输出格式不正确")
- raise HTTPException(status_code=500, detail="Pipeline输出格式不正确")
- # 步骤6:转换API输出格式
- logger.info(f"[{ctx.request_id}] 转换API输出格式...")
- try:
- response = convert_pipeline_output_to_api_response(
- pipeline_results=pipeline_output['evaluation_results'],
- original_target=request.original_target,
- similarity_results=pipeline_output.get('similarity_results')
- )
- except Exception as e:
- logger.error(f"[{ctx.request_id}] API输出格式转换失败: {e}", exc_info=True)
- raise HTTPException(status_code=500, detail=f"输出格式转换失败: {str(e)}")
- logger.info(f"[{ctx.request_id}] 搜索完成: "
- f"找到 {len(response['search_results'])} 个有效结果")
- # 步骤7:返回结果(工作目录会在退出上下文时自动清理)
- return response
- except HTTPException:
- raise
- except Exception as e:
- logger.error(f"[{ctx.request_id}] 请求失败: {e}", exc_info=True)
- raise HTTPException(status_code=500, detail=f"内部服务器错误: {str(e)}")
- @app.get("/health")
- async def health_check():
- """健康检查端点"""
- return {
- "status": "healthy",
- "max_concurrent_requests": MAX_CONCURRENT_REQUESTS,
- "semaphore_initialized": REQUEST_SEMAPHORE is not None
- }
|