| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190 |
- """
- Prompt Loader
- 支持 .prompt 文件格式:
- - YAML frontmatter 配置(model, temperature等)
- - $section$ 分节语法
- - %variable% 参数替换
- 格式:
- ---
- model: gemini-2.5-flash
- temperature: 0.3
- ---
- $system$
- 系统提示...
- $user$
- 用户提示...
- %variable%
- """
- import re
- import yaml
- from pathlib import Path
- from typing import Dict, Any, Tuple, Union
- import logging
- logger = logging.getLogger(__name__)
- def load_prompt(path: Union[Path, str]) -> Tuple[Dict[str, Any], Dict[str, str]]:
- """
- 加载 .prompt 文件
- Args:
- path: .prompt 文件路径
- Returns:
- (config, messages)
- - config: 配置字典 {'model': 'gemini-2.5-flash', 'temperature': 0.3, ...}
- - messages: 消息字典 {'system': '...', 'user': '...'}
- Raises:
- FileNotFoundError: 文件不存在
- ValueError: 文件格式错误
- Example:
- >>> config, messages = load_prompt(Path("task.prompt"))
- >>> config['model']
- 'gemini-2.5-flash'
- >>> messages['system']
- '你是一位计算机视觉专家...'
- """
- path = Path(path) if isinstance(path, str) else path
- if not path.exists():
- raise FileNotFoundError(f".prompt 文件不存在: {path}")
- try:
- content = path.read_text(encoding='utf-8')
- except Exception as e:
- raise ValueError(f"读取 .prompt 文件失败: {e}")
- # 解析文件
- config, messages = _parse_prompt(content)
- logger.debug(f"加载 .prompt 文件: {path}, 配置项: {len(config)}, 消息段: {len(messages)}")
- return config, messages
- def _parse_prompt(content: str) -> Tuple[Dict[str, Any], Dict[str, str]]:
- """
- 解析 .prompt 文件内容
- 格式:
- ---
- model: gemini-2.5-flash
- temperature: 0.3
- ---
- $system$
- 系统提示...
- $user$
- 用户提示...
- """
- # 1. 分离 YAML frontmatter 和正文
- parts = content.split('---', 2)
- if len(parts) < 3:
- raise ValueError(".prompt 文件格式错误: 缺少 YAML frontmatter(需要 --- 包裹)")
- # 2. 解析 YAML 配置
- try:
- config = yaml.safe_load(parts[1]) or {}
- except yaml.YAMLError as e:
- raise ValueError(f".prompt 文件 YAML 解析失败: {e}")
- # 3. 解析正文(按 $section$ 分割)
- body = parts[2]
- messages = _parse_sections(body)
- return config, messages
- def _parse_sections(body: str) -> Dict[str, str]:
- """
- 解析 .prompt 正文分节
- 支持语法:
- - $section$ (如 $system$, $user$)
- Args:
- body: .prompt 正文内容
- Returns:
- 消息字典 {'system': '...', 'user': '...'}
- Example:
- >>> body = "$system$\\n你好\\n$user$\\n世界"
- >>> _parse_sections(body)
- {'system': '你好', 'user': '世界'}
- """
- messages = {}
- # 使用正则表达式分割(匹配 $key$)
- pattern = r'\$([^$]+)\$'
- parts = re.split(pattern, body)
- # parts 格式:['前置空白', 'key1', '内容1', 'key2', '内容2', ...]
- # 跳过 parts[0](前置空白)
- i = 1
- while i < len(parts):
- if i + 1 >= len(parts):
- break
- key = parts[i].strip()
- value = parts[i + 1].strip()
- if key and value:
- messages[key] = value
- i += 2
- if not messages:
- logger.warning(".prompt 文件没有找到任何分节($section$)")
- return messages
- def get_message(messages: Dict[str, str], key: str, **params) -> str:
- """
- 获取消息(带参数替换)
- 参数替换:
- - 使用 %variable% 语法
- - 直接字符串替换
- Args:
- messages: 消息字典(来自 load_prompt)
- key: 消息键(如 'system', 'user')
- **params: 参数替换(如 text='内容')
- Returns:
- 替换后的消息字符串
- Example:
- >>> messages = {'user': '特征:%text%'}
- >>> get_message(messages, 'user', text='整体构图')
- '特征:整体构图'
- """
- message = messages.get(key, "")
- if not message:
- logger.warning(f".prompt 消息未找到: key='{key}'")
- return ""
- # 参数替换(%variable% 直接替换)
- if params:
- try:
- for param_name, param_value in params.items():
- placeholder = f"%{param_name}%"
- if placeholder in message:
- message = message.replace(placeholder, str(param_value))
- except Exception as e:
- logger.error(f".prompt 参数替换错误: key='{key}', error={e}")
- return message
|