pipeline_wrapper.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Pipeline包装器
  5. 封装EnhancedSearchV2,提供只执行阶段3-7的接口
  6. """
  7. import os
  8. import logging
  9. import tempfile
  10. from typing import Dict, List, Any, Optional
  11. from src.pipeline.feature_search_pipeline import EnhancedSearchV2
  12. from api.config import APIConfig
  13. logger = logging.getLogger(__name__)
  14. class PipelineWrapper:
  15. """Pipeline包装器,复用阶段3-7"""
  16. def __init__(self):
  17. """初始化Pipeline包装器"""
  18. # 创建临时输出目录
  19. self.temp_output_dir = tempfile.mkdtemp(prefix='api_pipeline_')
  20. logger.info(f"创建临时输出目录: {self.temp_output_dir}")
  21. # 初始化EnhancedSearchV2实例
  22. # 注意:how_json_path参数是必需的,但我们不会使用它(因为我们跳过阶段1-2)
  23. # 创建一个空的临时文件作为占位符
  24. temp_how_file = os.path.join(self.temp_output_dir, 'temp_how.json')
  25. with open(temp_how_file, 'w', encoding='utf-8') as f:
  26. import json
  27. json.dump({'解构结果': {}}, f)
  28. self.pipeline = EnhancedSearchV2(
  29. how_json_path=temp_how_file, # 占位符文件,实际不会使用
  30. openrouter_api_key=APIConfig.OPENROUTER_API_KEY,
  31. output_dir=self.temp_output_dir,
  32. top_n=10,
  33. max_total_searches=APIConfig.MAX_TOTAL_SEARCHES,
  34. search_max_workers=APIConfig.SEARCH_MAX_WORKERS,
  35. max_searches_per_feature=APIConfig.MAX_SEARCHES_PER_FEATURE,
  36. max_searches_per_base_word=APIConfig.MAX_SEARCHES_PER_BASE_WORD,
  37. enable_evaluation=True,
  38. evaluation_max_workers=APIConfig.EVALUATION_MAX_WORKERS,
  39. evaluation_max_notes_per_query=APIConfig.EVALUATION_MAX_NOTES_PER_QUERY,
  40. enable_deep_analysis=True, # 启用深度解构
  41. deep_analysis_only=False,
  42. deep_analysis_max_workers=APIConfig.DEEP_ANALYSIS_MAX_WORKERS,
  43. deep_analysis_max_notes=None,
  44. deep_analysis_skip_count=0,
  45. deep_analysis_sort_by='score',
  46. deep_analysis_api_url=APIConfig.DEEP_ANALYSIS_API_URL,
  47. deep_analysis_min_score=APIConfig.DEEP_ANALYSIS_MIN_SCORE,
  48. enable_similarity=True, # 启用相似度分析
  49. similarity_weight_embedding=APIConfig.SIMILARITY_WEIGHT_EMBEDDING,
  50. similarity_weight_semantic=APIConfig.SIMILARITY_WEIGHT_SEMANTIC,
  51. similarity_max_workers=APIConfig.SIMILARITY_MAX_WORKERS,
  52. similarity_min_similarity=APIConfig.SIMILARITY_MIN_SIMILARITY
  53. )
  54. logger.info("Pipeline包装器初始化完成")
  55. def run_stages_3_to_7(
  56. self,
  57. features_data: List[Dict[str, Any]]
  58. ) -> Dict[str, Any]:
  59. """
  60. 执行阶段3-7的完整流程
  61. Args:
  62. features_data: 阶段2的输出格式数据(candidate_words.json格式)
  63. Returns:
  64. 包含阶段3-7结果的字典
  65. Raises:
  66. Exception: 当任何阶段执行失败时
  67. """
  68. try:
  69. logger.info("=" * 60)
  70. logger.info("开始执行阶段3-7")
  71. logger.info("=" * 60)
  72. # 验证输入数据
  73. if not features_data:
  74. raise ValueError("features_data不能为空")
  75. # 阶段3:多词组合 + LLM评估
  76. logger.info("阶段3:生成搜索词...")
  77. try:
  78. queries = self.pipeline.generate_search_queries(
  79. features_data,
  80. max_workers=APIConfig.QUERY_GENERATION_MAX_WORKERS,
  81. max_candidates=APIConfig.MAX_CANDIDATES,
  82. max_combo_length=APIConfig.MAX_COMBO_LENGTH
  83. )
  84. except Exception as e:
  85. logger.error(f"阶段3执行失败: {e}", exc_info=True)
  86. raise Exception(f"搜索词生成失败: {str(e)}")
  87. # 阶段4:执行搜索
  88. logger.info("阶段4:执行搜索...")
  89. try:
  90. search_results = self.pipeline.execute_search_queries(
  91. queries,
  92. search_delay=2.0,
  93. top_n=self.pipeline.top_n
  94. )
  95. except Exception as e:
  96. logger.error(f"阶段4执行失败: {e}", exc_info=True)
  97. raise Exception(f"搜索执行失败: {str(e)}")
  98. # 阶段5:LLM评估搜索结果
  99. logger.info("阶段5:评估搜索结果...")
  100. try:
  101. evaluation_results = self.pipeline.evaluate_search_results(search_results)
  102. except Exception as e:
  103. logger.error(f"阶段5执行失败: {e}", exc_info=True)
  104. raise Exception(f"结果评估失败: {str(e)}")
  105. # 阶段6:深度解构分析
  106. logger.info("阶段6:深度解构分析...")
  107. try:
  108. deep_results = self.pipeline.deep_analyzer.run(evaluation_results)
  109. except Exception as e:
  110. logger.error(f"阶段6执行失败: {e}", exc_info=True)
  111. raise Exception(f"深度解构分析失败: {str(e)}")
  112. # 阶段7:相似度分析
  113. logger.info("阶段7:相似度分析...")
  114. try:
  115. similarity_results = self.pipeline.similarity_analyzer.run(
  116. deep_results,
  117. output_path=os.path.join(self.temp_output_dir, "similarity_analysis_results.json")
  118. )
  119. except Exception as e:
  120. logger.error(f"阶段7执行失败: {e}", exc_info=True)
  121. raise Exception(f"相似度分析失败: {str(e)}")
  122. # 注意:similarity_analyzer.run会自动更新evaluation_results中的comprehensive_score
  123. # 所以我们需要重新加载更新后的evaluation_results
  124. # 但由于similarity_analyzer已经更新了内存中的evaluation_results,这里直接使用即可
  125. logger.info("=" * 60)
  126. logger.info("阶段3-7执行完成")
  127. logger.info("=" * 60)
  128. return {
  129. 'evaluation_results': evaluation_results, # 已包含更新后的comprehensive_score
  130. 'deep_results': deep_results,
  131. 'similarity_results': similarity_results
  132. }
  133. except Exception as e:
  134. logger.error(f"执行阶段3-7失败: {e}", exc_info=True)
  135. raise
  136. def cleanup(self):
  137. """清理临时文件"""
  138. try:
  139. import shutil
  140. if os.path.exists(self.temp_output_dir):
  141. shutil.rmtree(self.temp_output_dir)
  142. logger.info(f"已清理临时目录: {self.temp_output_dir}")
  143. except Exception as e:
  144. logger.warning(f"清理临时目录失败: {e}")