sug_v6_2_combinatorial.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557
  1. import asyncio
  2. import json
  3. import os
  4. import argparse
  5. from datetime import datetime
  6. from itertools import combinations
  7. from agents import Agent, Runner
  8. from lib.my_trace import set_trace
  9. from typing import Literal
  10. from pydantic import BaseModel, Field
  11. from lib.utils import read_file_as_string
  12. from script.search_recommendations.xiaohongshu_search_recommendations import XiaohongshuSearchRecommendations
  13. class RunContext(BaseModel):
  14. version: str = Field(..., description="当前运行的脚本版本(文件名)")
  15. input_files: dict[str, str] = Field(..., description="输入文件路径映射")
  16. q_with_context: str
  17. q_context: str
  18. q: str
  19. log_url: str
  20. log_dir: str
  21. # 分词和组合
  22. keywords: list[str] | None = Field(default=None, description="提取的关键词")
  23. query_combinations: dict[str, list[str]] = Field(default_factory=dict, description="各层级的query组合")
  24. # 探索结果
  25. all_sug_queries: list[dict] = Field(default_factory=list, description="所有获取到的推荐词")
  26. # 评估结果
  27. evaluation_results: list[dict] = Field(default_factory=list, description="所有推荐词的评估结果")
  28. optimization_result: dict | None = Field(default=None, description="最终优化结果对象")
  29. final_output: str | None = Field(default=None, description="最终输出结果(格式化文本)")
  30. # ============================================================================
  31. # Agent 1: 分词专家
  32. # ============================================================================
  33. segmentation_instructions = """
  34. 你是中文分词专家。给定一个句子,将其分词。
  35. ## 分词原则
  36. 1. 去掉标点符号
  37. 2. 拆分成最小的有意义单元
  38. 3. 去掉助词、语气词、助动词
  39. 4. 保留疑问词
  40. 5. 保留实词:名词、动词、形容词、副词
  41. ## 输出要求
  42. 输出分词列表。
  43. """.strip()
  44. class SegmentationResult(BaseModel):
  45. """分词结果"""
  46. words: list[str] = Field(..., description="分词列表")
  47. reasoning: str = Field(..., description="分词说明")
  48. segmenter = Agent[None](
  49. name="分词专家",
  50. instructions=segmentation_instructions,
  51. output_type=SegmentationResult,
  52. )
  53. # ============================================================================
  54. # Agent 2: 评估专家(意图匹配 + 相关性评分)
  55. # ============================================================================
  56. eval_instructions = """
  57. 你是搜索query评估专家。给定原始问题和推荐query,评估两个维度。
  58. ## 评估目标
  59. 用这个推荐query搜索,能否找到满足原始需求的内容?
  60. ## 两层评分
  61. ### 1. intent_match(意图匹配)= true/false
  62. 推荐query的**使用意图**是否与原问题一致?
  63. **核心问题:用户搜索这个推荐词,想做什么?**
  64. **判断标准:**
  65. - 原问题意图:找方法?找教程?找资源/素材?找工具?看作品?
  66. - 推荐词意图:如果用户搜索这个词,他的目的是什么?
  67. **评分:**
  68. - true = 意图一致,搜索推荐词能达到原问题的目的
  69. - false = 意图改变,搜索推荐词无法达到原问题的目的
  70. ### 2. relevance_score(相关性)= 0-1 连续分数
  71. 推荐query在**主题、要素、属性**上与原问题的相关程度?
  72. **评估维度:**
  73. - 主题相关:核心主题是否匹配?(如:摄影、旅游、美食)
  74. - 要素覆盖:关键要素保留了多少?(如:地域、时间、对象、工具)
  75. - 属性匹配:质量、风格、特色等属性是否保留?
  76. **评分参考:**
  77. - 0.9-1.0 = 几乎完美匹配,所有核心要素都保留
  78. - 0.7-0.8 = 高度相关,核心要素保留,少数次要要素缺失
  79. - 0.5-0.6 = 中度相关,主题匹配但多个要素缺失
  80. - 0.3-0.4 = 低度相关,只有部分主题相关
  81. - 0-0.2 = 基本不相关
  82. ## 评估策略
  83. 1. **先判断 intent_match**:意图不匹配直接 false,无论相关性多高
  84. 2. **再评估 relevance_score**:在意图匹配的前提下,计算相关性
  85. ## 输出要求
  86. - intent_match: true/false
  87. - relevance_score: 0-1 的浮点数
  88. - reason: 详细的评估理由,需要说明:
  89. - 原问题的意图是什么
  90. - 推荐词的意图是什么
  91. - 为什么判断意图匹配/不匹配
  92. - 相关性分数的依据(哪些要素保留/缺失)
  93. """.strip()
  94. class RelevanceEvaluation(BaseModel):
  95. """评估反馈模型 - 意图匹配 + 相关性"""
  96. intent_match: bool = Field(..., description="意图是否匹配")
  97. relevance_score: float = Field(..., description="相关性分数 0-1,分数越高越相关")
  98. reason: str = Field(..., description="评估理由,需说明意图判断和相关性依据")
  99. evaluator = Agent[None](
  100. name="评估专家",
  101. instructions=eval_instructions,
  102. output_type=RelevanceEvaluation,
  103. )
  104. # ============================================================================
  105. # 核心函数
  106. # ============================================================================
  107. async def segment_text(q: str) -> SegmentationResult:
  108. """分词"""
  109. print("\n正在分词...")
  110. result = await Runner.run(segmenter, q)
  111. seg_result: SegmentationResult = result.final_output
  112. print(f"分词结果:{seg_result.words}")
  113. print(f"分词说明:{seg_result.reasoning}")
  114. return seg_result
  115. def generate_query_combinations(keywords: list[str], max_combination_size: int) -> dict[str, list[str]]:
  116. """
  117. 生成query组合
  118. Args:
  119. keywords: 关键词列表
  120. max_combination_size: 最大组合词数(N)
  121. Returns:
  122. {
  123. "1-word": [...],
  124. "2-word": [...],
  125. "3-word": [...],
  126. ...
  127. "N-word": [...]
  128. }
  129. """
  130. result = {}
  131. for size in range(1, max_combination_size + 1):
  132. if size > len(keywords):
  133. break
  134. combs = list(combinations(keywords, size))
  135. queries = [''.join(comb) for comb in combs] # 直接拼接,无空格
  136. result[f"{size}-word"] = queries
  137. print(f"\n{size}词组合:{len(queries)} 个")
  138. if len(queries) <= 10:
  139. for q in queries:
  140. print(f" - {q}")
  141. else:
  142. print(f" - {queries[0]}")
  143. print(f" - {queries[1]}")
  144. print(f" ...")
  145. print(f" - {queries[-1]}")
  146. return result
  147. async def fetch_suggestions_for_queries(queries: list[str], context: RunContext) -> list[dict]:
  148. """
  149. 并发获取所有query的推荐词
  150. Returns:
  151. [
  152. {
  153. "query": "川西",
  154. "suggestions": ["川西旅游", "川西攻略", ...],
  155. "timestamp": "..."
  156. },
  157. ...
  158. ]
  159. """
  160. print(f"\n{'='*60}")
  161. print(f"获取推荐词:{len(queries)} 个query")
  162. print(f"{'='*60}")
  163. xiaohongshu_api = XiaohongshuSearchRecommendations()
  164. async def get_single_sug(query: str):
  165. print(f" 查询: {query}")
  166. suggestions = xiaohongshu_api.get_recommendations(keyword=query)
  167. print(f" → {len(suggestions) if suggestions else 0} 个推荐词")
  168. return {
  169. "query": query,
  170. "suggestions": suggestions or [],
  171. "timestamp": datetime.now().isoformat()
  172. }
  173. results = await asyncio.gather(*[get_single_sug(q) for q in queries])
  174. return results
  175. async def evaluate_all_suggestions(sug_results: list[dict], original_question: str, context: RunContext) -> list[dict]:
  176. """
  177. 评估所有推荐词
  178. Args:
  179. sug_results: 所有query的推荐词结果
  180. original_question: 原始问题
  181. Returns:
  182. [
  183. {
  184. "source_query": "川西秋季",
  185. "sug_query": "川西秋季旅游",
  186. "intent_match": True,
  187. "relevance_score": 0.8,
  188. "reason": "..."
  189. },
  190. ...
  191. ]
  192. """
  193. print(f"\n{'='*60}")
  194. print(f"评估推荐词")
  195. print(f"{'='*60}")
  196. # 收集所有推荐词
  197. all_evaluations = []
  198. async def evaluate_single_sug(source_query: str, sug_query: str):
  199. eval_input = f"""
  200. <原始问题>
  201. {original_question}
  202. </原始问题>
  203. <待评估的推荐query>
  204. {sug_query}
  205. </待评估的推荐query>
  206. 请评估该推荐query:
  207. 1. intent_match: 意图是否匹配(true/false)
  208. 2. relevance_score: 相关性分数(0-1)
  209. 3. reason: 详细的评估理由
  210. """
  211. result = await Runner.run(evaluator, eval_input)
  212. evaluation: RelevanceEvaluation = result.final_output
  213. return {
  214. "source_query": source_query,
  215. "sug_query": sug_query,
  216. "intent_match": evaluation.intent_match,
  217. "relevance_score": evaluation.relevance_score,
  218. "reason": evaluation.reason,
  219. }
  220. # 并发评估所有推荐词
  221. tasks = []
  222. for sug_result in sug_results:
  223. source_query = sug_result["query"]
  224. for sug in sug_result["suggestions"]:
  225. tasks.append(evaluate_single_sug(source_query, sug))
  226. if tasks:
  227. print(f" 总共需要评估 {len(tasks)} 个推荐词...")
  228. all_evaluations = await asyncio.gather(*tasks)
  229. context.evaluation_results = all_evaluations
  230. return all_evaluations
  231. def find_qualified_queries(evaluations: list[dict], min_relevance_score: float = 0.7) -> list[dict]:
  232. """
  233. 查找所有合格的query
  234. 筛选标准:
  235. 1. intent_match = True(必须满足)
  236. 2. relevance_score >= min_relevance_score
  237. 返回:按 relevance_score 降序排列
  238. """
  239. qualified = [
  240. e for e in evaluations
  241. if e['intent_match'] is True and e['relevance_score'] >= min_relevance_score
  242. ]
  243. # 按relevance_score降序排列
  244. return sorted(qualified, key=lambda x: x['relevance_score'], reverse=True)
  245. # ============================================================================
  246. # 主流程
  247. # ============================================================================
  248. async def combinatorial_search(context: RunContext, max_combination_size: int = 1) -> dict:
  249. """
  250. 组合式搜索流程
  251. Args:
  252. context: 运行上下文
  253. max_combination_size: 最大组合词数(N),默认1
  254. 返回格式:
  255. {
  256. "success": True/False,
  257. "results": [...],
  258. "message": "..."
  259. }
  260. """
  261. # 步骤1:分词
  262. seg_result = await segment_text(context.q)
  263. context.keywords = seg_result.words
  264. # 步骤2:生成query组合
  265. print(f"\n{'='*60}")
  266. print(f"生成query组合(最大组合数:{max_combination_size})")
  267. print(f"{'='*60}")
  268. query_combinations = generate_query_combinations(context.keywords, max_combination_size)
  269. context.query_combinations = query_combinations
  270. # 步骤3:获取所有query的推荐词
  271. all_queries = []
  272. for level, queries in query_combinations.items():
  273. all_queries.extend(queries)
  274. sug_results = await fetch_suggestions_for_queries(all_queries, context)
  275. context.all_sug_queries = sug_results
  276. # 统计
  277. total_sugs = sum(len(r["suggestions"]) for r in sug_results)
  278. print(f"\n总共获取到 {total_sugs} 个推荐词")
  279. # 步骤4:评估所有推荐词
  280. evaluations = await evaluate_all_suggestions(sug_results, context.q, context)
  281. # 步骤5:筛选合格query
  282. qualified = find_qualified_queries(evaluations, min_relevance_score=0.7)
  283. if qualified:
  284. return {
  285. "success": True,
  286. "results": qualified,
  287. "message": f"找到 {len(qualified)} 个合格query(intent_match=True 且 relevance>=0.7)"
  288. }
  289. # 降低标准
  290. acceptable = find_qualified_queries(evaluations, min_relevance_score=0.5)
  291. if acceptable:
  292. return {
  293. "success": True,
  294. "results": acceptable,
  295. "message": f"找到 {len(acceptable)} 个可接受query(intent_match=True 且 relevance>=0.5)"
  296. }
  297. # 完全失败:返回所有intent_match=True的
  298. intent_matched = [e for e in evaluations if e['intent_match'] is True]
  299. if intent_matched:
  300. intent_matched_sorted = sorted(intent_matched, key=lambda x: x['relevance_score'], reverse=True)
  301. return {
  302. "success": False,
  303. "results": intent_matched_sorted[:10], # 只返回前10个
  304. "message": f"未找到高相关性query,但有 {len(intent_matched)} 个意图匹配的推荐词"
  305. }
  306. return {
  307. "success": False,
  308. "results": [],
  309. "message": "未找到任何意图匹配的推荐词"
  310. }
  311. # ============================================================================
  312. # 输出格式化
  313. # ============================================================================
  314. def format_output(optimization_result: dict, context: RunContext) -> str:
  315. """格式化输出结果"""
  316. results = optimization_result.get("results", [])
  317. output = f"原始问题:{context.q}\n"
  318. output += f"提取的关键词:{', '.join(context.keywords or [])}\n"
  319. output += f"关键词数量:{len(context.keywords or [])}\n"
  320. output += f"\nquery组合统计:\n"
  321. for level, queries in context.query_combinations.items():
  322. output += f" - {level}: {len(queries)} 个\n"
  323. # 统计信息
  324. total_queries = sum(len(q) for q in context.query_combinations.values())
  325. total_sugs = sum(len(r["suggestions"]) for r in context.all_sug_queries)
  326. total_evals = len(context.evaluation_results)
  327. output += f"\n探索统计:\n"
  328. output += f" - 总query数:{total_queries}\n"
  329. output += f" - 总推荐词数:{total_sugs}\n"
  330. output += f" - 总评估数:{total_evals}\n"
  331. output += f"\n状态:{optimization_result['message']}\n\n"
  332. if optimization_result["success"] and results:
  333. output += "=" * 60 + "\n"
  334. output += "合格的推荐query(按relevance_score降序):\n"
  335. output += "=" * 60 + "\n"
  336. for i, result in enumerate(results[:20], 1): # 只显示前20个
  337. output += f"\n{i}. [{result['relevance_score']:.2f}] {result['sug_query']}\n"
  338. output += f" 来源:{result['source_query']}\n"
  339. output += f" 意图:{'✓ 匹配' if result['intent_match'] else '✗ 不匹配'}\n"
  340. output += f" 理由:{result['reason'][:150]}...\n" if len(result['reason']) > 150 else f" 理由:{result['reason']}\n"
  341. else:
  342. output += "=" * 60 + "\n"
  343. output += "结果:未找到足够相关的推荐query\n"
  344. output += "=" * 60 + "\n"
  345. if results:
  346. output += "\n最接近的推荐词(前10个):\n\n"
  347. for i, result in enumerate(results[:10], 1):
  348. output += f"{i}. [{result['relevance_score']:.2f}] {result['sug_query']}\n"
  349. output += f" 来源:{result['source_query']}\n"
  350. output += f" 意图:{'✓ 匹配' if result['intent_match'] else '✗ 不匹配'}\n\n"
  351. # 按source_query分组显示
  352. output += "\n" + "=" * 60 + "\n"
  353. output += "按查询词分组的推荐词情况:\n"
  354. output += "=" * 60 + "\n"
  355. for sug_data in context.all_sug_queries:
  356. source_q = sug_data["query"]
  357. sugs = sug_data["suggestions"]
  358. # 找到这个source_query对应的所有评估
  359. related_evals = [e for e in context.evaluation_results if e["source_query"] == source_q]
  360. intent_match_count = sum(1 for e in related_evals if e["intent_match"])
  361. avg_relevance = sum(e["relevance_score"] for e in related_evals) / len(related_evals) if related_evals else 0
  362. output += f"\n查询:{source_q}\n"
  363. output += f" 推荐词数:{len(sugs)}\n"
  364. output += f" 意图匹配数:{intent_match_count}/{len(related_evals)}\n"
  365. output += f" 平均相关性:{avg_relevance:.2f}\n"
  366. # 显示前3个推荐词
  367. if sugs:
  368. output += f" 示例推荐词:\n"
  369. for sug in sugs[:3]:
  370. eval_item = next((e for e in related_evals if e["sug_query"] == sug), None)
  371. if eval_item:
  372. output += f" - {sug} [意图:{'✓' if eval_item['intent_match'] else '✗'}, 相关:{eval_item['relevance_score']:.2f}]\n"
  373. else:
  374. output += f" - {sug}\n"
  375. return output.strip()
  376. # ============================================================================
  377. # 主函数
  378. # ============================================================================
  379. async def main(input_dir: str, max_combination_size: int = 1):
  380. current_time, log_url = set_trace()
  381. # 从目录中读取固定文件名
  382. input_context_file = os.path.join(input_dir, 'context.md')
  383. input_q_file = os.path.join(input_dir, 'q.md')
  384. q_context = read_file_as_string(input_context_file)
  385. q = read_file_as_string(input_q_file)
  386. q_with_context = f"""
  387. <需求上下文>
  388. {q_context}
  389. </需求上下文>
  390. <当前问题>
  391. {q}
  392. </当前问题>
  393. """.strip()
  394. # 获取当前文件名作为版本
  395. version = os.path.basename(__file__)
  396. version_name = os.path.splitext(version)[0]
  397. # 日志保存目录
  398. log_dir = os.path.join(input_dir, "output", version_name, current_time)
  399. run_context = RunContext(
  400. version=version,
  401. input_files={
  402. "input_dir": input_dir,
  403. "context_file": input_context_file,
  404. "q_file": input_q_file,
  405. },
  406. q_with_context=q_with_context,
  407. q_context=q_context,
  408. q=q,
  409. log_dir=log_dir,
  410. log_url=log_url,
  411. )
  412. # 执行组合式搜索
  413. optimization_result = await combinatorial_search(run_context, max_combination_size=max_combination_size)
  414. # 格式化输出
  415. final_output = format_output(optimization_result, run_context)
  416. print(f"\n{'='*60}")
  417. print("最终结果")
  418. print(f"{'='*60}")
  419. print(final_output)
  420. # 保存结果
  421. run_context.optimization_result = optimization_result
  422. run_context.final_output = final_output
  423. # 保存 RunContext 到 log_dir
  424. os.makedirs(run_context.log_dir, exist_ok=True)
  425. context_file_path = os.path.join(run_context.log_dir, "run_context.json")
  426. with open(context_file_path, "w", encoding="utf-8") as f:
  427. json.dump(run_context.model_dump(), f, ensure_ascii=False, indent=2)
  428. print(f"\nRunContext saved to: {context_file_path}")
  429. if __name__ == "__main__":
  430. parser = argparse.ArgumentParser(description="搜索query优化工具 - v6.2 组合式搜索版")
  431. parser.add_argument(
  432. "--input-dir",
  433. type=str,
  434. default="input/简单扣图",
  435. help="输入目录路径,默认: input/简单扣图"
  436. )
  437. parser.add_argument(
  438. "--max-combo",
  439. type=int,
  440. default=1,
  441. help="最大组合词数(N),默认: 1"
  442. )
  443. args = parser.parse_args()
  444. asyncio.run(main(args.input_dir, max_combination_size=args.max_combo))