import asyncio import json import os import argparse from datetime import datetime from itertools import combinations from agents import Agent, Runner from lib.my_trace import set_trace from typing import Literal from pydantic import BaseModel, Field from lib.utils import read_file_as_string from script.search_recommendations.xiaohongshu_search_recommendations import XiaohongshuSearchRecommendations class RunContext(BaseModel): version: str = Field(..., description="当前运行的脚本版本(文件名)") input_files: dict[str, str] = Field(..., description="输入文件路径映射") q_with_context: str q_context: str q: str log_url: str log_dir: str # 分词和组合 keywords: list[str] | None = Field(default=None, description="提取的关键词") query_combinations: dict[str, list[str]] = Field(default_factory=dict, description="各层级的query组合") # 探索结果 all_sug_queries: list[dict] = Field(default_factory=list, description="所有获取到的推荐词") # 评估结果 evaluation_results: list[dict] = Field(default_factory=list, description="所有推荐词的评估结果") optimization_result: dict | None = Field(default=None, description="最终优化结果对象") final_output: str | None = Field(default=None, description="最终输出结果(格式化文本)") # ============================================================================ # Agent 1: 分词专家 # ============================================================================ segmentation_instructions = """ 你是中文分词专家。给定一个句子,将其分词。 ## 分词原则 1. 去掉标点符号 2. 拆分成最小的有意义单元 3. 去掉助词、语气词、助动词 4. 保留疑问词 5. 保留实词:名词、动词、形容词、副词 ## 输出要求 输出分词列表。 """.strip() class SegmentationResult(BaseModel): """分词结果""" words: list[str] = Field(..., description="分词列表") reasoning: str = Field(..., description="分词说明") segmenter = Agent[None]( name="分词专家", instructions=segmentation_instructions, output_type=SegmentationResult, ) # ============================================================================ # Agent 2: 评估专家(意图匹配 + 相关性评分) # ============================================================================ eval_instructions = """ 你是搜索query评估专家。给定原始问题和推荐query,评估两个维度。 ## 评估目标 用这个推荐query搜索,能否找到满足原始需求的内容? ## 两层评分 ### 1. intent_match(意图匹配)= true/false 推荐query的**使用意图**是否与原问题一致? **核心问题:用户搜索这个推荐词,想做什么?** **判断标准:** - 原问题意图:找方法?找教程?找资源/素材?找工具?看作品? - 推荐词意图:如果用户搜索这个词,他的目的是什么? **评分:** - true = 意图一致,搜索推荐词能达到原问题的目的 - false = 意图改变,搜索推荐词无法达到原问题的目的 ### 2. relevance_score(相关性)= 0-1 连续分数 推荐query在**主题、要素、属性**上与原问题的相关程度? **评估维度:** - 主题相关:核心主题是否匹配?(如:摄影、旅游、美食) - 要素覆盖:关键要素保留了多少?(如:地域、时间、对象、工具) - 属性匹配:质量、风格、特色等属性是否保留? **评分参考:** - 0.9-1.0 = 几乎完美匹配,所有核心要素都保留 - 0.7-0.8 = 高度相关,核心要素保留,少数次要要素缺失 - 0.5-0.6 = 中度相关,主题匹配但多个要素缺失 - 0.3-0.4 = 低度相关,只有部分主题相关 - 0-0.2 = 基本不相关 ## 评估策略 1. **先判断 intent_match**:意图不匹配直接 false,无论相关性多高 2. **再评估 relevance_score**:在意图匹配的前提下,计算相关性 ## 输出要求 - intent_match: true/false - relevance_score: 0-1 的浮点数 - reason: 详细的评估理由,需要说明: - 原问题的意图是什么 - 推荐词的意图是什么 - 为什么判断意图匹配/不匹配 - 相关性分数的依据(哪些要素保留/缺失) """.strip() class RelevanceEvaluation(BaseModel): """评估反馈模型 - 意图匹配 + 相关性""" intent_match: bool = Field(..., description="意图是否匹配") relevance_score: float = Field(..., description="相关性分数 0-1,分数越高越相关") reason: str = Field(..., description="评估理由,需说明意图判断和相关性依据") evaluator = Agent[None]( name="评估专家", instructions=eval_instructions, output_type=RelevanceEvaluation, ) # ============================================================================ # 核心函数 # ============================================================================ async def segment_text(q: str) -> SegmentationResult: """分词""" print("\n正在分词...") result = await Runner.run(segmenter, q) seg_result: SegmentationResult = result.final_output print(f"分词结果:{seg_result.words}") print(f"分词说明:{seg_result.reasoning}") return seg_result def generate_query_combinations(keywords: list[str], max_combination_size: int) -> dict[str, list[str]]: """ 生成query组合 Args: keywords: 关键词列表 max_combination_size: 最大组合词数(N) Returns: { "1-word": [...], "2-word": [...], "3-word": [...], ... "N-word": [...] } """ result = {} for size in range(1, max_combination_size + 1): if size > len(keywords): break combs = list(combinations(keywords, size)) queries = [''.join(comb) for comb in combs] # 直接拼接,无空格 result[f"{size}-word"] = queries print(f"\n{size}词组合:{len(queries)} 个") if len(queries) <= 10: for q in queries: print(f" - {q}") else: print(f" - {queries[0]}") print(f" - {queries[1]}") print(f" ...") print(f" - {queries[-1]}") return result async def fetch_suggestions_for_queries(queries: list[str], context: RunContext) -> list[dict]: """ 并发获取所有query的推荐词 Returns: [ { "query": "川西", "suggestions": ["川西旅游", "川西攻略", ...], "timestamp": "..." }, ... ] """ print(f"\n{'='*60}") print(f"获取推荐词:{len(queries)} 个query") print(f"{'='*60}") xiaohongshu_api = XiaohongshuSearchRecommendations() async def get_single_sug(query: str): print(f" 查询: {query}") suggestions = xiaohongshu_api.get_recommendations(keyword=query) print(f" → {len(suggestions) if suggestions else 0} 个推荐词") return { "query": query, "suggestions": suggestions or [], "timestamp": datetime.now().isoformat() } results = await asyncio.gather(*[get_single_sug(q) for q in queries]) return results async def evaluate_all_suggestions(sug_results: list[dict], original_question: str, context: RunContext) -> list[dict]: """ 评估所有推荐词 Args: sug_results: 所有query的推荐词结果 original_question: 原始问题 Returns: [ { "source_query": "川西秋季", "sug_query": "川西秋季旅游", "intent_match": True, "relevance_score": 0.8, "reason": "..." }, ... ] """ print(f"\n{'='*60}") print(f"评估推荐词") print(f"{'='*60}") # 收集所有推荐词 all_evaluations = [] async def evaluate_single_sug(source_query: str, sug_query: str): eval_input = f""" <原始问题> {original_question} <待评估的推荐query> {sug_query} 请评估该推荐query: 1. intent_match: 意图是否匹配(true/false) 2. relevance_score: 相关性分数(0-1) 3. reason: 详细的评估理由 """ result = await Runner.run(evaluator, eval_input) evaluation: RelevanceEvaluation = result.final_output return { "source_query": source_query, "sug_query": sug_query, "intent_match": evaluation.intent_match, "relevance_score": evaluation.relevance_score, "reason": evaluation.reason, } # 并发评估所有推荐词 tasks = [] for sug_result in sug_results: source_query = sug_result["query"] for sug in sug_result["suggestions"]: tasks.append(evaluate_single_sug(source_query, sug)) if tasks: print(f" 总共需要评估 {len(tasks)} 个推荐词...") all_evaluations = await asyncio.gather(*tasks) context.evaluation_results = all_evaluations return all_evaluations def find_qualified_queries(evaluations: list[dict], min_relevance_score: float = 0.7) -> list[dict]: """ 查找所有合格的query 筛选标准: 1. intent_match = True(必须满足) 2. relevance_score >= min_relevance_score 返回:按 relevance_score 降序排列 """ qualified = [ e for e in evaluations if e['intent_match'] is True and e['relevance_score'] >= min_relevance_score ] # 按relevance_score降序排列 return sorted(qualified, key=lambda x: x['relevance_score'], reverse=True) # ============================================================================ # 主流程 # ============================================================================ async def combinatorial_search(context: RunContext, max_combination_size: int = 1) -> dict: """ 组合式搜索流程 Args: context: 运行上下文 max_combination_size: 最大组合词数(N),默认1 返回格式: { "success": True/False, "results": [...], "message": "..." } """ # 步骤1:分词 seg_result = await segment_text(context.q) context.keywords = seg_result.words # 步骤2:生成query组合 print(f"\n{'='*60}") print(f"生成query组合(最大组合数:{max_combination_size})") print(f"{'='*60}") query_combinations = generate_query_combinations(context.keywords, max_combination_size) context.query_combinations = query_combinations # 步骤3:获取所有query的推荐词 all_queries = [] for level, queries in query_combinations.items(): all_queries.extend(queries) sug_results = await fetch_suggestions_for_queries(all_queries, context) context.all_sug_queries = sug_results # 统计 total_sugs = sum(len(r["suggestions"]) for r in sug_results) print(f"\n总共获取到 {total_sugs} 个推荐词") # 步骤4:评估所有推荐词 evaluations = await evaluate_all_suggestions(sug_results, context.q, context) # 步骤5:筛选合格query qualified = find_qualified_queries(evaluations, min_relevance_score=0.7) if qualified: return { "success": True, "results": qualified, "message": f"找到 {len(qualified)} 个合格query(intent_match=True 且 relevance>=0.7)" } # 降低标准 acceptable = find_qualified_queries(evaluations, min_relevance_score=0.5) if acceptable: return { "success": True, "results": acceptable, "message": f"找到 {len(acceptable)} 个可接受query(intent_match=True 且 relevance>=0.5)" } # 完全失败:返回所有intent_match=True的 intent_matched = [e for e in evaluations if e['intent_match'] is True] if intent_matched: intent_matched_sorted = sorted(intent_matched, key=lambda x: x['relevance_score'], reverse=True) return { "success": False, "results": intent_matched_sorted[:10], # 只返回前10个 "message": f"未找到高相关性query,但有 {len(intent_matched)} 个意图匹配的推荐词" } return { "success": False, "results": [], "message": "未找到任何意图匹配的推荐词" } # ============================================================================ # 输出格式化 # ============================================================================ def format_output(optimization_result: dict, context: RunContext) -> str: """格式化输出结果""" results = optimization_result.get("results", []) output = f"原始问题:{context.q}\n" output += f"提取的关键词:{', '.join(context.keywords or [])}\n" output += f"关键词数量:{len(context.keywords or [])}\n" output += f"\nquery组合统计:\n" for level, queries in context.query_combinations.items(): output += f" - {level}: {len(queries)} 个\n" # 统计信息 total_queries = sum(len(q) for q in context.query_combinations.values()) total_sugs = sum(len(r["suggestions"]) for r in context.all_sug_queries) total_evals = len(context.evaluation_results) output += f"\n探索统计:\n" output += f" - 总query数:{total_queries}\n" output += f" - 总推荐词数:{total_sugs}\n" output += f" - 总评估数:{total_evals}\n" output += f"\n状态:{optimization_result['message']}\n\n" if optimization_result["success"] and results: output += "=" * 60 + "\n" output += "合格的推荐query(按relevance_score降序):\n" output += "=" * 60 + "\n" for i, result in enumerate(results[:20], 1): # 只显示前20个 output += f"\n{i}. [{result['relevance_score']:.2f}] {result['sug_query']}\n" output += f" 来源:{result['source_query']}\n" output += f" 意图:{'✓ 匹配' if result['intent_match'] else '✗ 不匹配'}\n" output += f" 理由:{result['reason'][:150]}...\n" if len(result['reason']) > 150 else f" 理由:{result['reason']}\n" else: output += "=" * 60 + "\n" output += "结果:未找到足够相关的推荐query\n" output += "=" * 60 + "\n" if results: output += "\n最接近的推荐词(前10个):\n\n" for i, result in enumerate(results[:10], 1): output += f"{i}. [{result['relevance_score']:.2f}] {result['sug_query']}\n" output += f" 来源:{result['source_query']}\n" output += f" 意图:{'✓ 匹配' if result['intent_match'] else '✗ 不匹配'}\n\n" # 按source_query分组显示 output += "\n" + "=" * 60 + "\n" output += "按查询词分组的推荐词情况:\n" output += "=" * 60 + "\n" for sug_data in context.all_sug_queries: source_q = sug_data["query"] sugs = sug_data["suggestions"] # 找到这个source_query对应的所有评估 related_evals = [e for e in context.evaluation_results if e["source_query"] == source_q] intent_match_count = sum(1 for e in related_evals if e["intent_match"]) avg_relevance = sum(e["relevance_score"] for e in related_evals) / len(related_evals) if related_evals else 0 output += f"\n查询:{source_q}\n" output += f" 推荐词数:{len(sugs)}\n" output += f" 意图匹配数:{intent_match_count}/{len(related_evals)}\n" output += f" 平均相关性:{avg_relevance:.2f}\n" # 显示前3个推荐词 if sugs: output += f" 示例推荐词:\n" for sug in sugs[:3]: eval_item = next((e for e in related_evals if e["sug_query"] == sug), None) if eval_item: output += f" - {sug} [意图:{'✓' if eval_item['intent_match'] else '✗'}, 相关:{eval_item['relevance_score']:.2f}]\n" else: output += f" - {sug}\n" return output.strip() # ============================================================================ # 主函数 # ============================================================================ async def main(input_dir: str, max_combination_size: int = 1): current_time, log_url = set_trace() # 从目录中读取固定文件名 input_context_file = os.path.join(input_dir, 'context.md') input_q_file = os.path.join(input_dir, 'q.md') q_context = read_file_as_string(input_context_file) q = read_file_as_string(input_q_file) q_with_context = f""" <需求上下文> {q_context} <当前问题> {q} """.strip() # 获取当前文件名作为版本 version = os.path.basename(__file__) version_name = os.path.splitext(version)[0] # 日志保存目录 log_dir = os.path.join(input_dir, "output", version_name, current_time) run_context = RunContext( version=version, input_files={ "input_dir": input_dir, "context_file": input_context_file, "q_file": input_q_file, }, q_with_context=q_with_context, q_context=q_context, q=q, log_dir=log_dir, log_url=log_url, ) # 执行组合式搜索 optimization_result = await combinatorial_search(run_context, max_combination_size=max_combination_size) # 格式化输出 final_output = format_output(optimization_result, run_context) print(f"\n{'='*60}") print("最终结果") print(f"{'='*60}") print(final_output) # 保存结果 run_context.optimization_result = optimization_result run_context.final_output = final_output # 保存 RunContext 到 log_dir os.makedirs(run_context.log_dir, exist_ok=True) context_file_path = os.path.join(run_context.log_dir, "run_context.json") with open(context_file_path, "w", encoding="utf-8") as f: json.dump(run_context.model_dump(), f, ensure_ascii=False, indent=2) print(f"\nRunContext saved to: {context_file_path}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="搜索query优化工具 - v6.2 组合式搜索版") parser.add_argument( "--input-dir", type=str, default="input/简单扣图", help="输入目录路径,默认: input/简单扣图" ) parser.add_argument( "--max-combo", type=int, default=1, help="最大组合词数(N),默认: 1" ) args = parser.parse_args() asyncio.run(main(args.input_dir, max_combination_size=args.max_combo))