sug_v6_4_with_annotation.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927
  1. import asyncio
  2. import json
  3. import os
  4. import argparse
  5. from datetime import datetime
  6. from itertools import combinations, permutations
  7. from agents import Agent, Runner
  8. from lib.my_trace import set_trace
  9. from typing import Literal
  10. from pydantic import BaseModel, Field
  11. from lib.utils import read_file_as_string
  12. from script.search_recommendations.xiaohongshu_search_recommendations import XiaohongshuSearchRecommendations
  13. # ============================================================================
  14. # 并发控制配置
  15. # ============================================================================
  16. # API请求并发度(小红书接口)
  17. API_CONCURRENCY_LIMIT = 5
  18. # 模型评估并发度(GPT评估)
  19. MODEL_CONCURRENCY_LIMIT = 10
  20. class RunContext(BaseModel):
  21. version: str = Field(..., description="当前运行的脚本版本(文件名)")
  22. input_files: dict[str, str] = Field(..., description="输入文件路径映射")
  23. q_with_context: str
  24. q_context: str
  25. q: str
  26. log_url: str
  27. log_dir: str
  28. # 问题标注
  29. question_annotation: str | None = Field(default=None, description="问题的标注结果(三层)")
  30. # 分词和组合
  31. keywords: list[str] | None = Field(default=None, description="提取的关键词")
  32. query_combinations: dict[str, list[str]] = Field(default_factory=dict, description="各层级的query组合")
  33. # v6.4 新增:剪枝记录
  34. pruning_info: dict[str, dict] = Field(default_factory=dict, description="各层级的剪枝信息")
  35. # 探索结果
  36. all_sug_queries: list[dict] = Field(default_factory=list, description="所有获取到的推荐词")
  37. # 评估结果
  38. evaluation_results: list[dict] = Field(default_factory=list, description="所有推荐词的评估结果")
  39. optimization_result: dict | None = Field(default=None, description="最终优化结果对象")
  40. final_output: str | None = Field(default=None, description="最终输出结果(格式化文本)")
  41. # ============================================================================
  42. # Agent 1: 问题标注专家
  43. # ============================================================================
  44. question_annotation_instructions = """
  45. 你是搜索需求分析专家。给定问题(含需求背景),在原文上标注三层:本质、硬性、软性。
  46. ## 判断标准
  47. **[本质]** - 问题的核心意图
  48. - 如何获取、教程、推荐、作品、测评等
  49. **[硬]** - 客观事实性约束(可明确验证、非主观判断)
  50. - 能明确区分类别的:地域、时间、对象、工具、操作类型
  51. - 特征:改变后得到完全不同类别的结果
  52. **[软]** - 主观判断性修饰(因人而异、程度性的)
  53. - 需要主观评价的:质量、速度、美观、特色、程度
  54. - 特征:改变后仍是同类结果,只是满足程度不同
  55. ## 输出格式
  56. 词语[本质-描述]、词语[硬-描述]、词语[软-描述]
  57. ## 注意
  58. - 只输出标注后的字符串
  59. - 结合需求背景判断意图
  60. """.strip()
  61. question_annotator = Agent[None](
  62. name="问题标注专家",
  63. instructions=question_annotation_instructions,
  64. )
  65. # ============================================================================
  66. # Agent 2: 分词专家
  67. # ============================================================================
  68. segmentation_instructions = """
  69. 你是中文分词专家。给定一个句子,将其分词。
  70. ## 分词原则
  71. 1. 去掉标点符号
  72. 2. 拆分成最小的有意义单元
  73. 3. 去掉助词、语气词、助动词
  74. 4. 保留疑问词
  75. 5. 保留实词:名词、动词、形容词、副词
  76. ## 输出要求
  77. 输出分词列表。
  78. """.strip()
  79. class SegmentationResult(BaseModel):
  80. """分词结果"""
  81. words: list[str] = Field(..., description="分词列表")
  82. reasoning: str = Field(..., description="分词说明")
  83. segmenter = Agent[None](
  84. name="分词专家",
  85. instructions=segmentation_instructions,
  86. output_type=SegmentationResult,
  87. )
  88. # ============================================================================
  89. # Agent 3: 评估专家(意图匹配 + 相关性评分)
  90. # ============================================================================
  91. eval_instructions = """
  92. 你是搜索query评估专家。给定原始问题、问题标注和推荐query,评估两个维度。
  93. ## 输入信息
  94. 你会收到:
  95. 1. 原始问题:用户的原始表述
  96. 2. 问题标注:对原始问题的三层标注(本质、硬性、软性)
  97. 3. 推荐query:待评估的推荐词
  98. ## 评估目标
  99. 用这个推荐query搜索,能否找到满足原始需求的内容?
  100. ## 两层评分
  101. ### 1. intent_match(意图匹配)= true/false
  102. 推荐query的**使用意图**是否与原问题的**本质**一致?
  103. **核心:只关注[本质]标注**
  104. - 问题标注中的 `[本质-XXX]` 标记明确了用户的核心意图
  105. - 判断推荐词是否能达成这个核心意图
  106. **常见本质类型:**
  107. - 找方法/如何获取 → 推荐词应包含方法、途径、网站、渠道等
  108. - 找教程 → 推荐词应是教程、教学相关
  109. - 找资源/素材 → 推荐词应是资源、素材本身
  110. - 找工具 → 推荐词应是工具推荐
  111. - 看作品 → 推荐词应是作品展示
  112. **评分:**
  113. - true = 推荐词的意图与 `[本质]` 一致
  114. - false = 推荐词的意图与 `[本质]` 不一致
  115. ### 2. relevance_score(相关性)= 0-1 连续分数
  116. 在意图匹配的前提下,推荐query在**主题、要素、属性**上与原问题的相关程度?
  117. **评估维度:**
  118. - 主题相关:核心主题是否匹配?(如:摄影、旅游、美食)
  119. - 要素覆盖:`[硬-XXX]` 标记的硬性约束保留了多少?(地域、时间、对象、工具等)
  120. - 属性匹配:`[软-XXX]` 标记的软性修饰保留了多少?(质量、速度、美观等)
  121. **评分参考:**
  122. - 0.9-1.0 = 几乎完美匹配,[硬]和[软]标注的要素都保留
  123. - 0.7-0.8 = 高度相关,[硬]标注的要素都保留,[软]标注少数缺失
  124. - 0.5-0.6 = 中度相关,[硬]标注的要素保留大部分,[软]标注多数缺失
  125. - 0.3-0.4 = 低度相关,[硬]标注的要素部分缺失
  126. - 0-0.2 = 基本不相关,[硬]标注的要素大量缺失
  127. ## 评估策略
  128. 1. **先看[本质]判断 intent_match**:意图不匹配直接 false
  129. 2. **再看[硬][软]评估 relevance_score**:计算要素和属性的保留程度
  130. ## 输出要求
  131. 请先思考,再打分。按以下顺序输出:
  132. 1. reason: 详细的评估理由(先分析再打分)
  133. - 原问题的[本质]是什么,推荐词是否匹配这个本质
  134. - [硬]约束哪些保留/缺失
  135. - [软]修饰哪些保留/缺失
  136. - 基于以上分析,给出意图匹配判断和相关性分数的依据
  137. 2. intent_match: true/false(基于上述分析得出)
  138. 3. relevance_score: 0-1 的浮点数(基于上述分析得出)
  139. """.strip()
  140. class RelevanceEvaluation(BaseModel):
  141. """评估反馈模型 - 意图匹配 + 相关性"""
  142. reason: str = Field(..., description="评估理由,需说明意图判断和相关性依据")
  143. intent_match: bool = Field(..., description="意图是否匹配")
  144. relevance_score: float = Field(..., description="相关性分数 0-1,分数越高越相关")
  145. evaluator = Agent[None](
  146. name="评估专家",
  147. instructions=eval_instructions,
  148. output_type=RelevanceEvaluation,
  149. )
  150. # ============================================================================
  151. # 核心函数
  152. # ============================================================================
  153. async def annotate_question(q_with_context: str) -> str:
  154. """标注问题(三层)"""
  155. print("\n正在标注问题...")
  156. result = await Runner.run(question_annotator, q_with_context)
  157. annotation = str(result.final_output)
  158. print(f"问题标注完成:{annotation}")
  159. return annotation
  160. async def segment_text(q: str) -> SegmentationResult:
  161. """分词"""
  162. print("\n正在分词...")
  163. result = await Runner.run(segmenter, q)
  164. seg_result: SegmentationResult = result.final_output
  165. print(f"分词结果:{seg_result.words}")
  166. print(f"分词说明:{seg_result.reasoning}")
  167. return seg_result
  168. def generate_query_combinations_single_level(
  169. keywords: list[str],
  170. size: int
  171. ) -> list[str]:
  172. """
  173. 生成单个层级的query组合
  174. Args:
  175. keywords: 关键词列表
  176. size: 组合词数
  177. Returns:
  178. 该层级的所有query组合
  179. """
  180. if size > len(keywords):
  181. return []
  182. # 1-word组合:不需要考虑顺序
  183. if size == 1:
  184. return keywords.copy()
  185. # 多词组合:先选择size个词(combinations),再排列(permutations)
  186. all_queries = []
  187. combs = list(combinations(keywords, size))
  188. for comb in combs:
  189. # 对每个组合生成所有排列
  190. perms = list(permutations(comb))
  191. for perm in perms:
  192. query = ''.join(perm) # 直接拼接,无空格
  193. all_queries.append(query)
  194. # 去重
  195. return list(dict.fromkeys(all_queries))
  196. async def evaluate_single_sug_with_semaphore(
  197. source_query: str,
  198. sug_query: str,
  199. original_question: str,
  200. question_annotation: str,
  201. semaphore: asyncio.Semaphore
  202. ) -> dict:
  203. """带信号量的单个推荐词评估"""
  204. async with semaphore:
  205. eval_input = f"""
  206. <原始问题>
  207. {original_question}
  208. </原始问题>
  209. <问题标注(三层)>
  210. {question_annotation}
  211. </问题标注(三层)>
  212. <待评估的推荐query>
  213. {sug_query}
  214. </待评估的推荐query>
  215. 请评估该推荐query(请先分析理由,再给出评分):
  216. 1. reason: 详细的评估理由(先思考分析)
  217. 2. intent_match: 意图是否匹配(true/false)
  218. 3. relevance_score: 相关性分数(0-1)
  219. 评估时请参考问题标注中的[本质]、[硬]、[软]标记。
  220. """
  221. result = await Runner.run(evaluator, eval_input)
  222. evaluation: RelevanceEvaluation = result.final_output
  223. return {
  224. "source_query": source_query,
  225. "sug_query": sug_query,
  226. "intent_match": evaluation.intent_match,
  227. "relevance_score": evaluation.relevance_score,
  228. "reason": evaluation.reason,
  229. }
  230. async def fetch_and_evaluate_level(
  231. queries: list[str],
  232. original_question: str,
  233. question_annotation: str,
  234. level_name: str,
  235. context: RunContext
  236. ) -> tuple[list[dict], list[dict]]:
  237. """
  238. 处理单个层级:获取推荐词并评估
  239. Returns:
  240. (sug_results, evaluations)
  241. """
  242. xiaohongshu_api = XiaohongshuSearchRecommendations()
  243. # 创建信号量
  244. api_semaphore = asyncio.Semaphore(API_CONCURRENCY_LIMIT)
  245. model_semaphore = asyncio.Semaphore(MODEL_CONCURRENCY_LIMIT)
  246. # 结果收集
  247. sug_results = []
  248. all_evaluations = []
  249. # 统计
  250. total_queries = len(queries)
  251. completed_queries = 0
  252. total_sugs = 0
  253. completed_evals = 0
  254. async def get_and_evaluate_single_query(query: str):
  255. nonlocal completed_queries, total_sugs, completed_evals
  256. # 步骤1:获取推荐词
  257. async with api_semaphore:
  258. suggestions = xiaohongshu_api.get_recommendations(keyword=query)
  259. sug_count = len(suggestions) if suggestions else 0
  260. completed_queries += 1
  261. total_sugs += sug_count
  262. print(f" [{completed_queries}/{total_queries}] {query} → {sug_count} 个推荐词")
  263. sug_result = {
  264. "query": query,
  265. "suggestions": suggestions or [],
  266. "timestamp": datetime.now().isoformat()
  267. }
  268. sug_results.append(sug_result)
  269. # 步骤2:立即评估这些推荐词
  270. if suggestions:
  271. eval_tasks = []
  272. for sug in suggestions:
  273. eval_tasks.append(evaluate_single_sug_with_semaphore(
  274. query, sug, original_question, question_annotation, model_semaphore
  275. ))
  276. if eval_tasks:
  277. evals = await asyncio.gather(*eval_tasks)
  278. all_evaluations.extend(evals)
  279. completed_evals += len(evals)
  280. print(f" ↳ 已评估 {len(evals)} 个,累计评估 {completed_evals} 个")
  281. # 并发处理所有query
  282. await asyncio.gather(*[get_and_evaluate_single_query(q) for q in queries])
  283. print(f"\n{level_name} 完成:获取 {total_sugs} 个推荐词,完成 {completed_evals} 个评估")
  284. return sug_results, all_evaluations
  285. def find_intent_matched_keywords(
  286. keywords: list[str],
  287. evaluations: list[dict]
  288. ) -> set[str]:
  289. """
  290. 找出所有至少有一个 intent_match=True 的推荐词的关键词
  291. Args:
  292. keywords: 当前层级使用的关键词列表
  293. evaluations: 该层级的评估结果
  294. Returns:
  295. 有意图匹配的关键词集合
  296. """
  297. matched_keywords = set()
  298. for keyword in keywords:
  299. # 检查这个关键词对应的推荐词中是否有 intent_match=True 的
  300. keyword_evals = [
  301. e for e in evaluations
  302. if e['source_query'] == keyword and e['intent_match'] is True
  303. ]
  304. if keyword_evals:
  305. matched_keywords.add(keyword)
  306. return matched_keywords
  307. def find_top_keywords_by_relevance(
  308. keywords: list[str],
  309. evaluations: list[dict],
  310. top_n: int = 2
  311. ) -> list[str]:
  312. """
  313. 根据 relevance_score 找出表现最好的 top N 关键词
  314. Args:
  315. keywords: 当前层级使用的关键词列表
  316. evaluations: 该层级的评估结果
  317. top_n: 保留的关键词数量
  318. Returns:
  319. 按平均 relevance_score 排序的 top N 关键词
  320. """
  321. keyword_scores = {}
  322. for keyword in keywords:
  323. # 找到这个关键词对应的所有评估
  324. keyword_evals = [
  325. e for e in evaluations
  326. if e['source_query'] == keyword
  327. ]
  328. if keyword_evals:
  329. # 计算平均 relevance_score
  330. avg_score = sum(e['relevance_score'] for e in keyword_evals) / len(keyword_evals)
  331. # 同时记录最高分,作为次要排序依据
  332. max_score = max(e['relevance_score'] for e in keyword_evals)
  333. keyword_scores[keyword] = {
  334. 'avg': avg_score,
  335. 'max': max_score,
  336. 'count': len(keyword_evals)
  337. }
  338. if not keyword_scores:
  339. return []
  340. # 按平均分降序,最高分降序
  341. sorted_keywords = sorted(
  342. keyword_scores.items(),
  343. key=lambda x: (x[1]['avg'], x[1]['max']),
  344. reverse=True
  345. )
  346. # 返回 top N 关键词
  347. return [kw for kw, score in sorted_keywords[:top_n]]
  348. def find_qualified_queries(evaluations: list[dict], min_relevance_score: float = 0.7) -> list[dict]:
  349. """
  350. 查找所有合格的query
  351. 筛选标准:
  352. 1. intent_match = True(必须满足)
  353. 2. relevance_score >= min_relevance_score
  354. 返回:按 relevance_score 降序排列
  355. """
  356. qualified = [
  357. e for e in evaluations
  358. if e['intent_match'] is True and e['relevance_score'] >= min_relevance_score
  359. ]
  360. # 按relevance_score降序排列
  361. return sorted(qualified, key=lambda x: x['relevance_score'], reverse=True)
  362. # ============================================================================
  363. # 主流程 - v6.4 层级剪枝
  364. # ============================================================================
  365. async def combinatorial_search_with_pruning(
  366. context: RunContext,
  367. max_combination_size: int = 1,
  368. fallback_top_n: int = 2
  369. ) -> dict:
  370. """
  371. 组合式搜索流程(带层级剪枝)
  372. 策略:
  373. - 第1层:所有单词都尝试
  374. - 第2层及以上:
  375. 1. 优先使用在上一层中至少有一个 intent_match=True 的关键词
  376. 2. 如果没有,则使用 relevance_score 最高的 top N 关键词
  377. 3. 如果也无法计算,则使用全部关键词
  378. Args:
  379. context: 运行上下文
  380. max_combination_size: 最大组合词数(N),默认1
  381. fallback_top_n: 当没有意图匹配时,使用 relevance_score top N 关键词,默认2
  382. 返回格式:
  383. {
  384. "success": True/False,
  385. "results": [...],
  386. "message": "..."
  387. }
  388. """
  389. # 步骤1:标注问题(三层)
  390. annotation = await annotate_question(context.q_with_context)
  391. context.question_annotation = annotation
  392. # 步骤2:分词
  393. seg_result = await segment_text(context.q)
  394. all_keywords = seg_result.words
  395. context.keywords = all_keywords
  396. # 初始化累积结果
  397. all_sug_results = []
  398. all_evaluations = []
  399. # 当前层可用的关键词(第1层是所有关键词)
  400. current_keywords = all_keywords.copy()
  401. print(f"\n{'='*60}")
  402. print(f"层级剪枝式搜索(最大层级:{max_combination_size})")
  403. print(f"{'='*60}")
  404. # 逐层处理
  405. for level in range(1, max_combination_size + 1):
  406. level_name = f"{level}-word"
  407. print(f"\n{'='*60}")
  408. print(f"第 {level} 层:{level_name}")
  409. print(f"{'='*60}")
  410. # 检查是否有可用关键词
  411. if not current_keywords:
  412. print(f"⚠️ 没有可用的关键词,跳过第 {level} 层")
  413. context.pruning_info[level_name] = {
  414. "available_keywords": [],
  415. "queries_count": 0,
  416. "pruned": True,
  417. "reason": "上一层没有任何 intent_match=True 的关键词"
  418. }
  419. break
  420. # 生成当前层的query组合
  421. level_queries = generate_query_combinations_single_level(current_keywords, level)
  422. if not level_queries:
  423. print(f"⚠️ 无法生成 {level} 词组合,跳过")
  424. context.pruning_info[level_name] = {
  425. "available_keywords": current_keywords,
  426. "queries_count": 0,
  427. "pruned": True,
  428. "reason": f"关键词数量不足以生成 {level} 词组合"
  429. }
  430. break
  431. print(f"可用关键词:{current_keywords}")
  432. print(f"生成的query数:{len(level_queries)}")
  433. # 记录该层的query组合
  434. context.query_combinations[level_name] = level_queries
  435. # 打印部分query示例
  436. print(f"\nquery示例(前10个):")
  437. for i, q in enumerate(level_queries[:10], 1):
  438. print(f" {i}. {q}")
  439. if len(level_queries) > 10:
  440. print(f" ... 还有 {len(level_queries) - 10} 个")
  441. # 获取推荐词并评估
  442. print(f"\n开始处理第 {level} 层的推荐词...")
  443. level_sug_results, level_evaluations = await fetch_and_evaluate_level(
  444. level_queries,
  445. context.q,
  446. annotation,
  447. level_name,
  448. context
  449. )
  450. # 累积结果
  451. all_sug_results.extend(level_sug_results)
  452. all_evaluations.extend(level_evaluations)
  453. # 统计该层的意图匹配情况
  454. intent_matched_count = sum(1 for e in level_evaluations if e['intent_match'] is True)
  455. print(f"\n第 {level} 层统计:")
  456. print(f" - 查询数:{len(level_queries)}")
  457. print(f" - 推荐词数:{sum(len(r['suggestions']) for r in level_sug_results)}")
  458. print(f" - 意图匹配数:{intent_matched_count}/{len(level_evaluations)}")
  459. # 记录剪枝信息
  460. context.pruning_info[level_name] = {
  461. "available_keywords": current_keywords,
  462. "queries_count": len(level_queries),
  463. "pruned": False,
  464. "intent_matched_count": intent_matched_count,
  465. "total_evaluations": len(level_evaluations)
  466. }
  467. # 如果还有下一层,找出有意图匹配的关键词用于下一层
  468. if level < max_combination_size:
  469. # 只在第1层时需要找出有意图匹配的关键词
  470. if level == 1:
  471. matched_keywords = find_intent_matched_keywords(current_keywords, level_evaluations)
  472. print(f"\n剪枝结果:")
  473. print(f" - 原始关键词数:{len(current_keywords)}")
  474. print(f" - 意图匹配关键词数:{len(matched_keywords)}")
  475. if matched_keywords:
  476. print(f" ✓ 策略:使用意图匹配的关键词")
  477. print(f" - 保留的关键词:{sorted(matched_keywords)}")
  478. current_keywords = list(matched_keywords)
  479. else:
  480. print(f" ⚠️ 没有任何关键词产生 intent_match=True 的推荐词")
  481. # 退而求其次:使用 relevance_score 最高的 top N 关键词
  482. top_keywords = find_top_keywords_by_relevance(current_keywords, level_evaluations, top_n=fallback_top_n)
  483. if top_keywords:
  484. print(f" ✓ 策略:使用 relevance_score 最高的 top {fallback_top_n} 关键词")
  485. print(f" - 保留的关键词:{top_keywords}")
  486. current_keywords = top_keywords
  487. # 显示关键词的得分详情
  488. for kw in top_keywords:
  489. kw_evals = [e for e in level_evaluations if e['source_query'] == kw]
  490. if kw_evals:
  491. avg_score = sum(e['relevance_score'] for e in kw_evals) / len(kw_evals)
  492. max_score = max(e['relevance_score'] for e in kw_evals)
  493. print(f" - {kw}: 平均={avg_score:.2f}, 最高={max_score:.2f}, 推荐词数={len(kw_evals)}")
  494. else:
  495. print(f" ⚠️ 无法计算 relevance_score,第2层将使用全部关键词")
  496. current_keywords = all_keywords.copy()
  497. # 保存累积结果
  498. context.all_sug_queries = all_sug_results
  499. context.evaluation_results = all_evaluations
  500. # 筛选合格query
  501. print(f"\n{'='*60}")
  502. print(f"筛选最终结果")
  503. print(f"{'='*60}")
  504. qualified = find_qualified_queries(all_evaluations, min_relevance_score=0.7)
  505. if qualified:
  506. return {
  507. "success": True,
  508. "results": qualified,
  509. "message": f"找到 {len(qualified)} 个合格query(intent_match=True 且 relevance>=0.7)"
  510. }
  511. # 降低标准
  512. acceptable = find_qualified_queries(all_evaluations, min_relevance_score=0.5)
  513. if acceptable:
  514. return {
  515. "success": True,
  516. "results": acceptable,
  517. "message": f"找到 {len(acceptable)} 个可接受query(intent_match=True 且 relevance>=0.5)"
  518. }
  519. # 完全失败:返回所有intent_match=True的
  520. intent_matched = [e for e in all_evaluations if e['intent_match'] is True]
  521. if intent_matched:
  522. intent_matched_sorted = sorted(intent_matched, key=lambda x: x['relevance_score'], reverse=True)
  523. return {
  524. "success": False,
  525. "results": intent_matched_sorted[:10], # 只返回前10个
  526. "message": f"未找到高相关性query,但有 {len(intent_matched)} 个意图匹配的推荐词"
  527. }
  528. return {
  529. "success": False,
  530. "results": [],
  531. "message": "未找到任何意图匹配的推荐词"
  532. }
  533. # ============================================================================
  534. # 输出格式化
  535. # ============================================================================
  536. def format_output(optimization_result: dict, context: RunContext) -> str:
  537. """格式化输出结果"""
  538. results = optimization_result.get("results", [])
  539. output = f"原始问题:{context.q}\n"
  540. output += f"问题标注:{context.question_annotation}\n"
  541. output += f"提取的关键词:{', '.join(context.keywords or [])}\n"
  542. output += f"关键词数量:{len(context.keywords or [])}\n"
  543. # 层级剪枝信息
  544. output += f"\n{'='*60}\n"
  545. output += f"层级剪枝信息:\n"
  546. output += f"{'='*60}\n"
  547. for level_name, info in context.pruning_info.items():
  548. output += f"\n{level_name}:\n"
  549. if info.get('pruned'):
  550. output += f" 状态:已剪枝 ✂️\n"
  551. output += f" 原因:{info.get('reason', '未知')}\n"
  552. else:
  553. output += f" 状态:已处理 ✓\n"
  554. output += f" 可用关键词数:{len(info['available_keywords'])}\n"
  555. output += f" 可用关键词:{info['available_keywords']}\n"
  556. output += f" 生成query数:{info['queries_count']}\n"
  557. output += f" 意图匹配数:{info.get('intent_matched_count', 0)}/{info.get('total_evaluations', 0)}\n"
  558. # query组合统计
  559. output += f"\n{'='*60}\n"
  560. output += f"query组合统计:\n"
  561. output += f"{'='*60}\n"
  562. for level, queries in context.query_combinations.items():
  563. output += f" - {level}: {len(queries)} 个\n"
  564. # 统计信息
  565. total_queries = sum(len(q) for q in context.query_combinations.values())
  566. total_sugs = sum(len(r["suggestions"]) for r in context.all_sug_queries)
  567. total_evals = len(context.evaluation_results)
  568. output += f"\n探索统计:\n"
  569. output += f" - 总query数:{total_queries}\n"
  570. output += f" - 总推荐词数:{total_sugs}\n"
  571. output += f" - 总评估数:{total_evals}\n"
  572. output += f"\n状态:{optimization_result['message']}\n\n"
  573. if optimization_result["success"] and results:
  574. output += "=" * 60 + "\n"
  575. output += "合格的推荐query(按relevance_score降序):\n"
  576. output += "=" * 60 + "\n"
  577. for i, result in enumerate(results[:20], 1): # 只显示前20个
  578. output += f"\n{i}. [{result['relevance_score']:.2f}] {result['sug_query']}\n"
  579. output += f" 来源:{result['source_query']}\n"
  580. output += f" 意图:{'✓ 匹配' if result['intent_match'] else '✗ 不匹配'}\n"
  581. output += f" 理由:{result['reason'][:150]}...\n" if len(result['reason']) > 150 else f" 理由:{result['reason']}\n"
  582. else:
  583. output += "=" * 60 + "\n"
  584. output += "结果:未找到足够相关的推荐query\n"
  585. output += "=" * 60 + "\n"
  586. if results:
  587. output += "\n最接近的推荐词(前10个):\n\n"
  588. for i, result in enumerate(results[:10], 1):
  589. output += f"{i}. [{result['relevance_score']:.2f}] {result['sug_query']}\n"
  590. output += f" 来源:{result['source_query']}\n"
  591. output += f" 意图:{'✓ 匹配' if result['intent_match'] else '✗ 不匹配'}\n\n"
  592. # 按source_query分组显示
  593. output += "\n" + "=" * 60 + "\n"
  594. output += "按查询词分组的推荐词情况:\n"
  595. output += "=" * 60 + "\n"
  596. for sug_data in context.all_sug_queries:
  597. source_q = sug_data["query"]
  598. sugs = sug_data["suggestions"]
  599. # 找到这个source_query对应的所有评估
  600. related_evals = [e for e in context.evaluation_results if e["source_query"] == source_q]
  601. intent_match_count = sum(1 for e in related_evals if e["intent_match"])
  602. avg_relevance = sum(e["relevance_score"] for e in related_evals) / len(related_evals) if related_evals else 0
  603. output += f"\n查询:{source_q}\n"
  604. output += f" 推荐词数:{len(sugs)}\n"
  605. output += f" 意图匹配数:{intent_match_count}/{len(related_evals)}\n"
  606. output += f" 平均相关性:{avg_relevance:.2f}\n"
  607. # 显示前3个推荐词
  608. if sugs:
  609. output += f" 示例推荐词:\n"
  610. for sug in sugs[:3]:
  611. eval_item = next((e for e in related_evals if e["sug_query"] == sug), None)
  612. if eval_item:
  613. output += f" - {sug} [意图:{'✓' if eval_item['intent_match'] else '✗'}, 相关:{eval_item['relevance_score']:.2f}]\n"
  614. else:
  615. output += f" - {sug}\n"
  616. return output.strip()
  617. # ============================================================================
  618. # 主函数
  619. # ============================================================================
  620. async def main(
  621. input_dir: str,
  622. max_combination_size: int = 1,
  623. api_concurrency: int = API_CONCURRENCY_LIMIT,
  624. model_concurrency: int = MODEL_CONCURRENCY_LIMIT,
  625. fallback_top_n: int = 2
  626. ):
  627. # 更新全局并发配置
  628. global API_CONCURRENCY_LIMIT, MODEL_CONCURRENCY_LIMIT
  629. API_CONCURRENCY_LIMIT = api_concurrency
  630. MODEL_CONCURRENCY_LIMIT = model_concurrency
  631. current_time, log_url = set_trace()
  632. # 从目录中读取固定文件名
  633. input_context_file = os.path.join(input_dir, 'context.md')
  634. input_q_file = os.path.join(input_dir, 'q.md')
  635. q_context = read_file_as_string(input_context_file)
  636. q = read_file_as_string(input_q_file)
  637. q_with_context = f"""
  638. <需求上下文>
  639. {q_context}
  640. </需求上下文>
  641. <当前问题>
  642. {q}
  643. </当前问题>
  644. """.strip()
  645. # 获取当前文件名作为版本
  646. version = os.path.basename(__file__)
  647. version_name = os.path.splitext(version)[0]
  648. # 日志保存目录
  649. log_dir = os.path.join(input_dir, "output", version_name, current_time)
  650. run_context = RunContext(
  651. version=version,
  652. input_files={
  653. "input_dir": input_dir,
  654. "context_file": input_context_file,
  655. "q_file": input_q_file,
  656. },
  657. q_with_context=q_with_context,
  658. q_context=q_context,
  659. q=q,
  660. log_dir=log_dir,
  661. log_url=log_url,
  662. )
  663. print(f"\n{'='*60}")
  664. print(f"并发配置")
  665. print(f"{'='*60}")
  666. print(f"API请求并发度:{API_CONCURRENCY_LIMIT}")
  667. print(f"模型评估并发度:{MODEL_CONCURRENCY_LIMIT}")
  668. # 执行层级剪枝式搜索
  669. optimization_result = await combinatorial_search_with_pruning(
  670. run_context,
  671. max_combination_size=max_combination_size,
  672. fallback_top_n=fallback_top_n
  673. )
  674. # 格式化输出
  675. final_output = format_output(optimization_result, run_context)
  676. print(f"\n{'='*60}")
  677. print("最终结果")
  678. print(f"{'='*60}")
  679. print(final_output)
  680. # 保存结果
  681. run_context.optimization_result = optimization_result
  682. run_context.final_output = final_output
  683. # 保存 RunContext 到 log_dir
  684. os.makedirs(run_context.log_dir, exist_ok=True)
  685. context_file_path = os.path.join(run_context.log_dir, "run_context.json")
  686. with open(context_file_path, "w", encoding="utf-8") as f:
  687. json.dump(run_context.model_dump(), f, ensure_ascii=False, indent=2)
  688. print(f"\nRunContext saved to: {context_file_path}")
  689. if __name__ == "__main__":
  690. parser = argparse.ArgumentParser(
  691. description="搜索query优化工具 - v6.4 层级剪枝版",
  692. formatter_class=argparse.RawDescriptionHelpFormatter,
  693. epilog="""
  694. 示例:
  695. # 默认参数(只搜索1层)
  696. python sug_v6_4_with_annotation.py
  697. # 2层搜索,第2层只使用第1层中有意图匹配的关键词
  698. python sug_v6_4_with_annotation.py --max-combo 2
  699. # 2层搜索,如果第1层没有意图匹配,则使用 top 3 关键词
  700. python sug_v6_4_with_annotation.py --max-combo 2 --fallback-top 3
  701. # 3层搜索,API并发5,模型并发20
  702. python sug_v6_4_with_annotation.py --max-combo 3 --api-concurrency 5 --model-concurrency 20
  703. # 指定输入目录
  704. python sug_v6_4_with_annotation.py --input-dir "input/旅游-逸趣玩旅行/如何获取能体现川西秋季特色的高质量风光摄影素材?"
  705. """
  706. )
  707. parser.add_argument(
  708. "--input-dir",
  709. type=str,
  710. default="input/简单扣图",
  711. help="输入目录路径,默认: input/简单扣图"
  712. )
  713. parser.add_argument(
  714. "--max-combo",
  715. type=int,
  716. default=1,
  717. help="最大组合词数(N),默认: 1"
  718. )
  719. parser.add_argument(
  720. "--fallback-top",
  721. type=int,
  722. default=2,
  723. help="当第1层没有意图匹配时,使用 relevance_score top N 关键词,默认: 2"
  724. )
  725. parser.add_argument(
  726. "--api-concurrency",
  727. type=int,
  728. default=API_CONCURRENCY_LIMIT,
  729. help=f"API请求并发度,默认: {API_CONCURRENCY_LIMIT}"
  730. )
  731. parser.add_argument(
  732. "--model-concurrency",
  733. type=int,
  734. default=MODEL_CONCURRENCY_LIMIT,
  735. help=f"模型评估并发度,默认: {MODEL_CONCURRENCY_LIMIT}"
  736. )
  737. args = parser.parse_args()
  738. asyncio.run(main(
  739. args.input_dir,
  740. max_combination_size=args.max_combo,
  741. api_concurrency=args.api_concurrency,
  742. model_concurrency=args.model_concurrency,
  743. fallback_top_n=args.fallback_top
  744. ))