_utils.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. import re
  2. from dataclasses import dataclass, field
  3. from typing import Any, Dict, List, Optional
  4. from ._const import (
  5. ConfigCode,
  6. DemandRecommendConst,
  7. DemandSource,
  8. MatchMethod,
  9. )
  10. # ──────────────────────────────────────────────
  11. # Dataclasses
  12. # ──────────────────────────────────────────────
  13. @dataclass
  14. class DemandRecord:
  15. """上游需求表的一行原始数据(字段暂按结果表推测,后续按实际表结构对齐)"""
  16. dt: str = ""
  17. action_type: str = ""
  18. match_experiment_id: str = ""
  19. demand_source_crowd: str = ""
  20. demand_strategy: str = ""
  21. match_strategy: str = ""
  22. match_video_rule: str = ""
  23. demand_id: str = ""
  24. crowd_channel: str = ""
  25. crowd_segment: str = ""
  26. crowd_package: str = ""
  27. conversion_target: str = ""
  28. partner: str = ""
  29. account: str = ""
  30. scene_value: str = ""
  31. demand_source: str = ""
  32. drive_dim_time: str = ""
  33. drive_dim_space: str = ""
  34. demand_filter_strategy: str = ""
  35. demand_video_id: int = 0
  36. demand_video_title: str = ""
  37. scene_content_id: str = ""
  38. scene_content_title: str = ""
  39. demand_topic: str = ""
  40. demand_feature_points: str = ""
  41. @classmethod
  42. def from_dict(cls, data: Dict[str, Any]) -> "DemandRecord":
  43. return cls(
  44. dt=str(data.get("dt", "")),
  45. action_type=str(data.get("action_type", "")),
  46. match_experiment_id=str(data.get("match_experiment_id", "")),
  47. demand_source_crowd=str(data.get("demand_source_crowd", "")),
  48. demand_strategy=str(data.get("demand_strategy", "")),
  49. match_strategy=str(data.get("match_strategy", "")),
  50. match_video_rule=str(data.get("match_video_rule", "")),
  51. demand_id=str(data.get("demand_id", "")),
  52. crowd_channel=str(data.get("crowd_channel", "")),
  53. crowd_segment=str(data.get("crowd_segment", "")),
  54. crowd_package=str(data.get("crowd_package", "")),
  55. conversion_target=str(data.get("conversion_target", "")),
  56. partner=str(data.get("partner", "")),
  57. account=str(data.get("account", "")),
  58. scene_value=str(data.get("scene_value", "")),
  59. demand_source=str(data.get("demand_source", "")),
  60. drive_dim_time=str(data.get("drive_dim_time", "")),
  61. drive_dim_space=str(data.get("drive_dim_space", "")),
  62. demand_filter_strategy=str(data.get("demand_filter_strategy", "")),
  63. demand_video_id=int(data.get("demand_video_id", 0) or 0),
  64. demand_video_title=str(data.get("demand_video_title", "")),
  65. scene_content_id=str(data.get("scene_content_id", "")),
  66. scene_content_title=str(data.get("scene_content_title", "")),
  67. demand_topic=str(data.get("demand_topic", "")),
  68. demand_feature_points=str(data.get("demand_feature_points", "")),
  69. )
  70. @dataclass
  71. class MatchStrategy:
  72. """从 DemandRecord 解析出的匹配执行策略"""
  73. demand_id: str
  74. experiment_id: str
  75. dt: str
  76. match_methods: List[str] = field(default_factory=list)
  77. config_codes: List[str] = field(default_factory=list)
  78. top_n: int = DemandRecommendConst.DEFAULT_TOPN
  79. query_text: str = ""
  80. video_id: int = 0
  81. content_id: str = ""
  82. filter_rule: str = ""
  83. multi_recall_fusion: bool = False
  84. @dataclass
  85. class MatchResult:
  86. """单条匹配结果"""
  87. dt: str
  88. demand_id: str
  89. match_experiment_id: str
  90. match_method: str
  91. config_code: str
  92. video_id: int
  93. score: float
  94. rank_position: int = 0
  95. video_title: str = ""
  96. video_detail: Optional[Dict[str, Any]] = None
  97. # ──────────────────────────────────────────────
  98. # Strategy Parser
  99. # ──────────────────────────────────────────────
  100. class DemandStrategyParser:
  101. """解析 DemandRecord → MatchStrategy"""
  102. @staticmethod
  103. def select_config_codes(match_strategy: str) -> List[str]:
  104. """从匹配策略文本中推导 configCode 列表"""
  105. if not match_strategy:
  106. return [DemandRecommendConst.DEFAULT_CONFIG_CODE]
  107. codes: List[str] = []
  108. for keyword, code in DemandRecommendConst.STRATEGY_CONFIG_MAP.items():
  109. if keyword in match_strategy and code not in codes:
  110. codes.append(code)
  111. if not codes:
  112. codes.append(DemandRecommendConst.DEFAULT_CONFIG_CODE)
  113. return codes
  114. @staticmethod
  115. def select_match_methods(demand: DemandRecord) -> List[str]:
  116. """从需求行的 match_video_rule + 可用字段推导匹配方式"""
  117. if not demand.match_video_rule:
  118. return DemandStrategyParser._fallback_methods(demand)
  119. methods: List[str] = []
  120. for keyword, method in DemandRecommendConst.RULE_METHOD_MAP.items():
  121. if keyword in demand.match_video_rule:
  122. if method == MatchMethod.VIDEO_ID and demand.demand_video_id > 0:
  123. methods.append(method)
  124. elif method == MatchMethod.CONTENT_ID and demand.scene_content_id:
  125. methods.append(method)
  126. elif method == MatchMethod.TEXT and (
  127. demand.demand_topic or demand.demand_feature_points
  128. ):
  129. methods.append(method)
  130. if not methods:
  131. return DemandStrategyParser._fallback_methods(demand)
  132. return methods
  133. @staticmethod
  134. def _fallback_methods(demand: DemandRecord) -> List[str]:
  135. """当 match_video_rule 无法解析时,按可用字段兜底推导"""
  136. methods: List[str] = []
  137. if demand.demand_video_id > 0:
  138. methods.append(MatchMethod.VIDEO_ID)
  139. if demand.scene_content_id:
  140. methods.append(MatchMethod.CONTENT_ID)
  141. if demand.demand_topic or demand.demand_feature_points:
  142. methods.append(MatchMethod.TEXT)
  143. if not methods:
  144. methods.append(MatchMethod.TEXT) # 最终兜底
  145. return methods
  146. @staticmethod
  147. def parse_top_n(match_strategy: str) -> int:
  148. """从匹配策略中解析 topN 参数,缺省 10"""
  149. if not match_strategy:
  150. return DemandRecommendConst.DEFAULT_TOPN
  151. m = re.search(r"topN[=:]?\s*(\d+)", match_strategy, re.IGNORECASE)
  152. if m:
  153. return int(m.group(1))
  154. return DemandRecommendConst.DEFAULT_TOPN
  155. @classmethod
  156. def parse(cls, demand: DemandRecord) -> MatchStrategy:
  157. """完整解析一条需求记录为匹配策略"""
  158. return MatchStrategy(
  159. demand_id=demand.demand_id,
  160. experiment_id=demand.match_experiment_id,
  161. dt=demand.dt,
  162. match_methods=cls.select_match_methods(demand),
  163. config_codes=cls.select_config_codes(demand.match_strategy),
  164. top_n=cls.parse_top_n(demand.match_strategy),
  165. query_text=build_query_text(demand.demand_topic, demand.demand_feature_points),
  166. video_id=demand.demand_video_id,
  167. content_id=demand.scene_content_id,
  168. filter_rule=demand.demand_filter_strategy,
  169. multi_recall_fusion=("多路" in (demand.match_strategy or ""))
  170. or (len(cls.select_match_methods(demand)) > 1),
  171. )
  172. # ──────────────────────────────────────────────
  173. # Helpers
  174. # ──────────────────────────────────────────────
  175. def build_query_text(topic: str, feature_points: str) -> str:
  176. """拼接选题 + 特征点为检索文本"""
  177. parts = [p for p in [topic, feature_points] if p and p.strip()]
  178. return "。".join(parts) if parts else ""
  179. def parse_recall_items(
  180. api_response: Dict[str, Any],
  181. strategy: MatchStrategy,
  182. match_method: str,
  183. config_code: str,
  184. ) -> List[MatchResult]:
  185. """解析 API 返回结果为 MatchResult 列表"""
  186. if not api_response or api_response.get("code") != 0:
  187. return []
  188. data = api_response.get("data")
  189. if not data:
  190. return []
  191. # matchTopNVideo 返回 data 直接是 list
  192. if isinstance(data, list):
  193. items = data
  194. else:
  195. # recallTest 返回 data.items[]
  196. items = data.get("items", [])
  197. results: List[MatchResult] = []
  198. for rank, item in enumerate(items, start=1):
  199. vid = item.get("id") or item.get("videoId", 0)
  200. if not vid:
  201. continue
  202. results.append(MatchResult(
  203. dt=strategy.dt,
  204. demand_id=strategy.demand_id,
  205. match_experiment_id=strategy.experiment_id,
  206. match_method=match_method,
  207. config_code=config_code,
  208. video_id=int(vid),
  209. score=float(item.get("score", 0)),
  210. rank_position=rank,
  211. video_title=str(item.get("title", "")),
  212. video_detail=item.get("videoDetail"),
  213. ))
  214. return results
  215. def merge_multi_recall(
  216. result_groups: List[List[MatchResult]],
  217. top_n: int,
  218. ) -> List[MatchResult]:
  219. """多路召回结果合并:按 video_id 去重,保留最高分,取 top_n"""
  220. merged: Dict[int, MatchResult] = {}
  221. for group in result_groups:
  222. for r in group:
  223. if r.video_id in merged:
  224. if r.score > merged[r.video_id].score:
  225. merged[r.video_id] = r
  226. else:
  227. merged[r.video_id] = r
  228. sorted_results = sorted(merged.values(), key=lambda x: x.score, reverse=True)
  229. for i, r in enumerate(sorted_results[:top_n], start=1):
  230. r.rank_position = i
  231. return sorted_results[:top_n]