pipeline_wrapper.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Pipeline包装器
  5. 封装EnhancedSearchV2,提供只执行阶段3-7的接口
  6. """
  7. import os
  8. import json
  9. import logging
  10. import tempfile
  11. from typing import Dict, List, Any, Optional
  12. from src.pipeline.feature_search_pipeline import EnhancedSearchV2
  13. from api.config import APIConfig
  14. logger = logging.getLogger(__name__)
  15. class PipelineWrapper:
  16. """Pipeline包装器,复用阶段3-7"""
  17. def __init__(self):
  18. """初始化Pipeline包装器"""
  19. # 创建临时输出目录
  20. self.temp_output_dir = tempfile.mkdtemp(prefix='api_pipeline_')
  21. logger.info(f"创建临时输出目录: {self.temp_output_dir}")
  22. # 初始化EnhancedSearchV2实例
  23. # 注意:how_json_path参数是必需的,但我们不会使用它(因为我们跳过阶段1-2)
  24. # 创建一个空的临时文件作为占位符
  25. temp_how_file = os.path.join(self.temp_output_dir, 'temp_how.json')
  26. with open(temp_how_file, 'w', encoding='utf-8') as f:
  27. import json
  28. json.dump({'解构结果': {}}, f)
  29. self.pipeline = EnhancedSearchV2(
  30. how_json_path=temp_how_file, # 占位符文件,实际不会使用
  31. openrouter_api_key=APIConfig.OPENROUTER_API_KEY,
  32. output_dir=self.temp_output_dir,
  33. top_n=10,
  34. max_total_searches=APIConfig.MAX_TOTAL_SEARCHES,
  35. search_max_workers=APIConfig.SEARCH_MAX_WORKERS,
  36. max_searches_per_feature=APIConfig.MAX_SEARCHES_PER_FEATURE,
  37. max_searches_per_base_word=APIConfig.MAX_SEARCHES_PER_BASE_WORD,
  38. enable_evaluation=True,
  39. evaluation_max_workers=APIConfig.EVALUATION_MAX_WORKERS,
  40. evaluation_max_notes_per_query=APIConfig.EVALUATION_MAX_NOTES_PER_QUERY,
  41. enable_deep_analysis=True, # 启用深度解构
  42. deep_analysis_only=False,
  43. deep_analysis_max_workers=APIConfig.DEEP_ANALYSIS_MAX_WORKERS,
  44. deep_analysis_max_notes=None,
  45. deep_analysis_skip_count=0,
  46. deep_analysis_sort_by='score',
  47. deep_analysis_api_url=APIConfig.DEEP_ANALYSIS_API_URL,
  48. deep_analysis_min_score=APIConfig.DEEP_ANALYSIS_MIN_SCORE,
  49. enable_similarity=True, # 启用相似度分析
  50. similarity_weight_embedding=APIConfig.SIMILARITY_WEIGHT_EMBEDDING,
  51. similarity_weight_semantic=APIConfig.SIMILARITY_WEIGHT_SEMANTIC,
  52. similarity_max_workers=APIConfig.SIMILARITY_MAX_WORKERS,
  53. similarity_min_similarity=APIConfig.SIMILARITY_MIN_SIMILARITY
  54. )
  55. logger.info("Pipeline包装器初始化完成")
  56. async def run_stages_3_to_7(
  57. self,
  58. features_data: List[Dict[str, Any]]
  59. ) -> Dict[str, Any]:
  60. """
  61. 执行阶段3-7的完整流程
  62. Args:
  63. features_data: 阶段2的输出格式数据(candidate_words.json格式)
  64. Returns:
  65. 包含阶段3-7结果的字典
  66. Raises:
  67. Exception: 当任何阶段执行失败时
  68. """
  69. try:
  70. logger.info("=" * 60)
  71. logger.info("开始执行阶段3-7")
  72. logger.info("=" * 60)
  73. # 验证输入数据
  74. if not features_data:
  75. raise ValueError("features_data不能为空")
  76. # 阶段3:多词组合 + LLM评估
  77. logger.info("阶段3:生成搜索词...")
  78. try:
  79. queries = self.pipeline.generate_search_queries(
  80. features_data,
  81. max_workers=APIConfig.QUERY_GENERATION_MAX_WORKERS,
  82. max_candidates=APIConfig.MAX_CANDIDATES,
  83. max_combo_length=APIConfig.MAX_COMBO_LENGTH
  84. )
  85. except Exception as e:
  86. logger.error(f"阶段3执行失败: {e}", exc_info=True)
  87. raise Exception(f"搜索词生成失败: {str(e)}")
  88. # 阶段4:执行搜索
  89. logger.info("阶段4:执行搜索...")
  90. try:
  91. search_results = self.pipeline.execute_search_queries(
  92. queries,
  93. search_delay=2.0,
  94. top_n=self.pipeline.top_n
  95. )
  96. except Exception as e:
  97. logger.error(f"阶段4执行失败: {e}", exc_info=True)
  98. raise Exception(f"搜索执行失败: {str(e)}")
  99. # 阶段5:LLM评估搜索结果
  100. logger.info("阶段5:评估搜索结果...")
  101. try:
  102. evaluation_results = self.pipeline.evaluate_search_results(search_results)
  103. except Exception as e:
  104. logger.error(f"阶段5执行失败: {e}", exc_info=True)
  105. raise Exception(f"结果评估失败: {str(e)}")
  106. # 阶段6:深度解构分析
  107. logger.info("阶段6:深度解构分析...")
  108. try:
  109. deep_results = self.pipeline.deep_analyzer.run(evaluation_results)
  110. except Exception as e:
  111. logger.error(f"阶段6执行失败: {e}", exc_info=True)
  112. raise Exception(f"深度解构分析失败: {str(e)}")
  113. # 阶段7:相似度分析
  114. logger.info("阶段7:相似度分析...")
  115. try:
  116. # 在异步环境中直接调用run_async而不是run
  117. similarity_results = await self.pipeline.similarity_analyzer.run_async(
  118. deep_results,
  119. output_path=os.path.join(self.temp_output_dir, "similarity_analysis_results.json")
  120. )
  121. except Exception as e:
  122. logger.error(f"阶段7执行失败: {e}", exc_info=True)
  123. raise Exception(f"相似度分析失败: {str(e)}")
  124. # 重要:similarity_analyzer.run_async会更新文件中的evaluation_results(添加comprehensive_score)
  125. # 需要重新加载更新后的文件,因为内存中的evaluation_results变量还没有被更新
  126. logger.info("重新加载更新后的评估结果(包含comprehensive_score)...")
  127. evaluated_results_path = os.path.join(self.temp_output_dir, "evaluated_results.json")
  128. if os.path.exists(evaluated_results_path):
  129. with open(evaluated_results_path, 'r', encoding='utf-8') as f:
  130. evaluation_results = json.load(f)
  131. logger.info(f"已重新加载评估结果,包含 {len(evaluation_results)} 个原始特征")
  132. else:
  133. logger.warning(f"评估结果文件不存在: {evaluated_results_path},使用内存中的数据")
  134. logger.info("=" * 60)
  135. logger.info("阶段3-7执行完成")
  136. logger.info("=" * 60)
  137. return {
  138. 'evaluation_results': evaluation_results, # 已包含更新后的comprehensive_score
  139. 'deep_results': deep_results,
  140. 'similarity_results': similarity_results
  141. }
  142. except Exception as e:
  143. logger.error(f"执行阶段3-7失败: {e}", exc_info=True)
  144. raise
  145. def cleanup(self):
  146. """清理临时文件"""
  147. try:
  148. import shutil
  149. if os.path.exists(self.temp_output_dir):
  150. shutil.rmtree(self.temp_output_dir)
  151. logger.info(f"已清理临时目录: {self.temp_output_dir}")
  152. except Exception as e:
  153. logger.warning(f"清理临时目录失败: {e}")