main.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. """Midjourney 本地代理 — 三个 POST 与上游 JSON 对齐。
  2. 环境变量:
  3. MIDJOURNEY_API_BASE 必填,例如 https://your-host(后接 /submit_job 等)
  4. 接口(与 Router 注册一致):
  5. GET /health
  6. POST /submit_job cookie, prompt, user_id, mode ∈ relaxed|fast
  7. POST /query_job_status cookie, job_id
  8. POST /get_image_urls job_id
  9. """
  10. from __future__ import annotations
  11. import argparse
  12. from typing import Any, Literal
  13. import uvicorn
  14. from fastapi import FastAPI, HTTPException
  15. from pydantic import BaseModel, Field
  16. from midjourney_client import forward_post
  17. app = FastAPI(title="Midjourney API Proxy")
  18. class SubmitJobRequest(BaseModel):
  19. cookie: str = Field(..., description="Midjourney 会话 cookie")
  20. prompt: str = Field(..., description="提示词")
  21. user_id: str = Field(..., description="用户 ID")
  22. mode: Literal["relaxed", "fast"] = Field(..., description="relaxed 或 fast")
  23. class QueryJobStatusRequest(BaseModel):
  24. cookie: str = Field(..., description="Midjourney 会话 cookie")
  25. job_id: str = Field(..., description="submit_job 返回的任务 ID")
  26. class GetImageUrlsRequest(BaseModel):
  27. job_id: str = Field(..., description="任务 ID")
  28. @app.get("/health")
  29. def health() -> dict[str, str]:
  30. return {"status": "ok"}
  31. @app.post("/submit_job")
  32. def submit_job(req: SubmitJobRequest) -> Any:
  33. try:
  34. return forward_post(
  35. "/submit_job",
  36. {
  37. "cookie": req.cookie,
  38. "prompt": req.prompt,
  39. "user_id": req.user_id,
  40. "mode": req.mode,
  41. },
  42. )
  43. except ValueError as e:
  44. raise HTTPException(status_code=503, detail=str(e)) from e
  45. except RuntimeError as e:
  46. raise HTTPException(status_code=502, detail=str(e)) from e
  47. except Exception as e:
  48. raise HTTPException(status_code=502, detail=str(e)) from e
  49. @app.post("/query_job_status")
  50. def query_job_status(req: QueryJobStatusRequest) -> Any:
  51. try:
  52. return forward_post(
  53. "/query_job_status",
  54. {"cookie": req.cookie, "job_id": req.job_id},
  55. )
  56. except ValueError as e:
  57. raise HTTPException(status_code=503, detail=str(e)) from e
  58. except RuntimeError as e:
  59. raise HTTPException(status_code=502, detail=str(e)) from e
  60. except Exception as e:
  61. raise HTTPException(status_code=502, detail=str(e)) from e
  62. @app.post("/get_image_urls")
  63. def get_image_urls(req: GetImageUrlsRequest) -> Any:
  64. try:
  65. return forward_post("/get_image_urls", {"job_id": req.job_id})
  66. except ValueError as e:
  67. raise HTTPException(status_code=503, detail=str(e)) from e
  68. except RuntimeError as e:
  69. raise HTTPException(status_code=502, detail=str(e)) from e
  70. except Exception as e:
  71. raise HTTPException(status_code=502, detail=str(e)) from e
  72. if __name__ == "__main__":
  73. parser = argparse.ArgumentParser()
  74. parser.add_argument("--port", type=int, default=8001)
  75. args = parser.parse_args()
  76. uvicorn.run(app, host="0.0.0.0", port=args.port)