wrapper.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. """
  2. Prompt Wrapper - 为 .prompt 文件提供 Prompt 实现
  3. 类似 Resonote 的 SimpleHPrompt,但增加了多模态支持
  4. """
  5. import base64
  6. from pathlib import Path
  7. from typing import List, Dict, Any, Union, Optional
  8. from agent.prompts.loader import load_prompt, get_message
  9. class SimplePrompt:
  10. """
  11. 通用的 Prompt 包装器
  12. 特性:
  13. - 加载 .prompt 文件(YAML frontmatter + sections)
  14. - 支持参数替换(%variable%)
  15. - 支持多模态消息(图片)
  16. 使用示例:
  17. # 纯文本
  18. prompt = SimplePrompt(Path("task.prompt"))
  19. messages = prompt.build_messages(text="内容")
  20. # 多模态(文本 + 图片)
  21. messages = prompt.build_messages(
  22. text="分析这张图片",
  23. images="path/to/image.png" # 或 images=["img1.png", "img2.png"]
  24. )
  25. """
  26. def __init__(self, prompt_path: Union[Path, str]):
  27. """
  28. Args:
  29. prompt_path: .prompt 文件路径
  30. """
  31. self.prompt_path = Path(prompt_path) if isinstance(prompt_path, str) else prompt_path
  32. # 加载 .prompt 文件
  33. self.config, self._messages = load_prompt(self.prompt_path)
  34. def build_messages(self, **context) -> List[Dict[str, Any]]:
  35. """
  36. 构造消息列表(支持多模态)
  37. Args:
  38. **context: 参数
  39. - 普通参数:用于替换 %variable%
  40. - images: 图片资源(可选)
  41. - 单个图片:str 或 Path
  42. - 多个图片:List[str | Path]
  43. - 格式:文件路径或 base64 字符串
  44. Returns:
  45. 消息列表,格式遵循 OpenAI API 规范
  46. Example:
  47. >>> messages = prompt.build_messages(
  48. ... text="特征描述",
  49. ... images="input/image.png"
  50. ... )
  51. [
  52. {"role": "system", "content": "..."},
  53. {
  54. "role": "user",
  55. "content": [
  56. {"type": "text", "text": "..."},
  57. {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}
  58. ]
  59. }
  60. ]
  61. """
  62. # 提取图片资源(从 context 中移除,避免传入 get_message)
  63. images = context.pop('images', None)
  64. # 构建文本内容(支持参数替换)
  65. system_content = get_message(self._messages, 'system', **context)
  66. user_content = get_message(self._messages, 'user', **context)
  67. messages = []
  68. # 添加 system 消息
  69. if system_content:
  70. messages.append({"role": "system", "content": system_content})
  71. # 添加 user 消息(可能是多模态)
  72. if images:
  73. # 多模态消息
  74. user_message = {"role": "user", "content": []}
  75. # 添加文本部分
  76. if user_content:
  77. user_message["content"].append({
  78. "type": "text",
  79. "text": user_content
  80. })
  81. # 添加图片部分
  82. if isinstance(images, (list, tuple)):
  83. for img in images:
  84. user_message["content"].append(self._build_image_content(img))
  85. else:
  86. user_message["content"].append(self._build_image_content(images))
  87. messages.append(user_message)
  88. else:
  89. # 纯文本消息
  90. if user_content:
  91. messages.append({"role": "user", "content": user_content})
  92. return messages
  93. def _build_image_content(self, image: Union[str, Path]) -> Dict[str, Any]:
  94. """
  95. 构建图片内容部分(OpenAI 格式)
  96. Args:
  97. image: 图片路径或 base64 字符串
  98. Returns:
  99. {"type": "image_url", "image_url": {"url": "data:..."}}
  100. """
  101. # 如果已经是 base64 data URL,直接使用
  102. if isinstance(image, str) and image.startswith("data:"):
  103. return {
  104. "type": "image_url",
  105. "image_url": {"url": image}
  106. }
  107. # 否则,读取文件并转为 base64
  108. image_path = Path(image) if isinstance(image, str) else image
  109. # 推断 MIME type
  110. suffix = image_path.suffix.lower()
  111. mime_type_map = {
  112. '.png': 'image/png',
  113. '.jpg': 'image/jpeg',
  114. '.jpeg': 'image/jpeg',
  115. '.gif': 'image/gif',
  116. '.webp': 'image/webp'
  117. }
  118. mime_type = mime_type_map.get(suffix, 'image/png')
  119. # 读取并编码
  120. with open(image_path, 'rb') as f:
  121. image_data = base64.b64encode(f.read()).decode('utf-8')
  122. data_url = f"data:{mime_type};base64,{image_data}"
  123. return {
  124. "type": "image_url",
  125. "image_url": {"url": data_url}
  126. }
  127. def create_prompt(prompt_path: Union[Path, str]) -> SimplePrompt:
  128. """
  129. 工厂函数:创建 SimplePrompt 实例
  130. Args:
  131. prompt_path: .prompt 文件路径
  132. Returns:
  133. SimplePrompt 实例
  134. """
  135. return SimplePrompt(prompt_path)