import asyncio import json import os import sys import argparse from datetime import datetime from typing import Literal from agents import Agent, Runner from lib.my_trace import set_trace from pydantic import BaseModel, Field from lib.utils import read_file_as_string from lib.client import get_model MODEL_NAME = "google/gemini-2.5-flash" from script.search_recommendations.xiaohongshu_search_recommendations import XiaohongshuSearchRecommendations from script.search.xiaohongshu_search import XiaohongshuSearch # ============================================================================ # 数据模型 # ============================================================================ class QueryState(BaseModel): """Query状态跟踪""" query: str level: int # 当前所在层级 no_suggestion_rounds: int = 0 # 连续没有suggestion的轮数 relevance_score: float = 0.0 # 与原始需求的相关度 parent_query: str | None = None # 父query strategy: str | None = None # 生成策略:direct_sug, rewrite, add_word is_terminated: bool = False # 是否已终止(不再处理) class WordLibrary(BaseModel): """动态分词库""" words: set[str] = Field(default_factory=set) word_sources: dict[str, str] = Field(default_factory=dict) # 记录词的来源:word -> source(note_id或"initial") core_words: set[str] = Field(default_factory=set) # 核心词(第一层初始分词) def add_word(self, word: str, source: str = "unknown", is_core: bool = False): """添加单词到分词库""" if word and word.strip(): word = word.strip() self.words.add(word) if word not in self.word_sources: self.word_sources[word] = source if is_core: self.core_words.add(word) def add_words(self, words: list[str], source: str = "unknown", is_core: bool = False): """批量添加单词""" for word in words: self.add_word(word, source, is_core) def get_unused_word(self, current_query: str, prefer_core: bool = True) -> str | None: """获取一个当前query中没有的词 Args: current_query: 当前查询 prefer_core: 是否优先返回核心词(默认True) """ # 优先从核心词中查找 if prefer_core and self.core_words: for word in self.core_words: if word not in current_query: return word # 如果核心词都用完了,或者不优先使用核心词,从所有词中查找 for word in self.words: if word not in current_query: return word return None def model_dump(self): """序列化为dict""" return { "words": list(self.words), "word_sources": self.word_sources, "core_words": list(self.core_words) } class RunContext(BaseModel): """运行上下文""" version: str input_files: dict[str, str] q_with_context: str q_context: str q: str log_url: str log_dir: str # 新增字段 word_library: dict = Field(default_factory=dict) # 使用dict存储,因为set不能直接序列化 query_states: list[dict] = Field(default_factory=list) steps: list[dict] = Field(default_factory=list) # Query演化图 query_graph: dict = Field(default_factory=dict) # 记录Query的演化路径和关系 # 最终结果 satisfied_notes: list[dict] = Field(default_factory=list) final_output: str | None = None # ============================================================================ # Agent 定义 # ============================================================================ # Agent 1: 分词专家 class WordSegmentation(BaseModel): """分词结果""" words: list[str] = Field(..., description="分词结果列表") reasoning: str = Field(..., description="分词理由") word_segmentation_instructions = """ 你是分词专家。给定一个query,将其拆分成有意义的最小单元。 ## 分词原则 1. 保留有搜索意义的词汇 2. 拆分成独立的概念 3. 保留专业术语的完整性 4. 去除虚词(的、吗、呢等) ## 输出要求 返回分词列表和分词理由。 """.strip() word_segmenter = Agent[None]( name="分词专家", instructions=word_segmentation_instructions, model=get_model(MODEL_NAME), output_type=WordSegmentation, ) # Agent 2: Query相关度评估专家 class RelevanceEvaluation(BaseModel): """相关度评估""" relevance_score: float = Field(..., description="相关性分数 0-1") is_improved: bool = Field(..., description="是否比之前更好") reason: str = Field(..., description="评估理由") relevance_evaluation_instructions = """ 你是Query相关度评估专家。 ## 任务 评估当前query与原始需求的匹配程度。 ## 评估标准 - 主题相关性 - 要素覆盖度 - 意图匹配度 ## 输出 - relevance_score: 0-1的相关性分数 - is_improved: 如果提供了previous_score,判断是否有提升 - reason: 详细理由 """.strip() relevance_evaluator = Agent[None]( name="Query相关度评估专家", instructions=relevance_evaluation_instructions, model=get_model(MODEL_NAME), output_type=RelevanceEvaluation, ) # Agent 3: Query改写专家 class QueryRewrite(BaseModel): """Query改写结果""" rewritten_query: str = Field(..., description="改写后的query") rewrite_type: str = Field(..., description="改写类型:abstract或synonym") reasoning: str = Field(..., description="改写理由") query_rewrite_instructions = """ 你是Query改写专家。 ## 改写策略 1. **向上抽象**:将具体概念泛化到更高层次 - 例:iPhone 13 → 智能手机 2. **同义改写**:使用同义词或相关表达 - 例:购买 → 入手、获取 ## 输出要求 返回改写后的query、改写类型和理由。 """.strip() query_rewriter = Agent[None]( name="Query改写专家", instructions=query_rewrite_instructions, model=get_model(MODEL_NAME), output_type=QueryRewrite, ) # Agent 4: 加词位置评估专家 class WordInsertion(BaseModel): """加词结果""" new_query: str = Field(..., description="加词后的新query") insertion_position: str = Field(..., description="插入位置描述") reasoning: str = Field(..., description="插入理由") word_insertion_instructions = """ 你是加词位置评估专家。 ## 任务 将新词加到当前query的最合适位置,保持语义通顺。 ## 原则 1. 保持语法正确 2. 语义连贯 3. 符合搜索习惯 ## 输出 返回新query、插入位置描述和理由。 """.strip() word_inserter = Agent[None]( name="加词位置评估专家", instructions=word_insertion_instructions, model=get_model(MODEL_NAME), output_type=WordInsertion, ) # Agent 5: Result匹配度评估专家 class ResultEvaluation(BaseModel): """Result评估结果""" match_level: str = Field(..., description="匹配等级:satisfied, partial, unsatisfied") relevance_score: float = Field(..., description="相关性分数 0-1") missing_aspects: list[str] = Field(default_factory=list, description="缺失的方面") reason: str = Field(..., description="评估理由") result_evaluation_instructions = """ 你是Result匹配度评估专家。 ## 任务 评估搜索结果(帖子)与原始需求的匹配程度。 ## 评估等级 1. **satisfied**: 完全满足需求 2. **partial**: 部分满足,但有缺失 3. **unsatisfied**: 基本不满足 ## 输出要求 - match_level: 匹配等级 - relevance_score: 相关性分数 - missing_aspects: 如果是partial,列出缺失的方面 - reason: 详细理由 """.strip() result_evaluator = Agent[None]( name="Result匹配度评估专家", instructions=result_evaluation_instructions, model=get_model(MODEL_NAME), output_type=ResultEvaluation, ) # Agent 6: Query改造专家(基于缺失部分) class QueryImprovement(BaseModel): """Query改造结果""" improved_query: str = Field(..., description="改造后的query") added_aspects: list[str] = Field(..., description="添加的方面") reasoning: str = Field(..., description="改造理由") query_improvement_instructions = """ 你是Query改造专家。 ## 任务 根据搜索结果的缺失部分,改造query使其包含这些内容。 ## 原则 1. 针对性补充缺失方面 2. 保持query简洁 3. 符合搜索习惯 ## 输出 返回改造后的query、添加的方面和理由。 """.strip() query_improver = Agent[None]( name="Query改造专家", instructions=query_improvement_instructions, model=get_model(MODEL_NAME), output_type=QueryImprovement, ) # Agent 7: 关键词提取专家 class KeywordExtraction(BaseModel): """关键词提取结果""" keywords: list[str] = Field(..., description="提取的关键词列表") reasoning: str = Field(..., description="提取理由") keyword_extraction_instructions = """ 你是关键词提取专家。 ## 任务 从帖子标题和描述中提取核心关键词。 ## 提取原则 1. 提取有搜索价值的词汇 2. 去除虚词和通用词 3. 保留专业术语 4. 提取3-10个关键词 ## 输出 返回关键词列表和提取理由。 """.strip() keyword_extractor = Agent[None]( name="关键词提取专家", instructions=keyword_extraction_instructions, model=get_model(MODEL_NAME), output_type=KeywordExtraction, ) # ============================================================================ # 辅助函数 # ============================================================================ def add_step(context: RunContext, step_name: str, step_type: str, data: dict): """添加步骤记录""" step = { "step_number": len(context.steps) + 1, "step_name": step_name, "step_type": step_type, "timestamp": datetime.now().isoformat(), "data": data } context.steps.append(step) return step def add_query_to_graph(context: RunContext, query_state: QueryState, iteration: int, evaluation_reason: str = "", is_selected: bool = True, parent_level: int | None = None): """添加Query节点到演化图 Args: context: 运行上下文 query_state: Query状态 iteration: 迭代次数 evaluation_reason: 评估原因(可选) is_selected: 是否被选中进入处理队列(默认True) parent_level: 父节点的层级(用于构造parent_id) """ # 使用 "query_level" 格式作为节点ID query_id = f"{query_state.query}_{query_state.level}" # 初始化图结构 if "nodes" not in context.query_graph: context.query_graph["nodes"] = {} context.query_graph["edges"] = [] context.query_graph["iterations"] = {} # 添加Query节点(type: query) context.query_graph["nodes"][query_id] = { "type": "query", "query": query_state.query, "level": query_state.level, "relevance_score": query_state.relevance_score, "strategy": query_state.strategy, "parent_query": query_state.parent_query, "iteration": iteration, "is_terminated": query_state.is_terminated, "no_suggestion_rounds": query_state.no_suggestion_rounds, "evaluation_reason": evaluation_reason, # 评估原因 "is_selected": is_selected # 是否被选中 } # 添加边(父子关系) if query_state.parent_query and parent_level is not None: # 构造父节点ID: parent_query_parent_level parent_id = f"{query_state.parent_query}_{parent_level}" if parent_id in context.query_graph["nodes"]: context.query_graph["edges"].append({ "from": parent_id, "to": query_id, "edge_type": "query_to_query", "strategy": query_state.strategy, "score_improvement": query_state.relevance_score - context.query_graph["nodes"][parent_id]["relevance_score"] }) # 按迭代分组 if iteration not in context.query_graph["iterations"]: context.query_graph["iterations"][iteration] = [] context.query_graph["iterations"][iteration].append(query_id) def add_note_to_graph(context: RunContext, query: str, query_level: int, note: dict): """添加Note节点到演化图,并连接到对应的Query Args: context: 运行上下文 query: query文本 query_level: query所在层级 note: 帖子数据 """ note_id = note["note_id"] # 初始化图结构 if "nodes" not in context.query_graph: context.query_graph["nodes"] = {} context.query_graph["edges"] = [] context.query_graph["iterations"] = {} # 添加Note节点(type: note),包含完整的元信息 context.query_graph["nodes"][note_id] = { "type": "note", "note_id": note_id, "title": note["title"], "desc": note.get("desc", ""), # 完整描述,不截断 "note_url": note.get("note_url", ""), "image_list": note.get("image_list", []), # 图片列表 "interact_info": note.get("interact_info", {}), # 互动信息(点赞、收藏、评论、分享) "match_level": note["evaluation"]["match_level"], "relevance_score": note["evaluation"]["relevance_score"], "evaluation_reason": note["evaluation"].get("reason", ""), # 评估原因 "found_by_query": query } # 添加边:Query → Note,使用 query_level 格式的ID query_id = f"{query}_{query_level}" if query_id in context.query_graph["nodes"]: context.query_graph["edges"].append({ "from": query_id, "to": note_id, "edge_type": "query_to_note", "match_level": note["evaluation"]["match_level"], "relevance_score": note["evaluation"]["relevance_score"] }) def process_note_data(note: dict) -> dict: """处理搜索接口返回的帖子数据""" note_card = note.get("note_card", {}) image_list = note_card.get("image_list", []) interact_info = note_card.get("interact_info", {}) user_info = note_card.get("user", {}) return { "note_id": note.get("id", ""), "title": note_card.get("display_title", ""), "desc": note_card.get("desc", ""), "image_list": image_list, "interact_info": { "liked_count": interact_info.get("liked_count", 0), "collected_count": interact_info.get("collected_count", 0), "comment_count": interact_info.get("comment_count", 0), "shared_count": interact_info.get("shared_count", 0) }, "user": { "nickname": user_info.get("nickname", ""), "user_id": user_info.get("user_id", "") }, "type": note_card.get("type", "normal"), "note_url": f"https://www.xiaohongshu.com/explore/{note.get('id', '')}" } # ============================================================================ # 核心流程函数 # ============================================================================ async def initialize_word_library(original_query: str, context: RunContext) -> WordLibrary: """初始化分词库""" print("\n[初始化] 创建分词库...") # 使用Agent进行分词 result = await Runner.run(word_segmenter, original_query) segmentation: WordSegmentation = result.final_output word_lib = WordLibrary() # 初始分词标记为核心词(is_core=True) word_lib.add_words(segmentation.words, source="initial", is_core=True) print(f"初始分词库(核心词): {list(word_lib.words)}") print(f"分词理由: {segmentation.reasoning}") # 保存到context context.word_library = word_lib.model_dump() add_step(context, "初始化分词库", "word_library_init", { "agent": "分词专家", "input": original_query, "output": { "words": segmentation.words, "reasoning": segmentation.reasoning }, "result": { "word_library": list(word_lib.words) } }) return word_lib async def evaluate_query_relevance( query: str, original_need: str, previous_score: float | None = None, context: RunContext = None ) -> RelevanceEvaluation: """评估query与原始需求的相关度""" eval_input = f""" <原始需求> {original_need} <当前Query> {query} {"<之前的相关度分数>" + str(previous_score) + "" if previous_score is not None else ""} 请评估当前query与原始需求的相关度。 """ result = await Runner.run(relevance_evaluator, eval_input) evaluation: RelevanceEvaluation = result.final_output return evaluation async def process_suggestions( query: str, query_state: QueryState, original_need: str, word_lib: WordLibrary, context: RunContext, xiaohongshu_api: XiaohongshuSearchRecommendations, iteration: int ) -> list[QueryState]: """处理suggestion分支,返回新的query states""" print(f"\n [Suggestion分支] 处理query: {query}") # 收集本次分支处理中的所有Agent调用 agent_calls = [] # 1. 获取suggestions suggestions = xiaohongshu_api.get_recommendations(keyword=query) if not suggestions or len(suggestions) == 0: print(f" → 没有获取到suggestion") query_state.no_suggestion_rounds += 1 # 记录步骤 add_step(context, f"Suggestion分支 - {query}", "suggestion_branch", { "query": query, "query_level": query_state.level, "suggestions_count": 0, "no_suggestion_rounds": query_state.no_suggestion_rounds, "new_queries_generated": 0 }) return [] print(f" → 获取到 {len(suggestions)} 个suggestions") query_state.no_suggestion_rounds = 0 # 重置计数 # 2. 评估每个suggestion new_queries = [] suggestion_evaluations = [] for sug in suggestions: # 处理所有建议 # 评估sug与原始需求的相关度(注意:这里是与原始需求original_need对比,而非当前query) # 这样可以确保生成的suggestion始终围绕用户的核心需求 sug_eval = await evaluate_query_relevance(sug, original_need, query_state.relevance_score, context) sug_eval_record = { "suggestion": sug, "relevance_score": sug_eval.relevance_score, "is_improved": sug_eval.is_improved, "reason": sug_eval.reason } suggestion_evaluations.append(sug_eval_record) # 创建query state(所有suggestion都作为query节点) sug_state = QueryState( query=sug, level=query_state.level + 1, relevance_score=sug_eval.relevance_score, parent_query=query, strategy="调用sug" ) # 判断是否比当前query更好(只有提升的才加入待处理队列) is_selected = sug_eval.is_improved and sug_eval.relevance_score > query_state.relevance_score # 将所有suggestion添加到演化图(包括未提升的) add_query_to_graph( context, sug_state, iteration, evaluation_reason=sug_eval.reason, is_selected=is_selected, parent_level=query_state.level # 父节点的层级 ) if is_selected: print(f" ✓ {sug} (分数: {sug_eval.relevance_score:.2f}, 提升: {sug_eval.is_improved})") new_queries.append(sug_state) else: print(f" ✗ {sug} (分数: {sug_eval.relevance_score:.2f}, 未提升)") # 3. 改写策略(向上抽象或同义改写) if len(new_queries) < 3: # 如果直接使用sug的数量不够,尝试改写 # 尝试向上抽象 rewrite_input_abstract = f""" <当前Query> {query} <改写要求> 类型: abstract (向上抽象) 请改写这个query。 """ result = await Runner.run(query_rewriter, rewrite_input_abstract) rewrite: QueryRewrite = result.final_output # 收集改写Agent的输入输出 rewrite_agent_call = { "agent": "Query改写专家", "action": "向上抽象改写", "input": { "query": query, "rewrite_type": "abstract" }, "output": { "rewritten_query": rewrite.rewritten_query, "rewrite_type": rewrite.rewrite_type, "reasoning": rewrite.reasoning } } agent_calls.append(rewrite_agent_call) # 评估改写后的query rewrite_eval = await evaluate_query_relevance(rewrite.rewritten_query, original_need, query_state.relevance_score, context) # 创建改写后的query state new_state = QueryState( query=rewrite.rewritten_query, level=query_state.level + 1, relevance_score=rewrite_eval.relevance_score, parent_query=query, strategy="抽象改写" ) # 添加到演化图(无论是否提升) add_query_to_graph( context, new_state, iteration, evaluation_reason=rewrite_eval.reason, is_selected=rewrite_eval.is_improved, parent_level=query_state.level # 父节点的层级 ) if rewrite_eval.is_improved: print(f" ✓ 改写(抽象): {rewrite.rewritten_query} (分数: {rewrite_eval.relevance_score:.2f})") new_queries.append(new_state) else: print(f" ✗ 改写(抽象): {rewrite.rewritten_query} (分数: {rewrite_eval.relevance_score:.2f}, 未提升)") # 3.2. 同义改写策略 if len(new_queries) < 4: # 如果还不够,尝试同义改写 rewrite_input_synonym = f""" <当前Query> {query} <改写要求> 类型: synonym (同义改写) 使用同义词或相关表达来改写query,保持语义相同但表达方式不同。 请改写这个query。 """ result = await Runner.run(query_rewriter, rewrite_input_synonym) rewrite_syn: QueryRewrite = result.final_output # 收集同义改写Agent的输入输出 rewrite_syn_agent_call = { "agent": "Query改写专家", "action": "同义改写", "input": { "query": query, "rewrite_type": "synonym" }, "output": { "rewritten_query": rewrite_syn.rewritten_query, "rewrite_type": rewrite_syn.rewrite_type, "reasoning": rewrite_syn.reasoning } } agent_calls.append(rewrite_syn_agent_call) # 评估改写后的query rewrite_syn_eval = await evaluate_query_relevance(rewrite_syn.rewritten_query, original_need, query_state.relevance_score, context) # 创建改写后的query state new_state = QueryState( query=rewrite_syn.rewritten_query, level=query_state.level + 1, relevance_score=rewrite_syn_eval.relevance_score, parent_query=query, strategy="同义改写" ) # 添加到演化图(无论是否提升) add_query_to_graph( context, new_state, iteration, evaluation_reason=rewrite_syn_eval.reason, is_selected=rewrite_syn_eval.is_improved, parent_level=query_state.level # 父节点的层级 ) if rewrite_syn_eval.is_improved: print(f" ✓ 改写(同义): {rewrite_syn.rewritten_query} (分数: {rewrite_syn_eval.relevance_score:.2f})") new_queries.append(new_state) else: print(f" ✗ 改写(同义): {rewrite_syn.rewritten_query} (分数: {rewrite_syn_eval.relevance_score:.2f}, 未提升)") # 4. 加词策略(优先使用核心词) unused_word = word_lib.get_unused_word(query, prefer_core=True) is_core_word = unused_word in word_lib.core_words if unused_word else False if unused_word and len(new_queries) < 5: word_type = "核心词" if is_core_word else "普通词" insertion_input = f""" <当前Query> {query} <要添加的词> {unused_word} 请将这个词加到query的最合适位置。 """ result = await Runner.run(word_inserter, insertion_input) insertion: WordInsertion = result.final_output # 收集加词Agent的输入输出 insertion_agent_call = { "agent": "加词位置评估专家", "action": f"加词({word_type})", "input": { "query": query, "word_to_add": unused_word, "is_core_word": is_core_word }, "output": { "new_query": insertion.new_query, "insertion_position": insertion.insertion_position, "reasoning": insertion.reasoning } } agent_calls.append(insertion_agent_call) # 评估加词后的query insertion_eval = await evaluate_query_relevance(insertion.new_query, original_need, query_state.relevance_score, context) # 创建加词后的query state new_state = QueryState( query=insertion.new_query, level=query_state.level + 1, relevance_score=insertion_eval.relevance_score, parent_query=query, strategy="加词" ) # 添加到演化图(无论是否提升) add_query_to_graph( context, new_state, iteration, evaluation_reason=insertion_eval.reason, is_selected=insertion_eval.is_improved, parent_level=query_state.level # 父节点的层级 ) if insertion_eval.is_improved: print(f" ✓ 加词({word_type}): {insertion.new_query} [+{unused_word}] (分数: {insertion_eval.relevance_score:.2f})") new_queries.append(new_state) else: print(f" ✗ 加词({word_type}): {insertion.new_query} [+{unused_word}] (分数: {insertion_eval.relevance_score:.2f}, 未提升)") # 记录完整的suggestion分支处理结果(层级化) add_step(context, f"Suggestion分支 - {query}", "suggestion_branch", { "query": query, "query_level": query_state.level, "query_relevance": query_state.relevance_score, "suggestions_count": len(suggestions), "suggestions_evaluated": len(suggestion_evaluations), "suggestion_evaluations": suggestion_evaluations, # 保存所有评估 "agent_calls": agent_calls, # 所有Agent调用的详细记录 "new_queries_generated": len(new_queries), "new_queries": [{"query": nq.query, "score": nq.relevance_score, "strategy": nq.strategy} for nq in new_queries], "no_suggestion_rounds": query_state.no_suggestion_rounds }) return new_queries async def process_search_results( query: str, query_state: QueryState, original_need: str, word_lib: WordLibrary, context: RunContext, xiaohongshu_search: XiaohongshuSearch, relevance_threshold: float, iteration: int ) -> tuple[list[dict], list[QueryState]]: """ 处理搜索结果分支 返回: (满足需求的notes, 需要继续迭代的新queries) """ print(f"\n [Result分支] 搜索query: {query}") # 收集本次分支处理中的所有Agent调用 agent_calls = [] # 1. 判断query相关度是否达到门槛 if query_state.relevance_score < relevance_threshold: print(f" ✗ 相关度 {query_state.relevance_score:.2f} 低于门槛 {relevance_threshold},跳过搜索") return [], [] print(f" ✓ 相关度 {query_state.relevance_score:.2f} 达到门槛,执行搜索") # 2. 执行搜索 try: search_result = xiaohongshu_search.search(keyword=query) result_str = search_result.get("result", "{}") if isinstance(result_str, str): result_data = json.loads(result_str) else: result_data = result_str notes = result_data.get("data", {}).get("data", []) print(f" → 搜索到 {len(notes)} 个帖子") except Exception as e: print(f" ✗ 搜索失败: {e}") return [], [] if not notes: return [], [] # 3. 评估每个帖子 satisfied_notes = [] partial_notes = [] for note in notes: # 评估所有帖子 note_data = process_note_data(note) title = note_data["title"] or "" desc = note_data["desc"] or "" # 跳过空标题和描述的帖子 if not title and not desc: continue # 评估帖子 eval_input = f""" <原始需求> {original_need} <帖子> 标题: {title} 描述: {desc} 请评估这个帖子与原始需求的匹配程度。 """ result = await Runner.run(result_evaluator, eval_input) evaluation: ResultEvaluation = result.final_output # 收集Result评估Agent的输入输出 result_eval_agent_call = { "agent": "Result匹配度评估专家", "action": "评估帖子匹配度", "input": { "note_id": note_data.get("note_id"), "title": title, "desc": desc # 完整描述 }, "output": { "match_level": evaluation.match_level, "relevance_score": evaluation.relevance_score, "missing_aspects": evaluation.missing_aspects, "reason": evaluation.reason } } agent_calls.append(result_eval_agent_call) note_data["evaluation"] = { "match_level": evaluation.match_level, "relevance_score": evaluation.relevance_score, "missing_aspects": evaluation.missing_aspects, "reason": evaluation.reason } # 将所有评估过的帖子添加到演化图(包括satisfied、partial、unsatisfied) add_note_to_graph(context, query, query_state.level, note_data) if evaluation.match_level == "satisfied": satisfied_notes.append(note_data) print(f" ✓ 满足: {title[:30] if len(title) > 30 else title}... (分数: {evaluation.relevance_score:.2f})") elif evaluation.match_level == "partial": partial_notes.append(note_data) print(f" ~ 部分: {title[:30] if len(title) > 30 else title}... (缺失: {', '.join(evaluation.missing_aspects[:2])})") else: # unsatisfied print(f" ✗ 不满足: {title[:30] if len(title) > 30 else title}... (分数: {evaluation.relevance_score:.2f})") # 4. 处理满足的帖子:不再扩充分词库(避免无限扩张) new_queries = [] if satisfied_notes: print(f"\n ✓ 找到 {len(satisfied_notes)} 个满足的帖子,不再提取关键词入库") # 注释掉关键词提取逻辑,保持分词库稳定 # for note in satisfied_notes[:3]: # extract_input = f""" # <帖子> # 标题: {note['title']} # 描述: {note['desc']} # # # 请提取核心关键词。 # """ # result = await Runner.run(keyword_extractor, extract_input) # extraction: KeywordExtraction = result.final_output # # # 添加新词到分词库,标记来源 # note_id = note.get('note_id', 'unknown') # for keyword in extraction.keywords: # if keyword not in word_lib.words: # word_lib.add_word(keyword, source=f"note:{note_id}") # print(f" + 新词入库: {keyword} (来源: {note_id})") # 5. 处理部分匹配的帖子:改造query if partial_notes and len(satisfied_notes) < 5: # 如果满足的不够,基于部分匹配改进 print(f"\n 基于 {len(partial_notes)} 个部分匹配帖子改造query...") # 收集所有缺失方面 all_missing = [] for note in partial_notes: all_missing.extend(note["evaluation"]["missing_aspects"]) if all_missing: improvement_input = f""" <当前Query> {query} <缺失的方面> {', '.join(set(all_missing))} 请改造query使其包含这些缺失的内容。 """ result = await Runner.run(query_improver, improvement_input) improvement: QueryImprovement = result.final_output # 收集Query改造Agent的输入输出 improvement_agent_call = { "agent": "Query改造专家", "action": "基于缺失方面改造Query", "input": { "query": query, "missing_aspects": list(set(all_missing)) }, "output": { "improved_query": improvement.improved_query, "added_aspects": improvement.added_aspects, "reasoning": improvement.reasoning } } agent_calls.append(improvement_agent_call) # 评估改进后的query improved_eval = await evaluate_query_relevance(improvement.improved_query, original_need, query_state.relevance_score, context) # 创建改进后的query state new_state = QueryState( query=improvement.improved_query, level=query_state.level + 1, relevance_score=improved_eval.relevance_score, parent_query=query, strategy="基于部分匹配改进" ) # 添加到演化图(无论是否提升) add_query_to_graph( context, new_state, iteration, evaluation_reason=improved_eval.reason, is_selected=improved_eval.is_improved, parent_level=query_state.level # 父节点的层级 ) if improved_eval.is_improved: print(f" ✓ 改进: {improvement.improved_query} (添加: {', '.join(improvement.added_aspects[:2])})") new_queries.append(new_state) else: print(f" ✗ 改进: {improvement.improved_query} (分数: {improved_eval.relevance_score:.2f}, 未提升)") # 6. Result分支的改写策略(向上抽象和同义改写) # 如果搜索结果不理想且新queries不够,尝试改写当前query if len(satisfied_notes) < 3 and len(new_queries) < 2: print(f"\n 搜索结果不理想,尝试改写query...") # 6.1 向上抽象 if len(new_queries) < 3: rewrite_input_abstract = f""" <当前Query> {query} <改写要求> 类型: abstract (向上抽象) 请改写这个query。 """ result = await Runner.run(query_rewriter, rewrite_input_abstract) rewrite: QueryRewrite = result.final_output # 收集Result分支改写(抽象)Agent的输入输出 rewrite_agent_call = { "agent": "Query改写专家", "action": "向上抽象改写(Result分支)", "input": { "query": query, "rewrite_type": "abstract" }, "output": { "rewritten_query": rewrite.rewritten_query, "rewrite_type": rewrite.rewrite_type, "reasoning": rewrite.reasoning } } agent_calls.append(rewrite_agent_call) # 评估改写后的query rewrite_eval = await evaluate_query_relevance(rewrite.rewritten_query, original_need, query_state.relevance_score, context) # 创建改写后的query state new_state = QueryState( query=rewrite.rewritten_query, level=query_state.level + 1, relevance_score=rewrite_eval.relevance_score, parent_query=query, strategy="结果分支-抽象改写" ) # 添加到演化图(无论是否提升) add_query_to_graph( context, new_state, iteration, evaluation_reason=rewrite_eval.reason, is_selected=rewrite_eval.is_improved, parent_level=query_state.level # 父节点的层级 ) if rewrite_eval.is_improved: print(f" ✓ 改写(抽象): {rewrite.rewritten_query} (分数: {rewrite_eval.relevance_score:.2f})") new_queries.append(new_state) else: print(f" ✗ 改写(抽象): {rewrite.rewritten_query} (分数: {rewrite_eval.relevance_score:.2f}, 未提升)") # 6.2 同义改写 if len(new_queries) < 4: rewrite_input_synonym = f""" <当前Query> {query} <改写要求> 类型: synonym (同义改写) 使用同义词或相关表达来改写query,保持语义相同但表达方式不同。 请改写这个query。 """ result = await Runner.run(query_rewriter, rewrite_input_synonym) rewrite_syn: QueryRewrite = result.final_output # 收集Result分支改写(同义)Agent的输入输出 rewrite_syn_agent_call = { "agent": "Query改写专家", "action": "同义改写(Result分支)", "input": { "query": query, "rewrite_type": "synonym" }, "output": { "rewritten_query": rewrite_syn.rewritten_query, "rewrite_type": rewrite_syn.rewrite_type, "reasoning": rewrite_syn.reasoning } } agent_calls.append(rewrite_syn_agent_call) # 评估改写后的query rewrite_syn_eval = await evaluate_query_relevance(rewrite_syn.rewritten_query, original_need, query_state.relevance_score, context) # 创建改写后的query state new_state = QueryState( query=rewrite_syn.rewritten_query, level=query_state.level + 1, relevance_score=rewrite_syn_eval.relevance_score, parent_query=query, strategy="结果分支-同义改写" ) # 添加到演化图(无论是否提升) add_query_to_graph( context, new_state, iteration, evaluation_reason=rewrite_syn_eval.reason, is_selected=rewrite_syn_eval.is_improved, parent_level=query_state.level # 父节点的层级 ) if rewrite_syn_eval.is_improved: print(f" ✓ 改写(同义): {rewrite_syn.rewritten_query} (分数: {rewrite_syn_eval.relevance_score:.2f})") new_queries.append(new_state) else: print(f" ✗ 改写(同义): {rewrite_syn.rewritten_query} (分数: {rewrite_syn_eval.relevance_score:.2f}, 未提升)") # 记录完整的result分支处理结果(层级化) add_step(context, f"Result分支 - {query}", "result_branch", { "query": query, "query_level": query_state.level, "query_relevance": query_state.relevance_score, "relevance_threshold": relevance_threshold, "passed_threshold": query_state.relevance_score >= relevance_threshold, "notes_count": len(notes) if 'notes' in locals() else 0, "satisfied_count": len(satisfied_notes), "partial_count": len(partial_notes), "satisfied_notes": [ { "note_id": note["note_id"], "title": note["title"], "score": note["evaluation"]["relevance_score"], "match_level": note["evaluation"]["match_level"] } for note in satisfied_notes # 保存所有满足的帖子 ], "agent_calls": agent_calls, # 所有Agent调用的详细记录 "new_queries_generated": len(new_queries), "new_queries": [{"query": nq.query, "score": nq.relevance_score, "strategy": nq.strategy} for nq in new_queries] }) return satisfied_notes, new_queries async def iterative_search_loop( context: RunContext, max_iterations: int = 20, relevance_threshold: float = 0.6 ) -> list[dict]: """ 主循环:迭代搜索(按层级处理) Args: context: 运行上下文 max_iterations: 最大迭代次数(层级数) relevance_threshold: 相关度门槛 Returns: 满足需求的帖子列表 """ print(f"\n{'='*60}") print(f"开始迭代搜索循环") print(f"{'='*60}") # 0. 添加原始问题作为根节点 root_query_state = QueryState( query=context.q, level=0, relevance_score=1.0, # 原始问题本身相关度为1.0 strategy="根节点" ) add_query_to_graph(context, root_query_state, 0, evaluation_reason="原始问题,作为搜索的根节点", is_selected=True) print(f"[根节点] 原始问题: {context.q}") # 1. 初始化分词库 word_lib = await initialize_word_library(context.q, context) # 2. 初始化query队列 - 智能选择最相关的词 all_words = list(word_lib.words) query_queue = [] print(f"\n评估所有初始分词的相关度...") word_scores = [] for word in all_words: # 评估每个词的相关度 eval_result = await evaluate_query_relevance(word, context.q, None, context) word_scores.append({ 'word': word, 'score': eval_result.relevance_score, 'eval': eval_result }) print(f" {word}: {eval_result.relevance_score:.2f}") # 按相关度排序,使用所有分词 word_scores.sort(key=lambda x: x['score'], reverse=True) selected_words = word_scores # 使用所有分词 # 将所有分词添加到演化图(全部被选中) for item in word_scores: is_selected = True # 所有分词都被选中 query_state = QueryState( query=item['word'], level=1, relevance_score=item['score'], strategy="初始分词", parent_query=context.q # 父节点是原始问题 ) # 添加到演化图(会自动创建从parent_query到该query的边) add_query_to_graph(context, query_state, 0, evaluation_reason=item['eval'].reason, is_selected=is_selected, parent_level=0) # 父节点是根节点(level 0) # 只有被选中的才加入队列 if is_selected: query_queue.append(query_state) print(f"\n初始query队列(按相关度排序): {[(q.query, f'{q.relevance_score:.2f}') for q in query_queue]}") print(f" (共评估了 {len(word_scores)} 个分词,全部加入队列)") # 3. API实例 xiaohongshu_api = XiaohongshuSearchRecommendations() xiaohongshu_search = XiaohongshuSearch() # 4. 主循环 all_satisfied_notes = [] iteration = 0 while query_queue and iteration < max_iterations: iteration += 1 # 获取当前层级(队列中最小的level) current_level = min(q.level for q in query_queue) # 提取当前层级的所有query current_batch = [q for q in query_queue if q.level == current_level] query_queue = [q for q in query_queue if q.level != current_level] print(f"\n{'='*60}") print(f"迭代 {iteration}: 处理第 {current_level} 层,共 {len(current_batch)} 个query") print(f"{'='*60}") # 记录本轮处理的queries add_step(context, f"迭代 {iteration}", "iteration", { "iteration": iteration, "current_level": current_level, "current_batch_size": len(current_batch), "remaining_queue_size": len(query_queue), "processing_queries": [{"query": q.query, "level": q.level} for q in current_batch] }) new_queries_from_sug = [] new_queries_from_result = [] # 处理每个query for query_state in current_batch: print(f"\n处理Query [{query_state.level}]: {query_state.query} (分数: {query_state.relevance_score:.2f})") # 检查终止条件 if query_state.is_terminated or query_state.no_suggestion_rounds >= 2: print(f" ✗ 已终止或连续2轮无suggestion,跳过该query") query_state.is_terminated = True continue # 并行处理两个分支 sug_task = process_suggestions( query_state.query, query_state, context.q, word_lib, context, xiaohongshu_api, iteration ) result_task = process_search_results( query_state.query, query_state, context.q, word_lib, context, xiaohongshu_search, relevance_threshold, iteration ) # 等待两个分支完成 sug_queries, (satisfied_notes, result_queries) = await asyncio.gather( sug_task, result_task ) # 如果suggestion分支返回空,说明没有获取到suggestion,需要继承no_suggestion_rounds # 注意:process_suggestions内部已经更新了query_state.no_suggestion_rounds # 所以这里生成的新queries需要继承父query的no_suggestion_rounds(如果sug分支也返回空) if not sug_queries and not result_queries: # 两个分支都没有产生新query,标记当前query为终止 query_state.is_terminated = True print(f" ⚠ 两个分支均未产生新query,标记该query为终止") new_queries_from_sug.extend(sug_queries) new_queries_from_result.extend(result_queries) all_satisfied_notes.extend(satisfied_notes) # 更新队列 all_new_queries = new_queries_from_sug + new_queries_from_result # 注意:不需要在这里再次添加到演化图,因为在 process_suggestions 和 process_search_results 中已经添加过了 # 如果在这里再次调用 add_query_to_graph,会覆盖之前设置的 evaluation_reason 等字段 query_queue.extend(all_new_queries) # 去重(基于query文本)并过滤已终止的query seen = set() unique_queue = [] for q in query_queue: if q.query not in seen and not q.is_terminated: seen.add(q.query) unique_queue.append(q) query_queue = unique_queue # 按相关度排序 query_queue.sort(key=lambda x: x.relevance_score, reverse=True) print(f"\n本轮结果:") print(f" 新增满足帖子: {len(satisfied_notes)}") print(f" 累计满足帖子: {len(all_satisfied_notes)}") print(f" 新增queries: {len(all_new_queries)}") print(f" 队列剩余: {len(query_queue)}") # 更新分词库到context context.word_library = word_lib.model_dump() # 如果满足条件的帖子足够多,可以提前结束 if len(all_satisfied_notes) >= 20: print(f"\n已找到足够的满足帖子 ({len(all_satisfied_notes)}个),提前结束") break print(f"\n{'='*60}") print(f"迭代搜索完成") print(f" 总迭代次数: {iteration}") print(f" 最终满足帖子数: {len(all_satisfied_notes)}") print(f" 核心词库: {list(word_lib.core_words)}") print(f" 最终分词库大小: {len(word_lib.words)}") print(f"{'='*60}") # 保存最终结果 add_step(context, "迭代搜索完成", "loop_complete", { "total_iterations": iteration, "total_satisfied_notes": len(all_satisfied_notes), "core_words": list(word_lib.core_words), "final_word_library_size": len(word_lib.words), "final_word_library": list(word_lib.words) }) return all_satisfied_notes # ============================================================================ # 主函数 # ============================================================================ async def main(input_dir: str, max_iterations: int = 20, visualize: bool = False): """主函数""" current_time, log_url = set_trace() # 读取输入 input_context_file = os.path.join(input_dir, 'context.md') input_q_file = os.path.join(input_dir, 'q.md') q_context = read_file_as_string(input_context_file) q = read_file_as_string(input_q_file) q_with_context = f""" <需求上下文> {q_context} <当前问题> {q} """.strip() # 版本信息 version = os.path.basename(__file__) version_name = os.path.splitext(version)[0] # 日志目录 log_dir = os.path.join(input_dir, "output", version_name, current_time) # 创建运行上下文 run_context = RunContext( version=version, input_files={ "input_dir": input_dir, "context_file": input_context_file, "q_file": input_q_file, }, q_with_context=q_with_context, q_context=q_context, q=q, log_dir=log_dir, log_url=log_url, ) # 执行迭代搜索 satisfied_notes = await iterative_search_loop( run_context, max_iterations=max_iterations, relevance_threshold=0.6 ) # 保存结果 run_context.satisfied_notes = satisfied_notes # 格式化输出 output = f"原始问题:{run_context.q}\n" output += f"找到满足需求的帖子:{len(satisfied_notes)} 个\n" output += f"核心词库:{', '.join(run_context.word_library.get('core_words', []))}\n" output += f"分词库大小:{len(run_context.word_library.get('words', []))} 个词\n" output += "\n" + "="*60 + "\n" if satisfied_notes: output += "【满足需求的帖子】\n\n" for idx, note in enumerate(satisfied_notes, 1): output += f"{idx}. {note['title']}\n" output += f" 相关度: {note['evaluation']['relevance_score']:.2f}\n" output += f" URL: {note['note_url']}\n\n" else: output += "未找到满足需求的帖子\n" run_context.final_output = output print(f"\n{'='*60}") print("最终结果") print(f"{'='*60}") print(output) # 保存日志 os.makedirs(run_context.log_dir, exist_ok=True) context_file_path = os.path.join(run_context.log_dir, "run_context.json") context_dict = run_context.model_dump() with open(context_file_path, "w", encoding="utf-8") as f: json.dump(context_dict, f, ensure_ascii=False, indent=2) print(f"\nRunContext saved to: {context_file_path}") steps_file_path = os.path.join(run_context.log_dir, "steps.json") with open(steps_file_path, "w", encoding="utf-8") as f: json.dump(run_context.steps, f, ensure_ascii=False, indent=2) print(f"Steps log saved to: {steps_file_path}") # 保存Query演化图 query_graph_file_path = os.path.join(run_context.log_dir, "query_graph.json") with open(query_graph_file_path, "w", encoding="utf-8") as f: json.dump(run_context.query_graph, f, ensure_ascii=False, indent=2) print(f"Query graph saved to: {query_graph_file_path}") # 可视化 if visualize: import subprocess output_html = os.path.join(run_context.log_dir, "visualization.html") print(f"\n🎨 生成可视化HTML...") # 获取绝对路径 vis_script = os.path.abspath("visualization/sug_v6_1_2_6/index.js") abs_query_graph = os.path.abspath(query_graph_file_path) abs_output_html = os.path.abspath(output_html) # 在可视化脚本目录中执行,确保使用本地 node_modules result = subprocess.run([ "node", "index.js", abs_query_graph, abs_output_html ], cwd="visualization/sug_v6_1_2_6") if result.returncode == 0: print(f"✅ 可视化已生成: {output_html}") else: print(f"❌ 可视化生成失败") if __name__ == "__main__": parser = argparse.ArgumentParser(description="搜索query优化工具 - v6.1.2.5 迭代循环版") parser.add_argument( "--input-dir", type=str, default="input/简单扣图", help="输入目录路径,默认: input/简单扣图" ) parser.add_argument( "--max-iterations", type=int, default=20, help="最大迭代次数,默认: 20" ) parser.add_argument( "--visualize", action="store_true", default=False, help="运行完成后自动生成可视化HTML" ) parser.add_argument( "--visualize-only", type=str, help="仅生成可视化,指定query_graph.json文件路径" ) args = parser.parse_args() # 如果只是生成可视化 if args.visualize_only: import subprocess query_graph_path = args.visualize_only output_html = os.path.splitext(query_graph_path)[0].replace("query_graph", "visualization") + ".html" if not output_html.endswith(".html"): output_html = os.path.join(os.path.dirname(query_graph_path), "visualization.html") print(f"🎨 生成可视化HTML...") print(f"输入: {query_graph_path}") print(f"输出: {output_html}") # 获取绝对路径 abs_query_graph = os.path.abspath(query_graph_path) abs_output_html = os.path.abspath(output_html) # 在可视化脚本目录中执行,确保使用本地 node_modules result = subprocess.run([ "node", "index.js", abs_query_graph, abs_output_html ], cwd="visualization/sug_v6_1_2_6") if result.returncode == 0: print(f"✅ 可视化已生成: {output_html}") else: print(f"❌ 可视化生成失败") sys.exit(result.returncode) asyncio.run(main(args.input_dir, max_iterations=args.max_iterations, visualize=args.visualize))