刘立冬 3 viikkoa sitten
vanhempi
commit
abc89e41bb
2 muutettua tiedostoa jossa 218 lisäystä ja 2 poistoa
  1. 20 1
      run_stage8.py
  2. 198 1
      stage8_similarity_analyzer.py

+ 20 - 1
run_stage8.py

@@ -83,6 +83,18 @@ def main():
         help='最大并发数(默认: 5)'
     )
 
+    # 综合得分P计算配置
+    parser.add_argument(
+        '--stage6-path',
+        default='output_v2/stage6_with_evaluations.json',
+        help='Stage 6 数据文件路径,用于计算综合得分P(默认: output_v2/stage6_with_evaluations.json)'
+    )
+    parser.add_argument(
+        '--no-update-stage6',
+        action='store_true',
+        help='不计算和更新综合得分P(默认会计算)'
+    )
+
     # 配置文件
     parser.add_argument(
         '--config',
@@ -121,6 +133,9 @@ def main():
             args.weight_semantic = config.get('weight_semantic', args.weight_semantic)
             args.min_similarity = config.get('min_similarity', args.min_similarity)
             args.max_workers = config.get('max_workers', args.max_workers)
+            args.stage6_path = config.get('stage6_path', args.stage6_path)
+            if 'no_update_stage6' in config:
+                args.no_update_stage6 = config.get('no_update_stage6', args.no_update_stage6)
 
         except Exception as e:
             logger.error(f"读取配置文件失败: {e}")
@@ -154,6 +169,8 @@ def main():
     logger.info(f"LLM 模型权重: {args.weight_semantic}")
     logger.info(f"最小相似度阈值: {args.min_similarity}")
     logger.info(f"最大并发数: {args.max_workers}")
+    logger.info(f"Stage 6 文件路径: {args.stage6_path}")
+    logger.info(f"计算综合得分P: {'否' if args.no_update_stage6 else '是'}")
     logger.info("=" * 60 + "\n")
 
     # 创建分析器
@@ -163,7 +180,9 @@ def main():
             weight_semantic=args.weight_semantic,
             max_workers=args.max_workers,
             min_similarity=args.min_similarity,
-            target_features=args.feature
+            target_features=args.feature,
+            stage6_path=args.stage6_path,
+            update_stage6=not args.no_update_stage6
         )
     except Exception as e:
         logger.error(f"创建分析器失败: {e}")

+ 198 - 1
stage8_similarity_analyzer.py

@@ -283,7 +283,9 @@ class Stage8SimilarityAnalyzer:
         max_workers: int = 5,
         min_similarity: float = 0.0,
         output_dir: str = "output_v2",
-        target_features: Optional[List[str]] = None
+        target_features: Optional[List[str]] = None,
+        stage6_path: str = 'output_v2/stage6_with_evaluations.json',
+        update_stage6: bool = True
     ):
         """
         初始化 Stage 8 分析器
@@ -295,6 +297,8 @@ class Stage8SimilarityAnalyzer:
             min_similarity: 最小相似度阈值(默认 0.0,保留所有特征)
             output_dir: 输出目录
             target_features: 指定要处理的原始特征列表(None = 处理所有特征)
+            stage6_path: Stage 6 数据文件路径(用于计算综合得分)
+            update_stage6: 是否计算并更新 Stage 6 的综合得分(默认 True)
         """
         self.weight_embedding = weight_embedding
         self.weight_semantic = weight_semantic
@@ -302,6 +306,8 @@ class Stage8SimilarityAnalyzer:
         self.min_similarity = min_similarity
         self.output_dir = output_dir
         self.target_features = target_features
+        self.stage6_path = stage6_path
+        self.update_stage6 = update_stage6
 
         # 验证权重
         total_weight = weight_embedding + weight_semantic
@@ -503,8 +509,199 @@ class Stage8SimilarityAnalyzer:
 
         logger.info(f"  结果已保存: {output_path}")
 
+        # 计算并更新综合得分P
+        if self.update_stage6:
+            logger.info("\n" + "=" * 60)
+            logger.info("开始计算综合得分P并更新Stage 6数据...")
+            logger.info("=" * 60)
+            self._calculate_and_update_comprehensive_scores(results)
+
         return final_result
 
+    def _calculate_and_update_comprehensive_scores(self, stage8_results: List[Dict]):
+        """
+        计算综合得分P并更新Stage 6数据
+
+        Args:
+            stage8_results: Stage 8 的结果列表
+        """
+        try:
+            # 1. 加载 Stage 6 数据
+            logger.info(f"  加载 Stage 6 数据: {self.stage6_path}")
+            if not os.path.exists(self.stage6_path):
+                logger.error(f"  Stage 6 文件不存在: {self.stage6_path}")
+                return
+
+            with open(self.stage6_path, 'r', encoding='utf-8') as f:
+                stage6_data = json.load(f)
+
+            # 2. 构建 Stage 8 映射 (note_id → max_similarity)
+            logger.info("  构建相似度映射...")
+            similarity_map = {}
+            for result in stage8_results:
+                note_id = result['note_id']
+                max_similarity = result['similarity_statistics']['max_similarity']
+                similarity_map[note_id] = max_similarity
+
+            logger.info(f"  相似度映射条目数: {len(similarity_map)}")
+
+            # 3. 遍历 Stage 6 中的所有原始特征和搜索词,计算 P 值
+            # Stage 6 数据是一个列表,每个元素是一个原始特征
+            updated_count = 0
+            total_searches = 0
+
+            logger.info(f"  开始遍历 {len(stage6_data)} 个原始特征...")
+
+            for feature_item in stage6_data:
+                original_feature = feature_item.get('原始特征名称', '')
+                logger.info(f"\n  处理原始特征: {original_feature}")
+
+                # 遍历每个分组
+                for group in feature_item.get('组合评估结果_分组', []):
+                    source_word = group.get('source_word', '')
+
+                    # 遍历该分组的所有搜索词
+                    for search_item in group.get('top10_searches', []):
+                        search_word = search_item.get('search_word', '')
+                        total_searches += 1
+
+                        logger.info(f"    处理搜索词: {search_word} (来源: {source_word})")
+
+                        # 计算该搜索词的综合得分
+                        p_score, p_detail = self._calculate_single_query_score(
+                            search_item,
+                            similarity_map
+                        )
+
+                        # 更新搜索词数据
+                        if p_score is not None:
+                            search_item['comprehensive_score'] = round(p_score, 3)
+                            search_item['comprehensive_score_detail'] = p_detail
+                            updated_count += 1
+                            logger.info(f"      综合得分P = {p_score:.3f} (M={p_detail['M']}, N={p_detail['N']})")
+                        else:
+                            logger.warning(f"      无法计算综合得分(可能缺少数据)")
+
+            # 4. 保存更新后的 Stage 6 数据
+            logger.info(f"\n  保存更新后的 Stage 6 数据...")
+            logger.info(f"  已更新 {updated_count}/{total_searches} 个搜索词")
+
+            with open(self.stage6_path, 'w', encoding='utf-8') as f:
+                json.dump(stage6_data, f, ensure_ascii=False, indent=2)
+
+            logger.info(f"  更新完成: {self.stage6_path}")
+
+        except Exception as e:
+            logger.error(f"  计算综合得分失败: {e}", exc_info=True)
+
+    def _calculate_single_query_score(
+        self,
+        query: Dict,
+        similarity_map: Dict[str, float]
+    ) -> tuple[Optional[float], Optional[Dict]]:
+        """
+        计算单个查询的综合得分P
+
+        Args:
+            query: Stage 6 中的单个查询对象
+            similarity_map: note_id → max_similarity 的映射
+
+        Returns:
+            (P值, 详细计算信息) 或 (None, None)
+        """
+        # 获取总帖子数 N
+        evaluation_with_filter = query.get('evaluation_with_filter', {})
+        N = evaluation_with_filter.get('total_notes', 0)
+
+        if N == 0:
+            logger.warning(f"    查询总帖子数为0,无法计算P值")
+            return None, None
+
+        # 获取笔记评估数据和原始笔记数据
+        notes_evaluation = evaluation_with_filter.get('notes_evaluation', [])
+        search_result = query.get('search_result', {})
+        notes_data = search_result.get('data', {}).get('data', [])
+
+        if not notes_evaluation or not notes_data:
+            logger.warning(f"    缺少评估数据或笔记数据")
+            return 0.0, {
+                'N': N,
+                'M': 0,
+                'total_contribution': 0.0,
+                'complete_matches': []
+            }
+
+        # 获取完全匹配的帖子列表 (综合得分 >= 0.8)
+        complete_matches_data = []
+        for note_eval in notes_evaluation:
+            score = note_eval.get('综合得分', 0)
+            if score >= 0.8:
+                note_index = note_eval.get('note_index', -1)
+                if 0 <= note_index < len(notes_data):
+                    # 从原始数据中获取note_id
+                    note_id = notes_data[note_index].get('id', '')
+                    note_card = notes_data[note_index].get('note_card', {})
+                    note_title = note_card.get('display_title', '')
+
+                    complete_matches_data.append({
+                        'note_id': note_id,
+                        'note_title': note_title,
+                        'evaluation_score': score,
+                        'note_index': note_index
+                    })
+
+        M = len(complete_matches_data)
+        logger.info(f"    完全匹配数: M = {M}/{N}")
+
+        if M == 0:
+            # 没有完全匹配,P = 0
+            return 0.0, {
+                'N': N,
+                'M': 0,
+                'total_contribution': 0.0,
+                'complete_matches': []
+            }
+
+        # 计算每个完全匹配的贡献 a×b
+        contributions = []
+        total_contribution = 0.0
+
+        for match in complete_matches_data:
+            note_id = match['note_id']
+            evaluation_score = match['evaluation_score']  # a 值
+
+            # 从 similarity_map 获取 b 值
+            max_similarity = similarity_map.get(note_id, 0)  # b 值
+
+            # 计算贡献
+            contribution = evaluation_score * max_similarity
+            total_contribution += contribution
+
+            # 保存详细信息
+            contributions.append({
+                'note_id': note_id,
+                'note_title': match['note_title'],
+                'evaluation_score': round(evaluation_score, 3),
+                'max_similarity': round(max_similarity, 3),
+                'contribution': round(contribution, 3)
+            })
+
+        # 计算综合得分 P = Σ(a×b) / N
+        P = total_contribution / N
+
+        # 按贡献降序排序
+        contributions.sort(key=lambda x: x['contribution'], reverse=True)
+
+        # 构建详细信息
+        detail = {
+            'N': N,
+            'M': M,
+            'total_contribution': round(total_contribution, 3),
+            'complete_matches': contributions
+        }
+
+        return P, detail
+
     def run(
         self,
         stage7_results: Dict,