sug_v6_1_2_8.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831
  1. import asyncio
  2. import json
  3. import os
  4. import sys
  5. import argparse
  6. from datetime import datetime
  7. from typing import Literal
  8. from agents import Agent, Runner
  9. from lib.my_trace import set_trace
  10. from pydantic import BaseModel, Field
  11. from lib.utils import read_file_as_string
  12. from lib.client import get_model
  13. MODEL_NAME = "google/gemini-2.5-flash"
  14. from script.search_recommendations.xiaohongshu_search_recommendations import XiaohongshuSearchRecommendations
  15. from script.search.xiaohongshu_search import XiaohongshuSearch
  16. # ============================================================================
  17. # 数据模型
  18. # ============================================================================
  19. class Seg(BaseModel):
  20. """分词"""
  21. text: str
  22. score_with_o: float = 0.0 # 与原始问题的评分
  23. reason: str = "" # 评分理由
  24. from_o: str = "" # 原始问题
  25. class Word(BaseModel):
  26. """词"""
  27. text: str
  28. score_with_o: float = 0.0 # 与原始问题的评分
  29. from_o: str = "" # 原始问题
  30. class QFromQ(BaseModel):
  31. """Q来源信息(用于Sug中记录)"""
  32. text: str
  33. score_with_o: float = 0.0
  34. class Q(BaseModel):
  35. """查询"""
  36. text: str
  37. score_with_o: float = 0.0 # 与原始问题的评分
  38. reason: str = "" # 评分理由
  39. from_source: str = "" # seg/sug/add(加词)
  40. class Sug(BaseModel):
  41. """建议词"""
  42. text: str
  43. score_with_o: float = 0.0 # 与原始问题的评分
  44. reason: str = "" # 评分理由
  45. from_q: QFromQ | None = None # 来自的q
  46. class Seed(BaseModel):
  47. """种子"""
  48. text: str
  49. added_words: list[str] = Field(default_factory=list) # 已经增加的words
  50. from_type: str = "" # seg/sug
  51. score_with_o: float = 0.0 # 与原始问题的评分
  52. class Post(BaseModel):
  53. """帖子"""
  54. title: str = ""
  55. body_text: str = ""
  56. type: str = "normal" # video/normal
  57. images: list[str] = Field(default_factory=list) # 图片url列表,第一张为封面
  58. video: str = "" # 视频url
  59. interact_info: dict = Field(default_factory=dict) # 互动信息
  60. note_id: str = ""
  61. note_url: str = ""
  62. class Search(Sug):
  63. """搜索结果(继承Sug)"""
  64. post_list: list[Post] = Field(default_factory=list) # 搜索得到的帖子列表
  65. class RunContext(BaseModel):
  66. """运行上下文"""
  67. version: str
  68. input_files: dict[str, str]
  69. c: str # 原始需求
  70. o: str # 原始问题
  71. log_url: str
  72. log_dir: str
  73. # 每轮的数据
  74. rounds: list[dict] = Field(default_factory=list) # 每轮的详细数据
  75. # 最终结果
  76. final_output: str | None = None
  77. # ============================================================================
  78. # Agent 定义
  79. # ============================================================================
  80. # Agent 1: 分词专家
  81. class WordSegmentation(BaseModel):
  82. """分词结果"""
  83. words: list[str] = Field(..., description="分词结果列表")
  84. reasoning: str = Field(..., description="分词理由")
  85. word_segmentation_instructions = """
  86. 你是分词专家。给定一个query,将其拆分成有意义的最小单元。
  87. ## 分词原则
  88. 1. 保留有搜索意义的词汇
  89. 2. 拆分成独立的概念
  90. 3. 保留专业术语的完整性
  91. 4. 去除虚词(的、吗、呢等)
  92. ## 输出要求
  93. 返回分词列表和分词理由。
  94. """.strip()
  95. word_segmenter = Agent[None](
  96. name="分词专家",
  97. instructions=word_segmentation_instructions,
  98. model=get_model(MODEL_NAME),
  99. output_type=WordSegmentation,
  100. )
  101. # Agent 2: 相关度评估专家
  102. class RelevanceEvaluation(BaseModel):
  103. """相关度评估"""
  104. relevance_score: float = Field(..., description="相关性分数 0-1")
  105. reason: str = Field(..., description="评估理由")
  106. relevance_evaluation_instructions = """
  107. 你是相关度评估专家。
  108. ## 任务
  109. 评估当前文本与原始问题的匹配程度。
  110. ## 评估标准
  111. - 主题相关性
  112. - 要素覆盖度
  113. - 意图匹配度
  114. ## 输出
  115. - relevance_score: 0-1的相关性分数
  116. - reason: 详细理由
  117. """.strip()
  118. relevance_evaluator = Agent[None](
  119. name="相关度评估专家",
  120. instructions=relevance_evaluation_instructions,
  121. model=get_model(MODEL_NAME),
  122. output_type=RelevanceEvaluation,
  123. )
  124. # Agent 3: 加词选择专家
  125. class WordSelection(BaseModel):
  126. """加词选择结果"""
  127. selected_word: str = Field(..., description="选择的词")
  128. combined_query: str = Field(..., description="组合后的新query")
  129. reasoning: str = Field(..., description="选择理由")
  130. word_selection_instructions = """
  131. 你是加词选择专家。
  132. ## 任务
  133. 从候选词列表中选择一个最合适的词,与当前seed组合成新的query。
  134. ## 原则
  135. 1. 选择与当前seed最相关的词
  136. 2. 组合后的query要语义通顺
  137. 3. 符合搜索习惯
  138. 4. 优先选择能扩展搜索范围的词
  139. ## 输出
  140. - selected_word: 选中的词
  141. - combined_query: 组合后的新query
  142. - reasoning: 选择理由
  143. """.strip()
  144. word_selector = Agent[None](
  145. name="加词选择专家",
  146. instructions=word_selection_instructions,
  147. model=get_model(MODEL_NAME),
  148. output_type=WordSelection,
  149. )
  150. # ============================================================================
  151. # 辅助函数
  152. # ============================================================================
  153. def process_note_data(note: dict) -> Post:
  154. """处理搜索接口返回的帖子数据"""
  155. note_card = note.get("note_card", {})
  156. image_list = note_card.get("image_list", [])
  157. interact_info = note_card.get("interact_info", {})
  158. user_info = note_card.get("user", {})
  159. # 提取图片URL - 使用新的字段名 image_url
  160. images = []
  161. for img in image_list:
  162. if isinstance(img, dict):
  163. # 尝试新字段名 image_url,如果不存在则尝试旧字段名 url_default
  164. img_url = img.get("image_url") or img.get("url_default")
  165. if img_url:
  166. images.append(img_url)
  167. # 判断类型
  168. note_type = note_card.get("type", "normal")
  169. video_url = ""
  170. if note_type == "video":
  171. video_info = note_card.get("video", {})
  172. if isinstance(video_info, dict):
  173. # 尝试获取视频URL
  174. video_url = video_info.get("media", {}).get("stream", {}).get("h264", [{}])[0].get("master_url", "")
  175. return Post(
  176. note_id=note.get("id", ""),
  177. title=note_card.get("display_title", ""),
  178. body_text=note_card.get("desc", ""),
  179. type=note_type,
  180. images=images,
  181. video=video_url,
  182. interact_info={
  183. "liked_count": interact_info.get("liked_count", 0),
  184. "collected_count": interact_info.get("collected_count", 0),
  185. "comment_count": interact_info.get("comment_count", 0),
  186. "shared_count": interact_info.get("shared_count", 0)
  187. },
  188. note_url=f"https://www.xiaohongshu.com/explore/{note.get('id', '')}"
  189. )
  190. async def evaluate_with_o(text: str, o: str) -> tuple[float, str]:
  191. """评估文本与原始问题o的相关度
  192. Returns:
  193. tuple[float, str]: (相关度分数, 评估理由)
  194. """
  195. eval_input = f"""
  196. <原始问题>
  197. {o}
  198. </原始问题>
  199. <当前文本>
  200. {text}
  201. </当前文本>
  202. 请评估当前文本与原始问题的相关度。
  203. """
  204. result = await Runner.run(relevance_evaluator, eval_input)
  205. evaluation: RelevanceEvaluation = result.final_output
  206. return evaluation.relevance_score, evaluation.reason
  207. # ============================================================================
  208. # 核心流程函数
  209. # ============================================================================
  210. async def initialize(o: str, context: RunContext) -> tuple[list[Seg], list[Word], list[Q], list[Seed]]:
  211. """
  212. 初始化阶段
  213. Returns:
  214. (seg_list, word_list_1, q_list_1, seed_list)
  215. """
  216. print(f"\n{'='*60}")
  217. print(f"初始化阶段")
  218. print(f"{'='*60}")
  219. # 1. 分词:原始问题(o) ->分词-> seg_list
  220. print(f"\n[步骤1] 分词...")
  221. result = await Runner.run(word_segmenter, o)
  222. segmentation: WordSegmentation = result.final_output
  223. seg_list = []
  224. for word in segmentation.words:
  225. seg_list.append(Seg(text=word, from_o=o))
  226. print(f"分词结果: {[s.text for s in seg_list]}")
  227. print(f"分词理由: {segmentation.reasoning}")
  228. # 2. 分词评估:seg_list -> 每个seg与o进行评分(并发)
  229. print(f"\n[步骤2] 评估每个分词与原始问题的相关度...")
  230. async def evaluate_seg(seg: Seg) -> Seg:
  231. seg.score_with_o, seg.reason = await evaluate_with_o(seg.text, o)
  232. return seg
  233. if seg_list:
  234. eval_tasks = [evaluate_seg(seg) for seg in seg_list]
  235. await asyncio.gather(*eval_tasks)
  236. for seg in seg_list:
  237. print(f" {seg.text}: {seg.score_with_o:.2f}")
  238. # 3. 构建word_list_1: seg_list -> word_list_1
  239. print(f"\n[步骤3] 构建word_list_1...")
  240. word_list_1 = []
  241. for seg in seg_list:
  242. word_list_1.append(Word(
  243. text=seg.text,
  244. score_with_o=seg.score_with_o,
  245. from_o=o
  246. ))
  247. print(f"word_list_1: {[w.text for w in word_list_1]}")
  248. # 4. 构建q_list_1:seg_list 作为 q_list_1
  249. print(f"\n[步骤4] 构建q_list_1...")
  250. q_list_1 = []
  251. for seg in seg_list:
  252. q_list_1.append(Q(
  253. text=seg.text,
  254. score_with_o=seg.score_with_o,
  255. reason=seg.reason,
  256. from_source="seg"
  257. ))
  258. print(f"q_list_1: {[q.text for q in q_list_1]}")
  259. # 5. 构建seed_list: seg_list -> seed_list
  260. print(f"\n[步骤5] 构建seed_list...")
  261. seed_list = []
  262. for seg in seg_list:
  263. seed_list.append(Seed(
  264. text=seg.text,
  265. added_words=[],
  266. from_type="seg",
  267. score_with_o=seg.score_with_o
  268. ))
  269. print(f"seed_list: {[s.text for s in seed_list]}")
  270. return seg_list, word_list_1, q_list_1, seed_list
  271. async def run_round(
  272. round_num: int,
  273. q_list: list[Q],
  274. word_list: list[Word],
  275. seed_list: list[Seed],
  276. o: str,
  277. context: RunContext,
  278. xiaohongshu_api: XiaohongshuSearchRecommendations,
  279. xiaohongshu_search: XiaohongshuSearch,
  280. sug_threshold: float = 0.7
  281. ) -> tuple[list[Word], list[Q], list[Seed], list[Search]]:
  282. """
  283. 运行一轮
  284. Args:
  285. round_num: 轮次编号
  286. q_list: 当前轮的q列表
  287. word_list: 当前的word列表
  288. seed_list: 当前的seed列表
  289. o: 原始问题
  290. context: 运行上下文
  291. xiaohongshu_api: 建议词API
  292. xiaohongshu_search: 搜索API
  293. sug_threshold: suggestion的阈值
  294. Returns:
  295. (word_list_next, q_list_next, seed_list_next, search_list)
  296. """
  297. print(f"\n{'='*60}")
  298. print(f"第{round_num}轮")
  299. print(f"{'='*60}")
  300. round_data = {
  301. "round_num": round_num,
  302. "input_q_list": [{"text": q.text, "score": q.score_with_o} for q in q_list],
  303. "input_word_list_size": len(word_list),
  304. "input_seed_list_size": len(seed_list)
  305. }
  306. # 1. 请求sug:q_list -> 每个q请求sug接口 -> sug_list_list
  307. print(f"\n[步骤1] 为每个q请求建议词...")
  308. sug_list_list = [] # list of list
  309. for q in q_list:
  310. print(f"\n 处理q: {q.text}")
  311. suggestions = xiaohongshu_api.get_recommendations(keyword=q.text)
  312. q_sug_list = []
  313. if suggestions:
  314. print(f" 获取到 {len(suggestions)} 个建议词")
  315. for sug_text in suggestions:
  316. sug = Sug(
  317. text=sug_text,
  318. from_q=QFromQ(text=q.text, score_with_o=q.score_with_o)
  319. )
  320. q_sug_list.append(sug)
  321. else:
  322. print(f" 未获取到建议词")
  323. sug_list_list.append(q_sug_list)
  324. # 2. sug评估:sug_list_list -> 每个sug与o进评分(并发)
  325. print(f"\n[步骤2] 评估每个建议词与原始问题的相关度...")
  326. # 2.1 收集所有需要评估的sug,并记录它们所属的q
  327. all_sugs = []
  328. sug_to_q_map = {} # 记录每个sug属于哪个q
  329. for i, q_sug_list in enumerate(sug_list_list):
  330. if q_sug_list:
  331. q_text = q_list[i].text
  332. for sug in q_sug_list:
  333. all_sugs.append(sug)
  334. sug_to_q_map[id(sug)] = q_text
  335. # 2.2 并发评估所有sug
  336. async def evaluate_sug(sug: Sug) -> Sug:
  337. sug.score_with_o, sug.reason = await evaluate_with_o(sug.text, o)
  338. return sug
  339. if all_sugs:
  340. eval_tasks = [evaluate_sug(sug) for sug in all_sugs]
  341. await asyncio.gather(*eval_tasks)
  342. # 2.3 打印结果并组织到sug_details
  343. sug_details = {} # 保存每个Q对应的sug列表
  344. for i, q_sug_list in enumerate(sug_list_list):
  345. if q_sug_list:
  346. q_text = q_list[i].text
  347. print(f"\n 来自q '{q_text}' 的建议词:")
  348. sug_details[q_text] = []
  349. for sug in q_sug_list:
  350. print(f" {sug.text}: {sug.score_with_o:.2f}")
  351. # 保存到sug_details
  352. sug_details[q_text].append({
  353. "text": sug.text,
  354. "score": sug.score_with_o,
  355. "reason": sug.reason
  356. })
  357. # 3. search_list构建
  358. print(f"\n[步骤3] 构建search_list(阈值>{sug_threshold})...")
  359. search_list = []
  360. high_score_sugs = [sug for sug in all_sugs if sug.score_with_o > sug_threshold]
  361. if high_score_sugs:
  362. print(f" 找到 {len(high_score_sugs)} 个高分建议词")
  363. # 并发搜索
  364. async def search_for_sug(sug: Sug) -> Search:
  365. print(f" 搜索: {sug.text}")
  366. try:
  367. search_result = xiaohongshu_search.search(keyword=sug.text)
  368. result_str = search_result.get("result", "{}")
  369. if isinstance(result_str, str):
  370. result_data = json.loads(result_str)
  371. else:
  372. result_data = result_str
  373. notes = result_data.get("data", {}).get("data", [])
  374. post_list = []
  375. for note in notes[:10]: # 只取前10个
  376. post = process_note_data(note)
  377. post_list.append(post)
  378. print(f" → 找到 {len(post_list)} 个帖子")
  379. return Search(
  380. text=sug.text,
  381. score_with_o=sug.score_with_o,
  382. from_q=sug.from_q,
  383. post_list=post_list
  384. )
  385. except Exception as e:
  386. print(f" ✗ 搜索失败: {e}")
  387. return Search(
  388. text=sug.text,
  389. score_with_o=sug.score_with_o,
  390. from_q=sug.from_q,
  391. post_list=[]
  392. )
  393. search_tasks = [search_for_sug(sug) for sug in high_score_sugs]
  394. search_list = await asyncio.gather(*search_tasks)
  395. else:
  396. print(f" 没有高分建议词,search_list为空")
  397. # 4. 构建word_list_next: word_list -> word_list_next(先直接复制)
  398. print(f"\n[步骤4] 构建word_list_next(暂时直接复制)...")
  399. word_list_next = word_list.copy()
  400. # 5. 构建q_list_next
  401. print(f"\n[步骤5] 构建q_list_next...")
  402. q_list_next = []
  403. add_word_details = {} # 保存每个seed对应的组合词列表
  404. # 5.1 对于seed_list中的每个seed,从word_list_next中选一个未加过的词
  405. print(f"\n 5.1 为每个seed加词...")
  406. for seed in seed_list:
  407. print(f"\n 处理seed: {seed.text}")
  408. # 简单过滤:找出不在seed.text中且未被添加过的词
  409. candidate_words = []
  410. for word in word_list_next:
  411. # 检查词是否已在seed中
  412. if word.text in seed.text:
  413. continue
  414. # 检查词是否已被添加过
  415. if word.text in seed.added_words:
  416. continue
  417. candidate_words.append(word)
  418. if not candidate_words:
  419. print(f" 没有可用的候选词")
  420. continue
  421. print(f" 候选词: {[w.text for w in candidate_words]}")
  422. # 使用Agent选择最合适的词
  423. selection_input = f"""
  424. <原始问题>
  425. {o}
  426. </原始问题>
  427. <当前Seed>
  428. {seed.text}
  429. </当前Seed>
  430. <候选词列表>
  431. {', '.join([w.text for w in candidate_words])}
  432. </候选词列表>
  433. 请从候选词中选择一个最合适的词,与当前seed组合成新的query。
  434. """
  435. result = await Runner.run(word_selector, selection_input)
  436. selection: WordSelection = result.final_output
  437. # 验证选择的词是否在候选列表中
  438. if selection.selected_word not in [w.text for w in candidate_words]:
  439. print(f" ✗ Agent选择的词 '{selection.selected_word}' 不在候选列表中,跳过")
  440. continue
  441. print(f" ✓ 选择词: {selection.selected_word}")
  442. print(f" ✓ 新query: {selection.combined_query}")
  443. print(f" 理由: {selection.reasoning}")
  444. # 评估新query
  445. new_q_score, new_q_reason = await evaluate_with_o(selection.combined_query, o)
  446. print(f" 新query评分: {new_q_score:.2f}")
  447. # 创建新的q
  448. new_q = Q(
  449. text=selection.combined_query,
  450. score_with_o=new_q_score,
  451. reason=new_q_reason,
  452. from_source="add"
  453. )
  454. q_list_next.append(new_q)
  455. # 更新seed的added_words
  456. seed.added_words.append(selection.selected_word)
  457. # 保存到add_word_details
  458. if seed.text not in add_word_details:
  459. add_word_details[seed.text] = []
  460. add_word_details[seed.text].append({
  461. "text": selection.combined_query,
  462. "score": new_q_score,
  463. "reason": new_q_reason,
  464. "selected_word": selection.selected_word
  465. })
  466. # 5.2 对于sug_list_list中,每个sug大于来自的query分数,加到q_list_next
  467. print(f"\n 5.2 将高分sug加入q_list_next...")
  468. for sug in all_sugs:
  469. if sug.from_q and sug.score_with_o > sug.from_q.score_with_o:
  470. new_q = Q(
  471. text=sug.text,
  472. score_with_o=sug.score_with_o,
  473. reason=sug.reason,
  474. from_source="sug"
  475. )
  476. q_list_next.append(new_q)
  477. print(f" ✓ {sug.text} (分数: {sug.score_with_o:.2f} > {sug.from_q.score_with_o:.2f})")
  478. # 6. 更新seed_list
  479. print(f"\n[步骤6] 更新seed_list...")
  480. seed_list_next = seed_list.copy() # 保留原有的seed
  481. # 对于sug_list_list中,每个sug分数大于来源query分数的,且没在seed_list中出现过的,加入
  482. existing_seed_texts = {seed.text for seed in seed_list_next}
  483. for sug in all_sugs:
  484. # 新逻辑:sug分数 > 对应query分数
  485. if sug.from_q and sug.score_with_o > sug.from_q.score_with_o and sug.text not in existing_seed_texts:
  486. new_seed = Seed(
  487. text=sug.text,
  488. added_words=[],
  489. from_type="sug",
  490. score_with_o=sug.score_with_o
  491. )
  492. seed_list_next.append(new_seed)
  493. existing_seed_texts.add(sug.text)
  494. print(f" ✓ 新seed: {sug.text} (分数: {sug.score_with_o:.2f} > 来源query: {sug.from_q.score_with_o:.2f})")
  495. # 记录本轮数据
  496. round_data.update({
  497. "sug_count": len(all_sugs),
  498. "high_score_sug_count": len(high_score_sugs),
  499. "search_count": len(search_list),
  500. "total_posts": sum(len(s.post_list) for s in search_list),
  501. "q_list_next_size": len(q_list_next),
  502. "seed_list_next_size": len(seed_list_next),
  503. "word_list_next_size": len(word_list_next),
  504. "output_q_list": [{"text": q.text, "score": q.score_with_o, "reason": q.reason, "from": q.from_source} for q in q_list_next],
  505. "seed_list_next": [{"text": seed.text, "from": seed.from_type, "score": seed.score_with_o} for seed in seed_list_next], # 下一轮种子列表
  506. "sug_details": sug_details, # 每个Q对应的sug列表
  507. "add_word_details": add_word_details # 每个seed对应的组合词列表
  508. })
  509. context.rounds.append(round_data)
  510. print(f"\n本轮总结:")
  511. print(f" 建议词数量: {len(all_sugs)}")
  512. print(f" 高分建议词: {len(high_score_sugs)}")
  513. print(f" 搜索数量: {len(search_list)}")
  514. print(f" 帖子总数: {sum(len(s.post_list) for s in search_list)}")
  515. print(f" 下轮q数量: {len(q_list_next)}")
  516. print(f" seed数量: {len(seed_list_next)}")
  517. return word_list_next, q_list_next, seed_list_next, search_list
  518. async def iterative_loop(
  519. context: RunContext,
  520. max_rounds: int = 2,
  521. sug_threshold: float = 0.7
  522. ):
  523. """主迭代循环"""
  524. print(f"\n{'='*60}")
  525. print(f"开始迭代循环")
  526. print(f"最大轮数: {max_rounds}")
  527. print(f"sug阈值: {sug_threshold}")
  528. print(f"{'='*60}")
  529. # 初始化
  530. seg_list, word_list, q_list, seed_list = await initialize(context.o, context)
  531. # API实例
  532. xiaohongshu_api = XiaohongshuSearchRecommendations()
  533. xiaohongshu_search = XiaohongshuSearch()
  534. # 保存初始化数据
  535. context.rounds.append({
  536. "round_num": 0,
  537. "type": "initialization",
  538. "seg_list": [{"text": s.text, "score": s.score_with_o, "reason": s.reason} for s in seg_list],
  539. "word_list_1": [{"text": w.text, "score": w.score_with_o} for w in word_list],
  540. "q_list_1": [{"text": q.text, "score": q.score_with_o, "reason": q.reason} for q in q_list],
  541. "seed_list": [{"text": s.text, "from_type": s.from_type, "score": s.score_with_o} for s in seed_list]
  542. })
  543. # 收集所有搜索结果
  544. all_search_list = []
  545. # 迭代
  546. round_num = 1
  547. while q_list and round_num <= max_rounds:
  548. word_list, q_list, seed_list, search_list = await run_round(
  549. round_num=round_num,
  550. q_list=q_list,
  551. word_list=word_list,
  552. seed_list=seed_list,
  553. o=context.o,
  554. context=context,
  555. xiaohongshu_api=xiaohongshu_api,
  556. xiaohongshu_search=xiaohongshu_search,
  557. sug_threshold=sug_threshold
  558. )
  559. all_search_list.extend(search_list)
  560. round_num += 1
  561. print(f"\n{'='*60}")
  562. print(f"迭代完成")
  563. print(f" 总轮数: {round_num - 1}")
  564. print(f" 总搜索次数: {len(all_search_list)}")
  565. print(f" 总帖子数: {sum(len(s.post_list) for s in all_search_list)}")
  566. print(f"{'='*60}")
  567. return all_search_list
  568. # ============================================================================
  569. # 主函数
  570. # ============================================================================
  571. async def main(input_dir: str, max_rounds: int = 2, sug_threshold: float = 0.7, visualize: bool = False):
  572. """主函数"""
  573. current_time, log_url = set_trace()
  574. # 读取输入
  575. input_context_file = os.path.join(input_dir, 'context.md')
  576. input_q_file = os.path.join(input_dir, 'q.md')
  577. c = read_file_as_string(input_context_file) # 原始需求
  578. o = read_file_as_string(input_q_file) # 原始问题
  579. # 版本信息
  580. version = os.path.basename(__file__)
  581. version_name = os.path.splitext(version)[0]
  582. # 日志目录
  583. log_dir = os.path.join(input_dir, "output", version_name, current_time)
  584. # 创建运行上下文
  585. run_context = RunContext(
  586. version=version,
  587. input_files={
  588. "input_dir": input_dir,
  589. "context_file": input_context_file,
  590. "q_file": input_q_file,
  591. },
  592. c=c,
  593. o=o,
  594. log_dir=log_dir,
  595. log_url=log_url,
  596. )
  597. # 执行迭代
  598. all_search_list = await iterative_loop(
  599. run_context,
  600. max_rounds=max_rounds,
  601. sug_threshold=sug_threshold
  602. )
  603. # 格式化输出
  604. output = f"原始需求:{run_context.c}\n"
  605. output += f"原始问题:{run_context.o}\n"
  606. output += f"总搜索次数:{len(all_search_list)}\n"
  607. output += f"总帖子数:{sum(len(s.post_list) for s in all_search_list)}\n"
  608. output += "\n" + "="*60 + "\n"
  609. if all_search_list:
  610. output += "【搜索结果】\n\n"
  611. for idx, search in enumerate(all_search_list, 1):
  612. output += f"{idx}. 搜索词: {search.text} (分数: {search.score_with_o:.2f})\n"
  613. output += f" 帖子数: {len(search.post_list)}\n"
  614. if search.post_list:
  615. for post_idx, post in enumerate(search.post_list[:3], 1): # 只显示前3个
  616. output += f" {post_idx}) {post.title}\n"
  617. output += f" URL: {post.note_url}\n"
  618. output += "\n"
  619. else:
  620. output += "未找到搜索结果\n"
  621. run_context.final_output = output
  622. print(f"\n{'='*60}")
  623. print("最终结果")
  624. print(f"{'='*60}")
  625. print(output)
  626. # 保存日志
  627. os.makedirs(run_context.log_dir, exist_ok=True)
  628. context_file_path = os.path.join(run_context.log_dir, "run_context.json")
  629. context_dict = run_context.model_dump()
  630. with open(context_file_path, "w", encoding="utf-8") as f:
  631. json.dump(context_dict, f, ensure_ascii=False, indent=2)
  632. print(f"\nRunContext saved to: {context_file_path}")
  633. # 保存详细的搜索结果
  634. search_results_path = os.path.join(run_context.log_dir, "search_results.json")
  635. search_results_data = [s.model_dump() for s in all_search_list]
  636. with open(search_results_path, "w", encoding="utf-8") as f:
  637. json.dump(search_results_data, f, ensure_ascii=False, indent=2)
  638. print(f"Search results saved to: {search_results_path}")
  639. # 可视化
  640. if visualize:
  641. import subprocess
  642. output_html = os.path.join(run_context.log_dir, "visualization.html")
  643. print(f"\n🎨 生成可视化HTML...")
  644. # 获取绝对路径
  645. abs_context_file = os.path.abspath(context_file_path)
  646. abs_output_html = os.path.abspath(output_html)
  647. # 运行可视化脚本
  648. result = subprocess.run([
  649. "node",
  650. "visualization/sug_v6_1_2_8/index.js",
  651. abs_context_file,
  652. abs_output_html
  653. ])
  654. if result.returncode == 0:
  655. print(f"✅ 可视化已生成: {output_html}")
  656. else:
  657. print(f"❌ 可视化生成失败")
  658. if __name__ == "__main__":
  659. parser = argparse.ArgumentParser(description="搜索query优化工具 - v6.1.2.8 轮次迭代版")
  660. parser.add_argument(
  661. "--input-dir",
  662. type=str,
  663. default="input/旅游-逸趣玩旅行/如何获取能体现川西秋季特色的高质量风光摄影素材?",
  664. help="输入目录路径,默认: input/旅游-逸趣玩旅行/如何获取能体现川西秋季特色的高质量风光摄影素材?"
  665. )
  666. parser.add_argument(
  667. "--max-rounds",
  668. type=int,
  669. default=4,
  670. help="最大轮数,默认: 2"
  671. )
  672. parser.add_argument(
  673. "--sug-threshold",
  674. type=float,
  675. default=0.7,
  676. help="suggestion阈值,默认: 0.7"
  677. )
  678. parser.add_argument(
  679. "--visualize",
  680. action="store_true",
  681. default=True,
  682. help="运行完成后自动生成可视化HTML"
  683. )
  684. args = parser.parse_args()
  685. asyncio.run(main(args.input_dir, max_rounds=args.max_rounds, sug_threshold=args.sug_threshold, visualize=args.visualize))