search_service.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. FastAPI服务主文件
  5. 提供搜索API端点
  6. """
  7. import logging
  8. from typing import List, Dict, Any
  9. from fastapi import FastAPI, HTTPException
  10. from pydantic import BaseModel, Field
  11. from api.data_converter import (
  12. convert_api_input_to_pipeline_format,
  13. convert_pipeline_output_to_api_response
  14. )
  15. from api.pipeline_wrapper import PipelineWrapper
  16. logger = logging.getLogger(__name__)
  17. # 创建FastAPI应用
  18. app = FastAPI(
  19. title="特征搜索API服务",
  20. description="复用阶段3-7的搜索和评估服务",
  21. version="1.0.0"
  22. )
  23. # 全局Pipeline包装器实例
  24. pipeline_wrapper: PipelineWrapper = None
  25. # 请求模型
  26. class PersonaFeature(BaseModel):
  27. """人设特征模型"""
  28. persona_feature_name: str = Field(..., description="人设特征名称")
  29. class SearchRequest(BaseModel):
  30. """搜索请求模型"""
  31. original_target: str = Field(..., description="原始目标名称")
  32. persona_features: List[PersonaFeature] = Field(..., description="人设特征列表", min_items=1)
  33. candidate_words: List[str] = Field(..., description="候选词列表", min_items=1)
  34. # 响应模型
  35. class MatchedNote(BaseModel):
  36. """匹配的帖子模型"""
  37. note_id: str
  38. note_title: str
  39. evaluation_score: float
  40. max_similarity: float
  41. contribution: float
  42. note_data: Dict[str, Any] # 完整的搜索结果信息
  43. class ComprehensiveScoreDetail(BaseModel):
  44. """综合得分详情模型"""
  45. N: int = 0
  46. M: int = 0
  47. total_contribution: float = 0.0
  48. P: float = 0.0
  49. class SearchResult(BaseModel):
  50. """搜索结果模型"""
  51. search_word: str
  52. source_words: List[str] = Field(..., description="query来源组合,数组格式,包含生成该query所使用的所有来源词")
  53. comprehensive_score: float
  54. comprehensive_score_detail: Dict[str, Any]
  55. matched_notes: List[MatchedNote]
  56. class SearchResponse(BaseModel):
  57. """搜索响应模型"""
  58. original_target: str
  59. search_results: List[SearchResult]
  60. @app.on_event("startup")
  61. async def startup_event():
  62. """应用启动时初始化Pipeline包装器"""
  63. global pipeline_wrapper
  64. try:
  65. pipeline_wrapper = PipelineWrapper()
  66. logger.info("Pipeline包装器初始化成功")
  67. except Exception as e:
  68. logger.error(f"Pipeline包装器初始化失败: {e}", exc_info=True)
  69. raise
  70. @app.on_event("shutdown")
  71. async def shutdown_event():
  72. """应用关闭时清理资源"""
  73. global pipeline_wrapper
  74. if pipeline_wrapper:
  75. try:
  76. pipeline_wrapper.cleanup()
  77. logger.info("Pipeline包装器清理完成")
  78. except Exception as e:
  79. logger.warning(f"Pipeline包装器清理失败: {e}")
  80. @app.post("/what/search", response_model=SearchResponse)
  81. async def search(request: SearchRequest):
  82. """
  83. 执行搜索和评估
  84. Args:
  85. request: 搜索请求
  86. Returns:
  87. 搜索结果响应
  88. Raises:
  89. HTTPException: 当请求参数无效或处理失败时
  90. """
  91. try:
  92. logger.info(f"收到搜索请求: original_target={request.original_target}, "
  93. f"persona_features数量={len(request.persona_features)}, "
  94. f"candidate_words数量={len(request.candidate_words)}")
  95. # 验证Pipeline包装器是否已初始化
  96. if pipeline_wrapper is None:
  97. logger.error("Pipeline包装器未初始化")
  98. raise HTTPException(status_code=503, detail="Pipeline包装器未初始化,请稍后重试")
  99. # 验证输入参数
  100. if not request.original_target or not request.original_target.strip():
  101. raise HTTPException(status_code=400, detail="original_target不能为空")
  102. if not request.persona_features or len(request.persona_features) == 0:
  103. raise HTTPException(status_code=400, detail="persona_features不能为空")
  104. if not request.candidate_words or len(request.candidate_words) == 0:
  105. raise HTTPException(status_code=400, detail="candidate_words不能为空")
  106. # 验证persona_features中的persona_feature_name
  107. for idx, pf in enumerate(request.persona_features):
  108. if not pf.persona_feature_name or not pf.persona_feature_name.strip():
  109. raise HTTPException(
  110. status_code=400,
  111. detail=f"persona_features[{idx}].persona_feature_name不能为空"
  112. )
  113. # 步骤1:将API输入转换为pipeline格式
  114. logger.info("步骤1:转换API输入格式...")
  115. try:
  116. features_data = convert_api_input_to_pipeline_format(
  117. original_target=request.original_target,
  118. persona_features=[pf.dict() for pf in request.persona_features],
  119. candidate_words=request.candidate_words
  120. )
  121. except Exception as e:
  122. logger.error(f"API输入格式转换失败: {e}", exc_info=True)
  123. raise HTTPException(status_code=400, detail=f"输入格式转换失败: {str(e)}")
  124. if not features_data:
  125. raise HTTPException(status_code=400, detail="无法构建有效的特征数据,请检查输入参数")
  126. # 步骤2:执行阶段3-7
  127. logger.info("步骤2:执行阶段3-7...")
  128. try:
  129. pipeline_output = await pipeline_wrapper.run_stages_3_to_7(features_data)
  130. except Exception as e:
  131. logger.error(f"阶段3-7执行失败: {e}", exc_info=True)
  132. raise HTTPException(status_code=500, detail=f"Pipeline执行失败: {str(e)}")
  133. # 验证pipeline输出
  134. if not pipeline_output or 'evaluation_results' not in pipeline_output:
  135. logger.error("Pipeline输出格式不正确")
  136. raise HTTPException(status_code=500, detail="Pipeline输出格式不正确")
  137. # 步骤3:将pipeline输出转换为API响应格式
  138. logger.info("步骤3:转换API输出格式...")
  139. try:
  140. response = convert_pipeline_output_to_api_response(
  141. pipeline_results=pipeline_output['evaluation_results'],
  142. original_target=request.original_target,
  143. similarity_results=pipeline_output.get('similarity_results')
  144. )
  145. except Exception as e:
  146. logger.error(f"API输出格式转换失败: {e}", exc_info=True)
  147. raise HTTPException(status_code=500, detail=f"输出格式转换失败: {str(e)}")
  148. logger.info(f"搜索完成: 找到 {len(response['search_results'])} 个有效结果 "
  149. f"(综合得分P > 0)")
  150. return response
  151. except HTTPException:
  152. raise
  153. except Exception as e:
  154. logger.error(f"搜索请求处理失败: {e}", exc_info=True)
  155. raise HTTPException(status_code=500, detail=f"内部服务器错误: {str(e)}")
  156. @app.get("/health")
  157. async def health_check():
  158. """健康检查端点"""
  159. return {
  160. "status": "healthy",
  161. "pipeline_initialized": pipeline_wrapper is not None
  162. }