sug_v6_1_2_7.py 39 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213
  1. import asyncio
  2. import json
  3. import os
  4. import sys
  5. import argparse
  6. import time
  7. import re
  8. from datetime import datetime
  9. from typing import Literal, TypeVar, Type
  10. from agents import Agent, Runner
  11. from lib.my_trace import set_trace
  12. from pydantic import BaseModel, Field
  13. from lib.utils import read_file_as_string
  14. from lib.client import get_model
  15. MODEL_NAME = "google/gemini-2.5-flash"
  16. from script.search_recommendations.xiaohongshu_search_recommendations import XiaohongshuSearchRecommendations
  17. from script.search.xiaohongshu_search import XiaohongshuSearch
  18. # ============================================================================
  19. # 数据模型
  20. # ============================================================================
  21. class Seg(BaseModel):
  22. """分词结果"""
  23. text: str
  24. score_with_o: float
  25. from_o: str
  26. class Word(BaseModel):
  27. """词库中的词"""
  28. text: str
  29. score_with_o: float
  30. from_o: str
  31. class Q(BaseModel):
  32. """查询"""
  33. text: str
  34. score_with_o: float
  35. from_source: str # "seg" | "sug" | "add"
  36. class Sug(BaseModel):
  37. """建议查询"""
  38. text: str
  39. score_with_o: float
  40. from_q: dict # {"text": str, "score_with_o": float}
  41. evaluation_reason: str | None = None # 评估理由
  42. class Seed(BaseModel):
  43. """种子查询(用于加词探索)"""
  44. text: str
  45. added_words: list[str] = Field(default_factory=list)
  46. from_type: str # "seg" | "sug"
  47. class Post(BaseModel):
  48. """帖子"""
  49. note_id: str = ""
  50. title: str = ""
  51. body_text: str = ""
  52. type: str = "normal" # "video" | "normal"
  53. images: list[str] = Field(default_factory=list)
  54. video: str = ""
  55. interact_info: dict = Field(default_factory=dict)
  56. note_url: str = ""
  57. class Search(BaseModel):
  58. """搜索结果(继承自Sug)"""
  59. text: str
  60. score_with_o: float
  61. from_q: dict
  62. post_list: list[Post] = Field(default_factory=list)
  63. class RunContext(BaseModel):
  64. """运行上下文"""
  65. version: str
  66. input_files: dict[str, str]
  67. c: str # 原始需求(context)
  68. o: str # 原始问题
  69. log_url: str
  70. log_dir: str
  71. # 核心数据
  72. seg_list: list[dict] = Field(default_factory=list)
  73. word_lists: dict[int, list[dict]] = Field(default_factory=dict) # {round: word_list}
  74. q_lists: dict[int, list[dict]] = Field(default_factory=dict) # {round: q_list}
  75. sug_list_lists: dict[int, list[list[dict]]] = Field(default_factory=dict) # {round: [[sug, sug], [sug]]}
  76. search_lists: dict[int, list[dict]] = Field(default_factory=dict) # {round: search_list}
  77. seed_lists: dict[int, list[dict]] = Field(default_factory=dict) # {round: seed_list}
  78. steps: list[dict] = Field(default_factory=list)
  79. # 新增:详细的操作记录(中文命名,但数据结构保留英文)
  80. 轮次记录: dict[int, dict] = Field(default_factory=dict)
  81. # 最终结果
  82. all_posts: list[dict] = Field(default_factory=list)
  83. final_output: str | None = None
  84. # ============================================================================
  85. # 辅助函数:记录操作
  86. # ============================================================================
  87. def init_round_record(run_context: RunContext, round_num: int, round_name: str):
  88. """初始化一个轮次记录"""
  89. run_context.轮次记录[round_num] = {
  90. "轮次": round_num,
  91. "名称": round_name,
  92. "操作列表": []
  93. }
  94. def add_operation_record(
  95. run_context: RunContext,
  96. round_num: int,
  97. 操作名称: str,
  98. 输入: dict,
  99. 处理过程: dict,
  100. 输出: dict
  101. ):
  102. """添加一条操作记录"""
  103. from datetime import datetime
  104. operation = {
  105. "操作名称": 操作名称,
  106. "轮次": round_num,
  107. "时间": datetime.now().isoformat(),
  108. "输入": 输入,
  109. "处理过程": 处理过程,
  110. "输出": 输出
  111. }
  112. if round_num not in run_context.轮次记录:
  113. init_round_record(run_context, round_num, f"第{round_num}轮" if round_num > 0 else "初始化阶段")
  114. run_context.轮次记录[round_num]["操作列表"].append(operation)
  115. def record_agent_call(
  116. agent_name: str,
  117. model: str,
  118. instructions: str,
  119. user_message: str,
  120. raw_output: dict | str,
  121. parsed: bool,
  122. validation_error: str | None = None,
  123. input_schema: dict | None = None
  124. ) -> dict:
  125. """记录单次Agent调用"""
  126. return {
  127. "Agent名称": agent_name,
  128. "模型": model,
  129. "系统提示词": instructions,
  130. "输入Schema": input_schema,
  131. "用户消息": user_message,
  132. "原始输出": raw_output,
  133. "解析成功": parsed,
  134. "验证错误": validation_error
  135. }
  136. # ============================================================================
  137. # JSON后处理:处理markdown包裹的JSON响应
  138. # ============================================================================
  139. def clean_json_response(text: str) -> str:
  140. """清理可能包含markdown代码块包裹的JSON
  141. 模型可能返回:
  142. ```json
  143. {"key": "value"}
  144. ```
  145. 需要清理为:
  146. {"key": "value"}
  147. """
  148. text = text.strip()
  149. # 移除开头的 ```json 或 ```
  150. if text.startswith('```json'):
  151. text = text[7:]
  152. elif text.startswith('```'):
  153. text = text[3:]
  154. # 移除结尾的 ```
  155. if text.endswith('```'):
  156. text = text[:-3]
  157. return text.strip()
  158. T = TypeVar('T', bound=BaseModel)
  159. async def run_agent_with_json_cleanup(
  160. agent: Agent,
  161. input_text: str,
  162. output_type: Type[T]
  163. ) -> T:
  164. """运行Agent并处理可能的JSON包裹问题
  165. 如果Agent返回被markdown包裹的JSON,自动清理后重新解析
  166. """
  167. try:
  168. result = await Runner.run(agent, input_text)
  169. return result.final_output
  170. except Exception as e:
  171. error_msg = str(e)
  172. # 检查是否是JSON解析错误
  173. if "Invalid JSON when parsing" in error_msg:
  174. # 尝试从错误消息中提取JSON
  175. # 错误格式: "Invalid JSON when parsing ```json\n{...}\n``` for TypeAdapter(...)"
  176. match = re.search(r'when parsing (.+?) for TypeAdapter', error_msg, re.DOTALL)
  177. if match:
  178. json_text = match.group(1)
  179. cleaned_json = clean_json_response(json_text)
  180. try:
  181. # 手动解析JSON并创建Pydantic对象
  182. parsed_data = json.loads(cleaned_json)
  183. return output_type(**parsed_data)
  184. except Exception as parse_error:
  185. print(f"⚠️ JSON清理后仍无法解析: {parse_error}")
  186. print(f" 清理后的JSON: {cleaned_json}")
  187. raise ValueError(f"无法解析JSON: {parse_error}\n原始错误: {error_msg}")
  188. # 如果不是JSON解析错误,或清理失败,重新抛出原始错误
  189. raise
  190. # ============================================================================
  191. # Agent 定义
  192. # ============================================================================
  193. # Agent 1: 分词专家
  194. class WordSegmentation(BaseModel):
  195. """分词结果"""
  196. words: list[str] = Field(..., description="分词结果列表")
  197. reasoning: str = Field(..., description="分词理由")
  198. word_segmentation_instructions = """
  199. 你是分词专家。给定一个query,将其拆分成有意义的最小单元。
  200. ## 分词原则
  201. 1. 保留有搜索意义的词汇
  202. 2. 拆分成独立的概念
  203. 3. 保留专业术语的完整性
  204. 4. 去除虚词(的、吗、呢等)
  205. ## 输出要求
  206. 返回分词列表和分词理由。
  207. IMPORTANT: 直接返回纯JSON对象,不要使用markdown代码块标记(不要用```json...```包裹)。
  208. """.strip()
  209. word_segmenter = Agent[None](
  210. name="分词专家",
  211. instructions=word_segmentation_instructions,
  212. model=get_model(MODEL_NAME),
  213. output_type=WordSegmentation,
  214. )
  215. # Agent 2: Query相关度评估专家
  216. class RelevanceEvaluation(BaseModel):
  217. """相关度评估"""
  218. relevance_score: float = Field(..., description="相关性分数 0-1")
  219. reason: str = Field(..., description="评估理由")
  220. relevance_evaluation_instructions = """
  221. 你是Query相关度评估专家。
  222. ## 任务
  223. 评估当前query与原始问题的匹配程度。
  224. ## 评估标准
  225. - 主题相关性
  226. - 要素覆盖度
  227. - 意图匹配度
  228. ## 输出
  229. - relevance_score: 0-1的相关性分数
  230. - reason: 详细理由
  231. IMPORTANT: 直接返回纯JSON对象,不要使用markdown代码块标记(不要用```json...```包裹)。
  232. """.strip()
  233. relevance_evaluator = Agent[None](
  234. name="Query相关度评估专家",
  235. instructions=relevance_evaluation_instructions,
  236. model=get_model(MODEL_NAME),
  237. output_type=RelevanceEvaluation,
  238. )
  239. # Agent 3: Word选择专家
  240. class WordSelection(BaseModel):
  241. """Word选择结果"""
  242. selected_word: str = Field(..., description="选中的词")
  243. reasoning: str = Field(..., description="选择理由")
  244. word_selection_instructions = """
  245. 你是Word选择专家。
  246. ## 任务
  247. 从候选词列表中选择一个最适合与当前seed组合的词,用于探索新的搜索query。
  248. ## 选择原则
  249. 1. 与seed的语义相关性
  250. 2. 组合后的搜索价值
  251. 3. 能拓展搜索范围
  252. ## 输出
  253. 返回选中的词和选择理由。
  254. """.strip()
  255. word_selector = Agent[None](
  256. name="Word选择专家",
  257. instructions=word_selection_instructions,
  258. model=get_model(MODEL_NAME),
  259. output_type=WordSelection,
  260. )
  261. # Agent 4: 加词位置评估专家
  262. class WordInsertion(BaseModel):
  263. """加词结果"""
  264. new_query: str = Field(..., description="加词后的新query")
  265. insertion_position: str = Field(..., description="插入位置描述")
  266. reasoning: str = Field(..., description="插入理由")
  267. word_insertion_instructions = """
  268. 你是加词位置评估专家。
  269. ## 任务
  270. 将新词加到当前query的最合适位置,保持语义通顺。
  271. ## 原则
  272. 1. 保持语法正确
  273. 2. 语义连贯
  274. 3. 符合搜索习惯
  275. ## 输出
  276. 返回新query、插入位置描述和理由。
  277. """.strip()
  278. word_inserter = Agent[None](
  279. name="加词位置评估专家",
  280. instructions=word_insertion_instructions,
  281. model=get_model(MODEL_NAME),
  282. output_type=WordInsertion,
  283. )
  284. # ============================================================================
  285. # 辅助函数
  286. # ============================================================================
  287. def add_step(context: RunContext, step_name: str, step_type: str, data: dict):
  288. """添加步骤记录"""
  289. step = {
  290. "step_number": len(context.steps) + 1,
  291. "step_name": step_name,
  292. "step_type": step_type,
  293. "timestamp": datetime.now().isoformat(),
  294. "data": data
  295. }
  296. context.steps.append(step)
  297. return step
  298. def process_note_data(note: dict) -> Post:
  299. """处理搜索接口返回的帖子数据,转换为Post对象"""
  300. note_card = note.get("note_card", {})
  301. image_list = note_card.get("image_list", [])
  302. interact_info = note_card.get("interact_info", {})
  303. # 提取图片URLs - 使用 image_url 字段
  304. images = []
  305. for img in image_list:
  306. if "image_url" in img:
  307. images.append(img["image_url"])
  308. # 判断是否是视频
  309. note_type = note_card.get("type", "normal")
  310. video_url = ""
  311. if note_type == "video":
  312. # 视频类型可能有不同的结构,这里先留空
  313. # 如果需要可以后续补充
  314. pass
  315. return Post(
  316. note_id=note.get("id") or "",
  317. title=note_card.get("display_title") or "",
  318. body_text=note_card.get("desc") or "",
  319. type=note_type,
  320. images=images,
  321. video=video_url,
  322. interact_info={
  323. "liked_count": interact_info.get("liked_count", 0),
  324. "collected_count": interact_info.get("collected_count", 0),
  325. "comment_count": interact_info.get("comment_count", 0),
  326. "shared_count": interact_info.get("shared_count", 0)
  327. },
  328. note_url=f"https://www.xiaohongshu.com/explore/{note.get('id') or ''}"
  329. )
  330. # ============================================================================
  331. # 核心流程函数
  332. # ============================================================================
  333. async def evaluate_query_with_o(query_text: str, original_o: str) -> tuple[float, str]:
  334. """评估query与原始问题o的相关度
  335. Returns:
  336. (score, reason)
  337. """
  338. eval_input = f"""
  339. <原始问题>
  340. {original_o}
  341. </原始问题>
  342. <当前Query>
  343. {query_text}
  344. </当前Query>
  345. 请评估当前query与原始问题的相关度。
  346. """
  347. evaluation = await run_agent_with_json_cleanup(
  348. relevance_evaluator,
  349. eval_input,
  350. RelevanceEvaluation
  351. )
  352. return evaluation.relevance_score, evaluation.reason
  353. async def initialize(context: RunContext):
  354. """初始化:分词 → seg_list → word_list_1, q_list_1, seed_list_1"""
  355. print("\n" + "="*60)
  356. print("初始化阶段")
  357. print("="*60)
  358. # 初始化轮次0
  359. init_round_record(context, 0, "初始化阶段")
  360. # 1. 分词
  361. print(f"\n[1/4] 分词原始问题: {context.o}")
  362. segmentation = await run_agent_with_json_cleanup(
  363. word_segmenter,
  364. context.o,
  365. WordSegmentation
  366. )
  367. print(f" 分词结果: {segmentation.words}")
  368. print(f" 分词理由: {segmentation.reasoning}")
  369. # 2. 分词评估(并发)
  370. print(f"\n[2/4] 评估每个seg与原始问题的相关度...")
  371. seg_list = []
  372. agent_calls_seg_eval = []
  373. # 并发评估所有分词
  374. eval_tasks = [evaluate_query_with_o(word, context.o) for word in segmentation.words]
  375. eval_results = await asyncio.gather(*eval_tasks)
  376. for word, (score, reason) in zip(segmentation.words, eval_results):
  377. seg = Seg(text=word, score_with_o=score, from_o=context.o)
  378. seg_list.append(seg.model_dump())
  379. print(f" {word}: {score:.2f}")
  380. # 记录每个seg的评估
  381. agent_calls_seg_eval.append(
  382. record_agent_call(
  383. agent_name="Query相关度评估专家",
  384. model=MODEL_NAME,
  385. instructions=relevance_evaluation_instructions,
  386. user_message=f"评估query与原始问题的相关度:\n\nQuery: {word}\n原始问题: {context.o}",
  387. raw_output={"score": score, "reason": reason},
  388. parsed=True
  389. )
  390. )
  391. context.seg_list = seg_list
  392. # 记录分词操作
  393. add_operation_record(
  394. context,
  395. round_num=0,
  396. 操作名称="分词",
  397. 输入={"原始问题": context.o},
  398. 处理过程={
  399. "Agent调用": record_agent_call(
  400. agent_name="分词专家",
  401. model=MODEL_NAME,
  402. instructions=word_segmentation_instructions,
  403. user_message=f"请对以下query进行分词:{context.o}",
  404. raw_output={"words": segmentation.words, "reasoning": segmentation.reasoning},
  405. parsed=True,
  406. input_schema={"type": "WordSegmentation", "fields": {"words": "list[str]", "reasoning": "str"}}
  407. ),
  408. "seg评估Agent调用列表": agent_calls_seg_eval
  409. },
  410. 输出={"seg_list": seg_list}
  411. )
  412. # 3. 构建 word_list_1(直接从seg_list复制)
  413. print(f"\n[3/4] 构建 word_list_1...")
  414. word_list_1 = []
  415. for seg in seg_list:
  416. word = Word(text=seg["text"], score_with_o=seg["score_with_o"], from_o=seg["from_o"])
  417. word_list_1.append(word.model_dump())
  418. context.word_lists[1] = word_list_1
  419. print(f" word_list_1 大小: {len(word_list_1)}")
  420. # 4. 构建 q_list_1 和 seed_list_1
  421. print(f"\n[4/4] 构建 q_list_1 和 seed_list_1...")
  422. q_list_1 = []
  423. seed_list_1 = []
  424. for seg in seg_list:
  425. # q_list_1: seg作为q
  426. q = Q(text=seg["text"], score_with_o=seg["score_with_o"], from_source="seg")
  427. q_list_1.append(q.model_dump())
  428. # seed_list_1: seg作为seed
  429. seed = Seed(text=seg["text"], added_words=[], from_type="seg")
  430. seed_list_1.append(seed.model_dump())
  431. context.q_lists[1] = q_list_1
  432. context.seed_lists[1] = seed_list_1
  433. print(f" q_list_1 大小: {len(q_list_1)}")
  434. print(f" seed_list_1 大小: {len(seed_list_1)}")
  435. # 记录初始化操作
  436. add_operation_record(
  437. context,
  438. round_num=0,
  439. 操作名称="初始化",
  440. 输入={"seg_list": seg_list},
  441. 处理过程={"说明": "从seg_list构建初始q_list和seed_list"},
  442. 输出={
  443. "word_list_1": word_list_1,
  444. "q_list_1": q_list_1,
  445. "seed_list_1": seed_list_1
  446. }
  447. )
  448. add_step(context, "初始化完成", "initialize", {
  449. "seg_count": len(seg_list),
  450. "word_list_1_count": len(word_list_1),
  451. "q_list_1_count": len(q_list_1),
  452. "seed_list_1_count": len(seed_list_1)
  453. })
  454. async def process_round(round_num: int, context: RunContext, xiaohongshu_api: XiaohongshuSearchRecommendations, xiaohongshu_search: XiaohongshuSearch, sug_threshold: float = 0.7):
  455. """处理一轮迭代
  456. Args:
  457. round_num: 当前轮数
  458. context: 运行上下文
  459. xiaohongshu_api: sug API
  460. xiaohongshu_search: search API
  461. sug_threshold: sug评分阈值(默认0.7)
  462. """
  463. print(f"\n" + "="*60)
  464. print(f"第 {round_num} 轮")
  465. print("="*60)
  466. # 初始化轮次记录
  467. init_round_record(context, round_num, f"第{round_num}轮迭代")
  468. q_list_n = context.q_lists.get(round_num, [])
  469. if not q_list_n:
  470. print(f" q_list_{round_num} 为空,跳过本轮")
  471. return
  472. print(f" 处理 {len(q_list_n)} 个query")
  473. # 1. 请求sug
  474. print(f"\n[1/5] 请求sug...")
  475. sug_list_list_n = []
  476. api_calls_detail = []
  477. for q_data in q_list_n:
  478. q_text = q_data["text"]
  479. suggestions = xiaohongshu_api.get_recommendations(keyword=q_text)
  480. if not suggestions:
  481. print(f" {q_text}: 无sug")
  482. sug_list_list_n.append([])
  483. api_calls_detail.append({
  484. "query": q_text,
  485. "sug_count": 0
  486. })
  487. continue
  488. print(f" {q_text}: 获取 {len(suggestions)} 个sug")
  489. sug_list_list_n.append(suggestions)
  490. api_calls_detail.append({
  491. "query": q_text,
  492. "sug_count": len(suggestions)
  493. })
  494. # 记录请求sug操作
  495. total_sugs = sum(len(sl) for sl in sug_list_list_n)
  496. add_operation_record(
  497. context,
  498. round_num=round_num,
  499. 操作名称="请求推荐词",
  500. 输入={"q_list": [{"text": q["text"], "score": q["score_with_o"]} for q in q_list_n]},
  501. 处理过程={"API调用": api_calls_detail},
  502. 输出={
  503. "sug_list_list": [[{"text": s, "from_q": q_list_n[i]["text"]} for s in sl] for i, sl in enumerate(sug_list_list_n)],
  504. "总推荐词数": total_sugs
  505. }
  506. )
  507. # 2. sug评估(批量并发,限制并发数为10)
  508. print(f"\n[2/5] 评估sug...")
  509. sug_list_list_evaluated = []
  510. # 收集所有需要评估的sug及其上下文
  511. all_sug_tasks = []
  512. sug_contexts = [] # 记录每个sug对应的q_data和位置
  513. for i, sug_list in enumerate(sug_list_list_n):
  514. q_data = q_list_n[i]
  515. for sug_text in sug_list:
  516. all_sug_tasks.append(evaluate_query_with_o(sug_text, context.o))
  517. sug_contexts.append((i, q_data, sug_text))
  518. # 批量并发评估(每批10个)
  519. batch_size = 10
  520. all_results = []
  521. batches_detail = []
  522. for batch_idx in range(0, len(all_sug_tasks), batch_size):
  523. batch_tasks = all_sug_tasks[batch_idx:batch_idx+batch_size]
  524. batch_results = await asyncio.gather(*batch_tasks)
  525. all_results.extend(batch_results)
  526. # 记录这个批次的Agent调用
  527. batch_agent_calls = []
  528. start_idx = batch_idx
  529. for j, (score, reason) in enumerate(batch_results):
  530. if start_idx + j < len(sug_contexts):
  531. _, _, sug_text = sug_contexts[start_idx + j]
  532. batch_agent_calls.append(
  533. record_agent_call(
  534. agent_name="Query相关度评估专家",
  535. model=MODEL_NAME,
  536. instructions=relevance_evaluation_instructions,
  537. user_message=f"评估query与原始问题的相关度:\n\nQuery: {sug_text}\n原始问题: {context.o}",
  538. raw_output={"score": score, "reason": reason},
  539. parsed=True
  540. )
  541. )
  542. batches_detail.append({
  543. "批次ID": len(batches_detail),
  544. "并发执行": True,
  545. "Agent调用列表": batch_agent_calls
  546. })
  547. # 组织结果
  548. result_index = 0
  549. current_list_index = -1
  550. evaluated_sugs = []
  551. for list_idx, q_data, sug_text in sug_contexts:
  552. if list_idx != current_list_index:
  553. if evaluated_sugs:
  554. sug_list_list_evaluated.append(evaluated_sugs)
  555. evaluated_sugs = []
  556. current_list_index = list_idx
  557. score, reason = all_results[result_index]
  558. result_index += 1
  559. sug = Sug(
  560. text=sug_text,
  561. score_with_o=score,
  562. from_q={"text": q_data["text"], "score_with_o": q_data["score_with_o"]},
  563. evaluation_reason=reason
  564. )
  565. evaluated_sugs.append(sug.model_dump())
  566. print(f" {sug_text}: {score:.2f}")
  567. # 添加最后一批
  568. if evaluated_sugs:
  569. sug_list_list_evaluated.append(evaluated_sugs)
  570. context.sug_list_lists[round_num] = sug_list_list_evaluated
  571. # 记录评估sug操作
  572. add_operation_record(
  573. context,
  574. round_num=round_num,
  575. 操作名称="评估推荐词",
  576. 输入={
  577. "待评估推荐词": [[s for s in sl] for sl in sug_list_list_n],
  578. "总数": len(all_sug_tasks)
  579. },
  580. 处理过程={"批次列表": batches_detail},
  581. 输出={"已评估推荐词": sug_list_list_evaluated}
  582. )
  583. # 3. 构建search_list_n(阈值>= 0.7的sug)
  584. print(f"\n[3/5] 构建search_list并执行搜索...")
  585. search_list_n = []
  586. filter_comparisons = []
  587. search_details = []
  588. for sug_list_evaluated in sug_list_list_evaluated:
  589. for sug_data in sug_list_evaluated:
  590. # 记录筛选比较
  591. passed = sug_data["score_with_o"] >= sug_threshold
  592. filter_comparisons.append({
  593. "文本": sug_data["text"],
  594. "分数": sug_data["score_with_o"],
  595. "阈值": sug_threshold,
  596. "通过": passed
  597. })
  598. if passed:
  599. print(f" 搜索: {sug_data['text']} (分数: {sug_data['score_with_o']:.2f})")
  600. try:
  601. # 执行搜索
  602. search_result = xiaohongshu_search.search(keyword=sug_data["text"])
  603. result_str = search_result.get("result", "{}")
  604. if isinstance(result_str, str):
  605. result_data = json.loads(result_str)
  606. else:
  607. result_data = result_str
  608. notes = result_data.get("data", {}).get("data", [])
  609. print(f" → 搜索到 {len(notes)} 个帖子")
  610. # 转换为Post对象
  611. post_list = []
  612. for note in notes:
  613. post = process_note_data(note)
  614. post_list.append(post.model_dump())
  615. context.all_posts.append(post.model_dump())
  616. # 创建Search对象
  617. search = Search(
  618. text=sug_data["text"],
  619. score_with_o=sug_data["score_with_o"],
  620. from_q=sug_data["from_q"],
  621. post_list=post_list
  622. )
  623. search_list_n.append(search.model_dump())
  624. # 记录搜索详情
  625. search_details.append({
  626. "查询": sug_data["text"],
  627. "分数": sug_data["score_with_o"],
  628. "成功": True,
  629. "帖子数量": len(post_list),
  630. "错误": None
  631. })
  632. except Exception as e:
  633. print(f" ✗ 搜索失败: {e}")
  634. search_details.append({
  635. "查询": sug_data["text"],
  636. "分数": sug_data["score_with_o"],
  637. "成功": False,
  638. "帖子数量": 0,
  639. "错误": str(e)
  640. })
  641. context.search_lists[round_num] = search_list_n
  642. print(f" 本轮搜索到 {len(search_list_n)} 个有效结果")
  643. # 记录构建search和执行搜索操作(合并为一个操作)
  644. total_posts = sum(len(s["post_list"]) for s in search_list_n)
  645. add_operation_record(
  646. context,
  647. round_num=round_num,
  648. 操作名称="筛选并执行搜索",
  649. 输入={"已评估推荐词": sug_list_list_evaluated},
  650. 处理过程={
  651. "筛选条件": f"分数 >= {sug_threshold}",
  652. "筛选比较": filter_comparisons,
  653. "搜索详情": search_details
  654. },
  655. 输出={
  656. "search_list": search_list_n,
  657. "成功搜索数": len(search_list_n),
  658. "总帖子数": total_posts
  659. }
  660. )
  661. # 4. 构建word_list_(n+1)(先直接复制)
  662. print(f"\n[4/5] 构建word_list_{round_num+1}...")
  663. word_list_n = context.word_lists.get(round_num, [])
  664. word_list_next = word_list_n.copy()
  665. context.word_lists[round_num + 1] = word_list_next
  666. print(f" word_list_{round_num+1} 大小: {len(word_list_next)}")
  667. # 5. 构建q_list_(n+1)和更新seed_list
  668. print(f"\n[5/5] 构建q_list_{round_num+1}和更新seed_list...")
  669. q_list_next = []
  670. seed_list_n = context.seed_lists.get(round_num, [])
  671. seed_list_next = seed_list_n.copy()
  672. # 5.1 从seed加词(串行处理,避免重复)
  673. print(f" [5.1] 从seed加词生成新q(串行处理,去重)...")
  674. add_word_attempts = [] # 记录所有尝试
  675. new_queries_from_add = []
  676. generated_query_texts = set() # 记录已生成的查询文本
  677. for seed_data in seed_list_n:
  678. seed_text = seed_data["text"]
  679. added_words = seed_data["added_words"]
  680. # 过滤出未使用的词
  681. candidate_words = []
  682. for word_data in word_list_next:
  683. word_text = word_data["text"]
  684. # 简单字符串过滤
  685. if word_text not in seed_text and word_text not in added_words:
  686. candidate_words.append(word_data)
  687. if not candidate_words:
  688. print(f" {seed_text}: 无可用词")
  689. continue
  690. attempt = {
  691. "种子": {"text": seed_text, "已添加词": added_words},
  692. "候选词": [w["text"] for w in candidate_words[:10]]
  693. }
  694. # 使用agent选择词(提供已生成的查询列表)
  695. already_generated_str = ""
  696. if generated_query_texts:
  697. already_generated_str = f"""
  698. <已生成的查询>
  699. {', '.join(sorted(generated_query_texts))}
  700. </已生成的查询>
  701. 注意:请避免生成与上述已存在的查询重复或过于相似的新查询。
  702. """
  703. selection_input = f"""
  704. <当前Seed>
  705. {seed_text}
  706. </当前Seed>
  707. <候选词列表>
  708. {', '.join([w['text'] for w in candidate_words[:10]])}
  709. </候选词列表>
  710. {already_generated_str}
  711. 请从候选词中选择一个最适合与seed组合的词。
  712. """
  713. selection = await run_agent_with_json_cleanup(
  714. word_selector,
  715. selection_input,
  716. WordSelection
  717. )
  718. selected_word = selection.selected_word
  719. # 确保选中的词在候选列表中
  720. if selected_word not in [w["text"] for w in candidate_words]:
  721. # 如果agent选择的词不在候选列表中,使用第一个候选词
  722. selected_word = candidate_words[0]["text"]
  723. # 记录选词
  724. attempt["步骤1_选词"] = record_agent_call(
  725. agent_name="Word选择专家",
  726. model=MODEL_NAME,
  727. instructions=word_selection_instructions,
  728. user_message=selection_input,
  729. raw_output={"selected_word": selection.selected_word, "reasoning": selection.reasoning},
  730. parsed=True,
  731. input_schema={"type": "WordSelection", "fields": {"selected_word": "str", "reasoning": "str"}}
  732. )
  733. # 使用加词agent
  734. insertion_input = f"""
  735. <当前Query>
  736. {seed_text}
  737. </当前Query>
  738. <要添加的词>
  739. {selected_word}
  740. </要添加的词>
  741. 请将这个词加到query的最合适位置。
  742. """
  743. insertion = await run_agent_with_json_cleanup(
  744. word_inserter,
  745. insertion_input,
  746. WordInsertion
  747. )
  748. new_query_text = insertion.new_query
  749. # 记录插入位置
  750. attempt["步骤2_插入位置"] = record_agent_call(
  751. agent_name="加词位置评估专家",
  752. model=MODEL_NAME,
  753. instructions=word_insertion_instructions,
  754. user_message=insertion_input,
  755. raw_output={"new_query": insertion.new_query, "reasoning": insertion.reasoning},
  756. parsed=True,
  757. input_schema={"type": "WordInsertion", "fields": {"new_query": "str", "reasoning": "str"}}
  758. )
  759. # 检查是否重复
  760. if new_query_text in generated_query_texts:
  761. print(f" {seed_text} + {selected_word} → {new_query_text} (重复,跳过)")
  762. attempt["跳过原因"] = "查询重复"
  763. add_word_attempts.append(attempt)
  764. continue
  765. # 立即评估新query
  766. score, reason = await evaluate_query_with_o(new_query_text, context.o)
  767. # 记录评估
  768. attempt["步骤3_评估新查询"] = record_agent_call(
  769. agent_name="Query相关度评估专家",
  770. model=MODEL_NAME,
  771. instructions=relevance_evaluation_instructions,
  772. user_message=f"评估新query的相关度:\n\nQuery: {new_query_text}\n原始问题: {context.o}",
  773. raw_output={"score": score, "reason": reason},
  774. parsed=True
  775. )
  776. add_word_attempts.append(attempt)
  777. # 创建新q并加入列表
  778. new_q = Q(text=new_query_text, score_with_o=score, from_source="add")
  779. q_list_next.append(new_q.model_dump())
  780. new_queries_from_add.append(new_q.model_dump())
  781. generated_query_texts.add(new_query_text)
  782. # 更新seed的added_words
  783. for seed in seed_list_next:
  784. if seed["text"] == seed_text:
  785. seed["added_words"].append(selected_word)
  786. break
  787. print(f" {seed_text} + {selected_word} → {new_query_text} (分数: {score:.2f})")
  788. # 记录加词操作
  789. add_operation_record(
  790. context,
  791. round_num=round_num,
  792. 操作名称="加词生成新查询",
  793. 输入={
  794. "seed_list": seed_list_n,
  795. "word_list": word_list_next
  796. },
  797. 处理过程={"尝试列表": add_word_attempts},
  798. 输出={"新查询列表": new_queries_from_add}
  799. )
  800. # 5.2 从sug加入q_list(条件:sug分数 > from_q分数)
  801. print(f" [5.2] 从sug加入q_list_{round_num+1}(条件:sug分数 > from_q分数)...")
  802. sug_added_count = 0
  803. sug_filter_comparisons = []
  804. selected_sugs = []
  805. for sug_list_evaluated in sug_list_list_evaluated:
  806. for sug_data in sug_list_evaluated:
  807. # 新条件:sug的分数 > 其来源query的分数
  808. from_q_score = sug_data["from_q"]["score_with_o"]
  809. passed = sug_data["score_with_o"] > from_q_score
  810. sug_filter_comparisons.append({
  811. "推荐词": sug_data["text"],
  812. "推荐词分数": sug_data["score_with_o"],
  813. "来源查询分数": from_q_score,
  814. "通过": passed,
  815. "原因": f"{sug_data['score_with_o']:.2f} > {from_q_score:.2f}" if passed else f"{sug_data['score_with_o']:.2f} <= {from_q_score:.2f}"
  816. })
  817. if passed:
  818. # 检查是否已存在
  819. if sug_data["text"] not in [q["text"] for q in q_list_next]:
  820. new_q = Q(text=sug_data["text"], score_with_o=sug_data["score_with_o"], from_source="sug")
  821. q_list_next.append(new_q.model_dump())
  822. selected_sugs.append(new_q.model_dump())
  823. sug_added_count += 1
  824. print(f" ✓ {sug_data['text']} ({sug_data['score_with_o']:.2f} > {from_q_score:.2f})")
  825. print(f" 添加 {sug_added_count} 个sug到q_list_{round_num+1}")
  826. # 记录筛选sug操作
  827. add_operation_record(
  828. context,
  829. round_num=round_num,
  830. 操作名称="筛选推荐词进入下轮",
  831. 输入={"已评估推荐词": sug_list_list_evaluated},
  832. 处理过程={
  833. "筛选条件": "推荐词分数 > 来源查询分数",
  834. "比较结果": sug_filter_comparisons
  835. },
  836. 输出={"选中推荐词": selected_sugs}
  837. )
  838. # 5.3 更新seed_list(从sug中添加新seed,条件:sug分数 > from_q分数)
  839. print(f" [5.3] 更新seed_list_{round_num+1}(条件:sug分数 > from_q分数)...")
  840. seed_texts_existing = [s["text"] for s in seed_list_next]
  841. new_seed_count = 0
  842. for sug_list_evaluated in sug_list_list_evaluated:
  843. for sug_data in sug_list_evaluated:
  844. from_q_score = sug_data["from_q"]["score_with_o"]
  845. # 新条件:sug的分数 > 其来源query的分数
  846. if sug_data["score_with_o"] > from_q_score and sug_data["text"] not in seed_texts_existing:
  847. new_seed = Seed(text=sug_data["text"], added_words=[], from_type="sug")
  848. seed_list_next.append(new_seed.model_dump())
  849. seed_texts_existing.append(sug_data["text"])
  850. new_seed_count += 1
  851. print(f" 添加 {new_seed_count} 个sug到seed_list_{round_num+1}")
  852. context.q_lists[round_num + 1] = q_list_next
  853. context.seed_lists[round_num + 1] = seed_list_next
  854. print(f"\n q_list_{round_num+1} 大小: {len(q_list_next)}")
  855. print(f" seed_list_{round_num+1} 大小: {len(seed_list_next)}")
  856. # 记录构建下一轮操作
  857. add_operation_record(
  858. context,
  859. round_num=round_num,
  860. 操作名称="构建下一轮",
  861. 输入={
  862. "加词新查询": new_queries_from_add,
  863. "选中推荐词": selected_sugs
  864. },
  865. 处理过程={
  866. "合并": {
  867. "来自加词": len(new_queries_from_add),
  868. "来自推荐词": len(selected_sugs),
  869. "合并前总数": len(new_queries_from_add) + len(selected_sugs)
  870. },
  871. "去重": {
  872. "唯一数": len(q_list_next)
  873. }
  874. },
  875. 输出={
  876. "下轮查询列表": q_list_next,
  877. "下轮种子列表": seed_list_next
  878. }
  879. )
  880. add_step(context, f"第{round_num}轮完成", "round", {
  881. "round": round_num,
  882. "q_list_count": len(q_list_n),
  883. "sug_total_count": sum(len(s) for s in sug_list_list_evaluated),
  884. "search_count": len(search_list_n),
  885. "posts_found": sum(len(s["post_list"]) for s in search_list_n),
  886. "q_list_next_count": len(q_list_next),
  887. "seed_list_next_count": len(seed_list_next)
  888. })
  889. async def main_loop(context: RunContext, max_rounds: int = 2):
  890. """主循环
  891. Args:
  892. context: 运行上下文
  893. max_rounds: 最大轮数(默认2)
  894. """
  895. print("\n" + "="*60)
  896. print("开始主循环")
  897. print("="*60)
  898. # 初始化
  899. await initialize(context)
  900. # API实例
  901. xiaohongshu_api = XiaohongshuSearchRecommendations()
  902. xiaohongshu_search = XiaohongshuSearch()
  903. # 迭代
  904. for round_num in range(1, max_rounds + 1):
  905. await process_round(round_num, context, xiaohongshu_api, xiaohongshu_search)
  906. # 检查终止条件
  907. q_list_next = context.q_lists.get(round_num + 1, [])
  908. if not q_list_next:
  909. print(f"\n q_list_{round_num + 1} 为空,提前结束")
  910. break
  911. print("\n" + "="*60)
  912. print("主循环完成")
  913. print("="*60)
  914. print(f" 总共收集 {len(context.all_posts)} 个帖子")
  915. # ============================================================================
  916. # 主函数
  917. # ============================================================================
  918. async def main(input_dir: str, max_rounds: int = 2, visualize: bool = False):
  919. """主函数"""
  920. current_time, log_url = set_trace()
  921. # 读取输入
  922. input_context_file = os.path.join(input_dir, 'context.md')
  923. input_q_file = os.path.join(input_dir, 'q.md')
  924. c = read_file_as_string(input_context_file)
  925. o = read_file_as_string(input_q_file)
  926. # 版本信息
  927. version = os.path.basename(__file__)
  928. version_name = os.path.splitext(version)[0]
  929. # 日志目录
  930. log_dir = os.path.join(input_dir, "output", version_name, current_time)
  931. # 创建运行上下文
  932. run_context = RunContext(
  933. version=version,
  934. input_files={
  935. "input_dir": input_dir,
  936. "context_file": input_context_file,
  937. "q_file": input_q_file,
  938. },
  939. c=c,
  940. o=o,
  941. log_dir=log_dir,
  942. log_url=log_url,
  943. )
  944. # 执行主循环
  945. await main_loop(run_context, max_rounds=max_rounds)
  946. # 格式化输出
  947. output = f"原始需求:{run_context.c}\n"
  948. output += f"原始问题:{run_context.o}\n"
  949. output += f"收集帖子:{len(run_context.all_posts)} 个\n"
  950. output += "\n" + "="*60 + "\n"
  951. if run_context.all_posts:
  952. output += "【收集到的帖子】\n\n"
  953. for idx, post in enumerate(run_context.all_posts[:20], 1): # 只显示前20个
  954. output += f"{idx}. {post['title']}\n"
  955. output += f" 类型: {post['type']}\n"
  956. output += f" URL: {post['note_url']}\n\n"
  957. else:
  958. output += "未收集到帖子\n"
  959. run_context.final_output = output
  960. print(f"\n{'='*60}")
  961. print("最终结果")
  962. print(f"{'='*60}")
  963. print(output)
  964. # 保存日志
  965. os.makedirs(run_context.log_dir, exist_ok=True)
  966. context_file_path = os.path.join(run_context.log_dir, "run_context.json")
  967. context_dict = run_context.model_dump()
  968. with open(context_file_path, "w", encoding="utf-8") as f:
  969. json.dump(context_dict, f, ensure_ascii=False, indent=2)
  970. print(f"\nRunContext saved to: {context_file_path}")
  971. steps_file_path = os.path.join(run_context.log_dir, "steps.json")
  972. with open(steps_file_path, "w", encoding="utf-8") as f:
  973. json.dump(run_context.steps, f, ensure_ascii=False, indent=2)
  974. print(f"Steps log saved to: {steps_file_path}")
  975. if __name__ == "__main__":
  976. parser = argparse.ArgumentParser(description="搜索query优化工具 - v6.1.2.7 基于seed的迭代版")
  977. parser.add_argument(
  978. "--input-dir",
  979. type=str,
  980. default="input/简单扣图",
  981. help="输入目录路径,默认: input/简单扣图"
  982. )
  983. parser.add_argument(
  984. "--max-rounds",
  985. type=int,
  986. default=2,
  987. help="最大轮数,默认: 2"
  988. )
  989. parser.add_argument(
  990. "--visualize",
  991. action="store_true",
  992. default=True,
  993. help="运行完成后自动生成可视化HTML"
  994. )
  995. args = parser.parse_args()
  996. asyncio.run(main(args.input_dir, max_rounds=args.max_rounds, visualize=args.visualize))