store_results.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. """
  2. 推荐结果写入(demand_find_author、demand_find_content_result 表)
  3. """
  4. import json
  5. from typing import Any, Dict, List, Optional, Tuple
  6. from .connection import get_connection
  7. def _normalize_content_tags(value: Any) -> str:
  8. if value is None:
  9. return ""
  10. if isinstance(value, str):
  11. return value
  12. if isinstance(value, (list, tuple, set)):
  13. parts = [str(x).strip() for x in value if str(x).strip()]
  14. return ",".join(parts)
  15. if isinstance(value, dict):
  16. return json.dumps(value, ensure_ascii=False, separators=(",", ":"))
  17. return str(value)
  18. # demand_find_content_result.process_trace:由寻找过程字段拼成可读文本
  19. _PROCESS_TRACE_FIELDS: Tuple[Tuple[str, str], ...] = (
  20. ("strategy_type", "寻找策略"),
  21. ("from_case_aweme_id", "case内容id"),
  22. ("from_case_point", "灵感点"),
  23. ("search_keyword", "搜索词"),
  24. ("channel", "渠道"),
  25. ("find_way", "寻找方式"),
  26. )
  27. def _format_process_trace_text(item: Dict[str, Any]) -> str:
  28. """将 contents 条目中与寻找过程相关的字段拼成多行文本写入 process_trace。"""
  29. lines: List[str] = []
  30. for key, label in _PROCESS_TRACE_FIELDS:
  31. val = item.get(key)
  32. if val is None:
  33. continue
  34. s = str(val).strip()
  35. if not s:
  36. continue
  37. lines.append(f"{label}: {s}")
  38. if lines:
  39. return "\n".join(lines)
  40. legacy = item.get("process_trace")
  41. if legacy is None:
  42. return ""
  43. return str(legacy).strip()
  44. def upsert_good_authors(
  45. conn,
  46. trace_id: str,
  47. good_account_block: Optional[Dict[str, Any]],
  48. ) -> int:
  49. """
  50. 将 good_account_expansion 中的 accounts 写入 demand_find_author 表。
  51. 兼容两种格式:
  52. - 标准格式:{"enabled": true, "accounts": [...]}
  53. - 降级格式:直接传 list(agent 未严格遵守 schema 时的兜底)
  54. """
  55. if not good_account_block:
  56. return 0
  57. if isinstance(good_account_block, list):
  58. accounts: List[Dict[str, Any]] = good_account_block
  59. else:
  60. if not good_account_block.get("enabled"):
  61. return 0
  62. accounts = good_account_block.get("accounts") or []
  63. if not accounts:
  64. return 0
  65. sql = """
  66. INSERT INTO demand_find_author (trace_id, author_name, author_link, elderly_ratio, elderly_tgi, remark, content_tags)
  67. VALUES (%s, %s, %s, %s, %s, %s, %s)
  68. ON DUPLICATE KEY UPDATE
  69. elderly_ratio = VALUES(elderly_ratio),
  70. elderly_tgi = VALUES(elderly_tgi),
  71. remark = VALUES(remark)
  72. """
  73. with conn.cursor() as cur:
  74. rows = 0
  75. for acc in accounts:
  76. # 与 output_schema 一致:author_nickname / author_sec_uid / author_url
  77. # 兼容 Agent 常用别名:account_name、sec_uid(见 good_account_expansion 数组简写)
  78. author_name = (
  79. acc.get("author_nickname")
  80. or acc.get("account_name")
  81. or ""
  82. )
  83. author_link = acc.get("author_url") or ""
  84. sec_uid = acc.get("author_sec_uid") or acc.get("sec_uid")
  85. if not author_link and sec_uid:
  86. author_link = f"https://www.douyin.com/user/{sec_uid}"
  87. if not author_name or not author_link:
  88. continue
  89. elderly_ratio = acc.get("age_50_plus_ratio") or ""
  90. elderly_tgi = acc.get("age_50_plus_tgi") or ""
  91. remark = acc.get("reason") or acc.get("remark") or ""
  92. content_tags = _normalize_content_tags(acc.get("content_tags"))
  93. cur.execute(
  94. sql,
  95. (
  96. trace_id,
  97. author_name,
  98. author_link,
  99. str(elderly_ratio) if elderly_ratio is not None else None,
  100. str(elderly_tgi) if elderly_tgi is not None else None,
  101. remark or None,
  102. content_tags or None,
  103. ),
  104. )
  105. rows += cur.rowcount
  106. return rows
  107. def fetch_demand_content_dt(conn, demand_content_id: int) -> Optional[Any]:
  108. """按 demand_content.id 查询 dt(与 schedule 约定一致,多为 YYYYMMDD 整数)。"""
  109. sql = "SELECT dt FROM demand_content WHERE id = %s LIMIT 1"
  110. with conn.cursor() as cur:
  111. cur.execute(sql, (demand_content_id,))
  112. row = cur.fetchone()
  113. if not row:
  114. return None
  115. return row.get("dt")
  116. def insert_contents(
  117. conn,
  118. trace_id: str,
  119. query: str,
  120. demand_content_id: int,
  121. contents: List[Dict[str, Any]],
  122. dt: Optional[Any] = None,
  123. ) -> int:
  124. """
  125. 将 contents 列表写入 demand_find_content_result 表。
  126. dt 来自 demand_content.dt,与 demand_content_id 对应;未查到时可传 None。
  127. """
  128. if not contents:
  129. return 0
  130. sql = """
  131. INSERT INTO demand_find_content_result (
  132. trace_id, query, rank_no, aweme_id, video_url, title, author_name, author_id, author_link,
  133. digg_count, comment_count, share_count,
  134. portrait_source, elderly_ratio, elderly_tgi, recommendation_reason,
  135. demand_content_id, dt, channel, process_trace
  136. ) VALUES (
  137. %s, %s, %s, %s, %s, %s, %s, %s, %s,
  138. %s, %s, %s,
  139. %s, %s, %s, %s,
  140. %s, %s, %s, %s
  141. )
  142. """
  143. with conn.cursor() as cur:
  144. rows = 0
  145. for item in contents:
  146. video_url = item.get("video_url") or ""
  147. stats = item.get("statistics") or {}
  148. portrait = item.get("portrait_data") or {}
  149. # age_distribution 是 agent 有时输出的非标准结构,兜底提取 50+ 占比
  150. age_dist = portrait.get("age_distribution") or {}
  151. age_50_plus_ratio = portrait.get("age_50_plus_ratio") or age_dist.get("50+") or ""
  152. age_50_plus_tgi = portrait.get("age_50_plus_tgi") or ""
  153. cur.execute(
  154. sql,
  155. (
  156. trace_id,
  157. query,
  158. int(item.get("rank") or item.get("rank_no") or 0),
  159. item.get("aweme_id") or "",
  160. video_url,
  161. item.get("title") or "",
  162. item.get("author_nickname") or "",
  163. item.get("author_sec_uid") or "",
  164. item.get("author_url") or "",
  165. # like_count 是 agent 有时输出的非标准字段名,兜底处理
  166. int(stats.get("digg_count") or stats.get("like_count") or 0),
  167. int(stats.get("comment_count") or 0),
  168. int(stats.get("share_count") or 0),
  169. portrait.get("source") or "",
  170. str(age_50_plus_ratio) if age_50_plus_ratio != "" else "",
  171. str(age_50_plus_tgi) if age_50_plus_tgi != "" else "",
  172. item.get("reason") or "",
  173. demand_content_id,
  174. dt,
  175. item.get("channel") or "",
  176. _format_process_trace_text(item),
  177. ),
  178. )
  179. rows += cur.rowcount
  180. return rows
  181. def update_content_plan_ids(
  182. trace_id: str,
  183. aweme_ids: List[str],
  184. crawler_plan_id: str = "",
  185. produce_plan_id: str = "",
  186. publish_plan_id: str = "",
  187. ) -> int:
  188. """
  189. 更新 demand_find_content_result 中指定内容的计划字段。
  190. 约定:
  191. - 通过 (trace_id, aweme_id) 定位内容行
  192. - crawler_plan_id / produce_plan_id / publish_plan_id 可只传其一:仅更新非空字段
  193. - 至少一个计划 id 非空时才执行 UPDATE
  194. - 内部自行获取并关闭数据库连接
  195. """
  196. if not aweme_ids or not isinstance(aweme_ids, list):
  197. return 0
  198. c = (crawler_plan_id or "").strip()
  199. p = (produce_plan_id or "").strip()
  200. pub = (publish_plan_id or "").strip()
  201. if not c and not p and not pub:
  202. return 0
  203. set_parts: List[str] = []
  204. params: List[Any] = []
  205. if c:
  206. set_parts.append("crawler_plan_id = %s")
  207. params.append(c)
  208. if p:
  209. set_parts.append("produce_plan_id = %s")
  210. params.append(p)
  211. if pub:
  212. set_parts.append("publish_plan_id = %s")
  213. params.append(pub)
  214. sql = f"""
  215. UPDATE demand_find_content_result
  216. SET {", ".join(set_parts)}
  217. WHERE trace_id = %s AND aweme_id = %s
  218. """
  219. conn = get_connection()
  220. try:
  221. rows = 0
  222. with conn.cursor() as cur:
  223. for aweme_id in aweme_ids:
  224. cur.execute(sql, (*params, trace_id, aweme_id))
  225. rows += cur.rowcount
  226. return rows
  227. finally:
  228. conn.close()
  229. def update_web_html_url(trace_id: str, web_html_url: str) -> int:
  230. """
  231. 根据 trace_id 回写 demand_find_content_result.web_html_url。
  232. 约定:
  233. - trace_id 为 output 子目录名
  234. - web_html_url 为 OSS 公网 URL
  235. - 同一 trace_id 可能对应多条内容,统一更新
  236. """
  237. t = (trace_id or "").strip()
  238. url = (web_html_url or "").strip()
  239. if not t or not url:
  240. return 0
  241. sql = """
  242. UPDATE demand_find_content_result
  243. SET web_html_url = %s
  244. WHERE trace_id = %s
  245. """
  246. conn = get_connection()
  247. try:
  248. with conn.cursor() as cur:
  249. cur.execute(sql, (url, t))
  250. return cur.rowcount
  251. finally:
  252. conn.close()
  253. def update_process_trace_by_aweme_id(
  254. *,
  255. trace_id: str,
  256. aweme_id: str,
  257. process_trace_text: str,
  258. channel: str = "抖音",
  259. ) -> int:
  260. """
  261. 根据 (trace_id, aweme_id) 回写 demand_find_content_result.process_trace(TEXT)与 channel。
  262. 约定:
  263. - trace_id 为 output 子目录名
  264. - aweme_id 为内容唯一 id(表中 demand_find_content_result.aweme_id)
  265. - process_trace_text 为 JSON 序列化后的字符串(或原始文本)
  266. - channel 默认「抖音」;当前业务仅抖音搜索场景,后续可按行区分时再传入
  267. """
  268. t = (trace_id or "").strip()
  269. a = (aweme_id or "").strip()
  270. text = (process_trace_text or "").strip()
  271. ch = (channel or "").strip()
  272. if not t or not a or not text:
  273. return 0
  274. if not ch:
  275. ch = "抖音"
  276. sql = """
  277. UPDATE demand_find_content_result
  278. SET process_trace = %s,
  279. channel = %s
  280. WHERE trace_id = %s AND aweme_id = %s
  281. """
  282. conn = get_connection()
  283. try:
  284. with conn.cursor() as cur:
  285. cur.execute(sql, (text, ch, t, a))
  286. return cur.rowcount
  287. finally:
  288. conn.close()