store_results_mysql.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. """
  2. 将推荐结果写入 MySQL(优质作者表 + 内容表)。
  3. 约定:
  4. - 输入参数:trace_id(字符串)
  5. - 数据来源:.cache/traces/{trace_id}/recommendations.json
  6. - 表结构:good_authors, contents(字段见下面 SQL 注释)
  7. """
  8. import json
  9. import logging
  10. import os
  11. from pathlib import Path
  12. from typing import Any, Dict, List, Optional
  13. import pymysql
  14. from agent.tools import tool, ToolResult
  15. logger = logging.getLogger(__name__)
  16. def _get_connection():
  17. host = os.getenv("DB_HOST", "rm-t4nh1xx6o2a6vj8qu3o.mysql.singapore.rds.aliyuncs.com")
  18. port = int(os.getenv("DB_PORT", "3306"))
  19. user = os.getenv("DB_USER", "content_rw")
  20. password = os.getenv("DB_PASSWORD", "bC1aH4bA1lB0")
  21. database = os.getenv("DB_NAME", "content-deconstruction-supply")
  22. return pymysql.connect(
  23. host=host,
  24. port=port,
  25. user=user,
  26. password=password,
  27. database=database,
  28. charset="utf8mb4",
  29. cursorclass=pymysql.cursors.DictCursor,
  30. autocommit=True,
  31. )
  32. def _load_recommendations(trace_id: str) -> Dict[str, Any]:
  33. """
  34. 按约定路径读取推荐结果:
  35. - 优先:{TRACE_DIR}/{trace_id}/output.json (与 output_schema.md 保持一致)
  36. - 兼容:{TRACE_DIR}/{trace_id}/recommendations.json
  37. """
  38. trace_root = Path(os.getenv("TRACE_DIR", ".cache/traces"))
  39. base = trace_root / trace_id
  40. candidates = [
  41. base / "output.json",
  42. base / "recommendations.json",
  43. ]
  44. for path in candidates:
  45. if path.exists():
  46. with path.open("r", encoding="utf-8") as f:
  47. return json.load(f)
  48. raise FileNotFoundError(
  49. f"no recommendations JSON found for trace_id={trace_id}, tried: "
  50. + ", ".join(str(p) for p in candidates)
  51. )
  52. def _upsert_good_authors(
  53. conn,
  54. trace_id: str,
  55. good_account_block: Optional[Dict[str, Any]],
  56. ) -> int:
  57. """
  58. 将 good_account_expansion 中的 accounts 写入 good_authors 表。
  59. 约定表结构示例:
  60. CREATE TABLE demand_find_author (
  61. id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT,
  62. trace_id VARCHAR(64) NOT NULL,
  63. author_name VARCHAR(255) NOT NULL,
  64. author_link VARCHAR(512) NOT NULL,
  65. reason TEXT,
  66. expanded_count INT DEFAULT 0,
  67. PRIMARY KEY (id),
  68. KEY idx_demand_find_author_trace (trace_id),
  69. UNIQUE KEY uk_demand_find_author_trace_author (trace_id, author_link)
  70. ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
  71. """
  72. if not good_account_block:
  73. return 0
  74. if not good_account_block.get("found"):
  75. return 0
  76. accounts: List[Dict[str, Any]] = good_account_block.get("accounts") or []
  77. if not accounts:
  78. return 0
  79. sql = """
  80. INSERT INTO demand_find_author (trace_id, author_name, author_link, reason, expanded_count)
  81. VALUES (%s, %s, %s, %s, %s)
  82. ON DUPLICATE KEY UPDATE
  83. reason = VALUES(reason),
  84. expanded_count = VALUES(expanded_count)
  85. """
  86. with conn.cursor() as cur:
  87. rows = 0
  88. for acc in accounts:
  89. author_name = acc.get("account_name") or acc.get("author_name") or ""
  90. author_link = acc.get("author_link") or ""
  91. if not author_name or not author_link:
  92. # 如果只给了 sec_uid,可以由上层补 author_link
  93. sec_uid = acc.get("sec_uid")
  94. if sec_uid and not author_link:
  95. author_link = f"https://www.douyin.com/user/{sec_uid}"
  96. if not author_name or not author_link:
  97. continue
  98. reason = acc.get("reason") or ""
  99. expanded_count = int(acc.get("expanded_count") or 0)
  100. cur.execute(sql, (trace_id, author_name, author_link, reason, expanded_count))
  101. rows += cur.rowcount
  102. return rows
  103. def _insert_contents(
  104. conn,
  105. trace_id: str,
  106. contents: List[Dict[str, Any]],
  107. ) -> int:
  108. """
  109. 将 contents 列表写入 contents 表。
  110. 约定表结构示例:
  111. CREATE TABLE demand_find_content_result (
  112. id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT,
  113. trace_id VARCHAR(64) NOT NULL,
  114. rank INT NOT NULL,
  115. content_link VARCHAR(512) NOT NULL,
  116. title TEXT NOT NULL,
  117. author_name VARCHAR(255) NOT NULL,
  118. author_link VARCHAR(512) NOT NULL,
  119. digg_count BIGINT DEFAULT 0,
  120. comment_count BIGINT DEFAULT 0,
  121. share_count BIGINT DEFAULT 0,
  122. portrait_source VARCHAR(255),
  123. elderly_ratio VARCHAR(255),
  124. elderly_tgi VARCHAR(255),
  125. recommendation_reason TEXT,
  126. PRIMARY KEY (id),
  127. KEY idx_demand_find_content_trace (trace_id),
  128. KEY idx_demand_find_content_author (author_link)
  129. ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
  130. """
  131. if not contents:
  132. return 0
  133. sql = """
  134. INSERT INTO demand_find_content_result (
  135. trace_id, rank, content_link, title, author_name, author_link,
  136. digg_count, comment_count, share_count,
  137. portrait_source, elderly_ratio, elderly_tgi, recommendation_reason
  138. ) VALUES (
  139. %s, %s, %s, %s, %s, %s,
  140. %s, %s, %s,
  141. %s, %s, %s, %s
  142. )
  143. """
  144. with conn.cursor() as cur:
  145. rows = 0
  146. for item in contents:
  147. cur.execute(
  148. sql,
  149. (
  150. trace_id,
  151. int(item.get("rank") or 0),
  152. item.get("content_link") or "",
  153. item.get("title") or "",
  154. item.get("author_name") or "",
  155. item.get("author_link") or "",
  156. int(item.get("heat_metrics", {}).get("digg_count") or 0),
  157. int(item.get("heat_metrics", {}).get("comment_count") or 0),
  158. int(item.get("heat_metrics", {}).get("share_count") or 0),
  159. item.get("portrait_source") or "",
  160. str(item.get("elderly_ratio") or ""),
  161. str(item.get("elderly_tgi") or ""),
  162. item.get("recommendation_reason") or "",
  163. ),
  164. )
  165. rows += cur.rowcount
  166. return rows
  167. @tool(description="将推荐结果写入 MySQL(good_authors + contents)")
  168. async def store_results_mysql(trace_id: str) -> ToolResult:
  169. """
  170. 根据 trace_id 读取对应的 recommendations.json,并写入 MySQL 的两个表:
  171. - demand_find_author:优质账号信息
  172. - demand_find_content_result:推荐内容列表
  173. """
  174. try:
  175. data = _load_recommendations(trace_id)
  176. except Exception as e:
  177. msg = f"加载 recommendations.json 失败: {e}"
  178. logger.error(msg)
  179. return ToolResult(output=msg, metadata={"ok": False, "error": str(e)})
  180. conn = None
  181. try:
  182. conn = _get_connection()
  183. good_block = data.get("good_account_expansion") or data.get("good_accounts")
  184. contents = data.get("contents") or []
  185. authors_rows = _upsert_good_authors(conn, trace_id, good_block)
  186. contents_rows = _insert_contents(conn, trace_id, contents)
  187. output = (
  188. f"MySQL 写入完成:demand_find_author 影响行数={authors_rows}, "
  189. f"demand_find_content_result 插入条数={contents_rows}"
  190. )
  191. logger.info(output)
  192. return ToolResult(
  193. output=output,
  194. metadata={
  195. "ok": True,
  196. "trace_id": trace_id,
  197. "good_authors_affected": authors_rows,
  198. "contents_inserted": contents_rows,
  199. },
  200. )
  201. except Exception as e:
  202. msg = f"写入 MySQL 失败: {e}"
  203. logger.error(msg, exc_info=True)
  204. return ToolResult(output=msg, metadata={"ok": False, "error": str(e)})
  205. finally:
  206. if conn is not None:
  207. conn.close()