liuzhiheng 3 дней назад
Родитель
Сommit
bc42ad0ee4

+ 17 - 13
examples_how/overall_derivation/tools/find_pattern.py

@@ -133,6 +133,7 @@ def get_patterns_by_conditional_ratio(
     返回每项:pattern名称(nameA+nameB+nameC)、条件概率。
     返回每项:pattern名称(nameA+nameB+nameC)、条件概率。
     """
     """
     merged = _load_and_merge_patterns(account_name)
     merged = _load_and_merge_patterns(account_name)
+    print(f"_load_and_merge_patterns,patterns: {len(merged)}")
     if not merged:
     if not merged:
         return []
         return []
     base_dir = _BASE_INPUT
     base_dir = _BASE_INPUT
@@ -190,7 +191,7 @@ def get_patterns_by_conditional_ratio(
     description="按条件概率从 pattern 库中筛选 pattern,优先返回包含已推导选题点的 pattern,并检查每个 pattern 的元素是否与帖子选题点匹配。"
     description="按条件概率从 pattern 库中筛选 pattern,优先返回包含已推导选题点的 pattern,并检查每个 pattern 的元素是否与帖子选题点匹配。"
     "功能:根据账号与已推导选题点(可选),筛选条件概率不低于阈值的 pattern;当 derived_items 非空时,优先返回 pattern 元素中包含已推导选题点的 pattern;同时对每个 pattern 的所有元素做帖子选题点匹配,匹配结果直接包含在返回数据中。"
     "功能:根据账号与已推导选题点(可选),筛选条件概率不低于阈值的 pattern;当 derived_items 非空时,优先返回 pattern 元素中包含已推导选题点的 pattern;同时对每个 pattern 的所有元素做帖子选题点匹配,匹配结果直接包含在返回数据中。"
     "参数:account_name 为账号名;post_id 为帖子ID,用于加载帖子选题点并做匹配判断;derived_items 为已推导选题点列表,每项含 topic(或已推导的选题点)与 source_node(或推导来源人设树节点),可为空,为空时条件概率使用 pattern 自身的 support;conditional_ratio_threshold 为条件概率阈值;top_n 为返回条数上限,默认 100。"
     "参数:account_name 为账号名;post_id 为帖子ID,用于加载帖子选题点并做匹配判断;derived_items 为已推导选题点列表,每项含 topic(或已推导的选题点)与 source_node(或推导来源人设树节点),可为空,为空时条件概率使用 pattern 自身的 support;conditional_ratio_threshold 为条件概率阈值;top_n 为返回条数上限,默认 100。"
-    "返回:ToolResult,output 为可读的 pattern 列表文本,metadata.items 为列表,每项含「pattern名称」(nameA+nameB+nameC 形式)、「条件概率」、「帖子选题点匹配」(匹配到帖子选题点的元素列表,每项含 pattern元素、帖子选题点与匹配分数;若无匹配则为字符串'无匹配帖子选题点')。"
+    "返回:ToolResult,output 为可读的 pattern 列表文本,metadata.items 为列表,每项含「pattern名称」(nameA+nameB+nameC 形式)、「条件概率」、「帖子选题点匹配」(仅当 pattern 元素匹配到至少 2 个不同帖子选题点时才返回匹配列表,每项含 pattern元素、帖子选题点与匹配分数;否则为字符串'无匹配帖子选题点')。"
 )
 )
 async def find_pattern(
 async def find_pattern(
     account_name: str,
     account_name: str,
@@ -221,7 +222,7 @@ async def find_pattern(
         - output: 可读的 pattern 列表文本(每行:pattern名称、条件概率、帖子匹配情况)。
         - output: 可读的 pattern 列表文本(每行:pattern名称、条件概率、帖子匹配情况)。
         - metadata: 含 account_name、conditional_ratio_threshold、top_n、count、items;
         - metadata: 含 account_name、conditional_ratio_threshold、top_n、count、items;
           items 为列表,每项为 {"pattern名称": str, "条件概率": float,
           items 为列表,每项为 {"pattern名称": str, "条件概率": float,
-          "帖子选题点匹配": list[{"pattern元素": str, "帖子选题点": str, "匹配分数": float}] 或 "无匹配帖子选题点"}。
+          "帖子选题点匹配": 仅当匹配到至少 2 个不同帖子选题点时为 list[{"pattern元素", "帖子选题点", "匹配分数"}],否则为 "无匹配帖子选题点"}。
         - 出错时 error 为错误信息。
         - 出错时 error 为错误信息。
     """
     """
     pattern_path = _pattern_file(account_name)
     pattern_path = _pattern_file(account_name)
@@ -263,7 +264,11 @@ async def find_pattern(
                             "帖子选题点": post_match["帖子选题点"],
                             "帖子选题点": post_match["帖子选题点"],
                             "匹配分数": post_match["匹配分数"],
                             "匹配分数": post_match["匹配分数"],
                         })
                         })
-                item["帖子选题点匹配"] = pattern_matches if pattern_matches else "无匹配帖子选题点"
+                # 仅当 pattern 元素匹配到至少 2 个不同帖子选题点时才返回匹配信息,否则返回无匹配
+                distinct_post_points = len({m["帖子选题点"] for m in pattern_matches})
+                item["帖子选题点匹配"] = (
+                    pattern_matches if distinct_post_points >= 2 else "无匹配帖子选题点"
+                )
         if not items:
         if not items:
             output = f"未找到条件概率 >= {conditional_ratio_threshold} 的 pattern"
             output = f"未找到条件概率 >= {conditional_ratio_threshold} 的 pattern"
         else:
         else:
@@ -286,7 +291,6 @@ async def find_pattern(
                 "conditional_ratio_threshold": conditional_ratio_threshold,
                 "conditional_ratio_threshold": conditional_ratio_threshold,
                 "top_n": top_n,
                 "top_n": top_n,
                 "count": len(items),
                 "count": len(items),
-                "items": items,
             },
             },
         )
         )
     except Exception as e:
     except Exception as e:
@@ -310,17 +314,17 @@ def main() -> None:
         {"topic": "叙事结构", "source_node": "叙事逻辑"},
         {"topic": "叙事结构", "source_node": "叙事逻辑"},
     ]
     ]
     conditional_ratio_threshold = 0.01
     conditional_ratio_threshold = 0.01
-    top_n = 100
+    top_n = 2000
 
 
     # 1)直接调用核心函数(不含帖子匹配,仅验证排序逻辑)
     # 1)直接调用核心函数(不含帖子匹配,仅验证排序逻辑)
-    derived_list = _parse_derived_list(derived_items)
-    items = get_patterns_by_conditional_ratio(
-        account_name, derived_list, conditional_ratio_threshold, top_n, post_id
-    )
-    print(f"账号: {account_name}, 阈值: {conditional_ratio_threshold}, top_n: {top_n}")
-    print(f"共 {len(items)} 条 pattern:\n")
-    for x in items:
-        print(f"  - {x['pattern名称']}\t条件概率={x['条件概率']}")
+    # derived_list = _parse_derived_list(derived_items)
+    # items = get_patterns_by_conditional_ratio(
+    #     account_name, derived_list, conditional_ratio_threshold, top_n, post_id
+    # )
+    # print(f"账号: {account_name}, 阈值: {conditional_ratio_threshold}, top_n: {top_n}")
+    # print(f"共 {len(items)} 条 pattern:\n")
+    # for x in items:
+    #     print(f"  - {x['pattern名称']}\t条件概率={x['条件概率']}")
 
 
     # 2)有 agent 时通过 tool 接口再跑一遍(含帖子选题点匹配)
     # 2)有 agent 时通过 tool 接口再跑一遍(含帖子选题点匹配)
     if ToolResult is not None:
     if ToolResult is not None:

+ 4 - 4
examples_how/overall_derivation/tools/find_tree_node.py

@@ -230,7 +230,7 @@ async def find_tree_constant_nodes(
         return ToolResult(
         return ToolResult(
             title=f"常量节点 ({account_name})",
             title=f"常量节点 ({account_name})",
             output=output,
             output=output,
-            metadata={"account_name": account_name, "count": len(items), "items": items},
+            metadata={"account_name": account_name, "count": len(items)},
         )
         )
     except Exception as e:
     except Exception as e:
         return ToolResult(
         return ToolResult(
@@ -323,7 +323,6 @@ async def find_tree_nodes_by_conditional_ratio(
                 "threshold": conditional_ratio_threshold,
                 "threshold": conditional_ratio_threshold,
                 "top_n": top_n,
                 "top_n": top_n,
                 "count": len(items),
                 "count": len(items),
-                "items": items,
             },
             },
         )
         )
     except Exception as e:
     except Exception as e:
@@ -342,9 +341,10 @@ def main() -> None:
     post_id = "68fb6a5c000000000302e5de"
     post_id = "68fb6a5c000000000302e5de"
     derived_items = [
     derived_items = [
         {"topic": "分享", "source_node": "分享"},
         {"topic": "分享", "source_node": "分享"},
+        {"topic": "叙事结构", "source_node": "叙事结构"},
     ]
     ]
     conditional_ratio_threshold = 0.1
     conditional_ratio_threshold = 0.1
-    top_n = 10
+    top_n = 100
 
 
     # 1)常量节点(核心函数,无匹配)
     # 1)常量节点(核心函数,无匹配)
     constant_nodes = get_constant_nodes(account_name)
     constant_nodes = get_constant_nodes(account_name)
@@ -368,7 +368,7 @@ def main() -> None:
         async def run_tools():
         async def run_tools():
             r1 = await find_tree_constant_nodes(account_name, post_id=post_id)
             r1 = await find_tree_constant_nodes(account_name, post_id=post_id)
             print("--- find_tree_constant_nodes ---")
             print("--- find_tree_constant_nodes ---")
-            print(r1.output[:200] + "..." if len(r1.output) > 200 else r1.output)
+            print(r1.output[:2000] + "..." if len(r1.output) > 2000 else r1.output)
             r2 = await find_tree_nodes_by_conditional_ratio(
             r2 = await find_tree_nodes_by_conditional_ratio(
                 account_name,
                 account_name,
                 post_id=post_id,
                 post_id=post_id,

+ 5 - 0
examples_how/overall_derivation/utils/similarity_calc.py

@@ -361,6 +361,11 @@ async def test_similarity_matrix() -> None:
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
+    # 直接运行 python similarity_calc.py 时,将项目根加入 path,以便 import agent
+    _root = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
+    if _root not in __import__("sys").path:
+        __import__("sys").path.insert(0, _root)
+
     test_phrase_pairs()
     test_phrase_pairs()
     test_extract_json_array()
     test_extract_json_array()
     print("运行集成测试(需 embedding API、OPEN_ROUTER_API_KEY 及 agent 依赖)...")
     print("运行集成测试(需 embedding API、OPEN_ROUTER_API_KEY 及 agent 依赖)...")