get_video_topic.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. """
  2. 根据特征匹配高赞视频的选题解构信息(占位版)。
  3. 当前阶段没有真实接口:先把“工具签名 + 返回结构”固定,内部临时返回空列表。
  4. 后续接入数据源时,只需要填充 metadata.videos 的内容,不改调用方。
  5. """
  6. from __future__ import annotations
  7. import asyncio
  8. import json
  9. import os
  10. import sys
  11. from pathlib import Path
  12. from typing import Any, Dict, Iterable, List, Optional, Sequence, Set
  13. # 兼容直接执行该文件:将仓库根目录加入模块搜索路径
  14. sys.path.insert(0, str(Path(__file__).resolve().parents[1])) # content_finder 根目录(用于 db/utils)
  15. sys.path.insert(0, str(Path(__file__).resolve().parents[3])) # 仓库根目录(用于 agent)
  16. from agent.tools import ToolResult, tool
  17. import pymysql
  18. from db import get_open_aigc_pattern_connection, get_connection
  19. from utils.tool_logging import format_tool_result_for_log, log_tool_call
  20. JsonDict = Dict[str, Any]
  21. def _split_features(features: Optional[str]) -> List[str]:
  22. raw = (features or "").strip()
  23. if not raw:
  24. return []
  25. # 同时支持英文/中文逗号
  26. parts = raw.replace(",", ",").split(",")
  27. out: List[str] = []
  28. for p in parts:
  29. s = p.strip()
  30. if s:
  31. out.append(s)
  32. return out
  33. def _coerce_points(value: Any) -> List[str]:
  34. """
  35. 将 DB 字段的 points 归一成 List[str]。
  36. 允许 value 形态:
  37. - list: 过滤出非空字符串
  38. - str: 优先当作 JSON 解析;失败则按逗号/换行切分
  39. - None/其他:空列表
  40. """
  41. if value is None:
  42. return []
  43. if isinstance(value, list):
  44. return [x.strip() for x in value if isinstance(x, str) and x.strip()]
  45. if isinstance(value, str):
  46. s = value.strip()
  47. if not s:
  48. return []
  49. try:
  50. parsed = json.loads(s)
  51. except Exception:
  52. parsed = None
  53. if isinstance(parsed, list):
  54. return [x.strip() for x in parsed if isinstance(x, str) and x.strip()]
  55. # 兜底:按常见分隔符切开
  56. parts: List[str] = []
  57. for token in s.replace(",", ",").replace("\n", ",").split(","):
  58. t = token.strip()
  59. if t:
  60. parts.append(t)
  61. return parts
  62. return []
  63. def _intersect_post_ids(feature_post_id_sets: Sequence[Set[str]]) -> List[str]:
  64. if not feature_post_id_sets:
  65. return []
  66. inter = set(feature_post_id_sets[0])
  67. for s in feature_post_id_sets[1:]:
  68. inter &= s
  69. if not inter:
  70. return []
  71. # 稳定输出:按字符串排序
  72. return sorted(inter)
  73. def _query_post_ids_by_feature(conn, feature: str) -> Set[str]:
  74. """
  75. 在 element_classification_mapping 中按 feature 查关联 post_id。
  76. """
  77. sql = f"""
  78. SELECT DISTINCT post_id
  79. FROM topic_pattern_element
  80. WHERE name = %s
  81. """
  82. with conn.cursor() as cur:
  83. cur.execute(sql, (feature,))
  84. rows = cur.fetchall() or []
  85. out: Set[str] = set()
  86. for r in rows:
  87. pid = (r.get("post_id") if isinstance(r, dict) else None)
  88. if pid is None:
  89. continue
  90. s = str(pid).strip()
  91. if s:
  92. out.add(s)
  93. return out
  94. def _chunked(items: Sequence[str], size: int) -> Iterable[Sequence[str]]:
  95. if size <= 0:
  96. yield items
  97. return
  98. for i in range(0, len(items), size):
  99. yield items[i : i + size]
  100. def _coerce_limit(value: Any, default: int = 20) -> int:
  101. if value is None:
  102. return default
  103. if isinstance(value, bool):
  104. return default
  105. if isinstance(value, int):
  106. return value
  107. if isinstance(value, str):
  108. s = value.strip()
  109. if not s:
  110. return default
  111. try:
  112. return int(s)
  113. except Exception:
  114. return default
  115. try:
  116. return int(value)
  117. except Exception:
  118. return default
  119. def _query_points_by_post_ids(conn, post_ids: List[str]) -> Dict[str, JsonDict]:
  120. """
  121. 从 workflow_decode_task_result 取三类 points,按 post_id 映射。
  122. - post_id 对应字段:channel_content_id
  123. - 返回字段:purpose_points / key_points / inspiration_points
  124. """
  125. if not post_ids:
  126. return {}
  127. result: Dict[str, JsonDict] = {}
  128. # IN 过长时分批(默认 200)
  129. batch_size = int(os.getenv("GET_VIDEO_TOPIC_IN_BATCH_SIZE", "200"))
  130. for batch in _chunked(post_ids, batch_size):
  131. placeholders = ",".join(["%s"] * len(batch))
  132. sql = f"""
  133. SELECT
  134. channel_content_id,
  135. purpose_points,
  136. key_points,
  137. inspiration_points
  138. FROM workflow_decode_task_result
  139. WHERE channel_content_id IN ({placeholders})
  140. """
  141. with conn.cursor() as cur:
  142. cur.execute(sql, tuple(batch))
  143. rows = cur.fetchall() or []
  144. for r in rows:
  145. if not isinstance(r, dict):
  146. continue
  147. pid = str(r.get("channel_content_id") or "").strip()
  148. if not pid:
  149. continue
  150. result[pid] = {
  151. "goal_points": _coerce_points(r.get("purpose_points")),
  152. "key_points": _coerce_points(r.get("key_points")),
  153. "inspiration_points": _coerce_points(r.get("inspiration_points")),
  154. }
  155. return result
  156. @tool(description="根据特征匹配高赞视频,并返回每个视频的灵感点/目的点/关键点列表(当前占位返回空)")
  157. async def get_video_topic(
  158. features: str = "",
  159. limit: int = 20,
  160. ) -> ToolResult:
  161. """
  162. Args:
  163. features: 特征字符串,逗号分隔(可空)。例:"健康,防骗"
  164. limit: 期望返回的最大视频数(当前占位实现不使用)
  165. Returns:
  166. ToolResult:
  167. - metadata.videos: List[{
  168. "inspiration_points": [...],
  169. "goal_points": [...],
  170. "key_points": [...]
  171. }]
  172. """
  173. feature_list = _split_features(features)
  174. limit_int = _coerce_limit(limit, default=20)
  175. call_params: Dict[str, Any] = {"features": features, "feature_list": feature_list, "limit": limit_int}
  176. if not feature_list:
  177. result = ToolResult(
  178. title="选题解构",
  179. output="features 为空:返回空视频列表(videos=0)。",
  180. metadata={"videos": [], "features": [], "limit": limit_int, "post_ids": []},
  181. )
  182. log_tool_call(
  183. "工具调用:get_video_topic -> 高赞视频case选题点匹配",
  184. call_params,
  185. json.dumps(result.metadata.get("videos", []), ensure_ascii=False, indent=2),
  186. )
  187. return result
  188. open_aigc_conn = get_open_aigc_pattern_connection()
  189. supply_conn = get_connection()
  190. try:
  191. feature_sets: List[Set[str]] = []
  192. for f in feature_list:
  193. feature_sets.append(_query_post_ids_by_feature(open_aigc_conn, f))
  194. post_ids = _intersect_post_ids(feature_sets)
  195. if limit_int > 0:
  196. post_ids = post_ids[:limit_int]
  197. points_map = _query_points_by_post_ids(supply_conn, post_ids)
  198. videos: List[JsonDict] = []
  199. for pid in post_ids:
  200. item = points_map.get(pid) or {
  201. "inspiration_points": [],
  202. "goal_points": [],
  203. "key_points": [],
  204. }
  205. # 按约定:每条“视频”只输出三类 points,不额外带 post_id
  206. videos.append(
  207. {
  208. "id": pid,
  209. "选题点": {
  210. "灵感点": item.get("inspiration_points") or [],
  211. "目的点": item.get("goal_points") or [],
  212. "关键点": item.get("key_points") or [],
  213. }
  214. }
  215. )
  216. result = ToolResult(
  217. title="选题解构",
  218. output=f"命中特征交集 post_id={len(post_ids)},返回 videos={json.dumps(videos, ensure_ascii=False)}。",
  219. metadata={
  220. "videos": videos,
  221. "features": feature_list,
  222. "limit": limit_int,
  223. # 调试/可追溯:不放在 videos 条目里,避免污染“每条视频字段约定”
  224. "post_ids": post_ids,
  225. },
  226. long_term_memory=f"Get video topic points by features: {','.join(feature_list)} (videos={videos})",
  227. )
  228. log_tool_call(
  229. "工具调用:get_video_topic -> 高赞视频case选题点匹配",
  230. call_params,
  231. json.dumps(result.metadata.get("videos", []), ensure_ascii=False),
  232. )
  233. return result
  234. finally:
  235. supply_conn.close()
  236. open_aigc_conn.close()
  237. async def main():
  238. result = await get_video_topic(
  239. features ="毛泽东1965年预言"
  240. )
  241. print(result.output)
  242. if __name__ == "__main__":
  243. asyncio.run(main())