Quellcode durchsuchen

refactor: 优化匹配流程和进度显示

- 统一进度回调接口,简化参数传递
  - hybrid_similarity: 合并 llm_progress_callback 和 embedding_progress_callback 为单一 progress_callback
  - semantic_similarity: 添加进度回调支持
  - text_embedding_api: 移除未使用的 max_concurrent 参数

- 改进进度显示
  - 使用 tqdm 替代自定义 ProgressTracker
  - 每个帖子独立显示进度条
  - 更清晰的任务进度展示

- 优化保存策略
  - 每个帖子处理完立即保存,避免内存积累
  - 改为串行处理,便于错误定位和进度跟踪

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
yangxiaohui vor 1 Woche
Ursprung
Commit
bbfd5a7778

+ 11 - 16
lib/hybrid_similarity.py

@@ -143,11 +143,10 @@ async def compare_phrases_cartesian(
     phrases_a: List[str],
     phrases_b: List[str],
     max_concurrent: int = 50,
-    llm_progress_callback: Optional[callable] = None,
-    embedding_progress_callback: Optional[callable] = None
+    progress_callback: Optional[callable] = None
 ) -> List[List[Dict[str, Any]]]:
     """
-    混合相似度笛卡尔积批量计算:M×N矩阵(带双进度回调)
+    混合相似度笛卡尔积批量计算:M×N矩阵
 
     结合向量模型API笛卡尔积(快速)和LLM并发调用(已优化)
     使用默认权重:向量0.5,LLM 0.5
@@ -156,8 +155,7 @@ async def compare_phrases_cartesian(
         phrases_a: 第一组短语列表(M个)
         phrases_b: 第二组短语列表(N个)
         max_concurrent: 最大并发数,默认50(控制LLM调用并发)
-        llm_progress_callback: LLM进度回调函数,每完成一个LLM任务调用一次
-        embedding_progress_callback: 向量进度回调函数,每完成一个向量任务调用一次
+        progress_callback: 进度回调函数,每完成一个LLM任务时调用
 
     Returns:
         嵌套列表 List[List[Dict]],每个Dict包含完整结果
@@ -174,11 +172,14 @@ async def compare_phrases_cartesian(
         >>> print(results[0][0]['相似度'])  # 混合相似度
         >>> print(results[0][1]['说明'])    # 完整说明
 
-        >>> # 自定义并发控制
+        >>> # 使用进度回调
+        >>> def on_progress(count):
+        ...     print(f"完成 {count} 个任务")
         >>> results = await compare_phrases_cartesian(
         ...     ["深度学习"],
         ...     ["神经网络", "Python"],
-        ...     max_concurrent=100  # 提高并发数
+        ...     max_concurrent=100,
+        ...     progress_callback=on_progress
         ... )
     """
     # 参数验证
@@ -198,23 +199,17 @@ async def compare_phrases_cartesian(
     embedding_results = await asyncio.to_thread(
         compare_phrases_cartesian_api,
         phrases_a,
-        phrases_b,
-        max_concurrent,
-        None  # 不传递回调
+        phrases_b
     )
     elapsed = time.time() - start_time
     # print(f"✓ 向量模型完成,耗时: {elapsed:.1f}秒")  # 调试用
 
-    # 向量模型完成后,一次性批量更新进度(而不是循环25704次)
-    if embedding_progress_callback:
-        embedding_progress_callback(M * N)  # 传递总数,一次更新
-
     # 2. LLM模型:使用并发调用(M×N个任务,受max_concurrent控制)
     semantic_results = await compare_phrases_cartesian_semantic(
         phrases_a,
         phrases_b,
-        max_concurrent,  # 传递并发参数控制LLM调用
-        llm_progress_callback  # 传递LLM进度回调
+        max_concurrent,
+        progress_callback  # 传递进度回调
     )
     # embedding_results[i][j] = {"相似度": float, "说明": str}
     # semantic_results[i][j] = {"相似度": float, "说明": str}

+ 9 - 3
lib/semantic_similarity.py

@@ -550,10 +550,11 @@ async def compare_phrases(
 async def compare_phrases_cartesian(
     phrases_a: List[str],
     phrases_b: List[str],
-    max_concurrent: int = 50
+    max_concurrent: int = 50,
+    progress_callback: Optional[callable] = None
 ) -> List[List[Dict[str, Any]]]:
     """
-    笛卡尔积批量计算:M×N并发LLM调用(带并发控制)
+    笛卡尔积批量计算:M×N并发LLM调用(带并发控制和进度回调
 
     用于架构统一性,内部通过并发实现(LLM无法真正批处理)
 
@@ -561,6 +562,7 @@ async def compare_phrases_cartesian(
         phrases_a: 第一组短语列表(M个)
         phrases_b: 第二组短语列表(N个)
         max_concurrent: 最大并发数,默认50
+        progress_callback: 进度回调函数,每完成一个任务时调用
 
     Returns:
         嵌套列表 List[List[Dict]],每个Dict包含完整的比较结果
@@ -588,7 +590,11 @@ async def compare_phrases_cartesian(
 
     async def limited_compare(phrase_a: str, phrase_b: str):
         async with semaphore:
-            return await compare_phrases(phrase_a, phrase_b)
+            result = await compare_phrases(phrase_a, phrase_b)
+            # 调用进度回调
+            if progress_callback:
+                progress_callback(1)
+            return result
 
     # 创建M×N个受控的并发任务
     tasks = []

+ 1 - 3
lib/text_embedding_api.py

@@ -267,8 +267,7 @@ def compare_phrases_batch(
 
 def compare_phrases_cartesian(
     phrases_a: List[str],
-    phrases_b: List[str],
-    max_concurrent: int = 50
+    phrases_b: List[str]
 ) -> List[List[Dict[str, Any]]]:
     """
     计算笛卡尔积相似度(M×N矩阵)
@@ -279,7 +278,6 @@ def compare_phrases_cartesian(
     Args:
         phrases_a: 第一组短语列表 (M个)
         phrases_b: 第二组短语列表 (N个)
-        max_concurrent: 最大并发数(API一次性调用,此参数保留用于接口一致性)
 
     Returns:
         M×N的结果矩阵(嵌套列表)

+ 72 - 106
script/data_processing/match_inspiration_features.py

@@ -12,7 +12,7 @@ import asyncio
 from pathlib import Path
 from typing import Dict, List
 import sys
-from datetime import datetime
+from tqdm import tqdm
 
 # 添加项目根目录到路径
 project_root = Path(__file__).parent.parent.parent
@@ -21,58 +21,8 @@ sys.path.insert(0, str(project_root))
 from lib.hybrid_similarity import compare_phrases_cartesian
 from script.data_processing.path_config import PathConfig
 
-# 进度跟踪
-class ProgressTracker:
-    """进度跟踪器"""
-    def __init__(self, total: int):
-        self.total = total
-        self.completed = 0
-        self.start_time = datetime.now()
-        self.last_update_time = datetime.now()
-        self.last_completed = 0
-
-    def update(self, count: int = 1):
-        """更新进度"""
-        self.completed += count
-        current_time = datetime.now()
-
-        # 每秒最多更新一次,或者达到总数时更新
-        if (current_time - self.last_update_time).total_seconds() >= 1.0 or self.completed >= self.total:
-            self.display()
-            self.last_update_time = current_time
-            self.last_completed = self.completed
-
-    def display(self):
-        """显示进度"""
-        if self.total == 0:
-            return
-
-        percentage = (self.completed / self.total) * 100
-        elapsed = (datetime.now() - self.start_time).total_seconds()
-
-        # 计算速度和预估剩余时间
-        if elapsed > 0:
-            speed = self.completed / elapsed
-            if speed > 0:
-                remaining = (self.total - self.completed) / speed
-                eta_str = f", 预计剩余: {int(remaining)}秒"
-            else:
-                eta_str = ""
-        else:
-            eta_str = ""
-
-        bar_length = 40
-        filled_length = int(bar_length * self.completed / self.total)
-        bar = '█' * filled_length + '░' * (bar_length - filled_length)
-
-        print(f"\r  进度: [{bar}] {self.completed}/{self.total} ({percentage:.1f}%){eta_str}", end='', flush=True)
-
-        # 完成时换行
-        if self.completed >= self.total:
-            print()
-
-# 全局进度跟踪器
-progress_tracker = None
+# 全局进度条
+progress_bar = None
 
 
 async def process_single_point(
@@ -95,7 +45,7 @@ async def process_single_point(
     Returns:
         包含 how 步骤列表的点数据
     """
-    global progress_tracker
+    global progress_bar
 
     point_name = point.get("名称", "")
     feature_list = point.get("特征列表", [])
@@ -110,19 +60,22 @@ async def process_single_point(
     feature_names = [f.get("特征名称", "") for f in feature_list]
     persona_names = [pf["特征名称"] for pf in persona_features]
 
+    # 定义进度回调函数
+    def on_llm_progress(count: int):
+        """LLM完成一个任务时的回调"""
+        if progress_bar:
+            progress_bar.update(count)
+
     # 核心优化:使用混合模型笛卡尔积一次计算M×N
-    try:
-        similarity_results = await compare_phrases_cartesian(
-            feature_names,      # M个特征
-            persona_names,      # N个人设
-            max_concurrent=100  # LLM最大并发数
-        )
-        # similarity_results[i][j] = {"相似度": float, "说明": str}
-    except Exception as e:
-        print(f"\n⚠️  混合模型调用失败: {e}")
-        result = point.copy()
-        result["how步骤列表"] = []
-        return result
+    # max_concurrent 控制的是底层 LLM 的全局并发数
+    similarity_results = await compare_phrases_cartesian(
+        feature_names,      # M个特征
+        persona_names,      # N个人设
+        max_concurrent=100,  # LLM最大并发数(全局共享)
+        progress_callback=on_llm_progress  # 传递进度回调
+    )
+    # similarity_results[i][j] = {"相似度": float, "说明": str}
+
 
     # 构建匹配结果(使用模块返回的完整结果)
     feature_match_results = []
@@ -161,7 +114,7 @@ async def process_single_point(
                     all_categories = set()
                     for ft in ["灵感点", "关键点", "目的点"]:
                         if ft in category_mapping:
-                            for fname, fdata in category_mapping[ft].items():
+                            for _fname, fdata in category_mapping[ft].items():
                                 cats = fdata.get("所属分类", [])
                                 all_categories.update(cats)
 
@@ -181,10 +134,6 @@ async def process_single_point(
             }
             match_results.append(match_result)
 
-            # 更新进度
-            if progress_tracker:
-                progress_tracker.update(1)
-
         feature_match_results.append({
             "特征名称": feature_name,
             "权重": feature_weight,
@@ -231,36 +180,56 @@ async def process_single_task(
     Returns:
         包含 how 解构结果的任务
     """
+    global progress_bar
+
     post_id = task.get("帖子id", "")
-    print(f"\n[{task_index}/{total_tasks}] 处理帖子: {post_id}")
 
     # 获取 what 解构结果
     what_result = task.get("what解构结果", {})
 
+    # 计算当前帖子的总匹配任务数
+    current_task_match_count = 0
+    for point_type in ["灵感点", "关键点", "目的点"]:
+        point_list = what_result.get(f"{point_type}列表", [])
+        for point in point_list:
+            feature_count = len(point.get("特征列表", []))
+            current_task_match_count += feature_count * len(all_persona_features)
+
+    # 创建当前帖子的进度条
+    progress_bar = tqdm(
+        total=current_task_match_count,
+        desc=f"[{task_index}/{total_tasks}] {post_id}",
+        unit="匹配",
+        ncols=100
+    )
+
     # 构建 how 解构结果
     how_result = {}
 
-    # 处理灵感点、关键点和目的点
+    # 串行处理灵感点、关键点和目的点
     for point_type in ["灵感点", "关键点", "目的点"]:
         point_list_key = f"{point_type}列表"
         point_list = what_result.get(point_list_key, [])
 
         if point_list:
-            # 并发处理所有点
-            tasks = [
-                process_single_point(
+            updated_point_list = []
+            # 串行处理每个点
+            for point in point_list:
+                result = await process_single_point(
                     point=point,
                     point_type=point_type,
                     persona_features=all_persona_features,
                     category_mapping=category_mapping,
                     model_name=model_name
                 )
-                for point in point_list
-            ]
-            updated_point_list = await asyncio.gather(*tasks)
+                updated_point_list.append(result)
 
             # 添加到 how 解构结果
-            how_result[point_list_key] = list(updated_point_list)
+            how_result[point_list_key] = updated_point_list
+
+    # 关闭当前帖子的进度条
+    if progress_bar:
+        progress_bar.close()
 
     # 更新任务
     updated_task = task.copy()
@@ -273,22 +242,22 @@ async def process_task_list(
     task_list: List[Dict],
     persona_features_dict: Dict,
     category_mapping: Dict = None,
-    model_name: str = None
+    model_name: str = None,
+    output_dir: Path = None
 ) -> List[Dict]:
     """
-    处理整个解构任务列表(并发执行
+    处理整个解构任务列表(串行执行,每个帖子处理完立即保存
 
     Args:
         task_list: 解构任务列表
         persona_features_dict: 人设特征字典(包含灵感点、目的点、关键点)
         category_mapping: 特征分类映射字典
         model_name: 使用的模型名称
+        output_dir: 输出目录(如果提供,每个帖子处理完立即保存)
 
     Returns:
         包含 how 解构结果的任务列表
     """
-    global progress_tracker
-
     # 合并三种人设特征(灵感点、关键点、目的点)
     all_features = []
 
@@ -336,12 +305,10 @@ async def process_task_list(
     print(f"总匹配任务数: {total_match_count:,}")
     print()
 
-    # 初始化全局进度跟踪器
-    progress_tracker = ProgressTracker(total_match_count)
-
-    # 并发处理所有任务
-    tasks = [
-        process_single_task(
+    # 串行处理所有任务(一个接一个,每个处理完立即保存)
+    updated_task_list = []
+    for i, task in enumerate(task_list, 1):
+        updated_task = await process_single_task(
             task=task,
             task_index=i,
             total_tasks=len(task_list),
@@ -349,11 +316,19 @@ async def process_task_list(
             category_mapping=category_mapping,
             model_name=model_name
         )
-        for i, task in enumerate(task_list, 1)
-    ]
-    updated_task_list = await asyncio.gather(*tasks)
+        updated_task_list.append(updated_task)
+
+        # 立即保存当前帖子的结果
+        if output_dir:
+            post_id = updated_task.get("帖子id", "unknown")
+            output_file = output_dir / f"{post_id}_how.json"
 
-    return list(updated_task_list)
+            with open(output_file, "w", encoding="utf-8") as f:
+                json.dump(updated_task, f, ensure_ascii=False, indent=4)
+
+            print(f"  ✓ 已保存: {output_file.name}")
+
+    return updated_task_list
 
 
 async def main():
@@ -393,24 +368,15 @@ async def main():
     task_list = task_list_data.get("解构任务列表", [])
     print(f"总任务数: {len(task_list)}")
 
-    # 处理任务列表
+    # 处理任务列表(每个帖子处理完立即保存)
     updated_task_list = await process_task_list(
         task_list=task_list,
         persona_features_dict=persona_features_data,
         category_mapping=category_mapping,
-        model_name=None  # 使用默认模型
+        model_name=None,  # 使用默认模型
+        output_dir=output_dir  # 传递输出目录,启用即时保存
     )
 
-    # 分文件保存结果
-    print(f"\n保存结果到: {output_dir}")
-    for task in updated_task_list:
-        post_id = task.get("帖子id", "unknown")
-        output_file = output_dir / f"{post_id}_how.json"
-
-        print(f"  保存: {output_file.name}")
-        with open(output_file, "w", encoding="utf-8") as f:
-            json.dump(task, f, ensure_ascii=False, indent=4)
-
     print("\n完成!")
 
     # 打印统计信息