search_service.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. FastAPI服务主文件
  5. 提供搜索API端点
  6. """
  7. import asyncio
  8. import logging
  9. from typing import List, Dict, Any
  10. from concurrent.futures import ThreadPoolExecutor
  11. from fastapi import FastAPI, HTTPException
  12. from pydantic import BaseModel, Field
  13. from api.data_converter import (
  14. convert_api_input_to_pipeline_format,
  15. convert_pipeline_output_to_api_response
  16. )
  17. from api.pipeline_wrapper import PipelineWrapper
  18. from api.request_context import RequestContextManager
  19. logger = logging.getLogger(__name__)
  20. # 创建FastAPI应用
  21. app = FastAPI(
  22. title="特征搜索API服务",
  23. description="复用阶段3-7的搜索和评估服务",
  24. version="1.0.0"
  25. )
  26. # 并发控制
  27. REQUEST_SEMAPHORE = None
  28. MAX_CONCURRENT_REQUESTS = 5
  29. EXECUTOR = None # 线程池执行器,用于运行阻塞的pipeline代码
  30. # 请求模型
  31. class PersonaFeature(BaseModel):
  32. """人设特征模型"""
  33. persona_feature_name: str = Field(..., description="人设特征名称")
  34. class SearchRequest(BaseModel):
  35. """搜索请求模型"""
  36. original_target: str = Field(..., description="原始目标名称")
  37. persona_features: List[PersonaFeature] = Field(..., description="人设特征列表", min_items=1)
  38. candidate_words: List[str] = Field(..., description="候选词列表", min_items=1)
  39. # 响应模型
  40. class MatchedNote(BaseModel):
  41. """匹配的帖子模型"""
  42. note_id: str
  43. note_title: str
  44. evaluation_score: float
  45. max_similarity: float
  46. contribution: float
  47. note_data: Dict[str, Any] # 完整的搜索结果信息
  48. class ComprehensiveScoreDetail(BaseModel):
  49. """综合得分详情模型"""
  50. N: int = 0
  51. M: int = 0
  52. total_contribution: float = 0.0
  53. P: float = 0.0
  54. class SearchResult(BaseModel):
  55. """搜索结果模型"""
  56. search_word: str
  57. source_words: List[str] = Field(..., description="query来源组合,数组格式,包含生成该query所使用的所有来源词")
  58. comprehensive_score: float
  59. comprehensive_score_detail: Dict[str, Any]
  60. matched_notes: List[MatchedNote]
  61. class SearchResponse(BaseModel):
  62. """搜索响应模型"""
  63. original_target: str
  64. search_results: List[SearchResult]
  65. @app.on_event("startup")
  66. async def startup_event():
  67. """应用启动时初始化并发限流器和线程池"""
  68. global REQUEST_SEMAPHORE, EXECUTOR
  69. REQUEST_SEMAPHORE = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
  70. EXECUTOR = ThreadPoolExecutor(max_workers=MAX_CONCURRENT_REQUESTS)
  71. logger.info(f"API服务已启动,最大并发请求数: {MAX_CONCURRENT_REQUESTS}")
  72. @app.on_event("shutdown")
  73. async def shutdown_event():
  74. """应用关闭时清理资源"""
  75. global EXECUTOR
  76. if EXECUTOR:
  77. EXECUTOR.shutdown(wait=True)
  78. logger.info("线程池已关闭")
  79. logger.info("API服务关闭")
  80. @app.post("/what/search", response_model=SearchResponse)
  81. async def search(request: SearchRequest):
  82. """
  83. 执行搜索和评估(支持并发)
  84. 并发控制:最多5个请求同时处理,超出的请求会等待
  85. Args:
  86. request: 搜索请求
  87. Returns:
  88. 搜索结果响应
  89. Raises:
  90. HTTPException: 当请求参数无效或处理失败时
  91. """
  92. # 步骤1:获取并发许可(限流)
  93. async with REQUEST_SEMAPHORE:
  94. # 步骤2:创建请求专用上下文
  95. async with RequestContextManager(base_dir="temp_requests") as ctx:
  96. try:
  97. logger.info(f"[{ctx.request_id}] 收到搜索请求: "
  98. f"original_target={request.original_target}, "
  99. f"persona_features={len(request.persona_features)}, "
  100. f"candidate_words={len(request.candidate_words)}")
  101. # 验证输入参数
  102. if not request.original_target or not request.original_target.strip():
  103. raise HTTPException(status_code=400, detail="original_target不能为空")
  104. if not request.persona_features or len(request.persona_features) == 0:
  105. raise HTTPException(status_code=400, detail="persona_features不能为空")
  106. if not request.candidate_words or len(request.candidate_words) == 0:
  107. raise HTTPException(status_code=400, detail="candidate_words不能为空")
  108. # 验证persona_features中的persona_feature_name
  109. for idx, pf in enumerate(request.persona_features):
  110. if not pf.persona_feature_name or not pf.persona_feature_name.strip():
  111. raise HTTPException(
  112. status_code=400,
  113. detail=f"persona_features[{idx}].persona_feature_name不能为空"
  114. )
  115. # 步骤3:创建请求专用的Pipeline实例
  116. pipeline_wrapper = PipelineWrapper(
  117. output_dir=str(ctx.work_dir),
  118. request_id=ctx.request_id
  119. )
  120. # 步骤4:转换API输入格式
  121. logger.info(f"[{ctx.request_id}] 转换API输入格式...")
  122. try:
  123. features_data = convert_api_input_to_pipeline_format(
  124. original_target=request.original_target,
  125. persona_features=[pf.dict() for pf in request.persona_features],
  126. candidate_words=request.candidate_words
  127. )
  128. except Exception as e:
  129. logger.error(f"[{ctx.request_id}] API输入格式转换失败: {e}", exc_info=True)
  130. raise HTTPException(status_code=400, detail=f"输入格式转换失败: {str(e)}")
  131. if not features_data:
  132. raise HTTPException(status_code=400, detail="无法构建有效的特征数据")
  133. # 步骤5:执行阶段3-7(在线程池中运行以避免阻塞)
  134. logger.info(f"[{ctx.request_id}] 执行阶段3-7...")
  135. try:
  136. # 获取当前事件循环
  137. loop = asyncio.get_event_loop()
  138. # 在线程池中运行阻塞的pipeline代码
  139. pipeline_output = await loop.run_in_executor(
  140. EXECUTOR,
  141. pipeline_wrapper.run_stages_3_to_7_sync,
  142. features_data
  143. )
  144. except Exception as e:
  145. logger.error(f"[{ctx.request_id}] 阶段3-7执行失败: {e}", exc_info=True)
  146. raise HTTPException(status_code=500, detail=f"Pipeline执行失败: {str(e)}")
  147. # 验证pipeline输出
  148. if not pipeline_output or 'evaluation_results' not in pipeline_output:
  149. logger.error(f"[{ctx.request_id}] Pipeline输出格式不正确")
  150. raise HTTPException(status_code=500, detail="Pipeline输出格式不正确")
  151. # 步骤6:转换API输出格式
  152. logger.info(f"[{ctx.request_id}] 转换API输出格式...")
  153. try:
  154. response = convert_pipeline_output_to_api_response(
  155. pipeline_results=pipeline_output['evaluation_results'],
  156. original_target=request.original_target,
  157. similarity_results=pipeline_output.get('similarity_results')
  158. )
  159. except Exception as e:
  160. logger.error(f"[{ctx.request_id}] API输出格式转换失败: {e}", exc_info=True)
  161. raise HTTPException(status_code=500, detail=f"输出格式转换失败: {str(e)}")
  162. logger.info(f"[{ctx.request_id}] 搜索完成: "
  163. f"找到 {len(response['search_results'])} 个有效结果")
  164. # 步骤7:返回结果(工作目录会在退出上下文时自动清理)
  165. return response
  166. except HTTPException:
  167. raise
  168. except Exception as e:
  169. logger.error(f"[{ctx.request_id}] 请求失败: {e}", exc_info=True)
  170. raise HTTPException(status_code=500, detail=f"内部服务器错误: {str(e)}")
  171. @app.get("/health")
  172. async def health_check():
  173. """健康检查端点"""
  174. return {
  175. "status": "healthy",
  176. "max_concurrent_requests": MAX_CONCURRENT_REQUESTS,
  177. "semaphore_initialized": REQUEST_SEMAPHORE is not None
  178. }