demand_source.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. from __future__ import annotations
  2. import json
  3. from typing import Any, Callable
  4. from content_agent.errors import ContentAgentError, ErrorCode
  5. from content_agent.integrations.database_runtime import ContentSupplyDbConfig
  6. ConnectionFactory = Callable[[], Any]
  7. class DemandSourceService:
  8. def __init__(
  9. self,
  10. config: ContentSupplyDbConfig,
  11. connection_factory: ConnectionFactory | None = None,
  12. ) -> None:
  13. self.config = config
  14. self._connection_factory = connection_factory or config.connect
  15. def get_by_id(self, demand_content_id: int) -> dict[str, Any]:
  16. row = self._fetch_one(
  17. """
  18. SELECT id, merge_leve2, name, reason, suggestion, score, dt, ext_data
  19. FROM demand_content
  20. WHERE id = %s
  21. LIMIT 1
  22. """,
  23. (demand_content_id,),
  24. )
  25. if not row:
  26. raise ContentAgentError(
  27. ErrorCode.INVALID_SOURCE,
  28. "demand_content not found",
  29. {"demand_content_id": demand_content_id},
  30. status_code=400,
  31. )
  32. return self.to_source_payload(row)
  33. def get_by_run_label(self, run_label: str) -> dict[str, Any]:
  34. row = self._fetch_one(
  35. """
  36. SELECT id, merge_leve2, name, reason, suggestion, score, dt, ext_data
  37. FROM demand_content
  38. WHERE JSON_UNQUOTE(JSON_EXTRACT(ext_data, '$.run_label')) = %s
  39. ORDER BY id ASC
  40. LIMIT 1
  41. """,
  42. (run_label,),
  43. )
  44. if not row:
  45. raise ContentAgentError(
  46. ErrorCode.INVALID_SOURCE,
  47. "demand_content run_label not found",
  48. {"run_label": run_label},
  49. status_code=400,
  50. )
  51. return self.to_source_payload(row)
  52. def get_default_pg_pattern_source(self) -> dict[str, Any]:
  53. row = self._fetch_one(
  54. """
  55. SELECT id, merge_leve2, name, reason, suggestion, score, dt, ext_data
  56. FROM demand_content
  57. WHERE JSON_UNQUOTE(JSON_EXTRACT(ext_data, '$.evidence_pack.pattern_source_system')) = 'pg_pattern_v2'
  58. AND JSON_UNQUOTE(JSON_EXTRACT(ext_data, '$.evidence_pack.validation_status')) = 'passed'
  59. ORDER BY id ASC
  60. LIMIT 1
  61. """,
  62. (),
  63. )
  64. if not row:
  65. raise ContentAgentError(
  66. ErrorCode.INVALID_SOURCE,
  67. "default pg_pattern_v2 demand_content not found",
  68. {"selector": "default_pg_pattern_v2_passed"},
  69. status_code=400,
  70. )
  71. return self.to_source_payload(row)
  72. def to_source_payload(self, row: dict[str, Any]) -> dict[str, Any]:
  73. ext_data = _decode_ext_data(row.get("ext_data"))
  74. if not isinstance(ext_data, dict) or not ext_data.get("evidence_pack"):
  75. raise ContentAgentError(
  76. ErrorCode.INVALID_SOURCE,
  77. "demand_content missing ext_data.evidence_pack",
  78. {"demand_content_id": row.get("id")},
  79. status_code=400,
  80. )
  81. return {
  82. "id": row.get("id"),
  83. "demand_content_id": str(row.get("id") or ""),
  84. "merge_leve2": row.get("merge_leve2"),
  85. "name": row.get("name"),
  86. "reason": row.get("reason"),
  87. "suggestion": row.get("suggestion"),
  88. "score": row.get("score"),
  89. "dt": row.get("dt"),
  90. "ext_data": ext_data,
  91. "raw_demand_content": {
  92. "id": row.get("id"),
  93. "merge_leve2": row.get("merge_leve2"),
  94. "name": row.get("name"),
  95. "reason": row.get("reason"),
  96. "suggestion": row.get("suggestion"),
  97. "score": row.get("score"),
  98. "dt": row.get("dt"),
  99. "ext_data": ext_data,
  100. },
  101. }
  102. def _fetch_one(self, sql: str, params: tuple[Any, ...]) -> dict[str, Any] | None:
  103. with self._connection_factory() as conn:
  104. with conn.cursor() as cur:
  105. cur.execute(sql, params)
  106. return cur.fetchone()
  107. def _decode_ext_data(value: Any) -> dict[str, Any] | None:
  108. if isinstance(value, dict):
  109. return value
  110. if isinstance(value, str) and value.strip():
  111. decoded = json.loads(value)
  112. if isinstance(decoded, dict):
  113. return decoded
  114. return None