find_pattern.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. """
  2. 查找 Pattern Tool - 从 pattern 库中获取符合条件概率阈值的 pattern
  3. 功能:读取账号的 pattern 库,合并去重后按条件概率筛选,返回 topN 条 pattern(含 pattern 名称、条件概率)。
  4. """
  5. import importlib.util
  6. import json
  7. from pathlib import Path
  8. from typing import Any, Optional
  9. try:
  10. from agent.tools import tool, ToolResult, ToolContext
  11. except ImportError:
  12. def tool(*args, **kwargs):
  13. return lambda f: f
  14. ToolResult = None # 仅用 main() 测核心逻辑时可无 agent
  15. ToolContext = None
  16. # 与 pattern_data_process 一致的 key 定义
  17. TOP_KEYS = [
  18. "depth_max_with_name",
  19. "depth_mixed",
  20. "depth_max_concrete",
  21. "depth2_medium",
  22. "depth1_abstract",
  23. ]
  24. SUB_KEYS = ["two_x", "one_x", "zero_x"]
  25. # 加载 conditional_ratio_calc(与 find_tree_node 一致)
  26. _utils_dir = Path(__file__).resolve().parent.parent / "utils"
  27. _cond_spec = importlib.util.spec_from_file_location(
  28. "conditional_ratio_calc",
  29. _utils_dir / "conditional_ratio_calc.py",
  30. )
  31. _cond_mod = importlib.util.module_from_spec(_cond_spec)
  32. _cond_spec.loader.exec_module(_cond_mod)
  33. calc_pattern_conditional_ratio = _cond_mod.calc_pattern_conditional_ratio
  34. _BASE_INPUT = Path(__file__).resolve().parent.parent / "input"
  35. def _pattern_file(account_name: str) -> Path:
  36. """pattern 库文件:../input/{account_name}/原始数据/pattern/processed_edge_data.json"""
  37. return _BASE_INPUT / account_name / "原始数据" / "pattern" / "processed_edge_data.json"
  38. def _slim_pattern(p: dict) -> tuple[float, int, list[str], int]:
  39. """提取 name 列表(去重保序)、support、length、post_count。"""
  40. names = [item["name"] for item in (p.get("items") or [])]
  41. seen = set()
  42. unique = []
  43. for n in names:
  44. if n not in seen:
  45. seen.add(n)
  46. unique.append(n)
  47. support = round(float(p.get("support", 0)), 4)
  48. length = int(p.get("length", 0))
  49. post_count = int(p.get("post_count", 0))
  50. return support, length, unique, post_count
  51. def _merge_and_dedupe(patterns: list[dict]) -> list[dict]:
  52. """
  53. 按 items 的 name 集合去重(不区分顺序),留 support 最大;
  54. 输出格式保留 s、l、i(nameA+nameB+nameC)及 post_count,供条件概率计算使用。
  55. """
  56. key_to_best: dict[tuple, tuple[float, int, int]] = {}
  57. for p in patterns:
  58. support, length, unique, post_count = _slim_pattern(p)
  59. if not unique:
  60. continue
  61. key = tuple(sorted(unique))
  62. if key not in key_to_best or support > key_to_best[key][0]:
  63. key_to_best[key] = (support, length, post_count)
  64. out = []
  65. for k, (s, l, post_count) in key_to_best.items():
  66. if s < 0.1:
  67. continue
  68. out.append({
  69. "s": s,
  70. "l": l,
  71. "i": "+".join(k),
  72. "post_count": post_count,
  73. })
  74. out.sort(key=lambda x: x["s"] * x["l"], reverse=True)
  75. return out
  76. def _load_and_merge_patterns(account_name: str) -> list[dict]:
  77. """读取 pattern 库 JSON,按 TOP_KEYS/SUB_KEYS 合并为列表并做合并、去重。"""
  78. path = _pattern_file(account_name)
  79. if not path.is_file():
  80. return []
  81. with open(path, "r", encoding="utf-8") as f:
  82. data = json.load(f)
  83. all_patterns = []
  84. for top in TOP_KEYS:
  85. if top not in data:
  86. continue
  87. block = data[top]
  88. for sub in SUB_KEYS:
  89. all_patterns.extend(block.get(sub) or [])
  90. return _merge_and_dedupe(all_patterns)
  91. def _parse_derived_list(derived_items: list[dict[str, str]]) -> list[tuple[str, str]]:
  92. """将 agent 传入的 [{"topic": "x", "source_node": "y"}, ...] 转为 DerivedItem 列表。"""
  93. out = []
  94. for item in derived_items:
  95. if isinstance(item, dict):
  96. topic = item.get("topic") or item.get("已推导的选题点")
  97. source = item.get("source_node") or item.get("推导来源人设树节点")
  98. if topic is not None and source is not None:
  99. out.append((str(topic).strip(), str(source).strip()))
  100. elif isinstance(item, (list, tuple)) and len(item) >= 2:
  101. out.append((str(item[0]).strip(), str(item[1]).strip()))
  102. return out
  103. def get_patterns_by_conditional_ratio(
  104. account_name: str,
  105. derived_list: list[tuple[str, str]],
  106. conditional_ratio_threshold: float,
  107. top_n: int,
  108. ) -> list[dict[str, Any]]:
  109. """
  110. 从 pattern 库中获取条件概率 >= 阈值的 pattern,按条件概率降序(同分按 length 降序),返回 top_n 条。
  111. derived_list 为空时,条件概率使用 pattern 自身的 support(s)。
  112. 返回每项:pattern名称(nameA+nameB+nameC)、条件概率。
  113. """
  114. merged = _load_and_merge_patterns(account_name)
  115. if not merged:
  116. return []
  117. base_dir = _BASE_INPUT
  118. scored: list[tuple[dict, float]] = []
  119. if not derived_list:
  120. # derived_items 为空:条件概率取 pattern 本身的 support (s)
  121. for p in merged:
  122. ratio = float(p.get("s", 0))
  123. if ratio >= conditional_ratio_threshold:
  124. scored.append((p, ratio))
  125. else:
  126. for p in merged:
  127. ratio = calc_pattern_conditional_ratio(
  128. account_name, derived_list, p, base_dir=base_dir
  129. )
  130. if ratio >= conditional_ratio_threshold:
  131. scored.append((p, ratio))
  132. scored.sort(key=lambda x: (-x[1], -x[0]["l"]))
  133. result = []
  134. for p, ratio in scored[:top_n]:
  135. result.append({
  136. "pattern名称": p["i"],
  137. "条件概率": round(ratio, 6),
  138. })
  139. return result
  140. @tool(
  141. description="从 pattern 库中获取符合条件概率阈值的 pattern。"
  142. "输入:账号名、已推导选题点列表(可为空)、条件概率阈值、topN。"
  143. "derived_items 为空时,条件概率使用 pattern 自身的 support。"
  144. "返回:pattern 名称(nameA+nameB+nameC)及条件概率,按条件概率从高到低最多 topN 条。"
  145. )
  146. async def find_pattern(
  147. account_name: str,
  148. derived_items: list[dict[str, str]],
  149. conditional_ratio_threshold: float,
  150. top_n: int = 20,
  151. context: Optional[ToolContext] = None,
  152. ) -> ToolResult:
  153. """
  154. 从 pattern 库中获取符合条件概率阈值的 pattern。
  155. derived_items:可为空;非空时每项为 {"topic": "已推导选题点", "source_node": "推导来源人设树节点"}。
  156. 当 derived_items 为空时,各 pattern 的条件概率取其 support(s);非空时按已推导帖子集合计算条件概率。
  157. 返回每条:pattern名称(nameA+nameB+nameC)、条件概率。
  158. """
  159. pattern_path = _pattern_file(account_name)
  160. if not pattern_path.is_file():
  161. return ToolResult(
  162. title="Pattern 库不存在",
  163. output=f"pattern 文件不存在: {pattern_path}",
  164. error="Pattern file not found",
  165. )
  166. try:
  167. derived_list = _parse_derived_list(derived_items or [])
  168. items = get_patterns_by_conditional_ratio(
  169. account_name, derived_list, conditional_ratio_threshold, top_n
  170. )
  171. if not items:
  172. output = f"未找到条件概率 >= {conditional_ratio_threshold} 的 pattern"
  173. else:
  174. lines = [
  175. f"- {x['pattern名称']}\t条件概率={x['条件概率']}"
  176. for x in items
  177. ]
  178. output = "\n".join(lines)
  179. return ToolResult(
  180. title=f"符合条件概率的 Pattern ({account_name}, 阈值={conditional_ratio_threshold})",
  181. output=output,
  182. metadata={
  183. "account_name": account_name,
  184. "conditional_ratio_threshold": conditional_ratio_threshold,
  185. "top_n": top_n,
  186. "count": len(items),
  187. "items": items,
  188. },
  189. )
  190. except Exception as e:
  191. return ToolResult(
  192. title="查找 Pattern 失败",
  193. output=str(e),
  194. error=str(e),
  195. )
  196. def main() -> None:
  197. """本地测试:用家有大志账号、已推导选题点,查询符合条件概率阈值的 pattern。"""
  198. import asyncio
  199. account_name = "家有大志"
  200. # 已推导选题点,每项:已推导的选题点 + 推导来源人设树节点
  201. derived_items = [
  202. {"topic": "分享", "source_node": "分享"},
  203. {"topic": "柴犬", "source_node": "动物角色"},
  204. ]
  205. conditional_ratio_threshold = 0.01
  206. top_n = 10
  207. # 1)直接调用核心函数
  208. derived_list = _parse_derived_list(derived_items)
  209. items = get_patterns_by_conditional_ratio(
  210. account_name, derived_list, conditional_ratio_threshold, top_n
  211. )
  212. print(f"账号: {account_name}, 阈值: {conditional_ratio_threshold}, top_n: {top_n}")
  213. print(f"共 {len(items)} 条 pattern:\n")
  214. for x in items:
  215. print(f" - {x['pattern名称']}\t条件概率={x['条件概率']}")
  216. # 2)有 agent 时通过 tool 接口再跑一遍
  217. if ToolResult is not None:
  218. async def run_tool():
  219. result = await find_pattern(
  220. account_name=account_name,
  221. derived_items=derived_items,
  222. conditional_ratio_threshold=conditional_ratio_threshold,
  223. top_n=top_n,
  224. )
  225. print("\n--- Tool 返回 ---")
  226. print(result.output)
  227. asyncio.run(run_tool())
  228. if __name__ == "__main__":
  229. main()