_utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. import re
  2. from dataclasses import dataclass, field
  3. from typing import Any, Dict, List, Optional
  4. from ._const import (
  5. ConfigCode,
  6. DemandRecommendConst,
  7. DemandSource,
  8. MatchMethod,
  9. )
  10. # ──────────────────────────────────────────────
  11. # Dataclasses
  12. # ──────────────────────────────────────────────
  13. @dataclass
  14. class DemandRecord:
  15. """上游需求表的一行原始数据(字段暂按结果表推测,后续按实际表结构对齐)"""
  16. dt: str = ""
  17. action_type: str = ""
  18. match_experiment_id: str = ""
  19. demand_source_crowd: str = ""
  20. demand_strategy: str = ""
  21. match_strategy: str = ""
  22. match_video_rule: str = ""
  23. demand_id: str = ""
  24. crowd_channel: str = ""
  25. crowd_segment: str = ""
  26. crowd_package: str = ""
  27. conversion_target: str = ""
  28. partner: str = ""
  29. account: str = ""
  30. scene_value: str = ""
  31. demand_source: str = ""
  32. drive_dim_time: str = ""
  33. drive_dim_space: str = ""
  34. demand_filter_strategy: str = ""
  35. demand_video_id: int = 0
  36. demand_video_title: str = ""
  37. scene_content_id: str = ""
  38. scene_content_title: str = ""
  39. demand_topic: str = ""
  40. demand_feature_points: str = ""
  41. @classmethod
  42. def from_dict(cls, data: Dict[str, Any]) -> "DemandRecord":
  43. return cls(
  44. dt=str(data.get("dt", "")),
  45. action_type=str(data.get("action_type", "")),
  46. match_experiment_id=str(data.get("match_experiment_id", "")),
  47. demand_source_crowd=str(data.get("demand_source_crowd", "")),
  48. demand_strategy=str(data.get("demand_strategy", "")),
  49. match_strategy=str(data.get("match_strategy", "")),
  50. match_video_rule=str(data.get("match_video_rule", "")),
  51. demand_id=str(data.get("demand_id", "")),
  52. crowd_channel=str(data.get("crowd_channel", "")),
  53. crowd_segment=str(data.get("crowd_segment", "")),
  54. crowd_package=str(data.get("crowd_package", "")),
  55. conversion_target=str(data.get("conversion_target", "")),
  56. partner=str(data.get("partner", "")),
  57. account=str(data.get("account", "")),
  58. scene_value=str(data.get("scene_value", "")),
  59. demand_source=str(data.get("demand_source", "")),
  60. drive_dim_time=str(data.get("drive_dim_time", "")),
  61. drive_dim_space=str(data.get("drive_dim_space", "")),
  62. demand_filter_strategy=str(data.get("demand_filter_strategy", "")),
  63. demand_video_id=int(data.get("demand_video_id", 0) or 0),
  64. demand_video_title=str(data.get("demand_video_title", "")),
  65. scene_content_id=str(data.get("scene_content_id", "")),
  66. scene_content_title=str(data.get("scene_content_title", "")),
  67. demand_topic=str(data.get("demand_topic", "")),
  68. demand_feature_points=str(data.get("demand_feature_points", "")),
  69. )
  70. @dataclass
  71. class MatchStrategy:
  72. """从 DemandRecord 解析出的匹配执行策略"""
  73. demand_id: str
  74. experiment_id: str
  75. dt: str
  76. match_methods: List[str] = field(default_factory=list)
  77. config_codes: List[str] = field(default_factory=list)
  78. top_n: int = DemandRecommendConst.DEFAULT_TOPN
  79. query_text: str = ""
  80. video_id: int = 0
  81. content_id: str = ""
  82. filter_rule: str = ""
  83. multi_recall_fusion: bool = False
  84. @dataclass
  85. class MatchResult:
  86. """单条匹配结果"""
  87. dt: str
  88. demand_id: str
  89. match_experiment_id: str
  90. match_method: str
  91. config_code: str
  92. video_id: int
  93. score: float
  94. rank_position: int = 0
  95. video_title: str = ""
  96. video_detail: Optional[Dict[str, Any]] = None
  97. # recallWithScore 专用字段,非 scoring 模式时为 0
  98. sim: float = 0.0
  99. sim_norm: float = 0.0
  100. rov: float = 0.0
  101. rov_norm: float = 0.0
  102. # ──────────────────────────────────────────────
  103. # Strategy Parser
  104. # ──────────────────────────────────────────────
  105. class DemandStrategyParser:
  106. """解析 DemandRecord → MatchStrategy"""
  107. @staticmethod
  108. def select_config_codes(match_strategy: str) -> List[str]:
  109. """从匹配策略文本中推导 configCode 列表"""
  110. if not match_strategy:
  111. return [DemandRecommendConst.DEFAULT_CONFIG_CODE]
  112. codes: List[str] = []
  113. for keyword, code in DemandRecommendConst.STRATEGY_CONFIG_MAP.items():
  114. if keyword in match_strategy and code not in codes:
  115. codes.append(code)
  116. if not codes:
  117. codes.append(DemandRecommendConst.DEFAULT_CONFIG_CODE)
  118. return codes
  119. @staticmethod
  120. def select_match_methods(demand: DemandRecord) -> List[str]:
  121. """从需求行的 match_video_rule + 可用字段推导匹配方式"""
  122. if not demand.match_video_rule:
  123. return DemandStrategyParser._fallback_methods(demand)
  124. methods: List[str] = []
  125. for keyword, method in DemandRecommendConst.RULE_METHOD_MAP.items():
  126. if keyword in demand.match_video_rule:
  127. if method == MatchMethod.VIDEO_ID and demand.demand_video_id > 0:
  128. methods.append(method)
  129. elif method == MatchMethod.CONTENT_ID and demand.scene_content_id:
  130. methods.append(method)
  131. elif method == MatchMethod.TEXT and (
  132. demand.demand_topic or demand.demand_feature_points
  133. ):
  134. methods.append(method)
  135. if not methods:
  136. return DemandStrategyParser._fallback_methods(demand)
  137. return methods
  138. @staticmethod
  139. def _fallback_methods(demand: DemandRecord) -> List[str]:
  140. """当 match_video_rule 无法解析时,按可用字段兜底推导"""
  141. methods: List[str] = []
  142. if demand.demand_video_id > 0:
  143. methods.append(MatchMethod.VIDEO_ID)
  144. if demand.scene_content_id:
  145. methods.append(MatchMethod.CONTENT_ID)
  146. if demand.demand_topic or demand.demand_feature_points:
  147. methods.append(MatchMethod.TEXT)
  148. if not methods:
  149. methods.append(MatchMethod.TEXT) # 最终兜底
  150. return methods
  151. @staticmethod
  152. def parse_top_n(match_strategy: str) -> int:
  153. """从匹配策略中解析 topN 参数,缺省 10"""
  154. if not match_strategy:
  155. return DemandRecommendConst.DEFAULT_TOPN
  156. m = re.search(r"topN[=:]?\s*(\d+)", match_strategy, re.IGNORECASE)
  157. if m:
  158. return int(m.group(1))
  159. return DemandRecommendConst.DEFAULT_TOPN
  160. @classmethod
  161. def parse(cls, demand: DemandRecord) -> MatchStrategy:
  162. """完整解析一条需求记录为匹配策略"""
  163. return MatchStrategy(
  164. demand_id=demand.demand_id,
  165. experiment_id=demand.match_experiment_id,
  166. dt=demand.dt,
  167. match_methods=cls.select_match_methods(demand),
  168. config_codes=cls.select_config_codes(demand.match_strategy),
  169. top_n=cls.parse_top_n(demand.match_strategy),
  170. query_text=build_query_text(demand.demand_topic, demand.demand_feature_points),
  171. video_id=demand.demand_video_id,
  172. content_id=demand.scene_content_id,
  173. filter_rule=demand.demand_filter_strategy,
  174. multi_recall_fusion=("多路" in (demand.match_strategy or ""))
  175. or (len(cls.select_match_methods(demand)) > 1),
  176. )
  177. # ──────────────────────────────────────────────
  178. # Helpers
  179. # ──────────────────────────────────────────────
  180. def build_query_text(topic: str, feature_points: str) -> str:
  181. """拼接选题 + 特征点为检索文本"""
  182. parts = [p for p in [topic, feature_points] if p and p.strip()]
  183. return "。".join(parts) if parts else ""
  184. def parse_recall_items(
  185. api_response: Dict[str, Any],
  186. strategy: MatchStrategy,
  187. match_method: str,
  188. config_code: str,
  189. ) -> List[MatchResult]:
  190. """解析 API 返回结果为 MatchResult 列表"""
  191. if not api_response or api_response.get("code") != 0:
  192. return []
  193. data = api_response.get("data")
  194. if not data:
  195. return []
  196. # matchTopNVideo 返回 data 直接是 list
  197. if isinstance(data, list):
  198. items = data
  199. else:
  200. # recallTest 返回 data.items[]
  201. items = data.get("items", [])
  202. results: List[MatchResult] = []
  203. for rank, item in enumerate(items, start=1):
  204. vid = item.get("id") or item.get("videoId", 0)
  205. if not vid:
  206. continue
  207. results.append(MatchResult(
  208. dt=strategy.dt,
  209. demand_id=strategy.demand_id,
  210. match_experiment_id=strategy.experiment_id,
  211. match_method=match_method,
  212. config_code=config_code,
  213. video_id=int(vid),
  214. score=float(item.get("score", 0)),
  215. rank_position=rank,
  216. video_title=str(item.get("title", "")),
  217. video_detail=item.get("videoDetail"),
  218. sim=float(item.get("sim", 0)),
  219. sim_norm=float(item.get("simNorm", 0)),
  220. rov=float(item.get("rov", 0)),
  221. rov_norm=float(item.get("rovNorm", 0)),
  222. ))
  223. return results
  224. def parse_scored_items(
  225. api_response: Dict[str, Any],
  226. strategy: MatchStrategy,
  227. config_code: str,
  228. ) -> List[MatchResult]:
  229. """解析 recallWithScore 返回的 scored items 为 MatchResult 列表"""
  230. if not api_response or api_response.get("code") != 0:
  231. return []
  232. data = api_response.get("data")
  233. if not data:
  234. return []
  235. items = data.get("items", [])
  236. results: List[MatchResult] = []
  237. for rank, item in enumerate(items, start=1):
  238. vid = item.get("videoId", 0)
  239. if not vid:
  240. continue
  241. detail = item.get("videoDetail") or {}
  242. # 优先从 videoDetail 取真实标题,取不到用 text(向量化选题)
  243. raw_title = detail.get("title") or detail.get("选题") or str(item.get("text", ""))
  244. results.append(MatchResult(
  245. dt=strategy.dt,
  246. demand_id=strategy.demand_id,
  247. match_experiment_id=strategy.experiment_id,
  248. match_method=MatchMethod.TEXT,
  249. config_code=item.get("configCode", config_code),
  250. video_id=int(vid),
  251. score=float(item.get("score", 0) or item.get("sim", 0)),
  252. rank_position=rank,
  253. video_title=raw_title,
  254. video_detail=detail,
  255. sim=float(item.get("sim", 0)),
  256. sim_norm=float(item.get("simNorm", 0)),
  257. rov=float(item.get("rov", 0)),
  258. rov_norm=float(item.get("rovNorm", 0)),
  259. ))
  260. return results
  261. def merge_multi_recall(
  262. result_groups: List[List[MatchResult]],
  263. top_n: int,
  264. ) -> List[MatchResult]:
  265. """多路召回结果合并:按 video_id 去重,保留最高分,取 top_n"""
  266. merged: Dict[int, MatchResult] = {}
  267. for group in result_groups:
  268. for r in group:
  269. if r.video_id in merged:
  270. if r.score > merged[r.video_id].score:
  271. merged[r.video_id] = r
  272. else:
  273. merged[r.video_id] = r
  274. sorted_results = sorted(merged.values(), key=lambda x: x.score, reverse=True)
  275. for i, r in enumerate(sorted_results[:top_n], start=1):
  276. r.rank_position = i
  277. return sorted_results[:top_n]