| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256 |
- """
- 搜索结果与灵感匹配分析
- 评估搜索到的帖子与当前灵感的匹配度
- - 帖子标题(title)作为匹配要素
- - 帖子描述(desc)作为上下文
- """
- import asyncio
- import json
- import os
- import sys
- from typing import List, Dict, Optional
- from pathlib import Path
- from agents import trace
- from lib.my_trace import set_trace_smith as set_trace
- from lib.async_utils import process_tasks_with_semaphore
- from lib.match_analyzer import match_single
- from lib.data_loader import load_inspiration_list, select_inspiration
- # 模型配置
- MODEL_NAME = "google/gemini-2.5-pro"
- async def match_single_note(
- inspiration: str,
- note: dict,
- _index: int
- ) -> dict:
- """匹配单个帖子与灵感
- Args:
- inspiration: 灵感点文本
- note: 帖子数据,包含 title, desc, channel_content_id 等
- _index: 任务索引(由 async_utils 传入)
- Returns:
- 匹配结果
- """
- title = note.get("title", "")
- desc = note.get("desc", "")
- channel_content_id = note.get("channel_content_id", "")
- # 调用通用匹配模块
- # B = 灵感, A = 帖子标题, A_Context = 帖子描述
- match_result = await match_single(
- b_content=inspiration,
- a_content=title,
- model_name=MODEL_NAME,
- a_context=desc
- )
- # 构建完整结果
- full_result = {
- "输入信息": {
- "B": inspiration,
- "A": title,
- "B_Context": "",
- "A_Context": desc
- },
- "匹配结果": match_result,
- "业务信息": {
- "灵感": inspiration,
- "channel_content_id": channel_content_id,
- "title": title,
- "likes": note.get("like_count", 0),
- "user_nickname": note.get("channel_account_name", "")
- }
- }
- return full_result
- def find_search_result_file(persona_dir: str, inspiration: str, max_tasks: int = None) -> Optional[str]:
- """查找搜索结果文件
- Args:
- persona_dir: 人设目录
- inspiration: 灵感点名称
- max_tasks: 任务数限制(用于确定文件前缀)
- Returns:
- 搜索结果文件路径,如果未找到返回 None
- """
- search_dir = os.path.join(persona_dir, "how", "灵感点", inspiration, "search")
- if not os.path.exists(search_dir):
- return None
- scope_prefix = f"top{max_tasks}" if max_tasks is not None else "all"
- search_pattern = f"{scope_prefix}_search_*.json"
- search_files = list(Path(search_dir).glob(search_pattern))
- if not search_files:
- return None
- # 返回最新的文件
- return str(sorted(search_files, key=lambda x: x.stat().st_mtime, reverse=True)[0])
- async def main(current_time: str = None, log_url: str = None, force: bool = False):
- """主函数
- Args:
- current_time: 当前时间戳
- log_url: 日志链接
- force: 是否强制重新执行
- """
- # 解析命令行参数
- if len(sys.argv) < 3:
- print("用法: python step4_search_result_match.py <persona_dir> <inspiration> [max_tasks]")
- print("\n示例:")
- print(" python step4_search_result_match.py data/阿里多多酱/out/人设_1110 内容植入品牌推广")
- print(" python step4_search_result_match.py data/阿里多多酱/out/人设_1110 0 20")
- sys.exit(1)
- persona_dir = sys.argv[1]
- inspiration_arg = sys.argv[2]
- max_tasks = int(sys.argv[3]) if len(sys.argv) > 3 and sys.argv[3] != "all" else None
- # 加载灵感列表
- inspiration_list = load_inspiration_list(persona_dir)
- # 选择灵感
- inspiration = select_inspiration(inspiration_arg, inspiration_list)
- print(f"{'=' * 80}")
- print(f"Step4: 搜索结果与灵感匹配分析")
- print(f"{'=' * 80}")
- print(f"人设目录: {persona_dir}")
- print(f"灵感: {inspiration}")
- print(f"模型: {MODEL_NAME}")
- print()
- # 查找搜索结果文件
- search_file = find_search_result_file(persona_dir, inspiration, max_tasks)
- if not search_file:
- print(f"❌ 错误: 找不到搜索结果文件")
- print(f"请先运行搜索步骤: python run_inspiration_analysis.py --search-only --count 1")
- sys.exit(1)
- print(f"搜索结果文件: {search_file}\n")
- # 读取搜索结果
- with open(search_file, 'r', encoding='utf-8') as f:
- search_data = json.load(f)
- notes = search_data.get("notes", [])
- search_keyword = search_data.get("search_params", {}).get("keyword", "")
- if not notes:
- print(f"⚠️ 警告: 搜索结果为空")
- sys.exit(0)
- print(f"搜索关键词: {search_keyword}")
- print(f"搜索结果数: {len(notes)}")
- print()
- # 检查输出文件是否存在
- # 输出到 search/ 目录下
- output_dir = os.path.join(persona_dir, "how", "灵感点", inspiration, "search")
- os.makedirs(output_dir, exist_ok=True)
- scope_prefix = f"top{max_tasks}" if max_tasks is not None else "all"
- model_short = MODEL_NAME.replace("google/", "").replace("/", "_")
- output_file = os.path.join(output_dir, f"{scope_prefix}_step4_搜索结果匹配_{model_short}.json")
- if os.path.exists(output_file) and not force:
- print(f"✓ 输出文件已存在: {output_file}")
- print(f"使用 force=True 可强制重新执行")
- return
- # 执行匹配分析
- print(f"{'─' * 80}")
- print(f"开始匹配分析...")
- print(f"{'─' * 80}\n")
- # 构建匹配任务
- tasks = [
- {"inspiration": inspiration, "note": note}
- for note in notes
- ]
- # 并发执行匹配任务
- results = await process_tasks_with_semaphore(
- tasks=tasks,
- process_func=lambda task, idx: match_single_note(
- inspiration=task["inspiration"],
- note=task["note"],
- _index=idx
- ),
- max_concurrent=10,
- show_progress=True
- )
- # 按匹配分数排序
- results_sorted = sorted(
- results,
- key=lambda x: x.get("匹配结果", {}).get("score", 0),
- reverse=True
- )
- print(f"\n{'─' * 80}")
- print(f"匹配完成")
- print(f"{'─' * 80}\n")
- # 显示 Top 5 结果
- print("Top 5 匹配结果:")
- for i, result in enumerate(results_sorted[:5], 1):
- score = result.get("匹配结果", {}).get("score", 0)
- title = result.get("业务信息", {}).get("title", "")
- channel_content_id = result.get("业务信息", {}).get("channel_content_id", "")
- print(f" {i}. [score={score:.2f}] {title[:50]}... (ID: {channel_content_id})")
- print()
- # 保存结果
- output_data = {
- "元数据": {
- "current_time": current_time,
- "log_url": log_url,
- "model": MODEL_NAME,
- "step": "step4_搜索结果匹配"
- },
- "输入信息": {
- "灵感": inspiration,
- "搜索关键词": search_keyword,
- "搜索结果数": len(notes),
- "搜索结果文件": search_file
- },
- "匹配结果列表": results_sorted
- }
- with open(output_file, 'w', encoding='utf-8') as f:
- json.dump(output_data, f, ensure_ascii=False, indent=2)
- print(f"✓ 结果已保存: {output_file}")
- print()
- # 统计信息
- high_score_count = sum(1 for r in results_sorted if r.get("匹配结果", {}).get("score", 0) >= 0.7)
- medium_score_count = sum(1 for r in results_sorted if 0.4 <= r.get("匹配结果", {}).get("score", 0) < 0.7)
- low_score_count = sum(1 for r in results_sorted if r.get("匹配结果", {}).get("score", 0) < 0.4)
- print(f"匹配统计:")
- print(f" 高匹配 (≥0.7): {high_score_count} 个")
- print(f" 中匹配 (0.4-0.7): {medium_score_count} 个")
- print(f" 低匹配 (<0.4): {low_score_count} 个")
- if __name__ == "__main__":
- # 设置 trace
- current_time, log_url = set_trace()
- # 使用 trace 包装运行
- with trace("Step4: 搜索结果匹配"):
- asyncio.run(main(current_time, log_url))
|