| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235 |
- """
- Schema 自动更新脚本
- 根据 prompt 文件中定义的输出格式,调用 GPT(via OpenRouter)自动生成/更新对应的 JSON Schema。
- 用法:
- python -m examples.process_pipeline.script.update_schema researcher
- python -m examples.process_pipeline.script.update_schema researcher --dry-run
- python -m examples.process_pipeline.script.update_schema researcher --model openai/gpt-5.4
- """
- import argparse
- import asyncio
- import json
- import os
- import re
- import sys
- from pathlib import Path
- from typing import Any, Dict, List, Optional
- PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent.parent
- sys.path.insert(0, str(PROJECT_ROOT))
- from dotenv import load_dotenv
- load_dotenv()
- PROMPTS_DIR = Path(__file__).resolve().parent.parent / "prompts"
- SYSTEM_PROMPT = """\
- 你是一个 JSON Schema 专家。你的任务是根据给定的 prompt 文件中定义的"输出格式"部分,\
- 生成或更新一个 JSON Schema (Draft-07)。
- ## Contract Suffix 约定
- 这个项目的 schema 使用特殊的字段名后缀来标记字段的"契约等级":
- 1. **`-boundary`** 后缀:标记"容器边界"字段。
- - 这些字段通常是 array 或 object 类型
- - 表示该容器内部的元素结构可以自由演化
- - 但容器本身(字段名、类型、是否 required)是稳定的
- - 例:`cases-boundary`, `steps-boundary`, `clusters-boundary`
- 2. **`-ref`** 后缀:标记"引用锚点"字段。
- - 这些字段的名称和类型是不可变的(被外部引用)
- - 通常是 id、name、url 等标识性字段
- - 例:`case_id-ref`, `source_url-ref`, `name-ref`
- 3. **无后缀**:普通内部字段,可以自由演化。
- ## 规则
- 1. 只关注 prompt 中"输出格式"或"Output Format"部分定义的 JSON 结构
- 2. 为每个字段选择合适的 contract suffix:
- - 顶层数组容器 → `-boundary`
- - ID、URL、name 等标识字段 → `-ref`
- - 其他字段 → 无后缀
- 3. 合理设置 `required` 字段(只包含核心必需字段)
- 4. 对于枚举值使用 `enum`
- 5. 对于可选字段使用 `type: ["string", "null"]` 或不放入 required
- 6. 输出必须是合法的 JSON Schema Draft-07,不要包含任何注释或额外文本
- 7. 如果提供了现有 schema,尽量保留其中合理的约束(如 pattern、enum),只更新结构变化的部分
- ## 输出要求
- 直接输出完整的 JSON Schema,不要包含 markdown 代码块标记或任何其他文本。\
- """
- def build_messages(prompt_content: str, existing_schema: Optional[dict]) -> List[Dict[str, Any]]:
- """构造发给 LLM 的消息列表"""
- user_parts = [f"## Prompt 文件内容\n\n{prompt_content}"]
- if existing_schema:
- user_parts.append(
- f"\n\n## 现有 Schema(请在此基础上更新)\n\n```json\n{json.dumps(existing_schema, ensure_ascii=False, indent=2)}\n```"
- )
- else:
- user_parts.append("\n\n## 现有 Schema\n\n无(请从零生成)")
- user_parts.append(
- "\n\n## 任务\n\n"
- "请根据上面 prompt 中定义的输出 JSON 结构,生成完整的 JSON Schema (Draft-07)。"
- "注意应用 contract suffix 约定(-boundary / -ref)。"
- "直接输出 JSON,不要包含任何其他文本。"
- )
- return [
- {"role": "system", "content": SYSTEM_PROMPT},
- {"role": "user", "content": "".join(user_parts)},
- ]
- def extract_json_from_response(content: str) -> dict:
- """从 LLM 响应中提取 JSON(处理可能的 markdown 代码块包裹)"""
- content = content.strip()
- # 去掉 markdown 代码块
- if content.startswith("```"):
- lines = content.split("\n")
- # 去掉首行 ```json 和末行 ```
- if lines[0].startswith("```"):
- lines = lines[1:]
- if lines and lines[-1].strip() == "```":
- lines = lines[:-1]
- content = "\n".join(lines)
- return json.loads(content)
- async def update_schema(
- prompt_name: str,
- model: str = "openai/gpt-5.4",
- dry_run: bool = False,
- ) -> dict:
- """
- 根据 prompt 文件更新对应的 schema。
- Args:
- prompt_name: prompt 名称(不含扩展名)
- model: 使用的模型
- dry_run: 如果为 True,只打印不写入
- Returns:
- 生成的 schema dict
- """
- from agent.llm.openrouter import create_openrouter_llm_call
- prompt_file = PROMPTS_DIR / f"{prompt_name}.prompt"
- schema_file = PROMPTS_DIR / f"{prompt_name}.schema.json"
- if not prompt_file.exists():
- raise FileNotFoundError(f"Prompt file not found: {prompt_file}")
- # 读取 prompt
- prompt_content = prompt_file.read_text(encoding="utf-8")
- # 读取现有 schema(如果有)
- existing_schema = None
- if schema_file.exists():
- try:
- existing_schema = json.loads(schema_file.read_text(encoding="utf-8"))
- except json.JSONDecodeError:
- print(f"⚠️ 现有 schema 文件 JSON 格式错误,将从零生成")
- # 构造消息
- messages = build_messages(prompt_content, existing_schema)
- # 调用 LLM
- llm_call = create_openrouter_llm_call(model=model)
- print(f"🤖 Calling {model} to generate schema for '{prompt_name}'...")
- result = await llm_call(messages=messages, model=model, temperature=0.1)
- content = result.get("content", "")
- if not content:
- raise ValueError("LLM returned empty response")
- # 解析 JSON
- try:
- new_schema = extract_json_from_response(content)
- except json.JSONDecodeError as e:
- print(f"❌ LLM 返回的内容不是合法 JSON:")
- print(content[:500])
- raise ValueError(f"Failed to parse schema from LLM response: {e}")
- # 验证生成的 schema 本身是否合法
- try:
- import jsonschema
- jsonschema.Draft7Validator.check_schema(new_schema)
- except jsonschema.SchemaError as e:
- print(f"⚠️ 生成的 schema 不是合法的 Draft-07: {e.message}")
- print("仍然输出结果,但请手动检查。")
- # 输出
- schema_json = json.dumps(new_schema, ensure_ascii=False, indent=2)
- if dry_run:
- print(f"\n{'='*60}")
- print(f"[Dry Run] Generated schema for '{prompt_name}':")
- print(f"{'='*60}")
- print(schema_json)
- else:
- schema_file.write_text(schema_json + "\n", encoding="utf-8")
- print(f"✅ Schema written to: {schema_file}")
- # 打印 diff 摘要
- if existing_schema:
- old_keys = set(_flatten_keys(existing_schema))
- new_keys = set(_flatten_keys(new_schema))
- added = new_keys - old_keys
- removed = old_keys - new_keys
- if added:
- print(f" + Added: {', '.join(sorted(added)[:10])}")
- if removed:
- print(f" - Removed: {', '.join(sorted(removed)[:10])}")
- if not added and not removed:
- print(f" (no structural changes)")
- return new_schema
- def _flatten_keys(obj: Any, prefix: str = "") -> List[str]:
- """递归提取 schema 中所有 properties 的 key 路径"""
- keys = []
- if isinstance(obj, dict):
- if "properties" in obj:
- for k, v in obj["properties"].items():
- full_key = f"{prefix}.{k}" if prefix else k
- keys.append(full_key)
- keys.extend(_flatten_keys(v, full_key))
- if "items" in obj:
- keys.extend(_flatten_keys(obj["items"], f"{prefix}[]"))
- for variant_key in ("oneOf", "anyOf", "allOf"):
- if variant_key in obj:
- for i, variant in enumerate(obj[variant_key]):
- keys.extend(_flatten_keys(variant, f"{prefix}|{i}"))
- return keys
- def main():
- parser = argparse.ArgumentParser(description="根据 prompt 自动更新 JSON Schema")
- parser.add_argument("prompt_name", help="Prompt 名称(不含 .prompt 扩展名)")
- parser.add_argument("--model", default="openai/gpt-5.4", help="使用的模型(默认 openai/gpt-5.4)")
- parser.add_argument("--dry-run", action="store_true", help="只打印生成结果,不写入文件")
- args = parser.parse_args()
- asyncio.run(update_schema(
- prompt_name=args.prompt_name,
- model=args.model,
- dry_run=args.dry_run,
- ))
- if __name__ == "__main__":
- main()
|