graph.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. from __future__ import annotations
  2. import time
  3. from dataclasses import dataclass
  4. from datetime import datetime, timezone
  5. from typing import Any, Callable
  6. from langgraph.graph import END, START, StateGraph
  7. from content_agent.errors import ContentAgentError
  8. from content_agent.integrations import oss_archive
  9. from content_agent.business_modules import (
  10. content_discovery,
  11. learning_review,
  12. policy_version,
  13. progressive_screening,
  14. result_source_lookup,
  15. rule_judgment,
  16. run_record,
  17. search_intent,
  18. walk_engine,
  19. source_seed,
  20. )
  21. from content_agent.business_modules.content_discovery import pattern_recall
  22. from content_agent.interfaces import (
  23. GeminiVideoClient,
  24. PlatformSearchClient,
  25. PolicyBundleStore,
  26. QueryVariantClient,
  27. RuntimeFileStore,
  28. )
  29. from content_agent.models import RunState
  30. @dataclass(frozen=True)
  31. class RunDependencies:
  32. runtime: RuntimeFileStore
  33. platform_client: PlatformSearchClient
  34. policy_store: PolicyBundleStore
  35. query_variant_client: QueryVariantClient
  36. gemini_video_client: GeminiVideoClient
  37. def _instrumented(stage: str, fn: Callable[[RunState], dict[str, Any]], runtime: RuntimeFileStore):
  38. def wrapped(state: RunState) -> dict[str, Any]:
  39. started_monotonic = time.monotonic()
  40. started_at = datetime.now(timezone.utc).isoformat()
  41. run_record.record_stage_event(
  42. runtime, state["run_id"], state["policy_run_id"], stage, 1, "started",
  43. started_at=started_at,
  44. )
  45. try:
  46. result = fn(state)
  47. except Exception as exc:
  48. run_record.record_stage_event(
  49. runtime, state["run_id"], state["policy_run_id"], stage, 1, "failed",
  50. started_at=started_at,
  51. ended_at=datetime.now(timezone.utc).isoformat(),
  52. duration_ms=int((time.monotonic() - started_monotonic) * 1000),
  53. error_code=exc.error_code.value if isinstance(exc, ContentAgentError) else None,
  54. message=exc.message if isinstance(exc, ContentAgentError) else type(exc).__name__,
  55. )
  56. raise
  57. run_record.record_stage_event(
  58. runtime, state["run_id"], state["policy_run_id"], stage, 1, "completed",
  59. started_at=started_at,
  60. ended_at=datetime.now(timezone.utc).isoformat(),
  61. duration_ms=int((time.monotonic() - started_monotonic) * 1000),
  62. )
  63. return result
  64. return wrapped
  65. def build_run_graph(deps: RunDependencies):
  66. graph = StateGraph(RunState)
  67. def load_source(state: RunState) -> dict[str, Any]:
  68. result = source_seed.run(
  69. state["run_id"], state["policy_run_id"], state.get("source"), deps.runtime
  70. )
  71. return {**result, "current_step": "load_source"}
  72. def plan_queries(state: RunState) -> dict[str, Any]:
  73. search_queries = search_intent.run(
  74. state["run_id"],
  75. state["policy_run_id"],
  76. state["pattern_seed_pack"],
  77. deps.runtime,
  78. deps.query_variant_client,
  79. strategy_version=state.get("strategy_version"),
  80. platform=state.get("platform", ""),
  81. )
  82. return {"search_queries": search_queries, "current_step": "plan_queries"}
  83. def search_platform(state: RunState) -> dict[str, Any]:
  84. archive_dispatcher = (
  85. oss_archive.AsyncArchiveDispatcher(deps.runtime, state["run_id"])
  86. if state.get("platform_mode") == "real"
  87. else None
  88. )
  89. result = progressive_screening.run(
  90. run_id=state["run_id"],
  91. policy_run_id=state["policy_run_id"],
  92. search_queries=state["search_queries"],
  93. source_context=state["source_context"],
  94. policy_bundle=state["policy_bundle"],
  95. platform_client=deps.platform_client,
  96. runtime=deps.runtime,
  97. gemini_video_client=deps.gemini_video_client,
  98. archive_dispatcher=archive_dispatcher,
  99. platform=state.get("platform", ""),
  100. )
  101. return {**result, "current_step": "search_platform"}
  102. def build_discovered_content(state: RunState) -> dict[str, Any]:
  103. if state.get("discovered_content_items") is not None:
  104. return {"current_step": "build_discovered_content"}
  105. result = content_discovery.run(
  106. state["run_id"],
  107. state["policy_run_id"],
  108. state["platform_results"],
  109. state["source_context"],
  110. deps.runtime,
  111. )
  112. return {**result, "current_step": "build_discovered_content"}
  113. def recall_pattern(state: RunState) -> dict[str, Any]:
  114. if state.get("pattern_recall_evidence") is not None:
  115. return {"current_step": "recall_pattern"}
  116. result = pattern_recall.run(
  117. state["run_id"],
  118. state["policy_run_id"],
  119. state["discovered_content_items"],
  120. state["content_media_records"],
  121. state["evidence_bundles"],
  122. state["source_context"],
  123. deps.runtime,
  124. deps.gemini_video_client,
  125. )
  126. return {**result, "current_step": "recall_pattern"}
  127. def load_policy(state: RunState) -> dict[str, Any]:
  128. bundle = policy_version.run(state["strategy_version"], deps.policy_store)
  129. return {
  130. "policy_bundle": bundle,
  131. "policy_bundle_id": bundle["policy_bundle_id"],
  132. "strategy_version": bundle["strategy_version"],
  133. "strategy_source_ref": bundle["strategy_source_ref"],
  134. "current_step": "load_policy",
  135. }
  136. def evaluate_rules(state: RunState) -> dict[str, Any]:
  137. if state.get("rule_decisions") is not None:
  138. return {"current_step": "evaluate_rules"}
  139. decisions = rule_judgment.run(
  140. state["run_id"],
  141. state["policy_run_id"],
  142. state["evidence_bundles"],
  143. state["policy_bundle"],
  144. deps.runtime,
  145. )
  146. return {"rule_decisions": decisions, "current_step": "evaluate_rules"}
  147. def execute_walk(state: RunState) -> dict[str, Any]:
  148. # M9D Gate 2:仅非抖音游走的新标签 query 过 50+ 判定;抖音不判(query_gate=None)。
  149. query_gate = None
  150. if state.get("platform") and state.get("platform") != "douyin":
  151. query_gate = lambda tag: search_intent._gate2_keep(tag, deps.query_variant_client)
  152. # 与 search_platform 一致:real 模式给游走也挂 OSS 归档调度器,游走带回视频才能下载/传 OSS/被 Gemini 看。
  153. archive_dispatcher = (
  154. oss_archive.AsyncArchiveDispatcher(deps.runtime, state["run_id"])
  155. if state.get("platform_mode") == "real"
  156. else None
  157. )
  158. result = walk_engine.run_bounded_walk(
  159. run_id=state["run_id"],
  160. policy_run_id=state["policy_run_id"],
  161. pattern_seed_pack=state["pattern_seed_pack"],
  162. source_context=state["source_context"],
  163. search_queries=state["search_queries"],
  164. discovered_content_items=state["discovered_content_items"],
  165. content_media_records=state["content_media_records"],
  166. evidence_bundles=state["evidence_bundles"],
  167. rule_decisions=state["rule_decisions"],
  168. policy_bundle=state["policy_bundle"],
  169. platform_client=deps.platform_client,
  170. runtime=deps.runtime,
  171. gemini_video_client=deps.gemini_video_client,
  172. query_gate=query_gate,
  173. archive_dispatcher=archive_dispatcher,
  174. )
  175. return {**result, "current_step": "execute_walk"}
  176. def record_run(state: RunState) -> dict[str, Any]:
  177. result = run_record.run(
  178. state["run_id"],
  179. state["policy_run_id"],
  180. state["search_queries"],
  181. state["discovered_content_items"],
  182. state["rule_decisions"],
  183. state["source_path_record_basis"],
  184. state["policy_bundle"],
  185. deps.runtime,
  186. walk_actions=state["walk_actions"],
  187. query_failures=state.get("query_failures", []),
  188. )
  189. return {**result, "current_step": "record_run"}
  190. def commit_results(state: RunState) -> dict[str, Any]:
  191. final_output = result_source_lookup.run(
  192. state["run_id"],
  193. state["policy_run_id"],
  194. state["policy_bundle"],
  195. state["discovered_content_items"],
  196. state["content_media_records"],
  197. state["rule_decisions"],
  198. state["source_path_records"],
  199. state["search_clues"],
  200. deps.runtime,
  201. )
  202. return {"final_output": final_output, "current_step": "commit_results"}
  203. def review_strategy(state: RunState) -> dict[str, Any]:
  204. review = learning_review.run(state["run_id"], state["policy_run_id"], deps.runtime)
  205. return {"strategy_review": review, "current_step": "review_strategy", "status": "success"}
  206. nodes: dict[str, Callable[[RunState], dict[str, Any]]] = {
  207. "load_source": load_source,
  208. "plan_queries": plan_queries,
  209. "search_platform": search_platform,
  210. "build_discovered_content": build_discovered_content,
  211. "recall_pattern": recall_pattern,
  212. "load_policy": load_policy,
  213. "evaluate_rules": evaluate_rules,
  214. "execute_walk": execute_walk,
  215. "record_run": record_run,
  216. "commit_results": commit_results,
  217. "review_strategy": review_strategy,
  218. }
  219. for stage, fn in nodes.items():
  220. graph.add_node(stage, _instrumented(stage, fn, deps.runtime))
  221. graph.add_edge(START, "load_source")
  222. graph.add_edge("load_source", "plan_queries")
  223. graph.add_edge("plan_queries", "load_policy")
  224. graph.add_edge("load_policy", "search_platform")
  225. graph.add_edge("search_platform", "build_discovered_content")
  226. graph.add_edge("build_discovered_content", "recall_pattern")
  227. graph.add_edge("recall_pattern", "evaluate_rules")
  228. graph.add_edge("evaluate_rules", "execute_walk")
  229. graph.add_edge("execute_walk", "record_run")
  230. graph.add_edge("record_run", "commit_results")
  231. graph.add_edge("commit_results", "review_strategy")
  232. graph.add_edge("review_strategy", END)
  233. return graph.compile()