step4_search_result_match.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. """
  2. 搜索结果与灵感匹配分析
  3. 评估搜索到的帖子与当前灵感的匹配度
  4. - 帖子标题(title)作为匹配要素
  5. - 帖子描述(desc)作为上下文
  6. """
  7. import asyncio
  8. import json
  9. import os
  10. import sys
  11. from typing import List, Dict, Optional
  12. from pathlib import Path
  13. from agents import trace
  14. from lib.my_trace import set_trace_smith as set_trace
  15. from lib.async_utils import process_tasks_with_semaphore
  16. from lib.match_analyzer import match_single
  17. from lib.data_loader import load_inspiration_list, select_inspiration
  18. # 模型配置
  19. MODEL_NAME = "google/gemini-2.5-pro"
  20. async def match_single_note(
  21. inspiration: str,
  22. note: dict,
  23. _index: int
  24. ) -> dict:
  25. """匹配单个帖子与灵感
  26. Args:
  27. inspiration: 灵感点文本
  28. note: 帖子数据,包含 title, desc, channel_content_id 等
  29. _index: 任务索引(由 async_utils 传入)
  30. Returns:
  31. 匹配结果
  32. """
  33. title = note.get("title", "") or ""
  34. desc = note.get("desc", "") or ""
  35. channel_content_id = note.get("channel_content_id", "") or ""
  36. # 调用通用匹配模块
  37. # B = 灵感, A = 帖子标题, A_Context = 帖子描述
  38. match_result = await match_single(
  39. b_content=inspiration,
  40. a_content=title,
  41. model_name=MODEL_NAME,
  42. a_context=desc
  43. )
  44. # 构建完整结果
  45. full_result = {
  46. "输入信息": {
  47. "B": inspiration,
  48. "A": title,
  49. "B_Context": "",
  50. "A_Context": desc
  51. },
  52. "匹配结果": match_result,
  53. "业务信息": {
  54. "灵感": inspiration,
  55. "channel_content_id": channel_content_id,
  56. "title": title,
  57. "likes": note.get("like_count", 0),
  58. "user_nickname": note.get("channel_account_name", "")
  59. }
  60. }
  61. return full_result
  62. def find_search_result_file(persona_dir: str, inspiration: str, max_tasks: int = None) -> Optional[str]:
  63. """查找搜索结果文件
  64. Args:
  65. persona_dir: 人设目录
  66. inspiration: 灵感点名称
  67. max_tasks: 任务数限制(用于确定文件前缀)
  68. Returns:
  69. 搜索结果文件路径,如果未找到返回 None
  70. """
  71. search_dir = os.path.join(persona_dir, "how", "灵感点", inspiration, "search")
  72. if not os.path.exists(search_dir):
  73. return None
  74. scope_prefix = f"top{max_tasks}" if max_tasks is not None else "all"
  75. search_pattern = f"{scope_prefix}_search_*.json"
  76. search_files = list(Path(search_dir).glob(search_pattern))
  77. if not search_files:
  78. return None
  79. # 返回最新的文件
  80. return str(sorted(search_files, key=lambda x: x.stat().st_mtime, reverse=True)[0])
  81. async def main(current_time: str = None, log_url: str = None, force: bool = False):
  82. """主函数
  83. Args:
  84. current_time: 当前时间戳
  85. log_url: 日志链接
  86. force: 是否强制重新执行
  87. """
  88. # 解析命令行参数
  89. if len(sys.argv) < 3:
  90. print("用法: python step4_search_result_match.py <persona_dir> <inspiration> [max_tasks]")
  91. print("\n示例:")
  92. print(" python step4_search_result_match.py data/阿里多多酱/out/人设_1110 内容植入品牌推广")
  93. print(" python step4_search_result_match.py data/阿里多多酱/out/人设_1110 0 20")
  94. sys.exit(1)
  95. persona_dir = sys.argv[1]
  96. inspiration_arg = sys.argv[2]
  97. max_tasks = int(sys.argv[3]) if len(sys.argv) > 3 and sys.argv[3] != "all" else None
  98. # 加载灵感列表
  99. inspiration_list = load_inspiration_list(persona_dir)
  100. # 选择灵感
  101. inspiration = select_inspiration(inspiration_arg, inspiration_list)
  102. print(f"{'=' * 80}")
  103. print(f"Step4: 搜索结果与灵感匹配分析")
  104. print(f"{'=' * 80}")
  105. print(f"人设目录: {persona_dir}")
  106. print(f"灵感: {inspiration}")
  107. print(f"模型: {MODEL_NAME}")
  108. print()
  109. # 查找搜索结果文件
  110. search_file = find_search_result_file(persona_dir, inspiration, max_tasks)
  111. if not search_file:
  112. print(f"❌ 错误: 找不到搜索结果文件")
  113. print(f"请先运行搜索步骤: python run_inspiration_analysis.py --search-only --count 1")
  114. sys.exit(1)
  115. print(f"搜索结果文件: {search_file}\n")
  116. # 读取搜索结果
  117. with open(search_file, 'r', encoding='utf-8') as f:
  118. search_data = json.load(f)
  119. notes = search_data.get("notes", [])
  120. search_keyword = search_data.get("search_params", {}).get("keyword", "")
  121. if not notes:
  122. print(f"⚠️ 警告: 搜索结果为空")
  123. sys.exit(0)
  124. print(f"搜索关键词: {search_keyword}")
  125. print(f"搜索结果数: {len(notes)}")
  126. print()
  127. # 检查输出文件是否存在
  128. # 输出到 search/ 目录下
  129. output_dir = os.path.join(persona_dir, "how", "灵感点", inspiration, "search")
  130. os.makedirs(output_dir, exist_ok=True)
  131. scope_prefix = f"top{max_tasks}" if max_tasks is not None else "all"
  132. model_short = MODEL_NAME.replace("google/", "").replace("/", "_")
  133. output_file = os.path.join(output_dir, f"{scope_prefix}_step4_搜索结果匹配_{model_short}.json")
  134. if os.path.exists(output_file) and not force:
  135. print(f"✓ 输出文件已存在: {output_file}")
  136. print(f"使用 force=True 可强制重新执行")
  137. return
  138. # 执行匹配分析
  139. print(f"{'─' * 80}")
  140. print(f"开始匹配分析...")
  141. print(f"{'─' * 80}\n")
  142. # 构建匹配任务
  143. tasks = [
  144. {"inspiration": inspiration, "note": note}
  145. for note in notes
  146. ]
  147. # 并发执行匹配任务
  148. results = await process_tasks_with_semaphore(
  149. tasks=tasks,
  150. process_func=lambda task, idx: match_single_note(
  151. inspiration=task["inspiration"],
  152. note=task["note"],
  153. _index=idx
  154. ),
  155. max_concurrent=10,
  156. show_progress=True
  157. )
  158. # 按匹配分数排序
  159. results_sorted = sorted(
  160. results,
  161. key=lambda x: x.get("匹配结果", {}).get("score", 0),
  162. reverse=True
  163. )
  164. print(f"\n{'─' * 80}")
  165. print(f"匹配完成")
  166. print(f"{'─' * 80}\n")
  167. # 显示 Top 5 结果
  168. print("Top 5 匹配结果:")
  169. for i, result in enumerate(results_sorted[:5], 1):
  170. score = result.get("匹配结果", {}).get("score", 0)
  171. title = result.get("业务信息", {}).get("title", "") or ""
  172. channel_content_id = result.get("业务信息", {}).get("channel_content_id", "")
  173. # 安全地截取标题
  174. title_display = title[:50] if title else "(无标题)"
  175. print(f" {i}. [score={score:.2f}] {title_display}... (ID: {channel_content_id})")
  176. print()
  177. # 保存结果
  178. output_data = {
  179. "元数据": {
  180. "current_time": current_time,
  181. "log_url": log_url,
  182. "model": MODEL_NAME,
  183. "step": "step4_搜索结果匹配"
  184. },
  185. "输入信息": {
  186. "灵感": inspiration,
  187. "搜索关键词": search_keyword,
  188. "搜索结果数": len(notes),
  189. "搜索结果文件": search_file
  190. },
  191. "匹配结果列表": results_sorted
  192. }
  193. with open(output_file, 'w', encoding='utf-8') as f:
  194. json.dump(output_data, f, ensure_ascii=False, indent=2)
  195. print(f"✓ 结果已保存: {output_file}")
  196. print()
  197. # 统计信息
  198. high_score_count = sum(1 for r in results_sorted if r.get("匹配结果", {}).get("score", 0) >= 0.7)
  199. medium_score_count = sum(1 for r in results_sorted if 0.4 <= r.get("匹配结果", {}).get("score", 0) < 0.7)
  200. low_score_count = sum(1 for r in results_sorted if r.get("匹配结果", {}).get("score", 0) < 0.4)
  201. print(f"匹配统计:")
  202. print(f" 高匹配 (≥0.7): {high_score_count} 个")
  203. print(f" 中匹配 (0.4-0.7): {medium_score_count} 个")
  204. print(f" 低匹配 (<0.4): {low_score_count} 个")
  205. if __name__ == "__main__":
  206. # 设置 trace
  207. current_time, log_url = set_trace()
  208. # 使用 trace 包装运行
  209. with trace("Step4: 搜索结果匹配"):
  210. asyncio.run(main(current_time, log_url))