generate_flux.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. import os
  2. import time
  3. import argparse
  4. import requests
  5. from dotenv import load_dotenv
  6. from openai import OpenAI
  7. # 加载根目录或当前目录的 .env 文件
  8. env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))), ".env")
  9. if os.path.exists(env_path):
  10. load_dotenv(env_path)
  11. else:
  12. load_dotenv()
  13. APIYI_KEY = os.getenv("APIYI_KEY")
  14. if APIYI_KEY and not APIYI_KEY.startswith("sk-"):
  15. APIYI_KEY = f"sk-{APIYI_KEY}"
  16. APIYI_BASE_URL = os.getenv("APIYI_BASE_URL", "https://api.apiyi.com/v1")
  17. if not APIYI_KEY:
  18. print("错误: 请在 .env 文件中设置 APIYI_KEY")
  19. exit(1)
  20. # 初始化客户端
  21. client = OpenAI(
  22. api_key=APIYI_KEY,
  23. base_url=APIYI_BASE_URL
  24. )
  25. import tempfile
  26. from urllib.parse import urlparse
  27. def download_image_from_url(url):
  28. """从 URL 下载图片到临时文件"""
  29. try:
  30. response = requests.get(url, timeout=30)
  31. response.raise_for_status()
  32. parsed_url = urlparse(url)
  33. extension = parsed_url.path.split('.')[-1] if '.' in parsed_url.path else 'jpg'
  34. temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=f'.{extension}')
  35. temp_file.write(response.content)
  36. temp_file.close()
  37. return temp_file.name
  38. except Exception as e:
  39. print(f"下载图片失败: {e}")
  40. return None
  41. def generate_flux_image(prompt, aspect_ratio="1:1", model="flux-kontext-pro",
  42. seed=None, safety_tolerance=2, output_format="jpeg",
  43. prompt_upsampling=False, image_url=None, mask_url=None):
  44. """
  45. 生成或编辑 Flux 图像
  46. """
  47. mode = "编辑" if image_url else "生成"
  48. print(f"正在{mode}图像...")
  49. print(f" - 模型: {model}")
  50. print(f" - 比例: {aspect_ratio}")
  51. print(f" - 提示词: {prompt}")
  52. image_path = None
  53. mask_path = None
  54. image_file = None
  55. mask_file = None
  56. try:
  57. # 构建 extra_body 参数
  58. extra_params = {
  59. "aspect_ratio": aspect_ratio,
  60. "safety_tolerance": safety_tolerance,
  61. "output_format": output_format
  62. }
  63. if seed is not None:
  64. extra_params["seed"] = seed
  65. if not image_url:
  66. # 纯生成特有参数
  67. extra_params["prompt_upsampling"] = prompt_upsampling
  68. response = client.images.generate(
  69. model=model,
  70. prompt=prompt,
  71. extra_body=extra_params
  72. )
  73. else:
  74. # 图生图编辑模式
  75. print(f" - 参考图: {image_url}")
  76. image_path = download_image_from_url(image_url)
  77. if not image_path:
  78. raise ValueError("无法下载参考图")
  79. image_file = open(image_path, "rb")
  80. edit_kwargs = {
  81. "model": model,
  82. "image": image_file,
  83. "prompt": prompt,
  84. "extra_body": extra_params
  85. }
  86. if mask_url:
  87. print(f" - 蒙版图: {mask_url}")
  88. mask_path = download_image_from_url(mask_url)
  89. if mask_path:
  90. mask_file = open(mask_path, "rb")
  91. edit_kwargs["mask"] = mask_file
  92. response = client.images.edit(**edit_kwargs)
  93. image_out_url = response.data[0].url
  94. return image_out_url
  95. except Exception as e:
  96. print(f" 请求失败: {e}")
  97. return None
  98. finally:
  99. # 关闭文件并删除临时文件
  100. if image_file:
  101. image_file.close()
  102. if mask_file:
  103. mask_file.close()
  104. for path in [image_path, mask_path]:
  105. if path and os.path.exists(path):
  106. try:
  107. os.unlink(path)
  108. except:
  109. pass
  110. def save_image_from_url(url, prompt, ratio):
  111. """从 URL 下载并保存图像"""
  112. if not url:
  113. return None
  114. try:
  115. print("正在下载临时链接中的图像...")
  116. response = requests.get(url, timeout=30)
  117. response.raise_for_status()
  118. # 生成文件名,替换掉可能导致问题的非法字符
  119. timestamp = int(time.time())
  120. safe_prompt = ''.join(c if c.isalnum() else '_' for c in prompt[:30]).strip('_')
  121. # 保存到当前目录下的 outputs 文件夹
  122. output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "outputs")
  123. os.makedirs(output_dir, exist_ok=True)
  124. filename = f"flux_{safe_prompt}_{ratio.replace(':', 'x')}_{timestamp}.png"
  125. filepath = os.path.join(output_dir, filename)
  126. with open(filepath, 'wb') as f:
  127. f.write(response.content)
  128. print(f"✅ 已成功保存至: {filepath}")
  129. return filepath
  130. except requests.exceptions.RequestException as e:
  131. print(f"❌ 下载失败: {e}")
  132. print("提示:Flux URL 仅 10 分钟有效,请确保及时手动下载!")
  133. return None
  134. if __name__ == "__main__":
  135. parser = argparse.ArgumentParser(description="Flux API 图像生成高级调试工具")
  136. parser.add_argument("--prompt", type=str, default="A futuristic city with flying cars and neon lights, high resolution, cyberpunk style", help="图像描述提示词")
  137. parser.add_argument("--ratio", type=str, default="16:9", help="宽高比 (如: 1:1, 16:9, 9:16, 2:3, 3:2)")
  138. parser.add_argument("--model", type=str, default="flux-kontext-pro", choices=["flux-kontext-pro", "flux-kontext-max"], help="模型名称")
  139. parser.add_argument("--seed", type=int, default=None, help="随机种子 (默认随机)")
  140. parser.add_argument("--format", type=str, default="png", choices=["jpeg", "png"], help="输出图片格式")
  141. parser.add_argument("--upsampling", action="store_true", help="开启提示词自动增强")
  142. args = parser.parse_args()
  143. image_url = generate_flux_image(
  144. prompt=args.prompt,
  145. aspect_ratio=args.ratio,
  146. model=args.model,
  147. seed=args.seed,
  148. output_format=args.format,
  149. prompt_upsampling=args.upsampling
  150. )
  151. if image_url:
  152. print(f"\n✅ 生成成功!图像 URL (10分钟有效): \n{image_url}")
  153. save_image_from_url(image_url, args.prompt, args.ratio)