""" 根据特征匹配高赞视频的选题解构信息(占位版)。 当前阶段没有真实接口:先把“工具签名 + 返回结构”固定,内部临时返回空列表。 后续接入数据源时,只需要填充 metadata.videos 的内容,不改调用方。 """ from __future__ import annotations import json import os from typing import Any, Dict, Iterable, List, Optional, Sequence, Set from agent.tools import ToolResult, tool from db import get_open_aigc_pattern_connection, get_connection JsonDict = Dict[str, Any] def _split_features(features: Optional[str]) -> List[str]: raw = (features or "").strip() if not raw: return [] # 同时支持英文/中文逗号 parts = raw.replace(",", ",").split(",") out: List[str] = [] for p in parts: s = p.strip() if s: out.append(s) return out def _coerce_points(value: Any) -> List[str]: """ 将 DB 字段的 points 归一成 List[str]。 允许 value 形态: - list: 过滤出非空字符串 - str: 优先当作 JSON 解析;失败则按逗号/换行切分 - None/其他:空列表 """ if value is None: return [] if isinstance(value, list): return [x.strip() for x in value if isinstance(x, str) and x.strip()] if isinstance(value, str): s = value.strip() if not s: return [] try: parsed = json.loads(s) except Exception: parsed = None if isinstance(parsed, list): return [x.strip() for x in parsed if isinstance(x, str) and x.strip()] # 兜底:按常见分隔符切开 parts: List[str] = [] for token in s.replace(",", ",").replace("\n", ",").split(","): t = token.strip() if t: parts.append(t) return parts return [] def _intersect_post_ids(feature_post_id_sets: Sequence[Set[str]]) -> List[str]: if not feature_post_id_sets: return [] inter = set(feature_post_id_sets[0]) for s in feature_post_id_sets[1:]: inter &= s if not inter: return [] # 稳定输出:按字符串排序 return sorted(inter) def _query_post_ids_by_feature(conn, feature: str) -> Set[str]: """ 在 element_classification_mapping 中按 feature 查关联 post_id。 """ sql = f""" SELECT DISTINCT post_id FROM element_classification_mapping WHERE name = %s """ with conn.cursor() as cur: cur.execute(sql, (feature,)) rows = cur.fetchall() or [] out: Set[str] = set() for r in rows: pid = (r.get("post_id") if isinstance(r, dict) else None) if pid is None: continue s = str(pid).strip() if s: out.add(s) return out def _chunked(items: Sequence[str], size: int) -> Iterable[Sequence[str]]: if size <= 0: yield items return for i in range(0, len(items), size): yield items[i : i + size] def _query_points_by_post_ids(conn, post_ids: List[str]) -> Dict[str, JsonDict]: """ 从 workflow_decode_task_result 取三类 points,按 post_id 映射。 - post_id 对应字段:channel_content_id - 返回字段:purpose_points / key_points / inspiration_points """ if not post_ids: return {} result: Dict[str, JsonDict] = {} # IN 过长时分批(默认 200) batch_size = int(os.getenv("GET_VIDEO_TOPIC_IN_BATCH_SIZE", "200")) for batch in _chunked(post_ids, batch_size): placeholders = ",".join(["%s"] * len(batch)) sql = f""" SELECT channel_content_id, purpose_points, key_points, inspiration_points FROM workflow_decode_task_result WHERE channel_content_id IN ({placeholders}) """ with conn.cursor() as cur: cur.execute(sql, tuple(batch)) rows = cur.fetchall() or [] for r in rows: if not isinstance(r, dict): continue pid = str(r.get("channel_content_id") or "").strip() if not pid: continue result[pid] = { "goal_points": _coerce_points(r.get("purpose_points")), "key_points": _coerce_points(r.get("key_points")), "inspiration_points": _coerce_points(r.get("inspiration_points")), } return result @tool(description="根据特征匹配高赞视频,并返回每个视频的灵感点/目的点/关键点列表(当前占位返回空)") async def get_video_topic( features: str = "", limit: int = 20, ) -> ToolResult: """ Args: features: 特征字符串,逗号分隔(可空)。例:"健康,防骗" limit: 期望返回的最大视频数(当前占位实现不使用) Returns: ToolResult: - metadata.videos: List[{ "inspiration_points": [...], "goal_points": [...], "key_points": [...] }] """ feature_list = _split_features(features) if not feature_list: return ToolResult( title="选题解构", output="features 为空:返回空视频列表(videos=0)。", metadata={"videos": [], "features": [], "limit": limit, "post_ids": []}, ) open_aigc_conn = get_open_aigc_pattern_connection() supply_conn = get_connection() try: feature_sets: List[Set[str]] = [] for f in feature_list: feature_sets.append(_query_post_ids_by_feature(open_aigc_conn, f)) post_ids = _intersect_post_ids(feature_sets) if limit and limit > 0: post_ids = post_ids[: int(limit)] points_map = _query_points_by_post_ids(supply_conn, post_ids) videos: List[JsonDict] = [] for pid in post_ids: item = points_map.get(pid) or { "inspiration_points": [], "goal_points": [], "key_points": [], } # 按约定:每条“视频”只输出三类 points,不额外带 post_id videos.append( { "inspiration_points": item.get("inspiration_points") or [], "goal_points": item.get("goal_points") or [], "key_points": item.get("key_points") or [], } ) return ToolResult( title="选题解构", output=f"命中特征交集 post_id={len(post_ids)},返回 videos={len(videos)}。", metadata={ "videos": videos, "features": feature_list, "limit": limit, # 调试/可追溯:不放在 videos 条目里,避免污染“每条视频字段约定” "post_ids": post_ids, }, long_term_memory=f"Get video topic points by features: {','.join(feature_list)} (videos={len(videos)})", ) finally: conn.close()