| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450 |
- """
- AIGC-Channel 平台实现(9 个中文平台)
- 后端:aigc-channel.aiddit.com
- 平台:xhs / gzh / sph / github / toutiao / douyin / bili / zhihu / weibo
- """
- import json
- import re
- from typing import Any, Dict, List, Optional
- import httpx
- from agent.tools.models import ToolResult
- from agent.tools.utils.image import build_image_grid, encode_base64, load_images
- from agent.tools.builtin.content.registry import (
- PlatformDef, ParamSpec, register_platform,
- )
- BASE_URL = "http://aigc-channel.aiddit.com/aigc/channel"
- DEFAULT_TIMEOUT = 60.0
- # aigc-channel returns search-highlighted titles like
- # '<em class="highlight">关键词</em>'. Strip before any rendering / scoring use.
- _HTML_TAG_RE = re.compile(r"<[^>]+>")
- def _strip_html(text: Optional[str]) -> str:
- if not text:
- return ""
- return _HTML_TAG_RE.sub("", text)
- _SPH_TITLE_MAX = 20 # sph normalized title 截断字符数
- def _normalize_sph_post(post: Dict[str, Any]) -> None:
- """In-place: 视频号没有独立 title,后端把 caption 塞进 title 字段而 body_text 留空。
- 把整段 title 搬到 body_text,title 取剥 HTML 后前 20 字 + '...' 作为短摘要。
- 幂等:如果 body_text 已经有内容则不动,避免重复迁移或覆盖;title 已经 <=20 字
- 也不强加省略号。
- """
- if not isinstance(post, dict):
- return
- raw_title = post.get("title") or ""
- body = post.get("body_text") or ""
- body = body.strip() if isinstance(body, str) else ""
- if not raw_title or body:
- return
- clean = _strip_html(raw_title).strip()
- if not clean:
- return
- post["body_text"] = clean
- if len(clean) > _SPH_TITLE_MAX:
- post["title"] = clean[:_SPH_TITLE_MAX] + "..."
- else:
- post["title"] = clean
- # ── 平台注册 ──
- _XHS_SEARCH_PARAMS = {
- "sort_type": ParamSpec(
- values=["综合排序", "最新发布", "最多点赞"],
- default="综合排序",
- ),
- "publish_time": ParamSpec(
- values=["不限", "近1天", "近7天", "近30天"],
- default="不限",
- ),
- "content_type": ParamSpec(
- values=["不限", "图文", "视频", "文章"],
- default="不限",
- ),
- "filter_note_range": ParamSpec(
- values=["不限", "1分钟以内", "1-5分钟", "5分钟以上"],
- default="不限",
- note="仅视频内容生效",
- ),
- }
- _COMMON_CONTENT_TYPE = {
- "content_type": ParamSpec(
- values=["视频", "图文"],
- default="",
- note="留空不限",
- ),
- }
- # 9 个中文平台定义
- _AIGC_PLATFORMS = [
- PlatformDef(id="xhs", name="小红书", aliases=["RED", "xiaohongshu"], search_params=_XHS_SEARCH_PARAMS, supports_suggest=True),
- PlatformDef(id="gzh", name="公众号", aliases=["微信公众号", "wechat"], search_params=_COMMON_CONTENT_TYPE),
- PlatformDef(id="sph", name="视频号", aliases=["微信视频号"], search_params=_COMMON_CONTENT_TYPE),
- PlatformDef(id="github", name="GitHub", aliases=["gh"], search_params=_COMMON_CONTENT_TYPE),
- PlatformDef(id="toutiao", name="头条", aliases=["今日头条", "toutiao"], search_params=_COMMON_CONTENT_TYPE, supports_suggest=True),
- PlatformDef(id="douyin", name="抖音", aliases=["TikTok"], search_params=_COMMON_CONTENT_TYPE, supports_suggest=True),
- PlatformDef(id="bili", name="B站", aliases=["哔哩哔哩", "bilibili"], search_params=_COMMON_CONTENT_TYPE, supports_suggest=True),
- PlatformDef(id="zhihu", name="知乎", aliases=[], search_params=_COMMON_CONTENT_TYPE, supports_suggest=True),
- PlatformDef(id="weibo", name="微博", aliases=["sina"], search_params=_COMMON_CONTENT_TYPE),
- ]
- # suggest API 额外支持 wx(微信搜一搜),但它不是搜索平台
- _SUGGEST_ONLY_CHANNELS = {"wx": "微信"}
- # ── 搜索实现 ──
- async def search(
- platform_id: str,
- keyword: str,
- max_count: int = 20,
- cursor: str = "",
- extras: Optional[Dict[str, Any]] = None,
- ) -> ToolResult:
- """AIGC-Channel 统一搜索"""
- extras = extras or {}
- if platform_id == "xhs":
- payload = {
- "type": platform_id,
- "keyword": keyword,
- "cursor": cursor,
- "content_type": extras.get("content_type", "不限"),
- "sort_type": extras.get("sort_type", "综合排序"),
- "publish_time": extras.get("publish_time", "不限"),
- "filter_note_range": extras.get("filter_note_range", "不限"),
- }
- else:
- payload = {
- "type": platform_id,
- "keyword": keyword,
- "cursor": cursor or "0",
- "max_count": max_count,
- "content_type": extras.get("content_type", ""),
- }
- try:
- async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
- response = await client.post(
- f"{BASE_URL}/data",
- json=payload,
- headers={"Content-Type": "application/json"},
- )
- response.raise_for_status()
- data = response.json()
- except httpx.HTTPStatusError as e:
- return ToolResult(title="搜索失败", output="", error=f"HTTP {e.response.status_code}: {e.response.text}")
- except Exception as e:
- return ToolResult(title="搜索失败", output="", error=str(e))
- posts = data.get("data", [])
- # sph 字段 normalization:title 太长(后端把 caption 塞进 title),
- # 把它搬到 body_text,title 取前 20 字。在评分 / summary / cache 之前做。
- if platform_id == "sph":
- for p in posts:
- _normalize_sph_post(p)
- # 构建概览摘要
- summary_list = []
- # 动态导入评价模块
- try:
- from examples.process_pipeline.script.evaluate_source_quality import SourceQualityEvaluator
- evaluator = SourceQualityEvaluator()
- except ImportError:
- evaluator = None
- # 视频帖在评分前先并发探测 mp4 duration(HTTP Range,不下载视频流),
- # 让 evaluator 用真实时长替代 body 长度作为内容信号。
- if evaluator and posts:
- try:
- from agent.tools.builtin.content.transcription import probe_durations_for_posts
- await probe_durations_for_posts(platform_id, posts, concurrency=8)
- except Exception as e:
- import logging
- logging.getLogger(__name__).info("duration probe failed: %s", e)
- for idx, post in enumerate(posts, 1):
- body = post.get("body_text", "") or ""
- title = post.get("title") or body[:20] or ""
-
- score_info = {}
- if evaluator:
- try:
- eval_res = evaluator.evaluate_post(post)
- score_info = {
- "quality_score": eval_res["total_score"],
- "quality_grade": eval_res["grade"]
- }
- post["_quality_score"] = eval_res["total_score"]
- post["_quality_grade"] = eval_res["grade"]
- except Exception:
- pass
-
- summary_item = {
- "index": idx,
- "title": title,
- "body_text": body[:100] + ("..." if len(body) > 100 else ""),
- "like_count": post.get("like_count"),
- "comment_count": post.get("comment_count"),
- "channel": post.get("channel"),
- "link": post.get("link"),
- "content_type": post.get("content_type"),
- }
- summary_item.update(score_info)
- summary_list.append(summary_item)
- # 封面拼图
- images = []
- try:
- collage_obj = await _build_collage(posts)
- if collage_obj:
- images.append(collage_obj)
- except Exception as e:
- import logging
- logging.getLogger(__name__).warning("Error generating collage: %s", e)
- return ToolResult(
- title=f"搜索: {keyword} ({platform_id})",
- output=json.dumps({"data": summary_list}, ensure_ascii=False, indent=2),
- long_term_memory=f"Searched '{keyword}' on {platform_id}, {len(posts)} results. Use content_detail to view full details.",
- images=images,
- metadata={"posts": posts}, # 完整数据传给上层缓存
- )
- # ── 详情实现(从缓存获取,不需要额外 HTTP) ──
- MAX_DETAIL_IMAGES = 10 # detail 中保留的图片总数上限(含拼图)
- KEEP_INDIVIDUAL = 8 # 单张图片保留数量;剩余图片合并为 1 张拼图
- async def _build_images_collage(urls: List[str]) -> Optional[Dict[str, Any]]:
- """将一组图片 URL 拼成单张网格图"""
- if not urls:
- return None
- loaded = await load_images(urls)
- valid_images = [img for (_, img) in loaded if img is not None]
- if not valid_images:
- return None
- grid = build_image_grid(images=valid_images, labels=None)
- import io
- buf = io.BytesIO()
- grid.save(buf, format="PNG")
- img_bytes = buf.getvalue()
- try:
- from agent.tools.builtin.file.image_cdn import _upload_bytes_to_oss
- import hashlib
- md5_hash = hashlib.md5(img_bytes).hexdigest()[:12]
- filename = f"collage_detail_{md5_hash}.png"
- cdn_url = await _upload_bytes_to_oss(img_bytes, filename)
- return {"type": "url", "url": cdn_url}
- except Exception as e:
- import logging
- logging.getLogger(__name__).warning("Failed to upload detail collage to CDN: %s", e)
- b64, _ = encode_base64(grid, format="PNG")
- return {"type": "base64", "media_type": "image/png", "data": b64}
- async def detail(
- post: Dict[str, Any],
- extras: Optional[Dict[str, Any]] = None,
- platform_id: str = "",
- ) -> ToolResult:
- """返回单条帖子的完整内容;sph/douyin 视频会通过 Deepgram 自动转写。"""
- title = post.get("title") or post.get("body_text", "")[:30] or "无标题"
- img_urls = [u for u in post.get("images", []) if u]
- images = []
- if len(img_urls) > MAX_DETAIL_IMAGES:
- # 保留前 KEEP_INDIVIDUAL 张原图,剩余拼成 1 张网格图
- for u in img_urls[:KEEP_INDIVIDUAL]:
- images.append({"type": "url", "url": u})
- collage = await _build_images_collage(img_urls[KEEP_INDIVIDUAL:])
- if collage:
- images.append(collage)
- else:
- for u in img_urls:
- images.append({"type": "url", "url": u})
- # 视频字幕:任何 aigc-channel 平台只要 post.videos 字段非空就触发 Deepgram 转写。
- # 下载策略在 transcription._download_video 里按 platform 分支,未指定的平台走
- # "yt-dlp on page URL → httpx direct" 两步兜底。
- #
- # 三态语义(跟 extract_sources._needs_transcribe 对齐):
- # 字段缺失 → 没尝试过,跑 Deepgram
- # 字段 = "" → 尝试过但失败,跳过(保护 Deepgram 额度)
- # 字段 = text → 已成功,复用
- extras_d = extras or {}
- has_video = bool(post.get("videos"))
- field_present = "video_transcript" in post
- transcript_text: Optional[str] = post.get("video_transcript") or None
- if (
- not field_present
- and has_video
- and extras_d.get("include_transcript", True)
- ):
- from agent.tools.builtin.content.transcription import transcribe_video_from_post
- transcribe_error: Optional[str] = None
- try:
- transcript_text = await transcribe_video_from_post(platform_id, post)
- except Exception as e:
- transcript_text = None
- transcribe_error = f"{type(e).__name__}: {e}"
- import logging as _logging
- _logging.getLogger(__name__).warning(
- "transcribe_video_from_post raised for %s: %s", platform_id, e
- )
- # 三态写回:成功 = text;失败/None = "" 作为"已尝试"标记,下次 cache hit 直接短路。
- final_value = transcript_text or ""
- post["video_transcript"] = final_value
- if not final_value:
- # 失败原因暴露到 output JSON,方便 agent/用户判断要不要重试或换平台
- post["_transcribe_error"] = (
- transcribe_error
- or "transcribe returned None (下载/抽音/Deepgram 任一步失败,见 logger.warning)"
- )
- # cache writeback 不再以"成功"为前提:失败的 "" 也写回,让下次 cache hit 短路掉
- import os as _os
- from agent.tools.builtin.content import cache as _cache
- trace_id = extras_d.get("__trace_id__") or _os.getenv("TRACE_ID")
- content_id = (
- post.get("channel_content_id")
- or post.get("content_id")
- or post.get("video_id")
- )
- if trace_id and content_id:
- _cache.update_post_field(
- trace_id, platform_id, content_id, "video_transcript", final_value
- )
- # transcript already embedded as post["video_transcript"] inside the JSON dump;
- # no need to repeat as a separate section.
- output_text = json.dumps(post, ensure_ascii=False, indent=2)
- memory_suffix = " +transcript" if transcript_text else ""
- return ToolResult(
- title=f"详情: {title}",
- output=output_text,
- long_term_memory=f"Viewed detail: {title}{memory_suffix}",
- images=images,
- )
- # ── 建议词实现 ──
- async def suggest(channel: str, keyword: str) -> ToolResult:
- """获取搜索建议词"""
- try:
- async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
- response = await client.post(
- f"{BASE_URL}/suggest",
- json={"type": channel, "keyword": keyword},
- headers={"Content-Type": "application/json"},
- )
- response.raise_for_status()
- data = response.json()
- except Exception as e:
- return ToolResult(title="建议词获取失败", output="", error=str(e))
- suggestion_count = sum(len(item.get("list", [])) for item in data.get("data", []))
- return ToolResult(
- title=f"建议词: {keyword} ({channel})",
- output=json.dumps(data, ensure_ascii=False, indent=2),
- long_term_memory=f"Got {suggestion_count} suggestions for '{keyword}' on {channel}",
- )
- # ── 拼图辅助 ──
- async def _build_collage(posts: List[Dict[str, Any]]) -> Optional[str]:
- """封面图网格拼图"""
- urls, titles = [], []
- for post in posts:
- imgs = post.get("images", [])
- if imgs and imgs[0]:
- urls.append(imgs[0])
- base_title = _strip_html(post.get("title", ""))
- score = post.get("_quality_score")
- if score is not None:
- title_with_score = f"[{score}分] {base_title}"
- else:
- title_with_score = base_title
- titles.append(title_with_score)
- if not urls:
- return None
- loaded = await load_images(urls)
- valid_images, valid_labels = [], []
- for (_, img), title in zip(loaded, titles):
- if img is not None:
- valid_images.append(img)
- valid_labels.append(title)
- if not valid_images:
- return None
- grid = build_image_grid(images=valid_images, labels=valid_labels)
- import io
- buf = io.BytesIO()
- grid.save(buf, format="PNG")
- img_bytes = buf.getvalue()
-
- # 尝试上传到 CDN,替换冗长的 base64
- try:
- from agent.tools.builtin.file.image_cdn import _upload_bytes_to_oss
- import hashlib
-
- md5_hash = hashlib.md5(img_bytes).hexdigest()[:12]
- filename = f"collage_search_{md5_hash}.png"
- cdn_url = await _upload_bytes_to_oss(img_bytes, filename)
- return {"type": "url", "url": cdn_url}
- except Exception as e:
- import logging
- logging.getLogger(__name__).warning("Failed to upload collage to CDN: %s", e)
- # 降级:还是用 base64 但可能会超长
- b64, _ = encode_base64(grid, format="PNG")
- return {"type": "base64", "media_type": "image/png", "data": b64}
- # ── 注册所有 AIGC 平台 ──
- def _register_all():
- for p in _AIGC_PLATFORMS:
- p.search_impl = search
- # Bind each platform's id into detail_impl so the shared detail() knows
- # whether to trigger video transcription (only for sph/douyin).
- p.detail_impl = (
- lambda post, extras, _pid=p.id: detail(post, extras, _pid) # noqa: B023 (default-arg captures pid)
- )
- if p.supports_suggest:
- p.suggest_impl = suggest
- p.suggest_channels = [p.id]
- register_platform(p)
- # wx 只有 suggest,没有搜索
- # suggest 调用时 channel 传 "wx",但不注册为独立平台
- _register_all()
|