sug_v6_1_2_6.py 54 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551
  1. import asyncio
  2. import json
  3. import os
  4. import sys
  5. import argparse
  6. from datetime import datetime
  7. from typing import Literal
  8. from agents import Agent, Runner
  9. from lib.my_trace import set_trace
  10. from pydantic import BaseModel, Field
  11. from lib.utils import read_file_as_string
  12. from lib.client import get_model
  13. MODEL_NAME = "google/gemini-2.5-flash"
  14. from script.search_recommendations.xiaohongshu_search_recommendations import XiaohongshuSearchRecommendations
  15. from script.search.xiaohongshu_search import XiaohongshuSearch
  16. # ============================================================================
  17. # 数据模型
  18. # ============================================================================
  19. class QueryState(BaseModel):
  20. """Query状态跟踪"""
  21. query: str
  22. level: int # 当前所在层级
  23. no_suggestion_rounds: int = 0 # 连续没有suggestion的轮数
  24. relevance_score: float = 0.0 # 与原始需求的相关度
  25. parent_query: str | None = None # 父query
  26. strategy: str | None = None # 生成策略:direct_sug, rewrite, add_word
  27. is_terminated: bool = False # 是否已终止(不再处理)
  28. class WordLibrary(BaseModel):
  29. """动态分词库"""
  30. words: set[str] = Field(default_factory=set)
  31. word_sources: dict[str, str] = Field(default_factory=dict) # 记录词的来源:word -> source(note_id或"initial")
  32. core_words: set[str] = Field(default_factory=set) # 核心词(第一层初始分词)
  33. def add_word(self, word: str, source: str = "unknown", is_core: bool = False):
  34. """添加单词到分词库"""
  35. if word and word.strip():
  36. word = word.strip()
  37. self.words.add(word)
  38. if word not in self.word_sources:
  39. self.word_sources[word] = source
  40. if is_core:
  41. self.core_words.add(word)
  42. def add_words(self, words: list[str], source: str = "unknown", is_core: bool = False):
  43. """批量添加单词"""
  44. for word in words:
  45. self.add_word(word, source, is_core)
  46. def get_unused_word(self, current_query: str, prefer_core: bool = True) -> str | None:
  47. """获取一个当前query中没有的词
  48. Args:
  49. current_query: 当前查询
  50. prefer_core: 是否优先返回核心词(默认True)
  51. """
  52. # 优先从核心词中查找
  53. if prefer_core and self.core_words:
  54. for word in self.core_words:
  55. if word not in current_query:
  56. return word
  57. # 如果核心词都用完了,或者不优先使用核心词,从所有词中查找
  58. for word in self.words:
  59. if word not in current_query:
  60. return word
  61. return None
  62. def model_dump(self):
  63. """序列化为dict"""
  64. return {
  65. "words": list(self.words),
  66. "word_sources": self.word_sources,
  67. "core_words": list(self.core_words)
  68. }
  69. class RunContext(BaseModel):
  70. """运行上下文"""
  71. version: str
  72. input_files: dict[str, str]
  73. q_with_context: str
  74. q_context: str
  75. q: str
  76. log_url: str
  77. log_dir: str
  78. # 新增字段
  79. word_library: dict = Field(default_factory=dict) # 使用dict存储,因为set不能直接序列化
  80. query_states: list[dict] = Field(default_factory=list)
  81. steps: list[dict] = Field(default_factory=list)
  82. # Query演化图
  83. query_graph: dict = Field(default_factory=dict) # 记录Query的演化路径和关系
  84. # 最终结果
  85. satisfied_notes: list[dict] = Field(default_factory=list)
  86. final_output: str | None = None
  87. # ============================================================================
  88. # Agent 定义
  89. # ============================================================================
  90. # Agent 1: 分词专家
  91. class WordSegmentation(BaseModel):
  92. """分词结果"""
  93. words: list[str] = Field(..., description="分词结果列表")
  94. reasoning: str = Field(..., description="分词理由")
  95. word_segmentation_instructions = """
  96. 你是分词专家。给定一个query,将其拆分成有意义的最小单元。
  97. ## 分词原则
  98. 1. 保留有搜索意义的词汇
  99. 2. 拆分成独立的概念
  100. 3. 保留专业术语的完整性
  101. 4. 去除虚词(的、吗、呢等)
  102. ## 输出要求
  103. 返回分词列表和分词理由。
  104. """.strip()
  105. word_segmenter = Agent[None](
  106. name="分词专家",
  107. instructions=word_segmentation_instructions,
  108. model=get_model(MODEL_NAME),
  109. output_type=WordSegmentation,
  110. )
  111. # Agent 2: Query相关度评估专家
  112. class RelevanceEvaluation(BaseModel):
  113. """相关度评估"""
  114. relevance_score: float = Field(..., description="相关性分数 0-1")
  115. is_improved: bool = Field(..., description="是否比之前更好")
  116. reason: str = Field(..., description="评估理由")
  117. relevance_evaluation_instructions = """
  118. 你是Query相关度评估专家。
  119. ## 任务
  120. 评估当前query与原始需求的匹配程度。
  121. ## 评估标准
  122. - 主题相关性
  123. - 要素覆盖度
  124. - 意图匹配度
  125. ## 输出
  126. - relevance_score: 0-1的相关性分数
  127. - is_improved: 如果提供了previous_score,判断是否有提升
  128. - reason: 详细理由
  129. """.strip()
  130. relevance_evaluator = Agent[None](
  131. name="Query相关度评估专家",
  132. instructions=relevance_evaluation_instructions,
  133. model=get_model(MODEL_NAME),
  134. output_type=RelevanceEvaluation,
  135. )
  136. # Agent 3: Query改写专家
  137. class QueryRewrite(BaseModel):
  138. """Query改写结果"""
  139. rewritten_query: str = Field(..., description="改写后的query")
  140. rewrite_type: str = Field(..., description="改写类型:abstract或synonym")
  141. reasoning: str = Field(..., description="改写理由")
  142. query_rewrite_instructions = """
  143. 你是Query改写专家。
  144. ## 改写策略
  145. 1. **向上抽象**:将具体概念泛化到更高层次
  146. - 例:iPhone 13 → 智能手机
  147. 2. **同义改写**:使用同义词或相关表达
  148. - 例:购买 → 入手、获取
  149. ## 输出要求
  150. 返回改写后的query、改写类型和理由。
  151. """.strip()
  152. query_rewriter = Agent[None](
  153. name="Query改写专家",
  154. instructions=query_rewrite_instructions,
  155. model=get_model(MODEL_NAME),
  156. output_type=QueryRewrite,
  157. )
  158. # Agent 4: 加词位置评估专家
  159. class WordInsertion(BaseModel):
  160. """加词结果"""
  161. new_query: str = Field(..., description="加词后的新query")
  162. insertion_position: str = Field(..., description="插入位置描述")
  163. reasoning: str = Field(..., description="插入理由")
  164. word_insertion_instructions = """
  165. 你是加词位置评估专家。
  166. ## 任务
  167. 将新词加到当前query的最合适位置,保持语义通顺。
  168. ## 原则
  169. 1. 保持语法正确
  170. 2. 语义连贯
  171. 3. 符合搜索习惯
  172. ## 输出
  173. 返回新query、插入位置描述和理由。
  174. """.strip()
  175. word_inserter = Agent[None](
  176. name="加词位置评估专家",
  177. instructions=word_insertion_instructions,
  178. model=get_model(MODEL_NAME),
  179. output_type=WordInsertion,
  180. )
  181. # Agent 5: Result匹配度评估专家
  182. class ResultEvaluation(BaseModel):
  183. """Result评估结果"""
  184. match_level: str = Field(..., description="匹配等级:satisfied, partial, unsatisfied")
  185. relevance_score: float = Field(..., description="相关性分数 0-1")
  186. missing_aspects: list[str] = Field(default_factory=list, description="缺失的方面")
  187. reason: str = Field(..., description="评估理由")
  188. result_evaluation_instructions = """
  189. 你是Result匹配度评估专家。
  190. ## 任务
  191. 评估搜索结果(帖子)与原始需求的匹配程度。
  192. ## 评估等级
  193. 1. **satisfied**: 完全满足需求
  194. 2. **partial**: 部分满足,但有缺失
  195. 3. **unsatisfied**: 基本不满足
  196. ## 输出要求
  197. - match_level: 匹配等级
  198. - relevance_score: 相关性分数
  199. - missing_aspects: 如果是partial,列出缺失的方面
  200. - reason: 详细理由
  201. """.strip()
  202. result_evaluator = Agent[None](
  203. name="Result匹配度评估专家",
  204. instructions=result_evaluation_instructions,
  205. model=get_model(MODEL_NAME),
  206. output_type=ResultEvaluation,
  207. )
  208. # Agent 6: Query改造专家(基于缺失部分)
  209. class QueryImprovement(BaseModel):
  210. """Query改造结果"""
  211. improved_query: str = Field(..., description="改造后的query")
  212. added_aspects: list[str] = Field(..., description="添加的方面")
  213. reasoning: str = Field(..., description="改造理由")
  214. query_improvement_instructions = """
  215. 你是Query改造专家。
  216. ## 任务
  217. 根据搜索结果的缺失部分,改造query使其包含这些内容。
  218. ## 原则
  219. 1. 针对性补充缺失方面
  220. 2. 保持query简洁
  221. 3. 符合搜索习惯
  222. ## 输出
  223. 返回改造后的query、添加的方面和理由。
  224. """.strip()
  225. query_improver = Agent[None](
  226. name="Query改造专家",
  227. instructions=query_improvement_instructions,
  228. model=get_model(MODEL_NAME),
  229. output_type=QueryImprovement,
  230. )
  231. # Agent 7: 关键词提取专家
  232. class KeywordExtraction(BaseModel):
  233. """关键词提取结果"""
  234. keywords: list[str] = Field(..., description="提取的关键词列表")
  235. reasoning: str = Field(..., description="提取理由")
  236. keyword_extraction_instructions = """
  237. 你是关键词提取专家。
  238. ## 任务
  239. 从帖子标题和描述中提取核心关键词。
  240. ## 提取原则
  241. 1. 提取有搜索价值的词汇
  242. 2. 去除虚词和通用词
  243. 3. 保留专业术语
  244. 4. 提取3-10个关键词
  245. ## 输出
  246. 返回关键词列表和提取理由。
  247. """.strip()
  248. keyword_extractor = Agent[None](
  249. name="关键词提取专家",
  250. instructions=keyword_extraction_instructions,
  251. model=get_model(MODEL_NAME),
  252. output_type=KeywordExtraction,
  253. )
  254. # ============================================================================
  255. # 辅助函数
  256. # ============================================================================
  257. def add_step(context: RunContext, step_name: str, step_type: str, data: dict):
  258. """添加步骤记录"""
  259. step = {
  260. "step_number": len(context.steps) + 1,
  261. "step_name": step_name,
  262. "step_type": step_type,
  263. "timestamp": datetime.now().isoformat(),
  264. "data": data
  265. }
  266. context.steps.append(step)
  267. return step
  268. def add_query_to_graph(context: RunContext, query_state: QueryState, iteration: int, evaluation_reason: str = "", is_selected: bool = True, parent_level: int | None = None):
  269. """添加Query节点到演化图
  270. Args:
  271. context: 运行上下文
  272. query_state: Query状态
  273. iteration: 迭代次数
  274. evaluation_reason: 评估原因(可选)
  275. is_selected: 是否被选中进入处理队列(默认True)
  276. parent_level: 父节点的层级(用于构造parent_id)
  277. """
  278. # 使用 "query_level" 格式作为节点ID
  279. query_id = f"{query_state.query}_{query_state.level}"
  280. # 初始化图结构
  281. if "nodes" not in context.query_graph:
  282. context.query_graph["nodes"] = {}
  283. context.query_graph["edges"] = []
  284. context.query_graph["iterations"] = {}
  285. # 添加Query节点(type: query)
  286. context.query_graph["nodes"][query_id] = {
  287. "type": "query",
  288. "query": query_state.query,
  289. "level": query_state.level,
  290. "relevance_score": query_state.relevance_score,
  291. "strategy": query_state.strategy,
  292. "parent_query": query_state.parent_query,
  293. "iteration": iteration,
  294. "is_terminated": query_state.is_terminated,
  295. "no_suggestion_rounds": query_state.no_suggestion_rounds,
  296. "evaluation_reason": evaluation_reason, # 评估原因
  297. "is_selected": is_selected # 是否被选中
  298. }
  299. # 添加边(父子关系)
  300. if query_state.parent_query and parent_level is not None:
  301. # 构造父节点ID: parent_query_parent_level
  302. parent_id = f"{query_state.parent_query}_{parent_level}"
  303. if parent_id in context.query_graph["nodes"]:
  304. context.query_graph["edges"].append({
  305. "from": parent_id,
  306. "to": query_id,
  307. "edge_type": "query_to_query",
  308. "strategy": query_state.strategy,
  309. "score_improvement": query_state.relevance_score - context.query_graph["nodes"][parent_id]["relevance_score"]
  310. })
  311. # 按迭代分组
  312. if iteration not in context.query_graph["iterations"]:
  313. context.query_graph["iterations"][iteration] = []
  314. context.query_graph["iterations"][iteration].append(query_id)
  315. def add_note_to_graph(context: RunContext, query: str, query_level: int, note: dict):
  316. """添加Note节点到演化图,并连接到对应的Query
  317. Args:
  318. context: 运行上下文
  319. query: query文本
  320. query_level: query所在层级
  321. note: 帖子数据
  322. """
  323. note_id = note["note_id"]
  324. # 初始化图结构
  325. if "nodes" not in context.query_graph:
  326. context.query_graph["nodes"] = {}
  327. context.query_graph["edges"] = []
  328. context.query_graph["iterations"] = {}
  329. # 添加Note节点(type: note),包含完整的元信息
  330. context.query_graph["nodes"][note_id] = {
  331. "type": "note",
  332. "note_id": note_id,
  333. "title": note["title"],
  334. "desc": note.get("desc", ""), # 完整描述,不截断
  335. "note_url": note.get("note_url", ""),
  336. "image_list": note.get("image_list", []), # 图片列表
  337. "interact_info": note.get("interact_info", {}), # 互动信息(点赞、收藏、评论、分享)
  338. "match_level": note["evaluation"]["match_level"],
  339. "relevance_score": note["evaluation"]["relevance_score"],
  340. "evaluation_reason": note["evaluation"].get("reason", ""), # 评估原因
  341. "found_by_query": query
  342. }
  343. # 添加边:Query → Note,使用 query_level 格式的ID
  344. query_id = f"{query}_{query_level}"
  345. if query_id in context.query_graph["nodes"]:
  346. context.query_graph["edges"].append({
  347. "from": query_id,
  348. "to": note_id,
  349. "edge_type": "query_to_note",
  350. "match_level": note["evaluation"]["match_level"],
  351. "relevance_score": note["evaluation"]["relevance_score"]
  352. })
  353. def process_note_data(note: dict) -> dict:
  354. """处理搜索接口返回的帖子数据"""
  355. note_card = note.get("note_card", {})
  356. image_list = note_card.get("image_list", [])
  357. interact_info = note_card.get("interact_info", {})
  358. user_info = note_card.get("user", {})
  359. return {
  360. "note_id": note.get("id", ""),
  361. "title": note_card.get("display_title", ""),
  362. "desc": note_card.get("desc", ""),
  363. "image_list": image_list,
  364. "interact_info": {
  365. "liked_count": interact_info.get("liked_count", 0),
  366. "collected_count": interact_info.get("collected_count", 0),
  367. "comment_count": interact_info.get("comment_count", 0),
  368. "shared_count": interact_info.get("shared_count", 0)
  369. },
  370. "user": {
  371. "nickname": user_info.get("nickname", ""),
  372. "user_id": user_info.get("user_id", "")
  373. },
  374. "type": note_card.get("type", "normal"),
  375. "note_url": f"https://www.xiaohongshu.com/explore/{note.get('id', '')}"
  376. }
  377. # ============================================================================
  378. # 核心流程函数
  379. # ============================================================================
  380. async def initialize_word_library(original_query: str, context: RunContext) -> WordLibrary:
  381. """初始化分词库"""
  382. print("\n[初始化] 创建分词库...")
  383. # 使用Agent进行分词
  384. result = await Runner.run(word_segmenter, original_query)
  385. segmentation: WordSegmentation = result.final_output
  386. word_lib = WordLibrary()
  387. # 初始分词标记为核心词(is_core=True)
  388. word_lib.add_words(segmentation.words, source="initial", is_core=True)
  389. print(f"初始分词库(核心词): {list(word_lib.words)}")
  390. print(f"分词理由: {segmentation.reasoning}")
  391. # 保存到context
  392. context.word_library = word_lib.model_dump()
  393. add_step(context, "初始化分词库", "word_library_init", {
  394. "agent": "分词专家",
  395. "input": original_query,
  396. "output": {
  397. "words": segmentation.words,
  398. "reasoning": segmentation.reasoning
  399. },
  400. "result": {
  401. "word_library": list(word_lib.words)
  402. }
  403. })
  404. return word_lib
  405. async def evaluate_query_relevance(
  406. query: str,
  407. original_need: str,
  408. previous_score: float | None = None,
  409. context: RunContext = None
  410. ) -> RelevanceEvaluation:
  411. """评估query与原始需求的相关度"""
  412. eval_input = f"""
  413. <原始需求>
  414. {original_need}
  415. </原始需求>
  416. <当前Query>
  417. {query}
  418. </当前Query>
  419. {"<之前的相关度分数>" + str(previous_score) + "</之前的相关度分数>" if previous_score is not None else ""}
  420. 请评估当前query与原始需求的相关度。
  421. """
  422. result = await Runner.run(relevance_evaluator, eval_input)
  423. evaluation: RelevanceEvaluation = result.final_output
  424. return evaluation
  425. async def process_suggestions(
  426. query: str,
  427. query_state: QueryState,
  428. original_need: str,
  429. word_lib: WordLibrary,
  430. context: RunContext,
  431. xiaohongshu_api: XiaohongshuSearchRecommendations,
  432. iteration: int
  433. ) -> list[QueryState]:
  434. """处理suggestion分支,返回新的query states"""
  435. print(f"\n [Suggestion分支] 处理query: {query}")
  436. # 收集本次分支处理中的所有Agent调用
  437. agent_calls = []
  438. # 1. 获取suggestions
  439. suggestions = xiaohongshu_api.get_recommendations(keyword=query)
  440. if not suggestions or len(suggestions) == 0:
  441. print(f" → 没有获取到suggestion")
  442. query_state.no_suggestion_rounds += 1
  443. # 记录步骤
  444. add_step(context, f"Suggestion分支 - {query}", "suggestion_branch", {
  445. "query": query,
  446. "query_level": query_state.level,
  447. "suggestions_count": 0,
  448. "no_suggestion_rounds": query_state.no_suggestion_rounds,
  449. "new_queries_generated": 0
  450. })
  451. return []
  452. print(f" → 获取到 {len(suggestions)} 个suggestions")
  453. query_state.no_suggestion_rounds = 0 # 重置计数
  454. # 2. 评估每个suggestion
  455. new_queries = []
  456. suggestion_evaluations = []
  457. for sug in suggestions: # 处理所有建议
  458. # 评估sug与原始需求的相关度(注意:这里是与原始需求original_need对比,而非当前query)
  459. # 这样可以确保生成的suggestion始终围绕用户的核心需求
  460. sug_eval = await evaluate_query_relevance(sug, original_need, query_state.relevance_score, context)
  461. sug_eval_record = {
  462. "suggestion": sug,
  463. "relevance_score": sug_eval.relevance_score,
  464. "is_improved": sug_eval.is_improved,
  465. "reason": sug_eval.reason
  466. }
  467. suggestion_evaluations.append(sug_eval_record)
  468. # 创建query state(所有suggestion都作为query节点)
  469. sug_state = QueryState(
  470. query=sug,
  471. level=query_state.level + 1,
  472. relevance_score=sug_eval.relevance_score,
  473. parent_query=query,
  474. strategy="调用sug"
  475. )
  476. # 判断是否比当前query更好(只有提升的才加入待处理队列)
  477. is_selected = sug_eval.is_improved and sug_eval.relevance_score > query_state.relevance_score
  478. # 将所有suggestion添加到演化图(包括未提升的)
  479. add_query_to_graph(
  480. context,
  481. sug_state,
  482. iteration,
  483. evaluation_reason=sug_eval.reason,
  484. is_selected=is_selected,
  485. parent_level=query_state.level # 父节点的层级
  486. )
  487. if is_selected:
  488. print(f" ✓ {sug} (分数: {sug_eval.relevance_score:.2f}, 提升: {sug_eval.is_improved})")
  489. new_queries.append(sug_state)
  490. else:
  491. print(f" ✗ {sug} (分数: {sug_eval.relevance_score:.2f}, 未提升)")
  492. # 3. 改写策略(向上抽象或同义改写)
  493. if len(new_queries) < 3: # 如果直接使用sug的数量不够,尝试改写
  494. # 尝试向上抽象
  495. rewrite_input_abstract = f"""
  496. <当前Query>
  497. {query}
  498. </当前Query>
  499. <改写要求>
  500. 类型: abstract (向上抽象)
  501. </改写要求>
  502. 请改写这个query。
  503. """
  504. result = await Runner.run(query_rewriter, rewrite_input_abstract)
  505. rewrite: QueryRewrite = result.final_output
  506. # 收集改写Agent的输入输出
  507. rewrite_agent_call = {
  508. "agent": "Query改写专家",
  509. "action": "向上抽象改写",
  510. "input": {
  511. "query": query,
  512. "rewrite_type": "abstract"
  513. },
  514. "output": {
  515. "rewritten_query": rewrite.rewritten_query,
  516. "rewrite_type": rewrite.rewrite_type,
  517. "reasoning": rewrite.reasoning
  518. }
  519. }
  520. agent_calls.append(rewrite_agent_call)
  521. # 评估改写后的query
  522. rewrite_eval = await evaluate_query_relevance(rewrite.rewritten_query, original_need, query_state.relevance_score, context)
  523. # 创建改写后的query state
  524. new_state = QueryState(
  525. query=rewrite.rewritten_query,
  526. level=query_state.level + 1,
  527. relevance_score=rewrite_eval.relevance_score,
  528. parent_query=query,
  529. strategy="抽象改写"
  530. )
  531. # 添加到演化图(无论是否提升)
  532. add_query_to_graph(
  533. context,
  534. new_state,
  535. iteration,
  536. evaluation_reason=rewrite_eval.reason,
  537. is_selected=rewrite_eval.is_improved,
  538. parent_level=query_state.level # 父节点的层级
  539. )
  540. if rewrite_eval.is_improved:
  541. print(f" ✓ 改写(抽象): {rewrite.rewritten_query} (分数: {rewrite_eval.relevance_score:.2f})")
  542. new_queries.append(new_state)
  543. else:
  544. print(f" ✗ 改写(抽象): {rewrite.rewritten_query} (分数: {rewrite_eval.relevance_score:.2f}, 未提升)")
  545. # 3.2. 同义改写策略
  546. if len(new_queries) < 4: # 如果还不够,尝试同义改写
  547. rewrite_input_synonym = f"""
  548. <当前Query>
  549. {query}
  550. </当前Query>
  551. <改写要求>
  552. 类型: synonym (同义改写)
  553. 使用同义词或相关表达来改写query,保持语义相同但表达方式不同。
  554. </改写要求>
  555. 请改写这个query。
  556. """
  557. result = await Runner.run(query_rewriter, rewrite_input_synonym)
  558. rewrite_syn: QueryRewrite = result.final_output
  559. # 收集同义改写Agent的输入输出
  560. rewrite_syn_agent_call = {
  561. "agent": "Query改写专家",
  562. "action": "同义改写",
  563. "input": {
  564. "query": query,
  565. "rewrite_type": "synonym"
  566. },
  567. "output": {
  568. "rewritten_query": rewrite_syn.rewritten_query,
  569. "rewrite_type": rewrite_syn.rewrite_type,
  570. "reasoning": rewrite_syn.reasoning
  571. }
  572. }
  573. agent_calls.append(rewrite_syn_agent_call)
  574. # 评估改写后的query
  575. rewrite_syn_eval = await evaluate_query_relevance(rewrite_syn.rewritten_query, original_need, query_state.relevance_score, context)
  576. # 创建改写后的query state
  577. new_state = QueryState(
  578. query=rewrite_syn.rewritten_query,
  579. level=query_state.level + 1,
  580. relevance_score=rewrite_syn_eval.relevance_score,
  581. parent_query=query,
  582. strategy="同义改写"
  583. )
  584. # 添加到演化图(无论是否提升)
  585. add_query_to_graph(
  586. context,
  587. new_state,
  588. iteration,
  589. evaluation_reason=rewrite_syn_eval.reason,
  590. is_selected=rewrite_syn_eval.is_improved,
  591. parent_level=query_state.level # 父节点的层级
  592. )
  593. if rewrite_syn_eval.is_improved:
  594. print(f" ✓ 改写(同义): {rewrite_syn.rewritten_query} (分数: {rewrite_syn_eval.relevance_score:.2f})")
  595. new_queries.append(new_state)
  596. else:
  597. print(f" ✗ 改写(同义): {rewrite_syn.rewritten_query} (分数: {rewrite_syn_eval.relevance_score:.2f}, 未提升)")
  598. # 4. 加词策略(优先使用核心词)
  599. unused_word = word_lib.get_unused_word(query, prefer_core=True)
  600. is_core_word = unused_word in word_lib.core_words if unused_word else False
  601. if unused_word and len(new_queries) < 5:
  602. word_type = "核心词" if is_core_word else "普通词"
  603. insertion_input = f"""
  604. <当前Query>
  605. {query}
  606. </当前Query>
  607. <要添加的词>
  608. {unused_word}
  609. </要添加的词>
  610. 请将这个词加到query的最合适位置。
  611. """
  612. result = await Runner.run(word_inserter, insertion_input)
  613. insertion: WordInsertion = result.final_output
  614. # 收集加词Agent的输入输出
  615. insertion_agent_call = {
  616. "agent": "加词位置评估专家",
  617. "action": f"加词({word_type})",
  618. "input": {
  619. "query": query,
  620. "word_to_add": unused_word,
  621. "is_core_word": is_core_word
  622. },
  623. "output": {
  624. "new_query": insertion.new_query,
  625. "insertion_position": insertion.insertion_position,
  626. "reasoning": insertion.reasoning
  627. }
  628. }
  629. agent_calls.append(insertion_agent_call)
  630. # 评估加词后的query
  631. insertion_eval = await evaluate_query_relevance(insertion.new_query, original_need, query_state.relevance_score, context)
  632. # 创建加词后的query state
  633. new_state = QueryState(
  634. query=insertion.new_query,
  635. level=query_state.level + 1,
  636. relevance_score=insertion_eval.relevance_score,
  637. parent_query=query,
  638. strategy="加词"
  639. )
  640. # 添加到演化图(无论是否提升)
  641. add_query_to_graph(
  642. context,
  643. new_state,
  644. iteration,
  645. evaluation_reason=insertion_eval.reason,
  646. is_selected=insertion_eval.is_improved,
  647. parent_level=query_state.level # 父节点的层级
  648. )
  649. if insertion_eval.is_improved:
  650. print(f" ✓ 加词({word_type}): {insertion.new_query} [+{unused_word}] (分数: {insertion_eval.relevance_score:.2f})")
  651. new_queries.append(new_state)
  652. else:
  653. print(f" ✗ 加词({word_type}): {insertion.new_query} [+{unused_word}] (分数: {insertion_eval.relevance_score:.2f}, 未提升)")
  654. # 记录完整的suggestion分支处理结果(层级化)
  655. add_step(context, f"Suggestion分支 - {query}", "suggestion_branch", {
  656. "query": query,
  657. "query_level": query_state.level,
  658. "query_relevance": query_state.relevance_score,
  659. "suggestions_count": len(suggestions),
  660. "suggestions_evaluated": len(suggestion_evaluations),
  661. "suggestion_evaluations": suggestion_evaluations, # 保存所有评估
  662. "agent_calls": agent_calls, # 所有Agent调用的详细记录
  663. "new_queries_generated": len(new_queries),
  664. "new_queries": [{"query": nq.query, "score": nq.relevance_score, "strategy": nq.strategy} for nq in new_queries],
  665. "no_suggestion_rounds": query_state.no_suggestion_rounds
  666. })
  667. return new_queries
  668. async def process_search_results(
  669. query: str,
  670. query_state: QueryState,
  671. original_need: str,
  672. word_lib: WordLibrary,
  673. context: RunContext,
  674. xiaohongshu_search: XiaohongshuSearch,
  675. relevance_threshold: float,
  676. iteration: int
  677. ) -> tuple[list[dict], list[QueryState]]:
  678. """
  679. 处理搜索结果分支
  680. 返回: (满足需求的notes, 需要继续迭代的新queries)
  681. """
  682. print(f"\n [Result分支] 搜索query: {query}")
  683. # 收集本次分支处理中的所有Agent调用
  684. agent_calls = []
  685. # 1. 判断query相关度是否达到门槛
  686. if query_state.relevance_score < relevance_threshold:
  687. print(f" ✗ 相关度 {query_state.relevance_score:.2f} 低于门槛 {relevance_threshold},跳过搜索")
  688. return [], []
  689. print(f" ✓ 相关度 {query_state.relevance_score:.2f} 达到门槛,执行搜索")
  690. # 2. 执行搜索
  691. try:
  692. search_result = xiaohongshu_search.search(keyword=query)
  693. result_str = search_result.get("result", "{}")
  694. if isinstance(result_str, str):
  695. result_data = json.loads(result_str)
  696. else:
  697. result_data = result_str
  698. notes = result_data.get("data", {}).get("data", [])
  699. print(f" → 搜索到 {len(notes)} 个帖子")
  700. except Exception as e:
  701. print(f" ✗ 搜索失败: {e}")
  702. return [], []
  703. if not notes:
  704. return [], []
  705. # 3. 评估每个帖子
  706. satisfied_notes = []
  707. partial_notes = []
  708. for note in notes: # 评估所有帖子
  709. note_data = process_note_data(note)
  710. title = note_data["title"] or ""
  711. desc = note_data["desc"] or ""
  712. # 跳过空标题和描述的帖子
  713. if not title and not desc:
  714. continue
  715. # 评估帖子
  716. eval_input = f"""
  717. <原始需求>
  718. {original_need}
  719. </原始需求>
  720. <帖子>
  721. 标题: {title}
  722. 描述: {desc}
  723. </帖子>
  724. 请评估这个帖子与原始需求的匹配程度。
  725. """
  726. result = await Runner.run(result_evaluator, eval_input)
  727. evaluation: ResultEvaluation = result.final_output
  728. # 收集Result评估Agent的输入输出
  729. result_eval_agent_call = {
  730. "agent": "Result匹配度评估专家",
  731. "action": "评估帖子匹配度",
  732. "input": {
  733. "note_id": note_data.get("note_id"),
  734. "title": title,
  735. "desc": desc # 完整描述
  736. },
  737. "output": {
  738. "match_level": evaluation.match_level,
  739. "relevance_score": evaluation.relevance_score,
  740. "missing_aspects": evaluation.missing_aspects,
  741. "reason": evaluation.reason
  742. }
  743. }
  744. agent_calls.append(result_eval_agent_call)
  745. note_data["evaluation"] = {
  746. "match_level": evaluation.match_level,
  747. "relevance_score": evaluation.relevance_score,
  748. "missing_aspects": evaluation.missing_aspects,
  749. "reason": evaluation.reason
  750. }
  751. # 将所有评估过的帖子添加到演化图(包括satisfied、partial、unsatisfied)
  752. add_note_to_graph(context, query, query_state.level, note_data)
  753. if evaluation.match_level == "satisfied":
  754. satisfied_notes.append(note_data)
  755. print(f" ✓ 满足: {title[:30] if len(title) > 30 else title}... (分数: {evaluation.relevance_score:.2f})")
  756. elif evaluation.match_level == "partial":
  757. partial_notes.append(note_data)
  758. print(f" ~ 部分: {title[:30] if len(title) > 30 else title}... (缺失: {', '.join(evaluation.missing_aspects[:2])})")
  759. else: # unsatisfied
  760. print(f" ✗ 不满足: {title[:30] if len(title) > 30 else title}... (分数: {evaluation.relevance_score:.2f})")
  761. # 4. 处理满足的帖子:不再扩充分词库(避免无限扩张)
  762. new_queries = []
  763. if satisfied_notes:
  764. print(f"\n ✓ 找到 {len(satisfied_notes)} 个满足的帖子,不再提取关键词入库")
  765. # 注释掉关键词提取逻辑,保持分词库稳定
  766. # for note in satisfied_notes[:3]:
  767. # extract_input = f"""
  768. # <帖子>
  769. # 标题: {note['title']}
  770. # 描述: {note['desc']}
  771. # </帖子>
  772. #
  773. # 请提取核心关键词。
  774. # """
  775. # result = await Runner.run(keyword_extractor, extract_input)
  776. # extraction: KeywordExtraction = result.final_output
  777. #
  778. # # 添加新词到分词库,标记来源
  779. # note_id = note.get('note_id', 'unknown')
  780. # for keyword in extraction.keywords:
  781. # if keyword not in word_lib.words:
  782. # word_lib.add_word(keyword, source=f"note:{note_id}")
  783. # print(f" + 新词入库: {keyword} (来源: {note_id})")
  784. # 5. 处理部分匹配的帖子:改造query
  785. if partial_notes and len(satisfied_notes) < 5: # 如果满足的不够,基于部分匹配改进
  786. print(f"\n 基于 {len(partial_notes)} 个部分匹配帖子改造query...")
  787. # 收集所有缺失方面
  788. all_missing = []
  789. for note in partial_notes:
  790. all_missing.extend(note["evaluation"]["missing_aspects"])
  791. if all_missing:
  792. improvement_input = f"""
  793. <当前Query>
  794. {query}
  795. </当前Query>
  796. <缺失的方面>
  797. {', '.join(set(all_missing))}
  798. </缺失的方面>
  799. 请改造query使其包含这些缺失的内容。
  800. """
  801. result = await Runner.run(query_improver, improvement_input)
  802. improvement: QueryImprovement = result.final_output
  803. # 收集Query改造Agent的输入输出
  804. improvement_agent_call = {
  805. "agent": "Query改造专家",
  806. "action": "基于缺失方面改造Query",
  807. "input": {
  808. "query": query,
  809. "missing_aspects": list(set(all_missing))
  810. },
  811. "output": {
  812. "improved_query": improvement.improved_query,
  813. "added_aspects": improvement.added_aspects,
  814. "reasoning": improvement.reasoning
  815. }
  816. }
  817. agent_calls.append(improvement_agent_call)
  818. # 评估改进后的query
  819. improved_eval = await evaluate_query_relevance(improvement.improved_query, original_need, query_state.relevance_score, context)
  820. # 创建改进后的query state
  821. new_state = QueryState(
  822. query=improvement.improved_query,
  823. level=query_state.level + 1,
  824. relevance_score=improved_eval.relevance_score,
  825. parent_query=query,
  826. strategy="基于部分匹配改进"
  827. )
  828. # 添加到演化图(无论是否提升)
  829. add_query_to_graph(
  830. context,
  831. new_state,
  832. iteration,
  833. evaluation_reason=improved_eval.reason,
  834. is_selected=improved_eval.is_improved,
  835. parent_level=query_state.level # 父节点的层级
  836. )
  837. if improved_eval.is_improved:
  838. print(f" ✓ 改进: {improvement.improved_query} (添加: {', '.join(improvement.added_aspects[:2])})")
  839. new_queries.append(new_state)
  840. else:
  841. print(f" ✗ 改进: {improvement.improved_query} (分数: {improved_eval.relevance_score:.2f}, 未提升)")
  842. # 6. Result分支的改写策略(向上抽象和同义改写)
  843. # 如果搜索结果不理想且新queries不够,尝试改写当前query
  844. if len(satisfied_notes) < 3 and len(new_queries) < 2:
  845. print(f"\n 搜索结果不理想,尝试改写query...")
  846. # 6.1 向上抽象
  847. if len(new_queries) < 3:
  848. rewrite_input_abstract = f"""
  849. <当前Query>
  850. {query}
  851. </当前Query>
  852. <改写要求>
  853. 类型: abstract (向上抽象)
  854. </改写要求>
  855. 请改写这个query。
  856. """
  857. result = await Runner.run(query_rewriter, rewrite_input_abstract)
  858. rewrite: QueryRewrite = result.final_output
  859. # 收集Result分支改写(抽象)Agent的输入输出
  860. rewrite_agent_call = {
  861. "agent": "Query改写专家",
  862. "action": "向上抽象改写(Result分支)",
  863. "input": {
  864. "query": query,
  865. "rewrite_type": "abstract"
  866. },
  867. "output": {
  868. "rewritten_query": rewrite.rewritten_query,
  869. "rewrite_type": rewrite.rewrite_type,
  870. "reasoning": rewrite.reasoning
  871. }
  872. }
  873. agent_calls.append(rewrite_agent_call)
  874. # 评估改写后的query
  875. rewrite_eval = await evaluate_query_relevance(rewrite.rewritten_query, original_need, query_state.relevance_score, context)
  876. # 创建改写后的query state
  877. new_state = QueryState(
  878. query=rewrite.rewritten_query,
  879. level=query_state.level + 1,
  880. relevance_score=rewrite_eval.relevance_score,
  881. parent_query=query,
  882. strategy="结果分支-抽象改写"
  883. )
  884. # 添加到演化图(无论是否提升)
  885. add_query_to_graph(
  886. context,
  887. new_state,
  888. iteration,
  889. evaluation_reason=rewrite_eval.reason,
  890. is_selected=rewrite_eval.is_improved,
  891. parent_level=query_state.level # 父节点的层级
  892. )
  893. if rewrite_eval.is_improved:
  894. print(f" ✓ 改写(抽象): {rewrite.rewritten_query} (分数: {rewrite_eval.relevance_score:.2f})")
  895. new_queries.append(new_state)
  896. else:
  897. print(f" ✗ 改写(抽象): {rewrite.rewritten_query} (分数: {rewrite_eval.relevance_score:.2f}, 未提升)")
  898. # 6.2 同义改写
  899. if len(new_queries) < 4:
  900. rewrite_input_synonym = f"""
  901. <当前Query>
  902. {query}
  903. </当前Query>
  904. <改写要求>
  905. 类型: synonym (同义改写)
  906. 使用同义词或相关表达来改写query,保持语义相同但表达方式不同。
  907. </改写要求>
  908. 请改写这个query。
  909. """
  910. result = await Runner.run(query_rewriter, rewrite_input_synonym)
  911. rewrite_syn: QueryRewrite = result.final_output
  912. # 收集Result分支改写(同义)Agent的输入输出
  913. rewrite_syn_agent_call = {
  914. "agent": "Query改写专家",
  915. "action": "同义改写(Result分支)",
  916. "input": {
  917. "query": query,
  918. "rewrite_type": "synonym"
  919. },
  920. "output": {
  921. "rewritten_query": rewrite_syn.rewritten_query,
  922. "rewrite_type": rewrite_syn.rewrite_type,
  923. "reasoning": rewrite_syn.reasoning
  924. }
  925. }
  926. agent_calls.append(rewrite_syn_agent_call)
  927. # 评估改写后的query
  928. rewrite_syn_eval = await evaluate_query_relevance(rewrite_syn.rewritten_query, original_need, query_state.relevance_score, context)
  929. # 创建改写后的query state
  930. new_state = QueryState(
  931. query=rewrite_syn.rewritten_query,
  932. level=query_state.level + 1,
  933. relevance_score=rewrite_syn_eval.relevance_score,
  934. parent_query=query,
  935. strategy="结果分支-同义改写"
  936. )
  937. # 添加到演化图(无论是否提升)
  938. add_query_to_graph(
  939. context,
  940. new_state,
  941. iteration,
  942. evaluation_reason=rewrite_syn_eval.reason,
  943. is_selected=rewrite_syn_eval.is_improved,
  944. parent_level=query_state.level # 父节点的层级
  945. )
  946. if rewrite_syn_eval.is_improved:
  947. print(f" ✓ 改写(同义): {rewrite_syn.rewritten_query} (分数: {rewrite_syn_eval.relevance_score:.2f})")
  948. new_queries.append(new_state)
  949. else:
  950. print(f" ✗ 改写(同义): {rewrite_syn.rewritten_query} (分数: {rewrite_syn_eval.relevance_score:.2f}, 未提升)")
  951. # 记录完整的result分支处理结果(层级化)
  952. add_step(context, f"Result分支 - {query}", "result_branch", {
  953. "query": query,
  954. "query_level": query_state.level,
  955. "query_relevance": query_state.relevance_score,
  956. "relevance_threshold": relevance_threshold,
  957. "passed_threshold": query_state.relevance_score >= relevance_threshold,
  958. "notes_count": len(notes) if 'notes' in locals() else 0,
  959. "satisfied_count": len(satisfied_notes),
  960. "partial_count": len(partial_notes),
  961. "satisfied_notes": [
  962. {
  963. "note_id": note["note_id"],
  964. "title": note["title"],
  965. "score": note["evaluation"]["relevance_score"],
  966. "match_level": note["evaluation"]["match_level"]
  967. }
  968. for note in satisfied_notes # 保存所有满足的帖子
  969. ],
  970. "agent_calls": agent_calls, # 所有Agent调用的详细记录
  971. "new_queries_generated": len(new_queries),
  972. "new_queries": [{"query": nq.query, "score": nq.relevance_score, "strategy": nq.strategy} for nq in new_queries]
  973. })
  974. return satisfied_notes, new_queries
  975. async def iterative_search_loop(
  976. context: RunContext,
  977. max_iterations: int = 20,
  978. relevance_threshold: float = 0.6
  979. ) -> list[dict]:
  980. """
  981. 主循环:迭代搜索(按层级处理)
  982. Args:
  983. context: 运行上下文
  984. max_iterations: 最大迭代次数(层级数)
  985. relevance_threshold: 相关度门槛
  986. Returns:
  987. 满足需求的帖子列表
  988. """
  989. print(f"\n{'='*60}")
  990. print(f"开始迭代搜索循环")
  991. print(f"{'='*60}")
  992. # 0. 添加原始问题作为根节点
  993. root_query_state = QueryState(
  994. query=context.q,
  995. level=0,
  996. relevance_score=1.0, # 原始问题本身相关度为1.0
  997. strategy="根节点"
  998. )
  999. add_query_to_graph(context, root_query_state, 0, evaluation_reason="原始问题,作为搜索的根节点", is_selected=True)
  1000. print(f"[根节点] 原始问题: {context.q}")
  1001. # 1. 初始化分词库
  1002. word_lib = await initialize_word_library(context.q, context)
  1003. # 2. 初始化query队列 - 智能选择最相关的词
  1004. all_words = list(word_lib.words)
  1005. query_queue = []
  1006. print(f"\n评估所有初始分词的相关度...")
  1007. word_scores = []
  1008. for word in all_words:
  1009. # 评估每个词的相关度
  1010. eval_result = await evaluate_query_relevance(word, context.q, None, context)
  1011. word_scores.append({
  1012. 'word': word,
  1013. 'score': eval_result.relevance_score,
  1014. 'eval': eval_result
  1015. })
  1016. print(f" {word}: {eval_result.relevance_score:.2f}")
  1017. # 按相关度排序,使用所有分词
  1018. word_scores.sort(key=lambda x: x['score'], reverse=True)
  1019. selected_words = word_scores # 使用所有分词
  1020. # 将所有分词添加到演化图(全部被选中)
  1021. for item in word_scores:
  1022. is_selected = True # 所有分词都被选中
  1023. query_state = QueryState(
  1024. query=item['word'],
  1025. level=1,
  1026. relevance_score=item['score'],
  1027. strategy="初始分词",
  1028. parent_query=context.q # 父节点是原始问题
  1029. )
  1030. # 添加到演化图(会自动创建从parent_query到该query的边)
  1031. add_query_to_graph(context, query_state, 0, evaluation_reason=item['eval'].reason, is_selected=is_selected, parent_level=0) # 父节点是根节点(level 0)
  1032. # 只有被选中的才加入队列
  1033. if is_selected:
  1034. query_queue.append(query_state)
  1035. print(f"\n初始query队列(按相关度排序): {[(q.query, f'{q.relevance_score:.2f}') for q in query_queue]}")
  1036. print(f" (共评估了 {len(word_scores)} 个分词,全部加入队列)")
  1037. # 3. API实例
  1038. xiaohongshu_api = XiaohongshuSearchRecommendations()
  1039. xiaohongshu_search = XiaohongshuSearch()
  1040. # 4. 主循环
  1041. all_satisfied_notes = []
  1042. iteration = 0
  1043. while query_queue and iteration < max_iterations:
  1044. iteration += 1
  1045. # 获取当前层级(队列中最小的level)
  1046. current_level = min(q.level for q in query_queue)
  1047. # 提取当前层级的所有query
  1048. current_batch = [q for q in query_queue if q.level == current_level]
  1049. query_queue = [q for q in query_queue if q.level != current_level]
  1050. print(f"\n{'='*60}")
  1051. print(f"迭代 {iteration}: 处理第 {current_level} 层,共 {len(current_batch)} 个query")
  1052. print(f"{'='*60}")
  1053. # 记录本轮处理的queries
  1054. add_step(context, f"迭代 {iteration}", "iteration", {
  1055. "iteration": iteration,
  1056. "current_level": current_level,
  1057. "current_batch_size": len(current_batch),
  1058. "remaining_queue_size": len(query_queue),
  1059. "processing_queries": [{"query": q.query, "level": q.level} for q in current_batch]
  1060. })
  1061. new_queries_from_sug = []
  1062. new_queries_from_result = []
  1063. # 处理每个query
  1064. for query_state in current_batch:
  1065. print(f"\n处理Query [{query_state.level}]: {query_state.query} (分数: {query_state.relevance_score:.2f})")
  1066. # 检查终止条件
  1067. if query_state.is_terminated or query_state.no_suggestion_rounds >= 2:
  1068. print(f" ✗ 已终止或连续2轮无suggestion,跳过该query")
  1069. query_state.is_terminated = True
  1070. continue
  1071. # 并行处理两个分支
  1072. sug_task = process_suggestions(
  1073. query_state.query, query_state, context.q, word_lib, context, xiaohongshu_api, iteration
  1074. )
  1075. result_task = process_search_results(
  1076. query_state.query, query_state, context.q, word_lib, context,
  1077. xiaohongshu_search, relevance_threshold, iteration
  1078. )
  1079. # 等待两个分支完成
  1080. sug_queries, (satisfied_notes, result_queries) = await asyncio.gather(
  1081. sug_task,
  1082. result_task
  1083. )
  1084. # 如果suggestion分支返回空,说明没有获取到suggestion,需要继承no_suggestion_rounds
  1085. # 注意:process_suggestions内部已经更新了query_state.no_suggestion_rounds
  1086. # 所以这里生成的新queries需要继承父query的no_suggestion_rounds(如果sug分支也返回空)
  1087. if not sug_queries and not result_queries:
  1088. # 两个分支都没有产生新query,标记当前query为终止
  1089. query_state.is_terminated = True
  1090. print(f" ⚠ 两个分支均未产生新query,标记该query为终止")
  1091. new_queries_from_sug.extend(sug_queries)
  1092. new_queries_from_result.extend(result_queries)
  1093. all_satisfied_notes.extend(satisfied_notes)
  1094. # 更新队列
  1095. all_new_queries = new_queries_from_sug + new_queries_from_result
  1096. # 注意:不需要在这里再次添加到演化图,因为在 process_suggestions 和 process_search_results 中已经添加过了
  1097. # 如果在这里再次调用 add_query_to_graph,会覆盖之前设置的 evaluation_reason 等字段
  1098. query_queue.extend(all_new_queries)
  1099. # 去重(基于query文本)并过滤已终止的query
  1100. seen = set()
  1101. unique_queue = []
  1102. for q in query_queue:
  1103. if q.query not in seen and not q.is_terminated:
  1104. seen.add(q.query)
  1105. unique_queue.append(q)
  1106. query_queue = unique_queue
  1107. # 按相关度排序
  1108. query_queue.sort(key=lambda x: x.relevance_score, reverse=True)
  1109. print(f"\n本轮结果:")
  1110. print(f" 新增满足帖子: {len(satisfied_notes)}")
  1111. print(f" 累计满足帖子: {len(all_satisfied_notes)}")
  1112. print(f" 新增queries: {len(all_new_queries)}")
  1113. print(f" 队列剩余: {len(query_queue)}")
  1114. # 更新分词库到context
  1115. context.word_library = word_lib.model_dump()
  1116. # 如果满足条件的帖子足够多,可以提前结束
  1117. if len(all_satisfied_notes) >= 20:
  1118. print(f"\n已找到足够的满足帖子 ({len(all_satisfied_notes)}个),提前结束")
  1119. break
  1120. print(f"\n{'='*60}")
  1121. print(f"迭代搜索完成")
  1122. print(f" 总迭代次数: {iteration}")
  1123. print(f" 最终满足帖子数: {len(all_satisfied_notes)}")
  1124. print(f" 核心词库: {list(word_lib.core_words)}")
  1125. print(f" 最终分词库大小: {len(word_lib.words)}")
  1126. print(f"{'='*60}")
  1127. # 保存最终结果
  1128. add_step(context, "迭代搜索完成", "loop_complete", {
  1129. "total_iterations": iteration,
  1130. "total_satisfied_notes": len(all_satisfied_notes),
  1131. "core_words": list(word_lib.core_words),
  1132. "final_word_library_size": len(word_lib.words),
  1133. "final_word_library": list(word_lib.words)
  1134. })
  1135. return all_satisfied_notes
  1136. # ============================================================================
  1137. # 主函数
  1138. # ============================================================================
  1139. async def main(input_dir: str, max_iterations: int = 20, visualize: bool = False):
  1140. """主函数"""
  1141. current_time, log_url = set_trace()
  1142. # 读取输入
  1143. input_context_file = os.path.join(input_dir, 'context.md')
  1144. input_q_file = os.path.join(input_dir, 'q.md')
  1145. q_context = read_file_as_string(input_context_file)
  1146. q = read_file_as_string(input_q_file)
  1147. q_with_context = f"""
  1148. <需求上下文>
  1149. {q_context}
  1150. </需求上下文>
  1151. <当前问题>
  1152. {q}
  1153. </当前问题>
  1154. """.strip()
  1155. # 版本信息
  1156. version = os.path.basename(__file__)
  1157. version_name = os.path.splitext(version)[0]
  1158. # 日志目录
  1159. log_dir = os.path.join(input_dir, "output", version_name, current_time)
  1160. # 创建运行上下文
  1161. run_context = RunContext(
  1162. version=version,
  1163. input_files={
  1164. "input_dir": input_dir,
  1165. "context_file": input_context_file,
  1166. "q_file": input_q_file,
  1167. },
  1168. q_with_context=q_with_context,
  1169. q_context=q_context,
  1170. q=q,
  1171. log_dir=log_dir,
  1172. log_url=log_url,
  1173. )
  1174. # 执行迭代搜索
  1175. satisfied_notes = await iterative_search_loop(
  1176. run_context,
  1177. max_iterations=max_iterations,
  1178. relevance_threshold=0.6
  1179. )
  1180. # 保存结果
  1181. run_context.satisfied_notes = satisfied_notes
  1182. # 格式化输出
  1183. output = f"原始问题:{run_context.q}\n"
  1184. output += f"找到满足需求的帖子:{len(satisfied_notes)} 个\n"
  1185. output += f"核心词库:{', '.join(run_context.word_library.get('core_words', []))}\n"
  1186. output += f"分词库大小:{len(run_context.word_library.get('words', []))} 个词\n"
  1187. output += "\n" + "="*60 + "\n"
  1188. if satisfied_notes:
  1189. output += "【满足需求的帖子】\n\n"
  1190. for idx, note in enumerate(satisfied_notes, 1):
  1191. output += f"{idx}. {note['title']}\n"
  1192. output += f" 相关度: {note['evaluation']['relevance_score']:.2f}\n"
  1193. output += f" URL: {note['note_url']}\n\n"
  1194. else:
  1195. output += "未找到满足需求的帖子\n"
  1196. run_context.final_output = output
  1197. print(f"\n{'='*60}")
  1198. print("最终结果")
  1199. print(f"{'='*60}")
  1200. print(output)
  1201. # 保存日志
  1202. os.makedirs(run_context.log_dir, exist_ok=True)
  1203. context_file_path = os.path.join(run_context.log_dir, "run_context.json")
  1204. context_dict = run_context.model_dump()
  1205. with open(context_file_path, "w", encoding="utf-8") as f:
  1206. json.dump(context_dict, f, ensure_ascii=False, indent=2)
  1207. print(f"\nRunContext saved to: {context_file_path}")
  1208. steps_file_path = os.path.join(run_context.log_dir, "steps.json")
  1209. with open(steps_file_path, "w", encoding="utf-8") as f:
  1210. json.dump(run_context.steps, f, ensure_ascii=False, indent=2)
  1211. print(f"Steps log saved to: {steps_file_path}")
  1212. # 保存Query演化图
  1213. query_graph_file_path = os.path.join(run_context.log_dir, "query_graph.json")
  1214. with open(query_graph_file_path, "w", encoding="utf-8") as f:
  1215. json.dump(run_context.query_graph, f, ensure_ascii=False, indent=2)
  1216. print(f"Query graph saved to: {query_graph_file_path}")
  1217. # 可视化
  1218. if visualize:
  1219. import subprocess
  1220. output_html = os.path.join(run_context.log_dir, "visualization.html")
  1221. print(f"\n🎨 生成可视化HTML...")
  1222. # 获取绝对路径
  1223. vis_script = os.path.abspath("visualization/sug_v6_1_2_6/index.js")
  1224. abs_query_graph = os.path.abspath(query_graph_file_path)
  1225. abs_output_html = os.path.abspath(output_html)
  1226. # 在可视化脚本目录中执行,确保使用本地 node_modules
  1227. result = subprocess.run([
  1228. "node", "index.js",
  1229. abs_query_graph,
  1230. abs_output_html
  1231. ], cwd="visualization/sug_v6_1_2_6")
  1232. if result.returncode == 0:
  1233. print(f"✅ 可视化已生成: {output_html}")
  1234. else:
  1235. print(f"❌ 可视化生成失败")
  1236. if __name__ == "__main__":
  1237. parser = argparse.ArgumentParser(description="搜索query优化工具 - v6.1.2.5 迭代循环版")
  1238. parser.add_argument(
  1239. "--input-dir",
  1240. type=str,
  1241. default="input/简单扣图",
  1242. help="输入目录路径,默认: input/简单扣图"
  1243. )
  1244. parser.add_argument(
  1245. "--max-iterations",
  1246. type=int,
  1247. default=20,
  1248. help="最大迭代次数,默认: 20"
  1249. )
  1250. parser.add_argument(
  1251. "--visualize",
  1252. action="store_true",
  1253. default=False,
  1254. help="运行完成后自动生成可视化HTML"
  1255. )
  1256. parser.add_argument(
  1257. "--visualize-only",
  1258. type=str,
  1259. help="仅生成可视化,指定query_graph.json文件路径"
  1260. )
  1261. args = parser.parse_args()
  1262. # 如果只是生成可视化
  1263. if args.visualize_only:
  1264. import subprocess
  1265. query_graph_path = args.visualize_only
  1266. output_html = os.path.splitext(query_graph_path)[0].replace("query_graph", "visualization") + ".html"
  1267. if not output_html.endswith(".html"):
  1268. output_html = os.path.join(os.path.dirname(query_graph_path), "visualization.html")
  1269. print(f"🎨 生成可视化HTML...")
  1270. print(f"输入: {query_graph_path}")
  1271. print(f"输出: {output_html}")
  1272. # 获取绝对路径
  1273. abs_query_graph = os.path.abspath(query_graph_path)
  1274. abs_output_html = os.path.abspath(output_html)
  1275. # 在可视化脚本目录中执行,确保使用本地 node_modules
  1276. result = subprocess.run([
  1277. "node", "index.js",
  1278. abs_query_graph,
  1279. abs_output_html
  1280. ], cwd="visualization/sug_v6_1_2_6")
  1281. if result.returncode == 0:
  1282. print(f"✅ 可视化已生成: {output_html}")
  1283. else:
  1284. print(f"❌ 可视化生成失败")
  1285. sys.exit(result.returncode)
  1286. asyncio.run(main(args.input_dir, max_iterations=args.max_iterations, visualize=args.visualize))