| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071 |
- """即梦任务工具 — FastAPI 调用层(范本同 liblibai_controlnet)。
- 无本地缓存:/add_task、/query_task 直接转发到上游 HTTP(见 ji_meng_client.py)。
- 环境变量:
- JI_MENG_API_BASE 必填,上游根 URL
- JI_MENG_ADD_TASK_PATH 可选,默认 /add_task
- JI_MENG_QUERY_TASK_PATH 可选,默认 /query_task
- JI_MENG_API_KEY 可选,Bearer Token
- 注册(由 Agent 写入 registry + sources):
- tool_id=ji_meng_add_task / ji_meng_query_task,host_dir=tools/local/ji_meng,
- endpoint_path 分别为 /add_task、/query_task。
- """
- from __future__ import annotations
- import argparse
- from typing import Literal
- import uvicorn
- from fastapi import FastAPI, HTTPException
- from pydantic import BaseModel, Field
- from ji_meng_client import JiMengClient
- app = FastAPI(title="Ji Meng Task API")
- class AddTaskRequest(BaseModel):
- task_type: Literal['image', 'video'] = Field(default=..., description="任务类型")
- prompt: str = Field(..., description="任务描述 / 提示词")
- image_url: str | None = Field(default=None, description="图片 URL")
- class QueryTaskRequest(BaseModel):
- task_id: str = Field(..., description="创建任务接口返回的任务 ID")
- @app.get("/health")
- def health() -> dict:
- return {"status": "ok"}
- @app.post("/add_task")
- def add_task(req: AddTaskRequest) -> dict:
- try:
- client = JiMengClient()
- return client.submit_task(task_type=req.task_type, prompt=req.prompt, image_url=req.image_url)
- except ValueError as e:
- raise HTTPException(status_code=503, detail=str(e)) from e
- except Exception as e:
- raise HTTPException(status_code=502, detail=str(e)) from e
- @app.post("/query_task")
- def query_task(req: QueryTaskRequest) -> dict:
- try:
- client = JiMengClient()
- return client.query_task(req.task_id)
- except ValueError as e:
- raise HTTPException(status_code=503, detail=str(e)) from e
- except Exception as e:
- raise HTTPException(status_code=502, detail=str(e)) from e
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("--port", type=int, default=8001)
- args = parser.parse_args()
- uvicorn.run(app, host="0.0.0.0", port=args.port)
|