run_single.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. from dotenv import load_dotenv
  2. load_dotenv()
  3. from typing import Dict, Any, Optional
  4. import os
  5. from pathlib import Path
  6. import json
  7. from tools import fetch_account_article_list, fetch_weixin_account, weixin_search
  8. from agent import AgentRunner, RunConfig, FileSystemTraceStore, Trace, Message
  9. from agent.llm import create_openrouter_llm_call
  10. from agent.llm.prompts import SimplePrompt
  11. from agent.tools.builtin.knowledge import KnowledgeConfig
  12. # 默认搜索词
  13. DEFAULT_QUERY = "伊朗、以色列、和平是永恒的主题"
  14. DEFAULT_DEMAND_ID = 1
  15. import logging
  16. logger = logging.getLogger(__name__)
  17. PROJECT_ROOT = Path(__file__).resolve().parent
  18. def _normalize_ascii_double_quotes(text: str) -> str:
  19. """将字符串中的 ASCII 双引号 `"` 规范化为中文双引号 `“`、`”`。"""
  20. if '"' not in text:
  21. return text
  22. chars: list[str] = []
  23. open_quote = True
  24. for ch in text:
  25. if ch == '"':
  26. chars.append("“" if open_quote else "”")
  27. open_quote = not open_quote
  28. else:
  29. chars.append(ch)
  30. return "".join(chars)
  31. def _sanitize_json_strings(value: Any) -> Any:
  32. if isinstance(value, str):
  33. return _normalize_ascii_double_quotes(value)
  34. if isinstance(value, list):
  35. return [_sanitize_json_strings(v) for v in value]
  36. if isinstance(value, dict):
  37. return {k: _sanitize_json_strings(v) for k, v in value.items()}
  38. return value
  39. def _sanitize_output_json(output_json_path: Path) -> None:
  40. """
  41. 任务完成后对 output.json 做后处理:
  42. - 递归清洗所有字符串值中的英文双引号 `"`
  43. - 保持合法 JSON
  44. """
  45. if not output_json_path.exists():
  46. logger.warning(f"未找到 output.json,跳过清洗: {output_json_path}")
  47. return
  48. try:
  49. data = json.loads(output_json_path.read_text(encoding="utf-8"))
  50. except Exception as e:
  51. logger.warning(f"output.json 解析失败,跳过清洗: {e}")
  52. return
  53. cleaned = _sanitize_json_strings(data)
  54. output_json_path.write_text(
  55. json.dumps(cleaned, ensure_ascii=False, indent=2),
  56. encoding="utf-8"
  57. )
  58. logger.info(f"已完成 output.json 引号清洗: {output_json_path}")
  59. async def run_agent(
  60. query: Optional[str] = None,
  61. demand_id: Optional[int] = None,
  62. stream_output: bool = True,
  63. ) -> Dict[str, Any]:
  64. """
  65. 执行 agent 任务
  66. Args:
  67. query: 查询内容(搜索词),None 则使用默认值
  68. demand_id: 本次搜索任务 id(int,关联 demand_content 表)
  69. stream_output: 是否流式输出到 stdout(run.py 需要,server.py 不需要)
  70. Returns:
  71. {
  72. "trace_id": "20260317_103046_xyz789",
  73. "status": "completed" | "failed",
  74. "error": "错误信息" # 失败时
  75. }
  76. """
  77. query = query or DEFAULT_QUERY
  78. demand_id = demand_id or DEFAULT_DEMAND_ID
  79. # 加载 prompt
  80. prompt_path = PROJECT_ROOT / "content_finder.prompt"
  81. prompt = SimplePrompt(prompt_path)
  82. # output 目录
  83. output_dir = str(PROJECT_ROOT / "output")
  84. # 构建消息(替换 %query%、%output_dir%、%demand_id%)
  85. demand_id_str = str(demand_id) if demand_id is not None else ""
  86. messages = prompt.build_messages(query=query, output_dir=output_dir, demand_id=demand_id_str)
  87. # 初始化配置
  88. api_key = os.getenv("OPEN_ROUTER_API_KEY")
  89. if not api_key:
  90. raise ValueError("OPEN_ROUTER_API_KEY 未设置")
  91. model_name = prompt.config.get("model", "sonnet-4.6")
  92. model = os.getenv("MODEL", f"anthropic/claude-{model_name}")
  93. temperature = float(prompt.config.get("temperature", 0.3))
  94. max_iterations = 30
  95. trace_dir = str(PROJECT_ROOT / "traces")
  96. skills_dir = str(PROJECT_ROOT / "skills")
  97. Path(trace_dir).mkdir(parents=True, exist_ok=True)
  98. store = FileSystemTraceStore(base_path=trace_dir)
  99. allowed_tools = [
  100. "weixin_search",
  101. "fetch_weixin_account",
  102. "fetch_account_article_list",
  103. "fetch_article_detail",
  104. ]
  105. runner = AgentRunner(
  106. llm_call=create_openrouter_llm_call(model=model),
  107. trace_store=store,
  108. skills_dir=skills_dir,
  109. )
  110. config = RunConfig(
  111. name="内容寻找",
  112. model=model,
  113. temperature=temperature,
  114. max_iterations=max_iterations,
  115. tools=allowed_tools,
  116. extra_llm_params={"max_tokens": 8192},
  117. knowledge=KnowledgeConfig(
  118. enable_extraction=False,
  119. enable_completion_extraction=False,
  120. enable_injection=False,
  121. # owner="content_finder_agent",
  122. # default_tags={"project": "content_finder"},
  123. # default_scopes=["com.piaoquantv.supply"],
  124. # default_search_types=["tool", "usecase", "definition"],
  125. # default_search_owner="content_finder_agent"
  126. )
  127. )
  128. # 执行
  129. trace_id = None
  130. try:
  131. async for item in runner.run(messages=messages, config=config):
  132. if isinstance(item, Trace):
  133. trace_id = item.trace_id
  134. if item.status == "completed":
  135. if trace_id:
  136. output_json_path = Path(output_dir) / trace_id / "output.json"
  137. _sanitize_output_json(output_json_path)
  138. logger.info(f"Agent 执行完成: trace_id={trace_id}")
  139. return {
  140. "trace_id": trace_id,
  141. "status": "completed"
  142. }
  143. elif item.status == "failed":
  144. logger.error(f"Agent 执行失败: {item.error_message}")
  145. return {
  146. "trace_id": trace_id,
  147. "status": "failed",
  148. "error": item.error_message
  149. }
  150. elif isinstance(item, Message) and stream_output:
  151. # 流式输出(仅 run.py 需要)
  152. if item.role == "assistant":
  153. content = item.content
  154. if isinstance(content, dict):
  155. text = content.get("text", "")
  156. tool_calls = content.get("tool_calls", [])
  157. if text:
  158. # 如果有推荐结果,完整输出
  159. if len(text) > 500 and ("推荐结果" in text or "推荐内容" in text or "🎯" in text):
  160. print(f"\n{text}")
  161. # 如果有工具调用且文本较短,只输出摘要
  162. elif tool_calls and len(text) > 100:
  163. print(f"[思考] {text[:100]}...")
  164. # 其他情况输出完整文本
  165. else:
  166. print(f"\n{text}")
  167. # 输出工具调用信息
  168. if tool_calls:
  169. for tc in tool_calls:
  170. tool_name = tc.get("function", {}).get("name", "unknown")
  171. # 跳过 goal 工具的输出,减少噪音
  172. if tool_name != "goal":
  173. print(f"[工具] {tool_name}")
  174. elif isinstance(content, str) and content:
  175. print(f"\n{content}")
  176. elif item.role == "tool":
  177. content = item.content
  178. if isinstance(content, dict):
  179. tool_name = content.get("tool_name", "unknown")
  180. print(f"[结果] {tool_name} ✓")
  181. # 如果循环结束但没有返回,说明异常退出
  182. return {
  183. "trace_id": trace_id,
  184. "status": "failed",
  185. "error": "Agent 异常退出"
  186. }
  187. except KeyboardInterrupt:
  188. logger.info("用户中断")
  189. if stream_output:
  190. print("\n用户中断")
  191. return {
  192. "trace_id": trace_id,
  193. "status": "failed",
  194. "error": "用户中断"
  195. }
  196. except Exception as e:
  197. logger.error(f"Agent 执行异常: {e}", exc_info=True)
  198. if stream_output:
  199. print(f"\n执行失败: {e}")
  200. return {
  201. "trace_id": trace_id,
  202. "status": "failed",
  203. "error": str(e)
  204. }
  205. if __name__ == "__main__":
  206. import asyncio
  207. asyncio.run(run_agent())