point_match.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. """
  2. 选题点匹配 Tool - 判断推导选题点是否与帖子中的选题点匹配
  3. 功能:读取帖子选题点列表,与推导选题点做相似度计算,返回 combined_score >= 阈值的匹配对。
  4. """
  5. import importlib.util
  6. import json
  7. from pathlib import Path
  8. from typing import Any, List, Optional
  9. try:
  10. from agent.tools import tool, ToolResult, ToolContext
  11. except ImportError:
  12. def tool(*args, **kwargs):
  13. return lambda f: f
  14. ToolResult = None
  15. ToolContext = None
  16. # 加载 similarity_calc
  17. _utils_dir = Path(__file__).resolve().parent.parent / "utils"
  18. _sim_spec = importlib.util.spec_from_file_location(
  19. "similarity_calc",
  20. _utils_dir / "similarity_calc.py",
  21. )
  22. _sim_mod = importlib.util.module_from_spec(_sim_spec)
  23. _sim_spec.loader.exec_module(_sim_mod)
  24. similarity_matrix = _sim_mod.similarity_matrix
  25. _BASE_INPUT = Path(__file__).resolve().parent.parent / "input"
  26. # 默认匹配阈值
  27. DEFAULT_MATCH_THRESHOLD = 0.8
  28. def _post_topic_file(account_name: str, post_id: str) -> Path:
  29. """帖子选题点文件:../input/{account_name}/post_topic/{post_id}.json"""
  30. return _BASE_INPUT / account_name / "post_topic" / f"{post_id}.json"
  31. def _match_data_file(account_name: str, post_id: str) -> Path:
  32. """帖子选题点与人设树节点匹配结果文件:../input/{account_name}/match_data/{post_id}_匹配_all.json"""
  33. return _BASE_INPUT / account_name / "match_data" / f"{post_id}_匹配_all.json"
  34. def _load_match_data(
  35. account_name: str, post_id: str
  36. ) -> dict[tuple[str, str], float]:
  37. """
  38. 从匹配文件中读取 (帖子选题点, 人设树节点) -> 匹配分数。
  39. 文件为 JSON 数组,每项 name 为帖子选题点,match_personas 中 name 为人设树节点,match_score 为分数。
  40. """
  41. path = _match_data_file(account_name, post_id)
  42. if not path.is_file():
  43. return {}
  44. with open(path, "r", encoding="utf-8") as f:
  45. data = json.load(f)
  46. if not isinstance(data, list):
  47. return {}
  48. lookup: dict[tuple[str, str], float] = {}
  49. for item in data:
  50. if not isinstance(item, dict):
  51. continue
  52. post_point = item.get("name")
  53. personas = item.get("match_personas")
  54. if post_point is None or not isinstance(personas, list):
  55. continue
  56. post_point = str(post_point).strip()
  57. if not post_point:
  58. continue
  59. for mp in personas:
  60. if not isinstance(mp, dict):
  61. continue
  62. persona_name = mp.get("name")
  63. score = mp.get("match_score")
  64. if persona_name is None or score is None:
  65. continue
  66. persona_name = str(persona_name).strip()
  67. try:
  68. lookup[(post_point, persona_name)] = float(score)
  69. except (TypeError, ValueError):
  70. continue
  71. return lookup
  72. def _load_post_topic_points(account_name: str, post_id: str) -> List[str]:
  73. """从 post_topic JSON 读取帖子选题点列表。文件内容为字符串数组。"""
  74. path = _post_topic_file(account_name, post_id)
  75. if not path.is_file():
  76. return []
  77. with open(path, "r", encoding="utf-8") as f:
  78. data = json.load(f)
  79. if not isinstance(data, list):
  80. return []
  81. return [str(x).strip() for x in data if x and str(x).strip()]
  82. def _to_derivation_points(derivation_output_points: List[str]) -> List[str]:
  83. """从推导选题点字符串列表中筛出非空并 strip,返回列表。"""
  84. return [s.strip() for s in derivation_output_points if s is not None and str(s).strip()]
  85. async def match_derivation_to_post_points(
  86. derivation_output_points: List[str],
  87. account_name: str,
  88. post_id: str,
  89. match_threshold: float = DEFAULT_MATCH_THRESHOLD,
  90. ) -> List[dict[str, Any]]:
  91. """
  92. 判断推导选题点(视为人设树节点)是否与帖子选题点匹配,返回分数 >= match_threshold 的列表。
  93. 优先从 ../input/{account_name}/match_data/{post_id}_匹配_all.json 读取匹配分数,
  94. 文件中未出现的 (帖子选题点, 人设树节点) 再通过 similarity_matrix 计算。
  95. Returns:
  96. 每项: {"推导选题点": str, "帖子选题点": str, "匹配分数": float}
  97. """
  98. post_points = _load_post_topic_points(account_name, post_id)
  99. derivation_points = _to_derivation_points(derivation_output_points)
  100. if not derivation_points:
  101. return []
  102. if not post_points:
  103. return []
  104. # 从匹配文件读取 (帖子选题点, 人设树节点) -> 匹配分数;derivation_points 当作人设树节点
  105. match_lookup = _load_match_data(account_name, post_id)
  106. scores: dict[tuple[str, str], float] = {}
  107. missing_pairs: list[tuple[str, str]] = []
  108. for d in derivation_points:
  109. for p in post_points:
  110. key = (p, d)
  111. if key in match_lookup:
  112. scores[(d, p)] = match_lookup[key]
  113. else:
  114. missing_pairs.append((d, p))
  115. # 文件中没有的 (推导选题点, 帖子选题点) 用 similarity_matrix 计算
  116. if missing_pairs:
  117. derivation_missing = list({d for d, _ in missing_pairs})
  118. post_missing = list({p for _, p in missing_pairs})
  119. items = await similarity_matrix(derivation_missing, post_missing)
  120. for row in items:
  121. d, p = row["phrase_a"], row["phrase_b"]
  122. scores[(d, p)] = row["combined_score"]
  123. matched = []
  124. for (d, p), score in scores.items():
  125. if score >= match_threshold:
  126. matched.append({
  127. "推导选题点": d,
  128. "帖子选题点": p,
  129. "匹配分数": round(score, 6),
  130. })
  131. return matched
  132. @tool(
  133. description="判断推导选题点(人设树节点)与帖子选题点是否匹配。"
  134. "功能:根据账号与帖子ID,将传入的推导选题点列表与帖子选题点做匹配,返回达到阈值的匹配对。"
  135. "参数:derivation_output_points 为推导选题点字符串列表;account_name 为账号名;post_id 为帖子ID;match_threshold 为匹配分数阈值,默认 0.8。"
  136. "返回:ToolResult,output 为可读匹配结果文本,metadata.items 为匹配列表,每项含「推导选题点」「帖子选题点」「匹配分数」。"
  137. )
  138. async def point_match(
  139. derivation_output_points: List[str],
  140. account_name: str,
  141. post_id: str,
  142. match_threshold: float = DEFAULT_MATCH_THRESHOLD,
  143. context: Optional[ToolContext] = None,
  144. ) -> ToolResult:
  145. """
  146. 判断推导选题点与帖子选题点是否匹配,返回达到阈值的匹配对。
  147. 参数
  148. -------
  149. derivation_output_points : 推导选题点字符串列表。
  150. account_name : 账号名,用于定位 input 下的账号目录。
  151. post_id : 帖子ID,用于定位该帖的选题点与匹配数据。
  152. match_threshold : 匹配分数阈值,分数 >= 该值视为匹配成功,默认 0.8。
  153. context : 可选,Agent 工具上下文。
  154. 返回
  155. -------
  156. ToolResult:
  157. - title: 结果标题。
  158. - output: 可读的匹配结果文本(每行:推导选题点、帖子选题点、匹配分数)。
  159. - metadata: 含 account_name、post_id、match_threshold、count、items;
  160. items 为列表,每项为 {"推导选题点": str, "帖子选题点": str, "匹配分数": float}。
  161. - 出错时 error 为错误信息。
  162. """
  163. topic_path = _post_topic_file(account_name, post_id)
  164. if not topic_path.is_file():
  165. return ToolResult(
  166. title="帖子选题点文件不存在",
  167. output=f"帖子选题点文件不存在: {topic_path}",
  168. error="Post topic file not found",
  169. )
  170. try:
  171. derivation_points = _to_derivation_points(derivation_output_points)
  172. if not derivation_points:
  173. return ToolResult(
  174. title="参数无效",
  175. output="derivation_output_points 不能为空,且需为字符串列表",
  176. error="Invalid derivation_output_points",
  177. )
  178. matched = await match_derivation_to_post_points(
  179. derivation_output_points, account_name, post_id, match_threshold
  180. )
  181. if not matched:
  182. output = f"未找到 combined_score >= {match_threshold} 的匹配"
  183. else:
  184. lines = [
  185. f"- 推导: {x['推导选题点']}\t帖子: {x['帖子选题点']}\t分数={x['匹配分数']}"
  186. for x in matched
  187. ]
  188. output = "\n".join(lines)
  189. return ToolResult(
  190. title=f"选题点匹配结果 ({account_name}, post_id={post_id})",
  191. output=output,
  192. metadata={
  193. "account_name": account_name,
  194. "post_id": post_id,
  195. "match_threshold": match_threshold,
  196. "count": len(matched),
  197. "items": matched,
  198. },
  199. )
  200. except Exception as e:
  201. return ToolResult(
  202. title="选题点匹配失败",
  203. output=str(e),
  204. error=str(e),
  205. )
  206. def main() -> None:
  207. """本地测试:用家有大志账号、某帖子ID、推导选题点列表测试匹配。"""
  208. import asyncio
  209. account_name = "家有大志"
  210. post_id = "68fb6a5c000000000302e5de"
  211. derivation_output_points = ["分享", "创意改造", "柴犬", "不存在的点"]
  212. async def run():
  213. matched = await match_derivation_to_post_points(
  214. derivation_output_points, account_name, post_id
  215. )
  216. print(f"账号: {account_name}, post_id: {post_id}")
  217. print(f"推导选题点: {derivation_output_points}")
  218. print(f"匹配成功 {len(matched)} 条:\n")
  219. for x in matched:
  220. print(f" - 推导: {x['推导选题点']}\t帖子: {x['帖子选题点']}\t分数={x['匹配分数']}")
  221. if ToolResult is not None:
  222. result = await point_match(
  223. derivation_output_points=derivation_output_points,
  224. account_name=account_name,
  225. post_id=post_id,
  226. )
  227. print("\n--- Tool 返回 ---")
  228. print(result.output)
  229. asyncio.run(run())
  230. if __name__ == "__main__":
  231. main()