from __future__ import annotations import time from dataclasses import dataclass from datetime import datetime, timezone from typing import Any, Callable from langgraph.graph import END, START, StateGraph from content_agent.errors import ContentAgentError from content_agent.business_modules import ( content_discovery, learning_review, platform_access, policy_version, result_source_lookup, rule_judgment, run_record, search_intent, walk_engine, source_seed, ) from content_agent.business_modules.content_discovery import pattern_recall from content_agent.interfaces import ( GeminiVideoClient, PlatformSearchClient, PolicyBundleStore, QueryVariantClient, RuntimeFileStore, ) from content_agent.models import RunState @dataclass(frozen=True) class RunDependencies: runtime: RuntimeFileStore platform_client: PlatformSearchClient policy_store: PolicyBundleStore query_variant_client: QueryVariantClient gemini_video_client: GeminiVideoClient def _instrumented(stage: str, fn: Callable[[RunState], dict[str, Any]], runtime: RuntimeFileStore): def wrapped(state: RunState) -> dict[str, Any]: started_monotonic = time.monotonic() started_at = datetime.now(timezone.utc).isoformat() run_record.record_stage_event( runtime, state["run_id"], state["policy_run_id"], stage, 1, "started", started_at=started_at, ) try: result = fn(state) except Exception as exc: run_record.record_stage_event( runtime, state["run_id"], state["policy_run_id"], stage, 1, "failed", started_at=started_at, ended_at=datetime.now(timezone.utc).isoformat(), duration_ms=int((time.monotonic() - started_monotonic) * 1000), error_code=exc.error_code.value if isinstance(exc, ContentAgentError) else None, message=exc.message if isinstance(exc, ContentAgentError) else type(exc).__name__, ) raise run_record.record_stage_event( runtime, state["run_id"], state["policy_run_id"], stage, 1, "completed", started_at=started_at, ended_at=datetime.now(timezone.utc).isoformat(), duration_ms=int((time.monotonic() - started_monotonic) * 1000), ) return result return wrapped def build_run_graph(deps: RunDependencies): graph = StateGraph(RunState) def load_source(state: RunState) -> dict[str, Any]: result = source_seed.run( state["run_id"], state["policy_run_id"], state.get("source"), deps.runtime ) return {**result, "current_step": "load_source"} def plan_queries(state: RunState) -> dict[str, Any]: search_queries = search_intent.run( state["run_id"], state["policy_run_id"], state["pattern_seed_pack"], deps.runtime, deps.query_variant_client, ) return {"search_queries": search_queries, "current_step": "plan_queries"} def search_platform(state: RunState) -> dict[str, Any]: result = platform_access.run(state["search_queries"], deps.platform_client) return {**result, "current_step": "search_platform"} def build_discovered_content(state: RunState) -> dict[str, Any]: result = content_discovery.run( state["run_id"], state["policy_run_id"], state["platform_results"], state["source_context"], deps.runtime, ) return {**result, "current_step": "build_discovered_content"} def recall_pattern(state: RunState) -> dict[str, Any]: result = pattern_recall.run( state["run_id"], state["policy_run_id"], state["discovered_content_items"], state["content_media_records"], state["evidence_bundles"], state["source_context"], deps.runtime, deps.gemini_video_client, ) return {**result, "current_step": "recall_pattern"} def load_policy(state: RunState) -> dict[str, Any]: bundle = policy_version.run(state["strategy_version"], deps.policy_store) return { "policy_bundle": bundle, "policy_bundle_id": bundle["policy_bundle_id"], "strategy_version": bundle["strategy_version"], "strategy_source_ref": bundle["strategy_source_ref"], "current_step": "load_policy", } def evaluate_rules(state: RunState) -> dict[str, Any]: decisions = rule_judgment.run( state["run_id"], state["policy_run_id"], state["evidence_bundles"], state["policy_bundle"], deps.runtime, ) return {"rule_decisions": decisions, "current_step": "evaluate_rules"} def execute_walk(state: RunState) -> dict[str, Any]: result = walk_engine.run_bounded_walk( run_id=state["run_id"], policy_run_id=state["policy_run_id"], pattern_seed_pack=state["pattern_seed_pack"], source_context=state["source_context"], search_queries=state["search_queries"], discovered_content_items=state["discovered_content_items"], content_media_records=state["content_media_records"], evidence_bundles=state["evidence_bundles"], rule_decisions=state["rule_decisions"], policy_bundle=state["policy_bundle"], platform_client=deps.platform_client, runtime=deps.runtime, gemini_video_client=deps.gemini_video_client, ) return {**result, "current_step": "execute_walk"} def record_run(state: RunState) -> dict[str, Any]: result = run_record.run( state["run_id"], state["policy_run_id"], state["search_queries"], state["discovered_content_items"], state["rule_decisions"], state["source_path_record_basis"], state["policy_bundle"], deps.runtime, walk_actions=state["walk_actions"], query_failures=state.get("query_failures", []), ) return {**result, "current_step": "record_run"} def commit_results(state: RunState) -> dict[str, Any]: final_output = result_source_lookup.run( state["run_id"], state["policy_run_id"], state["policy_bundle"], state["discovered_content_items"], state["content_media_records"], state["rule_decisions"], state["source_path_records"], state["search_clues"], deps.runtime, ) return {"final_output": final_output, "current_step": "commit_results"} def review_strategy(state: RunState) -> dict[str, Any]: review = learning_review.run(state["run_id"], state["policy_run_id"], deps.runtime) return {"strategy_review": review, "current_step": "review_strategy", "status": "success"} nodes: dict[str, Callable[[RunState], dict[str, Any]]] = { "load_source": load_source, "plan_queries": plan_queries, "search_platform": search_platform, "build_discovered_content": build_discovered_content, "recall_pattern": recall_pattern, "load_policy": load_policy, "evaluate_rules": evaluate_rules, "execute_walk": execute_walk, "record_run": record_run, "commit_results": commit_results, "review_strategy": review_strategy, } for stage, fn in nodes.items(): graph.add_node(stage, _instrumented(stage, fn, deps.runtime)) graph.add_edge(START, "load_source") graph.add_edge("load_source", "plan_queries") graph.add_edge("plan_queries", "search_platform") graph.add_edge("search_platform", "build_discovered_content") graph.add_edge("build_discovered_content", "recall_pattern") graph.add_edge("recall_pattern", "load_policy") graph.add_edge("load_policy", "evaluate_rules") graph.add_edge("evaluate_rules", "execute_walk") graph.add_edge("execute_walk", "record_run") graph.add_edge("record_run", "commit_results") graph.add_edge("commit_results", "review_strategy") graph.add_edge("review_strategy", END) return graph.compile()