run.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. """
  2. 内容寻找 Agent
  3. 使用示例:
  4. python run.py
  5. """
  6. import asyncio
  7. import logging
  8. import sys
  9. import os
  10. from pathlib import Path
  11. sys.path.insert(0, str(Path(__file__).parent.parent.parent))
  12. from dotenv import load_dotenv
  13. load_dotenv()
  14. from agent import (
  15. AgentRunner,
  16. RunConfig,
  17. FileSystemTraceStore,
  18. Trace,
  19. Message,
  20. )
  21. from agent.llm import create_openrouter_llm_call
  22. from agent.llm.prompts import SimplePrompt
  23. # 导入工具(确保工具被注册)
  24. from tools import (
  25. douyin_search,
  26. douyin_user_videos,
  27. get_content_fans_portrait,
  28. get_account_fans_portrait,
  29. )
  30. # 配置日志
  31. log_dir = Path(__file__).parent / '.cache'
  32. log_dir.mkdir(exist_ok=True)
  33. logging.basicConfig(
  34. level=logging.INFO,
  35. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
  36. handlers=[
  37. logging.FileHandler(log_dir / 'agent.log'),
  38. logging.StreamHandler()
  39. ]
  40. )
  41. logger = logging.getLogger(__name__)
  42. async def generate_fallback_output(store: FileSystemTraceStore, trace_id: str):
  43. """
  44. 当任务未正常输出时,从 trace 中提取数据并生成兜底输出
  45. """
  46. try:
  47. # 读取所有消息
  48. messages_dir = Path(store.base_path) / trace_id / "messages"
  49. if not messages_dir.exists():
  50. print("无法生成摘要:找不到消息目录")
  51. return
  52. # 提取搜索结果和画像数据
  53. search_results = []
  54. portrait_data = {}
  55. import json
  56. import re
  57. for msg_file in sorted(messages_dir.glob("*.json")):
  58. with open(msg_file, 'r', encoding='utf-8') as f:
  59. msg = json.load(f)
  60. # 提取搜索结果(从文本结果中解析)
  61. if msg.get("role") == "tool" and msg.get("content", {}).get("tool_name") == "douyin_search":
  62. result_text = msg.get("content", {}).get("result", "")
  63. # 解析每条搜索结果
  64. lines = result_text.split("\n")
  65. current_item = {}
  66. for line in lines:
  67. line = line.strip()
  68. if not line:
  69. if current_item.get("aweme_id"):
  70. if current_item["aweme_id"] not in [r["aweme_id"] for r in search_results]:
  71. search_results.append(current_item)
  72. current_item = {}
  73. continue
  74. # 解析标题行(以数字开头)
  75. if re.match(r'^\d+\.', line):
  76. current_item["desc"] = line.split(".", 1)[1].strip()[:100]
  77. # 解析 ID
  78. elif line.startswith("ID:"):
  79. current_item["aweme_id"] = line.split("ID:")[1].strip()
  80. # 解析作者
  81. elif line.startswith("作者:"):
  82. author_name = line.split("作者:")[1].strip()
  83. current_item["author"] = {"nickname": author_name}
  84. # 解析 sec_uid
  85. elif line.startswith("sec_uid:"):
  86. sec_uid = line.split("sec_uid:")[1].strip()
  87. if "author" not in current_item:
  88. current_item["author"] = {}
  89. current_item["author"]["sec_uid"] = sec_uid
  90. # 解析数据
  91. elif line.startswith("数据:"):
  92. stats_text = line.split("数据:")[1].strip()
  93. stats = {}
  94. # 解析点赞数
  95. if "点赞" in stats_text:
  96. digg_match = re.search(r'点赞\s+([\d,]+)', stats_text)
  97. if digg_match:
  98. stats["digg_count"] = int(digg_match.group(1).replace(",", ""))
  99. # 解析评论数
  100. if "评论" in stats_text:
  101. comment_match = re.search(r'评论\s+([\d,]+)', stats_text)
  102. if comment_match:
  103. stats["comment_count"] = int(comment_match.group(1).replace(",", ""))
  104. # 解析分享数
  105. if "分享" in stats_text:
  106. share_match = re.search(r'分享\s+([\d,]+)', stats_text)
  107. if share_match:
  108. stats["share_count"] = int(share_match.group(1).replace(",", ""))
  109. current_item["statistics"] = stats
  110. # 添加最后一条
  111. if current_item.get("aweme_id"):
  112. if current_item["aweme_id"] not in [r["aweme_id"] for r in search_results]:
  113. search_results.append(current_item)
  114. # 提取画像数据
  115. elif msg.get("role") == "tool":
  116. tool_name = msg.get("content", {}).get("tool_name", "")
  117. result_text = msg.get("content", {}).get("result", "")
  118. if tool_name in ["get_content_fans_portrait", "get_account_fans_portrait"]:
  119. # 解析画像数据
  120. content_id = None
  121. age_50_plus = None
  122. tgi = None
  123. # 从结果文本中提取 ID
  124. if "内容 " in result_text:
  125. parts = result_text.split("内容 ")[1].split(" ")[0]
  126. content_id = parts
  127. elif "账号 " in result_text:
  128. parts = result_text.split("账号 ")[1].split(" ")[0]
  129. content_id = parts
  130. # 提取50岁以上数据(格式:50-: 48.35% (偏好度: 210.05))
  131. if "【年龄】分布" in result_text:
  132. lines = result_text.split("\n")
  133. for line in lines:
  134. if "50-:" in line:
  135. # 解析: 50-: 48.35% (偏好度: 210.05)
  136. parts = line.split("50-:")[1].strip()
  137. if "%" in parts:
  138. age_50_plus = parts.split("%")[0].strip()
  139. if "偏好度:" in parts:
  140. tgi_part = parts.split("偏好度:")[1].strip()
  141. tgi = tgi_part.replace(")", "").strip()
  142. break
  143. if content_id and age_50_plus:
  144. portrait_data[content_id] = {
  145. "age_50_plus": age_50_plus,
  146. "tgi": tgi,
  147. "source": "内容点赞画像" if tool_name == "get_content_fans_portrait" else "账号粉丝画像"
  148. }
  149. # 生成输出
  150. print("\n" + "="*60)
  151. print("📊 任务执行摘要(兜底输出)")
  152. print("="*60)
  153. print(f"\n搜索情况:找到 {len(search_results)} 条候选内容")
  154. print(f"画像获取:获取了 {len(portrait_data)} 条画像数据")
  155. # 筛选有画像且符合要求的内容
  156. matched_results = []
  157. for result in search_results:
  158. aweme_id = result["aweme_id"]
  159. author_id = result["author"].get("sec_uid", "")
  160. # 查找画像数据(优先内容画像,其次账号画像)
  161. portrait = portrait_data.get(aweme_id) or portrait_data.get(author_id)
  162. if portrait and portrait.get("age_50_plus"):
  163. try:
  164. age_ratio = float(portrait["age_50_plus"])
  165. if age_ratio >= 20: # 50岁以上占比>=20%
  166. matched_results.append({
  167. **result,
  168. "portrait": portrait
  169. })
  170. except:
  171. pass
  172. # 按50岁以上占比排序
  173. matched_results.sort(key=lambda x: float(x["portrait"]["age_50_plus"]), reverse=True)
  174. # 输出推荐结果
  175. print(f"\n符合要求:{len(matched_results)} 条内容(50岁以上占比>=20%)")
  176. print("\n" + "="*60)
  177. print("🎯 推荐结果")
  178. print("="*60)
  179. for i, result in enumerate(matched_results[:10], 1):
  180. aweme_id = result["aweme_id"]
  181. desc = result["desc"]
  182. author = result["author"]
  183. stats = result["statistics"]
  184. portrait = result["portrait"]
  185. print(f"\n{i}. {desc}")
  186. print(f" 链接: https://www.douyin.com/video/{aweme_id}")
  187. print(f" 作者: {author.get('nickname', '未知')}")
  188. print(f" 热度: 👍 {stats.get('digg_count', 0):,} | 💬 {stats.get('comment_count', 0):,} | 🔄 {stats.get('share_count', 0):,}")
  189. print(f" 画像: 50岁以上 {portrait['age_50_plus']}% (tgi: {portrait['tgi']}) - {portrait['source']}")
  190. print("\n" + "="*60)
  191. print(f"✅ 已为您找到 {min(len(matched_results), 10)} 条推荐视频")
  192. print("="*60)
  193. except Exception as e:
  194. logger.error(f"生成兜底输出失败: {e}", exc_info=True)
  195. print(f"\n生成摘要失败: {e}")
  196. async def main():
  197. print("\n" + "=" * 60)
  198. print("内容寻找 Agent")
  199. print("=" * 60)
  200. print("开始执行...\n")
  201. # 加载 prompt
  202. prompt_path = Path(__file__).parent / "content_finder.prompt"
  203. prompt = SimplePrompt(prompt_path)
  204. # 构建消息
  205. messages = prompt.build_messages()
  206. # 初始化
  207. api_key = os.getenv("OPEN_ROUTER_API_KEY")
  208. if not api_key:
  209. raise ValueError("OPEN_ROUTER_API_KEY 未设置,请在 .env 文件中配置")
  210. model = os.getenv("MODEL", f"anthropic/claude-{prompt.config.get('model', 'sonnet-4.6')}")
  211. temperature = float(prompt.config.get("temperature", 0.3))
  212. max_iterations = int(os.getenv("MAX_ITERATIONS", "30"))
  213. trace_dir = os.getenv("TRACE_DIR", ".cache/traces")
  214. skills_dir = str(Path(__file__).parent / "skills")
  215. Path(trace_dir).mkdir(parents=True, exist_ok=True)
  216. store = FileSystemTraceStore(base_path=trace_dir)
  217. # 限制工具范围:只使用抖音相关的4个工具
  218. allowed_tools = [
  219. "douyin_search",
  220. "douyin_user_videos",
  221. "get_content_fans_portrait",
  222. "get_account_fans_portrait",
  223. ]
  224. runner = AgentRunner(
  225. llm_call=create_openrouter_llm_call(model=model),
  226. trace_store=store,
  227. skills_dir=skills_dir,
  228. )
  229. config = RunConfig(
  230. model=model,
  231. temperature=temperature,
  232. max_iterations=max_iterations,
  233. tools=allowed_tools, # 限制工具范围
  234. extra_llm_params={"max_tokens": 8192}, # 增加输出 token 限制,避免被截断
  235. )
  236. # 执行
  237. trace_id = None
  238. has_final_output = False
  239. try:
  240. async for item in runner.run(messages=messages, config=config):
  241. if isinstance(item, Trace):
  242. trace_id = item.trace_id
  243. if item.status == "completed":
  244. print(f"\n[完成] trace_id={item.trace_id}")
  245. # 检查是否有最终输出
  246. if not has_final_output:
  247. print("\n⚠️ 检测到任务未完整输出,正在生成摘要...")
  248. await generate_fallback_output(store, item.trace_id)
  249. elif item.status == "failed":
  250. print(f"\n[失败] {item.error_message}")
  251. elif isinstance(item, Message):
  252. if item.role == "assistant":
  253. content = item.content
  254. if isinstance(content, dict):
  255. text = content.get("text", "")
  256. tool_calls = content.get("tool_calls")
  257. # 输出文本内容
  258. if text:
  259. # 检测是否包含最终推荐结果
  260. if "推荐结果" in text or "推荐内容" in text or "🎯" in text:
  261. has_final_output = True
  262. # 如果文本很长(>500字符)且包含推荐结果标记,输出完整内容
  263. if len(text) > 500 and ("推荐结果" in text or "推荐内容" in text or "🎯" in text):
  264. print(f"\n{text}")
  265. # 如果有工具调用且文本较短,只输出摘要
  266. elif tool_calls and len(text) > 100:
  267. print(f"[思考] {text[:100]}...")
  268. # 其他情况输出完整文本
  269. else:
  270. print(f"\n{text}")
  271. # 输出工具调用信息
  272. if tool_calls:
  273. for tc in tool_calls:
  274. tool_name = tc.get("function", {}).get("name", "unknown")
  275. # 跳过 goal 工具的输出,减少噪音
  276. if tool_name != "goal":
  277. print(f"[工具] {tool_name}")
  278. elif isinstance(content, str) and content:
  279. print(f"\n{content}")
  280. elif item.role == "tool":
  281. content = item.content
  282. if isinstance(content, dict):
  283. tool_name = content.get("tool_name", "unknown")
  284. print(f"[结果] {tool_name} ✓")
  285. except KeyboardInterrupt:
  286. print("\n用户中断")
  287. except Exception as e:
  288. logger.error(f"执行失败: {e}", exc_info=True)
  289. print(f"\n执行失败: {e}")
  290. sys.exit(1)
  291. if __name__ == "__main__":
  292. asyncio.run(main())