pipeline_wrapper.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Pipeline包装器
  5. 封装EnhancedSearchV2,提供只执行阶段3-7的接口
  6. """
  7. import os
  8. import json
  9. import logging
  10. import tempfile
  11. from typing import Dict, List, Any, Optional
  12. from src.pipeline.feature_search_pipeline import EnhancedSearchV2
  13. from api.config import APIConfig
  14. logger = logging.getLogger(__name__)
  15. class PipelineWrapper:
  16. """Pipeline包装器,复用阶段3-7"""
  17. def __init__(self, output_dir: str, request_id: str = None):
  18. """
  19. 初始化Pipeline包装器
  20. Args:
  21. output_dir: 请求专用的输出目录(由RequestContext提供)
  22. request_id: 请求ID(用于日志标识)
  23. """
  24. self.output_dir = output_dir
  25. self.request_id = request_id or "unknown"
  26. logger.info(f"[{self.request_id}] 初始化Pipeline,输出目录: {output_dir}")
  27. # 创建占位符how.json(API模式不需要真实文件)
  28. temp_how_file = os.path.join(output_dir, 'placeholder_how.json')
  29. with open(temp_how_file, 'w', encoding='utf-8') as f:
  30. json.dump({'解构结果': {}}, f)
  31. # 初始化EnhancedSearchV2实例,使用传入的output_dir
  32. self.pipeline = EnhancedSearchV2(
  33. how_json_path=temp_how_file, # 占位符文件,实际不会使用
  34. openrouter_api_key=APIConfig.OPENROUTER_API_KEY,
  35. output_dir=output_dir, # 使用独立目录
  36. top_n=10,
  37. max_total_searches=APIConfig.MAX_TOTAL_SEARCHES,
  38. search_max_workers=APIConfig.SEARCH_MAX_WORKERS,
  39. max_searches_per_feature=APIConfig.MAX_SEARCHES_PER_FEATURE,
  40. max_searches_per_base_word=APIConfig.MAX_SEARCHES_PER_BASE_WORD,
  41. enable_evaluation=True,
  42. evaluation_max_workers=APIConfig.EVALUATION_MAX_WORKERS,
  43. evaluation_max_notes_per_query=APIConfig.EVALUATION_MAX_NOTES_PER_QUERY,
  44. enable_deep_analysis=True, # 启用深度解构
  45. deep_analysis_only=False,
  46. deep_analysis_max_workers=APIConfig.DEEP_ANALYSIS_MAX_WORKERS,
  47. deep_analysis_max_notes=None,
  48. deep_analysis_skip_count=0,
  49. deep_analysis_sort_by='score',
  50. deep_analysis_api_url=APIConfig.DEEP_ANALYSIS_API_URL,
  51. deep_analysis_min_score=APIConfig.DEEP_ANALYSIS_MIN_SCORE,
  52. enable_similarity=True, # 启用相似度分析
  53. similarity_weight_embedding=APIConfig.SIMILARITY_WEIGHT_EMBEDDING,
  54. similarity_weight_semantic=APIConfig.SIMILARITY_WEIGHT_SEMANTIC,
  55. similarity_max_workers=APIConfig.SIMILARITY_MAX_WORKERS,
  56. similarity_min_similarity=APIConfig.SIMILARITY_MIN_SIMILARITY
  57. )
  58. logger.info(f"[{self.request_id}] Pipeline包装器初始化完成")
  59. def run_stages_3_to_7_sync(
  60. self,
  61. features_data: List[Dict[str, Any]]
  62. ) -> Dict[str, Any]:
  63. """
  64. 执行阶段3-7的完整流程(同步版本,用于在线程池中执行)
  65. Args:
  66. features_data: 阶段2的输出格式数据(candidate_words.json格式)
  67. Returns:
  68. 包含阶段3-7结果的字典
  69. Raises:
  70. Exception: 当任何阶段执行失败时
  71. """
  72. try:
  73. logger.info("=" * 60)
  74. logger.info("开始执行阶段3-7")
  75. logger.info("=" * 60)
  76. # 验证输入数据
  77. if not features_data:
  78. raise ValueError("features_data不能为空")
  79. # 阶段3:多词组合 + LLM评估
  80. logger.info("阶段3:生成搜索词...")
  81. try:
  82. queries = self.pipeline.generate_search_queries(
  83. features_data,
  84. max_workers=APIConfig.QUERY_GENERATION_MAX_WORKERS,
  85. max_candidates=APIConfig.MAX_CANDIDATES,
  86. max_combo_length=APIConfig.MAX_COMBO_LENGTH
  87. )
  88. except Exception as e:
  89. logger.error(f"阶段3执行失败: {e}", exc_info=True)
  90. raise Exception(f"搜索词生成失败: {str(e)}")
  91. # 阶段4:执行搜索
  92. logger.info("阶段4:执行搜索...")
  93. try:
  94. search_results = self.pipeline.execute_search_queries(
  95. queries,
  96. search_delay=2.0,
  97. top_n=self.pipeline.top_n
  98. )
  99. except Exception as e:
  100. logger.error(f"阶段4执行失败: {e}", exc_info=True)
  101. raise Exception(f"搜索执行失败: {str(e)}")
  102. # 阶段5:LLM评估搜索结果
  103. logger.info("阶段5:评估搜索结果...")
  104. try:
  105. evaluation_results = self.pipeline.evaluate_search_results(search_results)
  106. except Exception as e:
  107. logger.error(f"阶段5执行失败: {e}", exc_info=True)
  108. raise Exception(f"结果评估失败: {str(e)}")
  109. # 阶段6:深度解构分析
  110. logger.info("阶段6:深度解构分析...")
  111. try:
  112. deep_results = self.pipeline.deep_analyzer.run(evaluation_results)
  113. except Exception as e:
  114. logger.error(f"阶段6执行失败: {e}", exc_info=True)
  115. raise Exception(f"深度解构分析失败: {str(e)}")
  116. # 阶段7:相似度分析
  117. logger.info("阶段7:相似度分析...")
  118. try:
  119. # 同步版本使用run方法
  120. similarity_results = self.pipeline.similarity_analyzer.run(
  121. deep_results,
  122. output_path=os.path.join(self.output_dir, "similarity_analysis_results.json")
  123. )
  124. except Exception as e:
  125. logger.error(f"阶段7执行失败: {e}", exc_info=True)
  126. raise Exception(f"相似度分析失败: {str(e)}")
  127. # 重要:similarity_analyzer.run_async会更新文件中的evaluation_results(添加comprehensive_score)
  128. # 需要重新加载更新后的文件,因为内存中的evaluation_results变量还没有被更新
  129. logger.info("重新加载更新后的评估结果(包含comprehensive_score)...")
  130. evaluated_results_path = os.path.join(self.output_dir, "evaluated_results.json")
  131. if os.path.exists(evaluated_results_path):
  132. with open(evaluated_results_path, 'r', encoding='utf-8') as f:
  133. evaluation_results = json.load(f)
  134. logger.info(f"已重新加载评估结果,包含 {len(evaluation_results)} 个原始特征")
  135. else:
  136. logger.warning(f"评估结果文件不存在: {evaluated_results_path},使用内存中的数据")
  137. logger.info("=" * 60)
  138. logger.info("阶段3-7执行完成")
  139. logger.info("=" * 60)
  140. return {
  141. 'evaluation_results': evaluation_results, # 已包含更新后的comprehensive_score
  142. 'deep_results': deep_results,
  143. 'similarity_results': similarity_results
  144. }
  145. except Exception as e:
  146. logger.error(f"执行阶段3-7失败: {e}", exc_info=True)
  147. raise