"""视频选题检索:根据关键词在解构结果中匹配,返回 top5""" import json from typing import List, Dict, Any, Optional, Tuple from utils.sync_mysql_help import mysql from utils.params import TopicSearchParam TOP_N = 5 SEARCH_FIELDS = ("inspiration_points", "purpose_points", "key_points") FALLBACK_LIMIT = 3000 # 降级时单次最多拉取条数 def _to_points_list(val: Any) -> List[str]: """将逗号分隔字符串或列表转为列表格式""" if val is None: return [] if isinstance(val, list): return [str(v).strip() for v in val if v] if isinstance(val, str): return [s.strip() for s in val.split(",") if s.strip()] return [str(val)] def _extract_search_text(val: Any) -> str: """从字段值提取文本:支持字符串或列表(逗号分隔)""" if val is None: return "" if isinstance(val, str): return val.strip() if isinstance(val, list): return ",".join(str(v).strip() for v in val if v) return str(val) def _concat_search_fields(row: Dict[str, Any]) -> str: """将检索字段拼接为待匹配文本""" parts = [] for field in SEARCH_FIELDS: text = _extract_search_text(row.get(field)) if text: parts.append(text) return ",".join(parts) def _calc_match_score(text: str, keywords: List[str]) -> int: """计算匹配度:关键词在文本中出现的次数(不区分大小写)""" if not text or not keywords: return 0 text_lower = text.lower() score = 0 for kw in keywords: if kw and kw.lower() in text_lower: score += 1 return score PAYLOAD_FIELDS = (*SEARCH_FIELDS, "topic_fusion_result") def _escape_like(kw: str) -> str: """转义 LIKE 中的特殊字符:% _ \\""" return kw.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") def _build_keyword_like_conds(keywords: List[str]) -> Tuple[str, list]: """构造关键词的 LIKE 条件,返回 (SQL 片段, 参数列表)""" if not keywords: return "1=0", [] placeholders = [] params = [] for kw in keywords: kw = kw.strip() if not kw: continue escaped = _escape_like(kw.lower()) like_val = f"%{escaped}%" for field in SEARCH_FIELDS: placeholders.append(f"(LOWER({field}) LIKE LOWER(%s))") params.append(like_val) if not placeholders: return "1=0", [] return "(" + " OR ".join(placeholders) + ")", params def _parse_result_payload(payload: Any) -> Dict[str, Any]: """从 result_payload 解析出检索字段及 topic_fusion_result""" if not payload: return {} if isinstance(payload, str): try: payload = json.loads(payload) except json.JSONDecodeError: return {} if not isinstance(payload, dict): return {} return {f: payload.get(f) for f in PAYLOAD_FIELDS} def _fetch_decode_results(keywords: List[str]) -> List[Dict[str, Any]]: """获取有检索字段且匹配关键词的解构结果。优先用独立列,否则从 result_payload 解析""" kw_cond, kw_params = _build_keyword_like_conds(keywords) base_cond = """ (inspiration_points IS NOT NULL AND inspiration_points != '') OR (purpose_points IS NOT NULL AND purpose_points != '') OR (key_points IS NOT NULL AND key_points != '') """ try: fields = ", ".join(SEARCH_FIELDS) + ", topic_fusion_result, task_id, channel_content_id, title, images, video_url" sql = f""" SELECT {fields} FROM workflow_decode_task_result WHERE ({base_cond}) AND ({kw_cond}) """ rows = mysql.fetchall(sql, tuple(kw_params) if kw_params else None) return list(rows) if rows else [] except Exception: pass # 降级:从 result_payload 解析,限制条数减少全表扫描 sql = f""" SELECT task_id, channel_content_id, title, images, video_url, result_payload FROM workflow_decode_task_result WHERE result_payload IS NOT NULL AND result_payload != '' LIMIT {FALLBACK_LIMIT} """ rows = mysql.fetchall(sql) if not rows: return [] out = [] for r in rows: parsed = _parse_result_payload(r.get("result_payload")) merged = {**r, **parsed} if _concat_search_fields(merged): text = _concat_search_fields(merged) if _calc_match_score(text, keywords) > 0: out.append(merged) return out def _build_result_item(row: Dict[str, Any], score: int) -> Dict[str, Any]: """构建单条返回结果,*_points 转为列表格式""" return { "inspiration_points": _to_points_list(row.get("inspiration_points")), "purpose_points": _to_points_list(row.get("purpose_points")), "key_points": _to_points_list(row.get("key_points")), "topic_fusion_result": row.get("topic_fusion_result"), "score": score, } def search_topics(param: TopicSearchParam) -> List[Dict[str, Any]]: """ 根据关键词检索视频选题,返回匹配度最高的 top5。 无匹配时返回空数组。 """ keywords = [k.strip() for k in param.keywords if k and isinstance(k, str)] if not keywords: return [] rows = _fetch_decode_results(keywords) scored: List[tuple] = [] for row in rows: text = _concat_search_fields(row) score = _calc_match_score(text, keywords) if score > 0: scored.append((row, score)) scored.sort(key=lambda x: x[1], reverse=True) top = scored[:TOP_N] return [_build_result_item(row, score) for row, score in top]