| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330 |
- """
- Pipeline 输出验证脚本(Schema 驱动)
- 所有结构验证通过 .schema.json 文件驱动,不硬编码字段名。
- 非结构性检查(去重、计数一致性、引用完整性)拆到 validate_invariants 中。
- 文件名 → schema 映射:
- case_*.json → researcher.schema.json
- source.json → source.schema.json
- case_detailed.json → case_detailed.schema.json
- blueprint_temp.json → process_cluster.schema.json
- process.json → process.schema.json
- capabilities_temp.json → extract_capabilities.schema.json
- capabilities.json → capabilities.schema.json
- capabilities_extracted → capabilities_extracted.schema.json
- strategy.json → assemble_strategy.schema.json
- """
- import json
- from pathlib import Path
- import re
- import argparse
- from typing import Dict, List, Optional, Set, Tuple
- from .schema_manager import validate_with_schema, get_schema_manager
- VALID_PLATFORMS = {"xhs", "youtube", "bili", "x", "zhihu", "gzh"}
- # ── 文件名 → schema 名映射 ──────────────────────────────
- FILENAME_SCHEMA_MAP = {
- "source.json": "source",
- "case_detailed.json": "case_detailed",
- "blueprint_temp.json": "process_cluster",
- "process.json": "process",
- "capabilities_temp.json": "extract_capabilities",
- "capabilities.json": "capabilities",
- "capabilities_extracted.json": "capabilities_extracted",
- "strategy.json": "assemble_strategy",
- }
- def _resolve_schema_name(filename: str) -> Optional[str]:
- """根据文件名解析对应的 schema 名称"""
- if filename in FILENAME_SCHEMA_MAP:
- return FILENAME_SCHEMA_MAP[filename]
- if filename.startswith("case_") and filename.endswith(".json"):
- return "researcher"
- return None
- # ── Schema 驱动的验证函数 ──────────────────────────────
- # 每个函数只做一件事:调用 validate_with_schema
- # 公共签名保持不变:validate_X(data) -> Optional[str]
- def validate_case(data):
- """验证 case_*.json(Phase 1 输出)"""
- return validate_with_schema(data, "researcher")
- def validate_source(data):
- """验证 source.json(Phase 1.5 输出)"""
- return validate_with_schema(data, "source")
- def validate_case_detailed(data):
- """验证 case_detailed.json(Phase 1.6 输出)"""
- return validate_with_schema(data, "case_detailed")
- def validate_blueprint_temp(data):
- """验证 blueprint_temp.json(Phase 2.1.1 输出)"""
- return validate_with_schema(data, "process_cluster")
- def validate_process(data):
- """验证 process.json(Phase 2.1.2 输出)"""
- return validate_with_schema(data, "process")
- def validate_blueprint(data):
- """[Legacy] blueprint.json 已废弃,保留函数签名兼容旧调用"""
- return None
- def validate_capabilities_temp(data):
- """验证 capabilities_temp.json(Phase 2.2.1 输出)"""
- return validate_with_schema(data, "extract_capabilities")
- def validate_capabilities_enriched(data):
- """验证 capabilities.json(Phase 2.2.2 输出)"""
- return validate_with_schema(data, "capabilities")
- def validate_capabilities(data):
- """验证 capabilities_extracted.json(Phase 2 输出)"""
- return validate_with_schema(data, "capabilities_extracted")
- def validate_strategy(data):
- """验证 strategy.json(Phase 3 输出)"""
- return validate_with_schema(data, "assemble_strategy")
- # ── 非结构性检查(invariants) ──────────────────────────
- def is_valid_case_id(case_id: str) -> bool:
- """检查 case_id 是否为 {platform}_{content_id} 格式"""
- if not case_id or "_" not in case_id:
- return False
- platform = case_id.split("_", 1)[0]
- return platform in VALID_PLATFORMS
- def validate_invariants_source(data) -> Optional[str]:
- """source.json 的非结构性检查:去重、total 一致性"""
- sources = data.get("sources", [])
- seen_ids: Set[str] = set()
- for i, src in enumerate(sources):
- p = src.get("platform", "")
- cid = src.get("channel_content_id", "")
- if p and cid:
- dedup_key = f"{p}_{cid}"
- if dedup_key in seen_ids:
- return f"sources[{i}] duplicate: {dedup_key}"
- seen_ids.add(dedup_key)
- total = data.get("total")
- if total is not None and total != len(sources):
- return f"total ({total}) != len(sources) ({len(sources)})"
- return None
- def validate_invariants_case_detailed(data) -> Optional[str]:
- """case_detailed.json 的非结构性检查:去重、计数一致性"""
- cases = data.get("cases", [])
- seen_ids: Set[str] = set()
- success_count = 0
- for i, c in enumerate(cases):
- p = c.get("platform", "")
- cid = c.get("channel_content_id", "")
- if p and cid:
- dedup_key = f"{p}_{cid}"
- if dedup_key in seen_ids:
- return f"cases[{i}] duplicate: {dedup_key}"
- seen_ids.add(dedup_key)
- if c.get("workflow") is not None:
- success_count += 1
- total = data.get("total")
- if total is not None and total != len(cases):
- return f"total ({total}) != len(cases) ({len(cases)})"
- success = data.get("success")
- if success is not None and success != success_count:
- return f"success ({success}) != actual success count ({success_count})"
- return None
- # ── 跨文件引用检查 ──────────────────────────────────────
- def collect_valid_case_ids(raw_cases_dir: Path) -> Set[str]:
- """从 source.json 和 case_*.json 收集所有有效的 case_id"""
- valid_ids = set()
- source_file = raw_cases_dir / "source.json"
- if source_file.exists():
- try:
- with open(source_file, "r", encoding="utf-8") as f:
- data = json.load(f)
- for src in data.get("sources", []):
- p = src.get("platform")
- cid = src.get("channel_content_id")
- if p and cid:
- valid_ids.add(f"{p}_{cid}")
- case_id = src.get("case_id")
- if case_id:
- valid_ids.add(case_id)
- except Exception:
- pass
- for case_file in raw_cases_dir.glob("case_*.json"):
- if case_file.name in ("case_detailed.json",):
- continue
- try:
- with open(case_file, "r", encoding="utf-8") as f:
- data = json.load(f)
- for c in data.get("cases", []):
- case_id = c.get("case_id")
- if case_id:
- valid_ids.add(case_id)
- except Exception:
- pass
- return valid_ids
- def check_referential_integrity(req_dir: Path) -> List[Tuple[Path, str]]:
- """检查跨文件的引用一致性"""
- errors = []
- for filename in ["blueprint.json", "capabilities_extracted.json", "strategy.json"]:
- file_path = req_dir / filename
- if not file_path.exists():
- continue
- try:
- with open(file_path, "r", encoding="utf-8") as f:
- content = f.read()
- legacy_refs = re.findall(r'\bcase_\d{3}\b', content)
- for ref in set(legacy_refs):
- errors.append((file_path, f"Legacy reference: {ref} (should use {{platform}}_{{content_id}} format)"))
- except Exception:
- pass
- return errors
- # ── 文件缺失检查 ──────────────────────────────────────
- def check_missing_files(base_dir: Path) -> List[Tuple[str, str]]:
- """检查每个需求目录是否缺少必需的文件"""
- missing_files = []
- req_dirs = sorted([d for d in base_dir.iterdir() if d.is_dir() and d.name.isdigit()])
- for req_dir in req_dirs:
- req_id = req_dir.name
- required = {
- "raw_cases": req_dir / "raw_cases",
- "blueprint.json": req_dir / "blueprint.json",
- "capabilities_extracted.json": req_dir / "capabilities_extracted.json",
- "strategy.json": req_dir / "strategy.json"
- }
- for file_name, file_path in required.items():
- if file_name == "raw_cases":
- if not file_path.exists():
- missing_files.append((req_id, "raw_cases directory missing"))
- elif not list(file_path.glob("case_*.json")):
- missing_files.append((req_id, "raw_cases has no case files"))
- else:
- if not file_path.exists():
- missing_files.append((req_id, f"{file_name} missing"))
- return missing_files
- # ── 主入口 ──────────────────────────────────────
- def main():
- parser = argparse.ArgumentParser()
- parser.add_argument("--dir", default="output", help="Directory to validate")
- args = parser.parse_args()
- base_dir = Path(__file__).parent.parent / args.dir
- if not base_dir.exists():
- print(f"Error: {base_dir} does not exist.")
- return
- print(f"[Start] Checking for missing files...")
- missing_files = check_missing_files(base_dir)
- if missing_files:
- print(f"[WARNING] Found {len(missing_files)} missing files:")
- for req_id, issue in missing_files:
- print(f" - REQ_{req_id}: {issue}")
- else:
- print("[OK] All required files are present.")
- print("-" * 50)
- json_files = list(base_dir.rglob("*.json"))
- total_files = len(json_files)
- format_errors = []
- # 引用完整性检查
- req_dirs = sorted([d for d in base_dir.iterdir() if d.is_dir() and d.name.isdigit()])
- for req_dir in req_dirs:
- ref_errors = check_referential_integrity(req_dir)
- for path, err in ref_errors:
- rel_path = path.relative_to(base_dir.parent)
- format_errors.append((rel_path, f"Referential Integrity: {err}"))
- print(f"[Start] Validating schema for {total_files} JSON files...")
- for file_path in json_files:
- try:
- with open(file_path, "r", encoding="utf-8") as f:
- data = json.load(f)
- except Exception as e:
- format_errors.append((file_path, f"JSON Parsing Error: {e}"))
- continue
- filename = file_path.name
- rel_path = file_path.relative_to(base_dir.parent)
- # Schema 结构验证
- schema_name = _resolve_schema_name(filename)
- if schema_name:
- err = validate_with_schema(data, schema_name)
- if err:
- format_errors.append((rel_path, f"Schema mismatch: {err}"))
- # Invariant 检查
- if filename == "source.json":
- err = validate_invariants_source(data)
- if err:
- format_errors.append((rel_path, f"Invariant: {err}"))
- elif filename == "case_detailed.json":
- err = validate_invariants_case_detailed(data)
- if err:
- format_errors.append((rel_path, f"Invariant: {err}"))
- report_path = Path(__file__).parent / "schema_errors_report.txt"
- print("-" * 50)
- with open(report_path, "w", encoding="utf-8") as out_f:
- if not format_errors:
- msg = f"[OK] All {total_files} JSON files match their expected schemas!"
- print(msg)
- out_f.write(msg + "\n")
- else:
- msg = f"[ERROR] Found {len(format_errors)} files with issues:"
- print(msg)
- out_f.write(msg + "\n")
- for path, error in format_errors:
- print(f" - {path}: {error}")
- out_f.write(f" - {path}: {error}\n")
- print("-" * 50)
- print(f"Report saved to {report_path}")
- if __name__ == "__main__":
- main()
|