""" 搜索评估工具:搜索帖子并评估是否与账号人设匹配,提取关键词并匹配选题点。 处理流程: 1. 接收 query_list(多个搜索 query),并发处理 2. 每个 query:使用 xhs(失败或空则用 zhihu)搜索帖子 3. 并发对每篇帖子调用 LLM 判断人设匹配 & 提取关键词 4. 对匹配人设的帖子,调用 match_derivation_to_post_points 匹配选题点 5. 返回按 query 分组的评估结果字典 6. 支持本地文件缓存(.cache/search/{account_name}/{post_id}/) """ import asyncio import hashlib import json import logging import re import sys from pathlib import Path from typing import Dict, List, Optional logger = logging.getLogger(__name__) import httpx # 保证直接运行或作为包加载时都能解析 utils/tools(IDE 可跳转) _root = Path(__file__).resolve().parent.parent if str(_root) not in sys.path: sys.path.insert(0, str(_root)) from tools.point_match import match_derivation_to_post_points try: from agent.tools import tool, ToolResult, ToolContext from agent.llm.openrouter import openrouter_llm_call except ImportError: def tool(*args, **kwargs): return lambda f: f ToolResult = None ToolContext = None openrouter_llm_call = None _BASE_INPUT = Path(__file__).resolve().parent.parent / "input" _TOOLS_DIR = Path(__file__).resolve().parent _CACHE_ROOT = Path(__file__).resolve().parent.parent / ".cache" / "search" BASE_URL = "http://aigc-channel.aiddit.com/aigc/channel" DEFAULT_TIMEOUT = 60.0 # 支持多模态(视觉+文本)的 LLM 模型 EVAL_LLM_MODEL = "google/gemini-3-flash-preview" # 每篇帖子最多传入的图片数量(避免 token 过多) MAX_IMAGES_PER_POST = 20 def _load_match_and_extract_prompt() -> str: """读取帖子人设匹配 & 关键词提取的 system prompt 模板""" prompt_file = _TOOLS_DIR / "match_and_extract_prompt.md" with open(prompt_file, "r", encoding="utf-8") as f: return f.read() def _load_persona_text(account_name: str) -> str: """读取账号人设摘要,返回可读字符串;文件不存在时返回空人设提示""" persona_file = _BASE_INPUT / account_name / "处理后数据" / "persona_data" / "persona_summary.json" if not persona_file.is_file(): logger.warning("_load_persona_text: persona file not found: %s", persona_file) return f"账号:{account_name}(暂无人设数据)" with open(persona_file, "r", encoding="utf-8") as f: data = json.load(f) # 去掉不需要给 LLM 的中间推理字段,避免 prompt 过长或泄露分析细节 if isinstance(data, dict): data.pop("分析过程", None) logger.debug("_load_persona_text: loaded persona for account=%s", account_name) return json.dumps(data, ensure_ascii=False, indent=2) async def _do_search(query: str, channel: str) -> Optional[List[dict]]: """执行单次搜索,返回帖子列表;失败或空列表返回 None""" logger.debug("_do_search: channel=%s, query=%s", channel, query) payload = { "type": channel, "keyword": query, "cursor": "0", "max_count": 10, "content_type": "图文", } try: async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client: resp = await client.post( f"{BASE_URL}/data", json=payload, headers={"Content-Type": "application/json"}, ) resp.raise_for_status() data = resp.json() posts = data.get("data") or [] count = len(posts) if posts else 0 logger.info("_do_search: channel=%s, query=%s -> %d posts", channel, query, count) return posts if posts else None except Exception as e: logger.warning("_do_search: channel=%s, query=%s failed: %s", channel, query, e) return None async def _search_posts(query: str) -> List[dict]: """优先用 xhs 搜索,失败或空则用 zhihu,返回帖子列表""" xhs_query = query.replace(" ", "") posts = await _do_search(xhs_query, "xhs") if posts: logger.info("_search_posts: using xhs, %d posts for query=%s", len(posts), query) return posts posts = await _do_search(query, "zhihu") if posts: logger.info("_search_posts: xhs empty/failed, using zhihu, %d posts for query=%s", len(posts), query) else: logger.warning("_search_posts: no posts from xhs or zhihu for query=%s", query) return posts or [] def _build_user_message_content(post: dict) -> List[dict]: """ 将帖子数据构建为 OpenAI 多模态 user message content。 包含帖子文本描述 + 前 MAX_IMAGES_PER_POST 张图片。 """ parts: List[dict] = [] # 文本部分:将帖子的关键字段序列化给 LLM post_text = json.dumps( { "channel_content_id": post.get("channel_content_id", ""), "title": post.get("title", ""), "body_text": post.get("body_text", ""), }, ensure_ascii=False, ) parts.append({"type": "text", "text": post_text}) # 图片部分 images = post.get("images") or [] for img_url in images[:MAX_IMAGES_PER_POST]: if img_url: parts.append({"type": "image_url", "image_url": {"url": img_url}}) return parts def _extract_json_object(content: str) -> dict: """从 LLM 回复中解析第一个 JSON 对象(允许被 ```json ... ``` 包裹)""" content = content.strip() m = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", content) if m: content = m.group(1).strip() # 找到最外层 { ... } start = content.find("{") end = content.rfind("}") if start != -1 and end != -1: content = content[start : end + 1] return json.loads(content) async def _eval_single_post( post: dict, system_prompt: str, account_name: str, post_id: str, ) -> dict: """ 评估单篇帖子: 1. 调用 LLM 判断人设匹配并提取关键词 2. 若匹配,调用 match_derivation_to_post_points 匹配选题点 返回完整评估结果字典。 """ post_cid = post.get("channel_content_id", "") result: dict = { "channel_content_id": post_cid, "title": post.get("title", ""), "body_text": post.get("body_text", ""), "images": post.get("images") or [], "persona_match_result": False, "persona_match_reason": "", "post_keywords": [], "point_match_results": [], } try: logger.debug("_eval_single_post: evaluating post_id=%s, title=%s", post_cid, (result["title"] or "")[:40]) user_content = _build_user_message_content(post) messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_content}, ] llm_result = await openrouter_llm_call(messages, model=EVAL_LLM_MODEL) content = llm_result.get("content", "") if not content: result["error"] = "LLM 未返回内容" logger.warning("_eval_single_post: post_id=%s LLM returned empty content", post_cid) return result parsed = _extract_json_object(content) result["persona_match_result"] = bool(parsed.get("persona_match_result", False)) result["persona_match_reason"] = parsed.get("persona_match_reason", "") result["post_keywords"] = parsed.get("post_keywords") or [] logger.info( "_eval_single_post: post_id=%s persona_match=%s keywords=%s", post_cid, result["persona_match_result"], result["post_keywords"], ) # 仅对与人设匹配的帖子做选题点匹配 if result["persona_match_result"] and result["post_keywords"]: matched = await match_derivation_to_post_points( result["post_keywords"], account_name, post_id ) result["point_match_results"] = matched logger.info( "_eval_single_post: post_id=%s point_match count=%d", post_cid, len(matched), ) except Exception as e: logger.exception("_eval_single_post: post_id=%s error: %s", post_cid, e) result["error"] = str(e) return result def _cache_key(query: str) -> str: """将 query 转为安全的文件名:使用 MD5 哈希避免特殊字符问题""" h = hashlib.md5(query.encode("utf-8")).hexdigest()[:12] safe = re.sub(r'[^\w\u4e00-\u9fff]+', '_', query)[:60].strip('_') return f"{safe}_{h}" def _get_cache_path(account_name: str, post_id: str, query: str) -> Path: return _CACHE_ROOT / account_name / post_id / f"{_cache_key(query)}.json" def _read_cache(account_name: str, post_id: str, query: str) -> Optional[List[dict]]: """读取缓存,存在且合法则返回帖子列表,否则返回 None""" path = _get_cache_path(account_name, post_id, query) if not path.is_file(): return None try: with open(path, "r", encoding="utf-8") as f: data = json.load(f) if isinstance(data, list): logger.info("_read_cache: hit cache for query=%s, %d items", query, len(data)) return data except Exception as e: logger.warning("_read_cache: failed to read cache for query=%s: %s", query, e) return None def _write_cache(account_name: str, post_id: str, query: str, results: List[dict]) -> None: """写入缓存""" path = _get_cache_path(account_name, post_id, query) try: path.parent.mkdir(parents=True, exist_ok=True) with open(path, "w", encoding="utf-8") as f: json.dump(results, f, ensure_ascii=False, indent=2) logger.info("_write_cache: wrote cache for query=%s, %d items", query, len(results)) except Exception as e: logger.warning("_write_cache: failed to write cache for query=%s: %s", query, e) async def _search_and_eval_single_query( query: str, system_prompt: str, account_name: str, post_id: str, ) -> List[dict]: """处理单个 query 的搜索、评估、匹配流程,支持缓存""" cached = _read_cache(account_name, post_id, query) if cached is not None: return cached posts = await _search_posts(query) if not posts: logger.warning("_search_and_eval_single_query: no posts for query=%s", query) _write_cache(account_name, post_id, query, []) return [] logger.info("_search_and_eval_single_query: got %d posts for query=%s", len(posts), query) tasks = [ _eval_single_post(post, system_prompt, account_name, post_id) for post in posts ] results: List[dict] = await asyncio.gather(*tasks) _write_cache(account_name, post_id, query, results) return results @tool() async def search_and_eval( account_name: str, post_id: str, query_list: List[str], context: Optional[ToolContext] = None, ) -> ToolResult: """ 搜索帖子并评估是否与账号人设匹配,提取关键词并匹配选题点。 支持多个 query 并发处理,结果按 query 分组返回。 本地文件缓存:.cache/search/{account_name}/{post_id}/ 下每个 query 一个 JSON 文件。 Args: account_name: 账号名称,用于读取人设数据和选题点文件 post_id: 帖子ID,用于定位选题点匹配文件 query_list: 搜索词列表,每个元素为一个 query 字符串 Returns: ToolResult,output 为 JSON 格式的按 query 分组的结果字典: { "query1": [帖子评估结果列表], "query2": [帖子评估结果列表], ... } 每个帖子评估结果包含: - channel_content_id, title, body_text, images - persona_match_result: 是否与账号人设匹配(bool) - post_keywords: 提取的帖子关键词列表 - point_match_results: 关键词与帖子选题点的匹配结果列表 """ logger.info( "search_and_eval: account_name=%s post_id=%s query_list=%s", account_name, post_id, query_list, ) if True: return ToolResult( title="搜索评估工具不可用", output="搜索评估工具不可用" ) if not query_list: return ToolResult( title="搜索评估: 空 query_list", output="{}", ) try: prompt_template = _load_match_and_extract_prompt() persona_text = _load_persona_text(account_name) system_prompt = prompt_template.replace("{persona}", persona_text) tasks = [ _search_and_eval_single_query(q, system_prompt, account_name, post_id) for q in query_list ] all_results: List[List[dict]] = await asyncio.gather(*tasks) grouped: Dict[str, List[dict]] = {} total_posts = 0 total_matched = 0 for query, results in zip(query_list, all_results): grouped[query] = results total_posts += len(results) total_matched += sum(1 for r in results if r.get("persona_match_result")) logger.info( "search_and_eval: done. queries=%d total_posts=%d persona_matched=%d", len(query_list), total_posts, total_matched, ) output = json.dumps(grouped, ensure_ascii=False, indent=2) return ToolResult( title=( f"搜索评估: {len(query_list)} 个 query " f"(共 {total_posts} 条帖子,{total_matched} 条匹配人设)" ), output=output, metadata={"search_and_eval summary": f"{len(query_list)} queries, found {total_posts} posts, {total_matched} matched persona"}, ) except Exception as e: logger.exception("search_and_eval: failed: %s", e) return ToolResult( title="搜索评估失败", output="", error=str(e), ) def main() -> None: """本地测试:用家有大志账号测试搜索评估""" import asyncio logging.basicConfig( level=logging.DEBUG, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", datefmt="%H:%M:%S", ) account_name = "家有大志" post_id = "68fb6a5c000000000302e5de" query_list = ["柴犬 鞋子 啃坏"] async def run(): if ToolResult is None: print("agent 依赖未安装,无法直接运行 tool 版本") return result = await search_and_eval( account_name=account_name, post_id=post_id, query_list=query_list, ) if result.error: print(f"Error: {result.error}") else: print(result.title) grouped = json.loads(result.output) for query, items in grouped.items(): print(f"\n === query: {query} ({len(items)} posts) ===") for item in items: print( f" [{item.get('persona_match_result')}] {item.get('title', '')[:30]}" f" | keywords: {item.get('post_keywords')}" f" | matches: {len(item.get('point_match_results', []))}" ) asyncio.run(run()) if __name__ == "__main__": _project_root = str(Path(__file__).resolve().parent.parent.parent.parent) if _project_root not in sys.path: sys.path.insert(0, _project_root) main()