|
|
@@ -7,49 +7,159 @@
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
-from dataclasses import dataclass
|
|
|
-from typing import Any, Dict, List, Optional
|
|
|
+import json
|
|
|
+import os
|
|
|
+from typing import Any, Dict, Iterable, List, Optional, Sequence, Set
|
|
|
|
|
|
from agent.tools import ToolResult, tool
|
|
|
|
|
|
+from db import get_open_aigc_pattern_connection, get_connection
|
|
|
+
|
|
|
JsonDict = Dict[str, Any]
|
|
|
|
|
|
|
|
|
-@dataclass(frozen=True)
|
|
|
-class VideoTopicItem:
|
|
|
+def _split_features(features: Optional[str]) -> List[str]:
|
|
|
+ raw = (features or "").strip()
|
|
|
+ if not raw:
|
|
|
+ return []
|
|
|
+ # 同时支持英文/中文逗号
|
|
|
+ parts = raw.replace(",", ",").split(",")
|
|
|
+ out: List[str] = []
|
|
|
+ for p in parts:
|
|
|
+ s = p.strip()
|
|
|
+ if s:
|
|
|
+ out.append(s)
|
|
|
+ return out
|
|
|
+
|
|
|
+
|
|
|
+def _coerce_points(value: Any) -> List[str]:
|
|
|
"""
|
|
|
- 单条视频的选题点结构(仅保留三类列表)。
|
|
|
+ 将 DB 字段的 points 归一成 List[str]。
|
|
|
|
|
|
- - inspiration_points: 灵感点列表
|
|
|
- - goal_points: 目的点列表
|
|
|
- - key_points: 关键点列表
|
|
|
+ 允许 value 形态:
|
|
|
+ - list: 过滤出非空字符串
|
|
|
+ - str: 优先当作 JSON 解析;失败则按逗号/换行切分
|
|
|
+ - None/其他:空列表
|
|
|
"""
|
|
|
|
|
|
- inspiration_points: List[str]
|
|
|
- goal_points: List[str]
|
|
|
- key_points: List[str]
|
|
|
+ if value is None:
|
|
|
+ return []
|
|
|
+ if isinstance(value, list):
|
|
|
+ return [x.strip() for x in value if isinstance(x, str) and x.strip()]
|
|
|
+ if isinstance(value, str):
|
|
|
+ s = value.strip()
|
|
|
+ if not s:
|
|
|
+ return []
|
|
|
+ try:
|
|
|
+ parsed = json.loads(s)
|
|
|
+ except Exception:
|
|
|
+ parsed = None
|
|
|
+ if isinstance(parsed, list):
|
|
|
+ return [x.strip() for x in parsed if isinstance(x, str) and x.strip()]
|
|
|
+ # 兜底:按常见分隔符切开
|
|
|
+ parts: List[str] = []
|
|
|
+ for token in s.replace(",", ",").replace("\n", ",").split(","):
|
|
|
+ t = token.strip()
|
|
|
+ if t:
|
|
|
+ parts.append(t)
|
|
|
+ return parts
|
|
|
+ return []
|
|
|
|
|
|
- def to_dict(self) -> JsonDict:
|
|
|
- return {
|
|
|
- "inspiration_points": self.inspiration_points,
|
|
|
- "goal_points": self.goal_points,
|
|
|
- "key_points": self.key_points,
|
|
|
- }
|
|
|
|
|
|
+def _intersect_post_ids(feature_post_id_sets: Sequence[Set[str]]) -> List[str]:
|
|
|
+ if not feature_post_id_sets:
|
|
|
+ return []
|
|
|
+ inter = set(feature_post_id_sets[0])
|
|
|
+ for s in feature_post_id_sets[1:]:
|
|
|
+ inter &= s
|
|
|
+ if not inter:
|
|
|
+ return []
|
|
|
+ # 稳定输出:按字符串排序
|
|
|
+ return sorted(inter)
|
|
|
|
|
|
-def _empty_videos() -> List[JsonDict]:
|
|
|
- # 约定:返回“视频列表”,但当前无接口先返回空。
|
|
|
- return []
|
|
|
+
|
|
|
+def _query_post_ids_by_feature(conn, feature: str) -> Set[str]:
|
|
|
+ """
|
|
|
+ 在 element_classification_mapping 中按 feature 查关联 post_id。
|
|
|
+ """
|
|
|
+
|
|
|
+ sql = f"""
|
|
|
+ SELECT DISTINCT post_id
|
|
|
+ FROM element_classification_mapping
|
|
|
+ WHERE name = %s
|
|
|
+ """
|
|
|
+ with conn.cursor() as cur:
|
|
|
+ cur.execute(sql, (feature,))
|
|
|
+ rows = cur.fetchall() or []
|
|
|
+ out: Set[str] = set()
|
|
|
+ for r in rows:
|
|
|
+ pid = (r.get("post_id") if isinstance(r, dict) else None)
|
|
|
+ if pid is None:
|
|
|
+ continue
|
|
|
+ s = str(pid).strip()
|
|
|
+ if s:
|
|
|
+ out.add(s)
|
|
|
+ return out
|
|
|
+
|
|
|
+
|
|
|
+def _chunked(items: Sequence[str], size: int) -> Iterable[Sequence[str]]:
|
|
|
+ if size <= 0:
|
|
|
+ yield items
|
|
|
+ return
|
|
|
+ for i in range(0, len(items), size):
|
|
|
+ yield items[i : i + size]
|
|
|
+
|
|
|
+
|
|
|
+def _query_points_by_post_ids(conn, post_ids: List[str]) -> Dict[str, JsonDict]:
|
|
|
+ """
|
|
|
+ 从 workflow_decode_task_result 取三类 points,按 post_id 映射。
|
|
|
+
|
|
|
+ - post_id 对应字段:channel_content_id
|
|
|
+ - 返回字段:purpose_points / key_points / inspiration_points
|
|
|
+ """
|
|
|
+
|
|
|
+ if not post_ids:
|
|
|
+ return {}
|
|
|
+
|
|
|
+ result: Dict[str, JsonDict] = {}
|
|
|
+ # IN 过长时分批(默认 200)
|
|
|
+ batch_size = int(os.getenv("GET_VIDEO_TOPIC_IN_BATCH_SIZE", "200"))
|
|
|
+ for batch in _chunked(post_ids, batch_size):
|
|
|
+ placeholders = ",".join(["%s"] * len(batch))
|
|
|
+ sql = f"""
|
|
|
+ SELECT
|
|
|
+ channel_content_id,
|
|
|
+ purpose_points,
|
|
|
+ key_points,
|
|
|
+ inspiration_points
|
|
|
+ FROM workflow_decode_task_result
|
|
|
+ WHERE channel_content_id IN ({placeholders})
|
|
|
+ """
|
|
|
+ with conn.cursor() as cur:
|
|
|
+ cur.execute(sql, tuple(batch))
|
|
|
+ rows = cur.fetchall() or []
|
|
|
+ for r in rows:
|
|
|
+ if not isinstance(r, dict):
|
|
|
+ continue
|
|
|
+ pid = str(r.get("channel_content_id") or "").strip()
|
|
|
+ if not pid:
|
|
|
+ continue
|
|
|
+ result[pid] = {
|
|
|
+ "goal_points": _coerce_points(r.get("purpose_points")),
|
|
|
+ "key_points": _coerce_points(r.get("key_points")),
|
|
|
+ "inspiration_points": _coerce_points(r.get("inspiration_points")),
|
|
|
+ }
|
|
|
+ return result
|
|
|
|
|
|
|
|
|
@tool(description="根据特征匹配高赞视频,并返回每个视频的灵感点/目的点/关键点列表(当前占位返回空)")
|
|
|
async def get_video_topic(
|
|
|
- features: Optional[List[str]] = None,
|
|
|
+ features: str = "",
|
|
|
limit: int = 20,
|
|
|
) -> ToolResult:
|
|
|
"""
|
|
|
Args:
|
|
|
- features: 特征/关键词列表(可空)
|
|
|
+ features: 特征字符串,逗号分隔(可空)。例:"健康,防骗"
|
|
|
limit: 期望返回的最大视频数(当前占位实现不使用)
|
|
|
|
|
|
Returns:
|
|
|
@@ -61,13 +171,54 @@ async def get_video_topic(
|
|
|
}]
|
|
|
"""
|
|
|
|
|
|
- _ = features
|
|
|
- _ = limit
|
|
|
-
|
|
|
- videos = _empty_videos()
|
|
|
- return ToolResult(
|
|
|
- title="选题解构(占位)",
|
|
|
- output=f"当前无可用接口,临时返回空视频列表(videos=0)。",
|
|
|
- metadata={"videos": videos, "features": features or [], "limit": limit},
|
|
|
- long_term_memory="Get video topic decomposition (placeholder, empty result).",
|
|
|
- )
|
|
|
+ feature_list = _split_features(features)
|
|
|
+ if not feature_list:
|
|
|
+ return ToolResult(
|
|
|
+ title="选题解构",
|
|
|
+ output="features 为空:返回空视频列表(videos=0)。",
|
|
|
+ metadata={"videos": [], "features": [], "limit": limit, "post_ids": []},
|
|
|
+ )
|
|
|
+
|
|
|
+ open_aigc_conn = get_open_aigc_pattern_connection()
|
|
|
+ supply_conn = get_connection()
|
|
|
+ try:
|
|
|
+ feature_sets: List[Set[str]] = []
|
|
|
+ for f in feature_list:
|
|
|
+ feature_sets.append(_query_post_ids_by_feature(open_aigc_conn, f))
|
|
|
+
|
|
|
+ post_ids = _intersect_post_ids(feature_sets)
|
|
|
+ if limit and limit > 0:
|
|
|
+ post_ids = post_ids[: int(limit)]
|
|
|
+
|
|
|
+ points_map = _query_points_by_post_ids(supply_conn, post_ids)
|
|
|
+
|
|
|
+ videos: List[JsonDict] = []
|
|
|
+ for pid in post_ids:
|
|
|
+ item = points_map.get(pid) or {
|
|
|
+ "inspiration_points": [],
|
|
|
+ "goal_points": [],
|
|
|
+ "key_points": [],
|
|
|
+ }
|
|
|
+ # 按约定:每条“视频”只输出三类 points,不额外带 post_id
|
|
|
+ videos.append(
|
|
|
+ {
|
|
|
+ "inspiration_points": item.get("inspiration_points") or [],
|
|
|
+ "goal_points": item.get("goal_points") or [],
|
|
|
+ "key_points": item.get("key_points") or [],
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ return ToolResult(
|
|
|
+ title="选题解构",
|
|
|
+ output=f"命中特征交集 post_id={len(post_ids)},返回 videos={len(videos)}。",
|
|
|
+ metadata={
|
|
|
+ "videos": videos,
|
|
|
+ "features": feature_list,
|
|
|
+ "limit": limit,
|
|
|
+ # 调试/可追溯:不放在 videos 条目里,避免污染“每条视频字段约定”
|
|
|
+ "post_ids": post_ids,
|
|
|
+ },
|
|
|
+ long_term_memory=f"Get video topic points by features: {','.join(feature_list)} (videos={len(videos)})",
|
|
|
+ )
|
|
|
+ finally:
|
|
|
+ conn.close()
|