demand_pool_writer.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. """近期热点需求写入 Hive 需求池表。"""
  2. from __future__ import annotations
  3. import re
  4. from datetime import datetime
  5. from typing import Any
  6. from app.aliyun_odps.client import get_odps_client
  7. from app.hot_content.demand_hive_export import build_hive_rows_from_export_groups
  8. from app.hot_content.exceptions import HotContentFlowError
  9. from app.hot_content.repository import HotContentRepository
  10. from app.hot_content.timezone import SHANGHAI_TZ
  11. from app.hot_content.types import FlowConfig
  12. IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)?$")
  13. HIVE_COLUMNS = (
  14. "strategy",
  15. "demand_id",
  16. "demand_name",
  17. "weight",
  18. "type",
  19. "video_count",
  20. "video_list",
  21. "extend",
  22. )
  23. def _safe_identifier(name: str) -> str:
  24. value = name.strip()
  25. if not IDENTIFIER_RE.match(value):
  26. raise HotContentFlowError(f"invalid sql identifier: {name}")
  27. return value
  28. def _escape_sql_string(value: str) -> str:
  29. return value.replace("'", "''")
  30. class HotDemandPoolWriter:
  31. def __init__(self, config: FlowConfig, repository: HotContentRepository):
  32. self.config = config
  33. self.repository = repository
  34. def sync_today(self) -> dict[str, Any]:
  35. partition_dt = datetime.now(SHANGHAI_TZ).date().strftime("%Y%m%d")
  36. export_groups = self.repository.list_demand_export_groups()
  37. hive_rows = build_hive_rows_from_export_groups(
  38. export_groups,
  39. strategy=self.config.hot_demand_pool_strategy,
  40. partition_dt=partition_dt,
  41. wxindex_threshold=self.config.wxindex_score_threshold,
  42. )
  43. written_count = self._write_partition(
  44. hive_rows=hive_rows,
  45. partition_dt=partition_dt,
  46. strategy=self.config.hot_demand_pool_strategy,
  47. )
  48. return {
  49. "partition_dt": partition_dt,
  50. "strategy": self.config.hot_demand_pool_strategy,
  51. "source_record_count": len(export_groups),
  52. "hive_row_count": len(hive_rows),
  53. "written_count": written_count,
  54. "target_table": self.config.demand_pool_source_table,
  55. }
  56. def _write_partition(
  57. self,
  58. *,
  59. hive_rows: list[dict[str, Any]],
  60. partition_dt: str,
  61. strategy: str,
  62. ) -> int:
  63. table_name = _safe_identifier(self.config.demand_pool_source_table)
  64. odps_client = get_odps_client()
  65. table = odps_client.get_table(table_name)
  66. partition_spec = f"dt={partition_dt}"
  67. preserved_rows = self._read_preserved_rows(
  68. table=table,
  69. partition_spec=partition_spec,
  70. strategy=strategy,
  71. )
  72. payload_rows = preserved_rows + [
  73. self._to_write_row(row) for row in hive_rows
  74. ]
  75. if not payload_rows and table.exist_partition(partition_spec):
  76. odps_client.write_table(
  77. table_name,
  78. [],
  79. partition=partition_spec,
  80. create_partition=True,
  81. overwrite=True,
  82. )
  83. return 0
  84. odps_client.write_table(
  85. table_name,
  86. payload_rows,
  87. partition=partition_spec,
  88. create_partition=True,
  89. overwrite=True,
  90. )
  91. return len(hive_rows)
  92. @staticmethod
  93. def _read_preserved_rows(
  94. *,
  95. table: Any,
  96. partition_spec: str,
  97. strategy: str,
  98. ) -> list[list[Any]]:
  99. if not table.exist_partition(partition_spec):
  100. return []
  101. preserved_rows: list[list[Any]] = []
  102. with table.open_reader(partition=partition_spec) as reader:
  103. for record in reader:
  104. if str(record["strategy"] or "") == strategy:
  105. continue
  106. preserved_rows.append(
  107. [
  108. record["strategy"],
  109. record["demand_id"],
  110. record["demand_name"],
  111. record["weight"],
  112. record["type"],
  113. record["video_count"],
  114. record["video_list"],
  115. record["extend"],
  116. ]
  117. )
  118. return preserved_rows
  119. @staticmethod
  120. def _to_write_row(row: dict[str, Any]) -> list[Any]:
  121. return [
  122. row["strategy"],
  123. row["demand_id"],
  124. row["demand_name"],
  125. float(row["weight"]),
  126. row["type"],
  127. row["video_count"],
  128. row["video_list"],
  129. row["extend"],
  130. ]
  131. def sync_hot_demands_to_hive(
  132. config: FlowConfig,
  133. repository: HotContentRepository,
  134. ) -> dict[str, Any]:
  135. writer = HotDemandPoolWriter(config, repository)
  136. return writer.sync_today()