prompt_generator.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. """
  2. 图片还原 Prompt 生成系统
  3. 基于解构数据自动构建高质量的图片生成prompt
  4. """
  5. import json
  6. from pathlib import Path
  7. from typing import Dict, List, Any
  8. class PromptGenerator:
  9. """Prompt生成器"""
  10. def __init__(self, input_dir: str = "input/paragraphs"):
  11. self.input_dir = Path(input_dir)
  12. self.global_elements = self._load_global_elements()
  13. self.global_forms = self._load_global_forms()
  14. def _load_global_elements(self) -> List[Dict]:
  15. """加载全局实质元素(跨图聚合)"""
  16. file_path = self.input_dir / "03_图片制作点实质结果.json"
  17. with open(file_path, 'r', encoding='utf-8') as f:
  18. return json.load(f)
  19. def _load_global_forms(self) -> List[Dict]:
  20. """加载全局形式特征(跨图聚合)"""
  21. file_path = self.input_dir / "04_图片制作点形式结果.json"
  22. with open(file_path, 'r', encoding='utf-8') as f:
  23. return json.load(f)
  24. def _load_image_segment(self, image_num: int) -> Dict:
  25. """加载指定图片的分段数据"""
  26. # 查找对应的分段文件
  27. pattern = f"01_图片分段_{image_num:02d}_*.json"
  28. files = list(self.input_dir.glob(pattern))
  29. if not files:
  30. raise FileNotFoundError(f"未找到图片{image_num}的分段文件")
  31. with open(files[0], 'r', encoding='utf-8') as f:
  32. return json.load(f)
  33. def _load_image_form(self, image_num: int) -> Dict:
  34. """加载指定图片的形式分析"""
  35. file_path = self.input_dir / f"02_图片形式_{image_num:02d}.json"
  36. with open(file_path, 'r', encoding='utf-8') as f:
  37. return json.load(f)
  38. def _extract_key_descriptions(self, segment_data: Dict, form_data: Dict) -> Dict[str, str]:
  39. """提取关键描述信息"""
  40. descriptions = {
  41. "scene": "",
  42. "person": "",
  43. "person_pose": "",
  44. "person_clothing": "",
  45. "person_hair": "",
  46. "easel": "",
  47. "palette": "",
  48. "background": "",
  49. "details": []
  50. }
  51. # 从分段数据提取
  52. sections = segment_data.get("sections", [])
  53. if sections:
  54. main_section = sections[0]
  55. descriptions["scene"] = main_section.get("描述", "")
  56. # 遍历子段落
  57. for sub in main_section.get("子段落", []):
  58. name = sub.get("名称", "")
  59. desc = sub.get("描述", "")
  60. if "人物" in name:
  61. descriptions["person"] = desc
  62. elif "画架" in name:
  63. descriptions["easel"] = desc
  64. elif "调色板" in name:
  65. descriptions["palette"] = desc
  66. elif "背景" in name:
  67. descriptions["background"] = desc
  68. # 从形式数据提取细节
  69. form_elements = form_data.get("form_elements", [])
  70. for elem_group in form_elements:
  71. for form in elem_group.get("形式", []):
  72. name = form.get("名称", "")
  73. desc = form.get("描述", "")
  74. if "人物姿态" in name:
  75. descriptions["person_pose"] = desc
  76. elif "人物着装" in name:
  77. descriptions["person_clothing"] = desc
  78. elif "人物发型" in name:
  79. descriptions["person_hair"] = desc
  80. elif name not in ["人物", "背景", "画架", "调色板"]:
  81. descriptions["details"].append(f"{name}: {desc}")
  82. return descriptions
  83. def _build_prompt_structure(self, descriptions: Dict[str, str], group_id: str) -> str:
  84. """
  85. 构建结构化prompt
  86. 顺序:[主体] + [描述性属性] + [场景环境] + [光照条件] + [情感氛围] + [构图方式] + [艺术风格]
  87. """
  88. prompt_parts = []
  89. # 1. 主体描述(人物 + 道具)
  90. if descriptions["person"]:
  91. prompt_parts.append(descriptions["person"])
  92. # 2. 描述性属性(姿态、着装)
  93. if descriptions["person_pose"]:
  94. # 简化姿态描述,提取关键动作
  95. pose_simplified = self._simplify_pose(descriptions["person_pose"])
  96. prompt_parts.append(pose_simplified)
  97. if descriptions["person_clothing"]:
  98. # 简化着装描述
  99. clothing_simplified = self._simplify_clothing(descriptions["person_clothing"])
  100. prompt_parts.append(clothing_simplified)
  101. # 3. 道具细节
  102. if "g1" in group_id or "g2" in group_id: # 户外绘画场景
  103. if descriptions["easel"]:
  104. prompt_parts.append(descriptions["easel"])
  105. if descriptions["palette"]:
  106. prompt_parts.append(descriptions["palette"])
  107. # 4. 场景环境
  108. if descriptions["background"]:
  109. prompt_parts.append(descriptions["background"])
  110. # 5. 光照条件(从全局形式特征推断)
  111. prompt_parts.append("Natural outdoor lighting, bright and soft sunlight")
  112. # 6. 情感氛围
  113. if "g3" in group_id: # 人物特写
  114. prompt_parts.append("Peaceful and serene atmosphere, eyes closed in contemplation")
  115. else: # 绘画场景
  116. prompt_parts.append("Focused and creative atmosphere, artist at work")
  117. # 7. 艺术风格
  118. prompt_parts.append("Photorealistic style, high quality photography, professional composition")
  119. # 组合成完整prompt
  120. prompt = ". ".join(filter(None, prompt_parts)) + "."
  121. return prompt
  122. def _simplify_pose(self, pose_desc: str) -> str:
  123. """简化姿态描述,提取关键信息"""
  124. # 提取关键动作词
  125. key_actions = []
  126. if "站立" in pose_desc:
  127. key_actions.append("standing")
  128. if "侧身" in pose_desc or "侧向" in pose_desc:
  129. key_actions.append("side view")
  130. if "蹲" in pose_desc:
  131. key_actions.append("crouching")
  132. if "背对" in pose_desc:
  133. key_actions.append("back view")
  134. if "持画笔" in pose_desc:
  135. key_actions.append("holding a paintbrush")
  136. if "持调色板" in pose_desc or "托举调色板" in pose_desc:
  137. key_actions.append("holding a palette")
  138. return ", ".join(key_actions) if key_actions else pose_desc[:100]
  139. def _simplify_clothing(self, clothing_desc: str) -> str:
  140. """简化着装描述"""
  141. # 提取关键服饰信息
  142. simplified = []
  143. if "白色" in clothing_desc and "连衣裙" in clothing_desc:
  144. simplified.append("white dress")
  145. if "长袖" in clothing_desc:
  146. simplified.append("long sleeves")
  147. if "V字形领口" in clothing_desc or "V领" in clothing_desc:
  148. simplified.append("V-neck")
  149. return ", ".join(simplified) if simplified else clothing_desc[:100]
  150. def generate_prompt(self, image_num: int) -> Dict[str, Any]:
  151. """
  152. 为指定图片生成prompt
  153. Args:
  154. image_num: 图片编号 (1-9)
  155. Returns:
  156. 包含prompt和元数据的字典
  157. """
  158. # 加载数据
  159. segment_data = self._load_image_segment(image_num)
  160. form_data = self._load_image_form(image_num)
  161. # 确定分组
  162. segment_file = list(self.input_dir.glob(f"01_图片分段_{image_num:02d}_*.json"))[0]
  163. group_id = "g1" # 默认
  164. if "g2" in segment_file.name:
  165. group_id = "g2"
  166. elif "g3" in segment_file.name:
  167. group_id = "g3"
  168. # 提取关键描述
  169. descriptions = self._extract_key_descriptions(segment_data, form_data)
  170. # 构建prompt
  171. prompt = self._build_prompt_structure(descriptions, group_id)
  172. # 确定图片尺寸(竖版)
  173. size = "1024x1792" # DALL-E 3 竖版尺寸
  174. return {
  175. "image_num": image_num,
  176. "group_id": group_id,
  177. "prompt": prompt,
  178. "size": size,
  179. "quality": "hd",
  180. "descriptions": descriptions
  181. }
  182. def generate_all_prompts(self) -> List[Dict[str, Any]]:
  183. """生成所有9张图片的prompts"""
  184. prompts = []
  185. for i in range(1, 10):
  186. try:
  187. prompt_data = self.generate_prompt(i)
  188. prompts.append(prompt_data)
  189. print(f"✓ 图片 {i} prompt已生成")
  190. except Exception as e:
  191. print(f"✗ 图片 {i} 生成失败: {e}")
  192. return prompts
  193. def save_prompts(self, prompts: List[Dict], output_file: str = "output_1/prompts.json"):
  194. """保存生成的prompts到文件"""
  195. output_path = Path(output_file)
  196. output_path.parent.mkdir(parents=True, exist_ok=True)
  197. with open(output_path, 'w', encoding='utf-8') as f:
  198. json.dump(prompts, f, ensure_ascii=False, indent=2)
  199. print(f"\n✓ Prompts已保存到: {output_path}")
  200. def main():
  201. """主函数:生成并保存所有prompts"""
  202. print("=" * 60)
  203. print("图片还原 Prompt 生成系统")
  204. print("=" * 60)
  205. # 创建生成器
  206. generator = PromptGenerator()
  207. # 生成所有prompts
  208. print("\n开始生成prompts...")
  209. prompts = generator.generate_all_prompts()
  210. # 保存结果
  211. generator.save_prompts(prompts)
  212. # 打印预览
  213. print("\n" + "=" * 60)
  214. print("Prompt 预览(前3个):")
  215. print("=" * 60)
  216. for i, p in enumerate(prompts[:3], 1):
  217. print(f"\n图片 {i} ({p['group_id']}):")
  218. print(f"Prompt: {p['prompt'][:200]}...")
  219. print("\n" + "=" * 60)
  220. print(f"✓ 完成!共生成 {len(prompts)} 个prompts")
  221. print("=" * 60)
  222. if __name__ == "__main__":
  223. main()