point_match.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. """
  2. 选题点匹配 Tool - 判断推导选题点是否与帖子中的选题点匹配
  3. 功能:读取帖子选题点列表,与推导选题点做相似度计算,返回 combined_score >= 阈值的匹配对。
  4. """
  5. import json
  6. import sys
  7. from pathlib import Path
  8. from typing import Any, List, Optional
  9. # 保证直接运行或作为包加载时都能解析 utils(IDE 可跳转)
  10. _root = Path(__file__).resolve().parent.parent
  11. if str(_root) not in sys.path:
  12. sys.path.insert(0, str(_root))
  13. from utils.similarity_calc import similarity_matrix
  14. try:
  15. from agent.tools import tool, ToolResult, ToolContext
  16. except ImportError:
  17. def tool(*args, **kwargs):
  18. return lambda f: f
  19. ToolResult = None
  20. ToolContext = None
  21. _BASE_INPUT = Path(__file__).resolve().parent.parent / "input"
  22. # 默认匹配阈值
  23. DEFAULT_MATCH_THRESHOLD = 0.6
  24. def _post_topic_file(account_name: str, post_id: str) -> Path:
  25. """帖子选题点文件:../input/{account_name}/处理后数据/post_topic/{post_id}.json"""
  26. return _BASE_INPUT / account_name / "处理后数据" / "post_topic" / f"{post_id}.json"
  27. def _match_data_file(account_name: str, post_id: str) -> Path:
  28. """帖子选题点与人设树节点匹配结果文件:../input/{account_name}/处理后数据/match_data/{post_id}_匹配_all.json"""
  29. return _BASE_INPUT / account_name / "处理后数据" / "match_data" / f"{post_id}_匹配_all.json"
  30. def _load_match_data(
  31. account_name: str, post_id: str
  32. ) -> dict[tuple[str, str], float]:
  33. """
  34. 从匹配文件中读取 (帖子选题点, 人设树节点) -> 匹配分数。
  35. 文件为 JSON 数组,每项 name 为帖子选题点,match_personas 中 name 为人设树节点,match_score 为分数。
  36. """
  37. path = _match_data_file(account_name, post_id)
  38. if not path.is_file():
  39. return {}
  40. with open(path, "r", encoding="utf-8") as f:
  41. data = json.load(f)
  42. if not isinstance(data, list):
  43. return {}
  44. lookup: dict[tuple[str, str], float] = {}
  45. for item in data:
  46. if not isinstance(item, dict):
  47. continue
  48. post_point = item.get("name")
  49. personas = item.get("match_personas")
  50. if post_point is None or not isinstance(personas, list):
  51. continue
  52. post_point = str(post_point).strip()
  53. if not post_point:
  54. continue
  55. for mp in personas:
  56. if not isinstance(mp, dict):
  57. continue
  58. persona_name = mp.get("name")
  59. score = mp.get("match_score")
  60. if persona_name is None or score is None:
  61. continue
  62. persona_name = str(persona_name).strip()
  63. try:
  64. lookup[(post_point, persona_name)] = float(score)
  65. except (TypeError, ValueError):
  66. continue
  67. return lookup
  68. def _load_post_topic_points(account_name: str, post_id: str) -> List[str]:
  69. """从 post_topic JSON 读取帖子选题点列表。文件内容为字符串数组。"""
  70. path = _post_topic_file(account_name, post_id)
  71. if not path.is_file():
  72. return []
  73. with open(path, "r", encoding="utf-8") as f:
  74. data = json.load(f)
  75. if not isinstance(data, list):
  76. return []
  77. return [str(x).strip() for x in data if x and str(x).strip()]
  78. def _to_derivation_points(derivation_output_points: List[str]) -> List[str]:
  79. """
  80. 从推导选题点字符串列表中筛出非空并 strip,返回去重后的列表。
  81. 兼容 "叙事编排+商业融入+物品" 格式:先按 "+" 拆分,再展开为独立选题点,最终去重(保持首次出现顺序)。
  82. """
  83. seen: set[str] = set()
  84. result: List[str] = []
  85. for s in derivation_output_points:
  86. if s is None:
  87. continue
  88. for part in str(s).split("+"):
  89. part = part.strip()
  90. if part and part not in seen:
  91. seen.add(part)
  92. result.append(part)
  93. return result
  94. async def match_derivation_to_post_points(
  95. derivation_output_points: List[str],
  96. account_name: str,
  97. post_id: str,
  98. match_threshold: float = DEFAULT_MATCH_THRESHOLD,
  99. ) -> List[dict[str, Any]]:
  100. """
  101. 判断推导选题点(视为人设树节点)是否与帖子选题点匹配,返回分数 >= match_threshold 的列表。
  102. 优先从 ../input/{account_name}/match_data/{post_id}_匹配_all.json 读取匹配分数,
  103. 文件中未出现的 (帖子选题点, 人设树节点) 再通过 similarity_matrix 计算。
  104. Returns:
  105. 每项: {"推导选题点": str, "帖子选题点": str, "匹配分数": float}
  106. """
  107. post_points = _load_post_topic_points(account_name, post_id)
  108. derivation_points = _to_derivation_points(derivation_output_points)
  109. if not derivation_points:
  110. return []
  111. if not post_points:
  112. return []
  113. # 从匹配文件读取 (帖子选题点, 人设树节点) -> 匹配分数;derivation_points 当作人设树节点
  114. match_lookup = _load_match_data(account_name, post_id)
  115. scores: dict[tuple[str, str], float] = {}
  116. missing_pairs: list[tuple[str, str]] = []
  117. for d in derivation_points:
  118. for p in post_points:
  119. key = (p, d)
  120. if key in match_lookup:
  121. scores[(d, p)] = match_lookup[key]
  122. else:
  123. missing_pairs.append((d, p))
  124. # 文件中没有的 (推导选题点, 帖子选题点) 用 similarity_matrix 计算
  125. if missing_pairs:
  126. derivation_missing = list({d for d, _ in missing_pairs})
  127. post_missing = list({p for _, p in missing_pairs})
  128. items = await similarity_matrix(derivation_missing, post_missing)
  129. for row in items:
  130. d, p = row["phrase_a"], row["phrase_b"]
  131. scores[(d, p)] = row["combined_score"]
  132. matched = []
  133. for (d, p), score in scores.items():
  134. if score >= match_threshold:
  135. matched.append({
  136. "推导选题点": d,
  137. "帖子选题点": p,
  138. "匹配分数": round(score, 6),
  139. })
  140. return matched
  141. @tool()
  142. async def point_match(
  143. derivation_output_points: List[str],
  144. account_name: str,
  145. post_id: str,
  146. match_threshold: float = DEFAULT_MATCH_THRESHOLD,
  147. ) -> ToolResult:
  148. """
  149. 判断推导选题点与帖子选题点是否匹配,返回达到阈值的匹配对。
  150. Args:
  151. derivation_output_points : 推导选题点字符串列表。
  152. account_name : 账号名,用于定位 input 下的账号目录。
  153. post_id : 帖子ID,用于定位该帖的选题点与匹配数据。
  154. match_threshold : 匹配分数阈值,分数 >= 该值视为匹配成功,默认 0.6。
  155. Returns:
  156. ToolResult:
  157. - title: 结果标题。
  158. - output: 可读的匹配结果文本(每行:推导选题点、帖子选题点、匹配分数)。
  159. - 出错时 error 为错误信息。
  160. """
  161. topic_path = _post_topic_file(account_name, post_id)
  162. if not topic_path.is_file():
  163. return ToolResult(
  164. title="帖子选题点文件不存在",
  165. output=f"帖子选题点文件不存在: {topic_path}",
  166. error="Post topic file not found",
  167. )
  168. try:
  169. derivation_points = _to_derivation_points(derivation_output_points)
  170. if not derivation_points:
  171. return ToolResult(
  172. title="参数无效",
  173. output="derivation_output_points 不能为空,且需为字符串列表",
  174. error="Invalid derivation_output_points",
  175. )
  176. matched = await match_derivation_to_post_points(
  177. derivation_output_points, account_name, post_id, match_threshold
  178. )
  179. if not matched:
  180. output = f"未找到分数 >= {match_threshold} 的匹配对"
  181. else:
  182. lines = [
  183. f"- 推导: {x['推导选题点']}\t帖子: {x['帖子选题点']}\t分数={x['匹配分数']}"
  184. for x in matched
  185. ]
  186. output = "\n".join(lines)
  187. return ToolResult(
  188. title=f"选题点匹配结果 ({account_name}, post_id={post_id})",
  189. output=output,
  190. metadata={
  191. "account_name": account_name,
  192. "post_id": post_id,
  193. "match_threshold": match_threshold,
  194. "count": len(matched),
  195. },
  196. )
  197. except Exception as e:
  198. return ToolResult(
  199. title="选题点匹配失败",
  200. output=str(e),
  201. error=str(e),
  202. )
  203. def main() -> None:
  204. """本地测试:用家有大志账号、某帖子ID、推导选题点列表测试匹配。"""
  205. import asyncio
  206. account_name = "家有大志"
  207. post_id = "68fb6a5c000000000302e5de"
  208. derivation_output_points = ["分享+创意改造", "柴犬", "不存在的点"]
  209. async def run():
  210. matched = await match_derivation_to_post_points(
  211. derivation_output_points, account_name, post_id
  212. )
  213. print(f"账号: {account_name}, post_id: {post_id}")
  214. print(f"推导选题点: {derivation_output_points}")
  215. print(f"匹配成功 {len(matched)} 条:\n")
  216. for x in matched:
  217. print(f" - 推导: {x['推导选题点']}\t帖子: {x['帖子选题点']}\t分数={x['匹配分数']}")
  218. if ToolResult is not None:
  219. result = await point_match(
  220. derivation_output_points=derivation_output_points,
  221. account_name=account_name,
  222. post_id=post_id,
  223. )
  224. print("\n--- Tool 返回 ---")
  225. print(result.output)
  226. asyncio.run(run())
  227. if __name__ == "__main__":
  228. main()