sug_v6_4_with_annotation.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966
  1. import asyncio
  2. import json
  3. import os
  4. import argparse
  5. from datetime import datetime
  6. from itertools import combinations, permutations
  7. from agents import Agent, Runner
  8. from lib.my_trace import set_trace
  9. from typing import Literal
  10. from lib.client import get_model
  11. from pydantic import BaseModel, Field
  12. from lib.utils import read_file_as_string
  13. from script.search_recommendations.xiaohongshu_search_recommendations import XiaohongshuSearchRecommendations
  14. # ============================================================================
  15. # 并发控制配置
  16. # ============================================================================
  17. # API请求并发度(小红书接口)
  18. API_CONCURRENCY_LIMIT = 5
  19. # 模型评估并发度(GPT评估)
  20. MODEL_CONCURRENCY_LIMIT = 10
  21. class RunContext(BaseModel):
  22. version: str = Field(..., description="当前运行的脚本版本(文件名)")
  23. input_files: dict[str, str] = Field(..., description="输入文件路径映射")
  24. q_with_context: str
  25. q_context: str
  26. q: str
  27. log_url: str
  28. log_dir: str
  29. # 问题标注
  30. question_annotation: str | None = Field(default=None, description="问题的标注结果(三层)")
  31. # 分词和组合
  32. keywords: list[str] | None = Field(default=None, description="提取的关键词")
  33. query_combinations: dict[str, list[str]] = Field(default_factory=dict, description="各层级的query组合")
  34. # v6.4 新增:剪枝记录
  35. pruning_info: dict[str, dict] = Field(default_factory=dict, description="各层级的剪枝信息")
  36. # 探索结果
  37. all_sug_queries: list[dict] = Field(default_factory=list, description="所有获取到的推荐词")
  38. # 评估结果
  39. evaluation_results: list[dict] = Field(default_factory=list, description="所有推荐词的评估结果")
  40. optimization_result: dict | None = Field(default=None, description="最终优化结果对象")
  41. final_output: str | None = Field(default=None, description="最终输出结果(格式化文本)")
  42. # ============================================================================
  43. # Agent 1: 问题标注专家
  44. # ============================================================================
  45. question_annotation_instructions = """
  46. 你是搜索需求分析专家。给定问题(含需求背景),在原文上标注三层:本质、硬性、软性。
  47. ## 判断标准
  48. **[本质]** - 问题的核心意图
  49. - 如何获取、教程、推荐、作品、测评等
  50. **[硬]** - 客观事实性约束(可明确验证、非主观判断)
  51. - 能明确区分类别的:地域、时间、对象、工具、操作类型
  52. - 特征:改变后得到完全不同类别的结果
  53. **[软]** - 主观判断性修饰(因人而异、程度性的)
  54. - 需要主观评价的:质量、速度、美观、特色、程度
  55. - 特征:改变后仍是同类结果,只是满足程度不同
  56. ## 输出格式
  57. 词语[本质-描述]、词语[硬-描述]、词语[软-描述]
  58. ## 注意
  59. - 只输出标注后的字符串
  60. - 结合需求背景判断意图
  61. """.strip()
  62. question_annotator = Agent[None](
  63. name="问题标注专家",
  64. instructions=question_annotation_instructions,
  65. model=get_model(),
  66. )
  67. # ============================================================================
  68. # Agent 2: 分词专家
  69. # ============================================================================
  70. segmentation_instructions = """
  71. 你是中文分词专家。给定一个句子,将其分词。
  72. ## 分词原则
  73. 1. 去掉标点符号
  74. 2. 拆分成最小的有意义单元
  75. 3. 去掉助词、语气词、助动词
  76. 4. 保留疑问词
  77. 5. 保留实词:名词、动词、形容词、副词
  78. ## 输出要求
  79. 输出分词列表。
  80. """.strip()
  81. class SegmentationResult(BaseModel):
  82. """分词结果"""
  83. words: list[str] = Field(..., description="分词列表")
  84. reasoning: str = Field(..., description="分词说明")
  85. segmenter = Agent[None](
  86. name="分词专家",
  87. instructions=segmentation_instructions,
  88. model=get_model(),
  89. output_type=SegmentationResult,
  90. )
  91. # ============================================================================
  92. # Agent 3: 评估专家(意图匹配 + 相关性评分)
  93. # ============================================================================
  94. eval_instructions = """
  95. 你是搜索query评估专家。给定原始问题、问题标注和推荐query,评估两个维度。
  96. ## 输入信息
  97. 你会收到:
  98. 1. 原始问题:用户的原始表述
  99. 2. 问题标注:对原始问题的三层标注(本质、硬性、软性)
  100. 3. 推荐query:待评估的推荐词
  101. ## 评估目标
  102. 用这个推荐query搜索,能否找到满足原始需求的内容?
  103. ## 两层评分
  104. ### 1. intent_match(意图匹配)= true/false
  105. 推荐query的**使用意图**是否与原问题的**本质**一致?
  106. **核心:只关注[本质]标注**
  107. - 问题标注中的 `[本质-XXX]` 标记明确了用户的核心意图
  108. - 判断推荐词是否能达成这个核心意图
  109. **常见本质类型:**
  110. - 找方法/如何获取 → 推荐词应包含方法、途径、网站、渠道等
  111. - 找教程 → 推荐词应是教程、教学相关
  112. - 找资源/素材 → 推荐词应是资源、素材本身
  113. - 找工具 → 推荐词应是工具推荐
  114. - 看作品 → 推荐词应是作品展示
  115. **评分:**
  116. - true = 推荐词的意图与 `[本质]` 一致
  117. - false = 推荐词的意图与 `[本质]` 不一致
  118. ### 2. relevance_score(相关性)= 0-1 连续分数
  119. 在意图匹配的前提下,推荐query在**主题、要素、属性**上与原问题的相关程度?
  120. **评估维度:**
  121. - 主题相关:核心主题是否匹配?(如:摄影、旅游、美食)
  122. - 要素覆盖:`[硬-XXX]` 标记的硬性约束保留了多少?(地域、时间、对象、工具等)
  123. - 属性匹配:`[软-XXX]` 标记的软性修饰保留了多少?(质量、速度、美观等)
  124. **评分参考:**
  125. - 0.9-1.0 = 几乎完美匹配,[硬]和[软]标注的要素都保留
  126. - 0.7-0.8 = 高度相关,[硬]标注的要素都保留,[软]标注少数缺失
  127. - 0.5-0.6 = 中度相关,[硬]标注的要素保留大部分,[软]标注多数缺失
  128. - 0.3-0.4 = 低度相关,[硬]标注的要素部分缺失
  129. - 0-0.2 = 基本不相关,[硬]标注的要素大量缺失
  130. ## 评估策略
  131. 1. **先看[本质]判断 intent_match**:意图不匹配直接 false
  132. 2. **再看[硬][软]评估 relevance_score**:计算要素和属性的保留程度
  133. ## 输出要求
  134. 请先思考,再打分。按以下顺序输出:
  135. 1. reason: 详细的评估理由(先分析再打分)
  136. - 原问题的[本质]是什么,推荐词是否匹配这个本质
  137. - [硬]约束哪些保留/缺失
  138. - [软]修饰哪些保留/缺失
  139. - 基于以上分析,给出意图匹配判断和相关性分数的依据
  140. 2. intent_match: true/false(基于上述分析得出)
  141. 3. relevance_score: 0-1 的浮点数(基于上述分析得出)
  142. """.strip()
  143. class RelevanceEvaluation(BaseModel):
  144. """评估反馈模型 - 意图匹配 + 相关性"""
  145. reason: str = Field(..., description="评估理由,需说明意图判断和相关性依据")
  146. intent_match: bool = Field(..., description="意图是否匹配")
  147. relevance_score: float = Field(..., description="相关性分数 0-1,分数越高越相关")
  148. evaluator = Agent[None](
  149. name="评估专家",
  150. instructions=eval_instructions,
  151. model=get_model(),
  152. output_type=RelevanceEvaluation,
  153. )
  154. # ============================================================================
  155. # 核心函数
  156. # ============================================================================
  157. async def annotate_question(q_with_context: str) -> str:
  158. """标注问题(三层)"""
  159. print("\n正在标注问题...")
  160. result = await Runner.run(question_annotator, q_with_context)
  161. annotation = str(result.final_output)
  162. print(f"问题标注完成:{annotation}")
  163. return annotation
  164. async def segment_text(q: str) -> SegmentationResult:
  165. """分词"""
  166. print("\n正在分词...")
  167. result = await Runner.run(segmenter, q)
  168. seg_result: SegmentationResult = result.final_output
  169. print(f"分词结果:{seg_result.words}")
  170. print(f"分词说明:{seg_result.reasoning}")
  171. return seg_result
  172. def generate_query_combinations_single_level(
  173. keywords: list[str],
  174. size: int,
  175. must_include_keywords: set[str] | None = None
  176. ) -> list[str]:
  177. """
  178. 生成单个层级的query组合
  179. Args:
  180. keywords: 关键词列表
  181. size: 组合词数
  182. must_include_keywords: 必须包含的关键词集合(用于剪枝)
  183. Returns:
  184. 该层级的所有query组合
  185. """
  186. if size > len(keywords):
  187. return []
  188. # 1-word组合:不需要考虑顺序
  189. if size == 1:
  190. if must_include_keywords:
  191. # 只返回必须包含的关键词
  192. return [kw for kw in keywords if kw in must_include_keywords]
  193. return keywords.copy()
  194. # 多词组合:先选择size个词(combinations),再排列(permutations)
  195. all_queries = []
  196. combs = list(combinations(keywords, size))
  197. for comb in combs:
  198. # 如果设置了必须包含的关键词,检查组合是否包含至少一个
  199. if must_include_keywords:
  200. if not any(kw in must_include_keywords for kw in comb):
  201. continue # 跳过不包含必须关键词的组合
  202. # 对每个组合生成所有排列
  203. perms = list(permutations(comb))
  204. for perm in perms:
  205. query = ''.join(perm) # 直接拼接,无空格
  206. all_queries.append(query)
  207. # 去重
  208. return list(dict.fromkeys(all_queries))
  209. async def evaluate_single_sug_with_semaphore(
  210. source_query: str,
  211. sug_query: str,
  212. original_question: str,
  213. question_annotation: str,
  214. semaphore: asyncio.Semaphore
  215. ) -> dict:
  216. """带信号量的单个推荐词评估"""
  217. async with semaphore:
  218. eval_input = f"""
  219. <原始问题>
  220. {original_question}
  221. </原始问题>
  222. <问题标注(三层)>
  223. {question_annotation}
  224. </问题标注(三层)>
  225. <待评估的推荐query>
  226. {sug_query}
  227. </待评估的推荐query>
  228. 请评估该推荐query(请先分析理由,再给出评分):
  229. 1. reason: 详细的评估理由(先思考分析)
  230. 2. intent_match: 意图是否匹配(true/false)
  231. 3. relevance_score: 相关性分数(0-1)
  232. 评估时请参考问题标注中的[本质]、[硬]、[软]标记。
  233. """
  234. result = await Runner.run(evaluator, eval_input)
  235. evaluation: RelevanceEvaluation = result.final_output
  236. return {
  237. "source_query": source_query,
  238. "sug_query": sug_query,
  239. "intent_match": evaluation.intent_match,
  240. "relevance_score": evaluation.relevance_score,
  241. "reason": evaluation.reason,
  242. }
  243. async def fetch_and_evaluate_level(
  244. queries: list[str],
  245. original_question: str,
  246. question_annotation: str,
  247. level_name: str,
  248. context: RunContext
  249. ) -> tuple[list[dict], list[dict]]:
  250. """
  251. 处理单个层级:获取推荐词并评估
  252. Returns:
  253. (sug_results, evaluations)
  254. """
  255. xiaohongshu_api = XiaohongshuSearchRecommendations()
  256. # 创建信号量
  257. api_semaphore = asyncio.Semaphore(API_CONCURRENCY_LIMIT)
  258. model_semaphore = asyncio.Semaphore(MODEL_CONCURRENCY_LIMIT)
  259. # 结果收集
  260. sug_results = []
  261. all_evaluations = []
  262. # 统计
  263. total_queries = len(queries)
  264. completed_queries = 0
  265. total_sugs = 0
  266. completed_evals = 0
  267. async def get_and_evaluate_single_query(query: str):
  268. nonlocal completed_queries, total_sugs, completed_evals
  269. # 步骤1:获取推荐词
  270. async with api_semaphore:
  271. suggestions = xiaohongshu_api.get_recommendations(keyword=query)
  272. sug_count = len(suggestions) if suggestions else 0
  273. completed_queries += 1
  274. total_sugs += sug_count
  275. print(f" [{completed_queries}/{total_queries}] {query} → {sug_count} 个推荐词")
  276. sug_result = {
  277. "query": query,
  278. "suggestions": suggestions or [],
  279. "timestamp": datetime.now().isoformat()
  280. }
  281. sug_results.append(sug_result)
  282. # 步骤2:立即评估这些推荐词
  283. if suggestions:
  284. eval_tasks = []
  285. for sug in suggestions:
  286. eval_tasks.append(evaluate_single_sug_with_semaphore(
  287. query, sug, original_question, question_annotation, model_semaphore
  288. ))
  289. if eval_tasks:
  290. evals = await asyncio.gather(*eval_tasks)
  291. all_evaluations.extend(evals)
  292. completed_evals += len(evals)
  293. print(f" ↳ 已评估 {len(evals)} 个,累计评估 {completed_evals} 个")
  294. # 并发处理所有query
  295. await asyncio.gather(*[get_and_evaluate_single_query(q) for q in queries])
  296. print(f"\n{level_name} 完成:获取 {total_sugs} 个推荐词,完成 {completed_evals} 个评估")
  297. return sug_results, all_evaluations
  298. def find_intent_matched_keywords(
  299. keywords: list[str],
  300. evaluations: list[dict]
  301. ) -> set[str]:
  302. """
  303. 找出所有至少有一个 intent_match=True 的推荐词的关键词
  304. Args:
  305. keywords: 当前层级使用的关键词列表
  306. evaluations: 该层级的评估结果
  307. Returns:
  308. 有意图匹配的关键词集合
  309. """
  310. matched_keywords = set()
  311. for keyword in keywords:
  312. # 检查这个关键词对应的推荐词中是否有 intent_match=True 的
  313. keyword_evals = [
  314. e for e in evaluations
  315. if e['source_query'] == keyword and e['intent_match'] is True
  316. ]
  317. if keyword_evals:
  318. matched_keywords.add(keyword)
  319. return matched_keywords
  320. def find_top_keywords_by_relevance(
  321. keywords: list[str],
  322. evaluations: list[dict],
  323. top_n: int = 2
  324. ) -> list[str]:
  325. """
  326. 根据 relevance_score 找出表现最好的 top N 关键词
  327. Args:
  328. keywords: 当前层级使用的关键词列表
  329. evaluations: 该层级的评估结果
  330. top_n: 保留的关键词数量
  331. Returns:
  332. 按平均 relevance_score 排序的 top N 关键词
  333. """
  334. keyword_scores = {}
  335. for keyword in keywords:
  336. # 找到这个关键词对应的所有评估
  337. keyword_evals = [
  338. e for e in evaluations
  339. if e['source_query'] == keyword
  340. ]
  341. if keyword_evals:
  342. # 计算平均 relevance_score
  343. avg_score = sum(e['relevance_score'] for e in keyword_evals) / len(keyword_evals)
  344. # 同时记录最高分,作为次要排序依据
  345. max_score = max(e['relevance_score'] for e in keyword_evals)
  346. keyword_scores[keyword] = {
  347. 'avg': avg_score,
  348. 'max': max_score,
  349. 'count': len(keyword_evals)
  350. }
  351. if not keyword_scores:
  352. return []
  353. # 按平均分降序,最高分降序
  354. sorted_keywords = sorted(
  355. keyword_scores.items(),
  356. key=lambda x: (x[1]['avg'], x[1]['max']),
  357. reverse=True
  358. )
  359. # 返回 top N 关键词
  360. return [kw for kw, score in sorted_keywords[:top_n]]
  361. def find_qualified_queries(evaluations: list[dict], min_relevance_score: float = 0.7) -> list[dict]:
  362. """
  363. 查找所有合格的query
  364. 筛选标准:
  365. 1. intent_match = True(必须满足)
  366. 2. relevance_score >= min_relevance_score
  367. 返回:按 relevance_score 降序排列
  368. """
  369. qualified = [
  370. e for e in evaluations
  371. if e['intent_match'] is True and e['relevance_score'] >= min_relevance_score
  372. ]
  373. # 按relevance_score降序排列
  374. return sorted(qualified, key=lambda x: x['relevance_score'], reverse=True)
  375. # ============================================================================
  376. # 主流程 - v6.4 层级剪枝
  377. # ============================================================================
  378. async def combinatorial_search_with_pruning(
  379. context: RunContext,
  380. max_combination_size: int = 1,
  381. fallback_top_n: int = 2,
  382. pruning_start_level: int = 3
  383. ) -> dict:
  384. """
  385. 组合式搜索流程(带层级剪枝)
  386. 策略:
  387. - 第1-2层:充分探索,使用所有关键词
  388. - 第3层及以上:开始剪枝
  389. 1. 优先使用在第1层中至少有一个 intent_match=True 的关键词
  390. 2. 如果没有,则使用 relevance_score 最高的 top N 关键词
  391. 3. 如果也无法计算,则使用全部关键词
  392. Args:
  393. context: 运行上下文
  394. max_combination_size: 最大组合词数(N),默认1
  395. fallback_top_n: 当没有意图匹配时,使用 relevance_score top N 关键词,默认2
  396. pruning_start_level: 从第几层开始剪枝,默认3(即第1-2层充分探索)
  397. 返回格式:
  398. {
  399. "success": True/False,
  400. "results": [...],
  401. "message": "..."
  402. }
  403. """
  404. # 步骤1:标注问题(三层)
  405. annotation = await annotate_question(context.q_with_context)
  406. context.question_annotation = annotation
  407. # 步骤2:分词
  408. seg_result = await segment_text(context.q)
  409. all_keywords = seg_result.words
  410. context.keywords = all_keywords
  411. # 初始化累积结果
  412. all_sug_results = []
  413. all_evaluations = []
  414. # 保存第1层的评估结果,用于后续剪枝
  415. level_1_evaluations = []
  416. # 保存第1层中有意图匹配的关键词(用于剪枝)
  417. must_include_keywords = None
  418. print(f"\n{'='*60}")
  419. print(f"层级剪枝式搜索(最大层级:{max_combination_size},第{pruning_start_level}层开始剪枝)")
  420. print(f"{'='*60}")
  421. # 逐层处理
  422. for level in range(1, max_combination_size + 1):
  423. level_name = f"{level}-word"
  424. print(f"\n{'='*60}")
  425. print(f"第 {level} 层:{level_name}")
  426. print(f"{'='*60}")
  427. # 判断当前层是否需要剪枝
  428. should_prune = level >= pruning_start_level and must_include_keywords is not None
  429. if should_prune:
  430. print(f"剪枝模式:只生成包含以下关键词之一的组合")
  431. print(f" 必须包含的关键词:{sorted(must_include_keywords)}")
  432. # 生成当前层的query组合
  433. level_queries = generate_query_combinations_single_level(
  434. all_keywords,
  435. level,
  436. must_include_keywords=must_include_keywords if should_prune else None
  437. )
  438. if not level_queries:
  439. print(f"⚠️ 无法生成 {level} 词组合,跳过")
  440. context.pruning_info[level_name] = {
  441. "available_keywords": list(all_keywords),
  442. "queries_count": 0,
  443. "pruned": should_prune,
  444. "reason": f"剪枝后无有效组合" if should_prune else f"关键词数量不足以生成 {level} 词组合"
  445. }
  446. break
  447. print(f"可用关键词:{all_keywords}")
  448. if should_prune:
  449. total_possible = len(generate_query_combinations_single_level(all_keywords, level))
  450. print(f"剪枝效果:{len(level_queries)}/{total_possible} (保留 {len(level_queries)/total_possible*100:.1f}%)")
  451. print(f"生成的query数:{len(level_queries)}")
  452. # 记录该层的query组合
  453. context.query_combinations[level_name] = level_queries
  454. # 打印部分query示例
  455. print(f"\nquery示例(前10个):")
  456. for i, q in enumerate(level_queries[:10], 1):
  457. print(f" {i}. {q}")
  458. if len(level_queries) > 10:
  459. print(f" ... 还有 {len(level_queries) - 10} 个")
  460. # 获取推荐词并评估
  461. print(f"\n开始处理第 {level} 层的推荐词...")
  462. level_sug_results, level_evaluations = await fetch_and_evaluate_level(
  463. level_queries,
  464. context.q,
  465. annotation,
  466. level_name,
  467. context
  468. )
  469. # 累积结果
  470. all_sug_results.extend(level_sug_results)
  471. all_evaluations.extend(level_evaluations)
  472. # 保存第1层的评估结果
  473. if level == 1:
  474. level_1_evaluations = level_evaluations.copy()
  475. # 统计该层的意图匹配情况
  476. intent_matched_count = sum(1 for e in level_evaluations if e['intent_match'] is True)
  477. print(f"\n第 {level} 层统计:")
  478. print(f" - 查询数:{len(level_queries)}")
  479. print(f" - 推荐词数:{sum(len(r['suggestions']) for r in level_sug_results)}")
  480. print(f" - 意图匹配数:{intent_matched_count}/{len(level_evaluations)}")
  481. # 记录剪枝信息
  482. context.pruning_info[level_name] = {
  483. "available_keywords": list(all_keywords),
  484. "must_include_keywords": list(must_include_keywords) if must_include_keywords else None,
  485. "queries_count": len(level_queries),
  486. "pruned": should_prune,
  487. "intent_matched_count": intent_matched_count,
  488. "total_evaluations": len(level_evaluations)
  489. }
  490. # 如果是第1层,准备剪枝关键词
  491. if level == 1 and pruning_start_level > 1:
  492. # 基于第1层的结果确定后续层必须包含的关键词
  493. matched_keywords = find_intent_matched_keywords(all_keywords, level_1_evaluations)
  494. print(f"\n准备剪枝策略(将用于第{pruning_start_level}层及以上):")
  495. print(f" - 原始关键词数:{len(all_keywords)}")
  496. print(f" - 意图匹配关键词数(基于第1层):{len(matched_keywords)}")
  497. if matched_keywords:
  498. print(f" ✓ 后续层将只保留包含以下关键词之一的组合:{sorted(matched_keywords)}")
  499. must_include_keywords = matched_keywords
  500. else:
  501. print(f" ⚠️ 第1层没有任何关键词产生 intent_match=True 的推荐词")
  502. # 退而求其次:使用 relevance_score 最高的 top N 关键词
  503. top_keywords = find_top_keywords_by_relevance(all_keywords, level_1_evaluations, top_n=fallback_top_n)
  504. if top_keywords:
  505. print(f" ✓ 后续层将只保留包含以下关键词之一的组合(基于 relevance top {fallback_top_n}):{top_keywords}")
  506. must_include_keywords = set(top_keywords)
  507. # 显示关键词的得分详情
  508. for kw in top_keywords:
  509. kw_evals = [e for e in level_1_evaluations if e['source_query'] == kw]
  510. if kw_evals:
  511. avg_score = sum(e['relevance_score'] for e in kw_evals) / len(kw_evals)
  512. max_score = max(e['relevance_score'] for e in kw_evals)
  513. print(f" - {kw}: 平均={avg_score:.2f}, 最高={max_score:.2f}, 推荐词数={len(kw_evals)}")
  514. else:
  515. print(f" ⚠️ 无法计算 relevance_score,后续层将不剪枝")
  516. must_include_keywords = None
  517. # 保存累积结果
  518. context.all_sug_queries = all_sug_results
  519. context.evaluation_results = all_evaluations
  520. # 筛选合格query
  521. print(f"\n{'='*60}")
  522. print(f"筛选最终结果")
  523. print(f"{'='*60}")
  524. qualified = find_qualified_queries(all_evaluations, min_relevance_score=0.7)
  525. if qualified:
  526. return {
  527. "success": True,
  528. "results": qualified,
  529. "message": f"找到 {len(qualified)} 个合格query(intent_match=True 且 relevance>=0.7)"
  530. }
  531. # 降低标准
  532. acceptable = find_qualified_queries(all_evaluations, min_relevance_score=0.5)
  533. if acceptable:
  534. return {
  535. "success": True,
  536. "results": acceptable,
  537. "message": f"找到 {len(acceptable)} 个可接受query(intent_match=True 且 relevance>=0.5)"
  538. }
  539. # 完全失败:返回所有intent_match=True的
  540. intent_matched = [e for e in all_evaluations if e['intent_match'] is True]
  541. if intent_matched:
  542. intent_matched_sorted = sorted(intent_matched, key=lambda x: x['relevance_score'], reverse=True)
  543. return {
  544. "success": False,
  545. "results": intent_matched_sorted[:10], # 只返回前10个
  546. "message": f"未找到高相关性query,但有 {len(intent_matched)} 个意图匹配的推荐词"
  547. }
  548. return {
  549. "success": False,
  550. "results": [],
  551. "message": "未找到任何意图匹配的推荐词"
  552. }
  553. # ============================================================================
  554. # 输出格式化
  555. # ============================================================================
  556. def format_output(optimization_result: dict, context: RunContext) -> str:
  557. """格式化输出结果"""
  558. results = optimization_result.get("results", [])
  559. output = f"原始问题:{context.q}\n"
  560. output += f"问题标注:{context.question_annotation}\n"
  561. output += f"提取的关键词:{', '.join(context.keywords or [])}\n"
  562. output += f"关键词数量:{len(context.keywords or [])}\n"
  563. # 层级剪枝信息
  564. output += f"\n{'='*60}\n"
  565. output += f"层级剪枝信息:\n"
  566. output += f"{'='*60}\n"
  567. for level_name, info in context.pruning_info.items():
  568. output += f"\n{level_name}:\n"
  569. if info.get('pruned'):
  570. if info.get('reason'):
  571. output += f" 状态:剪枝失败 ⚠️\n"
  572. output += f" 原因:{info.get('reason')}\n"
  573. else:
  574. output += f" 状态:已剪枝 ✂️\n"
  575. output += f" 必须包含关键词:{info.get('must_include_keywords', [])}\n"
  576. output += f" 生成query数:{info['queries_count']}\n"
  577. else:
  578. output += f" 状态:充分探索 ✓\n"
  579. output += f" 可用关键词数:{len(info['available_keywords'])}\n"
  580. output += f" 可用关键词:{info['available_keywords']}\n"
  581. output += f" 生成query数:{info['queries_count']}\n"
  582. output += f" 意图匹配数:{info.get('intent_matched_count', 0)}/{info.get('total_evaluations', 0)}\n"
  583. # query组合统计
  584. output += f"\n{'='*60}\n"
  585. output += f"query组合统计:\n"
  586. output += f"{'='*60}\n"
  587. for level, queries in context.query_combinations.items():
  588. output += f" - {level}: {len(queries)} 个\n"
  589. # 统计信息
  590. total_queries = sum(len(q) for q in context.query_combinations.values())
  591. total_sugs = sum(len(r["suggestions"]) for r in context.all_sug_queries)
  592. total_evals = len(context.evaluation_results)
  593. output += f"\n探索统计:\n"
  594. output += f" - 总query数:{total_queries}\n"
  595. output += f" - 总推荐词数:{total_sugs}\n"
  596. output += f" - 总评估数:{total_evals}\n"
  597. output += f"\n状态:{optimization_result['message']}\n\n"
  598. if optimization_result["success"] and results:
  599. output += "=" * 60 + "\n"
  600. output += "合格的推荐query(按relevance_score降序):\n"
  601. output += "=" * 60 + "\n"
  602. for i, result in enumerate(results[:20], 1): # 只显示前20个
  603. output += f"\n{i}. [{result['relevance_score']:.2f}] {result['sug_query']}\n"
  604. output += f" 来源:{result['source_query']}\n"
  605. output += f" 意图:{'✓ 匹配' if result['intent_match'] else '✗ 不匹配'}\n"
  606. output += f" 理由:{result['reason'][:150]}...\n" if len(result['reason']) > 150 else f" 理由:{result['reason']}\n"
  607. else:
  608. output += "=" * 60 + "\n"
  609. output += "结果:未找到足够相关的推荐query\n"
  610. output += "=" * 60 + "\n"
  611. if results:
  612. output += "\n最接近的推荐词(前10个):\n\n"
  613. for i, result in enumerate(results[:10], 1):
  614. output += f"{i}. [{result['relevance_score']:.2f}] {result['sug_query']}\n"
  615. output += f" 来源:{result['source_query']}\n"
  616. output += f" 意图:{'✓ 匹配' if result['intent_match'] else '✗ 不匹配'}\n\n"
  617. # 按source_query分组显示
  618. output += "\n" + "=" * 60 + "\n"
  619. output += "按查询词分组的推荐词情况:\n"
  620. output += "=" * 60 + "\n"
  621. for sug_data in context.all_sug_queries:
  622. source_q = sug_data["query"]
  623. sugs = sug_data["suggestions"]
  624. # 找到这个source_query对应的所有评估
  625. related_evals = [e for e in context.evaluation_results if e["source_query"] == source_q]
  626. intent_match_count = sum(1 for e in related_evals if e["intent_match"])
  627. avg_relevance = sum(e["relevance_score"] for e in related_evals) / len(related_evals) if related_evals else 0
  628. output += f"\n查询:{source_q}\n"
  629. output += f" 推荐词数:{len(sugs)}\n"
  630. output += f" 意图匹配数:{intent_match_count}/{len(related_evals)}\n"
  631. output += f" 平均相关性:{avg_relevance:.2f}\n"
  632. # 显示前3个推荐词
  633. if sugs:
  634. output += f" 示例推荐词:\n"
  635. for sug in sugs[:3]:
  636. eval_item = next((e for e in related_evals if e["sug_query"] == sug), None)
  637. if eval_item:
  638. output += f" - {sug} [意图:{'✓' if eval_item['intent_match'] else '✗'}, 相关:{eval_item['relevance_score']:.2f}]\n"
  639. else:
  640. output += f" - {sug}\n"
  641. return output.strip()
  642. # ============================================================================
  643. # 主函数
  644. # ============================================================================
  645. async def main(
  646. input_dir: str,
  647. max_combination_size: int = 1,
  648. api_concurrency: int = API_CONCURRENCY_LIMIT,
  649. model_concurrency: int = MODEL_CONCURRENCY_LIMIT,
  650. fallback_top_n: int = 2,
  651. pruning_start_level: int = 3
  652. ):
  653. # 更新全局并发配置
  654. global API_CONCURRENCY_LIMIT, MODEL_CONCURRENCY_LIMIT
  655. API_CONCURRENCY_LIMIT = api_concurrency
  656. MODEL_CONCURRENCY_LIMIT = model_concurrency
  657. current_time, log_url = set_trace()
  658. # 从目录中读取固定文件名
  659. input_context_file = os.path.join(input_dir, 'context.md')
  660. input_q_file = os.path.join(input_dir, 'q.md')
  661. q_context = read_file_as_string(input_context_file)
  662. q = read_file_as_string(input_q_file)
  663. q_with_context = f"""
  664. <需求上下文>
  665. {q_context}
  666. </需求上下文>
  667. <当前问题>
  668. {q}
  669. </当前问题>
  670. """.strip()
  671. # 获取当前文件名作为版本
  672. version = os.path.basename(__file__)
  673. version_name = os.path.splitext(version)[0]
  674. # 日志保存目录
  675. log_dir = os.path.join(input_dir, "output", version_name, current_time)
  676. run_context = RunContext(
  677. version=version,
  678. input_files={
  679. "input_dir": input_dir,
  680. "context_file": input_context_file,
  681. "q_file": input_q_file,
  682. },
  683. q_with_context=q_with_context,
  684. q_context=q_context,
  685. q=q,
  686. log_dir=log_dir,
  687. log_url=log_url,
  688. )
  689. print(f"\n{'='*60}")
  690. print(f"并发配置")
  691. print(f"{'='*60}")
  692. print(f"API请求并发度:{API_CONCURRENCY_LIMIT}")
  693. print(f"模型评估并发度:{MODEL_CONCURRENCY_LIMIT}")
  694. # 执行层级剪枝式搜索
  695. optimization_result = await combinatorial_search_with_pruning(
  696. run_context,
  697. max_combination_size=max_combination_size,
  698. fallback_top_n=fallback_top_n,
  699. pruning_start_level=pruning_start_level
  700. )
  701. # 格式化输出
  702. final_output = format_output(optimization_result, run_context)
  703. print(f"\n{'='*60}")
  704. print("最终结果")
  705. print(f"{'='*60}")
  706. print(final_output)
  707. # 保存结果
  708. run_context.optimization_result = optimization_result
  709. run_context.final_output = final_output
  710. # 保存 RunContext 到 log_dir
  711. os.makedirs(run_context.log_dir, exist_ok=True)
  712. context_file_path = os.path.join(run_context.log_dir, "run_context.json")
  713. with open(context_file_path, "w", encoding="utf-8") as f:
  714. json.dump(run_context.model_dump(), f, ensure_ascii=False, indent=2)
  715. print(f"\nRunContext saved to: {context_file_path}")
  716. if __name__ == "__main__":
  717. parser = argparse.ArgumentParser(
  718. description="搜索query优化工具 - v6.4 层级剪枝版",
  719. formatter_class=argparse.RawDescriptionHelpFormatter,
  720. epilog="""
  721. 示例:
  722. # 默认参数(只搜索1层)
  723. python sug_v6_4_with_annotation.py
  724. # 2层搜索,第2层只使用第1层中有意图匹配的关键词
  725. python sug_v6_4_with_annotation.py --max-combo 2
  726. # 2层搜索,如果第1层没有意图匹配,则使用 top 3 关键词
  727. python sug_v6_4_with_annotation.py --max-combo 2 --fallback-top 3
  728. # 3层搜索,API并发5,模型并发20
  729. python sug_v6_4_with_annotation.py --max-combo 3 --api-concurrency 5 --model-concurrency 20
  730. # 指定输入目录
  731. python sug_v6_4_with_annotation.py --input-dir "input/旅游-逸趣玩旅行/如何获取能体现川西秋季特色的高质量风光摄影素材?"
  732. """
  733. )
  734. parser.add_argument(
  735. "--input-dir",
  736. type=str,
  737. default="input/简单扣图",
  738. help="输入目录路径,默认: input/简单扣图"
  739. )
  740. parser.add_argument(
  741. "--max-combo",
  742. type=int,
  743. default=1,
  744. help="最大组合词数(N),默认: 1"
  745. )
  746. parser.add_argument(
  747. "--fallback-top",
  748. type=int,
  749. default=2,
  750. help="当第1层没有意图匹配时,使用 relevance_score top N 关键词,默认: 2"
  751. )
  752. parser.add_argument(
  753. "--pruning-start",
  754. type=int,
  755. default=3,
  756. help="从第几层开始剪枝,默认: 3(即第1-2层充分探索,第3层及以上剪枝)"
  757. )
  758. parser.add_argument(
  759. "--api-concurrency",
  760. type=int,
  761. default=API_CONCURRENCY_LIMIT,
  762. help=f"API请求并发度,默认: {API_CONCURRENCY_LIMIT}"
  763. )
  764. parser.add_argument(
  765. "--model-concurrency",
  766. type=int,
  767. default=MODEL_CONCURRENCY_LIMIT,
  768. help=f"模型评估并发度,默认: {MODEL_CONCURRENCY_LIMIT}"
  769. )
  770. args = parser.parse_args()
  771. asyncio.run(main(
  772. args.input_dir,
  773. max_combination_size=args.max_combo,
  774. api_concurrency=args.api_concurrency,
  775. model_concurrency=args.model_concurrency,
  776. fallback_top_n=args.fallback_top,
  777. pruning_start_level=args.pruning_start
  778. ))