model_api.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. """
  2. 模型管理和推理工具
  3. """
  4. import httpx
  5. from agent import tool, ToolResult
  6. BASE_URL = "http://localhost:8100/api/agent"
  7. @tool(description="查询已部署和可部署的模型列表")
  8. async def list_models() -> ToolResult:
  9. """获取模型列表,包含已部署和可用模型的元数据"""
  10. async with httpx.AsyncClient(trust_env=False) as client:
  11. response = await client.get(f"{BASE_URL}/models", timeout=30.0)
  12. response.raise_for_status()
  13. return ToolResult(title="模型列表", output=response.text)
  14. @tool(description="部署指定模型到 GPU")
  15. async def deploy_model(model_id: str, epoch: int = None) -> ToolResult:
  16. """部署模型
  17. Args:
  18. model_id: 模型唯一标识符
  19. epoch: 可选的 checkpoint epoch
  20. """
  21. data = {"model_id": model_id}
  22. if epoch is not None:
  23. data["epoch"] = epoch
  24. async with httpx.AsyncClient(trust_env=False) as client:
  25. response = await client.post(f"{BASE_URL}/deploy", json=data, timeout=30.0)
  26. response.raise_for_status()
  27. return ToolResult(title="部署结果", output=response.text)
  28. @tool(description="向已部署的模型发送查询并获取响应")
  29. async def inference(model_id: str, query: str, temperature: float = 0.7) -> ToolResult:
  30. """模型推理
  31. Args:
  32. model_id: 已部署的模型标识符
  33. query: 推理查询内容
  34. temperature: 温度参数,范围 0.0-1.0,默认 0.7
  35. """
  36. data = {
  37. "model_id": model_id,
  38. "query": query,
  39. "temperature": temperature
  40. }
  41. async with httpx.AsyncClient(trust_env=False) as client:
  42. response = await client.post(f"{BASE_URL}/inference", json=data, timeout=120.0)
  43. response.raise_for_status()
  44. return ToolResult(title="推理结果", output=response.text)