schema_manager.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. """
  2. Schema 管理工具:统一加载和验证 JSON Schema
  3. 设计原则:
  4. 1. 每个 prompt 文件对应一个 .schema.json 文件
  5. 2. Schema 文件和 prompt 文件放在同一目录
  6. 3. 每个 schema 文件完全独立,不依赖外部引用
  7. 4. 所有验证逻辑都通过 jsonschema 库自动完成,不再硬编码字段名
  8. 后缀约定(用于标注契约边界):
  9. - `-boundary`: 容器字段,名称不可变,内部元素可演进(如 abilities-boundary)
  10. - `-ref`: 被外部直接引用的字段,名称和类型都不可变(如 ability_id-ref)
  11. - 无后缀: 内部字段,可自由演进
  12. 校验时自动剥离后缀,实际匹配的 key 是去掉后缀的版本。
  13. """
  14. import json
  15. from pathlib import Path
  16. from typing import Any, Dict, Optional, Tuple
  17. import copy
  18. try:
  19. import jsonschema
  20. from jsonschema import Draft7Validator, ValidationError
  21. JSONSCHEMA_AVAILABLE = True
  22. except ImportError:
  23. JSONSCHEMA_AVAILABLE = False
  24. print("Warning: jsonschema not installed. Run: pip install jsonschema")
  25. class SchemaManager:
  26. """Schema 管理器,负责加载和验证 JSON Schema"""
  27. CONTRACT_SUFFIXES = ["-boundary", "-ref"]
  28. def __init__(self, prompts_dir: Path):
  29. """
  30. 初始化 Schema 管理器
  31. Args:
  32. prompts_dir: prompts 目录路径
  33. """
  34. self.prompts_dir = Path(prompts_dir)
  35. self._schema_cache: Dict[str, Dict] = {}
  36. def load_schema(self, prompt_name: str) -> Optional[Dict]:
  37. """
  38. 加载指定 prompt 对应的 schema
  39. Args:
  40. prompt_name: prompt 文件名(不含 .prompt 后缀)
  41. Returns:
  42. Schema 字典,如果文件不存在则返回 None
  43. """
  44. # 检查缓存
  45. if prompt_name in self._schema_cache:
  46. return self._schema_cache[prompt_name]
  47. # 加载 schema 文件(先找 prompts/,再找 prompts/temp_schema/)
  48. schema_file = self.prompts_dir / f"{prompt_name}.schema.json"
  49. if not schema_file.exists():
  50. schema_file = self.prompts_dir / "temp_schema" / f"{prompt_name}.schema.json"
  51. if not schema_file.exists():
  52. return None
  53. try:
  54. with open(schema_file, "r", encoding="utf-8") as f:
  55. schema = json.load(f)
  56. self._schema_cache[prompt_name] = schema
  57. return schema
  58. except Exception as e:
  59. print(f"Error loading schema {schema_file}: {e}")
  60. return None
  61. @classmethod
  62. def _strip_suffix(cls, key: str) -> str:
  63. """剥离契约后缀,返回实际字段名"""
  64. for suffix in cls.CONTRACT_SUFFIXES:
  65. if key.endswith(suffix):
  66. return key[:-len(suffix)]
  67. return key
  68. @classmethod
  69. def _strip_schema(cls, schema: Any) -> Any:
  70. """
  71. 递归遍历 schema,将所有带后缀的 key 替换为剥离后的版本。
  72. 返回一份新的 schema(不修改原始对象)。
  73. """
  74. if isinstance(schema, dict):
  75. result = {}
  76. for k, v in schema.items():
  77. new_key = k
  78. # 只对 properties 和 required 里的 key 做剥离
  79. if k == "properties":
  80. # properties 的 value 是 {field_name: field_schema}
  81. result[k] = {
  82. cls._strip_suffix(fk): cls._strip_schema(fv)
  83. for fk, fv in v.items()
  84. }
  85. elif k == "required":
  86. # required 是字段名数组
  87. result[k] = [cls._strip_suffix(r) for r in v]
  88. else:
  89. result[k] = cls._strip_schema(v)
  90. return result
  91. elif isinstance(schema, list):
  92. return [cls._strip_schema(item) for item in schema]
  93. else:
  94. return schema
  95. def validate(self, data: Any, prompt_name: str) -> Tuple[bool, Optional[str]]:
  96. """
  97. 使用 JSON Schema 验证数据
  98. Args:
  99. data: 要验证的数据
  100. prompt_name: prompt 文件名(不含 .prompt 后缀)
  101. Returns:
  102. (is_valid, error_message) 元组
  103. """
  104. if not JSONSCHEMA_AVAILABLE:
  105. return True, None
  106. schema = self.load_schema(prompt_name)
  107. if schema is None:
  108. return True, None
  109. try:
  110. clean_schema = self._strip_schema(schema)
  111. validator = Draft7Validator(clean_schema)
  112. validator.validate(data)
  113. return True, None
  114. except ValidationError as e:
  115. path = ".".join(str(p) for p in e.absolute_path) if e.absolute_path else "root"
  116. return False, f"{path}: {e.message}"
  117. except Exception as e:
  118. return False, f"Validation error: {str(e)}"
  119. def get_example_output(self, prompt_name: str) -> Optional[Dict]:
  120. """
  121. 从 schema 中提取示例输出(如果有的话)
  122. Args:
  123. prompt_name: prompt 文件名(不含 .prompt 后缀)
  124. Returns:
  125. 示例输出字典,如果没有则返回 None
  126. """
  127. schema = self.load_schema(prompt_name)
  128. if schema is None:
  129. return None
  130. # 尝试从 schema 中提取 examples
  131. if "examples" in schema:
  132. return schema["examples"][0] if schema["examples"] else None
  133. # 或者根据 schema 生成一个最小示例
  134. return self._generate_minimal_example(schema)
  135. def get_stripped_schema(self, prompt_name: str) -> Optional[Dict]:
  136. """
  137. 获取剥离后缀的 schema(用于传给 LLM 的 response_format)
  138. Args:
  139. prompt_name: prompt 文件名(不含 .prompt 后缀)
  140. Returns:
  141. 剥离后缀的 schema 字典,如果文件不存在则返回 None
  142. """
  143. schema = self.load_schema(prompt_name)
  144. if schema is None:
  145. return None
  146. stripped = self._strip_schema(schema)
  147. return self._inline_local_refs(stripped)
  148. @classmethod
  149. def _inline_local_refs(cls, schema: Any, root: Optional[Dict] = None) -> Any:
  150. """
  151. 展开 schema 内的本地 $ref,避免结构化输出 strict mode 对 $ref 支持不一致。
  152. 仅支持形如 "#/definitions/xxx" 或 "#/$defs/xxx" 的本地引用。
  153. """
  154. if root is None and isinstance(schema, dict):
  155. root = schema
  156. if isinstance(schema, list):
  157. return [cls._inline_local_refs(item, root) for item in schema]
  158. if not isinstance(schema, dict):
  159. return schema
  160. if set(schema.keys()) == {"$ref"} and isinstance(schema.get("$ref"), str):
  161. ref = schema["$ref"]
  162. if ref.startswith("#/") and root is not None:
  163. target: Any = root
  164. for part in ref[2:].split("/"):
  165. part = part.replace("~1", "/").replace("~0", "~")
  166. target = target[part]
  167. return cls._inline_local_refs(copy.deepcopy(target), root)
  168. result = {}
  169. for key, value in schema.items():
  170. if key in ("definitions", "$defs"):
  171. continue
  172. result[key] = cls._inline_local_refs(value, root)
  173. return result
  174. def _generate_minimal_example(self, schema: Dict) -> Dict:
  175. """
  176. 根据 schema 生成一个最小示例
  177. Args:
  178. schema: JSON Schema 字典
  179. Returns:
  180. 最小示例字典
  181. """
  182. if schema.get("type") != "object":
  183. return {}
  184. example = {}
  185. required = schema.get("required", [])
  186. properties = schema.get("properties", {})
  187. for key in required:
  188. if key in properties:
  189. prop = properties[key]
  190. prop_type = prop.get("type")
  191. if prop_type == "string":
  192. example[key] = prop.get("examples", [""])[0] if "examples" in prop else ""
  193. elif prop_type == "integer":
  194. example[key] = prop.get("examples", [0])[0] if "examples" in prop else 0
  195. elif prop_type == "boolean":
  196. example[key] = prop.get("default", False)
  197. elif prop_type == "array":
  198. example[key] = []
  199. elif prop_type == "object":
  200. example[key] = {}
  201. elif isinstance(prop_type, list) and "null" in prop_type:
  202. example[key] = None
  203. return example
  204. # 全局单例
  205. _schema_manager: Optional[SchemaManager] = None
  206. def get_schema_manager(prompts_dir: Optional[Path] = None) -> SchemaManager:
  207. """
  208. 获取全局 Schema 管理器单例
  209. Args:
  210. prompts_dir: prompts 目录路径(首次调用时必须提供)
  211. Returns:
  212. SchemaManager 实例
  213. """
  214. global _schema_manager
  215. if _schema_manager is None:
  216. if prompts_dir is None:
  217. # 默认路径
  218. base_dir = Path(__file__).parent.parent
  219. prompts_dir = base_dir / "prompts"
  220. _schema_manager = SchemaManager(prompts_dir)
  221. return _schema_manager
  222. def validate_with_schema(data: Any, prompt_name: str) -> Optional[str]:
  223. """
  224. 便捷函数:使用 schema 验证数据
  225. Args:
  226. data: 要验证的数据
  227. prompt_name: prompt 文件名(不含 .prompt 后缀)
  228. Returns:
  229. 错误消息字符串,如果验证通过则返回 None
  230. """
  231. manager = get_schema_manager()
  232. is_valid, error = manager.validate(data, prompt_name)
  233. return error if not is_valid else None
  234. # 示例用法
  235. if __name__ == "__main__":
  236. # 测试加载 schema
  237. manager = get_schema_manager()
  238. # 测试 extract_workflow schema
  239. schema = manager.load_schema("extract_workflow")
  240. if schema:
  241. print("✓ Loaded extract_workflow.schema.json")
  242. # 测试验证
  243. test_data = {
  244. "id": "strategy-001",
  245. "name": "测试工序",
  246. "description": "这是一个测试",
  247. "modality": "图文",
  248. "inputs": {},
  249. "outputs": {},
  250. "steps": [
  251. {
  252. "order": 1,
  253. "type": "capability",
  254. "description": "测试步骤",
  255. "inputs": {},
  256. "outputs": {}
  257. }
  258. ]
  259. }
  260. is_valid, error = manager.validate(test_data, "extract_workflow")
  261. if is_valid:
  262. print("✓ Validation passed")
  263. else:
  264. print(f"✗ Validation failed: {error}")