source_context.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. from __future__ import annotations
  2. import copy
  3. import json
  4. from pathlib import Path
  5. from typing import Any
  6. from content_agent.constants import RUNTIME_SCHEMA_VERSION
  7. from content_agent.errors import ContentAgentError, ErrorCode
  8. LEGACY_RUN_ID_KEY = "tr" + "ace_id"
  9. def load_source_context(run_id: str, source: str | dict[str, Any] | None) -> dict[str, Any]:
  10. if isinstance(source, dict):
  11. data = _source_context_from_demand_content(_normalize_source_row(source))
  12. elif source:
  13. path = Path(source)
  14. if not path.exists():
  15. raise ContentAgentError(
  16. ErrorCode.INVALID_SOURCE,
  17. "source file not found",
  18. {"source": source},
  19. status_code=400,
  20. )
  21. data = _load_source_payload(path)
  22. else:
  23. raise ContentAgentError(
  24. ErrorCode.INVALID_SOURCE,
  25. "source payload is required",
  26. {"selector": "source"},
  27. status_code=400,
  28. )
  29. data["run_id"] = run_id
  30. data["schema_version"] = RUNTIME_SCHEMA_VERSION
  31. data.pop(LEGACY_RUN_ID_KEY, None)
  32. data.setdefault("ext_data", {})
  33. evidence_pack = data["ext_data"].get("evidence_pack") or {}
  34. data["ext_data"]["evidence_pack"] = normalize_evidence_pack(
  35. run_id,
  36. evidence_pack,
  37. )
  38. return data
  39. def normalize_evidence_pack(
  40. run_id: str,
  41. evidence_pack: dict[str, Any],
  42. ) -> dict[str, Any]:
  43. normalized: dict[str, Any] = {}
  44. incoming = copy.deepcopy(evidence_pack)
  45. upstream_run_id = incoming.pop(LEGACY_RUN_ID_KEY, None) or incoming.get("run_id")
  46. normalized.update(incoming)
  47. if upstream_run_id and upstream_run_id != run_id:
  48. normalized["upstream_run_id"] = upstream_run_id
  49. normalized["run_id"] = run_id
  50. if not normalized.get("source_kind") and normalized.get("pattern_scope"):
  51. normalized["source_kind"] = normalized["pattern_scope"]
  52. if not normalized.get("case_id_type") and normalized.get("pattern_id_type"):
  53. normalized["case_id_type"] = normalized["pattern_id_type"]
  54. if not normalized.get("source_certainty") and normalized.get("validation_method"):
  55. normalized["source_certainty"] = normalized["validation_method"]
  56. normalized["itemset_ids"] = list(normalized.get("itemset_ids") or [])
  57. normalized["itemset_items"] = list(normalized.get("itemset_items") or [])
  58. normalized["category_bindings"] = list(normalized.get("category_bindings") or [])
  59. normalized["element_bindings"] = list(normalized.get("element_bindings") or [])
  60. normalized["matched_post_ids"] = list(normalized.get("matched_post_ids") or [])
  61. normalized["video_ids"] = list(normalized.get("video_ids") or [])
  62. normalized["case_ids"] = list(normalized.get("case_ids") or [])
  63. normalized["decode_case_ids"] = list(normalized.get("decode_case_ids") or [])
  64. normalized["seed_terms"] = list(normalized.get("seed_terms") or [])
  65. _validate_pg_evidence_pack(normalized)
  66. return normalized
  67. def build_pattern_seed_pack(
  68. run_id: str,
  69. policy_run_id: str,
  70. source_context: dict[str, Any],
  71. ) -> dict[str, Any]:
  72. evidence_pack = source_context["ext_data"]["evidence_pack"]
  73. seed_terms = _unique_terms(evidence_pack.get("seed_terms", []))
  74. return {
  75. "schema_version": RUNTIME_SCHEMA_VERSION,
  76. "run_id": run_id,
  77. "policy_run_id": policy_run_id,
  78. "pattern_source_system": evidence_pack.get("pattern_source_system"),
  79. "pattern_execution_id": evidence_pack["pattern_execution_id"],
  80. "mining_config_id": evidence_pack.get("mining_config_id"),
  81. "source_post_id": evidence_pack.get("source_post_id"),
  82. "case_id_type": evidence_pack.get("case_id_type"),
  83. "itemsets": evidence_pack.get("itemset_items", []),
  84. "support": evidence_pack.get("support"),
  85. "absolute_support": evidence_pack.get("absolute_support"),
  86. "seed_terms": seed_terms,
  87. "category_bindings": evidence_pack.get("category_bindings", []),
  88. "element_bindings": evidence_pack.get("element_bindings", []),
  89. "matched_post_ids": evidence_pack.get("matched_post_ids", []),
  90. "video_ids": evidence_pack.get("video_ids", []),
  91. "case_ids": evidence_pack.get("case_ids", []),
  92. "decode_case_ids": evidence_pack.get("decode_case_ids", []),
  93. "source_certainty": evidence_pack.get("source_certainty"),
  94. "validation_status": evidence_pack.get("validation_status"),
  95. }
  96. def _load_source_payload(path: Path) -> dict[str, Any]:
  97. payload = json.loads(path.read_text(encoding="utf-8"))
  98. if isinstance(payload, list):
  99. for row in payload:
  100. if not isinstance(row, dict):
  101. continue
  102. normalized = _normalize_source_row(row)
  103. if normalized.get("ext_data", {}).get("evidence_pack"):
  104. return _source_context_from_demand_content(normalized)
  105. raise ValueError(f"{path} does not contain a row with ext_data.evidence_pack")
  106. if isinstance(payload, dict):
  107. normalized = _normalize_source_row(payload)
  108. if normalized.get("ext_data", {}).get("evidence_pack"):
  109. return normalized
  110. raise ValueError(f"{path} does not contain ext_data.evidence_pack")
  111. def _source_context_from_demand_content(row: dict[str, Any]) -> dict[str, Any]:
  112. return {
  113. "run_id": row.get("run_id"),
  114. "demand_content_id": str(row.get("id") or row.get("demand_content_id") or ""),
  115. "merge_leve2": row.get("merge_leve2"),
  116. "name": row.get("name"),
  117. "suggestion": row.get("suggestion"),
  118. "score": row.get("score"),
  119. "dt": row.get("dt"),
  120. "ext_data": copy.deepcopy(row["ext_data"]),
  121. "raw_demand_content": copy.deepcopy(row.get("raw_demand_content")),
  122. }
  123. def _normalize_source_row(row: dict[str, Any]) -> dict[str, Any]:
  124. normalized = dict(row)
  125. ext_data = normalized.get("ext_data")
  126. if isinstance(ext_data, str) and ext_data.strip():
  127. normalized["ext_data"] = json.loads(ext_data)
  128. return normalized
  129. def _validate_pg_evidence_pack(evidence_pack: dict[str, Any]) -> None:
  130. expected_values = {
  131. "pattern_source_system": "pg_pattern_v2",
  132. "validation_status": "passed",
  133. }
  134. for field, expected in expected_values.items():
  135. if evidence_pack.get(field) != expected:
  136. raise ValueError(f"invalid evidence_pack.{field}: expected {expected}")
  137. alias_expected_values = {
  138. ("source_kind", "pattern_scope"): "pattern_itemset",
  139. ("case_id_type", "pattern_id_type"): "post_id",
  140. ("source_certainty", "validation_method"): "db_validated",
  141. }
  142. for fields, expected in alias_expected_values.items():
  143. values = [
  144. evidence_pack.get(field)
  145. for field in fields
  146. if evidence_pack.get(field) is not None
  147. ]
  148. if not values or any(value != expected for value in values):
  149. raise ValueError(f"invalid evidence_pack.{fields[0]}: expected {expected}")
  150. required_fields = [
  151. "source_post_id",
  152. "pattern_execution_id",
  153. "mining_config_id",
  154. "itemset_ids",
  155. "itemset_items",
  156. "category_bindings",
  157. "support",
  158. "absolute_support",
  159. "matched_post_ids",
  160. "video_ids",
  161. "case_ids",
  162. "seed_terms",
  163. ]
  164. for field in required_fields:
  165. value = evidence_pack.get(field)
  166. if value is None or value == "" or value == []:
  167. raise ValueError(f"missing evidence_pack.{field}")
  168. if evidence_pack["source_post_id"] not in evidence_pack["matched_post_ids"]:
  169. raise ValueError("evidence_pack.source_post_id must be in matched_post_ids")
  170. def _unique_terms(terms: list[str]) -> list[str]:
  171. return list(dict.fromkeys(term for term in terms if term))