Ver Fonte

增加高赞内容选题点查找

jihuaqiang há 1 dia atrás
pai
commit
1e179e9a5f

+ 1 - 1
examples/content_finder/core.py

@@ -59,7 +59,7 @@ from tools import (
 logger = logging.getLogger(__name__)
 
 # 默认搜索词
-DEFAULT_QUERY = "伟人功绩"
+DEFAULT_QUERY = "毛泽东,反腐倡廉"
 DEFAULT_DEMAND_ID = 1
 
 

+ 2 - 0
examples/content_finder/db/__init__.py

@@ -7,6 +7,7 @@
 """
 
 from .connection import get_connection
+from .open_aigc_pattern_connection import get_open_aigc_pattern_connection
 from .schedule import (
     get_next_unprocessed_demand,
     create_task_record,
@@ -17,6 +18,7 @@ from .store_results import upsert_good_authors, insert_contents, update_content_
 
 __all__ = [
     "get_connection",
+    "get_open_aigc_pattern_connection",
     "get_next_unprocessed_demand",
     "create_task_record",
     "update_task_status",

+ 38 - 0
examples/content_finder/db/open_aigc_pattern_connection.py

@@ -0,0 +1,38 @@
+"""open_aigc_pattern 数据库连接封装(与 content_finder 默认 DB_* 不同)。"""
+
+from __future__ import annotations
+
+import os
+from urllib.parse import unquote, urlparse
+
+import pymysql
+
+
+def get_open_aigc_pattern_connection():
+    """
+    获取 open_aigc_pattern 数据库连接。
+    """
+
+
+    host = os.getenv("OPEN_AIGC_PATTERN_DB_HOST", "").strip()
+    port = int(os.getenv("OPEN_AIGC_PATTERN_DB_PORT", "3306"))
+    user = os.getenv("OPEN_AIGC_PATTERN_DB_USER", "").strip()
+    password = os.getenv("OPEN_AIGC_PATTERN_DB_PASSWORD", "")
+    database = os.getenv("OPEN_AIGC_PATTERN_DB_NAME", "open_aigc_pattern").strip()
+    if not all([host, user, database]):
+        raise ValueError(
+            "open_aigc_pattern 数据库未配置:请设置 OPEN_AIGC_PATTERN_DB_URL "
+            "或 OPEN_AIGC_PATTERN_DB_HOST/USER/PASSWORD/NAME"
+        )
+
+    return pymysql.connect(
+        host=host,
+        port=port,
+        user=user,
+        password=password,
+        database=database,
+        charset="utf8mb4",
+        cursorclass=pymysql.cursors.DictCursor,
+        autocommit=True,
+    )
+

+ 183 - 32
examples/content_finder/tools/get_video_topic.py

@@ -7,49 +7,159 @@
 
 from __future__ import annotations
 
-from dataclasses import dataclass
-from typing import Any, Dict, List, Optional
+import json
+import os
+from typing import Any, Dict, Iterable, List, Optional, Sequence, Set
 
 from agent.tools import ToolResult, tool
 
+from db import get_open_aigc_pattern_connection, get_connection
+
 JsonDict = Dict[str, Any]
 
 
-@dataclass(frozen=True)
-class VideoTopicItem:
+def _split_features(features: Optional[str]) -> List[str]:
+    raw = (features or "").strip()
+    if not raw:
+        return []
+    # 同时支持英文/中文逗号
+    parts = raw.replace(",", ",").split(",")
+    out: List[str] = []
+    for p in parts:
+        s = p.strip()
+        if s:
+            out.append(s)
+    return out
+
+
+def _coerce_points(value: Any) -> List[str]:
     """
-    单条视频的选题点结构(仅保留三类列表)。
+    将 DB 字段的 points 归一成 List[str]
 
-    - inspiration_points: 灵感点列表
-    - goal_points: 目的点列表
-    - key_points: 关键点列表
+    允许 value 形态:
+    - list: 过滤出非空字符串
+    - str: 优先当作 JSON 解析;失败则按逗号/换行切分
+    - None/其他:空列表
     """
 
-    inspiration_points: List[str]
-    goal_points: List[str]
-    key_points: List[str]
+    if value is None:
+        return []
+    if isinstance(value, list):
+        return [x.strip() for x in value if isinstance(x, str) and x.strip()]
+    if isinstance(value, str):
+        s = value.strip()
+        if not s:
+            return []
+        try:
+            parsed = json.loads(s)
+        except Exception:
+            parsed = None
+        if isinstance(parsed, list):
+            return [x.strip() for x in parsed if isinstance(x, str) and x.strip()]
+        # 兜底:按常见分隔符切开
+        parts: List[str] = []
+        for token in s.replace(",", ",").replace("\n", ",").split(","):
+            t = token.strip()
+            if t:
+                parts.append(t)
+        return parts
+    return []
 
-    def to_dict(self) -> JsonDict:
-        return {
-            "inspiration_points": self.inspiration_points,
-            "goal_points": self.goal_points,
-            "key_points": self.key_points,
-        }
 
+def _intersect_post_ids(feature_post_id_sets: Sequence[Set[str]]) -> List[str]:
+    if not feature_post_id_sets:
+        return []
+    inter = set(feature_post_id_sets[0])
+    for s in feature_post_id_sets[1:]:
+        inter &= s
+        if not inter:
+            return []
+    # 稳定输出:按字符串排序
+    return sorted(inter)
 
-def _empty_videos() -> List[JsonDict]:
-    # 约定:返回“视频列表”,但当前无接口先返回空。
-    return []
+
+def _query_post_ids_by_feature(conn, feature: str) -> Set[str]:
+    """
+    在 element_classification_mapping 中按 feature 查关联 post_id。
+    """
+
+    sql = f"""
+    SELECT DISTINCT post_id
+    FROM element_classification_mapping
+    WHERE name = %s
+    """
+    with conn.cursor() as cur:
+        cur.execute(sql, (feature,))
+        rows = cur.fetchall() or []
+        out: Set[str] = set()
+        for r in rows:
+            pid = (r.get("post_id") if isinstance(r, dict) else None)
+            if pid is None:
+                continue
+            s = str(pid).strip()
+            if s:
+                out.add(s)
+        return out
+
+
+def _chunked(items: Sequence[str], size: int) -> Iterable[Sequence[str]]:
+    if size <= 0:
+        yield items
+        return
+    for i in range(0, len(items), size):
+        yield items[i : i + size]
+
+
+def _query_points_by_post_ids(conn, post_ids: List[str]) -> Dict[str, JsonDict]:
+    """
+    从 workflow_decode_task_result 取三类 points,按 post_id 映射。
+
+    - post_id 对应字段:channel_content_id
+    - 返回字段:purpose_points / key_points / inspiration_points
+    """
+
+    if not post_ids:
+        return {}
+
+    result: Dict[str, JsonDict] = {}
+    # IN 过长时分批(默认 200)
+    batch_size = int(os.getenv("GET_VIDEO_TOPIC_IN_BATCH_SIZE", "200"))
+    for batch in _chunked(post_ids, batch_size):
+        placeholders = ",".join(["%s"] * len(batch))
+        sql = f"""
+        SELECT
+          channel_content_id,
+          purpose_points,
+          key_points,
+          inspiration_points
+        FROM workflow_decode_task_result
+        WHERE channel_content_id IN ({placeholders})
+        """
+        with conn.cursor() as cur:
+            cur.execute(sql, tuple(batch))
+            rows = cur.fetchall() or []
+            for r in rows:
+                if not isinstance(r, dict):
+                    continue
+                pid = str(r.get("channel_content_id") or "").strip()
+                if not pid:
+                    continue
+                result[pid] = {
+                    "goal_points": _coerce_points(r.get("purpose_points")),
+                    "key_points": _coerce_points(r.get("key_points")),
+                    "inspiration_points": _coerce_points(r.get("inspiration_points")),
+                }
+    return result
 
 
 @tool(description="根据特征匹配高赞视频,并返回每个视频的灵感点/目的点/关键点列表(当前占位返回空)")
 async def get_video_topic(
-    features: Optional[List[str]] = None,
+    features: str = "",
     limit: int = 20,
 ) -> ToolResult:
     """
     Args:
-        features: 特征/关键词列表(可空)
+        features: 特征字符串,逗号分隔(可空)。例:"健康,防骗"
         limit: 期望返回的最大视频数(当前占位实现不使用)
 
     Returns:
@@ -61,13 +171,54 @@ async def get_video_topic(
             }]
     """
 
-    _ = features
-    _ = limit
-
-    videos = _empty_videos()
-    return ToolResult(
-        title="选题解构(占位)",
-        output=f"当前无可用接口,临时返回空视频列表(videos=0)。",
-        metadata={"videos": videos, "features": features or [], "limit": limit},
-        long_term_memory="Get video topic decomposition (placeholder, empty result).",
-    )
+    feature_list = _split_features(features)
+    if not feature_list:
+        return ToolResult(
+            title="选题解构",
+            output="features 为空:返回空视频列表(videos=0)。",
+            metadata={"videos": [], "features": [], "limit": limit, "post_ids": []},
+        )
+
+    open_aigc_conn = get_open_aigc_pattern_connection()
+    supply_conn = get_connection()
+    try:
+        feature_sets: List[Set[str]] = []
+        for f in feature_list:
+            feature_sets.append(_query_post_ids_by_feature(open_aigc_conn, f))
+
+        post_ids = _intersect_post_ids(feature_sets)
+        if limit and limit > 0:
+            post_ids = post_ids[: int(limit)]
+
+        points_map = _query_points_by_post_ids(supply_conn, post_ids)
+
+        videos: List[JsonDict] = []
+        for pid in post_ids:
+            item = points_map.get(pid) or {
+                "inspiration_points": [],
+                "goal_points": [],
+                "key_points": [],
+            }
+            # 按约定:每条“视频”只输出三类 points,不额外带 post_id
+            videos.append(
+                {
+                    "inspiration_points": item.get("inspiration_points") or [],
+                    "goal_points": item.get("goal_points") or [],
+                    "key_points": item.get("key_points") or [],
+                }
+            )
+
+        return ToolResult(
+            title="选题解构",
+            output=f"命中特征交集 post_id={len(post_ids)},返回 videos={len(videos)}。",
+            metadata={
+                "videos": videos,
+                "features": feature_list,
+                "limit": limit,
+                # 调试/可追溯:不放在 videos 条目里,避免污染“每条视频字段约定”
+                "post_ids": post_ids,
+            },
+            long_term_memory=f"Get video topic points by features: {','.join(feature_list)} (videos={len(videos)})",
+        )
+    finally:
+        conn.close()