graph.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. from __future__ import annotations
  2. import time
  3. from dataclasses import dataclass
  4. from datetime import datetime, timezone
  5. from typing import Any, Callable
  6. from langgraph.graph import END, START, StateGraph
  7. from content_agent.errors import ContentAgentError
  8. from content_agent.business_modules import (
  9. content_discovery,
  10. learning_review,
  11. platform_access,
  12. policy_version,
  13. result_source_lookup,
  14. rule_judgment,
  15. run_record,
  16. search_intent,
  17. walk_engine,
  18. source_seed,
  19. )
  20. from content_agent.business_modules.content_discovery import pattern_recall
  21. from content_agent.interfaces import (
  22. GeminiVideoClient,
  23. PlatformSearchClient,
  24. PolicyBundleStore,
  25. QueryVariantClient,
  26. RuntimeFileStore,
  27. )
  28. from content_agent.models import RunState
  29. @dataclass(frozen=True)
  30. class RunDependencies:
  31. runtime: RuntimeFileStore
  32. platform_client: PlatformSearchClient
  33. policy_store: PolicyBundleStore
  34. query_variant_client: QueryVariantClient
  35. gemini_video_client: GeminiVideoClient
  36. def _instrumented(stage: str, fn: Callable[[RunState], dict[str, Any]], runtime: RuntimeFileStore):
  37. def wrapped(state: RunState) -> dict[str, Any]:
  38. started_monotonic = time.monotonic()
  39. started_at = datetime.now(timezone.utc).isoformat()
  40. run_record.record_stage_event(
  41. runtime, state["run_id"], state["policy_run_id"], stage, 1, "started",
  42. started_at=started_at,
  43. )
  44. try:
  45. result = fn(state)
  46. except Exception as exc:
  47. run_record.record_stage_event(
  48. runtime, state["run_id"], state["policy_run_id"], stage, 1, "failed",
  49. started_at=started_at,
  50. ended_at=datetime.now(timezone.utc).isoformat(),
  51. duration_ms=int((time.monotonic() - started_monotonic) * 1000),
  52. error_code=exc.error_code.value if isinstance(exc, ContentAgentError) else None,
  53. message=exc.message if isinstance(exc, ContentAgentError) else type(exc).__name__,
  54. )
  55. raise
  56. run_record.record_stage_event(
  57. runtime, state["run_id"], state["policy_run_id"], stage, 1, "completed",
  58. started_at=started_at,
  59. ended_at=datetime.now(timezone.utc).isoformat(),
  60. duration_ms=int((time.monotonic() - started_monotonic) * 1000),
  61. )
  62. return result
  63. return wrapped
  64. def build_run_graph(deps: RunDependencies):
  65. graph = StateGraph(RunState)
  66. def load_source(state: RunState) -> dict[str, Any]:
  67. result = source_seed.run(
  68. state["run_id"], state["policy_run_id"], state.get("source"), deps.runtime
  69. )
  70. return {**result, "current_step": "load_source"}
  71. def plan_queries(state: RunState) -> dict[str, Any]:
  72. search_queries = search_intent.run(
  73. state["run_id"],
  74. state["policy_run_id"],
  75. state["pattern_seed_pack"],
  76. deps.runtime,
  77. deps.query_variant_client,
  78. )
  79. return {"search_queries": search_queries, "current_step": "plan_queries"}
  80. def search_platform(state: RunState) -> dict[str, Any]:
  81. result = platform_access.run(state["search_queries"], deps.platform_client)
  82. return {**result, "current_step": "search_platform"}
  83. def build_discovered_content(state: RunState) -> dict[str, Any]:
  84. result = content_discovery.run(
  85. state["run_id"],
  86. state["policy_run_id"],
  87. state["platform_results"],
  88. state["source_context"],
  89. deps.runtime,
  90. )
  91. return {**result, "current_step": "build_discovered_content"}
  92. def recall_pattern(state: RunState) -> dict[str, Any]:
  93. result = pattern_recall.run(
  94. state["run_id"],
  95. state["policy_run_id"],
  96. state["discovered_content_items"],
  97. state["content_media_records"],
  98. state["evidence_bundles"],
  99. state["source_context"],
  100. deps.runtime,
  101. deps.gemini_video_client,
  102. )
  103. return {**result, "current_step": "recall_pattern"}
  104. def load_policy(state: RunState) -> dict[str, Any]:
  105. bundle = policy_version.run(state["strategy_version"], deps.policy_store)
  106. return {
  107. "policy_bundle": bundle,
  108. "policy_bundle_id": bundle["policy_bundle_id"],
  109. "strategy_version": bundle["strategy_version"],
  110. "strategy_source_ref": bundle["strategy_source_ref"],
  111. "current_step": "load_policy",
  112. }
  113. def evaluate_rules(state: RunState) -> dict[str, Any]:
  114. decisions = rule_judgment.run(
  115. state["run_id"],
  116. state["policy_run_id"],
  117. state["evidence_bundles"],
  118. state["policy_bundle"],
  119. deps.runtime,
  120. )
  121. return {"rule_decisions": decisions, "current_step": "evaluate_rules"}
  122. def execute_walk(state: RunState) -> dict[str, Any]:
  123. result = walk_engine.run_bounded_walk(
  124. run_id=state["run_id"],
  125. policy_run_id=state["policy_run_id"],
  126. pattern_seed_pack=state["pattern_seed_pack"],
  127. source_context=state["source_context"],
  128. search_queries=state["search_queries"],
  129. discovered_content_items=state["discovered_content_items"],
  130. content_media_records=state["content_media_records"],
  131. evidence_bundles=state["evidence_bundles"],
  132. rule_decisions=state["rule_decisions"],
  133. policy_bundle=state["policy_bundle"],
  134. platform_client=deps.platform_client,
  135. runtime=deps.runtime,
  136. gemini_video_client=deps.gemini_video_client,
  137. )
  138. return {**result, "current_step": "execute_walk"}
  139. def record_run(state: RunState) -> dict[str, Any]:
  140. result = run_record.run(
  141. state["run_id"],
  142. state["policy_run_id"],
  143. state["search_queries"],
  144. state["discovered_content_items"],
  145. state["rule_decisions"],
  146. state["source_path_record_basis"],
  147. state["policy_bundle"],
  148. deps.runtime,
  149. walk_actions=state["walk_actions"],
  150. query_failures=state.get("query_failures", []),
  151. )
  152. return {**result, "current_step": "record_run"}
  153. def commit_results(state: RunState) -> dict[str, Any]:
  154. final_output = result_source_lookup.run(
  155. state["run_id"],
  156. state["policy_run_id"],
  157. state["policy_bundle"],
  158. state["discovered_content_items"],
  159. state["content_media_records"],
  160. state["rule_decisions"],
  161. state["source_path_records"],
  162. state["search_clues"],
  163. deps.runtime,
  164. )
  165. return {"final_output": final_output, "current_step": "commit_results"}
  166. def review_strategy(state: RunState) -> dict[str, Any]:
  167. review = learning_review.run(state["run_id"], state["policy_run_id"], deps.runtime)
  168. return {"strategy_review": review, "current_step": "review_strategy", "status": "success"}
  169. nodes: dict[str, Callable[[RunState], dict[str, Any]]] = {
  170. "load_source": load_source,
  171. "plan_queries": plan_queries,
  172. "search_platform": search_platform,
  173. "build_discovered_content": build_discovered_content,
  174. "recall_pattern": recall_pattern,
  175. "load_policy": load_policy,
  176. "evaluate_rules": evaluate_rules,
  177. "execute_walk": execute_walk,
  178. "record_run": record_run,
  179. "commit_results": commit_results,
  180. "review_strategy": review_strategy,
  181. }
  182. for stage, fn in nodes.items():
  183. graph.add_node(stage, _instrumented(stage, fn, deps.runtime))
  184. graph.add_edge(START, "load_source")
  185. graph.add_edge("load_source", "plan_queries")
  186. graph.add_edge("plan_queries", "search_platform")
  187. graph.add_edge("search_platform", "build_discovered_content")
  188. graph.add_edge("build_discovered_content", "recall_pattern")
  189. graph.add_edge("recall_pattern", "load_policy")
  190. graph.add_edge("load_policy", "evaluate_rules")
  191. graph.add_edge("evaluate_rules", "execute_walk")
  192. graph.add_edge("execute_walk", "record_run")
  193. graph.add_edge("record_run", "commit_results")
  194. graph.add_edge("commit_results", "review_strategy")
  195. graph.add_edge("review_strategy", END)
  196. return graph.compile()