point_match.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. """
  2. 选题点匹配 Tool - 判断推导选题点是否与帖子中的选题点匹配
  3. 功能:读取帖子选题点列表,与推导选题点做相似度计算,返回 combined_score >= 阈值的匹配对。
  4. """
  5. import json
  6. import sys
  7. from pathlib import Path
  8. from typing import Any, List, Optional
  9. # 保证直接运行或作为包加载时都能解析 utils(IDE 可跳转)
  10. _root = Path(__file__).resolve().parent.parent
  11. if str(_root) not in sys.path:
  12. sys.path.insert(0, str(_root))
  13. from utils.similarity_calc import similarity_matrix
  14. try:
  15. from agent.tools import tool, ToolResult, ToolContext
  16. except ImportError:
  17. def tool(*args, **kwargs):
  18. return lambda f: f
  19. ToolResult = None
  20. ToolContext = None
  21. _BASE_INPUT = Path(__file__).resolve().parent.parent / "input"
  22. # 默认匹配阈值
  23. DEFAULT_MATCH_THRESHOLD = 0.6
  24. def _post_topic_file(account_name: str, post_id: str) -> Path:
  25. """帖子选题点文件:../input/{account_name}/post_topic/{post_id}.json"""
  26. return _BASE_INPUT / account_name / "post_topic" / f"{post_id}.json"
  27. def _match_data_file(account_name: str, post_id: str) -> Path:
  28. """帖子选题点与人设树节点匹配结果文件:../input/{account_name}/match_data/{post_id}_匹配_all.json"""
  29. return _BASE_INPUT / account_name / "match_data" / f"{post_id}_匹配_all.json"
  30. def _load_match_data(
  31. account_name: str, post_id: str
  32. ) -> dict[tuple[str, str], float]:
  33. """
  34. 从匹配文件中读取 (帖子选题点, 人设树节点) -> 匹配分数。
  35. 文件为 JSON 数组,每项 name 为帖子选题点,match_personas 中 name 为人设树节点,match_score 为分数。
  36. """
  37. path = _match_data_file(account_name, post_id)
  38. if not path.is_file():
  39. return {}
  40. with open(path, "r", encoding="utf-8") as f:
  41. data = json.load(f)
  42. if not isinstance(data, list):
  43. return {}
  44. lookup: dict[tuple[str, str], float] = {}
  45. for item in data:
  46. if not isinstance(item, dict):
  47. continue
  48. post_point = item.get("name")
  49. personas = item.get("match_personas")
  50. if post_point is None or not isinstance(personas, list):
  51. continue
  52. post_point = str(post_point).strip()
  53. if not post_point:
  54. continue
  55. for mp in personas:
  56. if not isinstance(mp, dict):
  57. continue
  58. persona_name = mp.get("name")
  59. score = mp.get("match_score")
  60. if persona_name is None or score is None:
  61. continue
  62. persona_name = str(persona_name).strip()
  63. try:
  64. lookup[(post_point, persona_name)] = float(score)
  65. except (TypeError, ValueError):
  66. continue
  67. return lookup
  68. def _load_post_topic_points(account_name: str, post_id: str) -> List[str]:
  69. """从 post_topic JSON 读取帖子选题点列表。文件内容为字符串数组。"""
  70. path = _post_topic_file(account_name, post_id)
  71. if not path.is_file():
  72. return []
  73. with open(path, "r", encoding="utf-8") as f:
  74. data = json.load(f)
  75. if not isinstance(data, list):
  76. return []
  77. return [str(x).strip() for x in data if x and str(x).strip()]
  78. def _to_derivation_points(derivation_output_points: List[str]) -> List[str]:
  79. """从推导选题点字符串列表中筛出非空并 strip,返回列表。"""
  80. return [s.strip() for s in derivation_output_points if s is not None and str(s).strip()]
  81. async def match_derivation_to_post_points(
  82. derivation_output_points: List[str],
  83. account_name: str,
  84. post_id: str,
  85. match_threshold: float = DEFAULT_MATCH_THRESHOLD,
  86. ) -> List[dict[str, Any]]:
  87. """
  88. 判断推导选题点(视为人设树节点)是否与帖子选题点匹配,返回分数 >= match_threshold 的列表。
  89. 优先从 ../input/{account_name}/match_data/{post_id}_匹配_all.json 读取匹配分数,
  90. 文件中未出现的 (帖子选题点, 人设树节点) 再通过 similarity_matrix 计算。
  91. Returns:
  92. 每项: {"推导选题点": str, "帖子选题点": str, "匹配分数": float}
  93. """
  94. post_points = _load_post_topic_points(account_name, post_id)
  95. derivation_points = _to_derivation_points(derivation_output_points)
  96. if not derivation_points:
  97. return []
  98. if not post_points:
  99. return []
  100. # 从匹配文件读取 (帖子选题点, 人设树节点) -> 匹配分数;derivation_points 当作人设树节点
  101. match_lookup = _load_match_data(account_name, post_id)
  102. scores: dict[tuple[str, str], float] = {}
  103. missing_pairs: list[tuple[str, str]] = []
  104. for d in derivation_points:
  105. for p in post_points:
  106. key = (p, d)
  107. if key in match_lookup:
  108. scores[(d, p)] = match_lookup[key]
  109. else:
  110. missing_pairs.append((d, p))
  111. # 文件中没有的 (推导选题点, 帖子选题点) 用 similarity_matrix 计算
  112. if missing_pairs:
  113. derivation_missing = list({d for d, _ in missing_pairs})
  114. post_missing = list({p for _, p in missing_pairs})
  115. items = await similarity_matrix(derivation_missing, post_missing)
  116. for row in items:
  117. d, p = row["phrase_a"], row["phrase_b"]
  118. scores[(d, p)] = row["combined_score"]
  119. matched = []
  120. for (d, p), score in scores.items():
  121. if score >= match_threshold:
  122. matched.append({
  123. "推导选题点": d,
  124. "帖子选题点": p,
  125. "匹配分数": round(score, 6),
  126. })
  127. return matched
  128. @tool(
  129. description="判断推导选题点(人设树节点)与帖子选题点是否匹配。"
  130. "功能:根据账号与帖子ID,将传入的推导选题点列表与帖子选题点做匹配,返回达到阈值的匹配对。"
  131. "参数:derivation_output_points 为推导选题点字符串列表;account_name 为账号名;post_id 为帖子ID;match_threshold 为匹配分数阈值,默认 0.8。"
  132. "返回:ToolResult,output 为可读匹配结果文本,metadata.items 为匹配列表,每项含「推导选题点」「帖子选题点」「匹配分数」。"
  133. )
  134. async def point_match(
  135. derivation_output_points: List[str],
  136. account_name: str,
  137. post_id: str,
  138. match_threshold: float = DEFAULT_MATCH_THRESHOLD,
  139. context: Optional[ToolContext] = None,
  140. ) -> ToolResult:
  141. """
  142. 判断推导选题点与帖子选题点是否匹配,返回达到阈值的匹配对。
  143. 参数
  144. -------
  145. derivation_output_points : 推导选题点字符串列表。
  146. account_name : 账号名,用于定位 input 下的账号目录。
  147. post_id : 帖子ID,用于定位该帖的选题点与匹配数据。
  148. match_threshold : 匹配分数阈值,分数 >= 该值视为匹配成功,默认 0.8。
  149. context : 可选,Agent 工具上下文。
  150. 返回
  151. -------
  152. ToolResult:
  153. - title: 结果标题。
  154. - output: 可读的匹配结果文本(每行:推导选题点、帖子选题点、匹配分数)。
  155. - metadata: 含 account_name、post_id、match_threshold、count、items;
  156. items 为列表,每项为 {"推导选题点": str, "帖子选题点": str, "匹配分数": float}。
  157. - 出错时 error 为错误信息。
  158. """
  159. topic_path = _post_topic_file(account_name, post_id)
  160. if not topic_path.is_file():
  161. return ToolResult(
  162. title="帖子选题点文件不存在",
  163. output=f"帖子选题点文件不存在: {topic_path}",
  164. error="Post topic file not found",
  165. )
  166. try:
  167. derivation_points = _to_derivation_points(derivation_output_points)
  168. if not derivation_points:
  169. return ToolResult(
  170. title="参数无效",
  171. output="derivation_output_points 不能为空,且需为字符串列表",
  172. error="Invalid derivation_output_points",
  173. )
  174. matched = await match_derivation_to_post_points(
  175. derivation_output_points, account_name, post_id, match_threshold
  176. )
  177. if not matched:
  178. output = f"未找到 combined_score >= {match_threshold} 的匹配"
  179. else:
  180. lines = [
  181. f"- 推导: {x['推导选题点']}\t帖子: {x['帖子选题点']}\t分数={x['匹配分数']}"
  182. for x in matched
  183. ]
  184. output = "\n".join(lines)
  185. return ToolResult(
  186. title=f"选题点匹配结果 ({account_name}, post_id={post_id})",
  187. output=output,
  188. metadata={
  189. "account_name": account_name,
  190. "post_id": post_id,
  191. "match_threshold": match_threshold,
  192. "count": len(matched),
  193. "items": matched,
  194. },
  195. )
  196. except Exception as e:
  197. return ToolResult(
  198. title="选题点匹配失败",
  199. output=str(e),
  200. error=str(e),
  201. )
  202. def main() -> None:
  203. """本地测试:用家有大志账号、某帖子ID、推导选题点列表测试匹配。"""
  204. import asyncio
  205. account_name = "家有大志"
  206. post_id = "68fb6a5c000000000302e5de"
  207. derivation_output_points = ["分享", "创意改造", "柴犬", "不存在的点"]
  208. async def run():
  209. matched = await match_derivation_to_post_points(
  210. derivation_output_points, account_name, post_id
  211. )
  212. print(f"账号: {account_name}, post_id: {post_id}")
  213. print(f"推导选题点: {derivation_output_points}")
  214. print(f"匹配成功 {len(matched)} 条:\n")
  215. for x in matched:
  216. print(f" - 推导: {x['推导选题点']}\t帖子: {x['帖子选题点']}\t分数={x['匹配分数']}")
  217. if ToolResult is not None:
  218. result = await point_match(
  219. derivation_output_points=derivation_output_points,
  220. account_name=account_name,
  221. post_id=post_id,
  222. )
  223. print("\n--- Tool 返回 ---")
  224. print(result.output)
  225. asyncio.run(run())
  226. if __name__ == "__main__":
  227. main()