update_schema.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. """
  2. Schema 自动更新脚本
  3. 根据 prompt 文件中定义的输出格式,调用 GPT(via OpenRouter)自动生成/更新对应的 JSON Schema。
  4. 用法:
  5. python -m examples.process_pipeline.script.update_schema researcher
  6. python -m examples.process_pipeline.script.update_schema researcher --dry-run
  7. python -m examples.process_pipeline.script.update_schema researcher --model openai/gpt-5.4
  8. """
  9. import argparse
  10. import asyncio
  11. import json
  12. import os
  13. import re
  14. import sys
  15. from pathlib import Path
  16. from typing import Any, Dict, List, Optional
  17. PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent.parent
  18. sys.path.insert(0, str(PROJECT_ROOT))
  19. from dotenv import load_dotenv
  20. load_dotenv()
  21. PROMPTS_DIR = Path(__file__).resolve().parent.parent / "prompts"
  22. SYSTEM_PROMPT = """\
  23. 你是一个 JSON Schema 专家。你的任务是根据给定的 prompt 文件中定义的"输出格式"部分,\
  24. 生成或更新一个 JSON Schema (Draft-07)。
  25. ## Contract Suffix 约定
  26. 这个项目的 schema 使用特殊的字段名后缀来标记字段的"契约等级":
  27. 1. **`-boundary`** 后缀:标记"容器边界"字段。
  28. - 这些字段通常是 array 或 object 类型
  29. - 表示该容器内部的元素结构可以自由演化
  30. - 但容器本身(字段名、类型、是否 required)是稳定的
  31. - 例:`cases-boundary`, `steps-boundary`, `clusters-boundary`
  32. 2. **`-ref`** 后缀:标记"引用锚点"字段。
  33. - 这些字段的名称和类型是不可变的(被外部引用)
  34. - 通常是 id、name、url 等标识性字段
  35. - 例:`case_id-ref`, `source_url-ref`, `name-ref`
  36. 3. **无后缀**:普通内部字段,可以自由演化。
  37. ## 规则
  38. 1. 只关注 prompt 中"输出格式"或"Output Format"部分定义的 JSON 结构
  39. 2. 为每个字段选择合适的 contract suffix:
  40. - 顶层数组容器 → `-boundary`
  41. - ID、URL、name 等标识字段 → `-ref`
  42. - 其他字段 → 无后缀
  43. 3. 合理设置 `required` 字段(只包含核心必需字段)
  44. 4. 对于枚举值使用 `enum`
  45. 5. 对于可选字段使用 `type: ["string", "null"]` 或不放入 required
  46. 6. 输出必须是合法的 JSON Schema Draft-07,不要包含任何注释或额外文本
  47. 7. 如果提供了现有 schema,尽量保留其中合理的约束(如 pattern、enum),只更新结构变化的部分
  48. ## 输出要求
  49. 直接输出完整的 JSON Schema,不要包含 markdown 代码块标记或任何其他文本。\
  50. """
  51. def build_messages(prompt_content: str, existing_schema: Optional[dict]) -> List[Dict[str, Any]]:
  52. """构造发给 LLM 的消息列表"""
  53. user_parts = [f"## Prompt 文件内容\n\n{prompt_content}"]
  54. if existing_schema:
  55. user_parts.append(
  56. f"\n\n## 现有 Schema(请在此基础上更新)\n\n```json\n{json.dumps(existing_schema, ensure_ascii=False, indent=2)}\n```"
  57. )
  58. else:
  59. user_parts.append("\n\n## 现有 Schema\n\n无(请从零生成)")
  60. user_parts.append(
  61. "\n\n## 任务\n\n"
  62. "请根据上面 prompt 中定义的输出 JSON 结构,生成完整的 JSON Schema (Draft-07)。"
  63. "注意应用 contract suffix 约定(-boundary / -ref)。"
  64. "直接输出 JSON,不要包含任何其他文本。"
  65. )
  66. return [
  67. {"role": "system", "content": SYSTEM_PROMPT},
  68. {"role": "user", "content": "".join(user_parts)},
  69. ]
  70. def extract_json_from_response(content: str) -> dict:
  71. """从 LLM 响应中提取 JSON(处理可能的 markdown 代码块包裹)"""
  72. content = content.strip()
  73. # 去掉 markdown 代码块
  74. if content.startswith("```"):
  75. lines = content.split("\n")
  76. # 去掉首行 ```json 和末行 ```
  77. if lines[0].startswith("```"):
  78. lines = lines[1:]
  79. if lines and lines[-1].strip() == "```":
  80. lines = lines[:-1]
  81. content = "\n".join(lines)
  82. return json.loads(content)
  83. async def update_schema(
  84. prompt_name: str,
  85. model: str = "openai/gpt-5.4",
  86. dry_run: bool = False,
  87. ) -> dict:
  88. """
  89. 根据 prompt 文件更新对应的 schema。
  90. Args:
  91. prompt_name: prompt 名称(不含扩展名)
  92. model: 使用的模型
  93. dry_run: 如果为 True,只打印不写入
  94. Returns:
  95. 生成的 schema dict
  96. """
  97. from agent.llm.openrouter import create_openrouter_llm_call
  98. prompt_file = PROMPTS_DIR / f"{prompt_name}.prompt"
  99. schema_file = PROMPTS_DIR / f"{prompt_name}.schema.json"
  100. if not prompt_file.exists():
  101. raise FileNotFoundError(f"Prompt file not found: {prompt_file}")
  102. # 读取 prompt
  103. prompt_content = prompt_file.read_text(encoding="utf-8")
  104. # 读取现有 schema(如果有)
  105. existing_schema = None
  106. if schema_file.exists():
  107. try:
  108. existing_schema = json.loads(schema_file.read_text(encoding="utf-8"))
  109. except json.JSONDecodeError:
  110. print(f"⚠️ 现有 schema 文件 JSON 格式错误,将从零生成")
  111. # 构造消息
  112. messages = build_messages(prompt_content, existing_schema)
  113. # 调用 LLM
  114. llm_call = create_openrouter_llm_call(model=model)
  115. print(f"🤖 Calling {model} to generate schema for '{prompt_name}'...")
  116. result = await llm_call(messages=messages, model=model, temperature=0.1)
  117. content = result.get("content", "")
  118. if not content:
  119. raise ValueError("LLM returned empty response")
  120. # 解析 JSON
  121. try:
  122. new_schema = extract_json_from_response(content)
  123. except json.JSONDecodeError as e:
  124. print(f"❌ LLM 返回的内容不是合法 JSON:")
  125. print(content[:500])
  126. raise ValueError(f"Failed to parse schema from LLM response: {e}")
  127. # 验证生成的 schema 本身是否合法
  128. try:
  129. import jsonschema
  130. jsonschema.Draft7Validator.check_schema(new_schema)
  131. except jsonschema.SchemaError as e:
  132. print(f"⚠️ 生成的 schema 不是合法的 Draft-07: {e.message}")
  133. print("仍然输出结果,但请手动检查。")
  134. # 输出
  135. schema_json = json.dumps(new_schema, ensure_ascii=False, indent=2)
  136. if dry_run:
  137. print(f"\n{'='*60}")
  138. print(f"[Dry Run] Generated schema for '{prompt_name}':")
  139. print(f"{'='*60}")
  140. print(schema_json)
  141. else:
  142. schema_file.write_text(schema_json + "\n", encoding="utf-8")
  143. print(f"✅ Schema written to: {schema_file}")
  144. # 打印 diff 摘要
  145. if existing_schema:
  146. old_keys = set(_flatten_keys(existing_schema))
  147. new_keys = set(_flatten_keys(new_schema))
  148. added = new_keys - old_keys
  149. removed = old_keys - new_keys
  150. if added:
  151. print(f" + Added: {', '.join(sorted(added)[:10])}")
  152. if removed:
  153. print(f" - Removed: {', '.join(sorted(removed)[:10])}")
  154. if not added and not removed:
  155. print(f" (no structural changes)")
  156. return new_schema
  157. def _flatten_keys(obj: Any, prefix: str = "") -> List[str]:
  158. """递归提取 schema 中所有 properties 的 key 路径"""
  159. keys = []
  160. if isinstance(obj, dict):
  161. if "properties" in obj:
  162. for k, v in obj["properties"].items():
  163. full_key = f"{prefix}.{k}" if prefix else k
  164. keys.append(full_key)
  165. keys.extend(_flatten_keys(v, full_key))
  166. if "items" in obj:
  167. keys.extend(_flatten_keys(obj["items"], f"{prefix}[]"))
  168. for variant_key in ("oneOf", "anyOf", "allOf"):
  169. if variant_key in obj:
  170. for i, variant in enumerate(obj[variant_key]):
  171. keys.extend(_flatten_keys(variant, f"{prefix}|{i}"))
  172. return keys
  173. def main():
  174. parser = argparse.ArgumentParser(description="根据 prompt 自动更新 JSON Schema")
  175. parser.add_argument("prompt_name", help="Prompt 名称(不含 .prompt 扩展名)")
  176. parser.add_argument("--model", default="openai/gpt-5.4", help="使用的模型(默认 openai/gpt-5.4)")
  177. parser.add_argument("--dry-run", action="store_true", help="只打印生成结果,不写入文件")
  178. args = parser.parse_args()
  179. asyncio.run(update_schema(
  180. prompt_name=args.prompt_name,
  181. model=args.model,
  182. dry_run=args.dry_run,
  183. ))
  184. if __name__ == "__main__":
  185. main()