| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- FastAPI服务主文件
- 提供搜索API端点
- """
- import logging
- from typing import List, Dict, Any
- from fastapi import FastAPI, HTTPException
- from pydantic import BaseModel, Field
- from api.data_converter import (
- convert_api_input_to_pipeline_format,
- convert_pipeline_output_to_api_response
- )
- from api.pipeline_wrapper import PipelineWrapper
- logger = logging.getLogger(__name__)
- # 创建FastAPI应用
- app = FastAPI(
- title="特征搜索API服务",
- description="复用阶段3-7的搜索和评估服务",
- version="1.0.0"
- )
- # 全局Pipeline包装器实例
- pipeline_wrapper: PipelineWrapper = None
- # 请求模型
- class PersonaFeature(BaseModel):
- """人设特征模型"""
- persona_feature_name: str = Field(..., description="人设特征名称")
- class SearchRequest(BaseModel):
- """搜索请求模型"""
- original_target: str = Field(..., description="原始目标名称")
- persona_features: List[PersonaFeature] = Field(..., description="人设特征列表", min_items=1)
- candidate_words: List[str] = Field(..., description="候选词列表", min_items=1)
- # 响应模型
- class MatchedNote(BaseModel):
- """匹配的帖子模型"""
- note_id: str
- note_title: str
- evaluation_score: float
- max_similarity: float
- contribution: float
- note_data: Dict[str, Any] # 完整的搜索结果信息
- class ComprehensiveScoreDetail(BaseModel):
- """综合得分详情模型"""
- N: int = 0
- M: int = 0
- total_contribution: float = 0.0
- P: float = 0.0
- class SearchResult(BaseModel):
- """搜索结果模型"""
- search_word: str
- comprehensive_score: float
- comprehensive_score_detail: Dict[str, Any]
- matched_notes: List[MatchedNote]
- class SearchResponse(BaseModel):
- """搜索响应模型"""
- original_target: str
- search_results: List[SearchResult]
- @app.on_event("startup")
- async def startup_event():
- """应用启动时初始化Pipeline包装器"""
- global pipeline_wrapper
- try:
- pipeline_wrapper = PipelineWrapper()
- logger.info("Pipeline包装器初始化成功")
- except Exception as e:
- logger.error(f"Pipeline包装器初始化失败: {e}", exc_info=True)
- raise
- @app.on_event("shutdown")
- async def shutdown_event():
- """应用关闭时清理资源"""
- global pipeline_wrapper
- if pipeline_wrapper:
- try:
- pipeline_wrapper.cleanup()
- logger.info("Pipeline包装器清理完成")
- except Exception as e:
- logger.warning(f"Pipeline包装器清理失败: {e}")
- @app.post("/what/search", response_model=SearchResponse)
- async def search(request: SearchRequest):
- """
- 执行搜索和评估
-
- Args:
- request: 搜索请求
-
- Returns:
- 搜索结果响应
-
- Raises:
- HTTPException: 当请求参数无效或处理失败时
- """
- try:
- logger.info(f"收到搜索请求: original_target={request.original_target}, "
- f"persona_features数量={len(request.persona_features)}, "
- f"candidate_words数量={len(request.candidate_words)}")
-
- # 验证Pipeline包装器是否已初始化
- if pipeline_wrapper is None:
- logger.error("Pipeline包装器未初始化")
- raise HTTPException(status_code=503, detail="Pipeline包装器未初始化,请稍后重试")
-
- # 验证输入参数
- if not request.original_target or not request.original_target.strip():
- raise HTTPException(status_code=400, detail="original_target不能为空")
-
- if not request.persona_features or len(request.persona_features) == 0:
- raise HTTPException(status_code=400, detail="persona_features不能为空")
-
- if not request.candidate_words or len(request.candidate_words) == 0:
- raise HTTPException(status_code=400, detail="candidate_words不能为空")
-
- # 验证persona_features中的persona_feature_name
- for idx, pf in enumerate(request.persona_features):
- if not pf.persona_feature_name or not pf.persona_feature_name.strip():
- raise HTTPException(
- status_code=400,
- detail=f"persona_features[{idx}].persona_feature_name不能为空"
- )
-
- # 步骤1:将API输入转换为pipeline格式
- logger.info("步骤1:转换API输入格式...")
- try:
- features_data = convert_api_input_to_pipeline_format(
- original_target=request.original_target,
- persona_features=[pf.dict() for pf in request.persona_features],
- candidate_words=request.candidate_words
- )
- except Exception as e:
- logger.error(f"API输入格式转换失败: {e}", exc_info=True)
- raise HTTPException(status_code=400, detail=f"输入格式转换失败: {str(e)}")
-
- if not features_data:
- raise HTTPException(status_code=400, detail="无法构建有效的特征数据,请检查输入参数")
-
- # 步骤2:执行阶段3-7
- logger.info("步骤2:执行阶段3-7...")
- try:
- pipeline_output = pipeline_wrapper.run_stages_3_to_7(features_data)
- except Exception as e:
- logger.error(f"阶段3-7执行失败: {e}", exc_info=True)
- raise HTTPException(status_code=500, detail=f"Pipeline执行失败: {str(e)}")
-
- # 验证pipeline输出
- if not pipeline_output or 'evaluation_results' not in pipeline_output:
- logger.error("Pipeline输出格式不正确")
- raise HTTPException(status_code=500, detail="Pipeline输出格式不正确")
-
- # 步骤3:将pipeline输出转换为API响应格式
- logger.info("步骤3:转换API输出格式...")
- try:
- response = convert_pipeline_output_to_api_response(
- pipeline_results=pipeline_output['evaluation_results'],
- original_target=request.original_target,
- similarity_results=pipeline_output.get('similarity_results')
- )
- except Exception as e:
- logger.error(f"API输出格式转换失败: {e}", exc_info=True)
- raise HTTPException(status_code=500, detail=f"输出格式转换失败: {str(e)}")
-
- logger.info(f"搜索完成: 找到 {len(response['search_results'])} 个有效结果 "
- f"(综合得分P > 0)")
-
- return response
-
- except HTTPException:
- raise
- except Exception as e:
- logger.error(f"搜索请求处理失败: {e}", exc_info=True)
- raise HTTPException(status_code=500, detail=f"内部服务器错误: {str(e)}")
- @app.get("/health")
- async def health_check():
- """健康检查端点"""
- return {
- "status": "healthy",
- "pipeline_initialized": pipeline_wrapper is not None
- }
|