| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224 |
- """
- 根据特征匹配高赞视频的选题解构信息(占位版)。
- 当前阶段没有真实接口:先把“工具签名 + 返回结构”固定,内部临时返回空列表。
- 后续接入数据源时,只需要填充 metadata.videos 的内容,不改调用方。
- """
- from __future__ import annotations
- import json
- import os
- from typing import Any, Dict, Iterable, List, Optional, Sequence, Set
- from agent.tools import ToolResult, tool
- from db import get_open_aigc_pattern_connection, get_connection
- JsonDict = Dict[str, Any]
- def _split_features(features: Optional[str]) -> List[str]:
- raw = (features or "").strip()
- if not raw:
- return []
- # 同时支持英文/中文逗号
- parts = raw.replace(",", ",").split(",")
- out: List[str] = []
- for p in parts:
- s = p.strip()
- if s:
- out.append(s)
- return out
- def _coerce_points(value: Any) -> List[str]:
- """
- 将 DB 字段的 points 归一成 List[str]。
- 允许 value 形态:
- - list: 过滤出非空字符串
- - str: 优先当作 JSON 解析;失败则按逗号/换行切分
- - None/其他:空列表
- """
- if value is None:
- return []
- if isinstance(value, list):
- return [x.strip() for x in value if isinstance(x, str) and x.strip()]
- if isinstance(value, str):
- s = value.strip()
- if not s:
- return []
- try:
- parsed = json.loads(s)
- except Exception:
- parsed = None
- if isinstance(parsed, list):
- return [x.strip() for x in parsed if isinstance(x, str) and x.strip()]
- # 兜底:按常见分隔符切开
- parts: List[str] = []
- for token in s.replace(",", ",").replace("\n", ",").split(","):
- t = token.strip()
- if t:
- parts.append(t)
- return parts
- return []
- def _intersect_post_ids(feature_post_id_sets: Sequence[Set[str]]) -> List[str]:
- if not feature_post_id_sets:
- return []
- inter = set(feature_post_id_sets[0])
- for s in feature_post_id_sets[1:]:
- inter &= s
- if not inter:
- return []
- # 稳定输出:按字符串排序
- return sorted(inter)
- def _query_post_ids_by_feature(conn, feature: str) -> Set[str]:
- """
- 在 element_classification_mapping 中按 feature 查关联 post_id。
- """
- sql = f"""
- SELECT DISTINCT post_id
- FROM element_classification_mapping
- WHERE name = %s
- """
- with conn.cursor() as cur:
- cur.execute(sql, (feature,))
- rows = cur.fetchall() or []
- out: Set[str] = set()
- for r in rows:
- pid = (r.get("post_id") if isinstance(r, dict) else None)
- if pid is None:
- continue
- s = str(pid).strip()
- if s:
- out.add(s)
- return out
- def _chunked(items: Sequence[str], size: int) -> Iterable[Sequence[str]]:
- if size <= 0:
- yield items
- return
- for i in range(0, len(items), size):
- yield items[i : i + size]
- def _query_points_by_post_ids(conn, post_ids: List[str]) -> Dict[str, JsonDict]:
- """
- 从 workflow_decode_task_result 取三类 points,按 post_id 映射。
- - post_id 对应字段:channel_content_id
- - 返回字段:purpose_points / key_points / inspiration_points
- """
- if not post_ids:
- return {}
- result: Dict[str, JsonDict] = {}
- # IN 过长时分批(默认 200)
- batch_size = int(os.getenv("GET_VIDEO_TOPIC_IN_BATCH_SIZE", "200"))
- for batch in _chunked(post_ids, batch_size):
- placeholders = ",".join(["%s"] * len(batch))
- sql = f"""
- SELECT
- channel_content_id,
- purpose_points,
- key_points,
- inspiration_points
- FROM workflow_decode_task_result
- WHERE channel_content_id IN ({placeholders})
- """
- with conn.cursor() as cur:
- cur.execute(sql, tuple(batch))
- rows = cur.fetchall() or []
- for r in rows:
- if not isinstance(r, dict):
- continue
- pid = str(r.get("channel_content_id") or "").strip()
- if not pid:
- continue
- result[pid] = {
- "goal_points": _coerce_points(r.get("purpose_points")),
- "key_points": _coerce_points(r.get("key_points")),
- "inspiration_points": _coerce_points(r.get("inspiration_points")),
- }
- return result
- @tool(description="根据特征匹配高赞视频,并返回每个视频的灵感点/目的点/关键点列表(当前占位返回空)")
- async def get_video_topic(
- features: str = "",
- limit: int = 20,
- ) -> ToolResult:
- """
- Args:
- features: 特征字符串,逗号分隔(可空)。例:"健康,防骗"
- limit: 期望返回的最大视频数(当前占位实现不使用)
- Returns:
- ToolResult:
- - metadata.videos: List[{
- "inspiration_points": [...],
- "goal_points": [...],
- "key_points": [...]
- }]
- """
- feature_list = _split_features(features)
- if not feature_list:
- return ToolResult(
- title="选题解构",
- output="features 为空:返回空视频列表(videos=0)。",
- metadata={"videos": [], "features": [], "limit": limit, "post_ids": []},
- )
- open_aigc_conn = get_open_aigc_pattern_connection()
- supply_conn = get_connection()
- try:
- feature_sets: List[Set[str]] = []
- for f in feature_list:
- feature_sets.append(_query_post_ids_by_feature(open_aigc_conn, f))
- post_ids = _intersect_post_ids(feature_sets)
- if limit and limit > 0:
- post_ids = post_ids[: int(limit)]
- points_map = _query_points_by_post_ids(supply_conn, post_ids)
- videos: List[JsonDict] = []
- for pid in post_ids:
- item = points_map.get(pid) or {
- "inspiration_points": [],
- "goal_points": [],
- "key_points": [],
- }
- # 按约定:每条“视频”只输出三类 points,不额外带 post_id
- videos.append(
- {
- "inspiration_points": item.get("inspiration_points") or [],
- "goal_points": item.get("goal_points") or [],
- "key_points": item.get("key_points") or [],
- }
- )
- return ToolResult(
- title="选题解构",
- output=f"命中特征交集 post_id={len(post_ids)},返回 videos={len(videos)}。",
- metadata={
- "videos": videos,
- "features": feature_list,
- "limit": limit,
- # 调试/可追溯:不放在 videos 条目里,避免污染“每条视频字段约定”
- "post_ids": post_ids,
- },
- long_term_memory=f"Get video topic points by features: {','.join(feature_list)} (videos={len(videos)})",
- )
- finally:
- conn.close()
|