| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213 |
- import asyncio
- import json
- import os
- import sys
- import argparse
- import time
- import re
- from datetime import datetime
- from typing import Literal, TypeVar, Type
- from agents import Agent, Runner
- from lib.my_trace import set_trace
- from pydantic import BaseModel, Field
- from lib.utils import read_file_as_string
- from lib.client import get_model
- MODEL_NAME = "google/gemini-2.5-flash"
- from script.search_recommendations.xiaohongshu_search_recommendations import XiaohongshuSearchRecommendations
- from script.search.xiaohongshu_search import XiaohongshuSearch
- # ============================================================================
- # 数据模型
- # ============================================================================
- class Seg(BaseModel):
- """分词结果"""
- text: str
- score_with_o: float
- from_o: str
- class Word(BaseModel):
- """词库中的词"""
- text: str
- score_with_o: float
- from_o: str
- class Q(BaseModel):
- """查询"""
- text: str
- score_with_o: float
- from_source: str # "seg" | "sug" | "add"
- class Sug(BaseModel):
- """建议查询"""
- text: str
- score_with_o: float
- from_q: dict # {"text": str, "score_with_o": float}
- evaluation_reason: str | None = None # 评估理由
- class Seed(BaseModel):
- """种子查询(用于加词探索)"""
- text: str
- added_words: list[str] = Field(default_factory=list)
- from_type: str # "seg" | "sug"
- class Post(BaseModel):
- """帖子"""
- note_id: str = ""
- title: str = ""
- body_text: str = ""
- type: str = "normal" # "video" | "normal"
- images: list[str] = Field(default_factory=list)
- video: str = ""
- interact_info: dict = Field(default_factory=dict)
- note_url: str = ""
- class Search(BaseModel):
- """搜索结果(继承自Sug)"""
- text: str
- score_with_o: float
- from_q: dict
- post_list: list[Post] = Field(default_factory=list)
- class RunContext(BaseModel):
- """运行上下文"""
- version: str
- input_files: dict[str, str]
- c: str # 原始需求(context)
- o: str # 原始问题
- log_url: str
- log_dir: str
- # 核心数据
- seg_list: list[dict] = Field(default_factory=list)
- word_lists: dict[int, list[dict]] = Field(default_factory=dict) # {round: word_list}
- q_lists: dict[int, list[dict]] = Field(default_factory=dict) # {round: q_list}
- sug_list_lists: dict[int, list[list[dict]]] = Field(default_factory=dict) # {round: [[sug, sug], [sug]]}
- search_lists: dict[int, list[dict]] = Field(default_factory=dict) # {round: search_list}
- seed_lists: dict[int, list[dict]] = Field(default_factory=dict) # {round: seed_list}
- steps: list[dict] = Field(default_factory=list)
- # 新增:详细的操作记录(中文命名,但数据结构保留英文)
- 轮次记录: dict[int, dict] = Field(default_factory=dict)
- # 最终结果
- all_posts: list[dict] = Field(default_factory=list)
- final_output: str | None = None
- # ============================================================================
- # 辅助函数:记录操作
- # ============================================================================
- def init_round_record(run_context: RunContext, round_num: int, round_name: str):
- """初始化一个轮次记录"""
- run_context.轮次记录[round_num] = {
- "轮次": round_num,
- "名称": round_name,
- "操作列表": []
- }
- def add_operation_record(
- run_context: RunContext,
- round_num: int,
- 操作名称: str,
- 输入: dict,
- 处理过程: dict,
- 输出: dict
- ):
- """添加一条操作记录"""
- from datetime import datetime
- operation = {
- "操作名称": 操作名称,
- "轮次": round_num,
- "时间": datetime.now().isoformat(),
- "输入": 输入,
- "处理过程": 处理过程,
- "输出": 输出
- }
- if round_num not in run_context.轮次记录:
- init_round_record(run_context, round_num, f"第{round_num}轮" if round_num > 0 else "初始化阶段")
- run_context.轮次记录[round_num]["操作列表"].append(operation)
- def record_agent_call(
- agent_name: str,
- model: str,
- instructions: str,
- user_message: str,
- raw_output: dict | str,
- parsed: bool,
- validation_error: str | None = None,
- input_schema: dict | None = None
- ) -> dict:
- """记录单次Agent调用"""
- return {
- "Agent名称": agent_name,
- "模型": model,
- "系统提示词": instructions,
- "输入Schema": input_schema,
- "用户消息": user_message,
- "原始输出": raw_output,
- "解析成功": parsed,
- "验证错误": validation_error
- }
- # ============================================================================
- # JSON后处理:处理markdown包裹的JSON响应
- # ============================================================================
- def clean_json_response(text: str) -> str:
- """清理可能包含markdown代码块包裹的JSON
- 模型可能返回:
- ```json
- {"key": "value"}
- ```
- 需要清理为:
- {"key": "value"}
- """
- text = text.strip()
- # 移除开头的 ```json 或 ```
- if text.startswith('```json'):
- text = text[7:]
- elif text.startswith('```'):
- text = text[3:]
- # 移除结尾的 ```
- if text.endswith('```'):
- text = text[:-3]
- return text.strip()
- T = TypeVar('T', bound=BaseModel)
- async def run_agent_with_json_cleanup(
- agent: Agent,
- input_text: str,
- output_type: Type[T]
- ) -> T:
- """运行Agent并处理可能的JSON包裹问题
- 如果Agent返回被markdown包裹的JSON,自动清理后重新解析
- """
- try:
- result = await Runner.run(agent, input_text)
- return result.final_output
- except Exception as e:
- error_msg = str(e)
- # 检查是否是JSON解析错误
- if "Invalid JSON when parsing" in error_msg:
- # 尝试从错误消息中提取JSON
- # 错误格式: "Invalid JSON when parsing ```json\n{...}\n``` for TypeAdapter(...)"
- match = re.search(r'when parsing (.+?) for TypeAdapter', error_msg, re.DOTALL)
- if match:
- json_text = match.group(1)
- cleaned_json = clean_json_response(json_text)
- try:
- # 手动解析JSON并创建Pydantic对象
- parsed_data = json.loads(cleaned_json)
- return output_type(**parsed_data)
- except Exception as parse_error:
- print(f"⚠️ JSON清理后仍无法解析: {parse_error}")
- print(f" 清理后的JSON: {cleaned_json}")
- raise ValueError(f"无法解析JSON: {parse_error}\n原始错误: {error_msg}")
- # 如果不是JSON解析错误,或清理失败,重新抛出原始错误
- raise
- # ============================================================================
- # Agent 定义
- # ============================================================================
- # Agent 1: 分词专家
- class WordSegmentation(BaseModel):
- """分词结果"""
- words: list[str] = Field(..., description="分词结果列表")
- reasoning: str = Field(..., description="分词理由")
- word_segmentation_instructions = """
- 你是分词专家。给定一个query,将其拆分成有意义的最小单元。
- ## 分词原则
- 1. 保留有搜索意义的词汇
- 2. 拆分成独立的概念
- 3. 保留专业术语的完整性
- 4. 去除虚词(的、吗、呢等)
- ## 输出要求
- 返回分词列表和分词理由。
- IMPORTANT: 直接返回纯JSON对象,不要使用markdown代码块标记(不要用```json...```包裹)。
- """.strip()
- word_segmenter = Agent[None](
- name="分词专家",
- instructions=word_segmentation_instructions,
- model=get_model(MODEL_NAME),
- output_type=WordSegmentation,
- )
- # Agent 2: Query相关度评估专家
- class RelevanceEvaluation(BaseModel):
- """相关度评估"""
- relevance_score: float = Field(..., description="相关性分数 0-1")
- reason: str = Field(..., description="评估理由")
- relevance_evaluation_instructions = """
- 你是Query相关度评估专家。
- ## 任务
- 评估当前query与原始问题的匹配程度。
- ## 评估标准
- - 主题相关性
- - 要素覆盖度
- - 意图匹配度
- ## 输出
- - relevance_score: 0-1的相关性分数
- - reason: 详细理由
- IMPORTANT: 直接返回纯JSON对象,不要使用markdown代码块标记(不要用```json...```包裹)。
- """.strip()
- relevance_evaluator = Agent[None](
- name="Query相关度评估专家",
- instructions=relevance_evaluation_instructions,
- model=get_model(MODEL_NAME),
- output_type=RelevanceEvaluation,
- )
- # Agent 3: Word选择专家
- class WordSelection(BaseModel):
- """Word选择结果"""
- selected_word: str = Field(..., description="选中的词")
- reasoning: str = Field(..., description="选择理由")
- word_selection_instructions = """
- 你是Word选择专家。
- ## 任务
- 从候选词列表中选择一个最适合与当前seed组合的词,用于探索新的搜索query。
- ## 选择原则
- 1. 与seed的语义相关性
- 2. 组合后的搜索价值
- 3. 能拓展搜索范围
- ## 输出
- 返回选中的词和选择理由。
- """.strip()
- word_selector = Agent[None](
- name="Word选择专家",
- instructions=word_selection_instructions,
- model=get_model(MODEL_NAME),
- output_type=WordSelection,
- )
- # Agent 4: 加词位置评估专家
- class WordInsertion(BaseModel):
- """加词结果"""
- new_query: str = Field(..., description="加词后的新query")
- insertion_position: str = Field(..., description="插入位置描述")
- reasoning: str = Field(..., description="插入理由")
- word_insertion_instructions = """
- 你是加词位置评估专家。
- ## 任务
- 将新词加到当前query的最合适位置,保持语义通顺。
- ## 原则
- 1. 保持语法正确
- 2. 语义连贯
- 3. 符合搜索习惯
- ## 输出
- 返回新query、插入位置描述和理由。
- """.strip()
- word_inserter = Agent[None](
- name="加词位置评估专家",
- instructions=word_insertion_instructions,
- model=get_model(MODEL_NAME),
- output_type=WordInsertion,
- )
- # ============================================================================
- # 辅助函数
- # ============================================================================
- def add_step(context: RunContext, step_name: str, step_type: str, data: dict):
- """添加步骤记录"""
- step = {
- "step_number": len(context.steps) + 1,
- "step_name": step_name,
- "step_type": step_type,
- "timestamp": datetime.now().isoformat(),
- "data": data
- }
- context.steps.append(step)
- return step
- def process_note_data(note: dict) -> Post:
- """处理搜索接口返回的帖子数据,转换为Post对象"""
- note_card = note.get("note_card", {})
- image_list = note_card.get("image_list", [])
- interact_info = note_card.get("interact_info", {})
- # 提取图片URLs - 使用 image_url 字段
- images = []
- for img in image_list:
- if "image_url" in img:
- images.append(img["image_url"])
- # 判断是否是视频
- note_type = note_card.get("type", "normal")
- video_url = ""
- if note_type == "video":
- # 视频类型可能有不同的结构,这里先留空
- # 如果需要可以后续补充
- pass
- return Post(
- note_id=note.get("id") or "",
- title=note_card.get("display_title") or "",
- body_text=note_card.get("desc") or "",
- type=note_type,
- images=images,
- video=video_url,
- interact_info={
- "liked_count": interact_info.get("liked_count", 0),
- "collected_count": interact_info.get("collected_count", 0),
- "comment_count": interact_info.get("comment_count", 0),
- "shared_count": interact_info.get("shared_count", 0)
- },
- note_url=f"https://www.xiaohongshu.com/explore/{note.get('id') or ''}"
- )
- # ============================================================================
- # 核心流程函数
- # ============================================================================
- async def evaluate_query_with_o(query_text: str, original_o: str) -> tuple[float, str]:
- """评估query与原始问题o的相关度
- Returns:
- (score, reason)
- """
- eval_input = f"""
- <原始问题>
- {original_o}
- </原始问题>
- <当前Query>
- {query_text}
- </当前Query>
- 请评估当前query与原始问题的相关度。
- """
- evaluation = await run_agent_with_json_cleanup(
- relevance_evaluator,
- eval_input,
- RelevanceEvaluation
- )
- return evaluation.relevance_score, evaluation.reason
- async def initialize(context: RunContext):
- """初始化:分词 → seg_list → word_list_1, q_list_1, seed_list_1"""
- print("\n" + "="*60)
- print("初始化阶段")
- print("="*60)
- # 初始化轮次0
- init_round_record(context, 0, "初始化阶段")
- # 1. 分词
- print(f"\n[1/4] 分词原始问题: {context.o}")
- segmentation = await run_agent_with_json_cleanup(
- word_segmenter,
- context.o,
- WordSegmentation
- )
- print(f" 分词结果: {segmentation.words}")
- print(f" 分词理由: {segmentation.reasoning}")
- # 2. 分词评估(并发)
- print(f"\n[2/4] 评估每个seg与原始问题的相关度...")
- seg_list = []
- agent_calls_seg_eval = []
- # 并发评估所有分词
- eval_tasks = [evaluate_query_with_o(word, context.o) for word in segmentation.words]
- eval_results = await asyncio.gather(*eval_tasks)
- for word, (score, reason) in zip(segmentation.words, eval_results):
- seg = Seg(text=word, score_with_o=score, from_o=context.o)
- seg_list.append(seg.model_dump())
- print(f" {word}: {score:.2f}")
- # 记录每个seg的评估
- agent_calls_seg_eval.append(
- record_agent_call(
- agent_name="Query相关度评估专家",
- model=MODEL_NAME,
- instructions=relevance_evaluation_instructions,
- user_message=f"评估query与原始问题的相关度:\n\nQuery: {word}\n原始问题: {context.o}",
- raw_output={"score": score, "reason": reason},
- parsed=True
- )
- )
- context.seg_list = seg_list
- # 记录分词操作
- add_operation_record(
- context,
- round_num=0,
- 操作名称="分词",
- 输入={"原始问题": context.o},
- 处理过程={
- "Agent调用": record_agent_call(
- agent_name="分词专家",
- model=MODEL_NAME,
- instructions=word_segmentation_instructions,
- user_message=f"请对以下query进行分词:{context.o}",
- raw_output={"words": segmentation.words, "reasoning": segmentation.reasoning},
- parsed=True,
- input_schema={"type": "WordSegmentation", "fields": {"words": "list[str]", "reasoning": "str"}}
- ),
- "seg评估Agent调用列表": agent_calls_seg_eval
- },
- 输出={"seg_list": seg_list}
- )
- # 3. 构建 word_list_1(直接从seg_list复制)
- print(f"\n[3/4] 构建 word_list_1...")
- word_list_1 = []
- for seg in seg_list:
- word = Word(text=seg["text"], score_with_o=seg["score_with_o"], from_o=seg["from_o"])
- word_list_1.append(word.model_dump())
- context.word_lists[1] = word_list_1
- print(f" word_list_1 大小: {len(word_list_1)}")
- # 4. 构建 q_list_1 和 seed_list_1
- print(f"\n[4/4] 构建 q_list_1 和 seed_list_1...")
- q_list_1 = []
- seed_list_1 = []
- for seg in seg_list:
- # q_list_1: seg作为q
- q = Q(text=seg["text"], score_with_o=seg["score_with_o"], from_source="seg")
- q_list_1.append(q.model_dump())
- # seed_list_1: seg作为seed
- seed = Seed(text=seg["text"], added_words=[], from_type="seg")
- seed_list_1.append(seed.model_dump())
- context.q_lists[1] = q_list_1
- context.seed_lists[1] = seed_list_1
- print(f" q_list_1 大小: {len(q_list_1)}")
- print(f" seed_list_1 大小: {len(seed_list_1)}")
- # 记录初始化操作
- add_operation_record(
- context,
- round_num=0,
- 操作名称="初始化",
- 输入={"seg_list": seg_list},
- 处理过程={"说明": "从seg_list构建初始q_list和seed_list"},
- 输出={
- "word_list_1": word_list_1,
- "q_list_1": q_list_1,
- "seed_list_1": seed_list_1
- }
- )
- add_step(context, "初始化完成", "initialize", {
- "seg_count": len(seg_list),
- "word_list_1_count": len(word_list_1),
- "q_list_1_count": len(q_list_1),
- "seed_list_1_count": len(seed_list_1)
- })
- async def process_round(round_num: int, context: RunContext, xiaohongshu_api: XiaohongshuSearchRecommendations, xiaohongshu_search: XiaohongshuSearch, sug_threshold: float = 0.7):
- """处理一轮迭代
- Args:
- round_num: 当前轮数
- context: 运行上下文
- xiaohongshu_api: sug API
- xiaohongshu_search: search API
- sug_threshold: sug评分阈值(默认0.7)
- """
- print(f"\n" + "="*60)
- print(f"第 {round_num} 轮")
- print("="*60)
- # 初始化轮次记录
- init_round_record(context, round_num, f"第{round_num}轮迭代")
- q_list_n = context.q_lists.get(round_num, [])
- if not q_list_n:
- print(f" q_list_{round_num} 为空,跳过本轮")
- return
- print(f" 处理 {len(q_list_n)} 个query")
- # 1. 请求sug
- print(f"\n[1/5] 请求sug...")
- sug_list_list_n = []
- api_calls_detail = []
- for q_data in q_list_n:
- q_text = q_data["text"]
- suggestions = xiaohongshu_api.get_recommendations(keyword=q_text)
- if not suggestions:
- print(f" {q_text}: 无sug")
- sug_list_list_n.append([])
- api_calls_detail.append({
- "query": q_text,
- "sug_count": 0
- })
- continue
- print(f" {q_text}: 获取 {len(suggestions)} 个sug")
- sug_list_list_n.append(suggestions)
- api_calls_detail.append({
- "query": q_text,
- "sug_count": len(suggestions)
- })
- # 记录请求sug操作
- total_sugs = sum(len(sl) for sl in sug_list_list_n)
- add_operation_record(
- context,
- round_num=round_num,
- 操作名称="请求推荐词",
- 输入={"q_list": [{"text": q["text"], "score": q["score_with_o"]} for q in q_list_n]},
- 处理过程={"API调用": api_calls_detail},
- 输出={
- "sug_list_list": [[{"text": s, "from_q": q_list_n[i]["text"]} for s in sl] for i, sl in enumerate(sug_list_list_n)],
- "总推荐词数": total_sugs
- }
- )
- # 2. sug评估(批量并发,限制并发数为10)
- print(f"\n[2/5] 评估sug...")
- sug_list_list_evaluated = []
- # 收集所有需要评估的sug及其上下文
- all_sug_tasks = []
- sug_contexts = [] # 记录每个sug对应的q_data和位置
- for i, sug_list in enumerate(sug_list_list_n):
- q_data = q_list_n[i]
- for sug_text in sug_list:
- all_sug_tasks.append(evaluate_query_with_o(sug_text, context.o))
- sug_contexts.append((i, q_data, sug_text))
- # 批量并发评估(每批10个)
- batch_size = 10
- all_results = []
- batches_detail = []
- for batch_idx in range(0, len(all_sug_tasks), batch_size):
- batch_tasks = all_sug_tasks[batch_idx:batch_idx+batch_size]
- batch_results = await asyncio.gather(*batch_tasks)
- all_results.extend(batch_results)
- # 记录这个批次的Agent调用
- batch_agent_calls = []
- start_idx = batch_idx
- for j, (score, reason) in enumerate(batch_results):
- if start_idx + j < len(sug_contexts):
- _, _, sug_text = sug_contexts[start_idx + j]
- batch_agent_calls.append(
- record_agent_call(
- agent_name="Query相关度评估专家",
- model=MODEL_NAME,
- instructions=relevance_evaluation_instructions,
- user_message=f"评估query与原始问题的相关度:\n\nQuery: {sug_text}\n原始问题: {context.o}",
- raw_output={"score": score, "reason": reason},
- parsed=True
- )
- )
- batches_detail.append({
- "批次ID": len(batches_detail),
- "并发执行": True,
- "Agent调用列表": batch_agent_calls
- })
- # 组织结果
- result_index = 0
- current_list_index = -1
- evaluated_sugs = []
- for list_idx, q_data, sug_text in sug_contexts:
- if list_idx != current_list_index:
- if evaluated_sugs:
- sug_list_list_evaluated.append(evaluated_sugs)
- evaluated_sugs = []
- current_list_index = list_idx
- score, reason = all_results[result_index]
- result_index += 1
- sug = Sug(
- text=sug_text,
- score_with_o=score,
- from_q={"text": q_data["text"], "score_with_o": q_data["score_with_o"]},
- evaluation_reason=reason
- )
- evaluated_sugs.append(sug.model_dump())
- print(f" {sug_text}: {score:.2f}")
- # 添加最后一批
- if evaluated_sugs:
- sug_list_list_evaluated.append(evaluated_sugs)
- context.sug_list_lists[round_num] = sug_list_list_evaluated
- # 记录评估sug操作
- add_operation_record(
- context,
- round_num=round_num,
- 操作名称="评估推荐词",
- 输入={
- "待评估推荐词": [[s for s in sl] for sl in sug_list_list_n],
- "总数": len(all_sug_tasks)
- },
- 处理过程={"批次列表": batches_detail},
- 输出={"已评估推荐词": sug_list_list_evaluated}
- )
- # 3. 构建search_list_n(阈值>= 0.7的sug)
- print(f"\n[3/5] 构建search_list并执行搜索...")
- search_list_n = []
- filter_comparisons = []
- search_details = []
- for sug_list_evaluated in sug_list_list_evaluated:
- for sug_data in sug_list_evaluated:
- # 记录筛选比较
- passed = sug_data["score_with_o"] >= sug_threshold
- filter_comparisons.append({
- "文本": sug_data["text"],
- "分数": sug_data["score_with_o"],
- "阈值": sug_threshold,
- "通过": passed
- })
- if passed:
- print(f" 搜索: {sug_data['text']} (分数: {sug_data['score_with_o']:.2f})")
- try:
- # 执行搜索
- search_result = xiaohongshu_search.search(keyword=sug_data["text"])
- result_str = search_result.get("result", "{}")
- if isinstance(result_str, str):
- result_data = json.loads(result_str)
- else:
- result_data = result_str
- notes = result_data.get("data", {}).get("data", [])
- print(f" → 搜索到 {len(notes)} 个帖子")
- # 转换为Post对象
- post_list = []
- for note in notes:
- post = process_note_data(note)
- post_list.append(post.model_dump())
- context.all_posts.append(post.model_dump())
- # 创建Search对象
- search = Search(
- text=sug_data["text"],
- score_with_o=sug_data["score_with_o"],
- from_q=sug_data["from_q"],
- post_list=post_list
- )
- search_list_n.append(search.model_dump())
- # 记录搜索详情
- search_details.append({
- "查询": sug_data["text"],
- "分数": sug_data["score_with_o"],
- "成功": True,
- "帖子数量": len(post_list),
- "错误": None
- })
- except Exception as e:
- print(f" ✗ 搜索失败: {e}")
- search_details.append({
- "查询": sug_data["text"],
- "分数": sug_data["score_with_o"],
- "成功": False,
- "帖子数量": 0,
- "错误": str(e)
- })
- context.search_lists[round_num] = search_list_n
- print(f" 本轮搜索到 {len(search_list_n)} 个有效结果")
- # 记录构建search和执行搜索操作(合并为一个操作)
- total_posts = sum(len(s["post_list"]) for s in search_list_n)
- add_operation_record(
- context,
- round_num=round_num,
- 操作名称="筛选并执行搜索",
- 输入={"已评估推荐词": sug_list_list_evaluated},
- 处理过程={
- "筛选条件": f"分数 >= {sug_threshold}",
- "筛选比较": filter_comparisons,
- "搜索详情": search_details
- },
- 输出={
- "search_list": search_list_n,
- "成功搜索数": len(search_list_n),
- "总帖子数": total_posts
- }
- )
- # 4. 构建word_list_(n+1)(先直接复制)
- print(f"\n[4/5] 构建word_list_{round_num+1}...")
- word_list_n = context.word_lists.get(round_num, [])
- word_list_next = word_list_n.copy()
- context.word_lists[round_num + 1] = word_list_next
- print(f" word_list_{round_num+1} 大小: {len(word_list_next)}")
- # 5. 构建q_list_(n+1)和更新seed_list
- print(f"\n[5/5] 构建q_list_{round_num+1}和更新seed_list...")
- q_list_next = []
- seed_list_n = context.seed_lists.get(round_num, [])
- seed_list_next = seed_list_n.copy()
- # 5.1 从seed加词(串行处理,避免重复)
- print(f" [5.1] 从seed加词生成新q(串行处理,去重)...")
- add_word_attempts = [] # 记录所有尝试
- new_queries_from_add = []
- generated_query_texts = set() # 记录已生成的查询文本
- for seed_data in seed_list_n:
- seed_text = seed_data["text"]
- added_words = seed_data["added_words"]
- # 过滤出未使用的词
- candidate_words = []
- for word_data in word_list_next:
- word_text = word_data["text"]
- # 简单字符串过滤
- if word_text not in seed_text and word_text not in added_words:
- candidate_words.append(word_data)
- if not candidate_words:
- print(f" {seed_text}: 无可用词")
- continue
- attempt = {
- "种子": {"text": seed_text, "已添加词": added_words},
- "候选词": [w["text"] for w in candidate_words[:10]]
- }
- # 使用agent选择词(提供已生成的查询列表)
- already_generated_str = ""
- if generated_query_texts:
- already_generated_str = f"""
- <已生成的查询>
- {', '.join(sorted(generated_query_texts))}
- </已生成的查询>
- 注意:请避免生成与上述已存在的查询重复或过于相似的新查询。
- """
- selection_input = f"""
- <当前Seed>
- {seed_text}
- </当前Seed>
- <候选词列表>
- {', '.join([w['text'] for w in candidate_words[:10]])}
- </候选词列表>
- {already_generated_str}
- 请从候选词中选择一个最适合与seed组合的词。
- """
- selection = await run_agent_with_json_cleanup(
- word_selector,
- selection_input,
- WordSelection
- )
- selected_word = selection.selected_word
- # 确保选中的词在候选列表中
- if selected_word not in [w["text"] for w in candidate_words]:
- # 如果agent选择的词不在候选列表中,使用第一个候选词
- selected_word = candidate_words[0]["text"]
- # 记录选词
- attempt["步骤1_选词"] = record_agent_call(
- agent_name="Word选择专家",
- model=MODEL_NAME,
- instructions=word_selection_instructions,
- user_message=selection_input,
- raw_output={"selected_word": selection.selected_word, "reasoning": selection.reasoning},
- parsed=True,
- input_schema={"type": "WordSelection", "fields": {"selected_word": "str", "reasoning": "str"}}
- )
- # 使用加词agent
- insertion_input = f"""
- <当前Query>
- {seed_text}
- </当前Query>
- <要添加的词>
- {selected_word}
- </要添加的词>
- 请将这个词加到query的最合适位置。
- """
- insertion = await run_agent_with_json_cleanup(
- word_inserter,
- insertion_input,
- WordInsertion
- )
- new_query_text = insertion.new_query
- # 记录插入位置
- attempt["步骤2_插入位置"] = record_agent_call(
- agent_name="加词位置评估专家",
- model=MODEL_NAME,
- instructions=word_insertion_instructions,
- user_message=insertion_input,
- raw_output={"new_query": insertion.new_query, "reasoning": insertion.reasoning},
- parsed=True,
- input_schema={"type": "WordInsertion", "fields": {"new_query": "str", "reasoning": "str"}}
- )
- # 检查是否重复
- if new_query_text in generated_query_texts:
- print(f" {seed_text} + {selected_word} → {new_query_text} (重复,跳过)")
- attempt["跳过原因"] = "查询重复"
- add_word_attempts.append(attempt)
- continue
- # 立即评估新query
- score, reason = await evaluate_query_with_o(new_query_text, context.o)
- # 记录评估
- attempt["步骤3_评估新查询"] = record_agent_call(
- agent_name="Query相关度评估专家",
- model=MODEL_NAME,
- instructions=relevance_evaluation_instructions,
- user_message=f"评估新query的相关度:\n\nQuery: {new_query_text}\n原始问题: {context.o}",
- raw_output={"score": score, "reason": reason},
- parsed=True
- )
- add_word_attempts.append(attempt)
- # 创建新q并加入列表
- new_q = Q(text=new_query_text, score_with_o=score, from_source="add")
- q_list_next.append(new_q.model_dump())
- new_queries_from_add.append(new_q.model_dump())
- generated_query_texts.add(new_query_text)
- # 更新seed的added_words
- for seed in seed_list_next:
- if seed["text"] == seed_text:
- seed["added_words"].append(selected_word)
- break
- print(f" {seed_text} + {selected_word} → {new_query_text} (分数: {score:.2f})")
- # 记录加词操作
- add_operation_record(
- context,
- round_num=round_num,
- 操作名称="加词生成新查询",
- 输入={
- "seed_list": seed_list_n,
- "word_list": word_list_next
- },
- 处理过程={"尝试列表": add_word_attempts},
- 输出={"新查询列表": new_queries_from_add}
- )
- # 5.2 从sug加入q_list(条件:sug分数 > from_q分数)
- print(f" [5.2] 从sug加入q_list_{round_num+1}(条件:sug分数 > from_q分数)...")
- sug_added_count = 0
- sug_filter_comparisons = []
- selected_sugs = []
- for sug_list_evaluated in sug_list_list_evaluated:
- for sug_data in sug_list_evaluated:
- # 新条件:sug的分数 > 其来源query的分数
- from_q_score = sug_data["from_q"]["score_with_o"]
- passed = sug_data["score_with_o"] > from_q_score
- sug_filter_comparisons.append({
- "推荐词": sug_data["text"],
- "推荐词分数": sug_data["score_with_o"],
- "来源查询分数": from_q_score,
- "通过": passed,
- "原因": f"{sug_data['score_with_o']:.2f} > {from_q_score:.2f}" if passed else f"{sug_data['score_with_o']:.2f} <= {from_q_score:.2f}"
- })
- if passed:
- # 检查是否已存在
- if sug_data["text"] not in [q["text"] for q in q_list_next]:
- new_q = Q(text=sug_data["text"], score_with_o=sug_data["score_with_o"], from_source="sug")
- q_list_next.append(new_q.model_dump())
- selected_sugs.append(new_q.model_dump())
- sug_added_count += 1
- print(f" ✓ {sug_data['text']} ({sug_data['score_with_o']:.2f} > {from_q_score:.2f})")
- print(f" 添加 {sug_added_count} 个sug到q_list_{round_num+1}")
- # 记录筛选sug操作
- add_operation_record(
- context,
- round_num=round_num,
- 操作名称="筛选推荐词进入下轮",
- 输入={"已评估推荐词": sug_list_list_evaluated},
- 处理过程={
- "筛选条件": "推荐词分数 > 来源查询分数",
- "比较结果": sug_filter_comparisons
- },
- 输出={"选中推荐词": selected_sugs}
- )
- # 5.3 更新seed_list(从sug中添加新seed,条件:sug分数 > from_q分数)
- print(f" [5.3] 更新seed_list_{round_num+1}(条件:sug分数 > from_q分数)...")
- seed_texts_existing = [s["text"] for s in seed_list_next]
- new_seed_count = 0
- for sug_list_evaluated in sug_list_list_evaluated:
- for sug_data in sug_list_evaluated:
- from_q_score = sug_data["from_q"]["score_with_o"]
- # 新条件:sug的分数 > 其来源query的分数
- if sug_data["score_with_o"] > from_q_score and sug_data["text"] not in seed_texts_existing:
- new_seed = Seed(text=sug_data["text"], added_words=[], from_type="sug")
- seed_list_next.append(new_seed.model_dump())
- seed_texts_existing.append(sug_data["text"])
- new_seed_count += 1
- print(f" 添加 {new_seed_count} 个sug到seed_list_{round_num+1}")
- context.q_lists[round_num + 1] = q_list_next
- context.seed_lists[round_num + 1] = seed_list_next
- print(f"\n q_list_{round_num+1} 大小: {len(q_list_next)}")
- print(f" seed_list_{round_num+1} 大小: {len(seed_list_next)}")
- # 记录构建下一轮操作
- add_operation_record(
- context,
- round_num=round_num,
- 操作名称="构建下一轮",
- 输入={
- "加词新查询": new_queries_from_add,
- "选中推荐词": selected_sugs
- },
- 处理过程={
- "合并": {
- "来自加词": len(new_queries_from_add),
- "来自推荐词": len(selected_sugs),
- "合并前总数": len(new_queries_from_add) + len(selected_sugs)
- },
- "去重": {
- "唯一数": len(q_list_next)
- }
- },
- 输出={
- "下轮查询列表": q_list_next,
- "下轮种子列表": seed_list_next
- }
- )
- add_step(context, f"第{round_num}轮完成", "round", {
- "round": round_num,
- "q_list_count": len(q_list_n),
- "sug_total_count": sum(len(s) for s in sug_list_list_evaluated),
- "search_count": len(search_list_n),
- "posts_found": sum(len(s["post_list"]) for s in search_list_n),
- "q_list_next_count": len(q_list_next),
- "seed_list_next_count": len(seed_list_next)
- })
- async def main_loop(context: RunContext, max_rounds: int = 2):
- """主循环
- Args:
- context: 运行上下文
- max_rounds: 最大轮数(默认2)
- """
- print("\n" + "="*60)
- print("开始主循环")
- print("="*60)
- # 初始化
- await initialize(context)
- # API实例
- xiaohongshu_api = XiaohongshuSearchRecommendations()
- xiaohongshu_search = XiaohongshuSearch()
- # 迭代
- for round_num in range(1, max_rounds + 1):
- await process_round(round_num, context, xiaohongshu_api, xiaohongshu_search)
- # 检查终止条件
- q_list_next = context.q_lists.get(round_num + 1, [])
- if not q_list_next:
- print(f"\n q_list_{round_num + 1} 为空,提前结束")
- break
- print("\n" + "="*60)
- print("主循环完成")
- print("="*60)
- print(f" 总共收集 {len(context.all_posts)} 个帖子")
- # ============================================================================
- # 主函数
- # ============================================================================
- async def main(input_dir: str, max_rounds: int = 2, visualize: bool = False):
- """主函数"""
- current_time, log_url = set_trace()
- # 读取输入
- input_context_file = os.path.join(input_dir, 'context.md')
- input_q_file = os.path.join(input_dir, 'q.md')
- c = read_file_as_string(input_context_file)
- o = read_file_as_string(input_q_file)
- # 版本信息
- version = os.path.basename(__file__)
- version_name = os.path.splitext(version)[0]
- # 日志目录
- log_dir = os.path.join(input_dir, "output", version_name, current_time)
- # 创建运行上下文
- run_context = RunContext(
- version=version,
- input_files={
- "input_dir": input_dir,
- "context_file": input_context_file,
- "q_file": input_q_file,
- },
- c=c,
- o=o,
- log_dir=log_dir,
- log_url=log_url,
- )
- # 执行主循环
- await main_loop(run_context, max_rounds=max_rounds)
- # 格式化输出
- output = f"原始需求:{run_context.c}\n"
- output += f"原始问题:{run_context.o}\n"
- output += f"收集帖子:{len(run_context.all_posts)} 个\n"
- output += "\n" + "="*60 + "\n"
- if run_context.all_posts:
- output += "【收集到的帖子】\n\n"
- for idx, post in enumerate(run_context.all_posts[:20], 1): # 只显示前20个
- output += f"{idx}. {post['title']}\n"
- output += f" 类型: {post['type']}\n"
- output += f" URL: {post['note_url']}\n\n"
- else:
- output += "未收集到帖子\n"
- run_context.final_output = output
- print(f"\n{'='*60}")
- print("最终结果")
- print(f"{'='*60}")
- print(output)
- # 保存日志
- os.makedirs(run_context.log_dir, exist_ok=True)
- context_file_path = os.path.join(run_context.log_dir, "run_context.json")
- context_dict = run_context.model_dump()
- with open(context_file_path, "w", encoding="utf-8") as f:
- json.dump(context_dict, f, ensure_ascii=False, indent=2)
- print(f"\nRunContext saved to: {context_file_path}")
- steps_file_path = os.path.join(run_context.log_dir, "steps.json")
- with open(steps_file_path, "w", encoding="utf-8") as f:
- json.dump(run_context.steps, f, ensure_ascii=False, indent=2)
- print(f"Steps log saved to: {steps_file_path}")
- if __name__ == "__main__":
- parser = argparse.ArgumentParser(description="搜索query优化工具 - v6.1.2.7 基于seed的迭代版")
- parser.add_argument(
- "--input-dir",
- type=str,
- default="input/简单扣图",
- help="输入目录路径,默认: input/简单扣图"
- )
- parser.add_argument(
- "--max-rounds",
- type=int,
- default=2,
- help="最大轮数,默认: 2"
- )
- parser.add_argument(
- "--visualize",
- action="store_true",
- default=True,
- help="运行完成后自动生成可视化HTML"
- )
- args = parser.parse_args()
- asyncio.run(main(args.input_dir, max_rounds=args.max_rounds, visualize=args.visualize))
|