loader.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. """
  2. Prompt Loader
  3. 支持 .prompt 文件格式:
  4. - YAML frontmatter 配置(model, temperature等)
  5. - $section$ 分节语法
  6. - %variable% 参数替换
  7. 格式:
  8. ---
  9. model: gemini-2.5-flash
  10. temperature: 0.3
  11. ---
  12. $system$
  13. 系统提示...
  14. $user$
  15. 用户提示...
  16. %variable%
  17. """
  18. import re
  19. import yaml
  20. from pathlib import Path
  21. from typing import Dict, Any, Tuple, Union
  22. import logging
  23. logger = logging.getLogger(__name__)
  24. def load_prompt(path: Union[Path, str]) -> Tuple[Dict[str, Any], Dict[str, str]]:
  25. """
  26. 加载 .prompt 文件
  27. Args:
  28. path: .prompt 文件路径
  29. Returns:
  30. (config, messages)
  31. - config: 配置字典 {'model': 'gemini-2.5-flash', 'temperature': 0.3, ...}
  32. - messages: 消息字典 {'system': '...', 'user': '...'}
  33. Raises:
  34. FileNotFoundError: 文件不存在
  35. ValueError: 文件格式错误
  36. Example:
  37. >>> config, messages = load_prompt(Path("task.prompt"))
  38. >>> config['model']
  39. 'gemini-2.5-flash'
  40. >>> messages['system']
  41. '你是一位计算机视觉专家...'
  42. """
  43. path = Path(path) if isinstance(path, str) else path
  44. if not path.exists():
  45. raise FileNotFoundError(f".prompt 文件不存在: {path}")
  46. try:
  47. content = path.read_text(encoding='utf-8')
  48. except Exception as e:
  49. raise ValueError(f"读取 .prompt 文件失败: {e}")
  50. # 解析文件
  51. config, messages = _parse_prompt(content)
  52. logger.debug(f"加载 .prompt 文件: {path}, 配置项: {len(config)}, 消息段: {len(messages)}")
  53. return config, messages
  54. def _parse_prompt(content: str) -> Tuple[Dict[str, Any], Dict[str, str]]:
  55. """
  56. 解析 .prompt 文件内容
  57. 格式:
  58. ---
  59. model: gemini-2.5-flash
  60. temperature: 0.3
  61. ---
  62. $system$
  63. 系统提示...
  64. $user$
  65. 用户提示...
  66. """
  67. # 1. 分离 YAML frontmatter 和正文
  68. parts = content.split('---', 2)
  69. if len(parts) < 3:
  70. raise ValueError(".prompt 文件格式错误: 缺少 YAML frontmatter(需要 --- 包裹)")
  71. # 2. 解析 YAML 配置
  72. try:
  73. config = yaml.safe_load(parts[1]) or {}
  74. except yaml.YAMLError as e:
  75. raise ValueError(f".prompt 文件 YAML 解析失败: {e}")
  76. # 3. 解析正文(按 $section$ 分割)
  77. body = parts[2]
  78. messages = _parse_sections(body)
  79. return config, messages
  80. def _parse_sections(body: str) -> Dict[str, str]:
  81. """
  82. 解析 .prompt 正文分节
  83. 支持语法:
  84. - $section$ (如 $system$, $user$)
  85. Args:
  86. body: .prompt 正文内容
  87. Returns:
  88. 消息字典 {'system': '...', 'user': '...'}
  89. Example:
  90. >>> body = "$system$\\n你好\\n$user$\\n世界"
  91. >>> _parse_sections(body)
  92. {'system': '你好', 'user': '世界'}
  93. """
  94. messages = {}
  95. # 使用正则表达式分割(匹配 $key$)
  96. pattern = r'\$([^$]+)\$'
  97. parts = re.split(pattern, body)
  98. # parts 格式:['前置空白', 'key1', '内容1', 'key2', '内容2', ...]
  99. # 跳过 parts[0](前置空白)
  100. i = 1
  101. while i < len(parts):
  102. if i + 1 >= len(parts):
  103. break
  104. key = parts[i].strip()
  105. value = parts[i + 1].strip()
  106. if key and value:
  107. messages[key] = value
  108. i += 2
  109. if not messages:
  110. logger.warning(".prompt 文件没有找到任何分节($section$)")
  111. return messages
  112. def get_message(messages: Dict[str, str], key: str, **params) -> str:
  113. """
  114. 获取消息(带参数替换)
  115. 参数替换:
  116. - 使用 %variable% 语法
  117. - 直接字符串替换
  118. Args:
  119. messages: 消息字典(来自 load_prompt)
  120. key: 消息键(如 'system', 'user')
  121. **params: 参数替换(如 text='内容')
  122. Returns:
  123. 替换后的消息字符串
  124. Example:
  125. >>> messages = {'user': '特征:%text%'}
  126. >>> get_message(messages, 'user', text='整体构图')
  127. '特征:整体构图'
  128. """
  129. message = messages.get(key, "")
  130. if not message:
  131. logger.warning(f".prompt 消息未找到: key='{key}'")
  132. return ""
  133. # 参数替换(%variable% 直接替换)
  134. if params:
  135. try:
  136. for param_name, param_value in params.items():
  137. placeholder = f"%{param_name}%"
  138. if placeholder in message:
  139. message = message.replace(placeholder, str(param_value))
  140. except Exception as e:
  141. logger.error(f".prompt 参数替换错误: key='{key}', error={e}")
  142. return message