yangxiaohui 1 周之前
父節點
當前提交
458321e73f
共有 2 個文件被更改,包括 51 次插入17 次删除
  1. 22 14
      lib/hybrid_similarity.py
  2. 29 3
      script/data_processing/path_config.py

+ 22 - 14
lib/hybrid_similarity.py

@@ -142,10 +142,12 @@ async def compare_phrases(
 async def compare_phrases_cartesian(
     phrases_a: List[str],
     phrases_b: List[str],
-    max_concurrent: int = 50
+    max_concurrent: int = 50,
+    llm_progress_callback: Optional[callable] = None,
+    embedding_progress_callback: Optional[callable] = None
 ) -> List[List[Dict[str, Any]]]:
     """
-    混合相似度笛卡尔积批量计算:M×N矩阵
+    混合相似度笛卡尔积批量计算:M×N矩阵(带双进度回调)
 
     结合向量模型API笛卡尔积(快速)和LLM并发调用(已优化)
     使用默认权重:向量0.5,LLM 0.5
@@ -154,6 +156,8 @@ async def compare_phrases_cartesian(
         phrases_a: 第一组短语列表(M个)
         phrases_b: 第二组短语列表(N个)
         max_concurrent: 最大并发数,默认50(控制LLM调用并发)
+        llm_progress_callback: LLM进度回调函数,每完成一个LLM任务调用一次
+        embedding_progress_callback: 向量进度回调函数,每完成一个向量任务调用一次
 
     Returns:
         嵌套列表 List[List[Dict]],每个Dict包含完整结果
@@ -187,26 +191,30 @@ async def compare_phrases_cartesian(
     weight_embedding = 0.5
     weight_semantic = 0.5
 
-    # 并发执行两个任务
-    # 1. 向量模型:使用API笛卡尔积(一次调用获取M×N完整结果)
-    embedding_task = asyncio.to_thread(
+    # 串行执行两个任务(向量模型快,先执行;避免并发死锁)
+    # 1. 向量模型:使用API笛卡尔积(一次调用获取M×N完整结果,通常1-2秒)
+    import time
+    start_time = time.time()
+    embedding_results = await asyncio.to_thread(
         compare_phrases_cartesian_api,
         phrases_a,
         phrases_b,
-        max_concurrent  # 传递并发参数(API不使用,但保持接口一致)
+        max_concurrent,
+        None  # 不传递回调
     )
+    elapsed = time.time() - start_time
+    # print(f"✓ 向量模型完成,耗时: {elapsed:.1f}秒")  # 调试用
+
+    # 向量模型完成后,一次性批量更新进度(而不是循环25704次)
+    if embedding_progress_callback:
+        embedding_progress_callback(M * N)  # 传递总数,一次更新
 
     # 2. LLM模型:使用并发调用(M×N个任务,受max_concurrent控制)
-    semantic_task = compare_phrases_cartesian_semantic(
+    semantic_results = await compare_phrases_cartesian_semantic(
         phrases_a,
         phrases_b,
-        max_concurrent  # 传递并发参数控制LLM调用
-    )
-
-    # 等待两个任务完成
-    embedding_results, semantic_results = await asyncio.gather(
-        embedding_task,
-        semantic_task
+        max_concurrent,  # 传递并发参数控制LLM调用
+        llm_progress_callback  # 传递LLM进度回调
     )
     # embedding_results[i][j] = {"相似度": float, "说明": str}
     # semantic_results[i][j] = {"相似度": float, "说明": str}

+ 29 - 3
script/data_processing/path_config.py

@@ -51,6 +51,14 @@ class PathConfig:
         with open(self.config_file, "r", encoding="utf-8") as f:
             self.config = json.load(f)
 
+    def _get_account_config(self, account_name: str) -> Optional[Dict]:
+        """获取特定账号的配置"""
+        accounts = self.config.get("accounts", [])
+        for acc in accounts:
+            if acc["name"] == account_name:
+                return acc
+        return None
+
     def _get_data_root(self) -> Path:
         """
         获取数据根目录
@@ -173,22 +181,40 @@ class PathConfig:
 
     # ===== 输入路径 =====
 
+    def _get_input_path(self, path_key: str) -> str:
+        """
+        获取输入路径配置,支持账号级别的自定义路径
+
+        优先级:
+        1. 账号特定配置 (accounts[x].paths.input.path_key)
+        2. 全局默认配置 (paths.input.path_key)
+        """
+        # 1. 检查账号特定配置
+        account_config = self._get_account_config(self.account_name)
+        if account_config and "paths" in account_config:
+            account_paths = account_config["paths"]
+            if "input" in account_paths and path_key in account_paths["input"]:
+                return account_paths["input"][path_key]
+
+        # 2. 使用全局默认配置
+        return self.config["paths"]["input"][path_key]
+
     @property
     def current_posts_dir(self) -> Path:
         """当前帖子what解构结果目录"""
-        rel_path = self.config["paths"]["input"]["current_posts"]
+        rel_path = self._get_input_path("current_posts")
         return self.account_dir / rel_path
 
     @property
     def historical_posts_dir(self) -> Path:
         """过去帖子what解构结果目录"""
-        rel_path = self.config["paths"]["input"]["historical_posts"]
+        rel_path = self._get_input_path("historical_posts")
         return self.account_dir / rel_path
 
     @property
     def pattern_cluster_file(self) -> Path:
         """pattern聚合结果文件"""
-        rel_path = self.config["paths"]["input"]["pattern_cluster"]
+        rel_path = self._get_input_path("pattern_cluster")
         return self.account_dir / rel_path
 
     # ===== 输出路径 =====