| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- Stage 8 相似度分析器
- 计算 Stage 7 解构特征与原始特征的相似度评分
- """
- import os
- import json
- import time
- import logging
- import asyncio
- from datetime import datetime
- from typing import Dict, List, Any, Optional
- from lib.hybrid_similarity import compare_phrases_cartesian
- from lib.config import get_cache_dir
- try:
- from tqdm import tqdm
- TQDM_AVAILABLE = True
- except ImportError:
- TQDM_AVAILABLE = False
- logger = logging.getLogger(__name__)
- def extract_deconstructed_features(api_response: Dict) -> List[Dict]:
- """
- 从三点解构中提取所有特征
- Args:
- api_response: Stage 7 的 api_response 对象
- Returns:
- 特征列表,每个特征包含:
- - feature_name: 特征名称
- - dimension: 维度 (灵感点-全新内容/灵感点-共性差异/灵感点-共性内容/目的点/关键点)
- - dimension_detail: 维度细分 (实质/形式/意图等)
- - weight: 权重
- - source_index: 在该维度中的索引
- - source_*: 溯源信息 (候选编号、目的点描述、关键点描述等)
- """
- features = []
- # 检查 API 响应状态
- if api_response.get('status') != 'success':
- logger.warning(" API 响应状态不是 success,无法提取特征")
- return features
- result = api_response.get('result', {})
- # 检查是否有 data 字段
- if 'data' not in result:
- logger.warning(" API 响应中没有 data 字段")
- return features
- data = result['data']
- three_point = data.get('三点解构', {})
- if not three_point:
- logger.warning(" 三点解构数据为空")
- return features
- # 1. 提取灵感点 (3个子类别)
- inspiration = three_point.get('灵感点', {})
- for category in ['全新内容', '共性差异', '共性内容']:
- items = inspiration.get(category, [])
- for idx, item in enumerate(items):
- extracted_features = item.get('提取的特征', [])
- for feat in extracted_features:
- feature_name = feat.get('特征名称', '')
- if not feature_name:
- continue
- features.append({
- 'feature_name': feature_name,
- 'dimension': f'灵感点-{category}',
- 'dimension_detail': feat.get('维度分类', ''), # 注意字段名
- 'weight': feat.get('权重', 0),
- 'source_index': idx,
- 'source_candidate_number': item.get('候选编号', 0),
- 'source_inspiration': item.get('灵感点', '')
- })
- # 2. 提取目的点
- purpose = three_point.get('目的点', {})
- purposes_list = purpose.get('purposes', [])
- for idx, item in enumerate(purposes_list):
- extracted_features = item.get('提取的特征', [])
- for feat in extracted_features:
- feature_name = feat.get('特征名称', '')
- if not feature_name:
- continue
- features.append({
- 'feature_name': feature_name,
- 'dimension': '目的点',
- 'dimension_detail': feat.get('特征分类', ''), # 注意字段名
- 'weight': feat.get('权重', 0),
- 'source_index': idx,
- 'source_purpose': item.get('目的点', ''),
- 'source_purpose_dimension': item.get('维度', {})
- })
- # 3. 提取关键点
- key_points_data = three_point.get('关键点', {})
- key_points_list = key_points_data.get('key_points', [])
- for idx, item in enumerate(key_points_list):
- extracted_features = item.get('提取的特征', [])
- for feat in extracted_features:
- feature_name = feat.get('特征名称', '')
- if not feature_name:
- continue
- features.append({
- 'feature_name': feature_name,
- 'dimension': '关键点',
- 'dimension_detail': feat.get('维度', ''), # 注意字段名
- 'weight': feat.get('权重', 0),
- 'source_index': idx,
- 'source_candidate_number': item.get('候选编号', 0),
- 'source_key_point': item.get('关键点', ''),
- 'source_key_point_dimension': item.get('维度', '')
- })
- logger.info(f" 提取特征数量: {len(features)}")
- if features:
- # 统计各维度数量
- dimension_counts = {}
- for feat in features:
- dim = feat['dimension']
- dimension_counts[dim] = dimension_counts.get(dim, 0) + 1
- logger.info(f" 维度分布: {dimension_counts}")
- return features
- async def calculate_similarity_for_note(
- note_result: Dict,
- original_feature: str,
- weight_embedding: float = 0.5,
- weight_semantic: float = 0.5,
- min_similarity: float = 0.0
- ) -> Dict:
- """
- 计算单个帖子的所有特征与原始特征的相似度
- Args:
- note_result: Stage 7 的单个 result 对象
- original_feature: 原始特征名称
- weight_embedding: 向量模型权重
- weight_semantic: LLM 模型权重
- min_similarity: 最小相似度阈值,低于此值的特征会被过滤
- Returns:
- 包含相似度信息的结果对象
- """
- note_id = note_result.get('note_id', '')
- logger.info(f" [{note_id}] 开始计算相似度...")
- # 1. 提取解构特征
- deconstructed_features = extract_deconstructed_features(
- note_result['api_response']
- )
- if not deconstructed_features:
- logger.warning(f" [{note_id}] 没有提取到特征")
- return {
- 'note_id': note_id,
- 'original_feature': original_feature,
- 'evaluation_score': note_result.get('evaluation_score', 0),
- 'search_word': note_result.get('search_word', ''),
- 'note_data': note_result.get('note_data', {}),
- 'deconstructed_features': [],
- 'similarity_statistics': {
- 'total_features': 0,
- 'max_similarity': 0,
- 'min_similarity': 0,
- 'avg_similarity': 0,
- 'high_similarity_count': 0,
- 'medium_similarity_count': 0,
- 'low_similarity_count': 0
- }
- }
- # 2. 构建特征名称列表
- feature_names = [f['feature_name'] for f in deconstructed_features]
- logger.info(f" [{note_id}] 调用相似度计算 API (1×{len(feature_names)} 笛卡尔积)...")
- # 3. 批量计算相似度 (1×N 笛卡尔积)
- try:
- start_time = time.time()
- similarity_results = await compare_phrases_cartesian(
- phrases_a=[original_feature],
- phrases_b=feature_names,
- max_concurrent=50
- )
- elapsed = time.time() - start_time
- logger.info(f" [{note_id}] 相似度计算完成 ({elapsed:.1f}秒)")
- # 4. 映射结果回特征对象
- for i, feat in enumerate(deconstructed_features):
- feat['similarity_score'] = similarity_results[0][i]['相似度']
- feat['similarity_explanation'] = similarity_results[0][i]['说明']
- # 5. 过滤低相似度特征
- if min_similarity > 0:
- original_count = len(deconstructed_features)
- deconstructed_features = [
- f for f in deconstructed_features
- if f['similarity_score'] >= min_similarity
- ]
- filtered_count = original_count - len(deconstructed_features)
- if filtered_count > 0:
- logger.info(f" [{note_id}] 过滤掉 {filtered_count} 个低相似度特征 (< {min_similarity})")
- # 6. 计算统计信息
- if deconstructed_features:
- scores = [f['similarity_score'] for f in deconstructed_features]
- statistics = {
- 'total_features': len(scores),
- 'max_similarity': round(max(scores), 3),
- 'min_similarity': round(min(scores), 3),
- 'avg_similarity': round(sum(scores) / len(scores), 3),
- 'high_similarity_count': sum(1 for s in scores if s >= 0.7),
- 'medium_similarity_count': sum(1 for s in scores if 0.5 <= s < 0.7),
- 'low_similarity_count': sum(1 for s in scores if s < 0.5)
- }
- # 7. 按相似度降序排序
- deconstructed_features.sort(key=lambda x: x['similarity_score'], reverse=True)
- logger.info(f" [{note_id}] 统计: 最高={statistics['max_similarity']}, "
- f"平均={statistics['avg_similarity']}, "
- f"高相似度={statistics['high_similarity_count']}个")
- else:
- statistics = {
- 'total_features': 0,
- 'max_similarity': 0,
- 'min_similarity': 0,
- 'avg_similarity': 0,
- 'high_similarity_count': 0,
- 'medium_similarity_count': 0,
- 'low_similarity_count': 0
- }
- return {
- 'note_id': note_id,
- 'original_feature': original_feature,
- 'evaluation_score': note_result.get('evaluation_score', 0),
- 'search_word': note_result.get('search_word', ''),
- 'note_data': note_result.get('note_data', {}),
- 'deconstructed_features': deconstructed_features,
- 'similarity_statistics': statistics,
- 'processing_time_seconds': round(elapsed, 2)
- }
- except Exception as e:
- logger.error(f" [{note_id}] 相似度计算失败: {e}")
- return {
- 'note_id': note_id,
- 'original_feature': original_feature,
- 'evaluation_score': note_result.get('evaluation_score', 0),
- 'search_word': note_result.get('search_word', ''),
- 'note_data': note_result.get('note_data', {}),
- 'deconstructed_features': [],
- 'similarity_statistics': {
- 'total_features': 0,
- 'error': str(e)
- }
- }
- class Stage8SimilarityAnalyzer:
- """Stage 8: 解构特征与原始特征的相似度分析"""
- def __init__(
- self,
- weight_embedding: float = 0.5,
- weight_semantic: float = 0.5,
- max_workers: int = 5,
- min_similarity: float = 0.0,
- output_dir: str = "output_v2",
- target_features: Optional[List[str]] = None
- ):
- """
- 初始化 Stage 8 分析器
- Args:
- weight_embedding: 向量模型权重(默认 0.5)
- weight_semantic: LLM 模型权重(默认 0.5)
- max_workers: 最大并发数(默认 5)
- min_similarity: 最小相似度阈值(默认 0.0,保留所有特征)
- output_dir: 输出目录
- target_features: 指定要处理的原始特征列表(None = 处理所有特征)
- """
- self.weight_embedding = weight_embedding
- self.weight_semantic = weight_semantic
- self.max_workers = max_workers
- self.min_similarity = min_similarity
- self.output_dir = output_dir
- self.target_features = target_features
- # 验证权重
- total_weight = weight_embedding + weight_semantic
- if abs(total_weight - 1.0) > 0.001:
- raise ValueError(f"权重之和必须为1.0,当前为: {total_weight}")
- def _save_intermediate_results(
- self,
- results: List[Dict],
- output_path: str,
- processed_count: int,
- total_count: int,
- start_time: float
- ):
- """保存中间结果"""
- base_dir = os.path.dirname(output_path) or self.output_dir
- base_name = os.path.basename(output_path)
- name_without_ext = os.path.splitext(base_name)[0]
- intermediate_path = os.path.join(
- base_dir,
- f"{name_without_ext}_partial_{processed_count}of{total_count}.json"
- )
- # 统计
- total_features = sum(r['similarity_statistics']['total_features'] for r in results)
- avg_max_sim = sum(r['similarity_statistics']['max_similarity'] for r in results) / len(results)
- intermediate_result = {
- 'metadata': {
- 'stage': 'stage8_partial',
- 'description': f'部分结果({processed_count}/{total_count})',
- 'processed_notes': len(results),
- 'total_features_extracted': total_features,
- 'avg_max_similarity': round(avg_max_sim, 3),
- 'saved_at': datetime.now().isoformat(),
- 'processing_time_seconds': round(time.time() - start_time, 2)
- },
- 'results': results
- }
- os.makedirs(base_dir, exist_ok=True)
- with open(intermediate_path, 'w', encoding='utf-8') as f:
- json.dump(intermediate_result, f, ensure_ascii=False, indent=2)
- logger.info(f" 已保存中间结果: {intermediate_path}")
- async def run_async(
- self,
- stage7_results: Dict,
- output_path: Optional[str] = None
- ) -> Dict:
- """
- 执行 Stage 8 相似度分析(异步版本)
- Args:
- stage7_results: Stage 7 结果
- output_path: 输出路径(可选)
- Returns:
- Stage 8 结果
- """
- logger.info("\n" + "=" * 60)
- logger.info("Stage 8: 解构特征与原始特征的相似度分析")
- logger.info("=" * 60)
- # 打印配置
- logger.info("配置参数:")
- logger.info(f" 向量模型权重: {self.weight_embedding}")
- logger.info(f" LLM 模型权重: {self.weight_semantic}")
- logger.info(f" 最大并发数: {self.max_workers}")
- logger.info(f" 最小相似度阈值: {self.min_similarity}")
- if self.target_features:
- logger.info(f" 目标特征: {', '.join(self.target_features)}")
- else:
- logger.info(f" 目标特征: 全部")
- # 默认输出路径
- if output_path is None:
- output_path = os.path.join(self.output_dir, "stage8_similarity_scores.json")
- # 提取 Stage 7 结果
- results_list = stage7_results.get('results', [])
- # 过滤目标特征
- if self.target_features:
- results_list = [
- r for r in results_list
- if r.get('original_feature') in self.target_features
- ]
- total_notes = len(results_list)
- logger.info(f" 待处理帖子数: {total_notes}")
- if total_notes == 0:
- logger.warning(" 没有需要处理的帖子")
- return {
- 'metadata': {
- 'stage': 'stage8',
- 'processed_notes': 0
- },
- 'results': []
- }
- # 创建任务列表
- start_time = time.time()
- results = []
- # 使用 Semaphore 控制并发数
- semaphore = asyncio.Semaphore(self.max_workers)
- async def bounded_task(result):
- async with semaphore:
- return await calculate_similarity_for_note(
- result,
- result.get('original_feature', ''),
- self.weight_embedding,
- self.weight_semantic,
- self.min_similarity
- )
- tasks = [bounded_task(result) for result in results_list]
- # 带进度条执行
- if TQDM_AVAILABLE:
- logger.info(" 使用进度条显示...")
- processed_count = 0
- save_interval = 10
- for coro in tqdm(
- asyncio.as_completed(tasks),
- total=len(tasks),
- desc=" 相似度计算进度",
- unit="帖子",
- ncols=100
- ):
- result = await coro
- results.append(result)
- processed_count += 1
- # 增量保存
- if processed_count % save_interval == 0:
- self._save_intermediate_results(
- results,
- output_path,
- processed_count,
- total_notes,
- start_time
- )
- else:
- # 简单执行
- results = await asyncio.gather(*tasks)
- logger.info(f" 完成: {len(results)}/{total_notes}")
- processing_time = time.time() - start_time
- # 计算总体统计
- total_features = sum(r['similarity_statistics']['total_features'] for r in results)
- all_max_similarities = [r['similarity_statistics']['max_similarity'] for r in results if r['similarity_statistics']['total_features'] > 0]
- overall_stats = {
- 'total_notes': total_notes,
- 'total_features_extracted': total_features,
- 'avg_features_per_note': round(total_features / total_notes, 1) if total_notes > 0 else 0,
- 'avg_max_similarity': round(sum(all_max_similarities) / len(all_max_similarities), 3) if all_max_similarities else 0,
- 'notes_with_high_similarity': sum(1 for r in results if r['similarity_statistics'].get('high_similarity_count', 0) > 0)
- }
- logger.info(f"\n 总耗时: {processing_time:.1f}秒")
- logger.info(f" 总特征数: {total_features}")
- logger.info(f" 平均特征数/帖子: {overall_stats['avg_features_per_note']}")
- logger.info(f" 平均最高相似度: {overall_stats['avg_max_similarity']}")
- logger.info(f" 包含高相似度特征的帖子: {overall_stats['notes_with_high_similarity']}")
- # 构建最终结果
- final_result = {
- 'metadata': {
- 'stage': 'stage8',
- 'description': '解构特征与原始特征的相似度评分',
- 'source_file': stage7_results.get('metadata', {}).get('created_at', ''),
- 'target_features': self.target_features if self.target_features else '全部',
- 'similarity_config': {
- 'algorithm': 'hybrid_similarity',
- 'weight_embedding': self.weight_embedding,
- 'weight_semantic': self.weight_semantic,
- 'min_similarity_threshold': self.min_similarity
- },
- 'overall_statistics': overall_stats,
- 'created_at': datetime.now().isoformat(),
- 'processing_time_seconds': round(processing_time, 2)
- },
- 'results': results
- }
- # 保存结果
- os.makedirs(os.path.dirname(output_path) or self.output_dir, exist_ok=True)
- with open(output_path, 'w', encoding='utf-8') as f:
- json.dump(final_result, f, ensure_ascii=False, indent=2)
- logger.info(f" 结果已保存: {output_path}")
- return final_result
- def run(
- self,
- stage7_results: Dict,
- output_path: Optional[str] = None
- ) -> Dict:
- """
- 执行 Stage 8 相似度分析(同步版本)
- Args:
- stage7_results: Stage 7 结果
- output_path: 输出路径(可选)
- Returns:
- Stage 8 结果
- """
- return asyncio.run(self.run_async(stage7_results, output_path))
- def test_stage8_analyzer():
- """测试 Stage 8 分析器"""
- # 读取 Stage 7 结果
- stage7_path = "output_v2/stage7_with_deconstruction.json"
- if not os.path.exists(stage7_path):
- print(f"Stage 7 结果不存在: {stage7_path}")
- return
- with open(stage7_path, 'r', encoding='utf-8') as f:
- stage7_results = json.load(f)
- # 创建分析器
- analyzer = Stage8SimilarityAnalyzer(
- weight_embedding=0.5,
- weight_semantic=0.5,
- max_workers=3,
- min_similarity=0.3,
- target_features=["墨镜"]
- )
- # 运行分析
- stage8_results = analyzer.run(stage7_results)
- print(f"\n处理了 {stage8_results['metadata']['overall_statistics']['total_notes']} 个帖子")
- print(f"提取了 {stage8_results['metadata']['overall_statistics']['total_features_extracted']} 个特征")
- print(f"平均最高相似度: {stage8_results['metadata']['overall_statistics']['avg_max_similarity']}")
- if __name__ == '__main__':
- logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
- )
- test_stage8_analyzer()
|