| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539 |
- from __future__ import annotations
- import json
- from pathlib import Path
- from typing import Any
- from fastapi.testclient import TestClient
- from content_agent import api
- from content_agent.errors import ErrorCode
- from content_agent.integrations.composite_runtime import CompositeRuntimeStore
- from content_agent.integrations.database_runtime import ContentSupplyDbConfig
- from content_agent.integrations.demand_source import DemandSourceService
- from content_agent.integrations.mock_platform import MockPlatformClient
- from content_agent.integrations.runtime_files import LocalRuntimeFileStore
- from content_agent.run_service import RunService
- from content_agent.schemas import RunStartRequest
- from tests.p1_helpers import (
- FakeDemandSource,
- FakeQueryVariantClient,
- REAL_SOURCE_FIXTURE,
- real_source_payload,
- )
- def test_composite_runtime_writes_primary_before_local(tmp_path):
- primary = _FakeRuntimeStore()
- export = _FakeRuntimeStore(run_dir=tmp_path / "export")
- store = CompositeRuntimeStore(primary, export)
- store.write_json("run_001", "final_output.json", {"run_id": "run_001"})
- store.append_jsonl("run_001", "run_events.jsonl", [{"run_id": "run_001"}])
- assert primary.calls == [
- ("write_json", "final_output.json"),
- ("append_jsonl", "run_events.jsonl"),
- ]
- assert export.calls == [
- ("write_json", "final_output.json"),
- ("append_jsonl", "run_events.jsonl"),
- ]
- def test_composite_runtime_db_failure_blocks_local_export(tmp_path):
- primary = _FakeRuntimeStore(fail_writes=True)
- export = _FakeRuntimeStore(run_dir=tmp_path / "export")
- store = CompositeRuntimeStore(primary, export)
- try:
- store.write_json("run_001", "final_output.json", {"run_id": "run_001"})
- except RuntimeError as exc:
- assert "primary failed" in str(exc)
- else:
- raise AssertionError("expected primary write failure")
- assert primary.calls == [("write_json", "final_output.json")]
- assert export.calls == []
- def test_run_service_success_records_run_policy_and_lifecycle_events(tmp_path):
- runtime = _SpyRuntimeStore(tmp_path / "runtime")
- demand_source = FakeDemandSource(real_source_payload(demand_content_id=123))
- service = RunService(
- runtime=runtime,
- demand_source=demand_source,
- query_variant_client=FakeQueryVariantClient(),
- )
- state = service.start_run(RunStartRequest(platform_mode="mock"))
- assert state["status"] == "success"
- assert demand_source.calls == ["get_default_pg_pattern_source"]
- assert service.read_json(state["run_id"], "source_context.json")["demand_content_id"] == "123"
- assert runtime.run_records[0]["status"] == "running"
- assert runtime.run_records[0]["source_ref"]["source_type"] == "demand_content_default"
- assert runtime.run_updates[0]["updates"]["demand_content_id"] == 123
- assert runtime.run_updates[0]["updates"]["source_ref"]["demand_content_id"] == 123
- assert runtime.run_updates[-1]["updates"]["status"] == "success"
- assert runtime.policy_runs[0]["policy_bundle_id"] == "douyin_policy_bundle_v1"
- assert runtime.policy_runs[0]["decision_summary"]["decision_action_counts"]
- event_ids = [event["event_id"] for event in runtime.lifecycle_events]
- assert event_ids == ["lifecycle_start", "lifecycle_success"]
- assert not any(event_id.startswith("evt_") for event_id in event_ids)
- assert runtime.lifecycle_events[0]["raw_payload"]["source_ref"]["demand_content_id"] == 123
- def test_run_service_partial_platform_failure_records_partial_success(tmp_path):
- runtime = _SpyRuntimeStore(tmp_path / "runtime")
- demand_source = FakeDemandSource(real_source_payload(demand_content_id=123))
- service = RunService(
- runtime=runtime,
- demand_source=demand_source,
- query_variant_client=FakeQueryVariantClient(),
- )
- service._platform_client = lambda platform, platform_mode: _PartialFailurePlatformClient()
- state = service.start_run(RunStartRequest(platform_mode="real"))
- assert state["status"] == "partial_success"
- assert state["query_failures"][0]["search_query_id"] == "q_002"
- assert runtime.run_updates[-1]["updates"]["status"] == "partial_success"
- assert runtime.policy_runs[0]["status"] == "partial_success"
- assert runtime.lifecycle_events[-1]["event_id"] == "lifecycle_success"
- assert runtime.lifecycle_events[-1]["status"] == "partial_success"
- assert runtime.lifecycle_events[-1]["raw_payload"]["query_failures"]
- run_events = service.read_jsonl(state["run_id"], "run_events.jsonl")
- assert any(event["event_type"] == "platform_query_failed" for event in run_events)
- search_clues = service.read_jsonl(state["run_id"], "search_clues.jsonl")
- failed_clue = next(
- clue for clue in search_clues if clue["search_query_id"] == "q_002"
- )
- assert failed_clue["search_query_effect_status"] == "failed"
- def test_run_service_all_platform_queries_fail_records_failed_query_details(tmp_path):
- runtime = _SpyRuntimeStore(tmp_path / "runtime")
- demand_source = FakeDemandSource(real_source_payload(demand_content_id=123))
- service = RunService(
- runtime=runtime,
- demand_source=demand_source,
- query_variant_client=FakeQueryVariantClient(),
- )
- service._platform_client = lambda platform, platform_mode: _AllFailurePlatformClient()
- state = service.start_run(RunStartRequest(platform_mode="real"))
- assert state["status"] == "failed"
- assert state["error_code"] == ErrorCode.PLATFORM_REQUEST_FAILED.value
- failed_query_ids = [failure["search_query_id"] for failure in state["error_detail"]["query_failures"]]
- assert failed_query_ids == ["q_001", "q_002", "q_003", "q_004"]
- assert runtime.run_updates[-1]["updates"]["status"] == "failed"
- assert runtime.run_updates[-1]["updates"]["error_detail"]["query_failures"]
- assert runtime.lifecycle_events[-1]["event_id"] == "lifecycle_failed"
- assert runtime.lifecycle_events[-1]["raw_payload"]["error_detail"]["query_failures"]
- search_queries = service.read_jsonl(state["run_id"], "search_queries.jsonl")
- assert {query["search_query_effect_status"] for query in search_queries} == {"failed"}
- assert all(query["raw_payload"]["query_failure"]["status"] == "failed" for query in search_queries)
- search_clues = service.read_jsonl(state["run_id"], "search_clues.jsonl")
- assert [clue["search_query_id"] for clue in search_clues] == failed_query_ids
- assert {clue["search_query_effect_status"] for clue in search_clues} == {"failed"}
- assert {clue["walk_next_step"] for clue in search_clues} == {"stop_search_query"}
- run_events = service.read_jsonl(state["run_id"], "run_events.jsonl")
- platform_failures = [
- event for event in run_events if event["event_type"] == "platform_query_failed"
- ]
- assert [event["input_ref"] for event in platform_failures] == [
- f"search_queries.jsonl:{query_id}" for query_id in failed_query_ids
- ]
- assert {event["status"] for event in platform_failures} == {"failed"}
- def test_run_service_query_generation_failure_records_error_code(tmp_path):
- runtime = _SpyRuntimeStore(tmp_path / "runtime")
- demand_source = FakeDemandSource(real_source_payload(demand_content_id=123))
- service = RunService(
- runtime=runtime,
- demand_source=demand_source,
- query_variant_client=FakeQueryVariantClient(error=RuntimeError("model unavailable")),
- )
- state = service.start_run(RunStartRequest(platform_mode="mock"))
- assert state["status"] == "failed"
- assert state["error_code"] == ErrorCode.QUERY_GENERATION_FAILED.value
- assert state["error_detail"]["reason"] == "llm_variant_exception"
- assert runtime.run_updates[-1]["updates"]["status"] == "failed"
- assert (
- runtime.run_updates[-1]["updates"]["error_code"]
- == ErrorCode.QUERY_GENERATION_FAILED.value
- )
- assert runtime.lifecycle_events[-1]["event_id"] == "lifecycle_failed"
- assert runtime.lifecycle_events[-1]["error_code"] == ErrorCode.QUERY_GENERATION_FAILED.value
- def test_run_service_duplicate_query_variant_records_query_generation_failed(tmp_path):
- runtime = _SpyRuntimeStore(tmp_path / "runtime")
- demand_source = FakeDemandSource(real_source_payload(demand_content_id=123))
- service = RunService(
- runtime=runtime,
- demand_source=demand_source,
- query_variant_client=FakeQueryVariantClient({"爱国情感": "爱国情感"}),
- )
- state = service.start_run(RunStartRequest(platform_mode="mock"))
- assert state["status"] == "failed"
- assert state["error_code"] == ErrorCode.QUERY_GENERATION_FAILED.value
- assert state["error_detail"]["reason"] == "llm_variant_same_as_seed"
- assert runtime.run_updates[-1]["updates"]["error_code"] == ErrorCode.QUERY_GENERATION_FAILED.value
- def test_run_service_generic_query_variant_records_query_generation_failed(tmp_path):
- runtime = _SpyRuntimeStore(tmp_path / "runtime")
- demand_source = FakeDemandSource(real_source_payload(demand_content_id=123))
- service = RunService(
- runtime=runtime,
- demand_source=demand_source,
- query_variant_client=FakeQueryVariantClient({"爱国情感": "热门视频"}),
- )
- state = service.start_run(RunStartRequest(platform_mode="mock"))
- assert state["status"] == "failed"
- assert state["error_code"] == ErrorCode.QUERY_GENERATION_FAILED.value
- assert state["error_detail"]["reason"] == "llm_variant_generic"
- assert runtime.run_updates[-1]["updates"]["error_code"] == ErrorCode.QUERY_GENERATION_FAILED.value
- def test_run_service_malformed_query_variant_result_records_query_generation_failed(tmp_path):
- runtime = _SpyRuntimeStore(tmp_path / "runtime")
- demand_source = FakeDemandSource(real_source_payload(demand_content_id=123))
- service = RunService(
- runtime=runtime,
- demand_source=demand_source,
- query_variant_client=_MalformedQueryVariantClient(),
- )
- state = service.start_run(RunStartRequest(platform_mode="mock"))
- assert state["status"] == "failed"
- assert state["error_code"] == ErrorCode.QUERY_GENERATION_FAILED.value
- assert state["error_detail"]["reason"] == "llm_variant_result_invalid"
- assert runtime.run_updates[-1]["updates"]["error_code"] == ErrorCode.QUERY_GENERATION_FAILED.value
- def test_run_service_missing_query_variant_client_fails_at_p2(tmp_path):
- runtime = _SpyRuntimeStore(tmp_path / "runtime")
- service = RunService(runtime=runtime)
- state = service.start_run(
- RunStartRequest(platform_mode="mock", source=str(REAL_SOURCE_FIXTURE))
- )
- assert state["status"] == "failed"
- assert state["error_code"] == ErrorCode.QUERY_GENERATION_FAILED.value
- assert state["error_detail"]["reason"] == "query variant client is not configured"
- assert runtime.run_updates[-1]["updates"]["error_code"] == ErrorCode.QUERY_GENERATION_FAILED.value
- def test_run_service_without_selector_requires_configured_demand_source(tmp_path):
- runtime = _SpyRuntimeStore(tmp_path / "runtime")
- service = RunService(runtime=runtime)
- state = service.start_run(RunStartRequest(platform_mode="mock"))
- assert state["status"] == "failed"
- assert state["error_code"] == ErrorCode.DB_CONFIG_MISSING.value
- assert state["error_detail"]["selector"] == "default_pg_pattern_v2_passed"
- assert runtime.run_records[0]["source_ref"]["source_type"] == "demand_content_default"
- assert runtime.run_updates[-1]["updates"]["status"] == "failed"
- assert runtime.lifecycle_events[-1]["event_id"] == "lifecycle_failed"
- def test_run_service_failure_records_error_code_and_failed_lifecycle(tmp_path):
- runtime = _SpyRuntimeStore(tmp_path / "runtime")
- service = RunService(runtime=runtime)
- state = service.start_run(
- RunStartRequest(platform_mode="mock", source=str(tmp_path / "missing.json"))
- )
- assert state["status"] == "failed"
- assert state["error_code"] == ErrorCode.INVALID_SOURCE.value
- assert state["http_status_code"] == 400
- assert runtime.run_updates[-1]["updates"]["status"] == "failed"
- assert runtime.run_updates[-1]["updates"]["error_code"] == ErrorCode.INVALID_SOURCE.value
- assert runtime.lifecycle_events[-1]["event_id"] == "lifecycle_failed"
- assert runtime.lifecycle_events[-1]["error_code"] == ErrorCode.INVALID_SOURCE.value
- def test_demand_source_service_maps_demand_content_row_to_source_payload():
- connection = _DemandConnection(
- {
- "id": 123,
- "merge_leve2": "PG Pattern",
- "name": "爱国情感,人物故事",
- "reason": "reason",
- "suggestion": "suggestion",
- "score": 0.91,
- "dt": "2026-06-07",
- "ext_data": json.dumps({"evidence_pack": {"pattern_source_system": "pg_pattern_v2"}}),
- }
- )
- service = DemandSourceService(_config(), connection_factory=lambda: connection)
- payload = service.get_by_id(123)
- assert payload["demand_content_id"] == "123"
- assert payload["ext_data"]["evidence_pack"]["pattern_source_system"] == "pg_pattern_v2"
- assert payload["raw_demand_content"]["id"] == 123
- assert "FROM demand_content" in connection.statements[0][0]
- assert connection.statements[0][1] == [123]
- def test_demand_source_service_run_label_uses_deterministic_selector():
- connection = _DemandConnection(None)
- service = DemandSourceService(_config(), connection_factory=lambda: connection)
- try:
- service.get_by_run_label("smoke")
- except Exception:
- pass
- sql, params = connection.statements[0]
- assert "JSON_UNQUOTE(JSON_EXTRACT(ext_data, '$.run_label'))" in sql
- assert "ORDER BY id ASC" in sql
- assert "LIMIT 1" in sql
- assert params == ["smoke"]
- def test_demand_source_service_default_uses_real_pg_pattern_selector():
- connection = _DemandConnection(None)
- service = DemandSourceService(_config(), connection_factory=lambda: connection)
- try:
- service.get_default_pg_pattern_source()
- except Exception:
- pass
- sql, params = connection.statements[0]
- assert "$.evidence_pack.pattern_source_system" in sql
- assert "$.evidence_pack.validation_status" in sql
- assert "pg_pattern_v2" in sql
- assert "passed" in sql
- assert "ORDER BY id ASC" in sql
- assert "LIMIT 1" in sql
- assert params == []
- def test_api_mutual_exclusion_returns_invalid_request(tmp_path, monkeypatch):
- monkeypatch.setattr(api, "service", RunService(runtime_root=tmp_path / "runtime" / "v1"))
- client = TestClient(api.app)
- response = client.post(
- "/runs",
- json={
- "platform_mode": "mock",
- "source": "source.json",
- "demand_content_id": 123,
- },
- )
- assert response.status_code == 422
- assert response.json()["detail"]["error_code"] == ErrorCode.INVALID_REQUEST.value
- def test_api_missing_source_returns_invalid_source(tmp_path, monkeypatch):
- monkeypatch.setattr(api, "service", RunService(runtime_root=tmp_path / "runtime" / "v1"))
- client = TestClient(api.app)
- response = client.post(
- "/runs",
- json={"platform_mode": "mock", "source": str(tmp_path / "missing.json")},
- )
- assert response.status_code == 400
- assert response.json()["detail"]["error_code"] == ErrorCode.INVALID_SOURCE.value
- def test_api_404_returns_structured_run_not_found(tmp_path, monkeypatch):
- monkeypatch.setattr(api, "service", RunService(runtime_root=tmp_path / "runtime" / "v1"))
- client = TestClient(api.app)
- response = client.get("/runs/missing_run")
- assert response.status_code == 404
- assert response.json()["detail"]["error_code"] == ErrorCode.RUN_NOT_FOUND.value
- class _SpyRuntimeStore:
- def __init__(self, base_dir: Path) -> None:
- self.local = LocalRuntimeFileStore(base_dir)
- self.run_records: list[dict[str, Any]] = []
- self.run_updates: list[dict[str, Any]] = []
- self.policy_runs: list[dict[str, Any]] = []
- self.lifecycle_events: list[dict[str, Any]] = []
- self.publish_jobs: list[dict[str, Any]] = []
- self.author_assets: list[dict[str, Any]] = []
- self.author_asset_roles: list[dict[str, Any]] = []
- self.search_clue_assets: list[dict[str, Any]] = []
- self.search_clue_asset_evidence: list[dict[str, Any]] = []
- def prepare_run(self, run_id: str) -> Path:
- return self.local.prepare_run(run_id)
- def run_dir(self, run_id: str) -> Path:
- return self.local.run_dir(run_id)
- def write_json(self, run_id: str, filename: str, data: dict[str, Any]) -> Path:
- return self.local.write_json(run_id, filename, data)
- def update_json(self, run_id: str, filename: str, data: dict[str, Any]) -> Path:
- return self.local.update_json(run_id, filename, data)
- def append_jsonl(self, run_id: str, filename: str, rows: list[dict[str, Any]]) -> Path:
- return self.local.append_jsonl(run_id, filename, rows)
- def read_json(self, run_id: str, filename: str) -> dict[str, Any]:
- return self.local.read_json(run_id, filename)
- def read_jsonl(self, run_id: str, filename: str) -> list[dict[str, Any]]:
- return self.local.read_jsonl(run_id, filename)
- def file_status(self, run_id: str) -> dict[str, bool]:
- return self.local.file_status(run_id)
- def create_run_record(self, record: dict[str, Any]) -> None:
- self.run_records.append(dict(record))
- def update_run_record(self, run_id: str, updates: dict[str, Any]) -> None:
- self.run_updates.append({"run_id": run_id, "updates": dict(updates)})
- def record_policy_run(self, record: dict[str, Any]) -> None:
- self.policy_runs.append(dict(record))
- def append_run_event_records(
- self,
- run_id: str,
- policy_run_id: str,
- rows: list[dict[str, Any]],
- ) -> None:
- self.lifecycle_events.extend(dict(row) for row in rows)
- def write_publish_jobs(
- self,
- run_id: str,
- policy_run_id: str,
- rows: list[dict[str, Any]],
- ) -> None:
- self.publish_jobs.extend(dict(row) for row in rows)
- def write_author_assets(self, rows: list[dict[str, Any]]) -> None:
- self.author_assets.extend(dict(row) for row in rows)
- def write_author_asset_roles(self, rows: list[dict[str, Any]]) -> None:
- self.author_asset_roles.extend(dict(row) for row in rows)
- def write_search_clue_assets(self, rows: list[dict[str, Any]]) -> None:
- self.search_clue_assets.extend(dict(row) for row in rows)
- def write_search_clue_asset_evidence(self, rows: list[dict[str, Any]]) -> None:
- self.search_clue_asset_evidence.extend(dict(row) for row in rows)
- def read_performance_feedback(
- self,
- run_id: str,
- policy_run_id: str,
- ) -> list[dict[str, Any]]:
- return []
- class _FakeRuntimeStore(_SpyRuntimeStore):
- def __init__(self, run_dir: Path | None = None, fail_writes: bool = False) -> None:
- super().__init__(run_dir or Path("fake_runtime"))
- self.calls: list[tuple[str, str]] = []
- self.fail_writes = fail_writes
- def write_json(self, run_id: str, filename: str, data: dict[str, Any]) -> Path:
- self.calls.append(("write_json", filename))
- if self.fail_writes:
- raise RuntimeError("primary failed")
- return self.run_dir(run_id) / filename
- def update_json(self, run_id: str, filename: str, data: dict[str, Any]) -> Path:
- self.calls.append(("update_json", filename))
- if self.fail_writes:
- raise RuntimeError("primary failed")
- return self.run_dir(run_id) / filename
- def append_jsonl(self, run_id: str, filename: str, rows: list[dict[str, Any]]) -> Path:
- self.calls.append(("append_jsonl", filename))
- if self.fail_writes:
- raise RuntimeError("primary failed")
- return self.run_dir(run_id) / filename
- class _MalformedQueryVariantClient:
- def generate_variant(self, *, seed_term: str, evidence_context: dict[str, Any]):
- return None
- class _PartialFailurePlatformClient:
- def __init__(self) -> None:
- self.mock = MockPlatformClient()
- def search(self, search_query: dict[str, Any]) -> list[dict[str, Any]]:
- if search_query["search_query_id"] == "q_002":
- raise RuntimeError("temporary platform failure")
- return self.mock.search(search_query)
- class _AllFailurePlatformClient:
- def search(self, search_query: dict[str, Any]) -> list[dict[str, Any]]:
- raise RuntimeError("platform unavailable")
- class _DemandCursor:
- def __init__(self, connection: "_DemandConnection") -> None:
- self.connection = connection
- def __enter__(self) -> "_DemandCursor":
- return self
- def __exit__(self, *_args) -> None:
- return None
- def execute(self, sql: str, params=None) -> None:
- self.connection.statements.append((sql, list(params or [])))
- def fetchone(self):
- return self.connection.row
- class _DemandConnection:
- def __init__(self, row: dict[str, Any] | None) -> None:
- self.row = row
- self.statements: list[tuple[str, list[Any]]] = []
- def __enter__(self) -> "_DemandConnection":
- return self
- def __exit__(self, *_args) -> None:
- return None
- def cursor(self) -> _DemandCursor:
- return _DemandCursor(self)
- def _config() -> ContentSupplyDbConfig:
- return ContentSupplyDbConfig(
- host="127.0.0.1",
- port=3306,
- user="content_rw",
- password="dummy_password",
- database="content-deconstruction-supply",
- )
|