|
|
@@ -0,0 +1,927 @@
|
|
|
+import asyncio
|
|
|
+import json
|
|
|
+import os
|
|
|
+import argparse
|
|
|
+from datetime import datetime
|
|
|
+from itertools import combinations, permutations
|
|
|
+
|
|
|
+from agents import Agent, Runner
|
|
|
+from lib.my_trace import set_trace
|
|
|
+from typing import Literal
|
|
|
+from pydantic import BaseModel, Field
|
|
|
+
|
|
|
+from lib.utils import read_file_as_string
|
|
|
+from script.search_recommendations.xiaohongshu_search_recommendations import XiaohongshuSearchRecommendations
|
|
|
+
|
|
|
+
|
|
|
+# ============================================================================
|
|
|
+# 并发控制配置
|
|
|
+# ============================================================================
|
|
|
+# API请求并发度(小红书接口)
|
|
|
+API_CONCURRENCY_LIMIT = 5
|
|
|
+
|
|
|
+# 模型评估并发度(GPT评估)
|
|
|
+MODEL_CONCURRENCY_LIMIT = 10
|
|
|
+
|
|
|
+
|
|
|
+class RunContext(BaseModel):
|
|
|
+ version: str = Field(..., description="当前运行的脚本版本(文件名)")
|
|
|
+ input_files: dict[str, str] = Field(..., description="输入文件路径映射")
|
|
|
+ q_with_context: str
|
|
|
+ q_context: str
|
|
|
+ q: str
|
|
|
+ log_url: str
|
|
|
+ log_dir: str
|
|
|
+
|
|
|
+ # 问题标注
|
|
|
+ question_annotation: str | None = Field(default=None, description="问题的标注结果(三层)")
|
|
|
+
|
|
|
+ # 分词和组合
|
|
|
+ keywords: list[str] | None = Field(default=None, description="提取的关键词")
|
|
|
+ query_combinations: dict[str, list[str]] = Field(default_factory=dict, description="各层级的query组合")
|
|
|
+
|
|
|
+ # v6.4 新增:剪枝记录
|
|
|
+ pruning_info: dict[str, dict] = Field(default_factory=dict, description="各层级的剪枝信息")
|
|
|
+
|
|
|
+ # 探索结果
|
|
|
+ all_sug_queries: list[dict] = Field(default_factory=list, description="所有获取到的推荐词")
|
|
|
+
|
|
|
+ # 评估结果
|
|
|
+ evaluation_results: list[dict] = Field(default_factory=list, description="所有推荐词的评估结果")
|
|
|
+ optimization_result: dict | None = Field(default=None, description="最终优化结果对象")
|
|
|
+ final_output: str | None = Field(default=None, description="最终输出结果(格式化文本)")
|
|
|
+
|
|
|
+
|
|
|
+# ============================================================================
|
|
|
+# Agent 1: 问题标注专家
|
|
|
+# ============================================================================
|
|
|
+question_annotation_instructions = """
|
|
|
+你是搜索需求分析专家。给定问题(含需求背景),在原文上标注三层:本质、硬性、软性。
|
|
|
+
|
|
|
+## 判断标准
|
|
|
+
|
|
|
+**[本质]** - 问题的核心意图
|
|
|
+- 如何获取、教程、推荐、作品、测评等
|
|
|
+
|
|
|
+**[硬]** - 客观事实性约束(可明确验证、非主观判断)
|
|
|
+- 能明确区分类别的:地域、时间、对象、工具、操作类型
|
|
|
+- 特征:改变后得到完全不同类别的结果
|
|
|
+
|
|
|
+**[软]** - 主观判断性修饰(因人而异、程度性的)
|
|
|
+- 需要主观评价的:质量、速度、美观、特色、程度
|
|
|
+- 特征:改变后仍是同类结果,只是满足程度不同
|
|
|
+
|
|
|
+## 输出格式
|
|
|
+
|
|
|
+词语[本质-描述]、词语[硬-描述]、词语[软-描述]
|
|
|
+
|
|
|
+## 注意
|
|
|
+- 只输出标注后的字符串
|
|
|
+- 结合需求背景判断意图
|
|
|
+""".strip()
|
|
|
+
|
|
|
+question_annotator = Agent[None](
|
|
|
+ name="问题标注专家",
|
|
|
+ instructions=question_annotation_instructions,
|
|
|
+)
|
|
|
+
|
|
|
+
|
|
|
+# ============================================================================
|
|
|
+# Agent 2: 分词专家
|
|
|
+# ============================================================================
|
|
|
+segmentation_instructions = """
|
|
|
+你是中文分词专家。给定一个句子,将其分词。
|
|
|
+
|
|
|
+## 分词原则
|
|
|
+
|
|
|
+1. 去掉标点符号
|
|
|
+2. 拆分成最小的有意义单元
|
|
|
+3. 去掉助词、语气词、助动词
|
|
|
+4. 保留疑问词
|
|
|
+5. 保留实词:名词、动词、形容词、副词
|
|
|
+
|
|
|
+## 输出要求
|
|
|
+
|
|
|
+输出分词列表。
|
|
|
+""".strip()
|
|
|
+
|
|
|
+class SegmentationResult(BaseModel):
|
|
|
+ """分词结果"""
|
|
|
+ words: list[str] = Field(..., description="分词列表")
|
|
|
+ reasoning: str = Field(..., description="分词说明")
|
|
|
+
|
|
|
+segmenter = Agent[None](
|
|
|
+ name="分词专家",
|
|
|
+ instructions=segmentation_instructions,
|
|
|
+ output_type=SegmentationResult,
|
|
|
+)
|
|
|
+
|
|
|
+
|
|
|
+# ============================================================================
|
|
|
+# Agent 3: 评估专家(意图匹配 + 相关性评分)
|
|
|
+# ============================================================================
|
|
|
+eval_instructions = """
|
|
|
+你是搜索query评估专家。给定原始问题、问题标注和推荐query,评估两个维度。
|
|
|
+
|
|
|
+## 输入信息
|
|
|
+
|
|
|
+你会收到:
|
|
|
+1. 原始问题:用户的原始表述
|
|
|
+2. 问题标注:对原始问题的三层标注(本质、硬性、软性)
|
|
|
+3. 推荐query:待评估的推荐词
|
|
|
+
|
|
|
+## 评估目标
|
|
|
+
|
|
|
+用这个推荐query搜索,能否找到满足原始需求的内容?
|
|
|
+
|
|
|
+## 两层评分
|
|
|
+
|
|
|
+### 1. intent_match(意图匹配)= true/false
|
|
|
+
|
|
|
+推荐query的**使用意图**是否与原问题的**本质**一致?
|
|
|
+
|
|
|
+**核心:只关注[本质]标注**
|
|
|
+- 问题标注中的 `[本质-XXX]` 标记明确了用户的核心意图
|
|
|
+- 判断推荐词是否能达成这个核心意图
|
|
|
+
|
|
|
+**常见本质类型:**
|
|
|
+- 找方法/如何获取 → 推荐词应包含方法、途径、网站、渠道等
|
|
|
+- 找教程 → 推荐词应是教程、教学相关
|
|
|
+- 找资源/素材 → 推荐词应是资源、素材本身
|
|
|
+- 找工具 → 推荐词应是工具推荐
|
|
|
+- 看作品 → 推荐词应是作品展示
|
|
|
+
|
|
|
+**评分:**
|
|
|
+- true = 推荐词的意图与 `[本质]` 一致
|
|
|
+- false = 推荐词的意图与 `[本质]` 不一致
|
|
|
+
|
|
|
+### 2. relevance_score(相关性)= 0-1 连续分数
|
|
|
+
|
|
|
+在意图匹配的前提下,推荐query在**主题、要素、属性**上与原问题的相关程度?
|
|
|
+
|
|
|
+**评估维度:**
|
|
|
+- 主题相关:核心主题是否匹配?(如:摄影、旅游、美食)
|
|
|
+- 要素覆盖:`[硬-XXX]` 标记的硬性约束保留了多少?(地域、时间、对象、工具等)
|
|
|
+- 属性匹配:`[软-XXX]` 标记的软性修饰保留了多少?(质量、速度、美观等)
|
|
|
+
|
|
|
+**评分参考:**
|
|
|
+- 0.9-1.0 = 几乎完美匹配,[硬]和[软]标注的要素都保留
|
|
|
+- 0.7-0.8 = 高度相关,[硬]标注的要素都保留,[软]标注少数缺失
|
|
|
+- 0.5-0.6 = 中度相关,[硬]标注的要素保留大部分,[软]标注多数缺失
|
|
|
+- 0.3-0.4 = 低度相关,[硬]标注的要素部分缺失
|
|
|
+- 0-0.2 = 基本不相关,[硬]标注的要素大量缺失
|
|
|
+
|
|
|
+## 评估策略
|
|
|
+
|
|
|
+1. **先看[本质]判断 intent_match**:意图不匹配直接 false
|
|
|
+2. **再看[硬][软]评估 relevance_score**:计算要素和属性的保留程度
|
|
|
+
|
|
|
+## 输出要求
|
|
|
+
|
|
|
+请先思考,再打分。按以下顺序输出:
|
|
|
+
|
|
|
+1. reason: 详细的评估理由(先分析再打分)
|
|
|
+ - 原问题的[本质]是什么,推荐词是否匹配这个本质
|
|
|
+ - [硬]约束哪些保留/缺失
|
|
|
+ - [软]修饰哪些保留/缺失
|
|
|
+ - 基于以上分析,给出意图匹配判断和相关性分数的依据
|
|
|
+
|
|
|
+2. intent_match: true/false(基于上述分析得出)
|
|
|
+
|
|
|
+3. relevance_score: 0-1 的浮点数(基于上述分析得出)
|
|
|
+""".strip()
|
|
|
+
|
|
|
+class RelevanceEvaluation(BaseModel):
|
|
|
+ """评估反馈模型 - 意图匹配 + 相关性"""
|
|
|
+ reason: str = Field(..., description="评估理由,需说明意图判断和相关性依据")
|
|
|
+ intent_match: bool = Field(..., description="意图是否匹配")
|
|
|
+ relevance_score: float = Field(..., description="相关性分数 0-1,分数越高越相关")
|
|
|
+
|
|
|
+evaluator = Agent[None](
|
|
|
+ name="评估专家",
|
|
|
+ instructions=eval_instructions,
|
|
|
+ output_type=RelevanceEvaluation,
|
|
|
+)
|
|
|
+
|
|
|
+
|
|
|
+# ============================================================================
|
|
|
+# 核心函数
|
|
|
+# ============================================================================
|
|
|
+
|
|
|
+async def annotate_question(q_with_context: str) -> str:
|
|
|
+ """标注问题(三层)"""
|
|
|
+ print("\n正在标注问题...")
|
|
|
+ result = await Runner.run(question_annotator, q_with_context)
|
|
|
+ annotation = str(result.final_output)
|
|
|
+ print(f"问题标注完成:{annotation}")
|
|
|
+ return annotation
|
|
|
+
|
|
|
+
|
|
|
+async def segment_text(q: str) -> SegmentationResult:
|
|
|
+ """分词"""
|
|
|
+ print("\n正在分词...")
|
|
|
+ result = await Runner.run(segmenter, q)
|
|
|
+ seg_result: SegmentationResult = result.final_output
|
|
|
+ print(f"分词结果:{seg_result.words}")
|
|
|
+ print(f"分词说明:{seg_result.reasoning}")
|
|
|
+ return seg_result
|
|
|
+
|
|
|
+
|
|
|
+def generate_query_combinations_single_level(
|
|
|
+ keywords: list[str],
|
|
|
+ size: int
|
|
|
+) -> list[str]:
|
|
|
+ """
|
|
|
+ 生成单个层级的query组合
|
|
|
+
|
|
|
+ Args:
|
|
|
+ keywords: 关键词列表
|
|
|
+ size: 组合词数
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 该层级的所有query组合
|
|
|
+ """
|
|
|
+ if size > len(keywords):
|
|
|
+ return []
|
|
|
+
|
|
|
+ # 1-word组合:不需要考虑顺序
|
|
|
+ if size == 1:
|
|
|
+ return keywords.copy()
|
|
|
+
|
|
|
+ # 多词组合:先选择size个词(combinations),再排列(permutations)
|
|
|
+ all_queries = []
|
|
|
+ combs = list(combinations(keywords, size))
|
|
|
+ for comb in combs:
|
|
|
+ # 对每个组合生成所有排列
|
|
|
+ perms = list(permutations(comb))
|
|
|
+ for perm in perms:
|
|
|
+ query = ''.join(perm) # 直接拼接,无空格
|
|
|
+ all_queries.append(query)
|
|
|
+
|
|
|
+ # 去重
|
|
|
+ return list(dict.fromkeys(all_queries))
|
|
|
+
|
|
|
+
|
|
|
+async def evaluate_single_sug_with_semaphore(
|
|
|
+ source_query: str,
|
|
|
+ sug_query: str,
|
|
|
+ original_question: str,
|
|
|
+ question_annotation: str,
|
|
|
+ semaphore: asyncio.Semaphore
|
|
|
+) -> dict:
|
|
|
+ """带信号量的单个推荐词评估"""
|
|
|
+ async with semaphore:
|
|
|
+ eval_input = f"""
|
|
|
+<原始问题>
|
|
|
+{original_question}
|
|
|
+</原始问题>
|
|
|
+
|
|
|
+<问题标注(三层)>
|
|
|
+{question_annotation}
|
|
|
+</问题标注(三层)>
|
|
|
+
|
|
|
+<待评估的推荐query>
|
|
|
+{sug_query}
|
|
|
+</待评估的推荐query>
|
|
|
+
|
|
|
+请评估该推荐query(请先分析理由,再给出评分):
|
|
|
+1. reason: 详细的评估理由(先思考分析)
|
|
|
+2. intent_match: 意图是否匹配(true/false)
|
|
|
+3. relevance_score: 相关性分数(0-1)
|
|
|
+
|
|
|
+评估时请参考问题标注中的[本质]、[硬]、[软]标记。
|
|
|
+"""
|
|
|
+ result = await Runner.run(evaluator, eval_input)
|
|
|
+ evaluation: RelevanceEvaluation = result.final_output
|
|
|
+ return {
|
|
|
+ "source_query": source_query,
|
|
|
+ "sug_query": sug_query,
|
|
|
+ "intent_match": evaluation.intent_match,
|
|
|
+ "relevance_score": evaluation.relevance_score,
|
|
|
+ "reason": evaluation.reason,
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+async def fetch_and_evaluate_level(
|
|
|
+ queries: list[str],
|
|
|
+ original_question: str,
|
|
|
+ question_annotation: str,
|
|
|
+ level_name: str,
|
|
|
+ context: RunContext
|
|
|
+) -> tuple[list[dict], list[dict]]:
|
|
|
+ """
|
|
|
+ 处理单个层级:获取推荐词并评估
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ (sug_results, evaluations)
|
|
|
+ """
|
|
|
+ xiaohongshu_api = XiaohongshuSearchRecommendations()
|
|
|
+
|
|
|
+ # 创建信号量
|
|
|
+ api_semaphore = asyncio.Semaphore(API_CONCURRENCY_LIMIT)
|
|
|
+ model_semaphore = asyncio.Semaphore(MODEL_CONCURRENCY_LIMIT)
|
|
|
+
|
|
|
+ # 结果收集
|
|
|
+ sug_results = []
|
|
|
+ all_evaluations = []
|
|
|
+
|
|
|
+ # 统计
|
|
|
+ total_queries = len(queries)
|
|
|
+ completed_queries = 0
|
|
|
+ total_sugs = 0
|
|
|
+ completed_evals = 0
|
|
|
+
|
|
|
+ async def get_and_evaluate_single_query(query: str):
|
|
|
+ nonlocal completed_queries, total_sugs, completed_evals
|
|
|
+
|
|
|
+ # 步骤1:获取推荐词
|
|
|
+ async with api_semaphore:
|
|
|
+ suggestions = xiaohongshu_api.get_recommendations(keyword=query)
|
|
|
+ sug_count = len(suggestions) if suggestions else 0
|
|
|
+
|
|
|
+ completed_queries += 1
|
|
|
+ total_sugs += sug_count
|
|
|
+
|
|
|
+ print(f" [{completed_queries}/{total_queries}] {query} → {sug_count} 个推荐词")
|
|
|
+
|
|
|
+ sug_result = {
|
|
|
+ "query": query,
|
|
|
+ "suggestions": suggestions or [],
|
|
|
+ "timestamp": datetime.now().isoformat()
|
|
|
+ }
|
|
|
+ sug_results.append(sug_result)
|
|
|
+
|
|
|
+ # 步骤2:立即评估这些推荐词
|
|
|
+ if suggestions:
|
|
|
+ eval_tasks = []
|
|
|
+ for sug in suggestions:
|
|
|
+ eval_tasks.append(evaluate_single_sug_with_semaphore(
|
|
|
+ query, sug, original_question, question_annotation, model_semaphore
|
|
|
+ ))
|
|
|
+
|
|
|
+ if eval_tasks:
|
|
|
+ evals = await asyncio.gather(*eval_tasks)
|
|
|
+ all_evaluations.extend(evals)
|
|
|
+ completed_evals += len(evals)
|
|
|
+ print(f" ↳ 已评估 {len(evals)} 个,累计评估 {completed_evals} 个")
|
|
|
+
|
|
|
+ # 并发处理所有query
|
|
|
+ await asyncio.gather(*[get_and_evaluate_single_query(q) for q in queries])
|
|
|
+
|
|
|
+ print(f"\n{level_name} 完成:获取 {total_sugs} 个推荐词,完成 {completed_evals} 个评估")
|
|
|
+
|
|
|
+ return sug_results, all_evaluations
|
|
|
+
|
|
|
+
|
|
|
+def find_intent_matched_keywords(
|
|
|
+ keywords: list[str],
|
|
|
+ evaluations: list[dict]
|
|
|
+) -> set[str]:
|
|
|
+ """
|
|
|
+ 找出所有至少有一个 intent_match=True 的推荐词的关键词
|
|
|
+
|
|
|
+ Args:
|
|
|
+ keywords: 当前层级使用的关键词列表
|
|
|
+ evaluations: 该层级的评估结果
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 有意图匹配的关键词集合
|
|
|
+ """
|
|
|
+ matched_keywords = set()
|
|
|
+
|
|
|
+ for keyword in keywords:
|
|
|
+ # 检查这个关键词对应的推荐词中是否有 intent_match=True 的
|
|
|
+ keyword_evals = [
|
|
|
+ e for e in evaluations
|
|
|
+ if e['source_query'] == keyword and e['intent_match'] is True
|
|
|
+ ]
|
|
|
+
|
|
|
+ if keyword_evals:
|
|
|
+ matched_keywords.add(keyword)
|
|
|
+
|
|
|
+ return matched_keywords
|
|
|
+
|
|
|
+
|
|
|
+def find_top_keywords_by_relevance(
|
|
|
+ keywords: list[str],
|
|
|
+ evaluations: list[dict],
|
|
|
+ top_n: int = 2
|
|
|
+) -> list[str]:
|
|
|
+ """
|
|
|
+ 根据 relevance_score 找出表现最好的 top N 关键词
|
|
|
+
|
|
|
+ Args:
|
|
|
+ keywords: 当前层级使用的关键词列表
|
|
|
+ evaluations: 该层级的评估结果
|
|
|
+ top_n: 保留的关键词数量
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 按平均 relevance_score 排序的 top N 关键词
|
|
|
+ """
|
|
|
+ keyword_scores = {}
|
|
|
+
|
|
|
+ for keyword in keywords:
|
|
|
+ # 找到这个关键词对应的所有评估
|
|
|
+ keyword_evals = [
|
|
|
+ e for e in evaluations
|
|
|
+ if e['source_query'] == keyword
|
|
|
+ ]
|
|
|
+
|
|
|
+ if keyword_evals:
|
|
|
+ # 计算平均 relevance_score
|
|
|
+ avg_score = sum(e['relevance_score'] for e in keyword_evals) / len(keyword_evals)
|
|
|
+ # 同时记录最高分,作为次要排序依据
|
|
|
+ max_score = max(e['relevance_score'] for e in keyword_evals)
|
|
|
+ keyword_scores[keyword] = {
|
|
|
+ 'avg': avg_score,
|
|
|
+ 'max': max_score,
|
|
|
+ 'count': len(keyword_evals)
|
|
|
+ }
|
|
|
+
|
|
|
+ if not keyword_scores:
|
|
|
+ return []
|
|
|
+
|
|
|
+ # 按平均分降序,最高分降序
|
|
|
+ sorted_keywords = sorted(
|
|
|
+ keyword_scores.items(),
|
|
|
+ key=lambda x: (x[1]['avg'], x[1]['max']),
|
|
|
+ reverse=True
|
|
|
+ )
|
|
|
+
|
|
|
+ # 返回 top N 关键词
|
|
|
+ return [kw for kw, score in sorted_keywords[:top_n]]
|
|
|
+
|
|
|
+
|
|
|
+def find_qualified_queries(evaluations: list[dict], min_relevance_score: float = 0.7) -> list[dict]:
|
|
|
+ """
|
|
|
+ 查找所有合格的query
|
|
|
+
|
|
|
+ 筛选标准:
|
|
|
+ 1. intent_match = True(必须满足)
|
|
|
+ 2. relevance_score >= min_relevance_score
|
|
|
+
|
|
|
+ 返回:按 relevance_score 降序排列
|
|
|
+ """
|
|
|
+ qualified = [
|
|
|
+ e for e in evaluations
|
|
|
+ if e['intent_match'] is True and e['relevance_score'] >= min_relevance_score
|
|
|
+ ]
|
|
|
+
|
|
|
+ # 按relevance_score降序排列
|
|
|
+ return sorted(qualified, key=lambda x: x['relevance_score'], reverse=True)
|
|
|
+
|
|
|
+
|
|
|
+# ============================================================================
|
|
|
+# 主流程 - v6.4 层级剪枝
|
|
|
+# ============================================================================
|
|
|
+
|
|
|
+async def combinatorial_search_with_pruning(
|
|
|
+ context: RunContext,
|
|
|
+ max_combination_size: int = 1,
|
|
|
+ fallback_top_n: int = 2
|
|
|
+) -> dict:
|
|
|
+ """
|
|
|
+ 组合式搜索流程(带层级剪枝)
|
|
|
+
|
|
|
+ 策略:
|
|
|
+ - 第1层:所有单词都尝试
|
|
|
+ - 第2层及以上:
|
|
|
+ 1. 优先使用在上一层中至少有一个 intent_match=True 的关键词
|
|
|
+ 2. 如果没有,则使用 relevance_score 最高的 top N 关键词
|
|
|
+ 3. 如果也无法计算,则使用全部关键词
|
|
|
+
|
|
|
+ Args:
|
|
|
+ context: 运行上下文
|
|
|
+ max_combination_size: 最大组合词数(N),默认1
|
|
|
+ fallback_top_n: 当没有意图匹配时,使用 relevance_score top N 关键词,默认2
|
|
|
+
|
|
|
+ 返回格式:
|
|
|
+ {
|
|
|
+ "success": True/False,
|
|
|
+ "results": [...],
|
|
|
+ "message": "..."
|
|
|
+ }
|
|
|
+ """
|
|
|
+
|
|
|
+ # 步骤1:标注问题(三层)
|
|
|
+ annotation = await annotate_question(context.q_with_context)
|
|
|
+ context.question_annotation = annotation
|
|
|
+
|
|
|
+ # 步骤2:分词
|
|
|
+ seg_result = await segment_text(context.q)
|
|
|
+ all_keywords = seg_result.words
|
|
|
+ context.keywords = all_keywords
|
|
|
+
|
|
|
+ # 初始化累积结果
|
|
|
+ all_sug_results = []
|
|
|
+ all_evaluations = []
|
|
|
+
|
|
|
+ # 当前层可用的关键词(第1层是所有关键词)
|
|
|
+ current_keywords = all_keywords.copy()
|
|
|
+
|
|
|
+ print(f"\n{'='*60}")
|
|
|
+ print(f"层级剪枝式搜索(最大层级:{max_combination_size})")
|
|
|
+ print(f"{'='*60}")
|
|
|
+
|
|
|
+ # 逐层处理
|
|
|
+ for level in range(1, max_combination_size + 1):
|
|
|
+ level_name = f"{level}-word"
|
|
|
+
|
|
|
+ print(f"\n{'='*60}")
|
|
|
+ print(f"第 {level} 层:{level_name}")
|
|
|
+ print(f"{'='*60}")
|
|
|
+
|
|
|
+ # 检查是否有可用关键词
|
|
|
+ if not current_keywords:
|
|
|
+ print(f"⚠️ 没有可用的关键词,跳过第 {level} 层")
|
|
|
+ context.pruning_info[level_name] = {
|
|
|
+ "available_keywords": [],
|
|
|
+ "queries_count": 0,
|
|
|
+ "pruned": True,
|
|
|
+ "reason": "上一层没有任何 intent_match=True 的关键词"
|
|
|
+ }
|
|
|
+ break
|
|
|
+
|
|
|
+ # 生成当前层的query组合
|
|
|
+ level_queries = generate_query_combinations_single_level(current_keywords, level)
|
|
|
+
|
|
|
+ if not level_queries:
|
|
|
+ print(f"⚠️ 无法生成 {level} 词组合,跳过")
|
|
|
+ context.pruning_info[level_name] = {
|
|
|
+ "available_keywords": current_keywords,
|
|
|
+ "queries_count": 0,
|
|
|
+ "pruned": True,
|
|
|
+ "reason": f"关键词数量不足以生成 {level} 词组合"
|
|
|
+ }
|
|
|
+ break
|
|
|
+
|
|
|
+ print(f"可用关键词:{current_keywords}")
|
|
|
+ print(f"生成的query数:{len(level_queries)}")
|
|
|
+
|
|
|
+ # 记录该层的query组合
|
|
|
+ context.query_combinations[level_name] = level_queries
|
|
|
+
|
|
|
+ # 打印部分query示例
|
|
|
+ print(f"\nquery示例(前10个):")
|
|
|
+ for i, q in enumerate(level_queries[:10], 1):
|
|
|
+ print(f" {i}. {q}")
|
|
|
+ if len(level_queries) > 10:
|
|
|
+ print(f" ... 还有 {len(level_queries) - 10} 个")
|
|
|
+
|
|
|
+ # 获取推荐词并评估
|
|
|
+ print(f"\n开始处理第 {level} 层的推荐词...")
|
|
|
+ level_sug_results, level_evaluations = await fetch_and_evaluate_level(
|
|
|
+ level_queries,
|
|
|
+ context.q,
|
|
|
+ annotation,
|
|
|
+ level_name,
|
|
|
+ context
|
|
|
+ )
|
|
|
+
|
|
|
+ # 累积结果
|
|
|
+ all_sug_results.extend(level_sug_results)
|
|
|
+ all_evaluations.extend(level_evaluations)
|
|
|
+
|
|
|
+ # 统计该层的意图匹配情况
|
|
|
+ intent_matched_count = sum(1 for e in level_evaluations if e['intent_match'] is True)
|
|
|
+
|
|
|
+ print(f"\n第 {level} 层统计:")
|
|
|
+ print(f" - 查询数:{len(level_queries)}")
|
|
|
+ print(f" - 推荐词数:{sum(len(r['suggestions']) for r in level_sug_results)}")
|
|
|
+ print(f" - 意图匹配数:{intent_matched_count}/{len(level_evaluations)}")
|
|
|
+
|
|
|
+ # 记录剪枝信息
|
|
|
+ context.pruning_info[level_name] = {
|
|
|
+ "available_keywords": current_keywords,
|
|
|
+ "queries_count": len(level_queries),
|
|
|
+ "pruned": False,
|
|
|
+ "intent_matched_count": intent_matched_count,
|
|
|
+ "total_evaluations": len(level_evaluations)
|
|
|
+ }
|
|
|
+
|
|
|
+ # 如果还有下一层,找出有意图匹配的关键词用于下一层
|
|
|
+ if level < max_combination_size:
|
|
|
+ # 只在第1层时需要找出有意图匹配的关键词
|
|
|
+ if level == 1:
|
|
|
+ matched_keywords = find_intent_matched_keywords(current_keywords, level_evaluations)
|
|
|
+
|
|
|
+ print(f"\n剪枝结果:")
|
|
|
+ print(f" - 原始关键词数:{len(current_keywords)}")
|
|
|
+ print(f" - 意图匹配关键词数:{len(matched_keywords)}")
|
|
|
+
|
|
|
+ if matched_keywords:
|
|
|
+ print(f" ✓ 策略:使用意图匹配的关键词")
|
|
|
+ print(f" - 保留的关键词:{sorted(matched_keywords)}")
|
|
|
+ current_keywords = list(matched_keywords)
|
|
|
+ else:
|
|
|
+ print(f" ⚠️ 没有任何关键词产生 intent_match=True 的推荐词")
|
|
|
+ # 退而求其次:使用 relevance_score 最高的 top N 关键词
|
|
|
+ top_keywords = find_top_keywords_by_relevance(current_keywords, level_evaluations, top_n=fallback_top_n)
|
|
|
+
|
|
|
+ if top_keywords:
|
|
|
+ print(f" ✓ 策略:使用 relevance_score 最高的 top {fallback_top_n} 关键词")
|
|
|
+ print(f" - 保留的关键词:{top_keywords}")
|
|
|
+ current_keywords = top_keywords
|
|
|
+
|
|
|
+ # 显示关键词的得分详情
|
|
|
+ for kw in top_keywords:
|
|
|
+ kw_evals = [e for e in level_evaluations if e['source_query'] == kw]
|
|
|
+ if kw_evals:
|
|
|
+ avg_score = sum(e['relevance_score'] for e in kw_evals) / len(kw_evals)
|
|
|
+ max_score = max(e['relevance_score'] for e in kw_evals)
|
|
|
+ print(f" - {kw}: 平均={avg_score:.2f}, 最高={max_score:.2f}, 推荐词数={len(kw_evals)}")
|
|
|
+ else:
|
|
|
+ print(f" ⚠️ 无法计算 relevance_score,第2层将使用全部关键词")
|
|
|
+ current_keywords = all_keywords.copy()
|
|
|
+
|
|
|
+ # 保存累积结果
|
|
|
+ context.all_sug_queries = all_sug_results
|
|
|
+ context.evaluation_results = all_evaluations
|
|
|
+
|
|
|
+ # 筛选合格query
|
|
|
+ print(f"\n{'='*60}")
|
|
|
+ print(f"筛选最终结果")
|
|
|
+ print(f"{'='*60}")
|
|
|
+
|
|
|
+ qualified = find_qualified_queries(all_evaluations, min_relevance_score=0.7)
|
|
|
+
|
|
|
+ if qualified:
|
|
|
+ return {
|
|
|
+ "success": True,
|
|
|
+ "results": qualified,
|
|
|
+ "message": f"找到 {len(qualified)} 个合格query(intent_match=True 且 relevance>=0.7)"
|
|
|
+ }
|
|
|
+
|
|
|
+ # 降低标准
|
|
|
+ acceptable = find_qualified_queries(all_evaluations, min_relevance_score=0.5)
|
|
|
+ if acceptable:
|
|
|
+ return {
|
|
|
+ "success": True,
|
|
|
+ "results": acceptable,
|
|
|
+ "message": f"找到 {len(acceptable)} 个可接受query(intent_match=True 且 relevance>=0.5)"
|
|
|
+ }
|
|
|
+
|
|
|
+ # 完全失败:返回所有intent_match=True的
|
|
|
+ intent_matched = [e for e in all_evaluations if e['intent_match'] is True]
|
|
|
+ if intent_matched:
|
|
|
+ intent_matched_sorted = sorted(intent_matched, key=lambda x: x['relevance_score'], reverse=True)
|
|
|
+ return {
|
|
|
+ "success": False,
|
|
|
+ "results": intent_matched_sorted[:10], # 只返回前10个
|
|
|
+ "message": f"未找到高相关性query,但有 {len(intent_matched)} 个意图匹配的推荐词"
|
|
|
+ }
|
|
|
+
|
|
|
+ return {
|
|
|
+ "success": False,
|
|
|
+ "results": [],
|
|
|
+ "message": "未找到任何意图匹配的推荐词"
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+# ============================================================================
|
|
|
+# 输出格式化
|
|
|
+# ============================================================================
|
|
|
+
|
|
|
+def format_output(optimization_result: dict, context: RunContext) -> str:
|
|
|
+ """格式化输出结果"""
|
|
|
+ results = optimization_result.get("results", [])
|
|
|
+
|
|
|
+ output = f"原始问题:{context.q}\n"
|
|
|
+ output += f"问题标注:{context.question_annotation}\n"
|
|
|
+ output += f"提取的关键词:{', '.join(context.keywords or [])}\n"
|
|
|
+ output += f"关键词数量:{len(context.keywords or [])}\n"
|
|
|
+
|
|
|
+ # 层级剪枝信息
|
|
|
+ output += f"\n{'='*60}\n"
|
|
|
+ output += f"层级剪枝信息:\n"
|
|
|
+ output += f"{'='*60}\n"
|
|
|
+ for level_name, info in context.pruning_info.items():
|
|
|
+ output += f"\n{level_name}:\n"
|
|
|
+ if info.get('pruned'):
|
|
|
+ output += f" 状态:已剪枝 ✂️\n"
|
|
|
+ output += f" 原因:{info.get('reason', '未知')}\n"
|
|
|
+ else:
|
|
|
+ output += f" 状态:已处理 ✓\n"
|
|
|
+ output += f" 可用关键词数:{len(info['available_keywords'])}\n"
|
|
|
+ output += f" 可用关键词:{info['available_keywords']}\n"
|
|
|
+ output += f" 生成query数:{info['queries_count']}\n"
|
|
|
+ output += f" 意图匹配数:{info.get('intent_matched_count', 0)}/{info.get('total_evaluations', 0)}\n"
|
|
|
+
|
|
|
+ # query组合统计
|
|
|
+ output += f"\n{'='*60}\n"
|
|
|
+ output += f"query组合统计:\n"
|
|
|
+ output += f"{'='*60}\n"
|
|
|
+ for level, queries in context.query_combinations.items():
|
|
|
+ output += f" - {level}: {len(queries)} 个\n"
|
|
|
+
|
|
|
+ # 统计信息
|
|
|
+ total_queries = sum(len(q) for q in context.query_combinations.values())
|
|
|
+ total_sugs = sum(len(r["suggestions"]) for r in context.all_sug_queries)
|
|
|
+ total_evals = len(context.evaluation_results)
|
|
|
+
|
|
|
+ output += f"\n探索统计:\n"
|
|
|
+ output += f" - 总query数:{total_queries}\n"
|
|
|
+ output += f" - 总推荐词数:{total_sugs}\n"
|
|
|
+ output += f" - 总评估数:{total_evals}\n"
|
|
|
+
|
|
|
+ output += f"\n状态:{optimization_result['message']}\n\n"
|
|
|
+
|
|
|
+ if optimization_result["success"] and results:
|
|
|
+ output += "=" * 60 + "\n"
|
|
|
+ output += "合格的推荐query(按relevance_score降序):\n"
|
|
|
+ output += "=" * 60 + "\n"
|
|
|
+ for i, result in enumerate(results[:20], 1): # 只显示前20个
|
|
|
+ output += f"\n{i}. [{result['relevance_score']:.2f}] {result['sug_query']}\n"
|
|
|
+ output += f" 来源:{result['source_query']}\n"
|
|
|
+ output += f" 意图:{'✓ 匹配' if result['intent_match'] else '✗ 不匹配'}\n"
|
|
|
+ output += f" 理由:{result['reason'][:150]}...\n" if len(result['reason']) > 150 else f" 理由:{result['reason']}\n"
|
|
|
+ else:
|
|
|
+ output += "=" * 60 + "\n"
|
|
|
+ output += "结果:未找到足够相关的推荐query\n"
|
|
|
+ output += "=" * 60 + "\n"
|
|
|
+ if results:
|
|
|
+ output += "\n最接近的推荐词(前10个):\n\n"
|
|
|
+ for i, result in enumerate(results[:10], 1):
|
|
|
+ output += f"{i}. [{result['relevance_score']:.2f}] {result['sug_query']}\n"
|
|
|
+ output += f" 来源:{result['source_query']}\n"
|
|
|
+ output += f" 意图:{'✓ 匹配' if result['intent_match'] else '✗ 不匹配'}\n\n"
|
|
|
+
|
|
|
+ # 按source_query分组显示
|
|
|
+ output += "\n" + "=" * 60 + "\n"
|
|
|
+ output += "按查询词分组的推荐词情况:\n"
|
|
|
+ output += "=" * 60 + "\n"
|
|
|
+
|
|
|
+ for sug_data in context.all_sug_queries:
|
|
|
+ source_q = sug_data["query"]
|
|
|
+ sugs = sug_data["suggestions"]
|
|
|
+
|
|
|
+ # 找到这个source_query对应的所有评估
|
|
|
+ related_evals = [e for e in context.evaluation_results if e["source_query"] == source_q]
|
|
|
+ intent_match_count = sum(1 for e in related_evals if e["intent_match"])
|
|
|
+ avg_relevance = sum(e["relevance_score"] for e in related_evals) / len(related_evals) if related_evals else 0
|
|
|
+
|
|
|
+ output += f"\n查询:{source_q}\n"
|
|
|
+ output += f" 推荐词数:{len(sugs)}\n"
|
|
|
+ output += f" 意图匹配数:{intent_match_count}/{len(related_evals)}\n"
|
|
|
+ output += f" 平均相关性:{avg_relevance:.2f}\n"
|
|
|
+
|
|
|
+ # 显示前3个推荐词
|
|
|
+ if sugs:
|
|
|
+ output += f" 示例推荐词:\n"
|
|
|
+ for sug in sugs[:3]:
|
|
|
+ eval_item = next((e for e in related_evals if e["sug_query"] == sug), None)
|
|
|
+ if eval_item:
|
|
|
+ output += f" - {sug} [意图:{'✓' if eval_item['intent_match'] else '✗'}, 相关:{eval_item['relevance_score']:.2f}]\n"
|
|
|
+ else:
|
|
|
+ output += f" - {sug}\n"
|
|
|
+
|
|
|
+ return output.strip()
|
|
|
+
|
|
|
+
|
|
|
+# ============================================================================
|
|
|
+# 主函数
|
|
|
+# ============================================================================
|
|
|
+
|
|
|
+async def main(
|
|
|
+ input_dir: str,
|
|
|
+ max_combination_size: int = 1,
|
|
|
+ api_concurrency: int = API_CONCURRENCY_LIMIT,
|
|
|
+ model_concurrency: int = MODEL_CONCURRENCY_LIMIT,
|
|
|
+ fallback_top_n: int = 2
|
|
|
+):
|
|
|
+ # 更新全局并发配置
|
|
|
+ global API_CONCURRENCY_LIMIT, MODEL_CONCURRENCY_LIMIT
|
|
|
+ API_CONCURRENCY_LIMIT = api_concurrency
|
|
|
+ MODEL_CONCURRENCY_LIMIT = model_concurrency
|
|
|
+
|
|
|
+ current_time, log_url = set_trace()
|
|
|
+
|
|
|
+ # 从目录中读取固定文件名
|
|
|
+ input_context_file = os.path.join(input_dir, 'context.md')
|
|
|
+ input_q_file = os.path.join(input_dir, 'q.md')
|
|
|
+
|
|
|
+ q_context = read_file_as_string(input_context_file)
|
|
|
+ q = read_file_as_string(input_q_file)
|
|
|
+ q_with_context = f"""
|
|
|
+<需求上下文>
|
|
|
+{q_context}
|
|
|
+</需求上下文>
|
|
|
+<当前问题>
|
|
|
+{q}
|
|
|
+</当前问题>
|
|
|
+""".strip()
|
|
|
+
|
|
|
+ # 获取当前文件名作为版本
|
|
|
+ version = os.path.basename(__file__)
|
|
|
+ version_name = os.path.splitext(version)[0]
|
|
|
+
|
|
|
+ # 日志保存目录
|
|
|
+ log_dir = os.path.join(input_dir, "output", version_name, current_time)
|
|
|
+
|
|
|
+ run_context = RunContext(
|
|
|
+ version=version,
|
|
|
+ input_files={
|
|
|
+ "input_dir": input_dir,
|
|
|
+ "context_file": input_context_file,
|
|
|
+ "q_file": input_q_file,
|
|
|
+ },
|
|
|
+ q_with_context=q_with_context,
|
|
|
+ q_context=q_context,
|
|
|
+ q=q,
|
|
|
+ log_dir=log_dir,
|
|
|
+ log_url=log_url,
|
|
|
+ )
|
|
|
+
|
|
|
+ print(f"\n{'='*60}")
|
|
|
+ print(f"并发配置")
|
|
|
+ print(f"{'='*60}")
|
|
|
+ print(f"API请求并发度:{API_CONCURRENCY_LIMIT}")
|
|
|
+ print(f"模型评估并发度:{MODEL_CONCURRENCY_LIMIT}")
|
|
|
+
|
|
|
+ # 执行层级剪枝式搜索
|
|
|
+ optimization_result = await combinatorial_search_with_pruning(
|
|
|
+ run_context,
|
|
|
+ max_combination_size=max_combination_size,
|
|
|
+ fallback_top_n=fallback_top_n
|
|
|
+ )
|
|
|
+
|
|
|
+ # 格式化输出
|
|
|
+ final_output = format_output(optimization_result, run_context)
|
|
|
+ print(f"\n{'='*60}")
|
|
|
+ print("最终结果")
|
|
|
+ print(f"{'='*60}")
|
|
|
+ print(final_output)
|
|
|
+
|
|
|
+ # 保存结果
|
|
|
+ run_context.optimization_result = optimization_result
|
|
|
+ run_context.final_output = final_output
|
|
|
+
|
|
|
+ # 保存 RunContext 到 log_dir
|
|
|
+ os.makedirs(run_context.log_dir, exist_ok=True)
|
|
|
+ context_file_path = os.path.join(run_context.log_dir, "run_context.json")
|
|
|
+ with open(context_file_path, "w", encoding="utf-8") as f:
|
|
|
+ json.dump(run_context.model_dump(), f, ensure_ascii=False, indent=2)
|
|
|
+ print(f"\nRunContext saved to: {context_file_path}")
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ parser = argparse.ArgumentParser(
|
|
|
+ description="搜索query优化工具 - v6.4 层级剪枝版",
|
|
|
+ formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
|
+ epilog="""
|
|
|
+示例:
|
|
|
+ # 默认参数(只搜索1层)
|
|
|
+ python sug_v6_4_with_annotation.py
|
|
|
+
|
|
|
+ # 2层搜索,第2层只使用第1层中有意图匹配的关键词
|
|
|
+ python sug_v6_4_with_annotation.py --max-combo 2
|
|
|
+
|
|
|
+ # 2层搜索,如果第1层没有意图匹配,则使用 top 3 关键词
|
|
|
+ python sug_v6_4_with_annotation.py --max-combo 2 --fallback-top 3
|
|
|
+
|
|
|
+ # 3层搜索,API并发5,模型并发20
|
|
|
+ python sug_v6_4_with_annotation.py --max-combo 3 --api-concurrency 5 --model-concurrency 20
|
|
|
+
|
|
|
+ # 指定输入目录
|
|
|
+ python sug_v6_4_with_annotation.py --input-dir "input/旅游-逸趣玩旅行/如何获取能体现川西秋季特色的高质量风光摄影素材?"
|
|
|
+ """
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--input-dir",
|
|
|
+ type=str,
|
|
|
+ default="input/简单扣图",
|
|
|
+ help="输入目录路径,默认: input/简单扣图"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--max-combo",
|
|
|
+ type=int,
|
|
|
+ default=1,
|
|
|
+ help="最大组合词数(N),默认: 1"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--fallback-top",
|
|
|
+ type=int,
|
|
|
+ default=2,
|
|
|
+ help="当第1层没有意图匹配时,使用 relevance_score top N 关键词,默认: 2"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--api-concurrency",
|
|
|
+ type=int,
|
|
|
+ default=API_CONCURRENCY_LIMIT,
|
|
|
+ help=f"API请求并发度,默认: {API_CONCURRENCY_LIMIT}"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--model-concurrency",
|
|
|
+ type=int,
|
|
|
+ default=MODEL_CONCURRENCY_LIMIT,
|
|
|
+ help=f"模型评估并发度,默认: {MODEL_CONCURRENCY_LIMIT}"
|
|
|
+ )
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ asyncio.run(main(
|
|
|
+ args.input_dir,
|
|
|
+ max_combination_size=args.max_combo,
|
|
|
+ api_concurrency=args.api_concurrency,
|
|
|
+ model_concurrency=args.model_concurrency,
|
|
|
+ fallback_top_n=args.fallback_top
|
|
|
+ ))
|