research.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. import asyncio
  2. import json
  3. import logging
  4. import sys
  5. from pathlib import Path
  6. from typing import Dict, Any, Optional
  7. # 确保项目路径可用
  8. sys.path.insert(0, str(Path(__file__).parent.parent.parent))
  9. from agent.core.runner import AgentRunner
  10. from agent.trace import FileSystemTraceStore, Trace, Message
  11. from agent.llm import create_qwen_llm_call
  12. from agent.llm.prompts import SimplePrompt
  13. logger = logging.getLogger(__name__)
  14. # 文件保存 trace 映射关系,持久化续跑
  15. TRACE_MAP_FILE = Path(".cache/research_trace_map.json")
  16. def _load_trace_map() -> Dict[str, str]:
  17. if TRACE_MAP_FILE.exists():
  18. return json.loads(TRACE_MAP_FILE.read_text(encoding="utf-8"))
  19. return {}
  20. def _save_trace_map(mapping: Dict[str, str]):
  21. TRACE_MAP_FILE.parent.mkdir(parents=True, exist_ok=True)
  22. TRACE_MAP_FILE.write_text(json.dumps(mapping, indent=2, ensure_ascii=False), encoding="utf-8")
  23. def get_research_trace_id(caller_trace_id: str) -> Optional[str]:
  24. """根据调用方 trace_id 查找对应的 Research trace_id"""
  25. if not caller_trace_id:
  26. return None
  27. mapping = _load_trace_map()
  28. return mapping.get(caller_trace_id)
  29. def set_research_trace_id(caller_trace_id: str, research_trace_id: str):
  30. """记录映射"""
  31. if not caller_trace_id:
  32. return
  33. mapping = _load_trace_map()
  34. mapping[caller_trace_id] = research_trace_id
  35. _save_trace_map(mapping)
  36. # ===== 单例 Runner =====
  37. _runner: Optional[AgentRunner] = None
  38. _prompt_messages = None
  39. _initialized = False
  40. def _ensure_initialized():
  41. """延迟初始化 Runner 和 Prompt(首次调用时执行)"""
  42. global _runner, _prompt_messages, _initialized
  43. if _initialized:
  44. return
  45. _initialized = True
  46. # 初始化 Runner。工具会自动从 __file__.parent.parent.parent / agent / tools 加载吗?
  47. # 根据用户环境,内置通用工具大概是在 agent/tools,或者自动全局识别
  48. # 在这里,我们将 skills_dir 也设为此处寻找特定技能,如果需要的话可以扩展。
  49. skills_dir = Path(__file__).parent / "skills"
  50. _runner = AgentRunner(
  51. trace_store=FileSystemTraceStore(base_path=".trace"),
  52. llm_call=create_qwen_llm_call(model="qwen3.5-plus"), # prompt使用sonnet,但如果想和系统对齐可保留qwen,按照之前的设定
  53. skills_dir=str(skills_dir) if skills_dir.exists() else None,
  54. debug=True,
  55. logger_name="agents.research",
  56. )
  57. prompt_path = Path(__file__).parent / "research_agent.prompt"
  58. if prompt_path.exists():
  59. prompt = SimplePrompt(prompt_path)
  60. _prompt_messages = prompt.build_messages()
  61. # 尝试通过 prompt meta 获取模型设置
  62. if getattr(prompt, "meta", None) and prompt.meta.get("model"):
  63. model_name = prompt.meta["model"]
  64. _runner.llm_call = create_qwen_llm_call(model=model_name)
  65. else:
  66. _prompt_messages = []
  67. logger.warning(f"Research prompt 文件不存在: {prompt_path}")
  68. logger.info("✓ Research Agent 已初始化")
  69. # ===== 核心方法 =====
  70. async def research(query: str, caller_trace_id: str = "") -> Dict[str, Any]:
  71. """
  72. 同步执行深度调研。运行 Research Agent,返回调查结果。
  73. Args:
  74. query: 用户设定的研究主题或查询
  75. caller_trace_id: 调用方 trace_id,用于续跑
  76. Returns:
  77. {"response": str, "source_ids": [str], "sources": [dict]}
  78. """
  79. _ensure_initialized()
  80. # 初始化云端无头浏览器(因为是部署在线上,必须防卡顿并自动分配独立环境)
  81. try:
  82. from agent.tools.builtin.browser import init_browser_session
  83. await init_browser_session(browser_type="cloud")
  84. except Exception as e:
  85. logger.warning(f"Failed to init cloud browser: {e}")
  86. # 查找或创建 trace
  87. research_trace_id = get_research_trace_id(caller_trace_id)
  88. from agent.core.runner import RunConfig
  89. config = RunConfig(
  90. model="qwen3.5-plus",
  91. temperature=0.3,
  92. max_iterations=200,
  93. tool_groups=["core", "content", "browser"],
  94. skills=["planning", "research", "browser"],
  95. )
  96. config.trace_id = research_trace_id # None = 新建, 有值 = 续跑
  97. # 构建消息
  98. content = f"[RESEARCH TASK] {query}"
  99. if research_trace_id is None:
  100. messages = _prompt_messages + [{"role": "user", "content": content}]
  101. else:
  102. messages = [{"role": "user", "content": content}]
  103. # 运行 Agent
  104. response_text = ""
  105. actual_trace_id = None
  106. async for item in _runner.run(
  107. messages=messages,
  108. config=config,
  109. ):
  110. if isinstance(item, Trace):
  111. actual_trace_id = item.trace_id
  112. elif isinstance(item, Message):
  113. if item.role == "assistant":
  114. msg_content = item.content
  115. if isinstance(msg_content, dict):
  116. text = msg_content.get("text", "")
  117. if text:
  118. response_text = text
  119. elif isinstance(msg_content, str) and msg_content:
  120. response_text = msg_content
  121. # 记录 trace 映射
  122. if actual_trace_id and caller_trace_id:
  123. set_research_trace_id(caller_trace_id, actual_trace_id)
  124. return {
  125. "response": response_text,
  126. "source_ids": [],
  127. "sources": [],
  128. }