import asyncio import json import os import sys import argparse import time import re from datetime import datetime from typing import Literal, TypeVar, Type from agents import Agent, Runner from lib.my_trace import set_trace from pydantic import BaseModel, Field from lib.utils import read_file_as_string from lib.client import get_model MODEL_NAME = "google/gemini-2.5-flash" from script.search_recommendations.xiaohongshu_search_recommendations import XiaohongshuSearchRecommendations from script.search.xiaohongshu_search import XiaohongshuSearch # ============================================================================ # 数据模型 # ============================================================================ class Seg(BaseModel): """分词结果""" text: str score_with_o: float from_o: str class Word(BaseModel): """词库中的词""" text: str score_with_o: float from_o: str class Q(BaseModel): """查询""" text: str score_with_o: float from_source: str # "seg" | "sug" | "add" class Sug(BaseModel): """建议查询""" text: str score_with_o: float from_q: dict # {"text": str, "score_with_o": float} evaluation_reason: str | None = None # 评估理由 class Seed(BaseModel): """种子查询(用于加词探索)""" text: str added_words: list[str] = Field(default_factory=list) from_type: str # "seg" | "sug" class Post(BaseModel): """帖子""" note_id: str = "" title: str = "" body_text: str = "" type: str = "normal" # "video" | "normal" images: list[str] = Field(default_factory=list) video: str = "" interact_info: dict = Field(default_factory=dict) note_url: str = "" class Search(BaseModel): """搜索结果(继承自Sug)""" text: str score_with_o: float from_q: dict post_list: list[Post] = Field(default_factory=list) class RunContext(BaseModel): """运行上下文""" version: str input_files: dict[str, str] c: str # 原始需求(context) o: str # 原始问题 log_url: str log_dir: str # 核心数据 seg_list: list[dict] = Field(default_factory=list) word_lists: dict[int, list[dict]] = Field(default_factory=dict) # {round: word_list} q_lists: dict[int, list[dict]] = Field(default_factory=dict) # {round: q_list} sug_list_lists: dict[int, list[list[dict]]] = Field(default_factory=dict) # {round: [[sug, sug], [sug]]} search_lists: dict[int, list[dict]] = Field(default_factory=dict) # {round: search_list} seed_lists: dict[int, list[dict]] = Field(default_factory=dict) # {round: seed_list} steps: list[dict] = Field(default_factory=list) # 新增:详细的操作记录(中文命名,但数据结构保留英文) 轮次记录: dict[int, dict] = Field(default_factory=dict) # 最终结果 all_posts: list[dict] = Field(default_factory=list) final_output: str | None = None # ============================================================================ # 辅助函数:记录操作 # ============================================================================ def init_round_record(run_context: RunContext, round_num: int, round_name: str): """初始化一个轮次记录""" run_context.轮次记录[round_num] = { "轮次": round_num, "名称": round_name, "操作列表": [] } def add_operation_record( run_context: RunContext, round_num: int, 操作名称: str, 输入: dict, 处理过程: dict, 输出: dict ): """添加一条操作记录""" from datetime import datetime operation = { "操作名称": 操作名称, "轮次": round_num, "时间": datetime.now().isoformat(), "输入": 输入, "处理过程": 处理过程, "输出": 输出 } if round_num not in run_context.轮次记录: init_round_record(run_context, round_num, f"第{round_num}轮" if round_num > 0 else "初始化阶段") run_context.轮次记录[round_num]["操作列表"].append(operation) def record_agent_call( agent_name: str, model: str, instructions: str, user_message: str, raw_output: dict | str, parsed: bool, validation_error: str | None = None, input_schema: dict | None = None ) -> dict: """记录单次Agent调用""" return { "Agent名称": agent_name, "模型": model, "系统提示词": instructions, "输入Schema": input_schema, "用户消息": user_message, "原始输出": raw_output, "解析成功": parsed, "验证错误": validation_error } # ============================================================================ # JSON后处理:处理markdown包裹的JSON响应 # ============================================================================ def clean_json_response(text: str) -> str: """清理可能包含markdown代码块包裹的JSON 模型可能返回: ```json {"key": "value"} ``` 需要清理为: {"key": "value"} """ text = text.strip() # 移除开头的 ```json 或 ``` if text.startswith('```json'): text = text[7:] elif text.startswith('```'): text = text[3:] # 移除结尾的 ``` if text.endswith('```'): text = text[:-3] return text.strip() T = TypeVar('T', bound=BaseModel) async def run_agent_with_json_cleanup( agent: Agent, input_text: str, output_type: Type[T] ) -> T: """运行Agent并处理可能的JSON包裹问题 如果Agent返回被markdown包裹的JSON,自动清理后重新解析 """ try: result = await Runner.run(agent, input_text) return result.final_output except Exception as e: error_msg = str(e) # 检查是否是JSON解析错误 if "Invalid JSON when parsing" in error_msg: # 尝试从错误消息中提取JSON # 错误格式: "Invalid JSON when parsing ```json\n{...}\n``` for TypeAdapter(...)" match = re.search(r'when parsing (.+?) for TypeAdapter', error_msg, re.DOTALL) if match: json_text = match.group(1) cleaned_json = clean_json_response(json_text) try: # 手动解析JSON并创建Pydantic对象 parsed_data = json.loads(cleaned_json) return output_type(**parsed_data) except Exception as parse_error: print(f"⚠️ JSON清理后仍无法解析: {parse_error}") print(f" 清理后的JSON: {cleaned_json}") raise ValueError(f"无法解析JSON: {parse_error}\n原始错误: {error_msg}") # 如果不是JSON解析错误,或清理失败,重新抛出原始错误 raise # ============================================================================ # Agent 定义 # ============================================================================ # Agent 1: 分词专家 class WordSegmentation(BaseModel): """分词结果""" words: list[str] = Field(..., description="分词结果列表") reasoning: str = Field(..., description="分词理由") word_segmentation_instructions = """ 你是分词专家。给定一个query,将其拆分成有意义的最小单元。 ## 分词原则 1. 保留有搜索意义的词汇 2. 拆分成独立的概念 3. 保留专业术语的完整性 4. 去除虚词(的、吗、呢等) ## 输出要求 返回分词列表和分词理由。 IMPORTANT: 直接返回纯JSON对象,不要使用markdown代码块标记(不要用```json...```包裹)。 """.strip() word_segmenter = Agent[None]( name="分词专家", instructions=word_segmentation_instructions, model=get_model(MODEL_NAME), output_type=WordSegmentation, ) # Agent 2: Query相关度评估专家 class RelevanceEvaluation(BaseModel): """相关度评估""" relevance_score: float = Field(..., description="相关性分数 0-1") reason: str = Field(..., description="评估理由") relevance_evaluation_instructions = """ 你是Query相关度评估专家。 ## 任务 评估当前query与原始问题的匹配程度。 ## 评估标准 - 主题相关性 - 要素覆盖度 - 意图匹配度 ## 输出 - relevance_score: 0-1的相关性分数 - reason: 详细理由 IMPORTANT: 直接返回纯JSON对象,不要使用markdown代码块标记(不要用```json...```包裹)。 """.strip() relevance_evaluator = Agent[None]( name="Query相关度评估专家", instructions=relevance_evaluation_instructions, model=get_model(MODEL_NAME), output_type=RelevanceEvaluation, ) # Agent 3: Word选择专家 class WordSelection(BaseModel): """Word选择结果""" selected_word: str = Field(..., description="选中的词") reasoning: str = Field(..., description="选择理由") word_selection_instructions = """ 你是Word选择专家。 ## 任务 从候选词列表中选择一个最适合与当前seed组合的词,用于探索新的搜索query。 ## 选择原则 1. 与seed的语义相关性 2. 组合后的搜索价值 3. 能拓展搜索范围 ## 输出 返回选中的词和选择理由。 """.strip() word_selector = Agent[None]( name="Word选择专家", instructions=word_selection_instructions, model=get_model(MODEL_NAME), output_type=WordSelection, ) # Agent 4: 加词位置评估专家 class WordInsertion(BaseModel): """加词结果""" new_query: str = Field(..., description="加词后的新query") insertion_position: str = Field(..., description="插入位置描述") reasoning: str = Field(..., description="插入理由") word_insertion_instructions = """ 你是加词位置评估专家。 ## 任务 将新词加到当前query的最合适位置,保持语义通顺。 ## 原则 1. 保持语法正确 2. 语义连贯 3. 符合搜索习惯 ## 输出 返回新query、插入位置描述和理由。 """.strip() word_inserter = Agent[None]( name="加词位置评估专家", instructions=word_insertion_instructions, model=get_model(MODEL_NAME), output_type=WordInsertion, ) # ============================================================================ # 辅助函数 # ============================================================================ def add_step(context: RunContext, step_name: str, step_type: str, data: dict): """添加步骤记录""" step = { "step_number": len(context.steps) + 1, "step_name": step_name, "step_type": step_type, "timestamp": datetime.now().isoformat(), "data": data } context.steps.append(step) return step def process_note_data(note: dict) -> Post: """处理搜索接口返回的帖子数据,转换为Post对象""" note_card = note.get("note_card", {}) image_list = note_card.get("image_list", []) interact_info = note_card.get("interact_info", {}) # 提取图片URLs - 使用 image_url 字段 images = [] for img in image_list: if "image_url" in img: images.append(img["image_url"]) # 判断是否是视频 note_type = note_card.get("type", "normal") video_url = "" if note_type == "video": # 视频类型可能有不同的结构,这里先留空 # 如果需要可以后续补充 pass return Post( note_id=note.get("id") or "", title=note_card.get("display_title") or "", body_text=note_card.get("desc") or "", type=note_type, images=images, video=video_url, interact_info={ "liked_count": interact_info.get("liked_count", 0), "collected_count": interact_info.get("collected_count", 0), "comment_count": interact_info.get("comment_count", 0), "shared_count": interact_info.get("shared_count", 0) }, note_url=f"https://www.xiaohongshu.com/explore/{note.get('id') or ''}" ) # ============================================================================ # 核心流程函数 # ============================================================================ async def evaluate_query_with_o(query_text: str, original_o: str) -> tuple[float, str]: """评估query与原始问题o的相关度 Returns: (score, reason) """ eval_input = f""" <原始问题> {original_o} <当前Query> {query_text} 请评估当前query与原始问题的相关度。 """ evaluation = await run_agent_with_json_cleanup( relevance_evaluator, eval_input, RelevanceEvaluation ) return evaluation.relevance_score, evaluation.reason async def initialize(context: RunContext): """初始化:分词 → seg_list → word_list_1, q_list_1, seed_list_1""" print("\n" + "="*60) print("初始化阶段") print("="*60) # 初始化轮次0 init_round_record(context, 0, "初始化阶段") # 1. 分词 print(f"\n[1/4] 分词原始问题: {context.o}") segmentation = await run_agent_with_json_cleanup( word_segmenter, context.o, WordSegmentation ) print(f" 分词结果: {segmentation.words}") print(f" 分词理由: {segmentation.reasoning}") # 2. 分词评估(并发) print(f"\n[2/4] 评估每个seg与原始问题的相关度...") seg_list = [] agent_calls_seg_eval = [] # 并发评估所有分词 eval_tasks = [evaluate_query_with_o(word, context.o) for word in segmentation.words] eval_results = await asyncio.gather(*eval_tasks) for word, (score, reason) in zip(segmentation.words, eval_results): seg = Seg(text=word, score_with_o=score, from_o=context.o) seg_list.append(seg.model_dump()) print(f" {word}: {score:.2f}") # 记录每个seg的评估 agent_calls_seg_eval.append( record_agent_call( agent_name="Query相关度评估专家", model=MODEL_NAME, instructions=relevance_evaluation_instructions, user_message=f"评估query与原始问题的相关度:\n\nQuery: {word}\n原始问题: {context.o}", raw_output={"score": score, "reason": reason}, parsed=True ) ) context.seg_list = seg_list # 记录分词操作 add_operation_record( context, round_num=0, 操作名称="分词", 输入={"原始问题": context.o}, 处理过程={ "Agent调用": record_agent_call( agent_name="分词专家", model=MODEL_NAME, instructions=word_segmentation_instructions, user_message=f"请对以下query进行分词:{context.o}", raw_output={"words": segmentation.words, "reasoning": segmentation.reasoning}, parsed=True, input_schema={"type": "WordSegmentation", "fields": {"words": "list[str]", "reasoning": "str"}} ), "seg评估Agent调用列表": agent_calls_seg_eval }, 输出={"seg_list": seg_list} ) # 3. 构建 word_list_1(直接从seg_list复制) print(f"\n[3/4] 构建 word_list_1...") word_list_1 = [] for seg in seg_list: word = Word(text=seg["text"], score_with_o=seg["score_with_o"], from_o=seg["from_o"]) word_list_1.append(word.model_dump()) context.word_lists[1] = word_list_1 print(f" word_list_1 大小: {len(word_list_1)}") # 4. 构建 q_list_1 和 seed_list_1 print(f"\n[4/4] 构建 q_list_1 和 seed_list_1...") q_list_1 = [] seed_list_1 = [] for seg in seg_list: # q_list_1: seg作为q q = Q(text=seg["text"], score_with_o=seg["score_with_o"], from_source="seg") q_list_1.append(q.model_dump()) # seed_list_1: seg作为seed seed = Seed(text=seg["text"], added_words=[], from_type="seg") seed_list_1.append(seed.model_dump()) context.q_lists[1] = q_list_1 context.seed_lists[1] = seed_list_1 print(f" q_list_1 大小: {len(q_list_1)}") print(f" seed_list_1 大小: {len(seed_list_1)}") # 记录初始化操作 add_operation_record( context, round_num=0, 操作名称="初始化", 输入={"seg_list": seg_list}, 处理过程={"说明": "从seg_list构建初始q_list和seed_list"}, 输出={ "word_list_1": word_list_1, "q_list_1": q_list_1, "seed_list_1": seed_list_1 } ) add_step(context, "初始化完成", "initialize", { "seg_count": len(seg_list), "word_list_1_count": len(word_list_1), "q_list_1_count": len(q_list_1), "seed_list_1_count": len(seed_list_1) }) async def process_round(round_num: int, context: RunContext, xiaohongshu_api: XiaohongshuSearchRecommendations, xiaohongshu_search: XiaohongshuSearch, sug_threshold: float = 0.7): """处理一轮迭代 Args: round_num: 当前轮数 context: 运行上下文 xiaohongshu_api: sug API xiaohongshu_search: search API sug_threshold: sug评分阈值(默认0.7) """ print(f"\n" + "="*60) print(f"第 {round_num} 轮") print("="*60) # 初始化轮次记录 init_round_record(context, round_num, f"第{round_num}轮迭代") q_list_n = context.q_lists.get(round_num, []) if not q_list_n: print(f" q_list_{round_num} 为空,跳过本轮") return print(f" 处理 {len(q_list_n)} 个query") # 1. 请求sug print(f"\n[1/5] 请求sug...") sug_list_list_n = [] api_calls_detail = [] for q_data in q_list_n: q_text = q_data["text"] suggestions = xiaohongshu_api.get_recommendations(keyword=q_text) if not suggestions: print(f" {q_text}: 无sug") sug_list_list_n.append([]) api_calls_detail.append({ "query": q_text, "sug_count": 0 }) continue print(f" {q_text}: 获取 {len(suggestions)} 个sug") sug_list_list_n.append(suggestions) api_calls_detail.append({ "query": q_text, "sug_count": len(suggestions) }) # 记录请求sug操作 total_sugs = sum(len(sl) for sl in sug_list_list_n) add_operation_record( context, round_num=round_num, 操作名称="请求推荐词", 输入={"q_list": [{"text": q["text"], "score": q["score_with_o"]} for q in q_list_n]}, 处理过程={"API调用": api_calls_detail}, 输出={ "sug_list_list": [[{"text": s, "from_q": q_list_n[i]["text"]} for s in sl] for i, sl in enumerate(sug_list_list_n)], "总推荐词数": total_sugs } ) # 2. sug评估(批量并发,限制并发数为10) print(f"\n[2/5] 评估sug...") sug_list_list_evaluated = [] # 收集所有需要评估的sug及其上下文 all_sug_tasks = [] sug_contexts = [] # 记录每个sug对应的q_data和位置 for i, sug_list in enumerate(sug_list_list_n): q_data = q_list_n[i] for sug_text in sug_list: all_sug_tasks.append(evaluate_query_with_o(sug_text, context.o)) sug_contexts.append((i, q_data, sug_text)) # 批量并发评估(每批10个) batch_size = 10 all_results = [] batches_detail = [] for batch_idx in range(0, len(all_sug_tasks), batch_size): batch_tasks = all_sug_tasks[batch_idx:batch_idx+batch_size] batch_results = await asyncio.gather(*batch_tasks) all_results.extend(batch_results) # 记录这个批次的Agent调用 batch_agent_calls = [] start_idx = batch_idx for j, (score, reason) in enumerate(batch_results): if start_idx + j < len(sug_contexts): _, _, sug_text = sug_contexts[start_idx + j] batch_agent_calls.append( record_agent_call( agent_name="Query相关度评估专家", model=MODEL_NAME, instructions=relevance_evaluation_instructions, user_message=f"评估query与原始问题的相关度:\n\nQuery: {sug_text}\n原始问题: {context.o}", raw_output={"score": score, "reason": reason}, parsed=True ) ) batches_detail.append({ "批次ID": len(batches_detail), "并发执行": True, "Agent调用列表": batch_agent_calls }) # 组织结果 result_index = 0 current_list_index = -1 evaluated_sugs = [] for list_idx, q_data, sug_text in sug_contexts: if list_idx != current_list_index: if evaluated_sugs: sug_list_list_evaluated.append(evaluated_sugs) evaluated_sugs = [] current_list_index = list_idx score, reason = all_results[result_index] result_index += 1 sug = Sug( text=sug_text, score_with_o=score, from_q={"text": q_data["text"], "score_with_o": q_data["score_with_o"]}, evaluation_reason=reason ) evaluated_sugs.append(sug.model_dump()) print(f" {sug_text}: {score:.2f}") # 添加最后一批 if evaluated_sugs: sug_list_list_evaluated.append(evaluated_sugs) context.sug_list_lists[round_num] = sug_list_list_evaluated # 记录评估sug操作 add_operation_record( context, round_num=round_num, 操作名称="评估推荐词", 输入={ "待评估推荐词": [[s for s in sl] for sl in sug_list_list_n], "总数": len(all_sug_tasks) }, 处理过程={"批次列表": batches_detail}, 输出={"已评估推荐词": sug_list_list_evaluated} ) # 3. 构建search_list_n(阈值>= 0.7的sug) print(f"\n[3/5] 构建search_list并执行搜索...") search_list_n = [] filter_comparisons = [] search_details = [] for sug_list_evaluated in sug_list_list_evaluated: for sug_data in sug_list_evaluated: # 记录筛选比较 passed = sug_data["score_with_o"] >= sug_threshold filter_comparisons.append({ "文本": sug_data["text"], "分数": sug_data["score_with_o"], "阈值": sug_threshold, "通过": passed }) if passed: print(f" 搜索: {sug_data['text']} (分数: {sug_data['score_with_o']:.2f})") try: # 执行搜索 search_result = xiaohongshu_search.search(keyword=sug_data["text"]) result_str = search_result.get("result", "{}") if isinstance(result_str, str): result_data = json.loads(result_str) else: result_data = result_str notes = result_data.get("data", {}).get("data", []) print(f" → 搜索到 {len(notes)} 个帖子") # 转换为Post对象 post_list = [] for note in notes: post = process_note_data(note) post_list.append(post.model_dump()) context.all_posts.append(post.model_dump()) # 创建Search对象 search = Search( text=sug_data["text"], score_with_o=sug_data["score_with_o"], from_q=sug_data["from_q"], post_list=post_list ) search_list_n.append(search.model_dump()) # 记录搜索详情 search_details.append({ "查询": sug_data["text"], "分数": sug_data["score_with_o"], "成功": True, "帖子数量": len(post_list), "错误": None }) except Exception as e: print(f" ✗ 搜索失败: {e}") search_details.append({ "查询": sug_data["text"], "分数": sug_data["score_with_o"], "成功": False, "帖子数量": 0, "错误": str(e) }) context.search_lists[round_num] = search_list_n print(f" 本轮搜索到 {len(search_list_n)} 个有效结果") # 记录构建search和执行搜索操作(合并为一个操作) total_posts = sum(len(s["post_list"]) for s in search_list_n) add_operation_record( context, round_num=round_num, 操作名称="筛选并执行搜索", 输入={"已评估推荐词": sug_list_list_evaluated}, 处理过程={ "筛选条件": f"分数 >= {sug_threshold}", "筛选比较": filter_comparisons, "搜索详情": search_details }, 输出={ "search_list": search_list_n, "成功搜索数": len(search_list_n), "总帖子数": total_posts } ) # 4. 构建word_list_(n+1)(先直接复制) print(f"\n[4/5] 构建word_list_{round_num+1}...") word_list_n = context.word_lists.get(round_num, []) word_list_next = word_list_n.copy() context.word_lists[round_num + 1] = word_list_next print(f" word_list_{round_num+1} 大小: {len(word_list_next)}") # 5. 构建q_list_(n+1)和更新seed_list print(f"\n[5/5] 构建q_list_{round_num+1}和更新seed_list...") q_list_next = [] seed_list_n = context.seed_lists.get(round_num, []) seed_list_next = seed_list_n.copy() # 5.1 从seed加词(串行处理,避免重复) print(f" [5.1] 从seed加词生成新q(串行处理,去重)...") add_word_attempts = [] # 记录所有尝试 new_queries_from_add = [] generated_query_texts = set() # 记录已生成的查询文本 for seed_data in seed_list_n: seed_text = seed_data["text"] added_words = seed_data["added_words"] # 过滤出未使用的词 candidate_words = [] for word_data in word_list_next: word_text = word_data["text"] # 简单字符串过滤 if word_text not in seed_text and word_text not in added_words: candidate_words.append(word_data) if not candidate_words: print(f" {seed_text}: 无可用词") continue attempt = { "种子": {"text": seed_text, "已添加词": added_words}, "候选词": [w["text"] for w in candidate_words[:10]] } # 使用agent选择词(提供已生成的查询列表) already_generated_str = "" if generated_query_texts: already_generated_str = f""" <已生成的查询> {', '.join(sorted(generated_query_texts))} 注意:请避免生成与上述已存在的查询重复或过于相似的新查询。 """ selection_input = f""" <当前Seed> {seed_text} <候选词列表> {', '.join([w['text'] for w in candidate_words[:10]])} {already_generated_str} 请从候选词中选择一个最适合与seed组合的词。 """ selection = await run_agent_with_json_cleanup( word_selector, selection_input, WordSelection ) selected_word = selection.selected_word # 确保选中的词在候选列表中 if selected_word not in [w["text"] for w in candidate_words]: # 如果agent选择的词不在候选列表中,使用第一个候选词 selected_word = candidate_words[0]["text"] # 记录选词 attempt["步骤1_选词"] = record_agent_call( agent_name="Word选择专家", model=MODEL_NAME, instructions=word_selection_instructions, user_message=selection_input, raw_output={"selected_word": selection.selected_word, "reasoning": selection.reasoning}, parsed=True, input_schema={"type": "WordSelection", "fields": {"selected_word": "str", "reasoning": "str"}} ) # 使用加词agent insertion_input = f""" <当前Query> {seed_text} <要添加的词> {selected_word} 请将这个词加到query的最合适位置。 """ insertion = await run_agent_with_json_cleanup( word_inserter, insertion_input, WordInsertion ) new_query_text = insertion.new_query # 记录插入位置 attempt["步骤2_插入位置"] = record_agent_call( agent_name="加词位置评估专家", model=MODEL_NAME, instructions=word_insertion_instructions, user_message=insertion_input, raw_output={"new_query": insertion.new_query, "reasoning": insertion.reasoning}, parsed=True, input_schema={"type": "WordInsertion", "fields": {"new_query": "str", "reasoning": "str"}} ) # 检查是否重复 if new_query_text in generated_query_texts: print(f" {seed_text} + {selected_word} → {new_query_text} (重复,跳过)") attempt["跳过原因"] = "查询重复" add_word_attempts.append(attempt) continue # 立即评估新query score, reason = await evaluate_query_with_o(new_query_text, context.o) # 记录评估 attempt["步骤3_评估新查询"] = record_agent_call( agent_name="Query相关度评估专家", model=MODEL_NAME, instructions=relevance_evaluation_instructions, user_message=f"评估新query的相关度:\n\nQuery: {new_query_text}\n原始问题: {context.o}", raw_output={"score": score, "reason": reason}, parsed=True ) add_word_attempts.append(attempt) # 创建新q并加入列表 new_q = Q(text=new_query_text, score_with_o=score, from_source="add") q_list_next.append(new_q.model_dump()) new_queries_from_add.append(new_q.model_dump()) generated_query_texts.add(new_query_text) # 更新seed的added_words for seed in seed_list_next: if seed["text"] == seed_text: seed["added_words"].append(selected_word) break print(f" {seed_text} + {selected_word} → {new_query_text} (分数: {score:.2f})") # 记录加词操作 add_operation_record( context, round_num=round_num, 操作名称="加词生成新查询", 输入={ "seed_list": seed_list_n, "word_list": word_list_next }, 处理过程={"尝试列表": add_word_attempts}, 输出={"新查询列表": new_queries_from_add} ) # 5.2 从sug加入q_list(条件:sug分数 > from_q分数) print(f" [5.2] 从sug加入q_list_{round_num+1}(条件:sug分数 > from_q分数)...") sug_added_count = 0 sug_filter_comparisons = [] selected_sugs = [] for sug_list_evaluated in sug_list_list_evaluated: for sug_data in sug_list_evaluated: # 新条件:sug的分数 > 其来源query的分数 from_q_score = sug_data["from_q"]["score_with_o"] passed = sug_data["score_with_o"] > from_q_score sug_filter_comparisons.append({ "推荐词": sug_data["text"], "推荐词分数": sug_data["score_with_o"], "来源查询分数": from_q_score, "通过": passed, "原因": f"{sug_data['score_with_o']:.2f} > {from_q_score:.2f}" if passed else f"{sug_data['score_with_o']:.2f} <= {from_q_score:.2f}" }) if passed: # 检查是否已存在 if sug_data["text"] not in [q["text"] for q in q_list_next]: new_q = Q(text=sug_data["text"], score_with_o=sug_data["score_with_o"], from_source="sug") q_list_next.append(new_q.model_dump()) selected_sugs.append(new_q.model_dump()) sug_added_count += 1 print(f" ✓ {sug_data['text']} ({sug_data['score_with_o']:.2f} > {from_q_score:.2f})") print(f" 添加 {sug_added_count} 个sug到q_list_{round_num+1}") # 记录筛选sug操作 add_operation_record( context, round_num=round_num, 操作名称="筛选推荐词进入下轮", 输入={"已评估推荐词": sug_list_list_evaluated}, 处理过程={ "筛选条件": "推荐词分数 > 来源查询分数", "比较结果": sug_filter_comparisons }, 输出={"选中推荐词": selected_sugs} ) # 5.3 更新seed_list(从sug中添加新seed,条件:sug分数 > from_q分数) print(f" [5.3] 更新seed_list_{round_num+1}(条件:sug分数 > from_q分数)...") seed_texts_existing = [s["text"] for s in seed_list_next] new_seed_count = 0 for sug_list_evaluated in sug_list_list_evaluated: for sug_data in sug_list_evaluated: from_q_score = sug_data["from_q"]["score_with_o"] # 新条件:sug的分数 > 其来源query的分数 if sug_data["score_with_o"] > from_q_score and sug_data["text"] not in seed_texts_existing: new_seed = Seed(text=sug_data["text"], added_words=[], from_type="sug") seed_list_next.append(new_seed.model_dump()) seed_texts_existing.append(sug_data["text"]) new_seed_count += 1 print(f" 添加 {new_seed_count} 个sug到seed_list_{round_num+1}") context.q_lists[round_num + 1] = q_list_next context.seed_lists[round_num + 1] = seed_list_next print(f"\n q_list_{round_num+1} 大小: {len(q_list_next)}") print(f" seed_list_{round_num+1} 大小: {len(seed_list_next)}") # 记录构建下一轮操作 add_operation_record( context, round_num=round_num, 操作名称="构建下一轮", 输入={ "加词新查询": new_queries_from_add, "选中推荐词": selected_sugs }, 处理过程={ "合并": { "来自加词": len(new_queries_from_add), "来自推荐词": len(selected_sugs), "合并前总数": len(new_queries_from_add) + len(selected_sugs) }, "去重": { "唯一数": len(q_list_next) } }, 输出={ "下轮查询列表": q_list_next, "下轮种子列表": seed_list_next } ) add_step(context, f"第{round_num}轮完成", "round", { "round": round_num, "q_list_count": len(q_list_n), "sug_total_count": sum(len(s) for s in sug_list_list_evaluated), "search_count": len(search_list_n), "posts_found": sum(len(s["post_list"]) for s in search_list_n), "q_list_next_count": len(q_list_next), "seed_list_next_count": len(seed_list_next) }) async def main_loop(context: RunContext, max_rounds: int = 2): """主循环 Args: context: 运行上下文 max_rounds: 最大轮数(默认2) """ print("\n" + "="*60) print("开始主循环") print("="*60) # 初始化 await initialize(context) # API实例 xiaohongshu_api = XiaohongshuSearchRecommendations() xiaohongshu_search = XiaohongshuSearch() # 迭代 for round_num in range(1, max_rounds + 1): await process_round(round_num, context, xiaohongshu_api, xiaohongshu_search) # 检查终止条件 q_list_next = context.q_lists.get(round_num + 1, []) if not q_list_next: print(f"\n q_list_{round_num + 1} 为空,提前结束") break print("\n" + "="*60) print("主循环完成") print("="*60) print(f" 总共收集 {len(context.all_posts)} 个帖子") # ============================================================================ # 主函数 # ============================================================================ async def main(input_dir: str, max_rounds: int = 2, visualize: bool = False): """主函数""" current_time, log_url = set_trace() # 读取输入 input_context_file = os.path.join(input_dir, 'context.md') input_q_file = os.path.join(input_dir, 'q.md') c = read_file_as_string(input_context_file) o = read_file_as_string(input_q_file) # 版本信息 version = os.path.basename(__file__) version_name = os.path.splitext(version)[0] # 日志目录 log_dir = os.path.join(input_dir, "output", version_name, current_time) # 创建运行上下文 run_context = RunContext( version=version, input_files={ "input_dir": input_dir, "context_file": input_context_file, "q_file": input_q_file, }, c=c, o=o, log_dir=log_dir, log_url=log_url, ) # 执行主循环 await main_loop(run_context, max_rounds=max_rounds) # 格式化输出 output = f"原始需求:{run_context.c}\n" output += f"原始问题:{run_context.o}\n" output += f"收集帖子:{len(run_context.all_posts)} 个\n" output += "\n" + "="*60 + "\n" if run_context.all_posts: output += "【收集到的帖子】\n\n" for idx, post in enumerate(run_context.all_posts[:20], 1): # 只显示前20个 output += f"{idx}. {post['title']}\n" output += f" 类型: {post['type']}\n" output += f" URL: {post['note_url']}\n\n" else: output += "未收集到帖子\n" run_context.final_output = output print(f"\n{'='*60}") print("最终结果") print(f"{'='*60}") print(output) # 保存日志 os.makedirs(run_context.log_dir, exist_ok=True) context_file_path = os.path.join(run_context.log_dir, "run_context.json") context_dict = run_context.model_dump() with open(context_file_path, "w", encoding="utf-8") as f: json.dump(context_dict, f, ensure_ascii=False, indent=2) print(f"\nRunContext saved to: {context_file_path}") steps_file_path = os.path.join(run_context.log_dir, "steps.json") with open(steps_file_path, "w", encoding="utf-8") as f: json.dump(run_context.steps, f, ensure_ascii=False, indent=2) print(f"Steps log saved to: {steps_file_path}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="搜索query优化工具 - v6.1.2.7 基于seed的迭代版") parser.add_argument( "--input-dir", type=str, default="input/简单扣图", help="输入目录路径,默认: input/简单扣图" ) parser.add_argument( "--max-rounds", type=int, default=2, help="最大轮数,默认: 2" ) parser.add_argument( "--visualize", action="store_true", default=True, help="运行完成后自动生成可视化HTML" ) args = parser.parse_args() asyncio.run(main(args.input_dir, max_rounds=args.max_rounds, visualize=args.visualize))