|
|
@@ -3,215 +3,68 @@
|
|
|
|
|
|
约定:
|
|
|
- 输入参数:trace_id(字符串)
|
|
|
-- 数据来源:.cache/traces/{trace_id}/recommendations.json
|
|
|
+- 数据来源:{TRACE_DIR}/{trace_id}/output.json
|
|
|
- 表结构:good_authors, contents(字段见下面 SQL 注释)
|
|
|
"""
|
|
|
-
|
|
|
+import asyncio
|
|
|
import json
|
|
|
import logging
|
|
|
import os
|
|
|
from pathlib import Path
|
|
|
-from typing import Any, Dict, List, Optional
|
|
|
-
|
|
|
-import pymysql
|
|
|
+from typing import Any, Dict
|
|
|
|
|
|
from agent.tools import tool, ToolResult
|
|
|
|
|
|
-logger = logging.getLogger(__name__)
|
|
|
+from db import get_connection, insert_contents, upsert_good_authors
|
|
|
|
|
|
-
|
|
|
-def _get_connection():
|
|
|
- host = os.getenv("DB_HOST", "rm-t4nh1xx6o2a6vj8qu3o.mysql.singapore.rds.aliyuncs.com")
|
|
|
- port = int(os.getenv("DB_PORT", "3306"))
|
|
|
- user = os.getenv("DB_USER", "content_rw")
|
|
|
- password = os.getenv("DB_PASSWORD", "bC1aH4bA1lB0")
|
|
|
- database = os.getenv("DB_NAME", "content-deconstruction-supply")
|
|
|
-
|
|
|
- return pymysql.connect(
|
|
|
- host=host,
|
|
|
- port=port,
|
|
|
- user=user,
|
|
|
- password=password,
|
|
|
- database=database,
|
|
|
- charset="utf8mb4",
|
|
|
- cursorclass=pymysql.cursors.DictCursor,
|
|
|
- autocommit=True,
|
|
|
- )
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
-def _load_recommendations(trace_id: str) -> Dict[str, Any]:
|
|
|
- """
|
|
|
- 按约定路径读取推荐结果:
|
|
|
- - 优先:{TRACE_DIR}/{trace_id}/output.json (与 output_schema.md 保持一致)
|
|
|
- - 兼容:{TRACE_DIR}/{trace_id}/recommendations.json
|
|
|
- """
|
|
|
+def _load_output(trace_id: str) -> Dict[str, Any]:
|
|
|
+ """从 {TRACE_DIR}/{trace_id}/output.json 读取输出数据。"""
|
|
|
trace_root = Path(os.getenv("TRACE_DIR", ".cache/traces"))
|
|
|
- base = trace_root / trace_id
|
|
|
+ path = trace_root / trace_id / "output.json"
|
|
|
|
|
|
- candidates = [
|
|
|
- base / "output.json",
|
|
|
- base / "recommendations.json",
|
|
|
- ]
|
|
|
+ if not path.exists():
|
|
|
+ raise FileNotFoundError(f"output.json not found for trace_id={trace_id}: {path}")
|
|
|
|
|
|
- for path in candidates:
|
|
|
- if path.exists():
|
|
|
- with path.open("r", encoding="utf-8") as f:
|
|
|
- return json.load(f)
|
|
|
-
|
|
|
- raise FileNotFoundError(
|
|
|
- f"no recommendations JSON found for trace_id={trace_id}, tried: "
|
|
|
- + ", ".join(str(p) for p in candidates)
|
|
|
- )
|
|
|
+ with path.open("r", encoding="utf-8") as f:
|
|
|
+ return json.load(f)
|
|
|
|
|
|
|
|
|
-def _upsert_good_authors(
|
|
|
- conn,
|
|
|
- trace_id: str,
|
|
|
- good_account_block: Optional[Dict[str, Any]],
|
|
|
-) -> int:
|
|
|
- """
|
|
|
- 将 good_account_expansion 中的 accounts 写入 good_authors 表。
|
|
|
-
|
|
|
- 约定表结构示例:
|
|
|
- CREATE TABLE demand_find_author (
|
|
|
- id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT,
|
|
|
- trace_id VARCHAR(64) NOT NULL,
|
|
|
- author_name VARCHAR(255) NOT NULL,
|
|
|
- author_link VARCHAR(512) NOT NULL,
|
|
|
- reason TEXT,
|
|
|
- expanded_count INT DEFAULT 0,
|
|
|
- PRIMARY KEY (id),
|
|
|
- KEY idx_demand_find_author_trace (trace_id),
|
|
|
- UNIQUE KEY uk_demand_find_author_trace_author (trace_id, author_link)
|
|
|
- ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
|
|
- """
|
|
|
- if not good_account_block:
|
|
|
- return 0
|
|
|
-
|
|
|
- if not good_account_block.get("found"):
|
|
|
- return 0
|
|
|
-
|
|
|
- accounts: List[Dict[str, Any]] = good_account_block.get("accounts") or []
|
|
|
- if not accounts:
|
|
|
- return 0
|
|
|
-
|
|
|
- sql = """
|
|
|
- INSERT INTO demand_find_author (trace_id, author_name, author_link, reason, expanded_count)
|
|
|
- VALUES (%s, %s, %s, %s, %s)
|
|
|
- ON DUPLICATE KEY UPDATE
|
|
|
- reason = VALUES(reason),
|
|
|
- expanded_count = VALUES(expanded_count)
|
|
|
- """
|
|
|
- with conn.cursor() as cur:
|
|
|
- rows = 0
|
|
|
- for acc in accounts:
|
|
|
- author_name = acc.get("account_name") or acc.get("author_name") or ""
|
|
|
- author_link = acc.get("author_link") or ""
|
|
|
- if not author_name or not author_link:
|
|
|
- # 如果只给了 sec_uid,可以由上层补 author_link
|
|
|
- sec_uid = acc.get("sec_uid")
|
|
|
- if sec_uid and not author_link:
|
|
|
- author_link = f"https://www.douyin.com/user/{sec_uid}"
|
|
|
- if not author_name or not author_link:
|
|
|
- continue
|
|
|
-
|
|
|
- reason = acc.get("reason") or ""
|
|
|
- expanded_count = int(acc.get("expanded_count") or 0)
|
|
|
- cur.execute(sql, (trace_id, author_name, author_link, reason, expanded_count))
|
|
|
- rows += cur.rowcount
|
|
|
- return rows
|
|
|
-
|
|
|
-
|
|
|
-def _insert_contents(
|
|
|
- conn,
|
|
|
- trace_id: str,
|
|
|
- contents: List[Dict[str, Any]],
|
|
|
-) -> int:
|
|
|
- """
|
|
|
- 将 contents 列表写入 contents 表。
|
|
|
-
|
|
|
- 约定表结构示例:
|
|
|
- CREATE TABLE demand_find_content_result (
|
|
|
- id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT,
|
|
|
- trace_id VARCHAR(64) NOT NULL,
|
|
|
- rank INT NOT NULL,
|
|
|
- content_link VARCHAR(512) NOT NULL,
|
|
|
- title TEXT NOT NULL,
|
|
|
- author_name VARCHAR(255) NOT NULL,
|
|
|
- author_link VARCHAR(512) NOT NULL,
|
|
|
- digg_count BIGINT DEFAULT 0,
|
|
|
- comment_count BIGINT DEFAULT 0,
|
|
|
- share_count BIGINT DEFAULT 0,
|
|
|
- portrait_source VARCHAR(255),
|
|
|
- elderly_ratio VARCHAR(255),
|
|
|
- elderly_tgi VARCHAR(255),
|
|
|
- recommendation_reason TEXT,
|
|
|
- PRIMARY KEY (id),
|
|
|
- KEY idx_demand_find_content_trace (trace_id),
|
|
|
- KEY idx_demand_find_content_author (author_link)
|
|
|
- ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
|
|
- """
|
|
|
- if not contents:
|
|
|
- return 0
|
|
|
-
|
|
|
- sql = """
|
|
|
- INSERT INTO demand_find_content_result (
|
|
|
- trace_id, rank, content_link, title, author_name, author_link,
|
|
|
- digg_count, comment_count, share_count,
|
|
|
- portrait_source, elderly_ratio, elderly_tgi, recommendation_reason
|
|
|
- ) VALUES (
|
|
|
- %s, %s, %s, %s, %s, %s,
|
|
|
- %s, %s, %s,
|
|
|
- %s, %s, %s, %s
|
|
|
- )
|
|
|
- """
|
|
|
- with conn.cursor() as cur:
|
|
|
- rows = 0
|
|
|
- for item in contents:
|
|
|
- cur.execute(
|
|
|
- sql,
|
|
|
- (
|
|
|
- trace_id,
|
|
|
- int(item.get("rank") or 0),
|
|
|
- item.get("content_link") or "",
|
|
|
- item.get("title") or "",
|
|
|
- item.get("author_name") or "",
|
|
|
- item.get("author_link") or "",
|
|
|
- int(item.get("heat_metrics", {}).get("digg_count") or 0),
|
|
|
- int(item.get("heat_metrics", {}).get("comment_count") or 0),
|
|
|
- int(item.get("heat_metrics", {}).get("share_count") or 0),
|
|
|
- item.get("portrait_source") or "",
|
|
|
- str(item.get("elderly_ratio") or ""),
|
|
|
- str(item.get("elderly_tgi") or ""),
|
|
|
- item.get("recommendation_reason") or "",
|
|
|
- ),
|
|
|
- )
|
|
|
- rows += cur.rowcount
|
|
|
- return rows
|
|
|
-
|
|
|
-
|
|
|
-@tool(description="将推荐结果写入 MySQL(good_authors + contents)")
|
|
|
+@tool(description="将推荐结果写入 MySQL")
|
|
|
async def store_results_mysql(trace_id: str) -> ToolResult:
|
|
|
"""
|
|
|
- 根据 trace_id 读取对应的 recommendations.json,并写入 MySQL 的两个表:
|
|
|
- - demand_find_author:优质账号信息
|
|
|
- - demand_find_content_result:推荐内容列表
|
|
|
+ 根据 trace_id 读取 output.json,并写入 MySQL。
|
|
|
+ demand_content_id 从 output 的 demand_id 字段获取,需在 output_schema 中输出。
|
|
|
"""
|
|
|
try:
|
|
|
- data = _load_recommendations(trace_id)
|
|
|
+ data = _load_output(trace_id)
|
|
|
except Exception as e:
|
|
|
- msg = f"加载 recommendations.json 失败: {e}"
|
|
|
+ msg = f"加载 output.json 失败: {e}"
|
|
|
+ logger.error(msg)
|
|
|
+ return ToolResult(title="存储推荐结果", output=msg, metadata={"ok": False, "error": str(e)})
|
|
|
+
|
|
|
+ demand_content_id = data.get("demand_id")
|
|
|
+ if demand_content_id is not None and not isinstance(demand_content_id, int):
|
|
|
+ try:
|
|
|
+ demand_content_id = int(demand_content_id)
|
|
|
+ except (ValueError, TypeError):
|
|
|
+ demand_content_id = None
|
|
|
+ if demand_content_id is None:
|
|
|
+ msg = "demand_id 必填:请在 output 的 demand_id 字段中输出(来自 user 消息的搜索词 id)"
|
|
|
logger.error(msg)
|
|
|
- return ToolResult(output=msg, metadata={"ok": False, "error": str(e)})
|
|
|
+ return ToolResult(title="存储推荐结果", output=msg, metadata={"ok": False, "error": msg})
|
|
|
|
|
|
conn = None
|
|
|
try:
|
|
|
- conn = _get_connection()
|
|
|
- good_block = data.get("good_account_expansion") or data.get("good_accounts")
|
|
|
+ conn = get_connection()
|
|
|
+ good_block = data.get("good_account_expansion")
|
|
|
contents = data.get("contents") or []
|
|
|
+ query = data.get("query") or ""
|
|
|
|
|
|
- authors_rows = _upsert_good_authors(conn, trace_id, good_block)
|
|
|
- contents_rows = _insert_contents(conn, trace_id, contents)
|
|
|
+ authors_rows = upsert_good_authors(conn, trace_id, good_block)
|
|
|
+ contents_rows = insert_contents(conn, trace_id, query, demand_content_id, contents)
|
|
|
|
|
|
output = (
|
|
|
f"MySQL 写入完成:demand_find_author 影响行数={authors_rows}, "
|
|
|
@@ -219,6 +72,7 @@ async def store_results_mysql(trace_id: str) -> ToolResult:
|
|
|
)
|
|
|
logger.info(output)
|
|
|
return ToolResult(
|
|
|
+ title="存储推荐结果",
|
|
|
output=output,
|
|
|
metadata={
|
|
|
"ok": True,
|
|
|
@@ -230,8 +84,17 @@ async def store_results_mysql(trace_id: str) -> ToolResult:
|
|
|
except Exception as e:
|
|
|
msg = f"写入 MySQL 失败: {e}"
|
|
|
logger.error(msg, exc_info=True)
|
|
|
- return ToolResult(output=msg, metadata={"ok": False, "error": str(e)})
|
|
|
+ return ToolResult(title="存储推荐结果", output=msg, metadata={"ok": False, "error": str(e)})
|
|
|
finally:
|
|
|
if conn is not None:
|
|
|
conn.close()
|
|
|
|
|
|
+async def main():
|
|
|
+ result = await store_results_mysql(
|
|
|
+ trace_id="7b211fa6-f0d6-4f98-a6f5-689e6af64748",
|
|
|
+ )
|
|
|
+ # ToolResult 是 dataclass,用 vars 输出
|
|
|
+ print(vars(result))
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ asyncio.run(main())
|