| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712 |
- import asyncio
- import json
- import os
- import argparse
- from datetime import datetime
- from itertools import combinations, permutations
- from agents import Agent, Runner
- from lib.my_trace import set_trace
- from typing import Literal
- from pydantic import BaseModel, Field
- from lib.utils import read_file_as_string
- from script.search_recommendations.xiaohongshu_search_recommendations import XiaohongshuSearchRecommendations
- # ============================================================================
- # 并发控制配置
- # ============================================================================
- # API请求并发度(小红书接口)
- API_CONCURRENCY_LIMIT = 5
- # 模型评估并发度(GPT评估)
- MODEL_CONCURRENCY_LIMIT = 10
- class RunContext(BaseModel):
- version: str = Field(..., description="当前运行的脚本版本(文件名)")
- input_files: dict[str, str] = Field(..., description="输入文件路径映射")
- q_with_context: str
- q_context: str
- q: str
- log_url: str
- log_dir: str
- # 问题标注
- question_annotation: str | None = Field(default=None, description="问题的标注结果(三层)")
- # 分词和组合
- keywords: list[str] | None = Field(default=None, description="提取的关键词")
- query_combinations: dict[str, list[str]] = Field(default_factory=dict, description="各层级的query组合")
- # 探索结果
- all_sug_queries: list[dict] = Field(default_factory=list, description="所有获取到的推荐词")
- # 评估结果
- evaluation_results: list[dict] = Field(default_factory=list, description="所有推荐词的评估结果")
- optimization_result: dict | None = Field(default=None, description="最终优化结果对象")
- final_output: str | None = Field(default=None, description="最终输出结果(格式化文本)")
- # ============================================================================
- # Agent 1: 问题标注专家
- # ============================================================================
- question_annotation_instructions = """
- 你是搜索需求分析专家。给定问题(含需求背景),在原文上标注三层:本质、硬性、软性。
- ## 判断标准
- **[本质]** - 问题的核心意图
- - 如何获取、教程、推荐、作品、测评等
- **[硬]** - 客观事实性约束(可明确验证、非主观判断)
- - 能明确区分类别的:地域、时间、对象、工具、操作类型
- - 特征:改变后得到完全不同类别的结果
- **[软]** - 主观判断性修饰(因人而异、程度性的)
- - 需要主观评价的:质量、速度、美观、特色、程度
- - 特征:改变后仍是同类结果,只是满足程度不同
- ## 输出格式
- 词语[本质-描述]、词语[硬-描述]、词语[软-描述]
- ## 注意
- - 只输出标注后的字符串
- - 结合需求背景判断意图
- """.strip()
- question_annotator = Agent[None](
- name="问题标注专家",
- instructions=question_annotation_instructions,
- )
- # ============================================================================
- # Agent 2: 分词专家
- # ============================================================================
- segmentation_instructions = """
- 你是中文分词专家。给定一个句子,将其分词。
- ## 分词原则
- 1. 去掉标点符号
- 2. 拆分成最小的有意义单元
- 3. 去掉助词、语气词、助动词
- 4. 保留疑问词
- 5. 保留实词:名词、动词、形容词、副词
- ## 输出要求
- 输出分词列表。
- """.strip()
- class SegmentationResult(BaseModel):
- """分词结果"""
- words: list[str] = Field(..., description="分词列表")
- reasoning: str = Field(..., description="分词说明")
- segmenter = Agent[None](
- name="分词专家",
- instructions=segmentation_instructions,
- output_type=SegmentationResult,
- )
- # ============================================================================
- # Agent 3: 评估专家(意图匹配 + 相关性评分)
- # ============================================================================
- eval_instructions = """
- 你是搜索query评估专家。给定原始问题、问题标注和推荐query,评估两个维度。
- ## 输入信息
- 你会收到:
- 1. 原始问题:用户的原始表述
- 2. 问题标注:对原始问题的三层标注(本质、硬性、软性)
- 3. 推荐query:待评估的推荐词
- ## 评估目标
- 用这个推荐query搜索,能否找到满足原始需求的内容?
- ## 两层评分
- ### 1. intent_match(意图匹配)= true/false
- 推荐query的**使用意图**是否与原问题的**本质**一致?
- **核心:只关注[本质]标注**
- - 问题标注中的 `[本质-XXX]` 标记明确了用户的核心意图
- - 判断推荐词是否能达成这个核心意图
- **常见本质类型:**
- - 找方法/如何获取 → 推荐词应包含方法、途径、网站、渠道等
- - 找教程 → 推荐词应是教程、教学相关
- - 找资源/素材 → 推荐词应是资源、素材本身
- - 找工具 → 推荐词应是工具推荐
- - 看作品 → 推荐词应是作品展示
- **评分:**
- - true = 推荐词的意图与 `[本质]` 一致
- - false = 推荐词的意图与 `[本质]` 不一致
- ### 2. relevance_score(相关性)= 0-1 连续分数
- 在意图匹配的前提下,推荐query在**主题、要素、属性**上与原问题的相关程度?
- **评估维度:**
- - 主题相关:核心主题是否匹配?(如:摄影、旅游、美食)
- - 要素覆盖:`[硬-XXX]` 标记的硬性约束保留了多少?(地域、时间、对象、工具等)
- - 属性匹配:`[软-XXX]` 标记的软性修饰保留了多少?(质量、速度、美观等)
- **评分参考:**
- - 0.9-1.0 = 几乎完美匹配,[硬]和[软]标注的要素都保留
- - 0.7-0.8 = 高度相关,[硬]标注的要素都保留,[软]标注少数缺失
- - 0.5-0.6 = 中度相关,[硬]标注的要素保留大部分,[软]标注多数缺失
- - 0.3-0.4 = 低度相关,[硬]标注的要素部分缺失
- - 0-0.2 = 基本不相关,[硬]标注的要素大量缺失
- ## 评估策略
- 1. **先看[本质]判断 intent_match**:意图不匹配直接 false
- 2. **再看[硬][软]评估 relevance_score**:计算要素和属性的保留程度
- ## 输出要求
- - intent_match: true/false
- - relevance_score: 0-1 的浮点数
- - reason: 详细的评估理由,需要说明:
- - 原问题的[本质]是什么,推荐词是否匹配这个本质
- - [硬]约束哪些保留/缺失
- - [软]修饰哪些保留/缺失
- - 最终相关性分数的依据
- """.strip()
- class RelevanceEvaluation(BaseModel):
- """评估反馈模型 - 意图匹配 + 相关性"""
- intent_match: bool = Field(..., description="意图是否匹配")
- relevance_score: float = Field(..., description="相关性分数 0-1,分数越高越相关")
- reason: str = Field(..., description="评估理由,需说明意图判断和相关性依据")
- evaluator = Agent[None](
- name="评估专家",
- instructions=eval_instructions,
- output_type=RelevanceEvaluation,
- )
- # ============================================================================
- # 核心函数
- # ============================================================================
- async def annotate_question(q_with_context: str) -> str:
- """标注问题(三层)"""
- print("\n正在标注问题...")
- result = await Runner.run(question_annotator, q_with_context)
- annotation = str(result.final_output)
- print(f"问题标注完成:{annotation}")
- return annotation
- async def segment_text(q: str) -> SegmentationResult:
- """分词"""
- print("\n正在分词...")
- result = await Runner.run(segmenter, q)
- seg_result: SegmentationResult = result.final_output
- print(f"分词结果:{seg_result.words}")
- print(f"分词说明:{seg_result.reasoning}")
- return seg_result
- def generate_query_combinations(keywords: list[str], max_combination_size: int) -> dict[str, list[str]]:
- """
- 生成query组合(考虑词的顺序)
- Args:
- keywords: 关键词列表
- max_combination_size: 最大组合词数(N)
- Returns:
- {
- "1-word": [...],
- "2-word": [...],
- "3-word": [...],
- ...
- "N-word": [...]
- }
- """
- result = {}
- for size in range(1, max_combination_size + 1):
- if size > len(keywords):
- break
- # 1-word组合:不需要考虑顺序
- if size == 1:
- queries = keywords.copy()
- else:
- # 多词组合:先选择size个词(combinations),再排列(permutations)
- all_queries = []
- combs = list(combinations(keywords, size))
- for comb in combs:
- # 对每个组合生成所有排列
- perms = list(permutations(comb))
- for perm in perms:
- query = ''.join(perm) # 直接拼接,无空格
- all_queries.append(query)
- # 去重(虽然理论上不会重复,但保险起见)
- queries = list(dict.fromkeys(all_queries))
- result[f"{size}-word"] = queries
- print(f"\n{size}词组合:{len(queries)} 个")
- if len(queries) <= 10:
- for q in queries:
- print(f" - {q}")
- else:
- print(f" - {queries[0]}")
- print(f" - {queries[1]}")
- print(f" ...")
- print(f" - {queries[-1]}")
- return result
- async def fetch_suggestions_for_queries(queries: list[str], context: RunContext) -> list[dict]:
- """
- 并发获取所有query的推荐词(带并发控制)
- Returns:
- [
- {
- "query": "川西",
- "suggestions": ["川西旅游", "川西攻略", ...],
- "timestamp": "..."
- },
- ...
- ]
- """
- print(f"\n{'='*60}")
- print(f"获取推荐词:{len(queries)} 个query(并发度:{API_CONCURRENCY_LIMIT})")
- print(f"{'='*60}")
- xiaohongshu_api = XiaohongshuSearchRecommendations()
- # 创建信号量控制并发
- semaphore = asyncio.Semaphore(API_CONCURRENCY_LIMIT)
- async def get_single_sug(query: str):
- async with semaphore:
- print(f" 查询: {query}")
- suggestions = xiaohongshu_api.get_recommendations(keyword=query)
- print(f" → {len(suggestions) if suggestions else 0} 个推荐词")
- return {
- "query": query,
- "suggestions": suggestions or [],
- "timestamp": datetime.now().isoformat()
- }
- results = await asyncio.gather(*[get_single_sug(q) for q in queries])
- return results
- async def evaluate_all_suggestions(
- sug_results: list[dict],
- original_question: str,
- question_annotation: str,
- context: RunContext
- ) -> list[dict]:
- """
- 评估所有推荐词(带并发控制)
- Args:
- sug_results: 所有query的推荐词结果
- original_question: 原始问题
- question_annotation: 问题标注(三层)
- Returns:
- [
- {
- "source_query": "川西秋季",
- "sug_query": "川西秋季旅游",
- "intent_match": True,
- "relevance_score": 0.8,
- "reason": "..."
- },
- ...
- ]
- """
- print(f"\n{'='*60}")
- print(f"评估推荐词(并发度:{MODEL_CONCURRENCY_LIMIT})")
- print(f"{'='*60}")
- # 创建信号量控制并发
- semaphore = asyncio.Semaphore(MODEL_CONCURRENCY_LIMIT)
- # 收集所有推荐词
- all_evaluations = []
- async def evaluate_single_sug(source_query: str, sug_query: str):
- async with semaphore:
- eval_input = f"""
- <原始问题>
- {original_question}
- </原始问题>
- <问题标注(三层)>
- {question_annotation}
- </问题标注(三层)>
- <待评估的推荐query>
- {sug_query}
- </待评估的推荐query>
- 请评估该推荐query:
- 1. intent_match: 意图是否匹配(true/false)
- 2. relevance_score: 相关性分数(0-1)
- 3. reason: 详细的评估理由
- 评估时请参考问题标注中的[本质]、[硬]、[软]标记。
- """
- result = await Runner.run(evaluator, eval_input)
- evaluation: RelevanceEvaluation = result.final_output
- return {
- "source_query": source_query,
- "sug_query": sug_query,
- "intent_match": evaluation.intent_match,
- "relevance_score": evaluation.relevance_score,
- "reason": evaluation.reason,
- }
- # 并发评估所有推荐词
- tasks = []
- for sug_result in sug_results:
- source_query = sug_result["query"]
- for sug in sug_result["suggestions"]:
- tasks.append(evaluate_single_sug(source_query, sug))
- if tasks:
- print(f" 总共需要评估 {len(tasks)} 个推荐词...")
- all_evaluations = await asyncio.gather(*tasks)
- context.evaluation_results = all_evaluations
- return all_evaluations
- def find_qualified_queries(evaluations: list[dict], min_relevance_score: float = 0.7) -> list[dict]:
- """
- 查找所有合格的query
- 筛选标准:
- 1. intent_match = True(必须满足)
- 2. relevance_score >= min_relevance_score
- 返回:按 relevance_score 降序排列
- """
- qualified = [
- e for e in evaluations
- if e['intent_match'] is True and e['relevance_score'] >= min_relevance_score
- ]
- # 按relevance_score降序排列
- return sorted(qualified, key=lambda x: x['relevance_score'], reverse=True)
- # ============================================================================
- # 主流程
- # ============================================================================
- async def combinatorial_search(context: RunContext, max_combination_size: int = 1) -> dict:
- """
- 组合式搜索流程(带问题标注)
- Args:
- context: 运行上下文
- max_combination_size: 最大组合词数(N),默认1
- 返回格式:
- {
- "success": True/False,
- "results": [...],
- "message": "..."
- }
- """
- # 步骤1:标注问题(三层)
- annotation = await annotate_question(context.q_with_context)
- context.question_annotation = annotation
- # 步骤2:分词
- seg_result = await segment_text(context.q)
- context.keywords = seg_result.words
- # 步骤3:生成query组合
- print(f"\n{'='*60}")
- print(f"生成query组合(最大组合数:{max_combination_size})")
- print(f"{'='*60}")
- query_combinations = generate_query_combinations(context.keywords, max_combination_size)
- context.query_combinations = query_combinations
- # 步骤4:获取所有query的推荐词
- all_queries = []
- for level, queries in query_combinations.items():
- all_queries.extend(queries)
- sug_results = await fetch_suggestions_for_queries(all_queries, context)
- context.all_sug_queries = sug_results
- # 统计
- total_sugs = sum(len(r["suggestions"]) for r in sug_results)
- print(f"\n总共获取到 {total_sugs} 个推荐词")
- # 步骤5:评估所有推荐词(使用原始问题和标注)
- evaluations = await evaluate_all_suggestions(sug_results, context.q, annotation, context)
- # 步骤6:筛选合格query
- qualified = find_qualified_queries(evaluations, min_relevance_score=0.7)
- if qualified:
- return {
- "success": True,
- "results": qualified,
- "message": f"找到 {len(qualified)} 个合格query(intent_match=True 且 relevance>=0.7)"
- }
- # 降低标准
- acceptable = find_qualified_queries(evaluations, min_relevance_score=0.5)
- if acceptable:
- return {
- "success": True,
- "results": acceptable,
- "message": f"找到 {len(acceptable)} 个可接受query(intent_match=True 且 relevance>=0.5)"
- }
- # 完全失败:返回所有intent_match=True的
- intent_matched = [e for e in evaluations if e['intent_match'] is True]
- if intent_matched:
- intent_matched_sorted = sorted(intent_matched, key=lambda x: x['relevance_score'], reverse=True)
- return {
- "success": False,
- "results": intent_matched_sorted[:10], # 只返回前10个
- "message": f"未找到高相关性query,但有 {len(intent_matched)} 个意图匹配的推荐词"
- }
- return {
- "success": False,
- "results": [],
- "message": "未找到任何意图匹配的推荐词"
- }
- # ============================================================================
- # 输出格式化
- # ============================================================================
- def format_output(optimization_result: dict, context: RunContext) -> str:
- """格式化输出结果"""
- results = optimization_result.get("results", [])
- output = f"原始问题:{context.q}\n"
- output += f"问题标注:{context.question_annotation}\n"
- output += f"提取的关键词:{', '.join(context.keywords or [])}\n"
- output += f"关键词数量:{len(context.keywords or [])}\n"
- output += f"\nquery组合统计:\n"
- for level, queries in context.query_combinations.items():
- output += f" - {level}: {len(queries)} 个\n"
- # 统计信息
- total_queries = sum(len(q) for q in context.query_combinations.values())
- total_sugs = sum(len(r["suggestions"]) for r in context.all_sug_queries)
- total_evals = len(context.evaluation_results)
- output += f"\n探索统计:\n"
- output += f" - 总query数:{total_queries}\n"
- output += f" - 总推荐词数:{total_sugs}\n"
- output += f" - 总评估数:{total_evals}\n"
- output += f"\n状态:{optimization_result['message']}\n\n"
- if optimization_result["success"] and results:
- output += "=" * 60 + "\n"
- output += "合格的推荐query(按relevance_score降序):\n"
- output += "=" * 60 + "\n"
- for i, result in enumerate(results[:20], 1): # 只显示前20个
- output += f"\n{i}. [{result['relevance_score']:.2f}] {result['sug_query']}\n"
- output += f" 来源:{result['source_query']}\n"
- output += f" 意图:{'✓ 匹配' if result['intent_match'] else '✗ 不匹配'}\n"
- output += f" 理由:{result['reason'][:150]}...\n" if len(result['reason']) > 150 else f" 理由:{result['reason']}\n"
- else:
- output += "=" * 60 + "\n"
- output += "结果:未找到足够相关的推荐query\n"
- output += "=" * 60 + "\n"
- if results:
- output += "\n最接近的推荐词(前10个):\n\n"
- for i, result in enumerate(results[:10], 1):
- output += f"{i}. [{result['relevance_score']:.2f}] {result['sug_query']}\n"
- output += f" 来源:{result['source_query']}\n"
- output += f" 意图:{'✓ 匹配' if result['intent_match'] else '✗ 不匹配'}\n\n"
- # 按source_query分组显示
- output += "\n" + "=" * 60 + "\n"
- output += "按查询词分组的推荐词情况:\n"
- output += "=" * 60 + "\n"
- for sug_data in context.all_sug_queries:
- source_q = sug_data["query"]
- sugs = sug_data["suggestions"]
- # 找到这个source_query对应的所有评估
- related_evals = [e for e in context.evaluation_results if e["source_query"] == source_q]
- intent_match_count = sum(1 for e in related_evals if e["intent_match"])
- avg_relevance = sum(e["relevance_score"] for e in related_evals) / len(related_evals) if related_evals else 0
- output += f"\n查询:{source_q}\n"
- output += f" 推荐词数:{len(sugs)}\n"
- output += f" 意图匹配数:{intent_match_count}/{len(related_evals)}\n"
- output += f" 平均相关性:{avg_relevance:.2f}\n"
- # 显示前3个推荐词
- if sugs:
- output += f" 示例推荐词:\n"
- for sug in sugs[:3]:
- eval_item = next((e for e in related_evals if e["sug_query"] == sug), None)
- if eval_item:
- output += f" - {sug} [意图:{'✓' if eval_item['intent_match'] else '✗'}, 相关:{eval_item['relevance_score']:.2f}]\n"
- else:
- output += f" - {sug}\n"
- return output.strip()
- # ============================================================================
- # 主函数
- # ============================================================================
- async def main(
- input_dir: str,
- max_combination_size: int = 1,
- api_concurrency: int = API_CONCURRENCY_LIMIT,
- model_concurrency: int = MODEL_CONCURRENCY_LIMIT
- ):
- # 更新全局并发配置
- global API_CONCURRENCY_LIMIT, MODEL_CONCURRENCY_LIMIT
- API_CONCURRENCY_LIMIT = api_concurrency
- MODEL_CONCURRENCY_LIMIT = model_concurrency
- current_time, log_url = set_trace()
- # 从目录中读取固定文件名
- input_context_file = os.path.join(input_dir, 'context.md')
- input_q_file = os.path.join(input_dir, 'q.md')
- q_context = read_file_as_string(input_context_file)
- q = read_file_as_string(input_q_file)
- q_with_context = f"""
- <需求上下文>
- {q_context}
- </需求上下文>
- <当前问题>
- {q}
- </当前问题>
- """.strip()
- # 获取当前文件名作为版本
- version = os.path.basename(__file__)
- version_name = os.path.splitext(version)[0]
- # 日志保存目录
- log_dir = os.path.join(input_dir, "output", version_name, current_time)
- run_context = RunContext(
- version=version,
- input_files={
- "input_dir": input_dir,
- "context_file": input_context_file,
- "q_file": input_q_file,
- },
- q_with_context=q_with_context,
- q_context=q_context,
- q=q,
- log_dir=log_dir,
- log_url=log_url,
- )
- print(f"\n{'='*60}")
- print(f"并发配置")
- print(f"{'='*60}")
- print(f"API请求并发度:{API_CONCURRENCY_LIMIT}")
- print(f"模型评估并发度:{MODEL_CONCURRENCY_LIMIT}")
- # 执行组合式搜索(带问题标注)
- optimization_result = await combinatorial_search(run_context, max_combination_size=max_combination_size)
- # 格式化输出
- final_output = format_output(optimization_result, run_context)
- print(f"\n{'='*60}")
- print("最终结果")
- print(f"{'='*60}")
- print(final_output)
- # 保存结果
- run_context.optimization_result = optimization_result
- run_context.final_output = final_output
- # 保存 RunContext 到 log_dir
- os.makedirs(run_context.log_dir, exist_ok=True)
- context_file_path = os.path.join(run_context.log_dir, "run_context.json")
- with open(context_file_path, "w", encoding="utf-8") as f:
- json.dump(run_context.model_dump(), f, ensure_ascii=False, indent=2)
- print(f"\nRunContext saved to: {context_file_path}")
- if __name__ == "__main__":
- parser = argparse.ArgumentParser(
- description="搜索query优化工具 - v6.3 组合式搜索+问题标注版",
- formatter_class=argparse.RawDescriptionHelpFormatter,
- epilog="""
- 示例:
- # 默认参数
- python sug_v6_3_with_annotation.py
- # 2词组合,API并发5,模型并发20
- python sug_v6_3_with_annotation.py --max-combo 2 --api-concurrency 5 --model-concurrency 20
- # 3词组合,降低并发度
- python sug_v6_3_with_annotation.py --max-combo 3 --api-concurrency 3 --model-concurrency 10
- """
- )
- parser.add_argument(
- "--input-dir",
- type=str,
- default="input/简单扣图",
- help="输入目录路径,默认: input/简单扣图"
- )
- parser.add_argument(
- "--max-combo",
- type=int,
- default=1,
- help="最大组合词数(N),默认: 1"
- )
- parser.add_argument(
- "--api-concurrency",
- type=int,
- default=API_CONCURRENCY_LIMIT,
- help=f"API请求并发度,默认: {API_CONCURRENCY_LIMIT}"
- )
- parser.add_argument(
- "--model-concurrency",
- type=int,
- default=MODEL_CONCURRENCY_LIMIT,
- help=f"模型评估并发度,默认: {MODEL_CONCURRENCY_LIMIT}"
- )
- args = parser.parse_args()
- asyncio.run(main(
- args.input_dir,
- max_combination_size=args.max_combo,
- api_concurrency=args.api_concurrency,
- model_concurrency=args.model_concurrency
- ))
|