get_video_topic.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. """
  2. 根据特征匹配高赞视频的选题解构信息(占位版)。
  3. 当前阶段没有真实接口:先把“工具签名 + 返回结构”固定,内部临时返回空列表。
  4. 后续接入数据源时,只需要填充 metadata.videos 的内容,不改调用方。
  5. """
  6. from __future__ import annotations
  7. import json
  8. import os
  9. from typing import Any, Dict, Iterable, List, Optional, Sequence, Set
  10. from agent.tools import ToolResult, tool
  11. from db import get_open_aigc_pattern_connection, get_connection
  12. JsonDict = Dict[str, Any]
  13. def _split_features(features: Optional[str]) -> List[str]:
  14. raw = (features or "").strip()
  15. if not raw:
  16. return []
  17. # 同时支持英文/中文逗号
  18. parts = raw.replace(",", ",").split(",")
  19. out: List[str] = []
  20. for p in parts:
  21. s = p.strip()
  22. if s:
  23. out.append(s)
  24. return out
  25. def _coerce_points(value: Any) -> List[str]:
  26. """
  27. 将 DB 字段的 points 归一成 List[str]。
  28. 允许 value 形态:
  29. - list: 过滤出非空字符串
  30. - str: 优先当作 JSON 解析;失败则按逗号/换行切分
  31. - None/其他:空列表
  32. """
  33. if value is None:
  34. return []
  35. if isinstance(value, list):
  36. return [x.strip() for x in value if isinstance(x, str) and x.strip()]
  37. if isinstance(value, str):
  38. s = value.strip()
  39. if not s:
  40. return []
  41. try:
  42. parsed = json.loads(s)
  43. except Exception:
  44. parsed = None
  45. if isinstance(parsed, list):
  46. return [x.strip() for x in parsed if isinstance(x, str) and x.strip()]
  47. # 兜底:按常见分隔符切开
  48. parts: List[str] = []
  49. for token in s.replace(",", ",").replace("\n", ",").split(","):
  50. t = token.strip()
  51. if t:
  52. parts.append(t)
  53. return parts
  54. return []
  55. def _intersect_post_ids(feature_post_id_sets: Sequence[Set[str]]) -> List[str]:
  56. if not feature_post_id_sets:
  57. return []
  58. inter = set(feature_post_id_sets[0])
  59. for s in feature_post_id_sets[1:]:
  60. inter &= s
  61. if not inter:
  62. return []
  63. # 稳定输出:按字符串排序
  64. return sorted(inter)
  65. def _query_post_ids_by_feature(conn, feature: str) -> Set[str]:
  66. """
  67. 在 element_classification_mapping 中按 feature 查关联 post_id。
  68. """
  69. sql = f"""
  70. SELECT DISTINCT post_id
  71. FROM element_classification_mapping
  72. WHERE name = %s
  73. """
  74. with conn.cursor() as cur:
  75. cur.execute(sql, (feature,))
  76. rows = cur.fetchall() or []
  77. out: Set[str] = set()
  78. for r in rows:
  79. pid = (r.get("post_id") if isinstance(r, dict) else None)
  80. if pid is None:
  81. continue
  82. s = str(pid).strip()
  83. if s:
  84. out.add(s)
  85. return out
  86. def _chunked(items: Sequence[str], size: int) -> Iterable[Sequence[str]]:
  87. if size <= 0:
  88. yield items
  89. return
  90. for i in range(0, len(items), size):
  91. yield items[i : i + size]
  92. def _query_points_by_post_ids(conn, post_ids: List[str]) -> Dict[str, JsonDict]:
  93. """
  94. 从 workflow_decode_task_result 取三类 points,按 post_id 映射。
  95. - post_id 对应字段:channel_content_id
  96. - 返回字段:purpose_points / key_points / inspiration_points
  97. """
  98. if not post_ids:
  99. return {}
  100. result: Dict[str, JsonDict] = {}
  101. # IN 过长时分批(默认 200)
  102. batch_size = int(os.getenv("GET_VIDEO_TOPIC_IN_BATCH_SIZE", "200"))
  103. for batch in _chunked(post_ids, batch_size):
  104. placeholders = ",".join(["%s"] * len(batch))
  105. sql = f"""
  106. SELECT
  107. channel_content_id,
  108. purpose_points,
  109. key_points,
  110. inspiration_points
  111. FROM workflow_decode_task_result
  112. WHERE channel_content_id IN ({placeholders})
  113. """
  114. with conn.cursor() as cur:
  115. cur.execute(sql, tuple(batch))
  116. rows = cur.fetchall() or []
  117. for r in rows:
  118. if not isinstance(r, dict):
  119. continue
  120. pid = str(r.get("channel_content_id") or "").strip()
  121. if not pid:
  122. continue
  123. result[pid] = {
  124. "goal_points": _coerce_points(r.get("purpose_points")),
  125. "key_points": _coerce_points(r.get("key_points")),
  126. "inspiration_points": _coerce_points(r.get("inspiration_points")),
  127. }
  128. return result
  129. @tool(description="根据特征匹配高赞视频,并返回每个视频的灵感点/目的点/关键点列表(当前占位返回空)")
  130. async def get_video_topic(
  131. features: str = "",
  132. limit: int = 20,
  133. ) -> ToolResult:
  134. """
  135. Args:
  136. features: 特征字符串,逗号分隔(可空)。例:"健康,防骗"
  137. limit: 期望返回的最大视频数(当前占位实现不使用)
  138. Returns:
  139. ToolResult:
  140. - metadata.videos: List[{
  141. "inspiration_points": [...],
  142. "goal_points": [...],
  143. "key_points": [...]
  144. }]
  145. """
  146. feature_list = _split_features(features)
  147. if not feature_list:
  148. return ToolResult(
  149. title="选题解构",
  150. output="features 为空:返回空视频列表(videos=0)。",
  151. metadata={"videos": [], "features": [], "limit": limit, "post_ids": []},
  152. )
  153. open_aigc_conn = get_open_aigc_pattern_connection()
  154. supply_conn = get_connection()
  155. try:
  156. feature_sets: List[Set[str]] = []
  157. for f in feature_list:
  158. feature_sets.append(_query_post_ids_by_feature(open_aigc_conn, f))
  159. post_ids = _intersect_post_ids(feature_sets)
  160. if limit and limit > 0:
  161. post_ids = post_ids[: int(limit)]
  162. points_map = _query_points_by_post_ids(supply_conn, post_ids)
  163. videos: List[JsonDict] = []
  164. for pid in post_ids:
  165. item = points_map.get(pid) or {
  166. "inspiration_points": [],
  167. "goal_points": [],
  168. "key_points": [],
  169. }
  170. # 按约定:每条“视频”只输出三类 points,不额外带 post_id
  171. videos.append(
  172. {
  173. "inspiration_points": item.get("inspiration_points") or [],
  174. "goal_points": item.get("goal_points") or [],
  175. "key_points": item.get("key_points") or [],
  176. }
  177. )
  178. return ToolResult(
  179. title="选题解构",
  180. output=f"命中特征交集 post_id={len(post_ids)},返回 videos={len(videos)}。",
  181. metadata={
  182. "videos": videos,
  183. "features": feature_list,
  184. "limit": limit,
  185. # 调试/可追溯:不放在 videos 条目里,避免污染“每条视频字段约定”
  186. "post_ids": post_ids,
  187. },
  188. long_term_memory=f"Get video topic points by features: {','.join(feature_list)} (videos={len(videos)})",
  189. )
  190. finally:
  191. conn.close()