search_and_eval.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. """
  2. 搜索评估工具:搜索帖子并评估是否与账号人设匹配,提取关键词并匹配选题点。
  3. 处理流程:
  4. 1. 接收 query_list(多个搜索 query),并发处理
  5. 2. 每个 query:使用 xhs(失败或空则用 zhihu)搜索帖子
  6. 3. 并发对每篇帖子调用 LLM 判断人设匹配 & 提取关键词
  7. 4. 对匹配人设的帖子,调用 match_derivation_to_post_points 匹配选题点
  8. 5. 返回按 query 分组的评估结果字典
  9. 6. 支持本地文件缓存(.cache/search/{account_name}/{post_id}/)
  10. """
  11. import asyncio
  12. import hashlib
  13. import json
  14. import logging
  15. import re
  16. import sys
  17. from pathlib import Path
  18. from typing import Dict, List, Optional
  19. logger = logging.getLogger(__name__)
  20. import httpx
  21. # 保证直接运行或作为包加载时都能解析 utils/tools(IDE 可跳转)
  22. _root = Path(__file__).resolve().parent.parent
  23. if str(_root) not in sys.path:
  24. sys.path.insert(0, str(_root))
  25. from tools.point_match import match_derivation_to_post_points
  26. try:
  27. from agent.tools import tool, ToolResult, ToolContext
  28. from agent.llm.openrouter import openrouter_llm_call
  29. except ImportError:
  30. def tool(*args, **kwargs):
  31. return lambda f: f
  32. ToolResult = None
  33. ToolContext = None
  34. openrouter_llm_call = None
  35. _BASE_INPUT = Path(__file__).resolve().parent.parent / "input"
  36. _TOOLS_DIR = Path(__file__).resolve().parent
  37. _CACHE_ROOT = Path(__file__).resolve().parent.parent / ".cache" / "search"
  38. BASE_URL = "http://aigc-channel.aiddit.com/aigc/channel"
  39. DEFAULT_TIMEOUT = 60.0
  40. # 支持多模态(视觉+文本)的 LLM 模型
  41. EVAL_LLM_MODEL = "google/gemini-3-flash-preview"
  42. # 每篇帖子最多传入的图片数量(避免 token 过多)
  43. MAX_IMAGES_PER_POST = 20
  44. def _load_match_and_extract_prompt() -> str:
  45. """读取帖子人设匹配 & 关键词提取的 system prompt 模板"""
  46. prompt_file = _TOOLS_DIR / "match_and_extract_prompt.md"
  47. with open(prompt_file, "r", encoding="utf-8") as f:
  48. return f.read()
  49. def _load_persona_text(account_name: str) -> str:
  50. """读取账号人设摘要,返回可读字符串;文件不存在时返回空人设提示"""
  51. persona_file = _BASE_INPUT / account_name / "处理后数据" / "persona_data" / "persona_summary.json"
  52. if not persona_file.is_file():
  53. logger.warning("_load_persona_text: persona file not found: %s", persona_file)
  54. return f"账号:{account_name}(暂无人设数据)"
  55. with open(persona_file, "r", encoding="utf-8") as f:
  56. data = json.load(f)
  57. # 去掉不需要给 LLM 的中间推理字段,避免 prompt 过长或泄露分析细节
  58. if isinstance(data, dict):
  59. data.pop("分析过程", None)
  60. logger.debug("_load_persona_text: loaded persona for account=%s", account_name)
  61. return json.dumps(data, ensure_ascii=False, indent=2)
  62. async def _do_search(query: str, channel: str) -> Optional[List[dict]]:
  63. """执行单次搜索,返回帖子列表;失败或空列表返回 None"""
  64. logger.debug("_do_search: channel=%s, query=%s", channel, query)
  65. payload = {
  66. "type": channel,
  67. "keyword": query,
  68. "cursor": "0",
  69. "max_count": 10,
  70. "content_type": "图文",
  71. }
  72. try:
  73. async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
  74. resp = await client.post(
  75. f"{BASE_URL}/data",
  76. json=payload,
  77. headers={"Content-Type": "application/json"},
  78. )
  79. resp.raise_for_status()
  80. data = resp.json()
  81. posts = data.get("data") or []
  82. count = len(posts) if posts else 0
  83. logger.info("_do_search: channel=%s, query=%s -> %d posts", channel, query, count)
  84. return posts if posts else None
  85. except Exception as e:
  86. logger.warning("_do_search: channel=%s, query=%s failed: %s", channel, query, e)
  87. return None
  88. async def _search_posts(query: str) -> List[dict]:
  89. """优先用 xhs 搜索,失败或空则用 zhihu,返回帖子列表"""
  90. xhs_query = query.replace(" ", "")
  91. posts = await _do_search(xhs_query, "xhs")
  92. if posts:
  93. logger.info("_search_posts: using xhs, %d posts for query=%s", len(posts), query)
  94. return posts
  95. posts = await _do_search(query, "zhihu")
  96. if posts:
  97. logger.info("_search_posts: xhs empty/failed, using zhihu, %d posts for query=%s", len(posts), query)
  98. else:
  99. logger.warning("_search_posts: no posts from xhs or zhihu for query=%s", query)
  100. return posts or []
  101. def _build_user_message_content(post: dict) -> List[dict]:
  102. """
  103. 将帖子数据构建为 OpenAI 多模态 user message content。
  104. 包含帖子文本描述 + 前 MAX_IMAGES_PER_POST 张图片。
  105. """
  106. parts: List[dict] = []
  107. # 文本部分:将帖子的关键字段序列化给 LLM
  108. post_text = json.dumps(
  109. {
  110. "channel_content_id": post.get("channel_content_id", ""),
  111. "title": post.get("title", ""),
  112. "body_text": post.get("body_text", ""),
  113. },
  114. ensure_ascii=False,
  115. )
  116. parts.append({"type": "text", "text": post_text})
  117. # 图片部分
  118. images = post.get("images") or []
  119. for img_url in images[:MAX_IMAGES_PER_POST]:
  120. if img_url:
  121. parts.append({"type": "image_url", "image_url": {"url": img_url}})
  122. return parts
  123. def _extract_json_object(content: str) -> dict:
  124. """从 LLM 回复中解析第一个 JSON 对象(允许被 ```json ... ``` 包裹)"""
  125. content = content.strip()
  126. m = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", content)
  127. if m:
  128. content = m.group(1).strip()
  129. # 找到最外层 { ... }
  130. start = content.find("{")
  131. end = content.rfind("}")
  132. if start != -1 and end != -1:
  133. content = content[start : end + 1]
  134. return json.loads(content)
  135. async def _eval_single_post(
  136. post: dict,
  137. system_prompt: str,
  138. account_name: str,
  139. post_id: str,
  140. ) -> dict:
  141. """
  142. 评估单篇帖子:
  143. 1. 调用 LLM 判断人设匹配并提取关键词
  144. 2. 若匹配,调用 match_derivation_to_post_points 匹配选题点
  145. 返回完整评估结果字典。
  146. """
  147. post_cid = post.get("channel_content_id", "")
  148. result: dict = {
  149. "channel_content_id": post_cid,
  150. "title": post.get("title", ""),
  151. "body_text": post.get("body_text", ""),
  152. "images": post.get("images") or [],
  153. "persona_match_result": False,
  154. "persona_match_reason": "",
  155. "post_keywords": [],
  156. "point_match_results": [],
  157. }
  158. try:
  159. logger.debug("_eval_single_post: evaluating post_id=%s, title=%s", post_cid, (result["title"] or "")[:40])
  160. user_content = _build_user_message_content(post)
  161. messages = [
  162. {"role": "system", "content": system_prompt},
  163. {"role": "user", "content": user_content},
  164. ]
  165. llm_result = await openrouter_llm_call(messages, model=EVAL_LLM_MODEL)
  166. content = llm_result.get("content", "")
  167. if not content:
  168. result["error"] = "LLM 未返回内容"
  169. logger.warning("_eval_single_post: post_id=%s LLM returned empty content", post_cid)
  170. return result
  171. parsed = _extract_json_object(content)
  172. result["persona_match_result"] = bool(parsed.get("persona_match_result", False))
  173. result["persona_match_reason"] = parsed.get("persona_match_reason", "")
  174. result["post_keywords"] = parsed.get("post_keywords") or []
  175. logger.info(
  176. "_eval_single_post: post_id=%s persona_match=%s keywords=%s",
  177. post_cid,
  178. result["persona_match_result"],
  179. result["post_keywords"],
  180. )
  181. # 仅对与人设匹配的帖子做选题点匹配
  182. if result["persona_match_result"] and result["post_keywords"]:
  183. matched = await match_derivation_to_post_points(
  184. result["post_keywords"], account_name, post_id
  185. )
  186. result["point_match_results"] = matched
  187. logger.info(
  188. "_eval_single_post: post_id=%s point_match count=%d",
  189. post_cid,
  190. len(matched),
  191. )
  192. except Exception as e:
  193. logger.exception("_eval_single_post: post_id=%s error: %s", post_cid, e)
  194. result["error"] = str(e)
  195. return result
  196. def _cache_key(query: str) -> str:
  197. """将 query 转为安全的文件名:使用 MD5 哈希避免特殊字符问题"""
  198. h = hashlib.md5(query.encode("utf-8")).hexdigest()[:12]
  199. safe = re.sub(r'[^\w\u4e00-\u9fff]+', '_', query)[:60].strip('_')
  200. return f"{safe}_{h}"
  201. def _get_cache_path(account_name: str, post_id: str, query: str) -> Path:
  202. return _CACHE_ROOT / account_name / post_id / f"{_cache_key(query)}.json"
  203. def _read_cache(account_name: str, post_id: str, query: str) -> Optional[List[dict]]:
  204. """读取缓存,存在且合法则返回帖子列表,否则返回 None"""
  205. path = _get_cache_path(account_name, post_id, query)
  206. if not path.is_file():
  207. return None
  208. try:
  209. with open(path, "r", encoding="utf-8") as f:
  210. data = json.load(f)
  211. if isinstance(data, list):
  212. logger.info("_read_cache: hit cache for query=%s, %d items", query, len(data))
  213. return data
  214. except Exception as e:
  215. logger.warning("_read_cache: failed to read cache for query=%s: %s", query, e)
  216. return None
  217. def _write_cache(account_name: str, post_id: str, query: str, results: List[dict]) -> None:
  218. """写入缓存"""
  219. path = _get_cache_path(account_name, post_id, query)
  220. try:
  221. path.parent.mkdir(parents=True, exist_ok=True)
  222. with open(path, "w", encoding="utf-8") as f:
  223. json.dump(results, f, ensure_ascii=False, indent=2)
  224. logger.info("_write_cache: wrote cache for query=%s, %d items", query, len(results))
  225. except Exception as e:
  226. logger.warning("_write_cache: failed to write cache for query=%s: %s", query, e)
  227. async def _search_and_eval_single_query(
  228. query: str,
  229. system_prompt: str,
  230. account_name: str,
  231. post_id: str,
  232. ) -> List[dict]:
  233. """处理单个 query 的搜索、评估、匹配流程,支持缓存"""
  234. cached = _read_cache(account_name, post_id, query)
  235. if cached is not None:
  236. return cached
  237. posts = await _search_posts(query)
  238. if not posts:
  239. logger.warning("_search_and_eval_single_query: no posts for query=%s", query)
  240. _write_cache(account_name, post_id, query, [])
  241. return []
  242. logger.info("_search_and_eval_single_query: got %d posts for query=%s", len(posts), query)
  243. tasks = [
  244. _eval_single_post(post, system_prompt, account_name, post_id)
  245. for post in posts
  246. ]
  247. results: List[dict] = await asyncio.gather(*tasks)
  248. _write_cache(account_name, post_id, query, results)
  249. return results
  250. @tool()
  251. async def search_and_eval(
  252. account_name: str,
  253. post_id: str,
  254. query_list: List[str],
  255. context: Optional[ToolContext] = None,
  256. ) -> ToolResult:
  257. """
  258. 搜索帖子并评估是否与账号人设匹配,提取关键词并匹配选题点。
  259. 支持多个 query 并发处理,结果按 query 分组返回。
  260. 本地文件缓存:.cache/search/{account_name}/{post_id}/ 下每个 query 一个 JSON 文件。
  261. Args:
  262. account_name: 账号名称,用于读取人设数据和选题点文件
  263. post_id: 帖子ID,用于定位选题点匹配文件
  264. query_list: 搜索词列表,每个元素为一个 query 字符串
  265. Returns:
  266. ToolResult,output 为 JSON 格式的按 query 分组的结果字典:
  267. {
  268. "query1": [帖子评估结果列表],
  269. "query2": [帖子评估结果列表],
  270. ...
  271. }
  272. 每个帖子评估结果包含:
  273. - channel_content_id, title, body_text, images
  274. - persona_match_result: 是否与账号人设匹配(bool)
  275. - post_keywords: 提取的帖子关键词列表
  276. - point_match_results: 关键词与帖子选题点的匹配结果列表
  277. """
  278. logger.info(
  279. "search_and_eval: account_name=%s post_id=%s query_list=%s",
  280. account_name,
  281. post_id,
  282. query_list,
  283. )
  284. if True:
  285. return ToolResult(
  286. title="搜索评估工具不可用",
  287. output="搜索评估工具不可用"
  288. )
  289. if not query_list:
  290. return ToolResult(
  291. title="搜索评估: 空 query_list",
  292. output="{}",
  293. )
  294. try:
  295. prompt_template = _load_match_and_extract_prompt()
  296. persona_text = _load_persona_text(account_name)
  297. system_prompt = prompt_template.replace("{persona}", persona_text)
  298. tasks = [
  299. _search_and_eval_single_query(q, system_prompt, account_name, post_id)
  300. for q in query_list
  301. ]
  302. all_results: List[List[dict]] = await asyncio.gather(*tasks)
  303. grouped: Dict[str, List[dict]] = {}
  304. total_posts = 0
  305. total_matched = 0
  306. for query, results in zip(query_list, all_results):
  307. grouped[query] = results
  308. total_posts += len(results)
  309. total_matched += sum(1 for r in results if r.get("persona_match_result"))
  310. logger.info(
  311. "search_and_eval: done. queries=%d total_posts=%d persona_matched=%d",
  312. len(query_list),
  313. total_posts,
  314. total_matched,
  315. )
  316. output = json.dumps(grouped, ensure_ascii=False, indent=2)
  317. return ToolResult(
  318. title=(
  319. f"搜索评估: {len(query_list)} 个 query "
  320. f"(共 {total_posts} 条帖子,{total_matched} 条匹配人设)"
  321. ),
  322. output=output,
  323. metadata={"search_and_eval summary": f"{len(query_list)} queries, found {total_posts} posts, {total_matched} matched persona"},
  324. )
  325. except Exception as e:
  326. logger.exception("search_and_eval: failed: %s", e)
  327. return ToolResult(
  328. title="搜索评估失败",
  329. output="",
  330. error=str(e),
  331. )
  332. def main() -> None:
  333. """本地测试:用家有大志账号测试搜索评估"""
  334. import asyncio
  335. logging.basicConfig(
  336. level=logging.DEBUG,
  337. format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
  338. datefmt="%H:%M:%S",
  339. )
  340. account_name = "家有大志"
  341. post_id = "68fb6a5c000000000302e5de"
  342. query_list = ["柴犬 鞋子 啃坏"]
  343. async def run():
  344. if ToolResult is None:
  345. print("agent 依赖未安装,无法直接运行 tool 版本")
  346. return
  347. result = await search_and_eval(
  348. account_name=account_name,
  349. post_id=post_id,
  350. query_list=query_list,
  351. )
  352. if result.error:
  353. print(f"Error: {result.error}")
  354. else:
  355. print(result.title)
  356. grouped = json.loads(result.output)
  357. for query, items in grouped.items():
  358. print(f"\n === query: {query} ({len(items)} posts) ===")
  359. for item in items:
  360. print(
  361. f" [{item.get('persona_match_result')}] {item.get('title', '')[:30]}"
  362. f" | keywords: {item.get('post_keywords')}"
  363. f" | matches: {len(item.get('point_match_results', []))}"
  364. )
  365. asyncio.run(run())
  366. if __name__ == "__main__":
  367. _project_root = str(Path(__file__).resolve().parent.parent.parent.parent)
  368. if _project_root not in sys.path:
  369. sys.path.insert(0, _project_root)
  370. main()