|
|
@@ -3,6 +3,8 @@ import json
|
|
|
import os
|
|
|
import sys
|
|
|
import argparse
|
|
|
+import time
|
|
|
+import hashlib
|
|
|
from datetime import datetime
|
|
|
from typing import Literal, Optional
|
|
|
|
|
|
@@ -15,6 +17,8 @@ from lib.client import get_model
|
|
|
MODEL_NAME = "google/gemini-2.5-flash"
|
|
|
# 得分提升阈值:sug或组合词必须比来源query提升至少此幅度才能进入下一轮
|
|
|
REQUIRED_SCORE_GAIN = 0.02
|
|
|
+SUG_CACHE_TTL = 24 * 3600 # 24小时
|
|
|
+SUG_CACHE_DIR = os.path.join(os.path.dirname(__file__), "data", "sug_cache")
|
|
|
from script.search_recommendations.xiaohongshu_search_recommendations import XiaohongshuSearchRecommendations
|
|
|
from script.search.xiaohongshu_search import XiaohongshuSearch
|
|
|
from script.search.xiaohongshu_detail import XiaohongshuDetail
|
|
|
@@ -1781,6 +1785,73 @@ scope_category_evaluator = Agent[None](
|
|
|
# v121 新增辅助函数
|
|
|
# ============================================================================
|
|
|
|
|
|
+def _ensure_sug_cache_dir():
|
|
|
+ """确保SUG缓存目录存在"""
|
|
|
+ os.makedirs(SUG_CACHE_DIR, exist_ok=True)
|
|
|
+
|
|
|
+
|
|
|
+def _sug_cache_path(keyword: str) -> str:
|
|
|
+ """根据关键词生成缓存文件路径"""
|
|
|
+ key_hash = hashlib.md5(keyword.encode("utf-8")).hexdigest()
|
|
|
+ return os.path.join(SUG_CACHE_DIR, f"{key_hash}.json")
|
|
|
+
|
|
|
+
|
|
|
+def load_sug_cache(keyword: str) -> Optional[list[str]]:
|
|
|
+ """从持久化缓存中读取SUG结果"""
|
|
|
+ if not keyword:
|
|
|
+ return None
|
|
|
+
|
|
|
+ cache_path = _sug_cache_path(keyword)
|
|
|
+ if not os.path.exists(cache_path):
|
|
|
+ return None
|
|
|
+
|
|
|
+ file_age = time.time() - os.path.getmtime(cache_path)
|
|
|
+ if file_age > SUG_CACHE_TTL:
|
|
|
+ return None
|
|
|
+
|
|
|
+ try:
|
|
|
+ with open(cache_path, "r", encoding="utf-8") as f:
|
|
|
+ data = json.load(f)
|
|
|
+ suggestions = data.get("suggestions")
|
|
|
+ if isinstance(suggestions, list):
|
|
|
+ return suggestions
|
|
|
+ except Exception as exc:
|
|
|
+ print(f" ⚠️ 读取SUG缓存失败({keyword}): {exc}")
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+def save_sug_cache(keyword: str, suggestions: list[str]):
|
|
|
+ """将SUG结果写入持久化缓存"""
|
|
|
+ if not keyword or not isinstance(suggestions, list):
|
|
|
+ return
|
|
|
+
|
|
|
+ _ensure_sug_cache_dir()
|
|
|
+ cache_path = _sug_cache_path(keyword)
|
|
|
+ try:
|
|
|
+ payload = {
|
|
|
+ "keyword": keyword,
|
|
|
+ "suggestions": suggestions,
|
|
|
+ "timestamp": datetime.now().isoformat()
|
|
|
+ }
|
|
|
+ with open(cache_path, "w", encoding="utf-8") as f:
|
|
|
+ json.dump(payload, f, ensure_ascii=False, indent=2)
|
|
|
+ except Exception as exc:
|
|
|
+ print(f" ⚠️ 写入SUG缓存失败({keyword}): {exc}")
|
|
|
+
|
|
|
+
|
|
|
+def get_suggestions_with_cache(keyword: str, api: XiaohongshuSearchRecommendations) -> list[str]:
|
|
|
+ """带持久化缓存的SUG获取"""
|
|
|
+ cached = load_sug_cache(keyword)
|
|
|
+ if cached is not None:
|
|
|
+ print(f" 📦 SUG缓存命中: {keyword} ({len(cached)} 个)")
|
|
|
+ return cached
|
|
|
+
|
|
|
+ suggestions = api.get_recommendations(keyword=keyword)
|
|
|
+ if suggestions:
|
|
|
+ save_sug_cache(keyword, suggestions)
|
|
|
+ return suggestions
|
|
|
+
|
|
|
+
|
|
|
def get_ordered_subsets(words: list[str], min_len: int = 1) -> list[list[str]]:
|
|
|
"""
|
|
|
生成words的所有有序子集(可跳过但不可重排)
|
|
|
@@ -2841,7 +2912,7 @@ async def run_round(
|
|
|
sug_list_list = [] # list of list
|
|
|
for q in q_list:
|
|
|
print(f"\n 处理q: {q.text}")
|
|
|
- suggestions = xiaohongshu_api.get_recommendations(keyword=q.text)
|
|
|
+ suggestions = get_suggestions_with_cache(q.text, xiaohongshu_api)
|
|
|
|
|
|
q_sug_list = []
|
|
|
if suggestions:
|
|
|
@@ -3530,7 +3601,7 @@ async def run_round_v2(
|
|
|
sug_details = {}
|
|
|
|
|
|
for q in query_input:
|
|
|
- suggestions = xiaohongshu_api.get_recommendations(keyword=q.text)
|
|
|
+ suggestions = get_suggestions_with_cache(q.text, xiaohongshu_api)
|
|
|
if suggestions:
|
|
|
print(f" {q.text}: 获取到 {len(suggestions)} 个SUG")
|
|
|
for sug_text in suggestions:
|
|
|
@@ -3571,91 +3642,85 @@ async def run_round_v2(
|
|
|
"type": "sug"
|
|
|
})
|
|
|
|
|
|
- # 步骤3: 搜索高分SUG
|
|
|
- print(f"\n[步骤3] 搜索高分SUG(阈值 > {sug_threshold})...")
|
|
|
- high_score_sugs = [sug for sug in all_sugs if sug.score_with_o > sug_threshold]
|
|
|
- print(f" 找到 {len(high_score_sugs)} 个高分SUG")
|
|
|
-
|
|
|
- search_list = []
|
|
|
- # extraction_results = {} # 内容提取流程已断开
|
|
|
-
|
|
|
- if len(high_score_sugs) > 0:
|
|
|
- async def search_for_sug(sug: Sug) -> Search:
|
|
|
- """返回Search结果"""
|
|
|
- print(f" 搜索: {sug.text}")
|
|
|
- # post_extractions = {} # 内容提取流程已断开
|
|
|
-
|
|
|
- try:
|
|
|
- search_result = xiaohongshu_search.search(keyword=sug.text)
|
|
|
- # xiaohongshu_search.search() 已经返回解析后的数据
|
|
|
- notes = search_result.get("data", {}).get("data", [])
|
|
|
- post_list = []
|
|
|
- for note in notes[:10]:
|
|
|
+ # 定义通用搜索函数(供步骤2.5、3、5.5共用)
|
|
|
+ async def search_keyword(text: str, score: float, source_type: str) -> Search:
|
|
|
+ """通用搜索函数"""
|
|
|
+ print(f" 搜索: {text} (来源: {source_type})")
|
|
|
+ try:
|
|
|
+ search_result = xiaohongshu_search.search(keyword=text)
|
|
|
+ notes = search_result.get("data", {}).get("data", [])
|
|
|
+ post_list = []
|
|
|
+
|
|
|
+ for note in notes[:10]:
|
|
|
+ try:
|
|
|
+ post = process_note_data(note)
|
|
|
+ post_list.append(post)
|
|
|
+ except Exception as e:
|
|
|
+ print(f" ⚠️ 解析帖子失败 {note.get('id', 'unknown')}: {str(e)[:50]}")
|
|
|
+
|
|
|
+ # 补充详情信息(仅视频类型需要补充视频URL)
|
|
|
+ video_posts = [p for p in post_list if p.type == "video"]
|
|
|
+ if video_posts:
|
|
|
+ print(f" 补充详情({len(video_posts)}个视频)...")
|
|
|
+ for post in video_posts:
|
|
|
try:
|
|
|
- post = process_note_data(note)
|
|
|
-
|
|
|
- # # 🆕 多模态提取(搜索后立即处理) - 内容提取流程已断开
|
|
|
- # if post.type == "normal" and len(post.images) > 0:
|
|
|
- # extraction = await extract_post_images(post)
|
|
|
- # if extraction:
|
|
|
- # post_extractions[post.note_id] = extraction
|
|
|
-
|
|
|
- post_list.append(post)
|
|
|
+ detail_response = xiaohongshu_detail.get_detail(post.note_id)
|
|
|
+ enrich_post_with_detail(post, detail_response)
|
|
|
except Exception as e:
|
|
|
- print(f" ⚠️ 解析帖子失败 {note.get('id', 'unknown')}: {str(e)[:50]}")
|
|
|
+ print(f" ⚠️ 详情补充失败 {post.note_id}: {str(e)[:50]}")
|
|
|
|
|
|
- # 补充详情信息(仅视频类型需要补充视频URL)
|
|
|
- video_posts = [p for p in post_list if p.type == "video"]
|
|
|
- if video_posts:
|
|
|
- print(f" 补充详情({len(video_posts)}个视频)...")
|
|
|
- for post in video_posts:
|
|
|
- try:
|
|
|
- detail_response = xiaohongshu_detail.get_detail(post.note_id)
|
|
|
- enrich_post_with_detail(post, detail_response)
|
|
|
- except Exception as e:
|
|
|
- print(f" ⚠️ 详情补充失败 {post.note_id}: {str(e)[:50]}")
|
|
|
+ print(f" → 找到 {len(post_list)} 个帖子")
|
|
|
+ return Search(text=text, score_with_o=score, post_list=post_list)
|
|
|
+ except Exception as e:
|
|
|
+ print(f" ✗ 搜索失败: {e}")
|
|
|
+ return Search(text=text, score_with_o=score, post_list=[])
|
|
|
|
|
|
- print(f" → 找到 {len(post_list)} 个帖子")
|
|
|
+ # 初始化search_list
|
|
|
+ search_list = []
|
|
|
|
|
|
- return Search(
|
|
|
- text=sug.text,
|
|
|
- score_with_o=sug.score_with_o,
|
|
|
- from_q=sug.from_q,
|
|
|
- post_list=post_list
|
|
|
- )
|
|
|
- # , post_extractions # 内容提取流程已断开
|
|
|
+ # 步骤2.5: 搜索高分query_input
|
|
|
+ print(f"\n[步骤2.5] 搜索高分输入query(阈值 > {sug_threshold})...")
|
|
|
+ high_score_queries = [q for q in query_input if q.score_with_o > sug_threshold]
|
|
|
+ print(f" 找到 {len(high_score_queries)} 个高分输入query")
|
|
|
|
|
|
- except Exception as e:
|
|
|
- print(f" ✗ 搜索失败: {e}")
|
|
|
- return Search(
|
|
|
- text=sug.text,
|
|
|
- score_with_o=sug.score_with_o,
|
|
|
- from_q=sug.from_q,
|
|
|
- post_list=[]
|
|
|
- )
|
|
|
- # , {} # 内容提取流程已断开
|
|
|
+ if high_score_queries:
|
|
|
+ query_search_tasks = [search_keyword(q.text, q.score_with_o, "query_input")
|
|
|
+ for q in high_score_queries]
|
|
|
+ query_searches = await asyncio.gather(*query_search_tasks)
|
|
|
+ search_list.extend(query_searches)
|
|
|
|
|
|
- search_tasks = [search_for_sug(sug) for sug in high_score_sugs]
|
|
|
- results = await asyncio.gather(*search_tasks)
|
|
|
+ # 评估搜索结果中的帖子
|
|
|
+ if enable_evaluation:
|
|
|
+ print(f"\n[评估] 评估query_input搜索结果中的帖子...")
|
|
|
+ for search in query_searches:
|
|
|
+ if search.post_list:
|
|
|
+ print(f" 评估来自 '{search.text}' 的 {len(search.post_list)} 个帖子")
|
|
|
+ for post in search.post_list:
|
|
|
+ knowledge_eval, content_eval, purpose_eval, category_eval, final_score, match_level = await evaluate_post_v3(post, o, semaphore=None)
|
|
|
+ if knowledge_eval:
|
|
|
+ apply_evaluation_v3_to_post(post, knowledge_eval, content_eval, purpose_eval, category_eval, final_score, match_level)
|
|
|
+
|
|
|
+ # 步骤3: 搜索高分SUG
|
|
|
+ print(f"\n[步骤3] 搜索高分SUG(阈值 > {sug_threshold})...")
|
|
|
+ high_score_sugs = [sug for sug in all_sugs if sug.score_with_o > sug_threshold]
|
|
|
+ print(f" 找到 {len(high_score_sugs)} 个高分SUG")
|
|
|
|
|
|
- # 收集搜索结果
|
|
|
- for search in results:
|
|
|
- search_list.append(search)
|
|
|
- # extraction_results.update(extractions) # 内容提取流程已断开
|
|
|
+ if high_score_sugs:
|
|
|
+ sug_search_tasks = [search_keyword(sug.text, sug.score_with_o, "sug")
|
|
|
+ for sug in high_score_sugs]
|
|
|
+ sug_searches = await asyncio.gather(*sug_search_tasks)
|
|
|
+ search_list.extend(sug_searches)
|
|
|
|
|
|
# 评估搜索结果中的帖子
|
|
|
if enable_evaluation:
|
|
|
- print(f"\n[评估] 评估搜索结果中的帖子...")
|
|
|
- for search in search_list:
|
|
|
+ print(f"\n[评估] 评估SUG搜索结果中的帖子...")
|
|
|
+ for search in sug_searches:
|
|
|
if search.post_list:
|
|
|
print(f" 评估来自 '{search.text}' 的 {len(search.post_list)} 个帖子")
|
|
|
- # 对每个帖子进行评估 (V3)
|
|
|
for post in search.post_list:
|
|
|
knowledge_eval, content_eval, purpose_eval, category_eval, final_score, match_level = await evaluate_post_v3(post, o, semaphore=None)
|
|
|
if knowledge_eval:
|
|
|
apply_evaluation_v3_to_post(post, knowledge_eval, content_eval, purpose_eval, category_eval, final_score, match_level)
|
|
|
- else:
|
|
|
- print(f"\n[评估] 实时评估已关闭 (使用 --enable-evaluation 启用)")
|
|
|
|
|
|
# 步骤4: 生成N域组合
|
|
|
print(f"\n[步骤4] 生成{round_num}域组合...")
|
|
|
@@ -3753,6 +3818,29 @@ async def run_round_v2(
|
|
|
comb.score_with_o > score for score in flat_scores
|
|
|
)
|
|
|
|
|
|
+ # 步骤5.5: 搜索高分组合词
|
|
|
+ print(f"\n[步骤5.5] 搜索高分组合词(阈值 > {sug_threshold})...")
|
|
|
+ high_score_combinations = [comb for comb in domain_combinations
|
|
|
+ if comb.score_with_o > sug_threshold]
|
|
|
+ print(f" 找到 {len(high_score_combinations)} 个高分组合词")
|
|
|
+
|
|
|
+ if high_score_combinations:
|
|
|
+ comb_search_tasks = [search_keyword(comb.text, comb.score_with_o, "combination")
|
|
|
+ for comb in high_score_combinations]
|
|
|
+ comb_searches = await asyncio.gather(*comb_search_tasks)
|
|
|
+ search_list.extend(comb_searches)
|
|
|
+
|
|
|
+ # 评估搜索结果中的帖子
|
|
|
+ if enable_evaluation:
|
|
|
+ print(f"\n[评估] 评估组合词搜索结果中的帖子...")
|
|
|
+ for search in comb_searches:
|
|
|
+ if search.post_list:
|
|
|
+ print(f" 评估来自 '{search.text}' 的 {len(search.post_list)} 个帖子")
|
|
|
+ for post in search.post_list:
|
|
|
+ knowledge_eval, content_eval, purpose_eval, category_eval, final_score, match_level = await evaluate_post_v3(post, o, semaphore=None)
|
|
|
+ if knowledge_eval:
|
|
|
+ apply_evaluation_v3_to_post(post, knowledge_eval, content_eval, purpose_eval, category_eval, final_score, match_level)
|
|
|
+
|
|
|
# 步骤6: 构建 q_list_next(组合 + 高分SUG)
|
|
|
print(f"\n[步骤6] 生成下轮输入...")
|
|
|
q_list_next: list[Q] = []
|