stage8_similarity_analyzer.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Stage 8 相似度分析器
  5. 计算 Stage 7 解构特征与原始特征的相似度评分
  6. """
  7. import os
  8. import json
  9. import time
  10. import logging
  11. import asyncio
  12. from datetime import datetime
  13. from typing import Dict, List, Any, Optional
  14. from lib.hybrid_similarity import compare_phrases_cartesian
  15. from lib.config import get_cache_dir
  16. try:
  17. from tqdm import tqdm
  18. TQDM_AVAILABLE = True
  19. except ImportError:
  20. TQDM_AVAILABLE = False
  21. logger = logging.getLogger(__name__)
  22. def extract_deconstructed_features(api_response: Dict) -> List[Dict]:
  23. """
  24. 从三点解构中提取所有特征
  25. Args:
  26. api_response: Stage 7 的 api_response 对象
  27. Returns:
  28. 特征列表,每个特征包含:
  29. - feature_name: 特征名称
  30. - dimension: 维度 (灵感点-全新内容/灵感点-共性差异/灵感点-共性内容/目的点/关键点)
  31. - dimension_detail: 维度细分 (实质/形式/意图等)
  32. - weight: 权重
  33. - source_index: 在该维度中的索引
  34. - source_*: 溯源信息 (候选编号、目的点描述、关键点描述等)
  35. """
  36. features = []
  37. # 检查 API 响应状态
  38. if api_response.get('status') != 'success':
  39. logger.warning(" API 响应状态不是 success,无法提取特征")
  40. return features
  41. result = api_response.get('result', {})
  42. # 检查是否有 data 字段
  43. if 'data' not in result:
  44. logger.warning(" API 响应中没有 data 字段")
  45. return features
  46. data = result['data']
  47. three_point = data.get('三点解构', {})
  48. if not three_point:
  49. logger.warning(" 三点解构数据为空")
  50. return features
  51. # 1. 提取灵感点 (3个子类别)
  52. inspiration = three_point.get('灵感点', {})
  53. for category in ['全新内容', '共性差异', '共性内容']:
  54. items = inspiration.get(category, [])
  55. for idx, item in enumerate(items):
  56. extracted_features = item.get('提取的特征', [])
  57. for feat in extracted_features:
  58. feature_name = feat.get('特征名称', '')
  59. if not feature_name:
  60. continue
  61. features.append({
  62. 'feature_name': feature_name,
  63. 'dimension': f'灵感点-{category}',
  64. 'dimension_detail': feat.get('维度分类', ''), # 注意字段名
  65. 'weight': feat.get('权重', 0),
  66. 'source_index': idx,
  67. 'source_candidate_number': item.get('候选编号', 0),
  68. 'source_inspiration': item.get('灵感点', '')
  69. })
  70. # 2. 提取目的点
  71. purpose = three_point.get('目的点', {})
  72. purposes_list = purpose.get('purposes', [])
  73. for idx, item in enumerate(purposes_list):
  74. extracted_features = item.get('提取的特征', [])
  75. for feat in extracted_features:
  76. feature_name = feat.get('特征名称', '')
  77. if not feature_name:
  78. continue
  79. features.append({
  80. 'feature_name': feature_name,
  81. 'dimension': '目的点',
  82. 'dimension_detail': feat.get('特征分类', ''), # 注意字段名
  83. 'weight': feat.get('权重', 0),
  84. 'source_index': idx,
  85. 'source_purpose': item.get('目的点', ''),
  86. 'source_purpose_dimension': item.get('维度', {})
  87. })
  88. # 3. 提取关键点
  89. key_points_data = three_point.get('关键点', {})
  90. key_points_list = key_points_data.get('key_points', [])
  91. for idx, item in enumerate(key_points_list):
  92. extracted_features = item.get('提取的特征', [])
  93. for feat in extracted_features:
  94. feature_name = feat.get('特征名称', '')
  95. if not feature_name:
  96. continue
  97. features.append({
  98. 'feature_name': feature_name,
  99. 'dimension': '关键点',
  100. 'dimension_detail': feat.get('维度', ''), # 注意字段名
  101. 'weight': feat.get('权重', 0),
  102. 'source_index': idx,
  103. 'source_candidate_number': item.get('候选编号', 0),
  104. 'source_key_point': item.get('关键点', ''),
  105. 'source_key_point_dimension': item.get('维度', '')
  106. })
  107. logger.info(f" 提取特征数量: {len(features)}")
  108. if features:
  109. # 统计各维度数量
  110. dimension_counts = {}
  111. for feat in features:
  112. dim = feat['dimension']
  113. dimension_counts[dim] = dimension_counts.get(dim, 0) + 1
  114. logger.info(f" 维度分布: {dimension_counts}")
  115. return features
  116. async def calculate_similarity_for_note(
  117. note_result: Dict,
  118. original_feature: str,
  119. weight_embedding: float = 0.5,
  120. weight_semantic: float = 0.5,
  121. min_similarity: float = 0.0
  122. ) -> Dict:
  123. """
  124. 计算单个帖子的所有特征与原始特征的相似度
  125. Args:
  126. note_result: Stage 7 的单个 result 对象
  127. original_feature: 原始特征名称
  128. weight_embedding: 向量模型权重
  129. weight_semantic: LLM 模型权重
  130. min_similarity: 最小相似度阈值,低于此值的特征会被过滤
  131. Returns:
  132. 包含相似度信息的结果对象
  133. """
  134. note_id = note_result.get('note_id', '')
  135. logger.info(f" [{note_id}] 开始计算相似度...")
  136. # 1. 提取解构特征
  137. deconstructed_features = extract_deconstructed_features(
  138. note_result['api_response']
  139. )
  140. if not deconstructed_features:
  141. logger.warning(f" [{note_id}] 没有提取到特征")
  142. return {
  143. 'note_id': note_id,
  144. 'original_feature': original_feature,
  145. 'evaluation_score': note_result.get('evaluation_score', 0),
  146. 'search_word': note_result.get('search_word', ''),
  147. 'note_data': note_result.get('note_data', {}),
  148. 'deconstructed_features': [],
  149. 'similarity_statistics': {
  150. 'total_features': 0,
  151. 'max_similarity': 0,
  152. 'min_similarity': 0,
  153. 'avg_similarity': 0,
  154. 'high_similarity_count': 0,
  155. 'medium_similarity_count': 0,
  156. 'low_similarity_count': 0
  157. }
  158. }
  159. # 2. 构建特征名称列表
  160. feature_names = [f['feature_name'] for f in deconstructed_features]
  161. logger.info(f" [{note_id}] 调用相似度计算 API (1×{len(feature_names)} 笛卡尔积)...")
  162. # 3. 批量计算相似度 (1×N 笛卡尔积)
  163. try:
  164. start_time = time.time()
  165. similarity_results = await compare_phrases_cartesian(
  166. phrases_a=[original_feature],
  167. phrases_b=feature_names,
  168. max_concurrent=50
  169. )
  170. elapsed = time.time() - start_time
  171. logger.info(f" [{note_id}] 相似度计算完成 ({elapsed:.1f}秒)")
  172. # 4. 映射结果回特征对象
  173. for i, feat in enumerate(deconstructed_features):
  174. feat['similarity_score'] = similarity_results[0][i]['相似度']
  175. feat['similarity_explanation'] = similarity_results[0][i]['说明']
  176. # 5. 过滤低相似度特征
  177. if min_similarity > 0:
  178. original_count = len(deconstructed_features)
  179. deconstructed_features = [
  180. f for f in deconstructed_features
  181. if f['similarity_score'] >= min_similarity
  182. ]
  183. filtered_count = original_count - len(deconstructed_features)
  184. if filtered_count > 0:
  185. logger.info(f" [{note_id}] 过滤掉 {filtered_count} 个低相似度特征 (< {min_similarity})")
  186. # 6. 计算统计信息
  187. if deconstructed_features:
  188. scores = [f['similarity_score'] for f in deconstructed_features]
  189. statistics = {
  190. 'total_features': len(scores),
  191. 'max_similarity': round(max(scores), 3),
  192. 'min_similarity': round(min(scores), 3),
  193. 'avg_similarity': round(sum(scores) / len(scores), 3),
  194. 'high_similarity_count': sum(1 for s in scores if s >= 0.7),
  195. 'medium_similarity_count': sum(1 for s in scores if 0.5 <= s < 0.7),
  196. 'low_similarity_count': sum(1 for s in scores if s < 0.5)
  197. }
  198. # 7. 按相似度降序排序
  199. deconstructed_features.sort(key=lambda x: x['similarity_score'], reverse=True)
  200. logger.info(f" [{note_id}] 统计: 最高={statistics['max_similarity']}, "
  201. f"平均={statistics['avg_similarity']}, "
  202. f"高相似度={statistics['high_similarity_count']}个")
  203. else:
  204. statistics = {
  205. 'total_features': 0,
  206. 'max_similarity': 0,
  207. 'min_similarity': 0,
  208. 'avg_similarity': 0,
  209. 'high_similarity_count': 0,
  210. 'medium_similarity_count': 0,
  211. 'low_similarity_count': 0
  212. }
  213. return {
  214. 'note_id': note_id,
  215. 'original_feature': original_feature,
  216. 'evaluation_score': note_result.get('evaluation_score', 0),
  217. 'search_word': note_result.get('search_word', ''),
  218. 'note_data': note_result.get('note_data', {}),
  219. 'deconstructed_features': deconstructed_features,
  220. 'similarity_statistics': statistics,
  221. 'processing_time_seconds': round(elapsed, 2)
  222. }
  223. except Exception as e:
  224. logger.error(f" [{note_id}] 相似度计算失败: {e}")
  225. return {
  226. 'note_id': note_id,
  227. 'original_feature': original_feature,
  228. 'evaluation_score': note_result.get('evaluation_score', 0),
  229. 'search_word': note_result.get('search_word', ''),
  230. 'note_data': note_result.get('note_data', {}),
  231. 'deconstructed_features': [],
  232. 'similarity_statistics': {
  233. 'total_features': 0,
  234. 'error': str(e)
  235. }
  236. }
  237. class Stage8SimilarityAnalyzer:
  238. """Stage 8: 解构特征与原始特征的相似度分析"""
  239. def __init__(
  240. self,
  241. weight_embedding: float = 0.5,
  242. weight_semantic: float = 0.5,
  243. max_workers: int = 5,
  244. min_similarity: float = 0.0,
  245. output_dir: str = "output_v2",
  246. target_features: Optional[List[str]] = None
  247. ):
  248. """
  249. 初始化 Stage 8 分析器
  250. Args:
  251. weight_embedding: 向量模型权重(默认 0.5)
  252. weight_semantic: LLM 模型权重(默认 0.5)
  253. max_workers: 最大并发数(默认 5)
  254. min_similarity: 最小相似度阈值(默认 0.0,保留所有特征)
  255. output_dir: 输出目录
  256. target_features: 指定要处理的原始特征列表(None = 处理所有特征)
  257. """
  258. self.weight_embedding = weight_embedding
  259. self.weight_semantic = weight_semantic
  260. self.max_workers = max_workers
  261. self.min_similarity = min_similarity
  262. self.output_dir = output_dir
  263. self.target_features = target_features
  264. # 验证权重
  265. total_weight = weight_embedding + weight_semantic
  266. if abs(total_weight - 1.0) > 0.001:
  267. raise ValueError(f"权重之和必须为1.0,当前为: {total_weight}")
  268. def _save_intermediate_results(
  269. self,
  270. results: List[Dict],
  271. output_path: str,
  272. processed_count: int,
  273. total_count: int,
  274. start_time: float
  275. ):
  276. """保存中间结果"""
  277. base_dir = os.path.dirname(output_path) or self.output_dir
  278. base_name = os.path.basename(output_path)
  279. name_without_ext = os.path.splitext(base_name)[0]
  280. intermediate_path = os.path.join(
  281. base_dir,
  282. f"{name_without_ext}_partial_{processed_count}of{total_count}.json"
  283. )
  284. # 统计
  285. total_features = sum(r['similarity_statistics']['total_features'] for r in results)
  286. avg_max_sim = sum(r['similarity_statistics']['max_similarity'] for r in results) / len(results)
  287. intermediate_result = {
  288. 'metadata': {
  289. 'stage': 'stage8_partial',
  290. 'description': f'部分结果({processed_count}/{total_count})',
  291. 'processed_notes': len(results),
  292. 'total_features_extracted': total_features,
  293. 'avg_max_similarity': round(avg_max_sim, 3),
  294. 'saved_at': datetime.now().isoformat(),
  295. 'processing_time_seconds': round(time.time() - start_time, 2)
  296. },
  297. 'results': results
  298. }
  299. os.makedirs(base_dir, exist_ok=True)
  300. with open(intermediate_path, 'w', encoding='utf-8') as f:
  301. json.dump(intermediate_result, f, ensure_ascii=False, indent=2)
  302. logger.info(f" 已保存中间结果: {intermediate_path}")
  303. async def run_async(
  304. self,
  305. stage7_results: Dict,
  306. output_path: Optional[str] = None
  307. ) -> Dict:
  308. """
  309. 执行 Stage 8 相似度分析(异步版本)
  310. Args:
  311. stage7_results: Stage 7 结果
  312. output_path: 输出路径(可选)
  313. Returns:
  314. Stage 8 结果
  315. """
  316. logger.info("\n" + "=" * 60)
  317. logger.info("Stage 8: 解构特征与原始特征的相似度分析")
  318. logger.info("=" * 60)
  319. # 打印配置
  320. logger.info("配置参数:")
  321. logger.info(f" 向量模型权重: {self.weight_embedding}")
  322. logger.info(f" LLM 模型权重: {self.weight_semantic}")
  323. logger.info(f" 最大并发数: {self.max_workers}")
  324. logger.info(f" 最小相似度阈值: {self.min_similarity}")
  325. if self.target_features:
  326. logger.info(f" 目标特征: {', '.join(self.target_features)}")
  327. else:
  328. logger.info(f" 目标特征: 全部")
  329. # 默认输出路径
  330. if output_path is None:
  331. output_path = os.path.join(self.output_dir, "stage8_similarity_scores.json")
  332. # 提取 Stage 7 结果
  333. results_list = stage7_results.get('results', [])
  334. # 过滤目标特征
  335. if self.target_features:
  336. results_list = [
  337. r for r in results_list
  338. if r.get('original_feature') in self.target_features
  339. ]
  340. total_notes = len(results_list)
  341. logger.info(f" 待处理帖子数: {total_notes}")
  342. if total_notes == 0:
  343. logger.warning(" 没有需要处理的帖子")
  344. return {
  345. 'metadata': {
  346. 'stage': 'stage8',
  347. 'processed_notes': 0
  348. },
  349. 'results': []
  350. }
  351. # 创建任务列表
  352. start_time = time.time()
  353. results = []
  354. # 使用 Semaphore 控制并发数
  355. semaphore = asyncio.Semaphore(self.max_workers)
  356. async def bounded_task(result):
  357. async with semaphore:
  358. return await calculate_similarity_for_note(
  359. result,
  360. result.get('original_feature', ''),
  361. self.weight_embedding,
  362. self.weight_semantic,
  363. self.min_similarity
  364. )
  365. tasks = [bounded_task(result) for result in results_list]
  366. # 带进度条执行
  367. if TQDM_AVAILABLE:
  368. logger.info(" 使用进度条显示...")
  369. processed_count = 0
  370. save_interval = 10
  371. for coro in tqdm(
  372. asyncio.as_completed(tasks),
  373. total=len(tasks),
  374. desc=" 相似度计算进度",
  375. unit="帖子",
  376. ncols=100
  377. ):
  378. result = await coro
  379. results.append(result)
  380. processed_count += 1
  381. # 增量保存
  382. if processed_count % save_interval == 0:
  383. self._save_intermediate_results(
  384. results,
  385. output_path,
  386. processed_count,
  387. total_notes,
  388. start_time
  389. )
  390. else:
  391. # 简单执行
  392. results = await asyncio.gather(*tasks)
  393. logger.info(f" 完成: {len(results)}/{total_notes}")
  394. processing_time = time.time() - start_time
  395. # 计算总体统计
  396. total_features = sum(r['similarity_statistics']['total_features'] for r in results)
  397. all_max_similarities = [r['similarity_statistics']['max_similarity'] for r in results if r['similarity_statistics']['total_features'] > 0]
  398. overall_stats = {
  399. 'total_notes': total_notes,
  400. 'total_features_extracted': total_features,
  401. 'avg_features_per_note': round(total_features / total_notes, 1) if total_notes > 0 else 0,
  402. 'avg_max_similarity': round(sum(all_max_similarities) / len(all_max_similarities), 3) if all_max_similarities else 0,
  403. 'notes_with_high_similarity': sum(1 for r in results if r['similarity_statistics'].get('high_similarity_count', 0) > 0)
  404. }
  405. logger.info(f"\n 总耗时: {processing_time:.1f}秒")
  406. logger.info(f" 总特征数: {total_features}")
  407. logger.info(f" 平均特征数/帖子: {overall_stats['avg_features_per_note']}")
  408. logger.info(f" 平均最高相似度: {overall_stats['avg_max_similarity']}")
  409. logger.info(f" 包含高相似度特征的帖子: {overall_stats['notes_with_high_similarity']}")
  410. # 构建最终结果
  411. final_result = {
  412. 'metadata': {
  413. 'stage': 'stage8',
  414. 'description': '解构特征与原始特征的相似度评分',
  415. 'source_file': stage7_results.get('metadata', {}).get('created_at', ''),
  416. 'target_features': self.target_features if self.target_features else '全部',
  417. 'similarity_config': {
  418. 'algorithm': 'hybrid_similarity',
  419. 'weight_embedding': self.weight_embedding,
  420. 'weight_semantic': self.weight_semantic,
  421. 'min_similarity_threshold': self.min_similarity
  422. },
  423. 'overall_statistics': overall_stats,
  424. 'created_at': datetime.now().isoformat(),
  425. 'processing_time_seconds': round(processing_time, 2)
  426. },
  427. 'results': results
  428. }
  429. # 保存结果
  430. os.makedirs(os.path.dirname(output_path) or self.output_dir, exist_ok=True)
  431. with open(output_path, 'w', encoding='utf-8') as f:
  432. json.dump(final_result, f, ensure_ascii=False, indent=2)
  433. logger.info(f" 结果已保存: {output_path}")
  434. return final_result
  435. def run(
  436. self,
  437. stage7_results: Dict,
  438. output_path: Optional[str] = None
  439. ) -> Dict:
  440. """
  441. 执行 Stage 8 相似度分析(同步版本)
  442. Args:
  443. stage7_results: Stage 7 结果
  444. output_path: 输出路径(可选)
  445. Returns:
  446. Stage 8 结果
  447. """
  448. return asyncio.run(self.run_async(stage7_results, output_path))
  449. def test_stage8_analyzer():
  450. """测试 Stage 8 分析器"""
  451. # 读取 Stage 7 结果
  452. stage7_path = "output_v2/stage7_with_deconstruction.json"
  453. if not os.path.exists(stage7_path):
  454. print(f"Stage 7 结果不存在: {stage7_path}")
  455. return
  456. with open(stage7_path, 'r', encoding='utf-8') as f:
  457. stage7_results = json.load(f)
  458. # 创建分析器
  459. analyzer = Stage8SimilarityAnalyzer(
  460. weight_embedding=0.5,
  461. weight_semantic=0.5,
  462. max_workers=3,
  463. min_similarity=0.3,
  464. target_features=["墨镜"]
  465. )
  466. # 运行分析
  467. stage8_results = analyzer.run(stage7_results)
  468. print(f"\n处理了 {stage8_results['metadata']['overall_statistics']['total_notes']} 个帖子")
  469. print(f"提取了 {stage8_results['metadata']['overall_statistics']['total_features_extracted']} 个特征")
  470. print(f"平均最高相似度: {stage8_results['metadata']['overall_statistics']['avg_max_similarity']}")
  471. if __name__ == '__main__':
  472. logging.basicConfig(
  473. level=logging.INFO,
  474. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
  475. )
  476. test_stage8_analyzer()