| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260 |
- """
- 选题点匹配 Tool - 判断推导选题点是否与帖子中的选题点匹配
- 功能:读取帖子选题点列表,与推导选题点做相似度计算,返回 combined_score >= 阈值的匹配对。
- """
- import json
- import sys
- from pathlib import Path
- from typing import Any, List, Optional
- # 保证直接运行或作为包加载时都能解析 utils(IDE 可跳转)
- _root = Path(__file__).resolve().parent.parent
- if str(_root) not in sys.path:
- sys.path.insert(0, str(_root))
- from utils.similarity_calc import similarity_matrix
- try:
- from agent.tools import tool, ToolResult, ToolContext
- except ImportError:
- def tool(*args, **kwargs):
- return lambda f: f
- ToolResult = None
- ToolContext = None
- _BASE_INPUT = Path(__file__).resolve().parent.parent / "input"
- # 默认匹配阈值
- DEFAULT_MATCH_THRESHOLD = 0.6
- def _post_topic_file(account_name: str, post_id: str) -> Path:
- """帖子选题点文件:../input/{account_name}/post_topic/{post_id}.json"""
- return _BASE_INPUT / account_name / "post_topic" / f"{post_id}.json"
- def _match_data_file(account_name: str, post_id: str) -> Path:
- """帖子选题点与人设树节点匹配结果文件:../input/{account_name}/match_data/{post_id}_匹配_all.json"""
- return _BASE_INPUT / account_name / "match_data" / f"{post_id}_匹配_all.json"
- def _load_match_data(
- account_name: str, post_id: str
- ) -> dict[tuple[str, str], float]:
- """
- 从匹配文件中读取 (帖子选题点, 人设树节点) -> 匹配分数。
- 文件为 JSON 数组,每项 name 为帖子选题点,match_personas 中 name 为人设树节点,match_score 为分数。
- """
- path = _match_data_file(account_name, post_id)
- if not path.is_file():
- return {}
- with open(path, "r", encoding="utf-8") as f:
- data = json.load(f)
- if not isinstance(data, list):
- return {}
- lookup: dict[tuple[str, str], float] = {}
- for item in data:
- if not isinstance(item, dict):
- continue
- post_point = item.get("name")
- personas = item.get("match_personas")
- if post_point is None or not isinstance(personas, list):
- continue
- post_point = str(post_point).strip()
- if not post_point:
- continue
- for mp in personas:
- if not isinstance(mp, dict):
- continue
- persona_name = mp.get("name")
- score = mp.get("match_score")
- if persona_name is None or score is None:
- continue
- persona_name = str(persona_name).strip()
- try:
- lookup[(post_point, persona_name)] = float(score)
- except (TypeError, ValueError):
- continue
- return lookup
- def _load_post_topic_points(account_name: str, post_id: str) -> List[str]:
- """从 post_topic JSON 读取帖子选题点列表。文件内容为字符串数组。"""
- path = _post_topic_file(account_name, post_id)
- if not path.is_file():
- return []
- with open(path, "r", encoding="utf-8") as f:
- data = json.load(f)
- if not isinstance(data, list):
- return []
- return [str(x).strip() for x in data if x and str(x).strip()]
- def _to_derivation_points(derivation_output_points: List[str]) -> List[str]:
- """从推导选题点字符串列表中筛出非空并 strip,返回列表。"""
- return [s.strip() for s in derivation_output_points if s is not None and str(s).strip()]
- async def match_derivation_to_post_points(
- derivation_output_points: List[str],
- account_name: str,
- post_id: str,
- match_threshold: float = DEFAULT_MATCH_THRESHOLD,
- ) -> List[dict[str, Any]]:
- """
- 判断推导选题点(视为人设树节点)是否与帖子选题点匹配,返回分数 >= match_threshold 的列表。
- 优先从 ../input/{account_name}/match_data/{post_id}_匹配_all.json 读取匹配分数,
- 文件中未出现的 (帖子选题点, 人设树节点) 再通过 similarity_matrix 计算。
- Returns:
- 每项: {"推导选题点": str, "帖子选题点": str, "匹配分数": float}
- """
- post_points = _load_post_topic_points(account_name, post_id)
- derivation_points = _to_derivation_points(derivation_output_points)
- if not derivation_points:
- return []
- if not post_points:
- return []
- # 从匹配文件读取 (帖子选题点, 人设树节点) -> 匹配分数;derivation_points 当作人设树节点
- match_lookup = _load_match_data(account_name, post_id)
- scores: dict[tuple[str, str], float] = {}
- missing_pairs: list[tuple[str, str]] = []
- for d in derivation_points:
- for p in post_points:
- key = (p, d)
- if key in match_lookup:
- scores[(d, p)] = match_lookup[key]
- else:
- missing_pairs.append((d, p))
- # 文件中没有的 (推导选题点, 帖子选题点) 用 similarity_matrix 计算
- if missing_pairs:
- derivation_missing = list({d for d, _ in missing_pairs})
- post_missing = list({p for _, p in missing_pairs})
- items = await similarity_matrix(derivation_missing, post_missing)
- for row in items:
- d, p = row["phrase_a"], row["phrase_b"]
- scores[(d, p)] = row["combined_score"]
- matched = []
- for (d, p), score in scores.items():
- if score >= match_threshold:
- matched.append({
- "推导选题点": d,
- "帖子选题点": p,
- "匹配分数": round(score, 6),
- })
- return matched
- @tool(
- description="判断推导选题点(人设树节点)与帖子选题点是否匹配。"
- "功能:根据账号与帖子ID,将传入的推导选题点列表与帖子选题点做匹配,返回达到阈值的匹配对。"
- "参数:derivation_output_points 为推导选题点字符串列表;account_name 为账号名;post_id 为帖子ID;match_threshold 为匹配分数阈值,默认 0.8。"
- "返回:ToolResult,output 为可读匹配结果文本,metadata.items 为匹配列表,每项含「推导选题点」「帖子选题点」「匹配分数」。"
- )
- async def point_match(
- derivation_output_points: List[str],
- account_name: str,
- post_id: str,
- match_threshold: float = DEFAULT_MATCH_THRESHOLD,
- context: Optional[ToolContext] = None,
- ) -> ToolResult:
- """
- 判断推导选题点与帖子选题点是否匹配,返回达到阈值的匹配对。
- 参数
- -------
- derivation_output_points : 推导选题点字符串列表。
- account_name : 账号名,用于定位 input 下的账号目录。
- post_id : 帖子ID,用于定位该帖的选题点与匹配数据。
- match_threshold : 匹配分数阈值,分数 >= 该值视为匹配成功,默认 0.8。
- context : 可选,Agent 工具上下文。
- 返回
- -------
- ToolResult:
- - title: 结果标题。
- - output: 可读的匹配结果文本(每行:推导选题点、帖子选题点、匹配分数)。
- - metadata: 含 account_name、post_id、match_threshold、count、items;
- items 为列表,每项为 {"推导选题点": str, "帖子选题点": str, "匹配分数": float}。
- - 出错时 error 为错误信息。
- """
- topic_path = _post_topic_file(account_name, post_id)
- if not topic_path.is_file():
- return ToolResult(
- title="帖子选题点文件不存在",
- output=f"帖子选题点文件不存在: {topic_path}",
- error="Post topic file not found",
- )
- try:
- derivation_points = _to_derivation_points(derivation_output_points)
- if not derivation_points:
- return ToolResult(
- title="参数无效",
- output="derivation_output_points 不能为空,且需为字符串列表",
- error="Invalid derivation_output_points",
- )
- matched = await match_derivation_to_post_points(
- derivation_output_points, account_name, post_id, match_threshold
- )
- if not matched:
- output = f"未找到 combined_score >= {match_threshold} 的匹配"
- else:
- lines = [
- f"- 推导: {x['推导选题点']}\t帖子: {x['帖子选题点']}\t分数={x['匹配分数']}"
- for x in matched
- ]
- output = "\n".join(lines)
- return ToolResult(
- title=f"选题点匹配结果 ({account_name}, post_id={post_id})",
- output=output,
- metadata={
- "account_name": account_name,
- "post_id": post_id,
- "match_threshold": match_threshold,
- "count": len(matched),
- "items": matched,
- },
- )
- except Exception as e:
- return ToolResult(
- title="选题点匹配失败",
- output=str(e),
- error=str(e),
- )
- def main() -> None:
- """本地测试:用家有大志账号、某帖子ID、推导选题点列表测试匹配。"""
- import asyncio
- account_name = "家有大志"
- post_id = "68fb6a5c000000000302e5de"
- derivation_output_points = ["分享", "创意改造", "柴犬", "不存在的点"]
- async def run():
- matched = await match_derivation_to_post_points(
- derivation_output_points, account_name, post_id
- )
- print(f"账号: {account_name}, post_id: {post_id}")
- print(f"推导选题点: {derivation_output_points}")
- print(f"匹配成功 {len(matched)} 条:\n")
- for x in matched:
- print(f" - 推导: {x['推导选题点']}\t帖子: {x['帖子选题点']}\t分数={x['匹配分数']}")
- if ToolResult is not None:
- result = await point_match(
- derivation_output_points=derivation_output_points,
- account_name=account_name,
- post_id=post_id,
- )
- print("\n--- Tool 返回 ---")
- print(result.output)
- asyncio.run(run())
- if __name__ == "__main__":
- main()
|