sug_v6_3_with_annotation.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712
  1. import asyncio
  2. import json
  3. import os
  4. import argparse
  5. from datetime import datetime
  6. from itertools import combinations, permutations
  7. from agents import Agent, Runner
  8. from lib.my_trace import set_trace
  9. from typing import Literal
  10. from pydantic import BaseModel, Field
  11. from lib.utils import read_file_as_string
  12. from script.search_recommendations.xiaohongshu_search_recommendations import XiaohongshuSearchRecommendations
  13. # ============================================================================
  14. # 并发控制配置
  15. # ============================================================================
  16. # API请求并发度(小红书接口)
  17. API_CONCURRENCY_LIMIT = 5
  18. # 模型评估并发度(GPT评估)
  19. MODEL_CONCURRENCY_LIMIT = 10
  20. class RunContext(BaseModel):
  21. version: str = Field(..., description="当前运行的脚本版本(文件名)")
  22. input_files: dict[str, str] = Field(..., description="输入文件路径映射")
  23. q_with_context: str
  24. q_context: str
  25. q: str
  26. log_url: str
  27. log_dir: str
  28. # 问题标注
  29. question_annotation: str | None = Field(default=None, description="问题的标注结果(三层)")
  30. # 分词和组合
  31. keywords: list[str] | None = Field(default=None, description="提取的关键词")
  32. query_combinations: dict[str, list[str]] = Field(default_factory=dict, description="各层级的query组合")
  33. # 探索结果
  34. all_sug_queries: list[dict] = Field(default_factory=list, description="所有获取到的推荐词")
  35. # 评估结果
  36. evaluation_results: list[dict] = Field(default_factory=list, description="所有推荐词的评估结果")
  37. optimization_result: dict | None = Field(default=None, description="最终优化结果对象")
  38. final_output: str | None = Field(default=None, description="最终输出结果(格式化文本)")
  39. # ============================================================================
  40. # Agent 1: 问题标注专家
  41. # ============================================================================
  42. question_annotation_instructions = """
  43. 你是搜索需求分析专家。给定问题(含需求背景),在原文上标注三层:本质、硬性、软性。
  44. ## 判断标准
  45. **[本质]** - 问题的核心意图
  46. - 如何获取、教程、推荐、作品、测评等
  47. **[硬]** - 客观事实性约束(可明确验证、非主观判断)
  48. - 能明确区分类别的:地域、时间、对象、工具、操作类型
  49. - 特征:改变后得到完全不同类别的结果
  50. **[软]** - 主观判断性修饰(因人而异、程度性的)
  51. - 需要主观评价的:质量、速度、美观、特色、程度
  52. - 特征:改变后仍是同类结果,只是满足程度不同
  53. ## 输出格式
  54. 词语[本质-描述]、词语[硬-描述]、词语[软-描述]
  55. ## 注意
  56. - 只输出标注后的字符串
  57. - 结合需求背景判断意图
  58. """.strip()
  59. question_annotator = Agent[None](
  60. name="问题标注专家",
  61. instructions=question_annotation_instructions,
  62. )
  63. # ============================================================================
  64. # Agent 2: 分词专家
  65. # ============================================================================
  66. segmentation_instructions = """
  67. 你是中文分词专家。给定一个句子,将其分词。
  68. ## 分词原则
  69. 1. 去掉标点符号
  70. 2. 拆分成最小的有意义单元
  71. 3. 去掉助词、语气词、助动词
  72. 4. 保留疑问词
  73. 5. 保留实词:名词、动词、形容词、副词
  74. ## 输出要求
  75. 输出分词列表。
  76. """.strip()
  77. class SegmentationResult(BaseModel):
  78. """分词结果"""
  79. words: list[str] = Field(..., description="分词列表")
  80. reasoning: str = Field(..., description="分词说明")
  81. segmenter = Agent[None](
  82. name="分词专家",
  83. instructions=segmentation_instructions,
  84. output_type=SegmentationResult,
  85. )
  86. # ============================================================================
  87. # Agent 3: 评估专家(意图匹配 + 相关性评分)
  88. # ============================================================================
  89. eval_instructions = """
  90. 你是搜索query评估专家。给定原始问题、问题标注和推荐query,评估两个维度。
  91. ## 输入信息
  92. 你会收到:
  93. 1. 原始问题:用户的原始表述
  94. 2. 问题标注:对原始问题的三层标注(本质、硬性、软性)
  95. 3. 推荐query:待评估的推荐词
  96. ## 评估目标
  97. 用这个推荐query搜索,能否找到满足原始需求的内容?
  98. ## 两层评分
  99. ### 1. intent_match(意图匹配)= true/false
  100. 推荐query的**使用意图**是否与原问题的**本质**一致?
  101. **核心:只关注[本质]标注**
  102. - 问题标注中的 `[本质-XXX]` 标记明确了用户的核心意图
  103. - 判断推荐词是否能达成这个核心意图
  104. **常见本质类型:**
  105. - 找方法/如何获取 → 推荐词应包含方法、途径、网站、渠道等
  106. - 找教程 → 推荐词应是教程、教学相关
  107. - 找资源/素材 → 推荐词应是资源、素材本身
  108. - 找工具 → 推荐词应是工具推荐
  109. - 看作品 → 推荐词应是作品展示
  110. **评分:**
  111. - true = 推荐词的意图与 `[本质]` 一致
  112. - false = 推荐词的意图与 `[本质]` 不一致
  113. ### 2. relevance_score(相关性)= 0-1 连续分数
  114. 在意图匹配的前提下,推荐query在**主题、要素、属性**上与原问题的相关程度?
  115. **评估维度:**
  116. - 主题相关:核心主题是否匹配?(如:摄影、旅游、美食)
  117. - 要素覆盖:`[硬-XXX]` 标记的硬性约束保留了多少?(地域、时间、对象、工具等)
  118. - 属性匹配:`[软-XXX]` 标记的软性修饰保留了多少?(质量、速度、美观等)
  119. **评分参考:**
  120. - 0.9-1.0 = 几乎完美匹配,[硬]和[软]标注的要素都保留
  121. - 0.7-0.8 = 高度相关,[硬]标注的要素都保留,[软]标注少数缺失
  122. - 0.5-0.6 = 中度相关,[硬]标注的要素保留大部分,[软]标注多数缺失
  123. - 0.3-0.4 = 低度相关,[硬]标注的要素部分缺失
  124. - 0-0.2 = 基本不相关,[硬]标注的要素大量缺失
  125. ## 评估策略
  126. 1. **先看[本质]判断 intent_match**:意图不匹配直接 false
  127. 2. **再看[硬][软]评估 relevance_score**:计算要素和属性的保留程度
  128. ## 输出要求
  129. - intent_match: true/false
  130. - relevance_score: 0-1 的浮点数
  131. - reason: 详细的评估理由,需要说明:
  132. - 原问题的[本质]是什么,推荐词是否匹配这个本质
  133. - [硬]约束哪些保留/缺失
  134. - [软]修饰哪些保留/缺失
  135. - 最终相关性分数的依据
  136. """.strip()
  137. class RelevanceEvaluation(BaseModel):
  138. """评估反馈模型 - 意图匹配 + 相关性"""
  139. intent_match: bool = Field(..., description="意图是否匹配")
  140. relevance_score: float = Field(..., description="相关性分数 0-1,分数越高越相关")
  141. reason: str = Field(..., description="评估理由,需说明意图判断和相关性依据")
  142. evaluator = Agent[None](
  143. name="评估专家",
  144. instructions=eval_instructions,
  145. output_type=RelevanceEvaluation,
  146. )
  147. # ============================================================================
  148. # 核心函数
  149. # ============================================================================
  150. async def annotate_question(q_with_context: str) -> str:
  151. """标注问题(三层)"""
  152. print("\n正在标注问题...")
  153. result = await Runner.run(question_annotator, q_with_context)
  154. annotation = str(result.final_output)
  155. print(f"问题标注完成:{annotation}")
  156. return annotation
  157. async def segment_text(q: str) -> SegmentationResult:
  158. """分词"""
  159. print("\n正在分词...")
  160. result = await Runner.run(segmenter, q)
  161. seg_result: SegmentationResult = result.final_output
  162. print(f"分词结果:{seg_result.words}")
  163. print(f"分词说明:{seg_result.reasoning}")
  164. return seg_result
  165. def generate_query_combinations(keywords: list[str], max_combination_size: int) -> dict[str, list[str]]:
  166. """
  167. 生成query组合(考虑词的顺序)
  168. Args:
  169. keywords: 关键词列表
  170. max_combination_size: 最大组合词数(N)
  171. Returns:
  172. {
  173. "1-word": [...],
  174. "2-word": [...],
  175. "3-word": [...],
  176. ...
  177. "N-word": [...]
  178. }
  179. """
  180. result = {}
  181. for size in range(1, max_combination_size + 1):
  182. if size > len(keywords):
  183. break
  184. # 1-word组合:不需要考虑顺序
  185. if size == 1:
  186. queries = keywords.copy()
  187. else:
  188. # 多词组合:先选择size个词(combinations),再排列(permutations)
  189. all_queries = []
  190. combs = list(combinations(keywords, size))
  191. for comb in combs:
  192. # 对每个组合生成所有排列
  193. perms = list(permutations(comb))
  194. for perm in perms:
  195. query = ''.join(perm) # 直接拼接,无空格
  196. all_queries.append(query)
  197. # 去重(虽然理论上不会重复,但保险起见)
  198. queries = list(dict.fromkeys(all_queries))
  199. result[f"{size}-word"] = queries
  200. print(f"\n{size}词组合:{len(queries)} 个")
  201. if len(queries) <= 10:
  202. for q in queries:
  203. print(f" - {q}")
  204. else:
  205. print(f" - {queries[0]}")
  206. print(f" - {queries[1]}")
  207. print(f" ...")
  208. print(f" - {queries[-1]}")
  209. return result
  210. async def fetch_suggestions_for_queries(queries: list[str], context: RunContext) -> list[dict]:
  211. """
  212. 并发获取所有query的推荐词(带并发控制)
  213. Returns:
  214. [
  215. {
  216. "query": "川西",
  217. "suggestions": ["川西旅游", "川西攻略", ...],
  218. "timestamp": "..."
  219. },
  220. ...
  221. ]
  222. """
  223. print(f"\n{'='*60}")
  224. print(f"获取推荐词:{len(queries)} 个query(并发度:{API_CONCURRENCY_LIMIT})")
  225. print(f"{'='*60}")
  226. xiaohongshu_api = XiaohongshuSearchRecommendations()
  227. # 创建信号量控制并发
  228. semaphore = asyncio.Semaphore(API_CONCURRENCY_LIMIT)
  229. async def get_single_sug(query: str):
  230. async with semaphore:
  231. print(f" 查询: {query}")
  232. suggestions = xiaohongshu_api.get_recommendations(keyword=query)
  233. print(f" → {len(suggestions) if suggestions else 0} 个推荐词")
  234. return {
  235. "query": query,
  236. "suggestions": suggestions or [],
  237. "timestamp": datetime.now().isoformat()
  238. }
  239. results = await asyncio.gather(*[get_single_sug(q) for q in queries])
  240. return results
  241. async def evaluate_all_suggestions(
  242. sug_results: list[dict],
  243. original_question: str,
  244. question_annotation: str,
  245. context: RunContext
  246. ) -> list[dict]:
  247. """
  248. 评估所有推荐词(带并发控制)
  249. Args:
  250. sug_results: 所有query的推荐词结果
  251. original_question: 原始问题
  252. question_annotation: 问题标注(三层)
  253. Returns:
  254. [
  255. {
  256. "source_query": "川西秋季",
  257. "sug_query": "川西秋季旅游",
  258. "intent_match": True,
  259. "relevance_score": 0.8,
  260. "reason": "..."
  261. },
  262. ...
  263. ]
  264. """
  265. print(f"\n{'='*60}")
  266. print(f"评估推荐词(并发度:{MODEL_CONCURRENCY_LIMIT})")
  267. print(f"{'='*60}")
  268. # 创建信号量控制并发
  269. semaphore = asyncio.Semaphore(MODEL_CONCURRENCY_LIMIT)
  270. # 收集所有推荐词
  271. all_evaluations = []
  272. async def evaluate_single_sug(source_query: str, sug_query: str):
  273. async with semaphore:
  274. eval_input = f"""
  275. <原始问题>
  276. {original_question}
  277. </原始问题>
  278. <问题标注(三层)>
  279. {question_annotation}
  280. </问题标注(三层)>
  281. <待评估的推荐query>
  282. {sug_query}
  283. </待评估的推荐query>
  284. 请评估该推荐query:
  285. 1. intent_match: 意图是否匹配(true/false)
  286. 2. relevance_score: 相关性分数(0-1)
  287. 3. reason: 详细的评估理由
  288. 评估时请参考问题标注中的[本质]、[硬]、[软]标记。
  289. """
  290. result = await Runner.run(evaluator, eval_input)
  291. evaluation: RelevanceEvaluation = result.final_output
  292. return {
  293. "source_query": source_query,
  294. "sug_query": sug_query,
  295. "intent_match": evaluation.intent_match,
  296. "relevance_score": evaluation.relevance_score,
  297. "reason": evaluation.reason,
  298. }
  299. # 并发评估所有推荐词
  300. tasks = []
  301. for sug_result in sug_results:
  302. source_query = sug_result["query"]
  303. for sug in sug_result["suggestions"]:
  304. tasks.append(evaluate_single_sug(source_query, sug))
  305. if tasks:
  306. print(f" 总共需要评估 {len(tasks)} 个推荐词...")
  307. all_evaluations = await asyncio.gather(*tasks)
  308. context.evaluation_results = all_evaluations
  309. return all_evaluations
  310. def find_qualified_queries(evaluations: list[dict], min_relevance_score: float = 0.7) -> list[dict]:
  311. """
  312. 查找所有合格的query
  313. 筛选标准:
  314. 1. intent_match = True(必须满足)
  315. 2. relevance_score >= min_relevance_score
  316. 返回:按 relevance_score 降序排列
  317. """
  318. qualified = [
  319. e for e in evaluations
  320. if e['intent_match'] is True and e['relevance_score'] >= min_relevance_score
  321. ]
  322. # 按relevance_score降序排列
  323. return sorted(qualified, key=lambda x: x['relevance_score'], reverse=True)
  324. # ============================================================================
  325. # 主流程
  326. # ============================================================================
  327. async def combinatorial_search(context: RunContext, max_combination_size: int = 1) -> dict:
  328. """
  329. 组合式搜索流程(带问题标注)
  330. Args:
  331. context: 运行上下文
  332. max_combination_size: 最大组合词数(N),默认1
  333. 返回格式:
  334. {
  335. "success": True/False,
  336. "results": [...],
  337. "message": "..."
  338. }
  339. """
  340. # 步骤1:标注问题(三层)
  341. annotation = await annotate_question(context.q_with_context)
  342. context.question_annotation = annotation
  343. # 步骤2:分词
  344. seg_result = await segment_text(context.q)
  345. context.keywords = seg_result.words
  346. # 步骤3:生成query组合
  347. print(f"\n{'='*60}")
  348. print(f"生成query组合(最大组合数:{max_combination_size})")
  349. print(f"{'='*60}")
  350. query_combinations = generate_query_combinations(context.keywords, max_combination_size)
  351. context.query_combinations = query_combinations
  352. # 步骤4:获取所有query的推荐词
  353. all_queries = []
  354. for level, queries in query_combinations.items():
  355. all_queries.extend(queries)
  356. sug_results = await fetch_suggestions_for_queries(all_queries, context)
  357. context.all_sug_queries = sug_results
  358. # 统计
  359. total_sugs = sum(len(r["suggestions"]) for r in sug_results)
  360. print(f"\n总共获取到 {total_sugs} 个推荐词")
  361. # 步骤5:评估所有推荐词(使用原始问题和标注)
  362. evaluations = await evaluate_all_suggestions(sug_results, context.q, annotation, context)
  363. # 步骤6:筛选合格query
  364. qualified = find_qualified_queries(evaluations, min_relevance_score=0.7)
  365. if qualified:
  366. return {
  367. "success": True,
  368. "results": qualified,
  369. "message": f"找到 {len(qualified)} 个合格query(intent_match=True 且 relevance>=0.7)"
  370. }
  371. # 降低标准
  372. acceptable = find_qualified_queries(evaluations, min_relevance_score=0.5)
  373. if acceptable:
  374. return {
  375. "success": True,
  376. "results": acceptable,
  377. "message": f"找到 {len(acceptable)} 个可接受query(intent_match=True 且 relevance>=0.5)"
  378. }
  379. # 完全失败:返回所有intent_match=True的
  380. intent_matched = [e for e in evaluations if e['intent_match'] is True]
  381. if intent_matched:
  382. intent_matched_sorted = sorted(intent_matched, key=lambda x: x['relevance_score'], reverse=True)
  383. return {
  384. "success": False,
  385. "results": intent_matched_sorted[:10], # 只返回前10个
  386. "message": f"未找到高相关性query,但有 {len(intent_matched)} 个意图匹配的推荐词"
  387. }
  388. return {
  389. "success": False,
  390. "results": [],
  391. "message": "未找到任何意图匹配的推荐词"
  392. }
  393. # ============================================================================
  394. # 输出格式化
  395. # ============================================================================
  396. def format_output(optimization_result: dict, context: RunContext) -> str:
  397. """格式化输出结果"""
  398. results = optimization_result.get("results", [])
  399. output = f"原始问题:{context.q}\n"
  400. output += f"问题标注:{context.question_annotation}\n"
  401. output += f"提取的关键词:{', '.join(context.keywords or [])}\n"
  402. output += f"关键词数量:{len(context.keywords or [])}\n"
  403. output += f"\nquery组合统计:\n"
  404. for level, queries in context.query_combinations.items():
  405. output += f" - {level}: {len(queries)} 个\n"
  406. # 统计信息
  407. total_queries = sum(len(q) for q in context.query_combinations.values())
  408. total_sugs = sum(len(r["suggestions"]) for r in context.all_sug_queries)
  409. total_evals = len(context.evaluation_results)
  410. output += f"\n探索统计:\n"
  411. output += f" - 总query数:{total_queries}\n"
  412. output += f" - 总推荐词数:{total_sugs}\n"
  413. output += f" - 总评估数:{total_evals}\n"
  414. output += f"\n状态:{optimization_result['message']}\n\n"
  415. if optimization_result["success"] and results:
  416. output += "=" * 60 + "\n"
  417. output += "合格的推荐query(按relevance_score降序):\n"
  418. output += "=" * 60 + "\n"
  419. for i, result in enumerate(results[:20], 1): # 只显示前20个
  420. output += f"\n{i}. [{result['relevance_score']:.2f}] {result['sug_query']}\n"
  421. output += f" 来源:{result['source_query']}\n"
  422. output += f" 意图:{'✓ 匹配' if result['intent_match'] else '✗ 不匹配'}\n"
  423. output += f" 理由:{result['reason'][:150]}...\n" if len(result['reason']) > 150 else f" 理由:{result['reason']}\n"
  424. else:
  425. output += "=" * 60 + "\n"
  426. output += "结果:未找到足够相关的推荐query\n"
  427. output += "=" * 60 + "\n"
  428. if results:
  429. output += "\n最接近的推荐词(前10个):\n\n"
  430. for i, result in enumerate(results[:10], 1):
  431. output += f"{i}. [{result['relevance_score']:.2f}] {result['sug_query']}\n"
  432. output += f" 来源:{result['source_query']}\n"
  433. output += f" 意图:{'✓ 匹配' if result['intent_match'] else '✗ 不匹配'}\n\n"
  434. # 按source_query分组显示
  435. output += "\n" + "=" * 60 + "\n"
  436. output += "按查询词分组的推荐词情况:\n"
  437. output += "=" * 60 + "\n"
  438. for sug_data in context.all_sug_queries:
  439. source_q = sug_data["query"]
  440. sugs = sug_data["suggestions"]
  441. # 找到这个source_query对应的所有评估
  442. related_evals = [e for e in context.evaluation_results if e["source_query"] == source_q]
  443. intent_match_count = sum(1 for e in related_evals if e["intent_match"])
  444. avg_relevance = sum(e["relevance_score"] for e in related_evals) / len(related_evals) if related_evals else 0
  445. output += f"\n查询:{source_q}\n"
  446. output += f" 推荐词数:{len(sugs)}\n"
  447. output += f" 意图匹配数:{intent_match_count}/{len(related_evals)}\n"
  448. output += f" 平均相关性:{avg_relevance:.2f}\n"
  449. # 显示前3个推荐词
  450. if sugs:
  451. output += f" 示例推荐词:\n"
  452. for sug in sugs[:3]:
  453. eval_item = next((e for e in related_evals if e["sug_query"] == sug), None)
  454. if eval_item:
  455. output += f" - {sug} [意图:{'✓' if eval_item['intent_match'] else '✗'}, 相关:{eval_item['relevance_score']:.2f}]\n"
  456. else:
  457. output += f" - {sug}\n"
  458. return output.strip()
  459. # ============================================================================
  460. # 主函数
  461. # ============================================================================
  462. async def main(
  463. input_dir: str,
  464. max_combination_size: int = 1,
  465. api_concurrency: int = API_CONCURRENCY_LIMIT,
  466. model_concurrency: int = MODEL_CONCURRENCY_LIMIT
  467. ):
  468. # 更新全局并发配置
  469. global API_CONCURRENCY_LIMIT, MODEL_CONCURRENCY_LIMIT
  470. API_CONCURRENCY_LIMIT = api_concurrency
  471. MODEL_CONCURRENCY_LIMIT = model_concurrency
  472. current_time, log_url = set_trace()
  473. # 从目录中读取固定文件名
  474. input_context_file = os.path.join(input_dir, 'context.md')
  475. input_q_file = os.path.join(input_dir, 'q.md')
  476. q_context = read_file_as_string(input_context_file)
  477. q = read_file_as_string(input_q_file)
  478. q_with_context = f"""
  479. <需求上下文>
  480. {q_context}
  481. </需求上下文>
  482. <当前问题>
  483. {q}
  484. </当前问题>
  485. """.strip()
  486. # 获取当前文件名作为版本
  487. version = os.path.basename(__file__)
  488. version_name = os.path.splitext(version)[0]
  489. # 日志保存目录
  490. log_dir = os.path.join(input_dir, "output", version_name, current_time)
  491. run_context = RunContext(
  492. version=version,
  493. input_files={
  494. "input_dir": input_dir,
  495. "context_file": input_context_file,
  496. "q_file": input_q_file,
  497. },
  498. q_with_context=q_with_context,
  499. q_context=q_context,
  500. q=q,
  501. log_dir=log_dir,
  502. log_url=log_url,
  503. )
  504. print(f"\n{'='*60}")
  505. print(f"并发配置")
  506. print(f"{'='*60}")
  507. print(f"API请求并发度:{API_CONCURRENCY_LIMIT}")
  508. print(f"模型评估并发度:{MODEL_CONCURRENCY_LIMIT}")
  509. # 执行组合式搜索(带问题标注)
  510. optimization_result = await combinatorial_search(run_context, max_combination_size=max_combination_size)
  511. # 格式化输出
  512. final_output = format_output(optimization_result, run_context)
  513. print(f"\n{'='*60}")
  514. print("最终结果")
  515. print(f"{'='*60}")
  516. print(final_output)
  517. # 保存结果
  518. run_context.optimization_result = optimization_result
  519. run_context.final_output = final_output
  520. # 保存 RunContext 到 log_dir
  521. os.makedirs(run_context.log_dir, exist_ok=True)
  522. context_file_path = os.path.join(run_context.log_dir, "run_context.json")
  523. with open(context_file_path, "w", encoding="utf-8") as f:
  524. json.dump(run_context.model_dump(), f, ensure_ascii=False, indent=2)
  525. print(f"\nRunContext saved to: {context_file_path}")
  526. if __name__ == "__main__":
  527. parser = argparse.ArgumentParser(
  528. description="搜索query优化工具 - v6.3 组合式搜索+问题标注版",
  529. formatter_class=argparse.RawDescriptionHelpFormatter,
  530. epilog="""
  531. 示例:
  532. # 默认参数
  533. python sug_v6_3_with_annotation.py
  534. # 2词组合,API并发5,模型并发20
  535. python sug_v6_3_with_annotation.py --max-combo 2 --api-concurrency 5 --model-concurrency 20
  536. # 3词组合,降低并发度
  537. python sug_v6_3_with_annotation.py --max-combo 3 --api-concurrency 3 --model-concurrency 10
  538. """
  539. )
  540. parser.add_argument(
  541. "--input-dir",
  542. type=str,
  543. default="input/简单扣图",
  544. help="输入目录路径,默认: input/简单扣图"
  545. )
  546. parser.add_argument(
  547. "--max-combo",
  548. type=int,
  549. default=1,
  550. help="最大组合词数(N),默认: 1"
  551. )
  552. parser.add_argument(
  553. "--api-concurrency",
  554. type=int,
  555. default=API_CONCURRENCY_LIMIT,
  556. help=f"API请求并发度,默认: {API_CONCURRENCY_LIMIT}"
  557. )
  558. parser.add_argument(
  559. "--model-concurrency",
  560. type=int,
  561. default=MODEL_CONCURRENCY_LIMIT,
  562. help=f"模型评估并发度,默认: {MODEL_CONCURRENCY_LIMIT}"
  563. )
  564. args = parser.parse_args()
  565. asyncio.run(main(
  566. args.input_dir,
  567. max_combination_size=args.max_combo,
  568. api_concurrency=args.api_concurrency,
  569. model_concurrency=args.model_concurrency
  570. ))