""" 选题点匹配 Tool - 判断推导选题点是否与帖子中的选题点匹配 功能:读取帖子选题点列表,与推导选题点做相似度计算,返回 combined_score >= 阈值的匹配对。 """ import json import sys from pathlib import Path from typing import Any, List, Optional # 保证直接运行或作为包加载时都能解析 utils(IDE 可跳转) _root = Path(__file__).resolve().parent.parent if str(_root) not in sys.path: sys.path.insert(0, str(_root)) from utils.similarity_calc import similarity_matrix try: from agent.tools import tool, ToolResult, ToolContext except ImportError: def tool(*args, **kwargs): return lambda f: f ToolResult = None ToolContext = None _BASE_INPUT = Path(__file__).resolve().parent.parent / "input" # 默认匹配阈值 DEFAULT_MATCH_THRESHOLD = 0.6 def _post_topic_file(account_name: str, post_id: str) -> Path: """帖子选题点文件:../input/{account_name}/post_topic/{post_id}.json""" return _BASE_INPUT / account_name / "post_topic" / f"{post_id}.json" def _match_data_file(account_name: str, post_id: str) -> Path: """帖子选题点与人设树节点匹配结果文件:../input/{account_name}/match_data/{post_id}_匹配_all.json""" return _BASE_INPUT / account_name / "match_data" / f"{post_id}_匹配_all.json" def _load_match_data( account_name: str, post_id: str ) -> dict[tuple[str, str], float]: """ 从匹配文件中读取 (帖子选题点, 人设树节点) -> 匹配分数。 文件为 JSON 数组,每项 name 为帖子选题点,match_personas 中 name 为人设树节点,match_score 为分数。 """ path = _match_data_file(account_name, post_id) if not path.is_file(): return {} with open(path, "r", encoding="utf-8") as f: data = json.load(f) if not isinstance(data, list): return {} lookup: dict[tuple[str, str], float] = {} for item in data: if not isinstance(item, dict): continue post_point = item.get("name") personas = item.get("match_personas") if post_point is None or not isinstance(personas, list): continue post_point = str(post_point).strip() if not post_point: continue for mp in personas: if not isinstance(mp, dict): continue persona_name = mp.get("name") score = mp.get("match_score") if persona_name is None or score is None: continue persona_name = str(persona_name).strip() try: lookup[(post_point, persona_name)] = float(score) except (TypeError, ValueError): continue return lookup def _load_post_topic_points(account_name: str, post_id: str) -> List[str]: """从 post_topic JSON 读取帖子选题点列表。文件内容为字符串数组。""" path = _post_topic_file(account_name, post_id) if not path.is_file(): return [] with open(path, "r", encoding="utf-8") as f: data = json.load(f) if not isinstance(data, list): return [] return [str(x).strip() for x in data if x and str(x).strip()] def _to_derivation_points(derivation_output_points: List[str]) -> List[str]: """从推导选题点字符串列表中筛出非空并 strip,返回列表。""" return [s.strip() for s in derivation_output_points if s is not None and str(s).strip()] async def match_derivation_to_post_points( derivation_output_points: List[str], account_name: str, post_id: str, match_threshold: float = DEFAULT_MATCH_THRESHOLD, ) -> List[dict[str, Any]]: """ 判断推导选题点(视为人设树节点)是否与帖子选题点匹配,返回分数 >= match_threshold 的列表。 优先从 ../input/{account_name}/match_data/{post_id}_匹配_all.json 读取匹配分数, 文件中未出现的 (帖子选题点, 人设树节点) 再通过 similarity_matrix 计算。 Returns: 每项: {"推导选题点": str, "帖子选题点": str, "匹配分数": float} """ post_points = _load_post_topic_points(account_name, post_id) derivation_points = _to_derivation_points(derivation_output_points) if not derivation_points: return [] if not post_points: return [] # 从匹配文件读取 (帖子选题点, 人设树节点) -> 匹配分数;derivation_points 当作人设树节点 match_lookup = _load_match_data(account_name, post_id) scores: dict[tuple[str, str], float] = {} missing_pairs: list[tuple[str, str]] = [] for d in derivation_points: for p in post_points: key = (p, d) if key in match_lookup: scores[(d, p)] = match_lookup[key] else: missing_pairs.append((d, p)) # 文件中没有的 (推导选题点, 帖子选题点) 用 similarity_matrix 计算 if missing_pairs: derivation_missing = list({d for d, _ in missing_pairs}) post_missing = list({p for _, p in missing_pairs}) items = await similarity_matrix(derivation_missing, post_missing) for row in items: d, p = row["phrase_a"], row["phrase_b"] scores[(d, p)] = row["combined_score"] matched = [] for (d, p), score in scores.items(): if score >= match_threshold: matched.append({ "推导选题点": d, "帖子选题点": p, "匹配分数": round(score, 6), }) return matched @tool( description="判断推导选题点(人设树节点)与帖子选题点是否匹配。" "功能:根据账号与帖子ID,将传入的推导选题点列表与帖子选题点做匹配,返回达到阈值的匹配对。" "参数:derivation_output_points 为推导选题点字符串列表;account_name 为账号名;post_id 为帖子ID;match_threshold 为匹配分数阈值,默认 0.8。" "返回:ToolResult,output 为可读匹配结果文本,metadata.items 为匹配列表,每项含「推导选题点」「帖子选题点」「匹配分数」。" ) async def point_match( derivation_output_points: List[str], account_name: str, post_id: str, match_threshold: float = DEFAULT_MATCH_THRESHOLD, context: Optional[ToolContext] = None, ) -> ToolResult: """ 判断推导选题点与帖子选题点是否匹配,返回达到阈值的匹配对。 参数 ------- derivation_output_points : 推导选题点字符串列表。 account_name : 账号名,用于定位 input 下的账号目录。 post_id : 帖子ID,用于定位该帖的选题点与匹配数据。 match_threshold : 匹配分数阈值,分数 >= 该值视为匹配成功,默认 0.8。 context : 可选,Agent 工具上下文。 返回 ------- ToolResult: - title: 结果标题。 - output: 可读的匹配结果文本(每行:推导选题点、帖子选题点、匹配分数)。 - metadata: 含 account_name、post_id、match_threshold、count、items; items 为列表,每项为 {"推导选题点": str, "帖子选题点": str, "匹配分数": float}。 - 出错时 error 为错误信息。 """ topic_path = _post_topic_file(account_name, post_id) if not topic_path.is_file(): return ToolResult( title="帖子选题点文件不存在", output=f"帖子选题点文件不存在: {topic_path}", error="Post topic file not found", ) try: derivation_points = _to_derivation_points(derivation_output_points) if not derivation_points: return ToolResult( title="参数无效", output="derivation_output_points 不能为空,且需为字符串列表", error="Invalid derivation_output_points", ) matched = await match_derivation_to_post_points( derivation_output_points, account_name, post_id, match_threshold ) if not matched: output = f"未找到 combined_score >= {match_threshold} 的匹配" else: lines = [ f"- 推导: {x['推导选题点']}\t帖子: {x['帖子选题点']}\t分数={x['匹配分数']}" for x in matched ] output = "\n".join(lines) return ToolResult( title=f"选题点匹配结果 ({account_name}, post_id={post_id})", output=output, metadata={ "account_name": account_name, "post_id": post_id, "match_threshold": match_threshold, "count": len(matched), "items": matched, }, ) except Exception as e: return ToolResult( title="选题点匹配失败", output=str(e), error=str(e), ) def main() -> None: """本地测试:用家有大志账号、某帖子ID、推导选题点列表测试匹配。""" import asyncio account_name = "家有大志" post_id = "68fb6a5c000000000302e5de" derivation_output_points = ["分享", "创意改造", "柴犬", "不存在的点"] async def run(): matched = await match_derivation_to_post_points( derivation_output_points, account_name, post_id ) print(f"账号: {account_name}, post_id: {post_id}") print(f"推导选题点: {derivation_output_points}") print(f"匹配成功 {len(matched)} 条:\n") for x in matched: print(f" - 推导: {x['推导选题点']}\t帖子: {x['帖子选题点']}\t分数={x['匹配分数']}") if ToolResult is not None: result = await point_match( derivation_output_points=derivation_output_points, account_name=account_name, post_id=post_id, ) print("\n--- Tool 返回 ---") print(result.output) asyncio.run(run()) if __name__ == "__main__": main()