midjourney_client.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. """将请求转发到自建 Midjourney HTTP 服务(与 tools/local 约定一致)。"""
  2. from __future__ import annotations
  3. import os
  4. from typing import Any
  5. import httpx
  6. from dotenv import load_dotenv
  7. _ = load_dotenv()
  8. def _base_url() -> str:
  9. base = os.environ.get("MIDJOURNEY_API_BASE", "").strip().rstrip("/")
  10. if not base:
  11. raise ValueError("缺少环境变量 MIDJOURNEY_API_BASE(上游根 URL,不含尾路径)")
  12. return base
  13. def forward_post(path: str, json_body: dict[str, Any]) -> Any:
  14. """POST {MIDJOURNEY_API_BASE}{path},Content-Type: application/json。"""
  15. url = f"{_base_url()}{path if path.startswith('/') else '/' + path}"
  16. with httpx.Client(timeout=300.0) as client:
  17. r = client.post(
  18. url,
  19. json=json_body,
  20. headers={"accept": "application/json", "Content-Type": "application/json"},
  21. )
  22. ct = (r.headers.get("content-type") or "").lower()
  23. if "application/json" not in ct:
  24. r.raise_for_status()
  25. raise RuntimeError(f"非 JSON 响应 ({r.status_code}): {r.text[:1500]}")
  26. try:
  27. data = r.json()
  28. except Exception:
  29. raise RuntimeError(f"无效 JSON ({r.status_code}): {r.text[:1500]}") from None
  30. if r.status_code >= 400:
  31. if isinstance(data, dict):
  32. msg = data.get("detail", data.get("message", data.get("error", str(data))))
  33. else:
  34. msg = str(data)
  35. raise RuntimeError(f"上游 HTTP {r.status_code}: {msg}")
  36. return data