run_stage8.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """Stage 8 独立运行脚本"""
  4. import os
  5. import json
  6. import logging
  7. import argparse
  8. from stage8_similarity_analyzer import Stage8SimilarityAnalyzer
  9. def main():
  10. parser = argparse.ArgumentParser(
  11. description='Stage 8 解构特征相似度分析(独立运行)',
  12. formatter_class=argparse.RawDescriptionHelpFormatter,
  13. epilog="""
  14. 使用示例:
  15. # 基础用法 - 处理"墨镜"特征
  16. python3 run_stage8.py --feature "墨镜"
  17. # 处理多个特征
  18. python3 run_stage8.py --feature "墨镜" "耳环"
  19. # 自定义权重配置
  20. python3 run_stage8.py --feature "墨镜" --weight-embedding 0.7 --weight-semantic 0.3
  21. # 过滤低相似度特征
  22. python3 run_stage8.py --feature "墨镜" --min-similarity 0.3
  23. # 使用配置文件
  24. python3 run_stage8.py --config stage8_config.json
  25. # 自定义输入输出路径
  26. python3 run_stage8.py --input output_v2/stage7_custom.json --output output_v2/stage8_custom.json
  27. """
  28. )
  29. # 输入输出
  30. parser.add_argument(
  31. '--input',
  32. default='output_v2/stage7_with_deconstruction.json',
  33. help='Stage 7 结果文件路径(默认: output_v2/stage7_with_deconstruction.json)'
  34. )
  35. parser.add_argument(
  36. '--output',
  37. default='output_v2/stage8_similarity_scores.json',
  38. help='输出文件路径(默认: output_v2/stage8_similarity_scores.json)'
  39. )
  40. # 特征过滤
  41. parser.add_argument(
  42. '--feature',
  43. nargs='+',
  44. default=None,
  45. help='指定要处理的原始特征名称(可指定多个),如: --feature "墨镜" "耳环"'
  46. )
  47. # 相似度配置
  48. parser.add_argument(
  49. '--weight-embedding',
  50. type=float,
  51. default=0.5,
  52. help='向量模型权重(默认: 0.5)'
  53. )
  54. parser.add_argument(
  55. '--weight-semantic',
  56. type=float,
  57. default=0.5,
  58. help='LLM 模型权重(默认: 0.5)'
  59. )
  60. parser.add_argument(
  61. '--min-similarity',
  62. type=float,
  63. default=0.0,
  64. help='最小相似度阈值,低于此值的特征会被过滤(默认: 0.0,保留所有)'
  65. )
  66. # 并发配置
  67. parser.add_argument(
  68. '--max-workers',
  69. type=int,
  70. default=5,
  71. help='最大并发数(默认: 5)'
  72. )
  73. # 综合得分P计算配置
  74. parser.add_argument(
  75. '--stage6-path',
  76. default='output_v2/stage6_with_evaluations.json',
  77. help='Stage 6 数据文件路径,用于计算综合得分P(默认: output_v2/stage6_with_evaluations.json)'
  78. )
  79. parser.add_argument(
  80. '--no-update-stage6',
  81. action='store_true',
  82. help='不计算和更新综合得分P(默认会计算)'
  83. )
  84. # 配置文件
  85. parser.add_argument(
  86. '--config',
  87. help='从配置文件读取参数(JSON 格式)'
  88. )
  89. # 日志级别
  90. parser.add_argument(
  91. '--log-level',
  92. default='INFO',
  93. choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
  94. help='日志级别(默认: INFO)'
  95. )
  96. args = parser.parse_args()
  97. # 配置日志
  98. logging.basicConfig(
  99. level=getattr(logging, args.log_level),
  100. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
  101. )
  102. logger = logging.getLogger(__name__)
  103. # 如果提供了配置文件,从文件读取参数
  104. if args.config:
  105. logger.info(f"从配置文件读取参数: {args.config}")
  106. try:
  107. with open(args.config, 'r', encoding='utf-8') as f:
  108. config = json.load(f)
  109. # 配置文件中的参数会覆盖命令行默认值,但不会覆盖用户显式指定的命令行参数
  110. args.input = config.get('input', args.input)
  111. args.output = config.get('output', args.output)
  112. args.feature = config.get('feature', args.feature)
  113. args.weight_embedding = config.get('weight_embedding', args.weight_embedding)
  114. args.weight_semantic = config.get('weight_semantic', args.weight_semantic)
  115. args.min_similarity = config.get('min_similarity', args.min_similarity)
  116. args.max_workers = config.get('max_workers', args.max_workers)
  117. args.stage6_path = config.get('stage6_path', args.stage6_path)
  118. if 'no_update_stage6' in config:
  119. args.no_update_stage6 = config.get('no_update_stage6', args.no_update_stage6)
  120. except Exception as e:
  121. logger.error(f"读取配置文件失败: {e}")
  122. return 1
  123. # 验证输入文件
  124. if not os.path.exists(args.input):
  125. logger.error(f"输入文件不存在: {args.input}")
  126. return 1
  127. # 读取 Stage 7 结果
  128. logger.info(f"读取 Stage 7 结果: {args.input}")
  129. try:
  130. with open(args.input, 'r', encoding='utf-8') as f:
  131. stage7_results = json.load(f)
  132. except Exception as e:
  133. logger.error(f"读取 Stage 7 结果失败: {e}")
  134. return 1
  135. # 打印配置信息
  136. logger.info("\n" + "=" * 60)
  137. logger.info("Stage 8 配置:")
  138. logger.info("=" * 60)
  139. logger.info(f"输入文件: {args.input}")
  140. logger.info(f"输出文件: {args.output}")
  141. if args.feature:
  142. logger.info(f"目标特征: {', '.join(args.feature)}")
  143. else:
  144. logger.info(f"目标特征: 全部")
  145. logger.info(f"向量模型权重: {args.weight_embedding}")
  146. logger.info(f"LLM 模型权重: {args.weight_semantic}")
  147. logger.info(f"最小相似度阈值: {args.min_similarity}")
  148. logger.info(f"最大并发数: {args.max_workers}")
  149. logger.info(f"Stage 6 文件路径: {args.stage6_path}")
  150. logger.info(f"计算综合得分P: {'否' if args.no_update_stage6 else '是'}")
  151. logger.info("=" * 60 + "\n")
  152. # 创建分析器
  153. try:
  154. analyzer = Stage8SimilarityAnalyzer(
  155. weight_embedding=args.weight_embedding,
  156. weight_semantic=args.weight_semantic,
  157. max_workers=args.max_workers,
  158. min_similarity=args.min_similarity,
  159. target_features=args.feature,
  160. stage6_path=args.stage6_path,
  161. update_stage6=not args.no_update_stage6
  162. )
  163. except Exception as e:
  164. logger.error(f"创建分析器失败: {e}")
  165. return 1
  166. # 运行分析
  167. try:
  168. stage8_results = analyzer.run(stage7_results, output_path=args.output)
  169. # 打印摘要
  170. logger.info("\n" + "=" * 60)
  171. logger.info("Stage 8 执行完成")
  172. logger.info("=" * 60)
  173. metadata = stage8_results['metadata']
  174. overall_stats = metadata['overall_statistics']
  175. logger.info(f"处理帖子数: {overall_stats['total_notes']}")
  176. logger.info(f"提取特征总数: {overall_stats['total_features_extracted']}")
  177. logger.info(f"平均特征数/帖子: {overall_stats['avg_features_per_note']}")
  178. logger.info(f"平均最高相似度: {overall_stats['avg_max_similarity']}")
  179. logger.info(f"包含高相似度特征的帖子: {overall_stats['notes_with_high_similarity']}")
  180. logger.info(f"总耗时: {metadata['processing_time_seconds']}秒")
  181. logger.info(f"结果已保存: {args.output}")
  182. logger.info("=" * 60 + "\n")
  183. # 打印 Top 5 高相似度特征示例
  184. if stage8_results['results']:
  185. logger.info("Top 5 高相似度特征示例:")
  186. all_features = []
  187. for result in stage8_results['results']:
  188. for feat in result['deconstructed_features'][:5]: # 每个帖子取前5个
  189. all_features.append({
  190. 'note_id': result['note_id'],
  191. 'feature_name': feat['feature_name'],
  192. 'dimension': feat['dimension'],
  193. 'similarity': feat['similarity_score']
  194. })
  195. # 按相似度排序,取 Top 5
  196. all_features.sort(key=lambda x: x['similarity'], reverse=True)
  197. for i, feat in enumerate(all_features[:5], 1):
  198. logger.info(f" {i}. [{feat['note_id'][:12]}...] "
  199. f"{feat['feature_name']} ({feat['dimension']}) "
  200. f"- 相似度: {feat['similarity']:.3f}")
  201. return 0
  202. except Exception as e:
  203. logger.error(f"Stage 8 执行失败: {e}", exc_info=True)
  204. return 1
  205. if __name__ == '__main__':
  206. exit(main())