| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220 |
- from __future__ import annotations
- import time
- from dataclasses import dataclass
- from datetime import datetime, timezone
- from typing import Any, Callable
- from langgraph.graph import END, START, StateGraph
- from content_agent.errors import ContentAgentError
- from content_agent.business_modules import (
- content_discovery,
- learning_review,
- platform_access,
- policy_version,
- result_source_lookup,
- rule_judgment,
- run_record,
- search_intent,
- walk_engine,
- source_seed,
- )
- from content_agent.business_modules.content_discovery import pattern_recall
- from content_agent.interfaces import (
- GeminiVideoClient,
- PlatformSearchClient,
- PolicyBundleStore,
- QueryVariantClient,
- RuntimeFileStore,
- )
- from content_agent.models import RunState
- @dataclass(frozen=True)
- class RunDependencies:
- runtime: RuntimeFileStore
- platform_client: PlatformSearchClient
- policy_store: PolicyBundleStore
- query_variant_client: QueryVariantClient
- gemini_video_client: GeminiVideoClient
- def _instrumented(stage: str, fn: Callable[[RunState], dict[str, Any]], runtime: RuntimeFileStore):
- def wrapped(state: RunState) -> dict[str, Any]:
- started_monotonic = time.monotonic()
- started_at = datetime.now(timezone.utc).isoformat()
- run_record.record_stage_event(
- runtime, state["run_id"], state["policy_run_id"], stage, 1, "started",
- started_at=started_at,
- )
- try:
- result = fn(state)
- except Exception as exc:
- run_record.record_stage_event(
- runtime, state["run_id"], state["policy_run_id"], stage, 1, "failed",
- started_at=started_at,
- ended_at=datetime.now(timezone.utc).isoformat(),
- duration_ms=int((time.monotonic() - started_monotonic) * 1000),
- error_code=exc.error_code.value if isinstance(exc, ContentAgentError) else None,
- message=exc.message if isinstance(exc, ContentAgentError) else type(exc).__name__,
- )
- raise
- run_record.record_stage_event(
- runtime, state["run_id"], state["policy_run_id"], stage, 1, "completed",
- started_at=started_at,
- ended_at=datetime.now(timezone.utc).isoformat(),
- duration_ms=int((time.monotonic() - started_monotonic) * 1000),
- )
- return result
- return wrapped
- def build_run_graph(deps: RunDependencies):
- graph = StateGraph(RunState)
- def load_source(state: RunState) -> dict[str, Any]:
- result = source_seed.run(
- state["run_id"], state["policy_run_id"], state.get("source"), deps.runtime
- )
- return {**result, "current_step": "load_source"}
- def plan_queries(state: RunState) -> dict[str, Any]:
- search_queries = search_intent.run(
- state["run_id"],
- state["policy_run_id"],
- state["pattern_seed_pack"],
- deps.runtime,
- deps.query_variant_client,
- )
- return {"search_queries": search_queries, "current_step": "plan_queries"}
- def search_platform(state: RunState) -> dict[str, Any]:
- result = platform_access.run(state["search_queries"], deps.platform_client)
- return {**result, "current_step": "search_platform"}
- def build_discovered_content(state: RunState) -> dict[str, Any]:
- result = content_discovery.run(
- state["run_id"],
- state["policy_run_id"],
- state["platform_results"],
- state["source_context"],
- deps.runtime,
- )
- return {**result, "current_step": "build_discovered_content"}
- def recall_pattern(state: RunState) -> dict[str, Any]:
- result = pattern_recall.run(
- state["run_id"],
- state["policy_run_id"],
- state["discovered_content_items"],
- state["content_media_records"],
- state["evidence_bundles"],
- state["source_context"],
- deps.runtime,
- deps.gemini_video_client,
- )
- return {**result, "current_step": "recall_pattern"}
- def load_policy(state: RunState) -> dict[str, Any]:
- bundle = policy_version.run(state["strategy_version"], deps.policy_store)
- return {
- "policy_bundle": bundle,
- "policy_bundle_id": bundle["policy_bundle_id"],
- "strategy_version": bundle["strategy_version"],
- "strategy_source_ref": bundle["strategy_source_ref"],
- "current_step": "load_policy",
- }
- def evaluate_rules(state: RunState) -> dict[str, Any]:
- decisions = rule_judgment.run(
- state["run_id"],
- state["policy_run_id"],
- state["evidence_bundles"],
- state["policy_bundle"],
- deps.runtime,
- )
- return {"rule_decisions": decisions, "current_step": "evaluate_rules"}
- def execute_walk(state: RunState) -> dict[str, Any]:
- result = walk_engine.run_bounded_walk(
- run_id=state["run_id"],
- policy_run_id=state["policy_run_id"],
- pattern_seed_pack=state["pattern_seed_pack"],
- source_context=state["source_context"],
- search_queries=state["search_queries"],
- discovered_content_items=state["discovered_content_items"],
- content_media_records=state["content_media_records"],
- evidence_bundles=state["evidence_bundles"],
- rule_decisions=state["rule_decisions"],
- policy_bundle=state["policy_bundle"],
- platform_client=deps.platform_client,
- runtime=deps.runtime,
- gemini_video_client=deps.gemini_video_client,
- )
- return {**result, "current_step": "execute_walk"}
- def record_run(state: RunState) -> dict[str, Any]:
- result = run_record.run(
- state["run_id"],
- state["policy_run_id"],
- state["search_queries"],
- state["discovered_content_items"],
- state["rule_decisions"],
- state["source_path_record_basis"],
- state["policy_bundle"],
- deps.runtime,
- walk_actions=state["walk_actions"],
- query_failures=state.get("query_failures", []),
- )
- return {**result, "current_step": "record_run"}
- def commit_results(state: RunState) -> dict[str, Any]:
- final_output = result_source_lookup.run(
- state["run_id"],
- state["policy_run_id"],
- state["policy_bundle"],
- state["discovered_content_items"],
- state["content_media_records"],
- state["rule_decisions"],
- state["source_path_records"],
- state["search_clues"],
- deps.runtime,
- )
- return {"final_output": final_output, "current_step": "commit_results"}
- def review_strategy(state: RunState) -> dict[str, Any]:
- review = learning_review.run(state["run_id"], state["policy_run_id"], deps.runtime)
- return {"strategy_review": review, "current_step": "review_strategy", "status": "success"}
- nodes: dict[str, Callable[[RunState], dict[str, Any]]] = {
- "load_source": load_source,
- "plan_queries": plan_queries,
- "search_platform": search_platform,
- "build_discovered_content": build_discovered_content,
- "recall_pattern": recall_pattern,
- "load_policy": load_policy,
- "evaluate_rules": evaluate_rules,
- "execute_walk": execute_walk,
- "record_run": record_run,
- "commit_results": commit_results,
- "review_strategy": review_strategy,
- }
- for stage, fn in nodes.items():
- graph.add_node(stage, _instrumented(stage, fn, deps.runtime))
- graph.add_edge(START, "load_source")
- graph.add_edge("load_source", "plan_queries")
- graph.add_edge("plan_queries", "search_platform")
- graph.add_edge("search_platform", "build_discovered_content")
- graph.add_edge("build_discovered_content", "recall_pattern")
- graph.add_edge("recall_pattern", "load_policy")
- graph.add_edge("load_policy", "evaluate_rules")
- graph.add_edge("evaluate_rules", "execute_walk")
- graph.add_edge("execute_walk", "record_run")
- graph.add_edge("record_run", "commit_results")
- graph.add_edge("commit_results", "review_strategy")
- graph.add_edge("review_strategy", END)
- return graph.compile()
|