topic_search.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. """视频选题检索:根据关键词在解构结果中匹配,返回 top5"""
  2. import json
  3. from typing import List, Dict, Any, Optional, Tuple
  4. from utils.sync_mysql_help import mysql
  5. from utils.params import TopicSearchParam
  6. TOP_N = 5
  7. SEARCH_FIELDS = ("inspiration_points", "purpose_points", "key_points")
  8. FALLBACK_LIMIT = 3000 # 降级时单次最多拉取条数
  9. def _to_points_list(val: Any) -> List[str]:
  10. """将逗号分隔字符串或列表转为列表格式"""
  11. if val is None:
  12. return []
  13. if isinstance(val, list):
  14. return [str(v).strip() for v in val if v]
  15. if isinstance(val, str):
  16. return [s.strip() for s in val.split(",") if s.strip()]
  17. return [str(val)]
  18. def _extract_search_text(val: Any) -> str:
  19. """从字段值提取文本:支持字符串或列表(逗号分隔)"""
  20. if val is None:
  21. return ""
  22. if isinstance(val, str):
  23. return val.strip()
  24. if isinstance(val, list):
  25. return ",".join(str(v).strip() for v in val if v)
  26. return str(val)
  27. def _concat_search_fields(row: Dict[str, Any]) -> str:
  28. """将检索字段拼接为待匹配文本"""
  29. parts = []
  30. for field in SEARCH_FIELDS:
  31. text = _extract_search_text(row.get(field))
  32. if text:
  33. parts.append(text)
  34. return ",".join(parts)
  35. def _calc_match_score(text: str, keywords: List[str]) -> int:
  36. """计算匹配度:关键词在文本中出现的次数(不区分大小写)"""
  37. if not text or not keywords:
  38. return 0
  39. text_lower = text.lower()
  40. score = 0
  41. for kw in keywords:
  42. if kw and kw.lower() in text_lower:
  43. score += 1
  44. return score
  45. PAYLOAD_FIELDS = (*SEARCH_FIELDS, "topic_fusion_result")
  46. def _escape_like(kw: str) -> str:
  47. """转义 LIKE 中的特殊字符:% _ \\"""
  48. return kw.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
  49. def _build_keyword_like_conds(keywords: List[str]) -> Tuple[str, list]:
  50. """构造关键词的 LIKE 条件,返回 (SQL 片段, 参数列表)"""
  51. if not keywords:
  52. return "1=0", []
  53. placeholders = []
  54. params = []
  55. for kw in keywords:
  56. kw = kw.strip()
  57. if not kw:
  58. continue
  59. escaped = _escape_like(kw.lower())
  60. like_val = f"%{escaped}%"
  61. for field in SEARCH_FIELDS:
  62. placeholders.append(f"(LOWER({field}) LIKE LOWER(%s))")
  63. params.append(like_val)
  64. if not placeholders:
  65. return "1=0", []
  66. return "(" + " OR ".join(placeholders) + ")", params
  67. def _parse_result_payload(payload: Any) -> Dict[str, Any]:
  68. """从 result_payload 解析出检索字段及 topic_fusion_result"""
  69. if not payload:
  70. return {}
  71. if isinstance(payload, str):
  72. try:
  73. payload = json.loads(payload)
  74. except json.JSONDecodeError:
  75. return {}
  76. if not isinstance(payload, dict):
  77. return {}
  78. return {f: payload.get(f) for f in PAYLOAD_FIELDS}
  79. def _fetch_decode_results(keywords: List[str]) -> List[Dict[str, Any]]:
  80. """获取有检索字段且匹配关键词的解构结果。优先用独立列,否则从 result_payload 解析"""
  81. kw_cond, kw_params = _build_keyword_like_conds(keywords)
  82. base_cond = """
  83. (inspiration_points IS NOT NULL AND inspiration_points != '')
  84. OR (purpose_points IS NOT NULL AND purpose_points != '')
  85. OR (key_points IS NOT NULL AND key_points != '')
  86. """
  87. try:
  88. fields = ", ".join(SEARCH_FIELDS) + ", topic_fusion_result, task_id, channel_content_id, title, images, video_url"
  89. sql = f"""
  90. SELECT {fields}
  91. FROM workflow_decode_task_result
  92. WHERE ({base_cond}) AND ({kw_cond})
  93. """
  94. rows = mysql.fetchall(sql, tuple(kw_params) if kw_params else None)
  95. return list(rows) if rows else []
  96. except Exception:
  97. pass
  98. # 降级:从 result_payload 解析,限制条数减少全表扫描
  99. sql = f"""
  100. SELECT task_id, channel_content_id, title, images, video_url, result_payload
  101. FROM workflow_decode_task_result
  102. WHERE result_payload IS NOT NULL AND result_payload != ''
  103. LIMIT {FALLBACK_LIMIT}
  104. """
  105. rows = mysql.fetchall(sql)
  106. if not rows:
  107. return []
  108. out = []
  109. for r in rows:
  110. parsed = _parse_result_payload(r.get("result_payload"))
  111. merged = {**r, **parsed}
  112. if _concat_search_fields(merged):
  113. text = _concat_search_fields(merged)
  114. if _calc_match_score(text, keywords) > 0:
  115. out.append(merged)
  116. return out
  117. def _build_result_item(row: Dict[str, Any], score: int) -> Dict[str, Any]:
  118. """构建单条返回结果,*_points 转为列表格式"""
  119. return {
  120. "inspiration_points": _to_points_list(row.get("inspiration_points")),
  121. "purpose_points": _to_points_list(row.get("purpose_points")),
  122. "key_points": _to_points_list(row.get("key_points")),
  123. "topic_fusion_result": row.get("topic_fusion_result"),
  124. "score": score,
  125. }
  126. def search_topics(param: TopicSearchParam) -> List[Dict[str, Any]]:
  127. """
  128. 根据关键词检索视频选题,返回匹配度最高的 top5。
  129. 无匹配时返回空数组。
  130. """
  131. keywords = [k.strip() for k in param.keywords if k and isinstance(k, str)]
  132. if not keywords:
  133. return []
  134. rows = _fetch_decode_results(keywords)
  135. scored: List[tuple] = []
  136. for row in rows:
  137. text = _concat_search_fields(row)
  138. score = _calc_match_score(text, keywords)
  139. if score > 0:
  140. scored.append((row, score))
  141. scored.sort(key=lambda x: x[1], reverse=True)
  142. top = scored[:TOP_N]
  143. return [_build_result_item(row, score) for row, score in top]