| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152 |
- """
- 模型管理和推理工具
- """
- import httpx
- from agent import tool, ToolResult
- BASE_URL = "http://localhost:8100/api/agent"
- @tool(description="查询已部署和可部署的模型列表")
- async def list_models() -> ToolResult:
- """获取模型列表,包含已部署和可用模型的元数据"""
- async with httpx.AsyncClient(trust_env=False) as client:
- response = await client.get(f"{BASE_URL}/models", timeout=30.0)
- response.raise_for_status()
- return ToolResult(title="模型列表", output=response.text)
- @tool(description="部署指定模型到 GPU")
- async def deploy_model(model_id: str, epoch: int = None) -> ToolResult:
- """部署模型
- Args:
- model_id: 模型唯一标识符
- epoch: 可选的 checkpoint epoch
- """
- data = {"model_id": model_id}
- if epoch is not None:
- data["epoch"] = epoch
- async with httpx.AsyncClient(trust_env=False) as client:
- response = await client.post(f"{BASE_URL}/deploy", json=data, timeout=30.0)
- response.raise_for_status()
- return ToolResult(title="部署结果", output=response.text)
- @tool(description="向已部署的模型发送查询并获取响应")
- async def inference(model_id: str, query: str, temperature: float = 0.7) -> ToolResult:
- """模型推理
- Args:
- model_id: 已部署的模型标识符
- query: 推理查询内容
- temperature: 温度参数,范围 0.0-1.0,默认 0.7
- """
- data = {
- "model_id": model_id,
- "query": query,
- "temperature": temperature
- }
- async with httpx.AsyncClient(trust_env=False) as client:
- response = await client.post(f"{BASE_URL}/inference", json=data, timeout=120.0)
- response.raise_for_status()
- return ToolResult(title="推理结果", output=response.text)
|