aigc_channel.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. """
  2. AIGC-Channel 平台实现(9 个中文平台)
  3. 后端:aigc-channel.aiddit.com
  4. 平台:xhs / gzh / sph / github / toutiao / douyin / bili / zhihu / weibo
  5. """
  6. import json
  7. from typing import Any, Dict, List, Optional
  8. import httpx
  9. from agent.tools.models import ToolResult
  10. from agent.tools.utils.image import build_image_grid, encode_base64, load_images
  11. from agent.tools.builtin.content.registry import (
  12. PlatformDef, ParamSpec, register_platform,
  13. )
  14. BASE_URL = "http://aigc-channel.aiddit.com/aigc/channel"
  15. DEFAULT_TIMEOUT = 60.0
  16. # ── 平台注册 ──
  17. _XHS_SEARCH_PARAMS = {
  18. "sort_type": ParamSpec(
  19. values=["综合排序", "最新发布", "最多点赞"],
  20. default="综合排序",
  21. ),
  22. "publish_time": ParamSpec(
  23. values=["不限", "近1天", "近7天", "近30天"],
  24. default="不限",
  25. ),
  26. "content_type": ParamSpec(
  27. values=["不限", "图文", "视频", "文章"],
  28. default="不限",
  29. ),
  30. "filter_note_range": ParamSpec(
  31. values=["不限", "1分钟以内", "1-5分钟", "5分钟以上"],
  32. default="不限",
  33. note="仅视频内容生效",
  34. ),
  35. }
  36. _COMMON_CONTENT_TYPE = {
  37. "content_type": ParamSpec(
  38. values=["视频", "图文"],
  39. default="",
  40. note="留空不限",
  41. ),
  42. }
  43. # 9 个中文平台定义
  44. _AIGC_PLATFORMS = [
  45. PlatformDef(id="xhs", name="小红书", aliases=["RED", "xiaohongshu"], search_params=_XHS_SEARCH_PARAMS, supports_suggest=True),
  46. PlatformDef(id="gzh", name="公众号", aliases=["微信公众号", "wechat"], search_params=_COMMON_CONTENT_TYPE),
  47. PlatformDef(id="sph", name="视频号", aliases=["微信视频号"], search_params=_COMMON_CONTENT_TYPE),
  48. PlatformDef(id="github", name="GitHub", aliases=["gh"], search_params=_COMMON_CONTENT_TYPE),
  49. PlatformDef(id="toutiao", name="头条", aliases=["今日头条", "toutiao"], search_params=_COMMON_CONTENT_TYPE, supports_suggest=True),
  50. PlatformDef(id="douyin", name="抖音", aliases=["TikTok"], search_params=_COMMON_CONTENT_TYPE, supports_suggest=True),
  51. PlatformDef(id="bili", name="B站", aliases=["哔哩哔哩", "bilibili"], search_params=_COMMON_CONTENT_TYPE, supports_suggest=True),
  52. PlatformDef(id="zhihu", name="知乎", aliases=[], search_params=_COMMON_CONTENT_TYPE, supports_suggest=True),
  53. PlatformDef(id="weibo", name="微博", aliases=["sina"], search_params=_COMMON_CONTENT_TYPE),
  54. ]
  55. # suggest API 额外支持 wx(微信搜一搜),但它不是搜索平台
  56. _SUGGEST_ONLY_CHANNELS = {"wx": "微信"}
  57. # ── 搜索实现 ──
  58. async def search(
  59. platform_id: str,
  60. keyword: str,
  61. max_count: int = 20,
  62. cursor: str = "",
  63. extras: Optional[Dict[str, Any]] = None,
  64. ) -> ToolResult:
  65. """AIGC-Channel 统一搜索"""
  66. extras = extras or {}
  67. if platform_id == "xhs":
  68. payload = {
  69. "type": platform_id,
  70. "keyword": keyword,
  71. "cursor": cursor,
  72. "content_type": extras.get("content_type", "不限"),
  73. "sort_type": extras.get("sort_type", "综合排序"),
  74. "publish_time": extras.get("publish_time", "不限"),
  75. "filter_note_range": extras.get("filter_note_range", "不限"),
  76. }
  77. else:
  78. payload = {
  79. "type": platform_id,
  80. "keyword": keyword,
  81. "cursor": cursor or "0",
  82. "max_count": max_count,
  83. "content_type": extras.get("content_type", ""),
  84. }
  85. try:
  86. async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
  87. response = await client.post(
  88. f"{BASE_URL}/data",
  89. json=payload,
  90. headers={"Content-Type": "application/json"},
  91. )
  92. response.raise_for_status()
  93. data = response.json()
  94. except httpx.HTTPStatusError as e:
  95. return ToolResult(title="搜索失败", output="", error=f"HTTP {e.response.status_code}: {e.response.text}")
  96. except Exception as e:
  97. return ToolResult(title="搜索失败", output="", error=str(e))
  98. posts = data.get("data", [])
  99. # 构建概览摘要
  100. summary_list = []
  101. for idx, post in enumerate(posts, 1):
  102. body = post.get("body_text", "") or ""
  103. title = post.get("title") or body[:20] or ""
  104. summary_list.append({
  105. "index": idx,
  106. "title": title,
  107. "body_text": body[:100] + ("..." if len(body) > 100 else ""),
  108. "like_count": post.get("like_count"),
  109. "comment_count": post.get("comment_count"),
  110. "channel": post.get("channel"),
  111. "link": post.get("link"),
  112. "content_type": post.get("content_type"),
  113. })
  114. # 封面拼图
  115. images = []
  116. try:
  117. collage_obj = await _build_collage(posts)
  118. if collage_obj:
  119. images.append(collage_obj)
  120. except Exception as e:
  121. import logging
  122. logging.getLogger(__name__).warning("Error generating collage: %s", e)
  123. return ToolResult(
  124. title=f"搜索: {keyword} ({platform_id})",
  125. output=json.dumps({"data": summary_list}, ensure_ascii=False, indent=2),
  126. long_term_memory=f"Searched '{keyword}' on {platform_id}, {len(posts)} results. Use content_detail to view full details.",
  127. images=images,
  128. metadata={"posts": posts}, # 完整数据传给上层缓存
  129. )
  130. # ── 详情实现(从缓存获取,不需要额外 HTTP) ──
  131. async def detail(post: Dict[str, Any], extras: Optional[Dict[str, Any]] = None) -> ToolResult:
  132. """返回单条帖子的完整内容"""
  133. title = post.get("title") or post.get("body_text", "")[:30] or "无标题"
  134. images = []
  135. for img_url in post.get("images", []):
  136. if img_url:
  137. images.append({"type": "url", "url": img_url})
  138. return ToolResult(
  139. title=f"详情: {title}",
  140. output=json.dumps(post, ensure_ascii=False, indent=2),
  141. long_term_memory=f"Viewed detail: {title}",
  142. images=images,
  143. )
  144. # ── 建议词实现 ──
  145. async def suggest(channel: str, keyword: str) -> ToolResult:
  146. """获取搜索建议词"""
  147. try:
  148. async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
  149. response = await client.post(
  150. f"{BASE_URL}/suggest",
  151. json={"type": channel, "keyword": keyword},
  152. headers={"Content-Type": "application/json"},
  153. )
  154. response.raise_for_status()
  155. data = response.json()
  156. except Exception as e:
  157. return ToolResult(title="建议词获取失败", output="", error=str(e))
  158. suggestion_count = sum(len(item.get("list", [])) for item in data.get("data", []))
  159. return ToolResult(
  160. title=f"建议词: {keyword} ({channel})",
  161. output=json.dumps(data, ensure_ascii=False, indent=2),
  162. long_term_memory=f"Got {suggestion_count} suggestions for '{keyword}' on {channel}",
  163. )
  164. # ── 拼图辅助 ──
  165. async def _build_collage(posts: List[Dict[str, Any]]) -> Optional[str]:
  166. """封面图网格拼图"""
  167. urls, titles = [], []
  168. for post in posts:
  169. imgs = post.get("images", [])
  170. if imgs and imgs[0]:
  171. urls.append(imgs[0])
  172. titles.append(post.get("title", "") or "")
  173. if not urls:
  174. return None
  175. loaded = await load_images(urls)
  176. valid_images, valid_labels = [], []
  177. for (_, img), title in zip(loaded, titles):
  178. if img is not None:
  179. valid_images.append(img)
  180. valid_labels.append(title)
  181. if not valid_images:
  182. return None
  183. grid = build_image_grid(images=valid_images, labels=valid_labels)
  184. import io
  185. buf = io.BytesIO()
  186. grid.save(buf, format="PNG")
  187. img_bytes = buf.getvalue()
  188. # 尝试上传到 CDN,替换冗长的 base64
  189. try:
  190. from agent.tools.builtin.file.image_cdn import _upload_bytes_to_oss
  191. import hashlib
  192. md5_hash = hashlib.md5(img_bytes).hexdigest()[:12]
  193. filename = f"collage_search_{md5_hash}.png"
  194. cdn_url = await _upload_bytes_to_oss(img_bytes, filename)
  195. return {"type": "url", "url": cdn_url}
  196. except Exception as e:
  197. import logging
  198. logging.getLogger(__name__).warning("Failed to upload collage to CDN: %s", e)
  199. # 降级:还是用 base64 但可能会超长
  200. b64, _ = encode_base64(grid, format="PNG")
  201. return {"type": "base64", "media_type": "image/png", "data": b64}
  202. # ── 注册所有 AIGC 平台 ──
  203. def _register_all():
  204. for p in _AIGC_PLATFORMS:
  205. p.search_impl = search
  206. p.detail_impl = detail
  207. if p.supports_suggest:
  208. p.suggest_impl = suggest
  209. p.suggest_channels = [p.id]
  210. register_platform(p)
  211. # wx 只有 suggest,没有搜索
  212. # suggest 调用时 channel 传 "wx",但不注册为独立平台
  213. _register_all()