sug_v6_1_2_5.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002
  1. import asyncio
  2. import json
  3. import os
  4. import sys
  5. import argparse
  6. from datetime import datetime
  7. from typing import Literal
  8. from agents import Agent, Runner
  9. from lib.my_trace import set_trace
  10. from pydantic import BaseModel, Field
  11. from lib.utils import read_file_as_string
  12. from lib.client import get_model
  13. MODEL_NAME = "google/gemini-2.5-flash"
  14. from script.search_recommendations.xiaohongshu_search_recommendations import XiaohongshuSearchRecommendations
  15. from script.search.xiaohongshu_search import XiaohongshuSearch
  16. # ============================================================================
  17. # 数据模型
  18. # ============================================================================
  19. class QueryState(BaseModel):
  20. """Query状态跟踪"""
  21. query: str
  22. level: int # 当前所在层级
  23. no_suggestion_rounds: int = 0 # 连续没有suggestion的轮数
  24. relevance_score: float = 0.0 # 与原始需求的相关度
  25. parent_query: str | None = None # 父query
  26. strategy: str | None = None # 生成策略:direct_sug, rewrite, add_word
  27. class WordLibrary(BaseModel):
  28. """动态分词库"""
  29. words: set[str] = Field(default_factory=set)
  30. def add_word(self, word: str):
  31. """添加单词到分词库"""
  32. if word and word.strip():
  33. self.words.add(word.strip())
  34. def add_words(self, words: list[str]):
  35. """批量添加单词"""
  36. for word in words:
  37. self.add_word(word)
  38. def get_unused_word(self, current_query: str) -> str | None:
  39. """获取一个当前query中没有的词"""
  40. for word in self.words:
  41. if word not in current_query:
  42. return word
  43. return None
  44. def model_dump(self):
  45. """序列化为dict"""
  46. return {"words": list(self.words)}
  47. class RunContext(BaseModel):
  48. """运行上下文"""
  49. version: str
  50. input_files: dict[str, str]
  51. q_with_context: str
  52. q_context: str
  53. q: str
  54. log_url: str
  55. log_dir: str
  56. # 新增字段
  57. word_library: dict = Field(default_factory=dict) # 使用dict存储,因为set不能直接序列化
  58. query_states: list[dict] = Field(default_factory=list)
  59. steps: list[dict] = Field(default_factory=list)
  60. # 最终结果
  61. satisfied_notes: list[dict] = Field(default_factory=list)
  62. final_output: str | None = None
  63. # ============================================================================
  64. # Agent 定义
  65. # ============================================================================
  66. # Agent 1: 分词专家
  67. class WordSegmentation(BaseModel):
  68. """分词结果"""
  69. words: list[str] = Field(..., description="分词结果列表")
  70. reasoning: str = Field(..., description="分词理由")
  71. word_segmentation_instructions = """
  72. 你是分词专家。给定一个query,将其拆分成有意义的最小单元。
  73. ## 分词原则
  74. 1. 保留有搜索意义的词汇
  75. 2. 拆分成独立的概念
  76. 3. 保留专业术语的完整性
  77. 4. 去除虚词(的、吗、呢等)
  78. ## 输出要求
  79. 返回分词列表和分词理由。
  80. """.strip()
  81. word_segmenter = Agent[None](
  82. name="分词专家",
  83. instructions=word_segmentation_instructions,
  84. model=get_model(MODEL_NAME),
  85. output_type=WordSegmentation,
  86. )
  87. # Agent 2: Query相关度评估专家
  88. class RelevanceEvaluation(BaseModel):
  89. """相关度评估"""
  90. relevance_score: float = Field(..., description="相关性分数 0-1")
  91. is_improved: bool = Field(..., description="是否比之前更好")
  92. reason: str = Field(..., description="评估理由")
  93. relevance_evaluation_instructions = """
  94. 你是Query相关度评估专家。
  95. ## 任务
  96. 评估当前query与原始需求的匹配程度。
  97. ## 评估标准
  98. - 主题相关性
  99. - 要素覆盖度
  100. - 意图匹配度
  101. ## 输出
  102. - relevance_score: 0-1的相关性分数
  103. - is_improved: 如果提供了previous_score,判断是否有提升
  104. - reason: 详细理由
  105. """.strip()
  106. relevance_evaluator = Agent[None](
  107. name="Query相关度评估专家",
  108. instructions=relevance_evaluation_instructions,
  109. model=get_model(MODEL_NAME),
  110. output_type=RelevanceEvaluation,
  111. )
  112. # Agent 3: Query改写专家
  113. class QueryRewrite(BaseModel):
  114. """Query改写结果"""
  115. rewritten_query: str = Field(..., description="改写后的query")
  116. rewrite_type: str = Field(..., description="改写类型:abstract或synonym")
  117. reasoning: str = Field(..., description="改写理由")
  118. query_rewrite_instructions = """
  119. 你是Query改写专家。
  120. ## 改写策略
  121. 1. **向上抽象**:将具体概念泛化到更高层次
  122. - 例:iPhone 13 → 智能手机
  123. 2. **同义改写**:使用同义词或相关表达
  124. - 例:购买 → 入手、获取
  125. ## 输出要求
  126. 返回改写后的query、改写类型和理由。
  127. """.strip()
  128. query_rewriter = Agent[None](
  129. name="Query改写专家",
  130. instructions=query_rewrite_instructions,
  131. model=get_model(MODEL_NAME),
  132. output_type=QueryRewrite,
  133. )
  134. # Agent 4: 加词位置评估专家
  135. class WordInsertion(BaseModel):
  136. """加词结果"""
  137. new_query: str = Field(..., description="加词后的新query")
  138. insertion_position: str = Field(..., description="插入位置描述")
  139. reasoning: str = Field(..., description="插入理由")
  140. word_insertion_instructions = """
  141. 你是加词位置评估专家。
  142. ## 任务
  143. 将新词加到当前query的最合适位置,保持语义通顺。
  144. ## 原则
  145. 1. 保持语法正确
  146. 2. 语义连贯
  147. 3. 符合搜索习惯
  148. ## 输出
  149. 返回新query、插入位置描述和理由。
  150. """.strip()
  151. word_inserter = Agent[None](
  152. name="加词位置评估专家",
  153. instructions=word_insertion_instructions,
  154. model=get_model(MODEL_NAME),
  155. output_type=WordInsertion,
  156. )
  157. # Agent 5: Result匹配度评估专家
  158. class ResultEvaluation(BaseModel):
  159. """Result评估结果"""
  160. match_level: str = Field(..., description="匹配等级:satisfied, partial, unsatisfied")
  161. relevance_score: float = Field(..., description="相关性分数 0-1")
  162. missing_aspects: list[str] = Field(default_factory=list, description="缺失的方面")
  163. reason: str = Field(..., description="评估理由")
  164. result_evaluation_instructions = """
  165. 你是Result匹配度评估专家。
  166. ## 任务
  167. 评估搜索结果(帖子)与原始需求的匹配程度。
  168. ## 评估等级
  169. 1. **satisfied**: 完全满足需求
  170. 2. **partial**: 部分满足,但有缺失
  171. 3. **unsatisfied**: 基本不满足
  172. ## 输出要求
  173. - match_level: 匹配等级
  174. - relevance_score: 相关性分数
  175. - missing_aspects: 如果是partial,列出缺失的方面
  176. - reason: 详细理由
  177. """.strip()
  178. result_evaluator = Agent[None](
  179. name="Result匹配度评估专家",
  180. instructions=result_evaluation_instructions,
  181. model=get_model(MODEL_NAME),
  182. output_type=ResultEvaluation,
  183. )
  184. # Agent 6: Query改造专家(基于缺失部分)
  185. class QueryImprovement(BaseModel):
  186. """Query改造结果"""
  187. improved_query: str = Field(..., description="改造后的query")
  188. added_aspects: list[str] = Field(..., description="添加的方面")
  189. reasoning: str = Field(..., description="改造理由")
  190. query_improvement_instructions = """
  191. 你是Query改造专家。
  192. ## 任务
  193. 根据搜索结果的缺失部分,改造query使其包含这些内容。
  194. ## 原则
  195. 1. 针对性补充缺失方面
  196. 2. 保持query简洁
  197. 3. 符合搜索习惯
  198. ## 输出
  199. 返回改造后的query、添加的方面和理由。
  200. """.strip()
  201. query_improver = Agent[None](
  202. name="Query改造专家",
  203. instructions=query_improvement_instructions,
  204. model=get_model(MODEL_NAME),
  205. output_type=QueryImprovement,
  206. )
  207. # Agent 7: 关键词提取专家
  208. class KeywordExtraction(BaseModel):
  209. """关键词提取结果"""
  210. keywords: list[str] = Field(..., description="提取的关键词列表")
  211. reasoning: str = Field(..., description="提取理由")
  212. keyword_extraction_instructions = """
  213. 你是关键词提取专家。
  214. ## 任务
  215. 从帖子标题和描述中提取核心关键词。
  216. ## 提取原则
  217. 1. 提取有搜索价值的词汇
  218. 2. 去除虚词和通用词
  219. 3. 保留专业术语
  220. 4. 提取3-10个关键词
  221. ## 输出
  222. 返回关键词列表和提取理由。
  223. """.strip()
  224. keyword_extractor = Agent[None](
  225. name="关键词提取专家",
  226. instructions=keyword_extraction_instructions,
  227. model=get_model(MODEL_NAME),
  228. output_type=KeywordExtraction,
  229. )
  230. # ============================================================================
  231. # 辅助函数
  232. # ============================================================================
  233. def add_step(context: RunContext, step_name: str, step_type: str, data: dict):
  234. """添加步骤记录"""
  235. step = {
  236. "step_number": len(context.steps) + 1,
  237. "step_name": step_name,
  238. "step_type": step_type,
  239. "timestamp": datetime.now().isoformat(),
  240. "data": data
  241. }
  242. context.steps.append(step)
  243. return step
  244. def process_note_data(note: dict) -> dict:
  245. """处理搜索接口返回的帖子数据"""
  246. note_card = note.get("note_card", {})
  247. image_list = note_card.get("image_list", [])
  248. interact_info = note_card.get("interact_info", {})
  249. user_info = note_card.get("user", {})
  250. return {
  251. "note_id": note.get("id", ""),
  252. "title": note_card.get("display_title", ""),
  253. "desc": note_card.get("desc", ""),
  254. "image_list": image_list,
  255. "interact_info": {
  256. "liked_count": interact_info.get("liked_count", 0),
  257. "collected_count": interact_info.get("collected_count", 0),
  258. "comment_count": interact_info.get("comment_count", 0),
  259. "shared_count": interact_info.get("shared_count", 0)
  260. },
  261. "user": {
  262. "nickname": user_info.get("nickname", ""),
  263. "user_id": user_info.get("user_id", "")
  264. },
  265. "type": note_card.get("type", "normal"),
  266. "note_url": f"https://www.xiaohongshu.com/explore/{note.get('id', '')}"
  267. }
  268. # ============================================================================
  269. # 核心流程函数
  270. # ============================================================================
  271. async def initialize_word_library(original_query: str, context: RunContext) -> WordLibrary:
  272. """初始化分词库"""
  273. print("\n[初始化] 创建分词库...")
  274. # 使用Agent进行分词
  275. result = await Runner.run(word_segmenter, original_query)
  276. segmentation: WordSegmentation = result.final_output
  277. word_lib = WordLibrary()
  278. word_lib.add_words(segmentation.words)
  279. print(f"初始分词库: {list(word_lib.words)}")
  280. print(f"分词理由: {segmentation.reasoning}")
  281. # 保存到context
  282. context.word_library = word_lib.model_dump()
  283. add_step(context, "初始化分词库", "word_library_init", {
  284. "original_query": original_query,
  285. "words": list(word_lib.words),
  286. "reasoning": segmentation.reasoning
  287. })
  288. return word_lib
  289. async def evaluate_query_relevance(
  290. query: str,
  291. original_need: str,
  292. previous_score: float | None = None,
  293. context: RunContext = None
  294. ) -> RelevanceEvaluation:
  295. """评估query与原始需求的相关度"""
  296. eval_input = f"""
  297. <原始需求>
  298. {original_need}
  299. </原始需求>
  300. <当前Query>
  301. {query}
  302. </当前Query>
  303. {"<之前的相关度分数>" + str(previous_score) + "</之前的相关度分数>" if previous_score is not None else ""}
  304. 请评估当前query与原始需求的相关度。
  305. """
  306. result = await Runner.run(relevance_evaluator, eval_input)
  307. evaluation: RelevanceEvaluation = result.final_output
  308. return evaluation
  309. async def process_suggestions(
  310. query: str,
  311. query_state: QueryState,
  312. original_need: str,
  313. word_lib: WordLibrary,
  314. context: RunContext,
  315. xiaohongshu_api: XiaohongshuSearchRecommendations
  316. ) -> list[QueryState]:
  317. """处理suggestion分支,返回新的query states"""
  318. print(f"\n [Suggestion分支] 处理query: {query}")
  319. # 1. 获取suggestions
  320. suggestions = xiaohongshu_api.get_recommendations(keyword=query)
  321. if not suggestions or len(suggestions) == 0:
  322. print(f" → 没有获取到suggestion")
  323. query_state.no_suggestion_rounds += 1
  324. # 记录步骤
  325. add_step(context, f"Suggestion分支 - {query}", "suggestion_branch", {
  326. "query": query,
  327. "query_level": query_state.level,
  328. "suggestions_count": 0,
  329. "no_suggestion_rounds": query_state.no_suggestion_rounds,
  330. "new_queries_generated": 0
  331. })
  332. return []
  333. print(f" → 获取到 {len(suggestions)} 个suggestions")
  334. query_state.no_suggestion_rounds = 0 # 重置计数
  335. # 2. 评估每个suggestion
  336. new_queries = []
  337. suggestion_evaluations = []
  338. for sug in suggestions[:5]: # 限制处理数量
  339. # 评估sug的相关度
  340. sug_eval = await evaluate_query_relevance(sug, original_need, query_state.relevance_score, context)
  341. sug_eval_record = {
  342. "suggestion": sug,
  343. "relevance_score": sug_eval.relevance_score,
  344. "is_improved": sug_eval.is_improved,
  345. "reason": sug_eval.reason
  346. }
  347. suggestion_evaluations.append(sug_eval_record)
  348. # 判断是否比当前query更好
  349. if sug_eval.is_improved and sug_eval.relevance_score > query_state.relevance_score:
  350. print(f" ✓ {sug} (分数: {sug_eval.relevance_score:.2f}, 提升: {sug_eval.is_improved})")
  351. # 创建新的query state(直接使用suggestion)
  352. new_state = QueryState(
  353. query=sug,
  354. level=query_state.level + 1,
  355. relevance_score=sug_eval.relevance_score,
  356. parent_query=query,
  357. strategy="direct_sug"
  358. )
  359. new_queries.append(new_state)
  360. else:
  361. print(f" ✗ {sug} (分数: {sug_eval.relevance_score:.2f}, 未提升)")
  362. # 3. 改写策略(向上抽象或同义改写)
  363. if len(new_queries) < 3: # 如果直接使用sug的数量不够,尝试改写
  364. rewrite_input = f"""
  365. <当前Query>
  366. {query}
  367. </当前Query>
  368. <改写要求>
  369. 类型: abstract (向上抽象)
  370. </改写要求>
  371. 请改写这个query。
  372. """
  373. result = await Runner.run(query_rewriter, rewrite_input)
  374. rewrite: QueryRewrite = result.final_output
  375. # 评估改写后的query
  376. rewrite_eval = await evaluate_query_relevance(rewrite.rewritten_query, original_need, query_state.relevance_score, context)
  377. if rewrite_eval.is_improved:
  378. print(f" ✓ 改写: {rewrite.rewritten_query} (分数: {rewrite_eval.relevance_score:.2f})")
  379. new_state = QueryState(
  380. query=rewrite.rewritten_query,
  381. level=query_state.level + 1,
  382. relevance_score=rewrite_eval.relevance_score,
  383. parent_query=query,
  384. strategy="rewrite"
  385. )
  386. new_queries.append(new_state)
  387. # 4. 加词策略
  388. unused_word = word_lib.get_unused_word(query)
  389. if unused_word and len(new_queries) < 5:
  390. insertion_input = f"""
  391. <当前Query>
  392. {query}
  393. </当前Query>
  394. <要添加的词>
  395. {unused_word}
  396. </要添加的词>
  397. 请将这个词加到query的最合适位置。
  398. """
  399. result = await Runner.run(word_inserter, insertion_input)
  400. insertion: WordInsertion = result.final_output
  401. # 评估加词后的query
  402. insertion_eval = await evaluate_query_relevance(insertion.new_query, original_need, query_state.relevance_score, context)
  403. if insertion_eval.is_improved:
  404. print(f" ✓ 加词: {insertion.new_query} (分数: {insertion_eval.relevance_score:.2f})")
  405. new_state = QueryState(
  406. query=insertion.new_query,
  407. level=query_state.level + 1,
  408. relevance_score=insertion_eval.relevance_score,
  409. parent_query=query,
  410. strategy="add_word"
  411. )
  412. new_queries.append(new_state)
  413. # 记录完整的suggestion分支处理结果
  414. add_step(context, f"Suggestion分支 - {query}", "suggestion_branch", {
  415. "query": query,
  416. "query_level": query_state.level,
  417. "query_relevance": query_state.relevance_score,
  418. "suggestions_count": len(suggestions),
  419. "suggestions_evaluated": len(suggestion_evaluations),
  420. "suggestion_evaluations": suggestion_evaluations[:10], # 只保存前10个
  421. "new_queries_generated": len(new_queries),
  422. "new_queries": [{"query": nq.query, "score": nq.relevance_score, "strategy": nq.strategy} for nq in new_queries],
  423. "no_suggestion_rounds": query_state.no_suggestion_rounds
  424. })
  425. return new_queries
  426. async def process_search_results(
  427. query: str,
  428. query_state: QueryState,
  429. original_need: str,
  430. word_lib: WordLibrary,
  431. context: RunContext,
  432. xiaohongshu_search: XiaohongshuSearch,
  433. relevance_threshold: float = 0.6
  434. ) -> tuple[list[dict], list[QueryState]]:
  435. """
  436. 处理搜索结果分支
  437. 返回: (满足需求的notes, 需要继续迭代的新queries)
  438. """
  439. print(f"\n [Result分支] 搜索query: {query}")
  440. # 1. 判断query相关度是否达到门槛
  441. if query_state.relevance_score < relevance_threshold:
  442. print(f" ✗ 相关度 {query_state.relevance_score:.2f} 低于门槛 {relevance_threshold},跳过搜索")
  443. return [], []
  444. print(f" ✓ 相关度 {query_state.relevance_score:.2f} 达到门槛,执行搜索")
  445. # 2. 执行搜索
  446. try:
  447. search_result = xiaohongshu_search.search(keyword=query)
  448. result_str = search_result.get("result", "{}")
  449. if isinstance(result_str, str):
  450. result_data = json.loads(result_str)
  451. else:
  452. result_data = result_str
  453. notes = result_data.get("data", {}).get("data", [])
  454. print(f" → 搜索到 {len(notes)} 个帖子")
  455. except Exception as e:
  456. print(f" ✗ 搜索失败: {e}")
  457. return [], []
  458. if not notes:
  459. return [], []
  460. # 3. 评估每个帖子
  461. satisfied_notes = []
  462. partial_notes = []
  463. for note in notes[:10]: # 限制评估数量
  464. note_data = process_note_data(note)
  465. title = note_data["title"] or ""
  466. desc = note_data["desc"] or ""
  467. # 跳过空标题和描述的帖子
  468. if not title and not desc:
  469. continue
  470. # 评估帖子
  471. eval_input = f"""
  472. <原始需求>
  473. {original_need}
  474. </原始需求>
  475. <帖子>
  476. 标题: {title}
  477. 描述: {desc}
  478. </帖子>
  479. 请评估这个帖子与原始需求的匹配程度。
  480. """
  481. result = await Runner.run(result_evaluator, eval_input)
  482. evaluation: ResultEvaluation = result.final_output
  483. note_data["evaluation"] = {
  484. "match_level": evaluation.match_level,
  485. "relevance_score": evaluation.relevance_score,
  486. "missing_aspects": evaluation.missing_aspects,
  487. "reason": evaluation.reason
  488. }
  489. if evaluation.match_level == "satisfied":
  490. satisfied_notes.append(note_data)
  491. print(f" ✓ 满足: {title[:30] if len(title) > 30 else title}... (分数: {evaluation.relevance_score:.2f})")
  492. elif evaluation.match_level == "partial":
  493. partial_notes.append(note_data)
  494. print(f" ~ 部分: {title[:30] if len(title) > 30 else title}... (缺失: {', '.join(evaluation.missing_aspects[:2])})")
  495. # 4. 处理满足的帖子:提取关键词并扩充分词库
  496. new_queries = []
  497. if satisfied_notes:
  498. print(f"\n 从 {len(satisfied_notes)} 个满足的帖子中提取关键词...")
  499. for note in satisfied_notes[:3]: # 限制处理数量
  500. extract_input = f"""
  501. <帖子>
  502. 标题: {note['title']}
  503. 描述: {note['desc']}
  504. </帖子>
  505. 请提取核心关键词。
  506. """
  507. result = await Runner.run(keyword_extractor, extract_input)
  508. extraction: KeywordExtraction = result.final_output
  509. # 添加新词到分词库
  510. for keyword in extraction.keywords:
  511. if keyword not in word_lib.words:
  512. word_lib.add_word(keyword)
  513. print(f" + 新词入库: {keyword}")
  514. # 5. 处理部分匹配的帖子:改造query
  515. if partial_notes and len(satisfied_notes) < 5: # 如果满足的不够,基于部分匹配改进
  516. print(f"\n 基于 {len(partial_notes)} 个部分匹配帖子改造query...")
  517. # 收集所有缺失方面
  518. all_missing = []
  519. for note in partial_notes:
  520. all_missing.extend(note["evaluation"]["missing_aspects"])
  521. if all_missing:
  522. improvement_input = f"""
  523. <当前Query>
  524. {query}
  525. </当前Query>
  526. <缺失的方面>
  527. {', '.join(set(all_missing[:5]))}
  528. </缺失的方面>
  529. 请改造query使其包含这些缺失的内容。
  530. """
  531. result = await Runner.run(query_improver, improvement_input)
  532. improvement: QueryImprovement = result.final_output
  533. # 评估改进后的query
  534. improved_eval = await evaluate_query_relevance(improvement.improved_query, original_need, query_state.relevance_score, context)
  535. if improved_eval.is_improved:
  536. print(f" ✓ 改进: {improvement.improved_query} (添加: {', '.join(improvement.added_aspects[:2])})")
  537. new_state = QueryState(
  538. query=improvement.improved_query,
  539. level=query_state.level + 1,
  540. relevance_score=improved_eval.relevance_score,
  541. parent_query=query,
  542. strategy="improve_from_partial"
  543. )
  544. new_queries.append(new_state)
  545. # 记录完整的result分支处理结果
  546. add_step(context, f"Result分支 - {query}", "result_branch", {
  547. "query": query,
  548. "query_level": query_state.level,
  549. "query_relevance": query_state.relevance_score,
  550. "relevance_threshold": relevance_threshold,
  551. "passed_threshold": query_state.relevance_score >= relevance_threshold,
  552. "notes_count": len(notes) if 'notes' in locals() else 0,
  553. "satisfied_count": len(satisfied_notes),
  554. "partial_count": len(partial_notes),
  555. "satisfied_notes": [
  556. {
  557. "note_id": note["note_id"],
  558. "title": note["title"],
  559. "score": note["evaluation"]["relevance_score"],
  560. "match_level": note["evaluation"]["match_level"]
  561. }
  562. for note in satisfied_notes[:10] # 只保存前10个
  563. ],
  564. "new_queries_generated": len(new_queries),
  565. "new_queries": [{"query": nq.query, "score": nq.relevance_score, "strategy": nq.strategy} for nq in new_queries],
  566. "word_library_expanded": len(new_queries) > 0 # 是否扩充了分词库
  567. })
  568. return satisfied_notes, new_queries
  569. async def iterative_search_loop(
  570. context: RunContext,
  571. max_iterations: int = 20,
  572. max_concurrent_queries: int = 5,
  573. relevance_threshold: float = 0.6
  574. ) -> list[dict]:
  575. """
  576. 主循环:迭代搜索
  577. Args:
  578. context: 运行上下文
  579. max_iterations: 最大迭代次数
  580. max_concurrent_queries: 最大并发query数量
  581. relevance_threshold: 相关度门槛
  582. Returns:
  583. 满足需求的帖子列表
  584. """
  585. print(f"\n{'='*60}")
  586. print(f"开始迭代搜索循环")
  587. print(f"{'='*60}")
  588. # 1. 初始化分词库
  589. word_lib = await initialize_word_library(context.q, context)
  590. # 2. 初始化query队列 - 智能选择最相关的词
  591. all_words = list(word_lib.words)
  592. query_queue = []
  593. print(f"\n评估所有初始分词的相关度...")
  594. word_scores = []
  595. for word in all_words:
  596. # 评估每个词的相关度
  597. eval_result = await evaluate_query_relevance(word, context.q, None, context)
  598. word_scores.append({
  599. 'word': word,
  600. 'score': eval_result.relevance_score,
  601. 'eval': eval_result
  602. })
  603. print(f" {word}: {eval_result.relevance_score:.2f}")
  604. # 按相关度排序,选择top 3
  605. word_scores.sort(key=lambda x: x['score'], reverse=True)
  606. selected_words = word_scores[:3]
  607. for item in selected_words:
  608. query_queue.append(QueryState(
  609. query=item['word'],
  610. level=1,
  611. relevance_score=item['score'],
  612. strategy="initial"
  613. ))
  614. print(f"\n初始query队列(按相关度选择): {[(q.query, f'{q.relevance_score:.2f}') for q in query_queue]}")
  615. # 3. API实例
  616. xiaohongshu_api = XiaohongshuSearchRecommendations()
  617. xiaohongshu_search = XiaohongshuSearch()
  618. # 4. 主循环
  619. all_satisfied_notes = []
  620. iteration = 0
  621. while query_queue and iteration < max_iterations:
  622. iteration += 1
  623. print(f"\n{'='*60}")
  624. print(f"迭代 {iteration}: 队列中有 {len(query_queue)} 个query")
  625. print(f"{'='*60}")
  626. # 限制并发数量
  627. current_batch = query_queue[:max_concurrent_queries]
  628. query_queue = query_queue[max_concurrent_queries:]
  629. # 记录本轮处理的queries
  630. add_step(context, f"迭代 {iteration}", "iteration", {
  631. "iteration": iteration,
  632. "queue_size": len(query_queue) + len(current_batch),
  633. "processing_queries": [q.query for q in current_batch]
  634. })
  635. new_queries_from_sug = []
  636. new_queries_from_result = []
  637. # 处理每个query
  638. for query_state in current_batch:
  639. print(f"\n处理Query [{query_state.level}]: {query_state.query} (分数: {query_state.relevance_score:.2f})")
  640. # 检查终止条件
  641. if query_state.no_suggestion_rounds >= 2:
  642. print(f" ✗ 连续2轮无suggestion,终止该分支")
  643. continue
  644. # 并行处理两个分支
  645. sug_task = process_suggestions(
  646. query_state.query, query_state, context.q, word_lib, context, xiaohongshu_api
  647. )
  648. result_task = process_search_results(
  649. query_state.query, query_state, context.q, word_lib, context,
  650. xiaohongshu_search, relevance_threshold
  651. )
  652. # 等待两个分支完成
  653. sug_queries, (satisfied_notes, result_queries) = await asyncio.gather(
  654. sug_task,
  655. result_task
  656. )
  657. new_queries_from_sug.extend(sug_queries)
  658. new_queries_from_result.extend(result_queries)
  659. all_satisfied_notes.extend(satisfied_notes)
  660. # 更新队列
  661. all_new_queries = new_queries_from_sug + new_queries_from_result
  662. query_queue.extend(all_new_queries)
  663. # 去重(基于query文本)
  664. seen = set()
  665. unique_queue = []
  666. for q in query_queue:
  667. if q.query not in seen:
  668. seen.add(q.query)
  669. unique_queue.append(q)
  670. query_queue = unique_queue
  671. # 按相关度排序
  672. query_queue.sort(key=lambda x: x.relevance_score, reverse=True)
  673. print(f"\n本轮结果:")
  674. print(f" 新增满足帖子: {len(satisfied_notes)}")
  675. print(f" 累计满足帖子: {len(all_satisfied_notes)}")
  676. print(f" 新增queries: {len(all_new_queries)}")
  677. print(f" 队列剩余: {len(query_queue)}")
  678. # 更新分词库到context
  679. context.word_library = word_lib.model_dump()
  680. # 如果满足条件的帖子足够多,可以提前结束
  681. if len(all_satisfied_notes) >= 20:
  682. print(f"\n已找到足够的满足帖子 ({len(all_satisfied_notes)}个),提前结束")
  683. break
  684. print(f"\n{'='*60}")
  685. print(f"迭代搜索完成")
  686. print(f" 总迭代次数: {iteration}")
  687. print(f" 最终满足帖子数: {len(all_satisfied_notes)}")
  688. print(f" 最终分词库大小: {len(word_lib.words)}")
  689. print(f"{'='*60}")
  690. # 保存最终结果
  691. add_step(context, "迭代搜索完成", "loop_complete", {
  692. "total_iterations": iteration,
  693. "total_satisfied_notes": len(all_satisfied_notes),
  694. "final_word_library_size": len(word_lib.words),
  695. "final_word_library": list(word_lib.words)
  696. })
  697. return all_satisfied_notes
  698. # ============================================================================
  699. # 主函数
  700. # ============================================================================
  701. async def main(input_dir: str, max_iterations: int = 20, visualize: bool = False):
  702. """主函数"""
  703. current_time, log_url = set_trace()
  704. # 读取输入
  705. input_context_file = os.path.join(input_dir, 'context.md')
  706. input_q_file = os.path.join(input_dir, 'q.md')
  707. q_context = read_file_as_string(input_context_file)
  708. q = read_file_as_string(input_q_file)
  709. q_with_context = f"""
  710. <需求上下文>
  711. {q_context}
  712. </需求上下文>
  713. <当前问题>
  714. {q}
  715. </当前问题>
  716. """.strip()
  717. # 版本信息
  718. version = os.path.basename(__file__)
  719. version_name = os.path.splitext(version)[0]
  720. # 日志目录
  721. log_dir = os.path.join(input_dir, "output", version_name, current_time)
  722. # 创建运行上下文
  723. run_context = RunContext(
  724. version=version,
  725. input_files={
  726. "input_dir": input_dir,
  727. "context_file": input_context_file,
  728. "q_file": input_q_file,
  729. },
  730. q_with_context=q_with_context,
  731. q_context=q_context,
  732. q=q,
  733. log_dir=log_dir,
  734. log_url=log_url,
  735. )
  736. # 执行迭代搜索
  737. satisfied_notes = await iterative_search_loop(
  738. run_context,
  739. max_iterations=max_iterations,
  740. max_concurrent_queries=3,
  741. relevance_threshold=0.6
  742. )
  743. # 保存结果
  744. run_context.satisfied_notes = satisfied_notes
  745. # 格式化输出
  746. output = f"原始问题:{run_context.q}\n"
  747. output += f"找到满足需求的帖子:{len(satisfied_notes)} 个\n"
  748. output += f"分词库大小:{len(run_context.word_library.get('words', []))} 个词\n"
  749. output += "\n" + "="*60 + "\n"
  750. if satisfied_notes:
  751. output += "【满足需求的帖子】\n\n"
  752. for idx, note in enumerate(satisfied_notes[:10], 1):
  753. output += f"{idx}. {note['title']}\n"
  754. output += f" 相关度: {note['evaluation']['relevance_score']:.2f}\n"
  755. output += f" URL: {note['note_url']}\n\n"
  756. else:
  757. output += "未找到满足需求的帖子\n"
  758. run_context.final_output = output
  759. print(f"\n{'='*60}")
  760. print("最终结果")
  761. print(f"{'='*60}")
  762. print(output)
  763. # 保存日志
  764. os.makedirs(run_context.log_dir, exist_ok=True)
  765. context_file_path = os.path.join(run_context.log_dir, "run_context.json")
  766. context_dict = run_context.model_dump()
  767. with open(context_file_path, "w", encoding="utf-8") as f:
  768. json.dump(context_dict, f, ensure_ascii=False, indent=2)
  769. print(f"\nRunContext saved to: {context_file_path}")
  770. steps_file_path = os.path.join(run_context.log_dir, "steps.json")
  771. with open(steps_file_path, "w", encoding="utf-8") as f:
  772. json.dump(run_context.steps, f, ensure_ascii=False, indent=2)
  773. print(f"Steps log saved to: {steps_file_path}")
  774. # 可视化
  775. if visualize:
  776. import subprocess
  777. output_html = os.path.join(run_context.log_dir, "visualization.html")
  778. print(f"\n🎨 生成可视化HTML...")
  779. result = subprocess.run([
  780. "python", "sug_v6_1_2_3.visualize.py",
  781. steps_file_path,
  782. "-o", output_html
  783. ])
  784. if result.returncode == 0:
  785. print(f"✅ 可视化已生成: {output_html}")
  786. else:
  787. print(f"❌ 可视化生成失败")
  788. if __name__ == "__main__":
  789. parser = argparse.ArgumentParser(description="搜索query优化工具 - v6.1.2.5 迭代循环版")
  790. parser.add_argument(
  791. "--input-dir",
  792. type=str,
  793. default="input/简单扣图",
  794. help="输入目录路径,默认: input/简单扣图"
  795. )
  796. parser.add_argument(
  797. "--max-iterations",
  798. type=int,
  799. default=20,
  800. help="最大迭代次数,默认: 20"
  801. )
  802. parser.add_argument(
  803. "--visualize",
  804. action="store_true",
  805. default=False,
  806. help="运行完成后自动生成可视化HTML"
  807. )
  808. args = parser.parse_args()
  809. asyncio.run(main(args.input_dir, max_iterations=args.max_iterations, visualize=args.visualize))