| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535 |
- import yaml
- import logging
- import fnmatch
- import os
- import asyncio
- import re
- from sqlalchemy.orm import Session
- from app.models import Project, DataVersion, DataFile, DataRecord
- from app.services.gogs_client import GogsClient
- from app.services.storage_service import StorageService
- logger = logging.getLogger(__name__)
- def normalize_path(path: str) -> str:
- """Normalize path by removing ./ prefix."""
- path = path.strip()
- if path.startswith("./"):
- path = path[2:]
- return path
- def is_directory_pattern(path: str) -> bool:
- """Check if the path pattern represents a directory."""
- return path.endswith("/")
- class WebhookService:
- def __init__(self, db: Session):
- self.db = db
- self.gogs = GogsClient()
- self.storage = StorageService(db, self.gogs)
- async def process_webhook(self, payload: dict):
- # 1. Parse payload
- ref = payload.get("ref")
- if not ref:
- logger.warning("No ref in payload")
- return
- after_sha = payload.get("after")
- repo = payload.get("repository", {})
- repo_name = repo.get("name")
- owner = repo.get("owner", {}).get("username")
- pusher = payload.get("pusher", {})
- author_name = pusher.get("username", "unknown")
- if not after_sha or not repo_name or not owner:
- logger.error("Invalid payload: missing essential fields")
- return
- logger.info(f"Processing push for {owner}/{repo_name} commit {after_sha}")
- branch_name = self._extract_branch_name(ref)
- # 2. Get manifest
- manifest_content = await self.gogs.get_manifest(owner, repo_name, after_sha)
- if not manifest_content:
- logger.info("No manifest.yaml found. Skipping.")
- return
- try:
- manifest = yaml.safe_load(manifest_content)
- except yaml.YAMLError as e:
- logger.error(f"Failed to parse manifest: {e}")
- return
- # 3. Validation
- project_name = manifest.get("project_name")
- if not project_name:
- logger.error("Manifest missing project_name")
- return
- # 4. Get or create project
- project = self.storage.get_or_create_project(project_name)
- # 5. Get all changed files from payload for pre-filtering
- all_changed_files = self._get_all_changed_files(payload)
- manifest_changed = "manifest.yaml" in all_changed_files
- logger.info(f"Detected {len(all_changed_files)} changed files. Manifest changed: {manifest_changed}")
- # 6. Process stages
- stages = manifest.get("stages", [])
- if not stages:
- logger.error("Manifest missing stages configuration")
- return
- for stage_config in stages:
- stage_name = stage_config.get("name")
- outputs = stage_config.get("outputs", [])
- if not stage_name:
- logger.warning("Stage missing name, skipping")
- continue
- # --- PRE-FILTERING LOGIC ---
- # Skip if manifest hasn't changed AND no files in this stage's scope have changed
- if not manifest_changed and not self._is_stage_affected(outputs, all_changed_files):
- logger.info(f"Stage '{stage_name}': No relevant files changed. Skipping processing.")
- continue
- # Check if this stage+commit already processed (idempotency)
- existing_version = self.db.query(DataVersion).filter(
- DataVersion.project_id == project.id,
- DataVersion.stage == stage_name,
- DataVersion.commit_id == after_sha
- ).first()
- if existing_version:
- logger.info(f"Stage '{stage_name}' already processed. Skipping.")
- continue
- # Get commit message from payload if available
- commit_msg = None
- commits = payload.get("commits", [])
- if commits:
- commit_msg = commits[0].get("message")
- # Create version for this stage
- version = self.storage.create_version(
- project.id, stage_name, after_sha, author_name, manifest=manifest_content, commit_message=commit_msg
- )
- if not version:
- logger.info(f"Stage '{stage_name}' (commit {after_sha[:8]}) is already being processed. Skipping.")
- continue
- logger.info(f"Processing stage '{stage_name}' with {len(outputs)} output rules")
- # Process ONLY changed files that match output rules (no directory tree fetching)
- has_new_uploads = await self._process_outputs(
- version,
- outputs,
- owner,
- repo_name,
- ref=after_sha,
- fallback_ref=branch_name,
- changed_files=all_changed_files,
- )
- # Check if this version represents a real change (content OR file set)
- if not self.storage.is_snapshot_changed(version, has_new_uploads):
- # No changes detected — this was a code-only push, discard the snapshot
- self.storage.rollback_version(version)
- logger.info(
- f"Stage '{stage_name}': no data changes detected (content and file set same). "
- f"Version discarded."
- )
- else:
- self.storage.aggregate_version_records(version)
- # ── Backfill: supplement missing paired inputs in recent versions ──
- # Handles the case where output files were committed BEFORE their
- # paired input files. When the input file arrives in a later push,
- # we retroactively attach it to the older records that were missing it.
- for stage_config in stages:
- stage_name = stage_config.get("name")
- outputs = stage_config.get("outputs", [])
- if not stage_name or not outputs:
- continue
- has_paired = any(
- oc.get("paired_input") or oc.get("paired_inputs")
- for oc in outputs
- )
- if not has_paired:
- continue
- if not manifest_changed and not self._is_stage_affected(outputs, all_changed_files):
- continue
- await self._backfill_incomplete_records(
- project.id, stage_name, outputs,
- owner, repo_name, after_sha,
- )
- # Close the shared HTTP client (connection pool)
- await self.gogs.aclose()
- def _get_all_changed_files(self, payload: dict) -> set[str]:
- """Extract all added, modified, and removed files from all commits in payload."""
- files = set()
- commits = payload.get("commits", [])
- for commit in commits:
- # for key in ["added", "modified", "removed"]:
- for key in ["added", "modified"]:
- for f in (commit.get(key) or []):
- files.add(normalize_path(f))
- return files
- @staticmethod
- def _extract_branch_name(git_ref: str | None) -> str | None:
- """Extract branch name from webhook ref, e.g. refs/heads/main -> main."""
- if not git_ref:
- return None
- prefix = "refs/heads/"
- if git_ref.startswith(prefix):
- return git_ref[len(prefix):]
- return git_ref
- def _is_stage_affected(self, outputs: list, changed_files: set[str]) -> bool:
- """Check if any of the changed files fall under the scope of the stage's outputs."""
- if not changed_files:
- return True # Fallback: if we can't tell what changed, process it
- for output in outputs:
- path_pattern = normalize_path(output.get("path", ""))
- is_dir = is_directory_pattern(output.get("path", ""))
- for f in changed_files:
- if is_dir:
- # If it's a directory output, any change inside that directory counts
- dir_path = path_pattern.rstrip("/")
- if '*' in dir_path:
- import fnmatch
- if fnmatch.fnmatch(f, dir_path + "/*") or fnmatch.fnmatch(f, dir_path):
- return True
- else:
- if f == dir_path or f.startswith(dir_path + "/"):
- return True
- else:
- # Single file output: exact match
- if f == path_pattern:
- return True
- return False
- def _find_matching_output(self, file_path: str, outputs: list) -> dict | None:
- """Check if a file path matches any manifest output rule using LOCAL logic only.
- No Gogs API calls are made — this is pure string/glob matching.
- Returns the matching output config dict, or None.
- """
- for output in outputs:
- raw_path = output.get("path", "")
- path_pattern = normalize_path(raw_path)
- is_dir = is_directory_pattern(raw_path)
- patterns = output.get("pattern", "*")
- excludes = output.get("exclude")
- if is_dir:
- dir_path = path_pattern.rstrip("/")
- if '*' in dir_path:
- if not fnmatch.fnmatch(file_path, dir_path + "/*") and not fnmatch.fnmatch(file_path, dir_path):
- continue
- else:
- if not file_path.startswith(dir_path + "/"):
- continue
- filename = os.path.basename(file_path)
- if self._match_patterns(filename, patterns, excludes):
- return output
- else:
- if file_path == path_pattern:
- return output
- return None
- async def _fetch_and_process_file(
- self, version, file_path: str, output_config: dict,
- owner: str, repo_name: str, ref: str, fallback_ref: str | None,
- processed_keys: set
- ) -> bool:
- """Get file SHA from Gogs and process a single changed file, plus its paired input if configured."""
- # Pre-collect paired inputs to help determine grouping logic
- paired_configs = list(output_config.get("paired_inputs", []))
- if "paired_input" in output_config:
- paired_configs.append(output_config["paired_input"])
- # The group_key is ALWAYS the output file's exact relative_path.
- # This guarantees 1 Output : N Inputs mapping strictly.
- group_key = file_path
- # Deduplicate API calls and DB entries across concurrently running tasks
- task_key = (file_path, group_key)
- if task_key in processed_keys:
- return False
- processed_keys.add(task_key)
- file_info = await self.gogs.get_file_info(owner, repo_name, file_path, ref=ref)
- if not file_info:
- logger.info(f"File {file_path} not found at ref {ref[:8]} (removed). Skipping.")
- return False
- has_change = await self.storage.process_file_with_sha(
- version, file_path, file_info.get("sha"), owner, repo_name,
- direction=output_config.get("direction"),
- label=output_config.get("label"),
- extract_json_key=output_config.get("extract_json_key"),
- group_key=group_key,
- content_ref=file_info.get("ref", ref),
- )
- for paired_config in paired_configs:
- extract_regex = paired_config.get("extract_regex")
- path_template = paired_config.get("path_template")
- if extract_regex and path_template:
- match = re.search(extract_regex, file_path)
- if match:
- # Construct paired file path using named capture groups
- try:
- paired_path = path_template.format(**match.groupdict())
- except KeyError as e:
- logger.error(f"Failed to format paired_path: missing {e} in regex match for {file_path}")
- paired_path = None
-
- if paired_path:
- # Deduplicate paired input fetches
- paired_task_key = (paired_path, group_key)
- if paired_task_key in processed_keys:
- continue
- processed_keys.add(paired_task_key)
- # Actively fetch paired file info from Gogs
- paired_info = await self.gogs.get_file_info(
- owner,
- repo_name,
- paired_path,
- ref=ref,
- fallback_ref=fallback_ref,
- )
- if paired_info:
- paired_changed = await self.storage.process_file_with_sha(
- version, paired_path, paired_info.get("sha"), owner, repo_name,
- direction=paired_config.get("direction", "input"),
- label=paired_config.get("label"),
- extract_json_key=paired_config.get("extract_json_key"),
- group_key=group_key, # Link them together!
- content_ref=paired_info.get("ref", ref),
- )
- has_change = has_change or paired_changed
- else:
- logger.warning(f"Paired input file not found at ref {ref[:8]}: {paired_path}")
- return has_change
- async def _backfill_incomplete_records(
- self, project_id: str, stage_name: str, outputs: list,
- owner: str, repo_name: str, current_commit: str,
- ):
- """Backfill paired inputs that were missing when records were first created.
- When output files are committed before their paired input files, the
- initial records will have empty inputs. This method finds those
- incomplete records and tries to fetch the now-available paired inputs
- using the *current* commit (which may contain files added later).
- """
- recent_versions = (
- self.db.query(DataVersion)
- .filter(
- DataVersion.project_id == project_id,
- DataVersion.stage == stage_name,
- )
- .order_by(DataVersion.created_at.desc())
- .limit(20)
- .all()
- )
- # Cache Gogs lookups so we don't fetch the same path twice
- file_info_cache: dict[str, dict | None] = {}
- for version in recent_versions:
- records = (
- self.db.query(DataRecord)
- .filter(DataRecord.version_id == version.id)
- .all()
- )
- needs_reaggregate = False
- for record in records:
- out_path = (
- record.outputs[0]["relative_path"]
- if record.outputs
- else None
- )
- if not out_path:
- continue
- output_config = self._find_matching_output(out_path, outputs)
- if not output_config:
- continue
- paired_configs = list(output_config.get("paired_inputs", []))
- if "paired_input" in output_config:
- paired_configs.append(output_config["paired_input"])
- if not paired_configs:
- continue
- existing_input_paths = {
- inp["relative_path"] for inp in (record.inputs or [])
- }
- for pc in paired_configs:
- extract_regex = pc.get("extract_regex")
- path_template = pc.get("path_template")
- if not extract_regex or not path_template:
- continue
- match = re.search(extract_regex, out_path)
- if not match:
- continue
- try:
- paired_path = path_template.format(**match.groupdict())
- except KeyError:
- continue
- # Already present in this record's inputs
- if paired_path in existing_input_paths:
- continue
- # DataFile already exists but not yet reflected in record
- existing_df = (
- self.db.query(DataFile)
- .filter(
- DataFile.version_id == version.id,
- DataFile.relative_path == paired_path,
- DataFile.group_key == record.group_key,
- )
- .first()
- )
- if existing_df:
- needs_reaggregate = True
- continue
- # Fetch from Gogs (with cache)
- if paired_path not in file_info_cache:
- file_info_cache[paired_path] = (
- await self.gogs.get_file_info(
- owner, repo_name, paired_path,
- ref=current_commit,
- )
- )
- paired_info = file_info_cache[paired_path]
- if not paired_info:
- continue # still not available
- await self.storage.process_file_with_sha(
- version,
- paired_path,
- paired_info.get("sha"),
- owner,
- repo_name,
- direction=pc.get("direction", "input"),
- label=pc.get("label"),
- extract_json_key=pc.get("extract_json_key"),
- group_key=record.group_key,
- content_ref=paired_info.get("ref", current_commit),
- )
- needs_reaggregate = True
- logger.info(
- f"Backfilled paired input {paired_path} "
- f"for version {version.id} (commit {version.commit_id[:8]})"
- )
- if needs_reaggregate:
- self.storage.aggregate_version_records(version)
- logger.info(
- f"Re-aggregated version {version.id} after backfilling paired inputs"
- )
- async def _process_outputs(
- self,
- version,
- outputs: list,
- owner: str,
- repo_name: str,
- ref: str,
- fallback_ref: str | None,
- changed_files: set[str],
- ) -> bool:
- """Process ONLY changed files that match manifest output rules.
- Instead of fetching entire directory trees from Gogs API (slow),
- we match the webhook payload's changed-file list against manifest
- rules using LOCAL string/glob logic — zero API calls for matching.
- Returns True if at least one file had actual content changes.
- """
- # Step 1: Local matching — zero API calls
- matched_files = []
- for file_path in changed_files:
- matched_output = self._find_matching_output(file_path, outputs)
- if matched_output is not None:
- matched_files.append((file_path, matched_output))
- if not matched_files:
- logger.info("No changed files matched any output rule.")
- return False
- logger.info(f"Matched {len(matched_files)} changed file(s) to output rules. Processing in parallel.")
- # Step 2: Fetch file info + download/upload in parallel
- processed_keys = set()
- tasks = [
- self._fetch_and_process_file(
- version, fp, oc, owner, repo_name, ref, fallback_ref, processed_keys
- )
- for fp, oc in matched_files
- ]
- has_changes = False
- results = await asyncio.gather(*tasks, return_exceptions=True)
- for i, res in enumerate(results):
- if isinstance(res, Exception):
- logger.error(f"Error processing {matched_files[i][0]}: {res}")
- elif res is True:
- has_changes = True
- return has_changes
- def _match_patterns(
- self,
- filename: str,
- include_patterns: str | list[str],
- exclude_patterns: str | list[str] | None = None,
- ) -> bool:
- """Helper to match filename against multiple include and exclude glob patterns."""
- # Normalize to lists
- includes = (
- [include_patterns] if isinstance(include_patterns, str) else include_patterns
- )
- excludes = []
- if exclude_patterns:
- excludes = (
- [exclude_patterns] if isinstance(exclude_patterns, str) else exclude_patterns
- )
- # 1. Check if it matches ANY include pattern (OR logic)
- is_included = any(fnmatch.fnmatch(filename, p) for p in includes)
- if not is_included:
- return False
- # 2. Check if it matches ANY exclude pattern (OR logic: any match means reject)
- is_excluded = any(fnmatch.fnmatch(filename, p) for p in excludes)
- return not is_excluded
|