aigc_channel.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450
  1. """
  2. AIGC-Channel 平台实现(9 个中文平台)
  3. 后端:aigc-channel.aiddit.com
  4. 平台:xhs / gzh / sph / github / toutiao / douyin / bili / zhihu / weibo
  5. """
  6. import json
  7. import re
  8. from typing import Any, Dict, List, Optional
  9. import httpx
  10. from agent.tools.models import ToolResult
  11. from agent.tools.utils.image import build_image_grid, encode_base64, load_images
  12. from agent.tools.builtin.content.registry import (
  13. PlatformDef, ParamSpec, register_platform,
  14. )
  15. BASE_URL = "http://aigc-channel.aiddit.com/aigc/channel"
  16. DEFAULT_TIMEOUT = 60.0
  17. # aigc-channel returns search-highlighted titles like
  18. # '<em class="highlight">关键词</em>'. Strip before any rendering / scoring use.
  19. _HTML_TAG_RE = re.compile(r"<[^>]+>")
  20. def _strip_html(text: Optional[str]) -> str:
  21. if not text:
  22. return ""
  23. return _HTML_TAG_RE.sub("", text)
  24. _SPH_TITLE_MAX = 20 # sph normalized title 截断字符数
  25. def _normalize_sph_post(post: Dict[str, Any]) -> None:
  26. """In-place: 视频号没有独立 title,后端把 caption 塞进 title 字段而 body_text 留空。
  27. 把整段 title 搬到 body_text,title 取剥 HTML 后前 20 字 + '...' 作为短摘要。
  28. 幂等:如果 body_text 已经有内容则不动,避免重复迁移或覆盖;title 已经 <=20 字
  29. 也不强加省略号。
  30. """
  31. if not isinstance(post, dict):
  32. return
  33. raw_title = post.get("title") or ""
  34. body = post.get("body_text") or ""
  35. body = body.strip() if isinstance(body, str) else ""
  36. if not raw_title or body:
  37. return
  38. clean = _strip_html(raw_title).strip()
  39. if not clean:
  40. return
  41. post["body_text"] = clean
  42. if len(clean) > _SPH_TITLE_MAX:
  43. post["title"] = clean[:_SPH_TITLE_MAX] + "..."
  44. else:
  45. post["title"] = clean
  46. # ── 平台注册 ──
  47. _XHS_SEARCH_PARAMS = {
  48. "sort_type": ParamSpec(
  49. values=["综合排序", "最新发布", "最多点赞"],
  50. default="综合排序",
  51. ),
  52. "publish_time": ParamSpec(
  53. values=["不限", "近1天", "近7天", "近30天"],
  54. default="不限",
  55. ),
  56. "content_type": ParamSpec(
  57. values=["不限", "图文", "视频", "文章"],
  58. default="不限",
  59. ),
  60. "filter_note_range": ParamSpec(
  61. values=["不限", "1分钟以内", "1-5分钟", "5分钟以上"],
  62. default="不限",
  63. note="仅视频内容生效",
  64. ),
  65. }
  66. _COMMON_CONTENT_TYPE = {
  67. "content_type": ParamSpec(
  68. values=["视频", "图文"],
  69. default="",
  70. note="留空不限",
  71. ),
  72. }
  73. # 9 个中文平台定义
  74. _AIGC_PLATFORMS = [
  75. PlatformDef(id="xhs", name="小红书", aliases=["RED", "xiaohongshu"], search_params=_XHS_SEARCH_PARAMS, supports_suggest=True),
  76. PlatformDef(id="gzh", name="公众号", aliases=["微信公众号", "wechat"], search_params=_COMMON_CONTENT_TYPE),
  77. PlatformDef(id="sph", name="视频号", aliases=["微信视频号"], search_params=_COMMON_CONTENT_TYPE),
  78. PlatformDef(id="github", name="GitHub", aliases=["gh"], search_params=_COMMON_CONTENT_TYPE),
  79. PlatformDef(id="toutiao", name="头条", aliases=["今日头条", "toutiao"], search_params=_COMMON_CONTENT_TYPE, supports_suggest=True),
  80. PlatformDef(id="douyin", name="抖音", aliases=["TikTok"], search_params=_COMMON_CONTENT_TYPE, supports_suggest=True),
  81. PlatformDef(id="bili", name="B站", aliases=["哔哩哔哩", "bilibili"], search_params=_COMMON_CONTENT_TYPE, supports_suggest=True),
  82. PlatformDef(id="zhihu", name="知乎", aliases=[], search_params=_COMMON_CONTENT_TYPE, supports_suggest=True),
  83. PlatformDef(id="weibo", name="微博", aliases=["sina"], search_params=_COMMON_CONTENT_TYPE),
  84. ]
  85. # suggest API 额外支持 wx(微信搜一搜),但它不是搜索平台
  86. _SUGGEST_ONLY_CHANNELS = {"wx": "微信"}
  87. # ── 搜索实现 ──
  88. async def search(
  89. platform_id: str,
  90. keyword: str,
  91. max_count: int = 20,
  92. cursor: str = "",
  93. extras: Optional[Dict[str, Any]] = None,
  94. ) -> ToolResult:
  95. """AIGC-Channel 统一搜索"""
  96. extras = extras or {}
  97. if platform_id == "xhs":
  98. payload = {
  99. "type": platform_id,
  100. "keyword": keyword,
  101. "cursor": cursor,
  102. "content_type": extras.get("content_type", "不限"),
  103. "sort_type": extras.get("sort_type", "综合排序"),
  104. "publish_time": extras.get("publish_time", "不限"),
  105. "filter_note_range": extras.get("filter_note_range", "不限"),
  106. }
  107. else:
  108. payload = {
  109. "type": platform_id,
  110. "keyword": keyword,
  111. "cursor": cursor or "0",
  112. "max_count": max_count,
  113. "content_type": extras.get("content_type", ""),
  114. }
  115. try:
  116. async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
  117. response = await client.post(
  118. f"{BASE_URL}/data",
  119. json=payload,
  120. headers={"Content-Type": "application/json"},
  121. )
  122. response.raise_for_status()
  123. data = response.json()
  124. except httpx.HTTPStatusError as e:
  125. return ToolResult(title="搜索失败", output="", error=f"HTTP {e.response.status_code}: {e.response.text}")
  126. except Exception as e:
  127. return ToolResult(title="搜索失败", output="", error=str(e))
  128. posts = data.get("data", [])
  129. # sph 字段 normalization:title 太长(后端把 caption 塞进 title),
  130. # 把它搬到 body_text,title 取前 20 字。在评分 / summary / cache 之前做。
  131. if platform_id == "sph":
  132. for p in posts:
  133. _normalize_sph_post(p)
  134. # 构建概览摘要
  135. summary_list = []
  136. # 动态导入评价模块
  137. try:
  138. from examples.process_pipeline.script.evaluate_source_quality import SourceQualityEvaluator
  139. evaluator = SourceQualityEvaluator()
  140. except ImportError:
  141. evaluator = None
  142. # 视频帖在评分前先并发探测 mp4 duration(HTTP Range,不下载视频流),
  143. # 让 evaluator 用真实时长替代 body 长度作为内容信号。
  144. if evaluator and posts:
  145. try:
  146. from agent.tools.builtin.content.transcription import probe_durations_for_posts
  147. await probe_durations_for_posts(platform_id, posts, concurrency=8)
  148. except Exception as e:
  149. import logging
  150. logging.getLogger(__name__).info("duration probe failed: %s", e)
  151. for idx, post in enumerate(posts, 1):
  152. body = post.get("body_text", "") or ""
  153. title = post.get("title") or body[:20] or ""
  154. score_info = {}
  155. if evaluator:
  156. try:
  157. eval_res = evaluator.evaluate_post(post)
  158. score_info = {
  159. "quality_score": eval_res["total_score"],
  160. "quality_grade": eval_res["grade"]
  161. }
  162. post["_quality_score"] = eval_res["total_score"]
  163. post["_quality_grade"] = eval_res["grade"]
  164. except Exception:
  165. pass
  166. summary_item = {
  167. "index": idx,
  168. "title": title,
  169. "body_text": body[:100] + ("..." if len(body) > 100 else ""),
  170. "like_count": post.get("like_count"),
  171. "comment_count": post.get("comment_count"),
  172. "channel": post.get("channel"),
  173. "link": post.get("link"),
  174. "content_type": post.get("content_type"),
  175. }
  176. summary_item.update(score_info)
  177. summary_list.append(summary_item)
  178. # 封面拼图
  179. images = []
  180. try:
  181. collage_obj = await _build_collage(posts)
  182. if collage_obj:
  183. images.append(collage_obj)
  184. except Exception as e:
  185. import logging
  186. logging.getLogger(__name__).warning("Error generating collage: %s", e)
  187. return ToolResult(
  188. title=f"搜索: {keyword} ({platform_id})",
  189. output=json.dumps({"data": summary_list}, ensure_ascii=False, indent=2),
  190. long_term_memory=f"Searched '{keyword}' on {platform_id}, {len(posts)} results. Use content_detail to view full details.",
  191. images=images,
  192. metadata={"posts": posts}, # 完整数据传给上层缓存
  193. )
  194. # ── 详情实现(从缓存获取,不需要额外 HTTP) ──
  195. MAX_DETAIL_IMAGES = 10 # detail 中保留的图片总数上限(含拼图)
  196. KEEP_INDIVIDUAL = 8 # 单张图片保留数量;剩余图片合并为 1 张拼图
  197. async def _build_images_collage(urls: List[str]) -> Optional[Dict[str, Any]]:
  198. """将一组图片 URL 拼成单张网格图"""
  199. if not urls:
  200. return None
  201. loaded = await load_images(urls)
  202. valid_images = [img for (_, img) in loaded if img is not None]
  203. if not valid_images:
  204. return None
  205. grid = build_image_grid(images=valid_images, labels=None)
  206. import io
  207. buf = io.BytesIO()
  208. grid.save(buf, format="PNG")
  209. img_bytes = buf.getvalue()
  210. try:
  211. from agent.tools.builtin.file.image_cdn import _upload_bytes_to_oss
  212. import hashlib
  213. md5_hash = hashlib.md5(img_bytes).hexdigest()[:12]
  214. filename = f"collage_detail_{md5_hash}.png"
  215. cdn_url = await _upload_bytes_to_oss(img_bytes, filename)
  216. return {"type": "url", "url": cdn_url}
  217. except Exception as e:
  218. import logging
  219. logging.getLogger(__name__).warning("Failed to upload detail collage to CDN: %s", e)
  220. b64, _ = encode_base64(grid, format="PNG")
  221. return {"type": "base64", "media_type": "image/png", "data": b64}
  222. async def detail(
  223. post: Dict[str, Any],
  224. extras: Optional[Dict[str, Any]] = None,
  225. platform_id: str = "",
  226. ) -> ToolResult:
  227. """返回单条帖子的完整内容;sph/douyin 视频会通过 Deepgram 自动转写。"""
  228. title = post.get("title") or post.get("body_text", "")[:30] or "无标题"
  229. img_urls = [u for u in post.get("images", []) if u]
  230. images = []
  231. if len(img_urls) > MAX_DETAIL_IMAGES:
  232. # 保留前 KEEP_INDIVIDUAL 张原图,剩余拼成 1 张网格图
  233. for u in img_urls[:KEEP_INDIVIDUAL]:
  234. images.append({"type": "url", "url": u})
  235. collage = await _build_images_collage(img_urls[KEEP_INDIVIDUAL:])
  236. if collage:
  237. images.append(collage)
  238. else:
  239. for u in img_urls:
  240. images.append({"type": "url", "url": u})
  241. # 视频字幕:任何 aigc-channel 平台只要 post.videos 字段非空就触发 Deepgram 转写。
  242. # 下载策略在 transcription._download_video 里按 platform 分支,未指定的平台走
  243. # "yt-dlp on page URL → httpx direct" 两步兜底。
  244. #
  245. # 三态语义(跟 extract_sources._needs_transcribe 对齐):
  246. # 字段缺失 → 没尝试过,跑 Deepgram
  247. # 字段 = "" → 尝试过但失败,跳过(保护 Deepgram 额度)
  248. # 字段 = text → 已成功,复用
  249. extras_d = extras or {}
  250. has_video = bool(post.get("videos"))
  251. field_present = "video_transcript" in post
  252. transcript_text: Optional[str] = post.get("video_transcript") or None
  253. if (
  254. not field_present
  255. and has_video
  256. and extras_d.get("include_transcript", True)
  257. ):
  258. from agent.tools.builtin.content.transcription import transcribe_video_from_post
  259. transcribe_error: Optional[str] = None
  260. try:
  261. transcript_text = await transcribe_video_from_post(platform_id, post)
  262. except Exception as e:
  263. transcript_text = None
  264. transcribe_error = f"{type(e).__name__}: {e}"
  265. import logging as _logging
  266. _logging.getLogger(__name__).warning(
  267. "transcribe_video_from_post raised for %s: %s", platform_id, e
  268. )
  269. # 三态写回:成功 = text;失败/None = "" 作为"已尝试"标记,下次 cache hit 直接短路。
  270. final_value = transcript_text or ""
  271. post["video_transcript"] = final_value
  272. if not final_value:
  273. # 失败原因暴露到 output JSON,方便 agent/用户判断要不要重试或换平台
  274. post["_transcribe_error"] = (
  275. transcribe_error
  276. or "transcribe returned None (下载/抽音/Deepgram 任一步失败,见 logger.warning)"
  277. )
  278. # cache writeback 不再以"成功"为前提:失败的 "" 也写回,让下次 cache hit 短路掉
  279. import os as _os
  280. from agent.tools.builtin.content import cache as _cache
  281. trace_id = extras_d.get("__trace_id__") or _os.getenv("TRACE_ID")
  282. content_id = (
  283. post.get("channel_content_id")
  284. or post.get("content_id")
  285. or post.get("video_id")
  286. )
  287. if trace_id and content_id:
  288. _cache.update_post_field(
  289. trace_id, platform_id, content_id, "video_transcript", final_value
  290. )
  291. # transcript already embedded as post["video_transcript"] inside the JSON dump;
  292. # no need to repeat as a separate section.
  293. output_text = json.dumps(post, ensure_ascii=False, indent=2)
  294. memory_suffix = " +transcript" if transcript_text else ""
  295. return ToolResult(
  296. title=f"详情: {title}",
  297. output=output_text,
  298. long_term_memory=f"Viewed detail: {title}{memory_suffix}",
  299. images=images,
  300. )
  301. # ── 建议词实现 ──
  302. async def suggest(channel: str, keyword: str) -> ToolResult:
  303. """获取搜索建议词"""
  304. try:
  305. async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
  306. response = await client.post(
  307. f"{BASE_URL}/suggest",
  308. json={"type": channel, "keyword": keyword},
  309. headers={"Content-Type": "application/json"},
  310. )
  311. response.raise_for_status()
  312. data = response.json()
  313. except Exception as e:
  314. return ToolResult(title="建议词获取失败", output="", error=str(e))
  315. suggestion_count = sum(len(item.get("list", [])) for item in data.get("data", []))
  316. return ToolResult(
  317. title=f"建议词: {keyword} ({channel})",
  318. output=json.dumps(data, ensure_ascii=False, indent=2),
  319. long_term_memory=f"Got {suggestion_count} suggestions for '{keyword}' on {channel}",
  320. )
  321. # ── 拼图辅助 ──
  322. async def _build_collage(posts: List[Dict[str, Any]]) -> Optional[str]:
  323. """封面图网格拼图"""
  324. urls, titles = [], []
  325. for post in posts:
  326. imgs = post.get("images", [])
  327. if imgs and imgs[0]:
  328. urls.append(imgs[0])
  329. base_title = _strip_html(post.get("title", ""))
  330. score = post.get("_quality_score")
  331. if score is not None:
  332. title_with_score = f"[{score}分] {base_title}"
  333. else:
  334. title_with_score = base_title
  335. titles.append(title_with_score)
  336. if not urls:
  337. return None
  338. loaded = await load_images(urls)
  339. valid_images, valid_labels = [], []
  340. for (_, img), title in zip(loaded, titles):
  341. if img is not None:
  342. valid_images.append(img)
  343. valid_labels.append(title)
  344. if not valid_images:
  345. return None
  346. grid = build_image_grid(images=valid_images, labels=valid_labels)
  347. import io
  348. buf = io.BytesIO()
  349. grid.save(buf, format="PNG")
  350. img_bytes = buf.getvalue()
  351. # 尝试上传到 CDN,替换冗长的 base64
  352. try:
  353. from agent.tools.builtin.file.image_cdn import _upload_bytes_to_oss
  354. import hashlib
  355. md5_hash = hashlib.md5(img_bytes).hexdigest()[:12]
  356. filename = f"collage_search_{md5_hash}.png"
  357. cdn_url = await _upload_bytes_to_oss(img_bytes, filename)
  358. return {"type": "url", "url": cdn_url}
  359. except Exception as e:
  360. import logging
  361. logging.getLogger(__name__).warning("Failed to upload collage to CDN: %s", e)
  362. # 降级:还是用 base64 但可能会超长
  363. b64, _ = encode_base64(grid, format="PNG")
  364. return {"type": "base64", "media_type": "image/png", "data": b64}
  365. # ── 注册所有 AIGC 平台 ──
  366. def _register_all():
  367. for p in _AIGC_PLATFORMS:
  368. p.search_impl = search
  369. # Bind each platform's id into detail_impl so the shared detail() knows
  370. # whether to trigger video transcription (only for sph/douyin).
  371. p.detail_impl = (
  372. lambda post, extras, _pid=p.id: detail(post, extras, _pid) # noqa: B023 (default-arg captures pid)
  373. )
  374. if p.supports_suggest:
  375. p.suggest_impl = suggest
  376. p.suggest_channels = [p.id]
  377. register_platform(p)
  378. # wx 只有 suggest,没有搜索
  379. # suggest 调用时 channel 传 "wx",但不注册为独立平台
  380. _register_all()