| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195 |
- from __future__ import annotations
- import copy
- import json
- from pathlib import Path
- from typing import Any
- from content_agent.constants import RUNTIME_SCHEMA_VERSION
- from content_agent.errors import ContentAgentError, ErrorCode
- LEGACY_RUN_ID_KEY = "tr" + "ace_id"
- def load_source_context(run_id: str, source: str | dict[str, Any] | None) -> dict[str, Any]:
- if isinstance(source, dict):
- data = _source_context_from_demand_content(_normalize_source_row(source))
- elif source:
- path = Path(source)
- if not path.exists():
- raise ContentAgentError(
- ErrorCode.INVALID_SOURCE,
- "source file not found",
- {"source": source},
- status_code=400,
- )
- data = _load_source_payload(path)
- else:
- raise ContentAgentError(
- ErrorCode.INVALID_SOURCE,
- "source payload is required",
- {"selector": "source"},
- status_code=400,
- )
- data["run_id"] = run_id
- data["schema_version"] = RUNTIME_SCHEMA_VERSION
- data.pop(LEGACY_RUN_ID_KEY, None)
- data.setdefault("ext_data", {})
- evidence_pack = data["ext_data"].get("evidence_pack") or {}
- data["ext_data"]["evidence_pack"] = normalize_evidence_pack(
- run_id,
- evidence_pack,
- )
- return data
- def normalize_evidence_pack(
- run_id: str,
- evidence_pack: dict[str, Any],
- ) -> dict[str, Any]:
- normalized: dict[str, Any] = {}
- incoming = copy.deepcopy(evidence_pack)
- upstream_run_id = incoming.pop(LEGACY_RUN_ID_KEY, None) or incoming.get("run_id")
- normalized.update(incoming)
- if upstream_run_id and upstream_run_id != run_id:
- normalized["upstream_run_id"] = upstream_run_id
- normalized["run_id"] = run_id
- if not normalized.get("source_kind") and normalized.get("pattern_scope"):
- normalized["source_kind"] = normalized["pattern_scope"]
- if not normalized.get("case_id_type") and normalized.get("pattern_id_type"):
- normalized["case_id_type"] = normalized["pattern_id_type"]
- if not normalized.get("source_certainty") and normalized.get("validation_method"):
- normalized["source_certainty"] = normalized["validation_method"]
- normalized["itemset_ids"] = list(normalized.get("itemset_ids") or [])
- normalized["itemset_items"] = list(normalized.get("itemset_items") or [])
- normalized["category_bindings"] = list(normalized.get("category_bindings") or [])
- normalized["element_bindings"] = list(normalized.get("element_bindings") or [])
- normalized["matched_post_ids"] = list(normalized.get("matched_post_ids") or [])
- normalized["video_ids"] = list(normalized.get("video_ids") or [])
- normalized["case_ids"] = list(normalized.get("case_ids") or [])
- normalized["decode_case_ids"] = list(normalized.get("decode_case_ids") or [])
- normalized["seed_terms"] = list(normalized.get("seed_terms") or [])
- _validate_pg_evidence_pack(normalized)
- return normalized
- def build_pattern_seed_pack(
- run_id: str,
- policy_run_id: str,
- source_context: dict[str, Any],
- ) -> dict[str, Any]:
- evidence_pack = source_context["ext_data"]["evidence_pack"]
- seed_terms = _unique_terms(evidence_pack.get("seed_terms", []))
- return {
- "schema_version": RUNTIME_SCHEMA_VERSION,
- "run_id": run_id,
- "policy_run_id": policy_run_id,
- "pattern_source_system": evidence_pack.get("pattern_source_system"),
- "pattern_execution_id": evidence_pack["pattern_execution_id"],
- "mining_config_id": evidence_pack.get("mining_config_id"),
- "source_post_id": evidence_pack.get("source_post_id"),
- "case_id_type": evidence_pack.get("case_id_type"),
- "itemsets": evidence_pack.get("itemset_items", []),
- "support": evidence_pack.get("support"),
- "absolute_support": evidence_pack.get("absolute_support"),
- "seed_terms": seed_terms,
- "category_bindings": evidence_pack.get("category_bindings", []),
- "element_bindings": evidence_pack.get("element_bindings", []),
- "matched_post_ids": evidence_pack.get("matched_post_ids", []),
- "video_ids": evidence_pack.get("video_ids", []),
- "case_ids": evidence_pack.get("case_ids", []),
- "decode_case_ids": evidence_pack.get("decode_case_ids", []),
- "source_certainty": evidence_pack.get("source_certainty"),
- "validation_status": evidence_pack.get("validation_status"),
- }
- def _load_source_payload(path: Path) -> dict[str, Any]:
- payload = json.loads(path.read_text(encoding="utf-8"))
- if isinstance(payload, list):
- for row in payload:
- if not isinstance(row, dict):
- continue
- normalized = _normalize_source_row(row)
- if normalized.get("ext_data", {}).get("evidence_pack"):
- return _source_context_from_demand_content(normalized)
- raise ValueError(f"{path} does not contain a row with ext_data.evidence_pack")
- if isinstance(payload, dict):
- normalized = _normalize_source_row(payload)
- if normalized.get("ext_data", {}).get("evidence_pack"):
- return normalized
- raise ValueError(f"{path} does not contain ext_data.evidence_pack")
- def _source_context_from_demand_content(row: dict[str, Any]) -> dict[str, Any]:
- return {
- "run_id": row.get("run_id"),
- "demand_content_id": str(row.get("id") or row.get("demand_content_id") or ""),
- "merge_leve2": row.get("merge_leve2"),
- "name": row.get("name"),
- "suggestion": row.get("suggestion"),
- "score": row.get("score"),
- "dt": row.get("dt"),
- "ext_data": copy.deepcopy(row["ext_data"]),
- "raw_demand_content": copy.deepcopy(row.get("raw_demand_content")),
- }
- def _normalize_source_row(row: dict[str, Any]) -> dict[str, Any]:
- normalized = dict(row)
- ext_data = normalized.get("ext_data")
- if isinstance(ext_data, str) and ext_data.strip():
- normalized["ext_data"] = json.loads(ext_data)
- return normalized
- def _validate_pg_evidence_pack(evidence_pack: dict[str, Any]) -> None:
- expected_values = {
- "pattern_source_system": "pg_pattern_v2",
- "validation_status": "passed",
- }
- for field, expected in expected_values.items():
- if evidence_pack.get(field) != expected:
- raise ValueError(f"invalid evidence_pack.{field}: expected {expected}")
- alias_expected_values = {
- ("source_kind", "pattern_scope"): "pattern_itemset",
- ("case_id_type", "pattern_id_type"): "post_id",
- ("source_certainty", "validation_method"): "db_validated",
- }
- for fields, expected in alias_expected_values.items():
- values = [
- evidence_pack.get(field)
- for field in fields
- if evidence_pack.get(field) is not None
- ]
- if not values or any(value != expected for value in values):
- raise ValueError(f"invalid evidence_pack.{fields[0]}: expected {expected}")
- required_fields = [
- "source_post_id",
- "pattern_execution_id",
- "mining_config_id",
- "itemset_ids",
- "itemset_items",
- "category_bindings",
- "support",
- "absolute_support",
- "matched_post_ids",
- "video_ids",
- "case_ids",
- "seed_terms",
- ]
- for field in required_fields:
- value = evidence_pack.get(field)
- if value is None or value == "" or value == []:
- raise ValueError(f"missing evidence_pack.{field}")
- if evidence_pack["source_post_id"] not in evidence_pack["matched_post_ids"]:
- raise ValueError("evidence_pack.source_post_id must be in matched_post_ids")
- def _unique_terms(terms: list[str]) -> list[str]:
- return list(dict.fromkeys(term for term in terms if term))
|