| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436 |
- """
- 搜索评估工具:搜索帖子并评估是否与账号人设匹配,提取关键词并匹配选题点。
- 处理流程:
- 1. 接收 query_list(多个搜索 query),并发处理
- 2. 每个 query:使用 xhs(失败或空则用 zhihu)搜索帖子
- 3. 并发对每篇帖子调用 LLM 判断人设匹配 & 提取关键词
- 4. 对匹配人设的帖子,调用 match_derivation_to_post_points 匹配选题点
- 5. 返回按 query 分组的评估结果字典
- 6. 支持本地文件缓存(.cache/search/{account_name}/{post_id}/)
- """
- import asyncio
- import hashlib
- import json
- import logging
- import re
- import sys
- from pathlib import Path
- from typing import Dict, List, Optional
- logger = logging.getLogger(__name__)
- import httpx
- # 保证直接运行或作为包加载时都能解析 utils/tools(IDE 可跳转)
- _root = Path(__file__).resolve().parent.parent
- if str(_root) not in sys.path:
- sys.path.insert(0, str(_root))
- from tools.point_match import match_derivation_to_post_points
- try:
- from agent.tools import tool, ToolResult, ToolContext
- from agent.llm.openrouter import openrouter_llm_call
- except ImportError:
- def tool(*args, **kwargs):
- return lambda f: f
- ToolResult = None
- ToolContext = None
- openrouter_llm_call = None
- _BASE_INPUT = Path(__file__).resolve().parent.parent / "input"
- _TOOLS_DIR = Path(__file__).resolve().parent
- _CACHE_ROOT = Path(__file__).resolve().parent.parent / ".cache" / "search"
- BASE_URL = "http://aigc-channel.aiddit.com/aigc/channel"
- DEFAULT_TIMEOUT = 60.0
- # 支持多模态(视觉+文本)的 LLM 模型
- EVAL_LLM_MODEL = "google/gemini-3-flash-preview"
- # 每篇帖子最多传入的图片数量(避免 token 过多)
- MAX_IMAGES_PER_POST = 20
- def _load_match_and_extract_prompt() -> str:
- """读取帖子人设匹配 & 关键词提取的 system prompt 模板"""
- prompt_file = _TOOLS_DIR / "match_and_extract_prompt.md"
- with open(prompt_file, "r", encoding="utf-8") as f:
- return f.read()
- def _load_persona_text(account_name: str) -> str:
- """读取账号人设摘要,返回可读字符串;文件不存在时返回空人设提示"""
- persona_file = _BASE_INPUT / account_name / "处理后数据" / "persona_data" / "persona_summary.json"
- if not persona_file.is_file():
- logger.warning("_load_persona_text: persona file not found: %s", persona_file)
- return f"账号:{account_name}(暂无人设数据)"
- with open(persona_file, "r", encoding="utf-8") as f:
- data = json.load(f)
- # 去掉不需要给 LLM 的中间推理字段,避免 prompt 过长或泄露分析细节
- if isinstance(data, dict):
- data.pop("分析过程", None)
- logger.debug("_load_persona_text: loaded persona for account=%s", account_name)
- return json.dumps(data, ensure_ascii=False, indent=2)
- async def _do_search(query: str, channel: str) -> Optional[List[dict]]:
- """执行单次搜索,返回帖子列表;失败或空列表返回 None"""
- logger.debug("_do_search: channel=%s, query=%s", channel, query)
- payload = {
- "type": channel,
- "keyword": query,
- "cursor": "0",
- "max_count": 10,
- "content_type": "图文",
- }
- try:
- async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
- resp = await client.post(
- f"{BASE_URL}/data",
- json=payload,
- headers={"Content-Type": "application/json"},
- )
- resp.raise_for_status()
- data = resp.json()
- posts = data.get("data") or []
- count = len(posts) if posts else 0
- logger.info("_do_search: channel=%s, query=%s -> %d posts", channel, query, count)
- return posts if posts else None
- except Exception as e:
- logger.warning("_do_search: channel=%s, query=%s failed: %s", channel, query, e)
- return None
- async def _search_posts(query: str) -> List[dict]:
- """优先用 xhs 搜索,失败或空则用 zhihu,返回帖子列表"""
- xhs_query = query.replace(" ", "")
- posts = await _do_search(xhs_query, "xhs")
- if posts:
- logger.info("_search_posts: using xhs, %d posts for query=%s", len(posts), query)
- return posts
- posts = await _do_search(query, "zhihu")
- if posts:
- logger.info("_search_posts: xhs empty/failed, using zhihu, %d posts for query=%s", len(posts), query)
- else:
- logger.warning("_search_posts: no posts from xhs or zhihu for query=%s", query)
- return posts or []
- def _build_user_message_content(post: dict) -> List[dict]:
- """
- 将帖子数据构建为 OpenAI 多模态 user message content。
- 包含帖子文本描述 + 前 MAX_IMAGES_PER_POST 张图片。
- """
- parts: List[dict] = []
- # 文本部分:将帖子的关键字段序列化给 LLM
- post_text = json.dumps(
- {
- "channel_content_id": post.get("channel_content_id", ""),
- "title": post.get("title", ""),
- "body_text": post.get("body_text", ""),
- },
- ensure_ascii=False,
- )
- parts.append({"type": "text", "text": post_text})
- # 图片部分
- images = post.get("images") or []
- for img_url in images[:MAX_IMAGES_PER_POST]:
- if img_url:
- parts.append({"type": "image_url", "image_url": {"url": img_url}})
- return parts
- def _extract_json_object(content: str) -> dict:
- """从 LLM 回复中解析第一个 JSON 对象(允许被 ```json ... ``` 包裹)"""
- content = content.strip()
- m = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", content)
- if m:
- content = m.group(1).strip()
- # 找到最外层 { ... }
- start = content.find("{")
- end = content.rfind("}")
- if start != -1 and end != -1:
- content = content[start : end + 1]
- return json.loads(content)
- async def _eval_single_post(
- post: dict,
- system_prompt: str,
- account_name: str,
- post_id: str,
- ) -> dict:
- """
- 评估单篇帖子:
- 1. 调用 LLM 判断人设匹配并提取关键词
- 2. 若匹配,调用 match_derivation_to_post_points 匹配选题点
- 返回完整评估结果字典。
- """
- post_cid = post.get("channel_content_id", "")
- result: dict = {
- "channel_content_id": post_cid,
- "title": post.get("title", ""),
- "body_text": post.get("body_text", ""),
- "images": post.get("images") or [],
- "persona_match_result": False,
- "persona_match_reason": "",
- "post_keywords": [],
- "point_match_results": [],
- }
- try:
- logger.debug("_eval_single_post: evaluating post_id=%s, title=%s", post_cid, (result["title"] or "")[:40])
- user_content = _build_user_message_content(post)
- messages = [
- {"role": "system", "content": system_prompt},
- {"role": "user", "content": user_content},
- ]
- llm_result = await openrouter_llm_call(messages, model=EVAL_LLM_MODEL)
- content = llm_result.get("content", "")
- if not content:
- result["error"] = "LLM 未返回内容"
- logger.warning("_eval_single_post: post_id=%s LLM returned empty content", post_cid)
- return result
- parsed = _extract_json_object(content)
- result["persona_match_result"] = bool(parsed.get("persona_match_result", False))
- result["persona_match_reason"] = parsed.get("persona_match_reason", "")
- result["post_keywords"] = parsed.get("post_keywords") or []
- logger.info(
- "_eval_single_post: post_id=%s persona_match=%s keywords=%s",
- post_cid,
- result["persona_match_result"],
- result["post_keywords"],
- )
- # 仅对与人设匹配的帖子做选题点匹配
- if result["persona_match_result"] and result["post_keywords"]:
- matched = await match_derivation_to_post_points(
- result["post_keywords"], account_name, post_id
- )
- result["point_match_results"] = matched
- logger.info(
- "_eval_single_post: post_id=%s point_match count=%d",
- post_cid,
- len(matched),
- )
- except Exception as e:
- logger.exception("_eval_single_post: post_id=%s error: %s", post_cid, e)
- result["error"] = str(e)
- return result
- def _cache_key(query: str) -> str:
- """将 query 转为安全的文件名:使用 MD5 哈希避免特殊字符问题"""
- h = hashlib.md5(query.encode("utf-8")).hexdigest()[:12]
- safe = re.sub(r'[^\w\u4e00-\u9fff]+', '_', query)[:60].strip('_')
- return f"{safe}_{h}"
- def _get_cache_path(account_name: str, post_id: str, query: str) -> Path:
- return _CACHE_ROOT / account_name / post_id / f"{_cache_key(query)}.json"
- def _read_cache(account_name: str, post_id: str, query: str) -> Optional[List[dict]]:
- """读取缓存,存在且合法则返回帖子列表,否则返回 None"""
- path = _get_cache_path(account_name, post_id, query)
- if not path.is_file():
- return None
- try:
- with open(path, "r", encoding="utf-8") as f:
- data = json.load(f)
- if isinstance(data, list):
- logger.info("_read_cache: hit cache for query=%s, %d items", query, len(data))
- return data
- except Exception as e:
- logger.warning("_read_cache: failed to read cache for query=%s: %s", query, e)
- return None
- def _write_cache(account_name: str, post_id: str, query: str, results: List[dict]) -> None:
- """写入缓存"""
- path = _get_cache_path(account_name, post_id, query)
- try:
- path.parent.mkdir(parents=True, exist_ok=True)
- with open(path, "w", encoding="utf-8") as f:
- json.dump(results, f, ensure_ascii=False, indent=2)
- logger.info("_write_cache: wrote cache for query=%s, %d items", query, len(results))
- except Exception as e:
- logger.warning("_write_cache: failed to write cache for query=%s: %s", query, e)
- async def _search_and_eval_single_query(
- query: str,
- system_prompt: str,
- account_name: str,
- post_id: str,
- ) -> List[dict]:
- """处理单个 query 的搜索、评估、匹配流程,支持缓存"""
- cached = _read_cache(account_name, post_id, query)
- if cached is not None:
- return cached
- posts = await _search_posts(query)
- if not posts:
- logger.warning("_search_and_eval_single_query: no posts for query=%s", query)
- _write_cache(account_name, post_id, query, [])
- return []
- logger.info("_search_and_eval_single_query: got %d posts for query=%s", len(posts), query)
- tasks = [
- _eval_single_post(post, system_prompt, account_name, post_id)
- for post in posts
- ]
- results: List[dict] = await asyncio.gather(*tasks)
- _write_cache(account_name, post_id, query, results)
- return results
- @tool()
- async def search_and_eval(
- account_name: str,
- post_id: str,
- query_list: List[str],
- context: Optional[ToolContext] = None,
- ) -> ToolResult:
- """
- 搜索帖子并评估是否与账号人设匹配,提取关键词并匹配选题点。
- 支持多个 query 并发处理,结果按 query 分组返回。
- 本地文件缓存:.cache/search/{account_name}/{post_id}/ 下每个 query 一个 JSON 文件。
- Args:
- account_name: 账号名称,用于读取人设数据和选题点文件
- post_id: 帖子ID,用于定位选题点匹配文件
- query_list: 搜索词列表,每个元素为一个 query 字符串
- Returns:
- ToolResult,output 为 JSON 格式的按 query 分组的结果字典:
- {
- "query1": [帖子评估结果列表],
- "query2": [帖子评估结果列表],
- ...
- }
- 每个帖子评估结果包含:
- - channel_content_id, title, body_text, images
- - persona_match_result: 是否与账号人设匹配(bool)
- - post_keywords: 提取的帖子关键词列表
- - point_match_results: 关键词与帖子选题点的匹配结果列表
- """
- logger.info(
- "search_and_eval: account_name=%s post_id=%s query_list=%s",
- account_name,
- post_id,
- query_list,
- )
- if True:
- return ToolResult(
- title="搜索评估工具不可用",
- output="搜索评估工具不可用"
- )
- if not query_list:
- return ToolResult(
- title="搜索评估: 空 query_list",
- output="{}",
- )
- try:
- prompt_template = _load_match_and_extract_prompt()
- persona_text = _load_persona_text(account_name)
- system_prompt = prompt_template.replace("{persona}", persona_text)
- tasks = [
- _search_and_eval_single_query(q, system_prompt, account_name, post_id)
- for q in query_list
- ]
- all_results: List[List[dict]] = await asyncio.gather(*tasks)
- grouped: Dict[str, List[dict]] = {}
- total_posts = 0
- total_matched = 0
- for query, results in zip(query_list, all_results):
- grouped[query] = results
- total_posts += len(results)
- total_matched += sum(1 for r in results if r.get("persona_match_result"))
- logger.info(
- "search_and_eval: done. queries=%d total_posts=%d persona_matched=%d",
- len(query_list),
- total_posts,
- total_matched,
- )
- output = json.dumps(grouped, ensure_ascii=False, indent=2)
- return ToolResult(
- title=(
- f"搜索评估: {len(query_list)} 个 query "
- f"(共 {total_posts} 条帖子,{total_matched} 条匹配人设)"
- ),
- output=output,
- metadata={"search_and_eval summary": f"{len(query_list)} queries, found {total_posts} posts, {total_matched} matched persona"},
- )
- except Exception as e:
- logger.exception("search_and_eval: failed: %s", e)
- return ToolResult(
- title="搜索评估失败",
- output="",
- error=str(e),
- )
- def main() -> None:
- """本地测试:用家有大志账号测试搜索评估"""
- import asyncio
- logging.basicConfig(
- level=logging.DEBUG,
- format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
- datefmt="%H:%M:%S",
- )
- account_name = "家有大志"
- post_id = "68fb6a5c000000000302e5de"
- query_list = ["柴犬 鞋子 啃坏"]
- async def run():
- if ToolResult is None:
- print("agent 依赖未安装,无法直接运行 tool 版本")
- return
- result = await search_and_eval(
- account_name=account_name,
- post_id=post_id,
- query_list=query_list,
- )
- if result.error:
- print(f"Error: {result.error}")
- else:
- print(result.title)
- grouped = json.loads(result.output)
- for query, items in grouped.items():
- print(f"\n === query: {query} ({len(items)} posts) ===")
- for item in items:
- print(
- f" [{item.get('persona_match_result')}] {item.get('title', '')[:30]}"
- f" | keywords: {item.get('post_keywords')}"
- f" | matches: {len(item.get('point_match_results', []))}"
- )
- asyncio.run(run())
- if __name__ == "__main__":
- _project_root = str(Path(__file__).resolve().parent.parent.parent.parent)
- if _project_root not in sys.path:
- sys.path.insert(0, _project_root)
- main()
|