|
|
@@ -0,0 +1,359 @@
|
|
|
+"""
|
|
|
+搜索评估工具:搜索帖子并评估是否与账号人设匹配,提取关键词并匹配选题点。
|
|
|
+
|
|
|
+处理流程:
|
|
|
+1. 使用 xhs(失败或空则用 zhihu)搜索帖子
|
|
|
+2. 并发对每篇帖子调用 LLM 判断人设匹配 & 提取关键词
|
|
|
+3. 对匹配人设的帖子,调用 match_derivation_to_post_points 匹配选题点
|
|
|
+4. 返回完整评估结果列表
|
|
|
+"""
|
|
|
+
|
|
|
+import asyncio
|
|
|
+import json
|
|
|
+import logging
|
|
|
+import re
|
|
|
+import sys
|
|
|
+from pathlib import Path
|
|
|
+from typing import Any, Dict, List, Optional
|
|
|
+
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+import httpx
|
|
|
+
|
|
|
+# 保证直接运行或作为包加载时都能解析 utils/tools(IDE 可跳转)
|
|
|
+_root = Path(__file__).resolve().parent.parent
|
|
|
+if str(_root) not in sys.path:
|
|
|
+ sys.path.insert(0, str(_root))
|
|
|
+
|
|
|
+from tools.point_match import match_derivation_to_post_points
|
|
|
+
|
|
|
+try:
|
|
|
+ from agent.tools import tool, ToolResult, ToolContext
|
|
|
+ from agent.llm.openrouter import openrouter_llm_call
|
|
|
+except ImportError:
|
|
|
+ def tool(*args, **kwargs):
|
|
|
+ return lambda f: f
|
|
|
+ ToolResult = None
|
|
|
+ ToolContext = None
|
|
|
+ openrouter_llm_call = None
|
|
|
+
|
|
|
+_BASE_INPUT = Path(__file__).resolve().parent.parent / "input"
|
|
|
+_TOOLS_DIR = Path(__file__).resolve().parent
|
|
|
+
|
|
|
+BASE_URL = "http://aigc-channel.aiddit.com/aigc/channel"
|
|
|
+DEFAULT_TIMEOUT = 60.0
|
|
|
+
|
|
|
+# 支持多模态(视觉+文本)的 LLM 模型
|
|
|
+EVAL_LLM_MODEL = "google/gemini-3-flash-preview"
|
|
|
+
|
|
|
+# 每篇帖子最多传入的图片数量(避免 token 过多)
|
|
|
+MAX_IMAGES_PER_POST = 20
|
|
|
+
|
|
|
+
|
|
|
+def _load_match_and_extract_prompt() -> str:
|
|
|
+ """读取帖子人设匹配 & 关键词提取的 system prompt 模板"""
|
|
|
+ prompt_file = _TOOLS_DIR / "match_and_extract_prompt.md"
|
|
|
+ with open(prompt_file, "r", encoding="utf-8") as f:
|
|
|
+ return f.read()
|
|
|
+
|
|
|
+
|
|
|
+def _load_persona_text(account_name: str) -> str:
|
|
|
+ """读取账号人设摘要,返回可读字符串;文件不存在时返回空人设提示"""
|
|
|
+ persona_file = _BASE_INPUT / account_name / "persona_data" / "persona_summary.json"
|
|
|
+ if not persona_file.is_file():
|
|
|
+ logger.warning("_load_persona_text: persona file not found: %s", persona_file)
|
|
|
+ return f"账号:{account_name}(暂无人设数据)"
|
|
|
+ with open(persona_file, "r", encoding="utf-8") as f:
|
|
|
+ data = json.load(f)
|
|
|
+ logger.debug("_load_persona_text: loaded persona for account=%s", account_name)
|
|
|
+ return json.dumps(data, ensure_ascii=False, indent=2)
|
|
|
+
|
|
|
+
|
|
|
+async def _do_search(query: str, channel: str) -> Optional[List[dict]]:
|
|
|
+ """执行单次搜索,返回帖子列表;失败或空列表返回 None"""
|
|
|
+ logger.debug("_do_search: channel=%s, query=%s", channel, query)
|
|
|
+ payload = {
|
|
|
+ "type": channel,
|
|
|
+ "keyword": query,
|
|
|
+ "cursor": "0",
|
|
|
+ "max_count": 5,
|
|
|
+ "content_type": "图文",
|
|
|
+ }
|
|
|
+ try:
|
|
|
+ async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
|
|
|
+ resp = await client.post(
|
|
|
+ f"{BASE_URL}/data",
|
|
|
+ json=payload,
|
|
|
+ headers={"Content-Type": "application/json"},
|
|
|
+ )
|
|
|
+ resp.raise_for_status()
|
|
|
+ data = resp.json()
|
|
|
+ posts = data.get("data") or []
|
|
|
+ count = len(posts) if posts else 0
|
|
|
+ logger.info("_do_search: channel=%s, query=%s -> %d posts", channel, query, count)
|
|
|
+ return posts if posts else None
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning("_do_search: channel=%s, query=%s failed: %s", channel, query, e)
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+async def _search_posts(query: str) -> List[dict]:
|
|
|
+ """优先用 xhs 搜索,失败或空则用 zhihu,返回帖子列表"""
|
|
|
+ posts = await _do_search(query, "xhs")
|
|
|
+ if posts:
|
|
|
+ logger.info("_search_posts: using xhs, %d posts for query=%s", len(posts), query)
|
|
|
+ return posts
|
|
|
+ posts = await _do_search(query, "zhihu")
|
|
|
+ if posts:
|
|
|
+ logger.info("_search_posts: xhs empty/failed, using zhihu, %d posts for query=%s", len(posts), query)
|
|
|
+ else:
|
|
|
+ logger.warning("_search_posts: no posts from xhs or zhihu for query=%s", query)
|
|
|
+ return posts or []
|
|
|
+
|
|
|
+
|
|
|
+def _build_user_message_content(post: dict) -> List[dict]:
|
|
|
+ """
|
|
|
+ 将帖子数据构建为 OpenAI 多模态 user message content。
|
|
|
+ 包含帖子文本描述 + 前 MAX_IMAGES_PER_POST 张图片。
|
|
|
+ """
|
|
|
+ parts: List[dict] = []
|
|
|
+
|
|
|
+ # 文本部分:将帖子的关键字段序列化给 LLM
|
|
|
+ post_text = json.dumps(
|
|
|
+ {
|
|
|
+ "channel_content_id": post.get("channel_content_id", ""),
|
|
|
+ "title": post.get("title", ""),
|
|
|
+ "body_text": post.get("body_text", ""),
|
|
|
+ },
|
|
|
+ ensure_ascii=False,
|
|
|
+ )
|
|
|
+ parts.append({"type": "text", "text": post_text})
|
|
|
+
|
|
|
+ # 图片部分
|
|
|
+ images = post.get("images") or []
|
|
|
+ for img_url in images[:MAX_IMAGES_PER_POST]:
|
|
|
+ if img_url:
|
|
|
+ parts.append({"type": "image_url", "image_url": {"url": img_url}})
|
|
|
+
|
|
|
+ return parts
|
|
|
+
|
|
|
+
|
|
|
+def _extract_json_object(content: str) -> dict:
|
|
|
+ """从 LLM 回复中解析第一个 JSON 对象(允许被 ```json ... ``` 包裹)"""
|
|
|
+ content = content.strip()
|
|
|
+ m = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", content)
|
|
|
+ if m:
|
|
|
+ content = m.group(1).strip()
|
|
|
+ # 找到最外层 { ... }
|
|
|
+ start = content.find("{")
|
|
|
+ end = content.rfind("}")
|
|
|
+ if start != -1 and end != -1:
|
|
|
+ content = content[start : end + 1]
|
|
|
+ return json.loads(content)
|
|
|
+
|
|
|
+
|
|
|
+async def _eval_single_post(
|
|
|
+ post: dict,
|
|
|
+ system_prompt: str,
|
|
|
+ account_name: str,
|
|
|
+ post_id: str,
|
|
|
+) -> dict:
|
|
|
+ """
|
|
|
+ 评估单篇帖子:
|
|
|
+ 1. 调用 LLM 判断人设匹配并提取关键词
|
|
|
+ 2. 若匹配,调用 match_derivation_to_post_points 匹配选题点
|
|
|
+ 返回完整评估结果字典。
|
|
|
+ """
|
|
|
+ post_cid = post.get("channel_content_id", "")
|
|
|
+ result: dict = {
|
|
|
+ "channel_content_id": post_cid,
|
|
|
+ "title": post.get("title", ""),
|
|
|
+ "body_text": post.get("body_text", ""),
|
|
|
+ "images": post.get("images") or [],
|
|
|
+ "persona_match_result": False,
|
|
|
+ "persona_match_reason": "",
|
|
|
+ "post_keywords": [],
|
|
|
+ "point_match_results": [],
|
|
|
+ }
|
|
|
+
|
|
|
+ try:
|
|
|
+ logger.debug("_eval_single_post: evaluating post_id=%s, title=%s", post_cid, (result["title"] or "")[:40])
|
|
|
+ user_content = _build_user_message_content(post)
|
|
|
+ messages = [
|
|
|
+ {"role": "system", "content": system_prompt},
|
|
|
+ {"role": "user", "content": user_content},
|
|
|
+ ]
|
|
|
+
|
|
|
+ llm_result = await openrouter_llm_call(messages, model=EVAL_LLM_MODEL)
|
|
|
+ content = llm_result.get("content", "")
|
|
|
+ if not content:
|
|
|
+ result["error"] = "LLM 未返回内容"
|
|
|
+ logger.warning("_eval_single_post: post_id=%s LLM returned empty content", post_cid)
|
|
|
+ return result
|
|
|
+
|
|
|
+ parsed = _extract_json_object(content)
|
|
|
+ result["persona_match_result"] = bool(parsed.get("persona_match_result", False))
|
|
|
+ result["persona_match_reason"] = parsed.get("persona_match_reason", "")
|
|
|
+ result["post_keywords"] = parsed.get("post_keywords") or []
|
|
|
+ logger.info(
|
|
|
+ "_eval_single_post: post_id=%s persona_match=%s keywords=%s",
|
|
|
+ post_cid,
|
|
|
+ result["persona_match_result"],
|
|
|
+ result["post_keywords"],
|
|
|
+ )
|
|
|
+
|
|
|
+ # 仅对与人设匹配的帖子做选题点匹配
|
|
|
+ if result["persona_match_result"] and result["post_keywords"]:
|
|
|
+ matched = await match_derivation_to_post_points(
|
|
|
+ result["post_keywords"], account_name, post_id
|
|
|
+ )
|
|
|
+ result["point_match_results"] = matched
|
|
|
+ logger.info(
|
|
|
+ "_eval_single_post: post_id=%s point_match count=%d",
|
|
|
+ post_cid,
|
|
|
+ len(matched),
|
|
|
+ )
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.exception("_eval_single_post: post_id=%s error: %s", post_cid, e)
|
|
|
+ result["error"] = str(e)
|
|
|
+
|
|
|
+ return result
|
|
|
+
|
|
|
+
|
|
|
+@tool(
|
|
|
+ description=(
|
|
|
+ "搜索帖子并评估是否与账号人设匹配,提取帖子关键词并与帖子选题点进行匹配。"
|
|
|
+ "参数:account_name 账号名称;post_id 帖子ID;query 搜索词。"
|
|
|
+ )
|
|
|
+)
|
|
|
+async def search_and_eval(
|
|
|
+ account_name: str,
|
|
|
+ post_id: str,
|
|
|
+ query: str,
|
|
|
+ context: Optional[ToolContext] = None,
|
|
|
+) -> ToolResult:
|
|
|
+ """
|
|
|
+ 搜索帖子并评估是否与账号人设匹配,提取关键词并匹配选题点。
|
|
|
+
|
|
|
+ Args:
|
|
|
+ account_name: 账号名称,用于读取人设数据和选题点文件
|
|
|
+ post_id: 帖子ID,用于定位选题点匹配文件
|
|
|
+ query: 搜索词
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ ToolResult,output 为 JSON 格式的帖子评估结果列表,每项包含:
|
|
|
+ - channel_content_id: 帖子ID
|
|
|
+ - title: 标题
|
|
|
+ - body_text: 正文
|
|
|
+ - images: 图集URL列表
|
|
|
+ - persona_match_result: 是否与账号人设匹配(bool)
|
|
|
+ - post_keywords: 提取的帖子关键词列表
|
|
|
+ - point_match_results: 关键词与帖子选题点的匹配结果列表,
|
|
|
+ 每项含「推导选题点」「帖子选题点」「匹配分数」
|
|
|
+ """
|
|
|
+ logger.info(
|
|
|
+ "search_and_eval: account_name=%s post_id=%s query=%s",
|
|
|
+ account_name,
|
|
|
+ post_id,
|
|
|
+ query,
|
|
|
+ )
|
|
|
+ try:
|
|
|
+ # 1. 搜索帖子
|
|
|
+ posts = await _search_posts(query)
|
|
|
+ if not posts:
|
|
|
+ logger.warning("search_and_eval: no posts found for query=%s", query)
|
|
|
+ return ToolResult(
|
|
|
+ title=f"搜索评估: {query}",
|
|
|
+ output="[]",
|
|
|
+ long_term_memory=f"search_and_eval: query='{query}', no posts found",
|
|
|
+ )
|
|
|
+
|
|
|
+ logger.info("search_and_eval: got %d posts, loading prompt and persona", len(posts))
|
|
|
+ # 2. 构建 system prompt(替换账号人设)
|
|
|
+ prompt_template = _load_match_and_extract_prompt()
|
|
|
+ persona_text = _load_persona_text(account_name)
|
|
|
+ system_prompt = prompt_template.replace("{persona}", persona_text)
|
|
|
+
|
|
|
+ # 3. 并发评估所有帖子
|
|
|
+ tasks = [
|
|
|
+ _eval_single_post(post, system_prompt, account_name, post_id)
|
|
|
+ for post in posts
|
|
|
+ ]
|
|
|
+ results: List[dict] = await asyncio.gather(*tasks)
|
|
|
+
|
|
|
+ matched_count = sum(1 for r in results if r.get("persona_match_result"))
|
|
|
+ error_count = sum(1 for r in results if r.get("error"))
|
|
|
+ logger.info(
|
|
|
+ "search_and_eval: done. total=%d persona_matched=%d errors=%d",
|
|
|
+ len(results),
|
|
|
+ matched_count,
|
|
|
+ error_count,
|
|
|
+ )
|
|
|
+ output = json.dumps(results, ensure_ascii=False, indent=2)
|
|
|
+ logger.info("search_and_eval: output=%s", output)
|
|
|
+
|
|
|
+ return ToolResult(
|
|
|
+ title=(
|
|
|
+ f"搜索评估: {query} "
|
|
|
+ f"(共 {len(results)} 条,{matched_count} 条匹配人设)"
|
|
|
+ ),
|
|
|
+ output=output,
|
|
|
+ long_term_memory=(
|
|
|
+ f"search_and_eval: query='{query}', "
|
|
|
+ f"found {len(results)} posts, {matched_count} matched persona"
|
|
|
+ ),
|
|
|
+ metadata={"items": results},
|
|
|
+ )
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.exception("search_and_eval: failed: %s", e)
|
|
|
+ return ToolResult(
|
|
|
+ title="搜索评估失败",
|
|
|
+ output="",
|
|
|
+ error=str(e),
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+def main() -> None:
|
|
|
+ """本地测试:用家有大志账号测试搜索评估"""
|
|
|
+ import asyncio
|
|
|
+
|
|
|
+ logging.basicConfig(
|
|
|
+ level=logging.DEBUG,
|
|
|
+ format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
|
|
+ datefmt="%H:%M:%S",
|
|
|
+ )
|
|
|
+ account_name = "家有大志"
|
|
|
+ post_id = "68fb6a5c000000000302e5de"
|
|
|
+ query = "柴犬 鞋子"
|
|
|
+
|
|
|
+ async def run():
|
|
|
+ if ToolResult is None:
|
|
|
+ print("agent 依赖未安装,无法直接运行 tool 版本")
|
|
|
+ return
|
|
|
+ result = await search_and_eval(
|
|
|
+ account_name=account_name,
|
|
|
+ post_id=post_id,
|
|
|
+ query=query,
|
|
|
+ )
|
|
|
+ if result.error:
|
|
|
+ print(f"Error: {result.error}")
|
|
|
+ else:
|
|
|
+ print(result.title)
|
|
|
+ data = json.loads(result.output)
|
|
|
+ for item in data:
|
|
|
+ print(
|
|
|
+ f" [{item.get('persona_match_result')}] {item.get('title', '')[:30]}"
|
|
|
+ f" | keywords: {item.get('post_keywords')}"
|
|
|
+ f" | matches: {len(item.get('point_match_results', []))}"
|
|
|
+ )
|
|
|
+
|
|
|
+ asyncio.run(run())
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ _project_root = str(Path(__file__).resolve().parent.parent.parent.parent)
|
|
|
+ if _project_root not in sys.path:
|
|
|
+ sys.path.insert(0, _project_root)
|
|
|
+ main()
|