evaluate_tool.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. """
  2. 图像质量评估工具
  3. 输入:需求文档路径 + 图片路径(单图或多图)+ 质量标准(可选)
  4. 输出:评分 + 详细反馈
  5. 通过多模态 VL 大模型对生成图像进行质量评估:
  6. - 单图模式:对照需求文档检查是否满足要求
  7. - 多图模式:检查跨图一致性(角色、服装、色调等)
  8. """
  9. import json
  10. from pathlib import Path
  11. from typing import Dict, Any, Optional, List, Union
  12. from agent.tools import tool, ToolResult
  13. from agent.llm import create_qwen_llm_call
  14. @tool(
  15. display={
  16. "zh": {"name": "图像质量评估", "params": {
  17. "requirement_path": "需求文档路径",
  18. "image_paths": "图片路径(单个字符串或列表)",
  19. "quality_criteria": "质量标准(可选)"
  20. }},
  21. "en": {"name": "Image Quality Evaluation", "params": {
  22. "requirement_path": "Requirement document path",
  23. "image_paths": "Image path(s) (string or list)",
  24. "quality_criteria": "Quality criteria (optional)"
  25. }},
  26. }
  27. )
  28. async def evaluate_image(
  29. requirement_path: str,
  30. image_paths: Union[str, List[str]],
  31. quality_criteria: Optional[str] = None
  32. ) -> ToolResult:
  33. """评估生成图像是否满足需求文档的要求
  34. 使用多模态 VL 大模型对生成图像进行质量评估:
  35. **单图模式**(传入单个路径字符串):
  36. - 姿态、服装、光影、背景等是否符合规格
  37. - 材质、细节的真实感
  38. - 整体构图和色调
  39. **多图模式**(传入路径列表):
  40. - 检查跨图一致性:角色外观、服装款式、色调风格是否统一
  41. - 识别不一致的图片并给出修复建议
  42. Args:
  43. requirement_path: 需求文档路径(JSON 或文本文件)
  44. image_paths: 待评估的图片路径(单个字符串或路径列表)
  45. quality_criteria: 额外的质量标准描述(可选)
  46. Returns:
  47. ToolResult 包含评分(0-10)和详细反馈
  48. """
  49. # 统一处理为列表
  50. if isinstance(image_paths, str):
  51. paths_list = [image_paths]
  52. is_multi_image = False
  53. else:
  54. paths_list = image_paths
  55. is_multi_image = len(paths_list) > 1
  56. # 1. 读取需求文档
  57. req_path = Path(requirement_path)
  58. if not req_path.exists():
  59. return ToolResult(
  60. title="评估失败",
  61. output="",
  62. error=f"需求文档不存在: {requirement_path}",
  63. )
  64. requirement_text = req_path.read_text(encoding="utf-8")
  65. # 如果是 JSON,尝试智能提取内容
  66. requirement_summary = requirement_text
  67. if requirement_path.endswith(".json"):
  68. try:
  69. req_data = json.loads(requirement_text)
  70. parts = []
  71. # 单图需求
  72. if "required_spec" in req_data:
  73. parts.append("## 单图需求\n" + json.dumps(req_data["required_spec"], ensure_ascii=False, indent=2))
  74. if "prompt" in req_data:
  75. parts.append(f"Prompt: {req_data['prompt']}")
  76. # 多图一致性需求
  77. if "consistency_checks" in req_data:
  78. parts.append("## 一致性检查标准\n" + json.dumps(req_data["consistency_checks"], ensure_ascii=False, indent=2))
  79. # 多图各自的需求(pipeline.json 整体)
  80. if "images" in req_data:
  81. img_specs = {}
  82. for img_id, img_data in req_data["images"].items():
  83. if "required_spec" in img_data:
  84. img_specs[img_id] = img_data["required_spec"]
  85. if img_specs:
  86. parts.append("## 各图需求规格\n" + json.dumps(img_specs, ensure_ascii=False, indent=2))
  87. if parts:
  88. requirement_summary = "\n\n".join(parts)
  89. except:
  90. pass
  91. # 2. 检查所有图片文件
  92. import base64
  93. image_contents = []
  94. missing_files = []
  95. for p_str in paths_list:
  96. p = Path(p_str)
  97. if not p.exists():
  98. missing_files.append(p_str)
  99. continue
  100. img_bytes = p.read_bytes()
  101. img_b64 = base64.b64encode(img_bytes).decode("utf-8")
  102. mime_type = "image/png"
  103. if p.suffix.lower() in (".jpg", ".jpeg"):
  104. mime_type = "image/jpeg"
  105. elif p.suffix.lower() == ".webp":
  106. mime_type = "image/webp"
  107. image_contents.append({
  108. "path": p_str,
  109. "b64": img_b64,
  110. "mime": mime_type,
  111. })
  112. if missing_files:
  113. return ToolResult(
  114. title="评估失败",
  115. output="",
  116. error=f"以下图片文件不存在: {', '.join(missing_files)}",
  117. )
  118. if not image_contents:
  119. return ToolResult(
  120. title="评估失败",
  121. output="",
  122. error="没有可评估的图片",
  123. )
  124. # 3. 构建评估 prompt(根据模式不同)
  125. if is_multi_image:
  126. image_labels = "\n".join([f"- 图片 {i+1}: {ic['path']}" for i, ic in enumerate(image_contents)])
  127. eval_prompt = f"""你是一个专业的图像质量评估专家。请对以下 {len(image_contents)} 张生成图像进行评估,重点检查**跨图一致性**。
  128. ## 需求文档
  129. {requirement_summary}
  130. ## 图片列表
  131. {image_labels}
  132. ## 质量标准
  133. {quality_criteria if quality_criteria else "按照需求文档中的一致性检查标准进行评估"}
  134. ## 评估维度
  135. ### A. 跨图一致性(每项 0-10 分)
  136. 1. **角色一致性**:所有图中的人物面部特征、发型、肤色是否保持一致
  137. 2. **服装一致性**:白色长裙的款式、材质、颜色是否 100% 统一
  138. 3. **色调一致性**:白绿配色方案、色彩饱和度是否贯穿所有图像
  139. 4. **光影一致性**:逆光/轮廓光方向、光晕效果是否统一
  140. 5. **风格一致性**:摄影风格、镜头参数感(85mm、f/1.8 景深)是否统一
  141. ### B. 单图质量(每张图 0-10 分)
  142. 对每张图分别给出质量评分。
  143. ## 输出格式
  144. 请严格按照以下 JSON 格式输出:
  145. ```json
  146. {{
  147. "overall_score": <0-10 的总分>,
  148. "consistency_scores": {{
  149. "character": <0-10>,
  150. "clothing": <0-10>,
  151. "color_scheme": <0-10>,
  152. "lighting": <0-10>,
  153. "style": <0-10>
  154. }},
  155. "per_image_scores": {{
  156. "图片1": <0-10>,
  157. "图片2": <0-10>
  158. }},
  159. "inconsistent_images": ["<列出不一致的图片编号及问题>"],
  160. "feedback": "<详细的文字反馈,指出一致性的优点和不足>",
  161. "suggestions": "<改进建议,哪些图需要重新生成、怎么调整>"
  162. }}
  163. ```
  164. 请仔细对比所有图像,给出客观、专业的评估。"""
  165. else:
  166. eval_prompt = f"""你是一个专业的图像质量评估专家。请根据以下需求文档,对生成的图像进行详细评估。
  167. ## 需求文档
  168. {requirement_summary}
  169. ## 质量标准
  170. {quality_criteria if quality_criteria else "按照需求文档中的 required_spec 和 prompt 描述进行评估"}
  171. ## 评估维度
  172. 请从以下维度评估图像质量(每项 0-10 分):
  173. 1. **姿态准确性**:人物姿态是否符合需求描述
  174. 2. **服装还原度**:服装款式、材质、细节是否符合要求
  175. 3. **光影效果**:光线方向、强度、轮廓光等是否符合描述
  176. 4. **背景一致性**:背景元素、虚化效果是否符合要求
  177. 5. **材质真实感**:服装、道具的材质是否真实自然
  178. 6. **整体构图**:构图、色调、氛围是否符合预期
  179. ## 输出格式
  180. 请严格按照以下 JSON 格式输出评估结果:
  181. ```json
  182. {{
  183. "overall_score": <0-10 的总分>,
  184. "dimension_scores": {{
  185. "pose": <0-10>,
  186. "clothing": <0-10>,
  187. "lighting": <0-10>,
  188. "background": <0-10>,
  189. "material": <0-10>,
  190. "composition": <0-10>
  191. }},
  192. "feedback": "<详细的文字反馈,指出优点和不足>",
  193. "suggestions": "<改进建议,如需调整哪些参数或换用哪些工具>"
  194. }}
  195. ```
  196. 请仔细观察图像,给出客观、专业的评估。"""
  197. # 4. 构建多模态消息
  198. content_parts = [{"type": "text", "text": eval_prompt}]
  199. for ic in image_contents:
  200. content_parts.append({
  201. "type": "image_url",
  202. "image_url": {
  203. "url": f"data:{ic['mime']};base64,{ic['b64']}"
  204. }
  205. })
  206. messages = [{"role": "user", "content": content_parts}]
  207. # 5. 调用 VL 模型
  208. try:
  209. llm_call = create_qwen_llm_call(model="qwen-vl-max")
  210. response = await llm_call(messages, model="qwen-vl-max", temperature=0.3)
  211. # 6. 解析评估结果
  212. response_text = response["content"].strip()
  213. # 提取 JSON
  214. if "```json" in response_text:
  215. json_start = response_text.find("```json") + 7
  216. json_end = response_text.find("```", json_start)
  217. json_str = response_text[json_start:json_end].strip()
  218. elif "```" in response_text:
  219. json_start = response_text.find("```") + 3
  220. json_end = response_text.find("```", json_start)
  221. json_str = response_text[json_start:json_end].strip()
  222. else:
  223. json_str = response_text
  224. eval_result = json.loads(json_str)
  225. # 7. 格式化输出
  226. output = {
  227. "mode": "multi_image_consistency" if is_multi_image else "single_image",
  228. "requirement_path": requirement_path,
  229. "image_paths": paths_list,
  230. "evaluation": eval_result,
  231. }
  232. overall_score = eval_result.get("overall_score", 0)
  233. image_count = len(paths_list)
  234. if is_multi_image:
  235. title = f"多图一致性评估完成({image_count} 张,总分: {overall_score}/10)"
  236. memory = f"Consistency evaluation of {image_count} images: score={overall_score}/10"
  237. else:
  238. title = f"图像评估完成(总分: {overall_score}/10)"
  239. memory = f"Evaluated {paths_list[0]}: score={overall_score}/10"
  240. return ToolResult(
  241. title=title,
  242. output=json.dumps(output, ensure_ascii=False, indent=2),
  243. long_term_memory=memory,
  244. )
  245. except json.JSONDecodeError as e:
  246. return ToolResult(
  247. title="评估完成(JSON 解析失败,返回原始文本)",
  248. output=f"LLM 返回内容:\n{response_text}",
  249. error=f"无法解析 LLM 返回的 JSON: {e}",
  250. )
  251. except Exception as e:
  252. return ToolResult(
  253. title="评估失败",
  254. output="",
  255. error=f"评估过程出错: {e}",
  256. )