api_server.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. """
  2. API Server - FastAPI 应用入口
  3. 聚合所有模块的 API 路由:
  4. - GET /api/traces — 查询(trace/api.py)
  5. - POST /api/traces — 执行控制(trace/run_api.py,需配置 Runner)
  6. - WS /api/traces/{id}/watch — 实时推送(trace/websocket.py)
  7. - GET /api/experiences — 经验查询(trace/run_api.py,需配置 Runner)
  8. """
  9. import logging
  10. import json
  11. import os
  12. from dotenv import load_dotenv
  13. load_dotenv()
  14. from fastapi import FastAPI, Request, WebSocket
  15. from fastapi.middleware.cors import CORSMiddleware
  16. import uvicorn
  17. from agent.trace import FileSystemTraceStore
  18. from agent.trace.api import router as api_router, set_trace_store as set_api_trace_store
  19. from agent.trace.run_api import router as run_router, experiences_router, set_runner
  20. from agent.trace.websocket import router as ws_router, set_trace_store as set_ws_trace_store
  21. from agent.trace.examples_api import router as examples_router
  22. from agent.trace.logs_websocket import router as logs_router, setup_websocket_logging
  23. from agent.trace.upload_api import router as upload_router, set_trace_store as set_upload_trace_store
  24. # ===== 日志配置 =====
  25. logging.basicConfig(
  26. level=logging.INFO,
  27. format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
  28. )
  29. logger = logging.getLogger(__name__)
  30. # 设置WebSocket日志推送
  31. setup_websocket_logging(level=logging.INFO)
  32. # ===== FastAPI 应用 =====
  33. app = FastAPI(
  34. title="Agent API",
  35. description="Agent 查询 + 执行 API",
  36. version="1.0.0"
  37. )
  38. # CORS 配置(允许前端跨域访问)
  39. app.add_middleware(
  40. CORSMiddleware,
  41. allow_origins=["*"], # 生产环境应限制具体域名
  42. allow_credentials=True,
  43. allow_methods=["*"],
  44. allow_headers=["*"],
  45. )
  46. # ===== 初始化存储 =====
  47. # 使用文件系统存储(支持跨进程和持久化)
  48. trace_store = FileSystemTraceStore(base_path=".trace")
  49. # 注入到 step_tree 模块
  50. set_api_trace_store(trace_store)
  51. set_ws_trace_store(trace_store)
  52. set_upload_trace_store(trace_store)
  53. # ===== 可选:配置 Runner(启用执行 API)=====
  54. # 如需启用 POST /api/traces(新建/运行/停止/反思),取消以下注释并配置 LLM:
  55. from agent.core.runner import AgentRunner
  56. from agent.llm import create_openrouter_llm_call
  57. runner = AgentRunner(
  58. trace_store=trace_store,
  59. llm_call=create_openrouter_llm_call(model="anthropic/claude-sonnet-4.5"),
  60. )
  61. set_runner(runner)
  62. # ===== 注册路由 =====
  63. # Examples API(GET /api/examples)
  64. app.include_router(examples_router)
  65. # Trace 上传 API(POST /api/traces/upload)
  66. app.include_router(upload_router)
  67. # Trace 执行 API(POST + GET /running,需配置 Runner)
  68. # 注意:run_router 必须在 api_router 之前注册,否则 GET /running 会被 /{trace_id} 捕获
  69. app.include_router(run_router)
  70. # 经验 API(GET /api/experiences,需配置 Runner)
  71. app.include_router(experiences_router)
  72. # Trace 查询 API(GET)
  73. app.include_router(api_router)
  74. # Trace WebSocket(实时推送)
  75. app.include_router(ws_router)
  76. # Logs WebSocket(日志推送)
  77. app.include_router(logs_router)
  78. @app.on_event("startup")
  79. async def on_startup():
  80. """服务器启动时执行状态对齐"""
  81. from agent.trace.run_api import reconcile_traces
  82. await reconcile_traces()
  83. @app.websocket("/ws_ping")
  84. async def ws_ping(websocket: WebSocket):
  85. await websocket.accept()
  86. await websocket.send_text("pong")
  87. await websocket.close()
  88. # ===== 健康检查 =====
  89. @app.get("/health")
  90. async def health_check():
  91. """健康检查"""
  92. return {
  93. "status": "ok",
  94. "service": "Agent Step Tree API",
  95. "version": "1.0.0"
  96. }
  97. # ===== 启动服务 =====
  98. if __name__ == "__main__":
  99. logger.info("Starting API server...")
  100. uvicorn.run(
  101. "api_server:app",
  102. host="0.0.0.0",
  103. port=8000,
  104. reload=True, # 开发模式
  105. log_level="info"
  106. )