from __future__ import annotations import copy import json from pathlib import Path from typing import Any from content_agent.constants import RUNTIME_SCHEMA_VERSION from content_agent.errors import ContentAgentError, ErrorCode LEGACY_RUN_ID_KEY = "tr" + "ace_id" def load_source_context(run_id: str, source: str | dict[str, Any] | None) -> dict[str, Any]: if isinstance(source, dict): data = _source_context_from_demand_content(_normalize_source_row(source)) elif source: path = Path(source) if not path.exists(): raise ContentAgentError( ErrorCode.INVALID_SOURCE, "source file not found", {"source": source}, status_code=400, ) data = _load_source_payload(path) else: raise ContentAgentError( ErrorCode.INVALID_SOURCE, "source payload is required", {"selector": "source"}, status_code=400, ) data["run_id"] = run_id data["schema_version"] = RUNTIME_SCHEMA_VERSION data.pop(LEGACY_RUN_ID_KEY, None) data.setdefault("ext_data", {}) evidence_pack = data["ext_data"].get("evidence_pack") or {} data["ext_data"]["evidence_pack"] = normalize_evidence_pack( run_id, evidence_pack, ) return data def normalize_evidence_pack( run_id: str, evidence_pack: dict[str, Any], ) -> dict[str, Any]: normalized: dict[str, Any] = {} incoming = copy.deepcopy(evidence_pack) upstream_run_id = incoming.pop(LEGACY_RUN_ID_KEY, None) or incoming.get("run_id") normalized.update(incoming) if upstream_run_id and upstream_run_id != run_id: normalized["upstream_run_id"] = upstream_run_id normalized["run_id"] = run_id if not normalized.get("source_kind") and normalized.get("pattern_scope"): normalized["source_kind"] = normalized["pattern_scope"] if not normalized.get("case_id_type") and normalized.get("pattern_id_type"): normalized["case_id_type"] = normalized["pattern_id_type"] if not normalized.get("source_certainty") and normalized.get("validation_method"): normalized["source_certainty"] = normalized["validation_method"] normalized["itemset_ids"] = list(normalized.get("itemset_ids") or []) normalized["itemset_items"] = list(normalized.get("itemset_items") or []) normalized["category_bindings"] = list(normalized.get("category_bindings") or []) normalized["element_bindings"] = list(normalized.get("element_bindings") or []) normalized["matched_post_ids"] = list(normalized.get("matched_post_ids") or []) normalized["video_ids"] = list(normalized.get("video_ids") or []) normalized["case_ids"] = list(normalized.get("case_ids") or []) normalized["decode_case_ids"] = list(normalized.get("decode_case_ids") or []) normalized["seed_terms"] = list(normalized.get("seed_terms") or []) _validate_pg_evidence_pack(normalized) return normalized def build_pattern_seed_pack( run_id: str, policy_run_id: str, source_context: dict[str, Any], ) -> dict[str, Any]: evidence_pack = source_context["ext_data"]["evidence_pack"] seed_terms = _unique_terms(evidence_pack.get("seed_terms", [])) return { "schema_version": RUNTIME_SCHEMA_VERSION, "run_id": run_id, "policy_run_id": policy_run_id, "pattern_source_system": evidence_pack.get("pattern_source_system"), "pattern_execution_id": evidence_pack["pattern_execution_id"], "mining_config_id": evidence_pack.get("mining_config_id"), "source_post_id": evidence_pack.get("source_post_id"), "case_id_type": evidence_pack.get("case_id_type"), "itemsets": evidence_pack.get("itemset_items", []), "support": evidence_pack.get("support"), "absolute_support": evidence_pack.get("absolute_support"), "seed_terms": seed_terms, "category_bindings": evidence_pack.get("category_bindings", []), "element_bindings": evidence_pack.get("element_bindings", []), "matched_post_ids": evidence_pack.get("matched_post_ids", []), "video_ids": evidence_pack.get("video_ids", []), "case_ids": evidence_pack.get("case_ids", []), "decode_case_ids": evidence_pack.get("decode_case_ids", []), "source_certainty": evidence_pack.get("source_certainty"), "validation_status": evidence_pack.get("validation_status"), } def _load_source_payload(path: Path) -> dict[str, Any]: payload = json.loads(path.read_text(encoding="utf-8")) if isinstance(payload, list): for row in payload: if not isinstance(row, dict): continue normalized = _normalize_source_row(row) if normalized.get("ext_data", {}).get("evidence_pack"): return _source_context_from_demand_content(normalized) raise ValueError(f"{path} does not contain a row with ext_data.evidence_pack") if isinstance(payload, dict): normalized = _normalize_source_row(payload) if normalized.get("ext_data", {}).get("evidence_pack"): return normalized raise ValueError(f"{path} does not contain ext_data.evidence_pack") def _source_context_from_demand_content(row: dict[str, Any]) -> dict[str, Any]: return { "run_id": row.get("run_id"), "demand_content_id": str(row.get("id") or row.get("demand_content_id") or ""), "merge_leve2": row.get("merge_leve2"), "name": row.get("name"), "suggestion": row.get("suggestion"), "score": row.get("score"), "dt": row.get("dt"), "ext_data": copy.deepcopy(row["ext_data"]), "raw_demand_content": copy.deepcopy(row.get("raw_demand_content")), } def _normalize_source_row(row: dict[str, Any]) -> dict[str, Any]: normalized = dict(row) ext_data = normalized.get("ext_data") if isinstance(ext_data, str) and ext_data.strip(): normalized["ext_data"] = json.loads(ext_data) return normalized def _validate_pg_evidence_pack(evidence_pack: dict[str, Any]) -> None: expected_values = { "pattern_source_system": "pg_pattern_v2", "validation_status": "passed", } for field, expected in expected_values.items(): if evidence_pack.get(field) != expected: raise ValueError(f"invalid evidence_pack.{field}: expected {expected}") alias_expected_values = { ("source_kind", "pattern_scope"): "pattern_itemset", ("case_id_type", "pattern_id_type"): "post_id", ("source_certainty", "validation_method"): "db_validated", } for fields, expected in alias_expected_values.items(): values = [ evidence_pack.get(field) for field in fields if evidence_pack.get(field) is not None ] if not values or any(value != expected for value in values): raise ValueError(f"invalid evidence_pack.{fields[0]}: expected {expected}") required_fields = [ "source_post_id", "pattern_execution_id", "mining_config_id", "itemset_ids", "itemset_items", "category_bindings", "support", "absolute_support", "matched_post_ids", "video_ids", "case_ids", "seed_terms", ] for field in required_fields: value = evidence_pack.get(field) if value is None or value == "" or value == []: raise ValueError(f"missing evidence_pack.{field}") if evidence_pack["source_post_id"] not in evidence_pack["matched_post_ids"]: raise ValueError("evidence_pack.source_post_id must be in matched_post_ids") def _unique_terms(terms: list[str]) -> list[str]: return list(dict.fromkeys(term for term in terms if term))