| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402 |
- import yaml
- import logging
- import fnmatch
- import os
- import asyncio
- import re
- from sqlalchemy.orm import Session
- from app.models import Project, DataVersion
- from app.services.gogs_client import GogsClient
- from app.services.storage_service import StorageService
- logger = logging.getLogger(__name__)
- def normalize_path(path: str) -> str:
- """Normalize path by removing ./ prefix."""
- path = path.strip()
- if path.startswith("./"):
- path = path[2:]
- return path
- def is_directory_pattern(path: str) -> bool:
- """Check if the path pattern represents a directory."""
- return path.endswith("/")
- class WebhookService:
- def __init__(self, db: Session):
- self.db = db
- self.gogs = GogsClient()
- self.storage = StorageService(db, self.gogs)
- async def process_webhook(self, payload: dict):
- # 1. Parse payload
- ref = payload.get("ref")
- if not ref:
- logger.warning("No ref in payload")
- return
- after_sha = payload.get("after")
- repo = payload.get("repository", {})
- repo_name = repo.get("name")
- owner = repo.get("owner", {}).get("username")
- pusher = payload.get("pusher", {})
- author_name = pusher.get("username", "unknown")
- if not after_sha or not repo_name or not owner:
- logger.error("Invalid payload: missing essential fields")
- return
- logger.info(f"Processing push for {owner}/{repo_name} commit {after_sha}")
- branch_name = self._extract_branch_name(ref)
- # 2. Get manifest
- manifest_content = await self.gogs.get_manifest(owner, repo_name, after_sha)
- if not manifest_content:
- logger.info("No manifest.yaml found. Skipping.")
- return
- try:
- manifest = yaml.safe_load(manifest_content)
- except yaml.YAMLError as e:
- logger.error(f"Failed to parse manifest: {e}")
- return
- # 3. Validation
- project_name = manifest.get("project_name")
- if not project_name:
- logger.error("Manifest missing project_name")
- return
- # 4. Get or create project
- project = self.storage.get_or_create_project(project_name)
- # 5. Get all changed files from payload for pre-filtering
- all_changed_files = self._get_all_changed_files(payload)
- manifest_changed = "manifest.yaml" in all_changed_files
- logger.info(f"Detected {len(all_changed_files)} changed files. Manifest changed: {manifest_changed}")
- # 6. Process stages
- stages = manifest.get("stages", [])
- # Backward compatibility: old single-stage format
- if not stages and manifest.get("stage"):
- stages = [{
- "name": manifest.get("stage"),
- "outputs": manifest.get("outputs", [])
- }]
- if not stages:
- logger.error("Manifest missing stages configuration")
- return
- for stage_config in stages:
- stage_name = stage_config.get("name")
- outputs = stage_config.get("outputs", [])
- if not stage_name:
- logger.warning("Stage missing name, skipping")
- continue
- # --- PRE-FILTERING LOGIC ---
- # Skip if manifest hasn't changed AND no files in this stage's scope have changed
- if not manifest_changed and not self._is_stage_affected(outputs, all_changed_files):
- logger.info(f"Stage '{stage_name}': No relevant files changed. Skipping processing.")
- continue
- # Check if this stage+commit already processed (idempotency)
- existing_version = self.db.query(DataVersion).filter(
- DataVersion.project_id == project.id,
- DataVersion.stage == stage_name,
- DataVersion.commit_id == after_sha
- ).first()
- if existing_version:
- logger.info(f"Stage '{stage_name}' already processed. Skipping.")
- continue
- # Get commit message from payload if available
- commit_msg = None
- commits = payload.get("commits", [])
- if commits:
- commit_msg = commits[0].get("message")
- # Create version for this stage
- version = self.storage.create_version(
- project.id, stage_name, after_sha, author_name, manifest=manifest_content, commit_message=commit_msg
- )
- if not version:
- logger.info(f"Stage '{stage_name}' (commit {after_sha[:8]}) is already being processed. Skipping.")
- continue
- logger.info(f"Processing stage '{stage_name}' with {len(outputs)} output rules")
- # Process ONLY changed files that match output rules (no directory tree fetching)
- has_new_uploads = await self._process_outputs(
- version,
- outputs,
- owner,
- repo_name,
- ref=after_sha,
- fallback_ref=branch_name,
- changed_files=all_changed_files,
- )
- # Check if this version represents a real change (content OR file set)
- if not self.storage.is_snapshot_changed(version, has_new_uploads):
- # No changes detected — this was a code-only push, discard the snapshot
- self.storage.rollback_version(version)
- logger.info(
- f"Stage '{stage_name}': no data changes detected (content and file set same). "
- f"Version discarded."
- )
- else:
- self.storage.aggregate_version_records(version)
- def _get_all_changed_files(self, payload: dict) -> set[str]:
- """Extract all added, modified, and removed files from all commits in payload."""
- files = set()
- commits = payload.get("commits", [])
- for commit in commits:
- # for key in ["added", "modified", "removed"]:
- for key in ["added", "modified"]:
- for f in (commit.get(key) or []):
- files.add(normalize_path(f))
- return files
- @staticmethod
- def _extract_branch_name(git_ref: str | None) -> str | None:
- """Extract branch name from webhook ref, e.g. refs/heads/main -> main."""
- if not git_ref:
- return None
- prefix = "refs/heads/"
- if git_ref.startswith(prefix):
- return git_ref[len(prefix):]
- return git_ref
- def _is_stage_affected(self, outputs: list, changed_files: set[str]) -> bool:
- """Check if any of the changed files fall under the scope of the stage's outputs."""
- if not changed_files:
- return True # Fallback: if we can't tell what changed, process it
- for output in outputs:
- path_pattern = normalize_path(output.get("path", ""))
- is_dir = is_directory_pattern(output.get("path", ""))
- for f in changed_files:
- if is_dir:
- # If it's a directory output, any change inside that directory counts
- dir_path = path_pattern.rstrip("/")
- if '*' in dir_path:
- import fnmatch
- if fnmatch.fnmatch(f, dir_path + "/*") or fnmatch.fnmatch(f, dir_path):
- return True
- else:
- if f == dir_path or f.startswith(dir_path + "/"):
- return True
- else:
- # Single file output: exact match
- if f == path_pattern:
- return True
- return False
- def _find_matching_output(self, file_path: str, outputs: list) -> dict | None:
- """Check if a file path matches any manifest output rule using LOCAL logic only.
- No Gogs API calls are made — this is pure string/glob matching.
- Returns the matching output config dict, or None.
- """
- for output in outputs:
- raw_path = output.get("path", "")
- path_pattern = normalize_path(raw_path)
- is_dir = is_directory_pattern(raw_path)
- patterns = output.get("pattern", "*")
- excludes = output.get("exclude")
- if is_dir:
- dir_path = path_pattern.rstrip("/")
- if '*' in dir_path:
- if not fnmatch.fnmatch(file_path, dir_path + "/*") and not fnmatch.fnmatch(file_path, dir_path):
- continue
- else:
- if not file_path.startswith(dir_path + "/"):
- continue
- filename = os.path.basename(file_path)
- if self._match_patterns(filename, patterns, excludes):
- return output
- else:
- if file_path == path_pattern:
- return output
- return None
- async def _fetch_and_process_file(
- self, version, file_path: str, output_config: dict,
- owner: str, repo_name: str, ref: str, fallback_ref: str | None,
- processed_keys: set
- ) -> bool:
- """Get file SHA from Gogs and process a single changed file, plus its paired input if configured."""
- # Pre-collect paired inputs to help determine grouping logic
- paired_configs = list(output_config.get("paired_inputs", []))
- if "paired_input" in output_config:
- paired_configs.append(output_config["paired_input"])
- # Calculate group_key here so both paired input and output can share it
- directory_depth = output_config.get("directory_depth")
- if directory_depth is not None and directory_depth > 0:
- parts = file_path.split("/")
- if len(parts) > 1:
- group_key = "/".join(parts[:-1][:directory_depth])
- else:
- group_key = ""
- elif paired_configs:
- # If we have paired inputs, use the full file path as a unique group key
- # to avoid "cross-talk" where multiple outputs in the same directory
- # share all paired inputs from that directory.
- group_key = file_path
- else:
- group_key = os.path.dirname(file_path)
- # Deduplicate API calls and DB entries across concurrently running tasks
- task_key = (file_path, group_key)
- if task_key in processed_keys:
- return False
- processed_keys.add(task_key)
- file_info = await self.gogs.get_file_info(owner, repo_name, file_path, ref=ref)
- if not file_info:
- logger.info(f"File {file_path} not found at ref {ref[:8]} (removed). Skipping.")
- return False
- has_change = await self.storage.process_file_with_sha(
- version, file_path, file_info.get("sha"), owner, repo_name,
- direction=output_config.get("direction"),
- label=output_config.get("label"),
- extract_json_key=output_config.get("extract_json_key"),
- directory_depth=directory_depth,
- group_key=group_key,
- content_ref=file_info.get("ref", ref),
- )
- for paired_config in paired_configs:
- extract_regex = paired_config.get("extract_regex")
- path_template = paired_config.get("path_template")
- if extract_regex and path_template:
- match = re.search(extract_regex, file_path)
- if match:
- # Construct paired file path using named capture groups
- try:
- paired_path = path_template.format(**match.groupdict())
- except KeyError as e:
- logger.error(f"Failed to format paired_path: missing {e} in regex match for {file_path}")
- paired_path = None
-
- if paired_path:
- # Deduplicate paired input fetches
- paired_task_key = (paired_path, group_key)
- if paired_task_key in processed_keys:
- continue
- processed_keys.add(paired_task_key)
- # Actively fetch paired file info from Gogs
- paired_info = await self.gogs.get_file_info(
- owner,
- repo_name,
- paired_path,
- ref=ref,
- fallback_ref=fallback_ref,
- )
- if paired_info:
- paired_changed = await self.storage.process_file_with_sha(
- version, paired_path, paired_info.get("sha"), owner, repo_name,
- direction=paired_config.get("direction", "input"),
- label=paired_config.get("label"),
- extract_json_key=paired_config.get("extract_json_key"),
- group_key=group_key, # Link them together!
- content_ref=paired_info.get("ref", ref),
- )
- has_change = has_change or paired_changed
- else:
- logger.warning(f"Paired input file not found at ref {ref[:8]}: {paired_path}")
- return has_change
- async def _process_outputs(
- self,
- version,
- outputs: list,
- owner: str,
- repo_name: str,
- ref: str,
- fallback_ref: str | None,
- changed_files: set[str],
- ) -> bool:
- """Process ONLY changed files that match manifest output rules.
- Instead of fetching entire directory trees from Gogs API (slow),
- we match the webhook payload's changed-file list against manifest
- rules using LOCAL string/glob logic — zero API calls for matching.
- Returns True if at least one file had actual content changes.
- """
- # Step 1: Local matching — zero API calls
- matched_files = []
- for file_path in changed_files:
- matched_output = self._find_matching_output(file_path, outputs)
- if matched_output is not None:
- matched_files.append((file_path, matched_output))
- if not matched_files:
- logger.info("No changed files matched any output rule.")
- return False
- logger.info(f"Matched {len(matched_files)} changed file(s) to output rules. Processing in parallel.")
- # Step 2: Fetch file info + download/upload in parallel
- processed_keys = set()
- tasks = [
- self._fetch_and_process_file(
- version, fp, oc, owner, repo_name, ref, fallback_ref, processed_keys
- )
- for fp, oc in matched_files
- ]
- has_changes = False
- results = await asyncio.gather(*tasks, return_exceptions=True)
- for i, res in enumerate(results):
- if isinstance(res, Exception):
- logger.error(f"Error processing {matched_files[i][0]}: {res}")
- elif res is True:
- has_changes = True
- return has_changes
- def _match_patterns(
- self,
- filename: str,
- include_patterns: str | list[str],
- exclude_patterns: str | list[str] | None = None,
- ) -> bool:
- """Helper to match filename against multiple include and exclude glob patterns."""
- # Normalize to lists
- includes = (
- [include_patterns] if isinstance(include_patterns, str) else include_patterns
- )
- excludes = []
- if exclude_patterns:
- excludes = (
- [exclude_patterns] if isinstance(exclude_patterns, str) else exclude_patterns
- )
- # 1. Check if it matches ANY include pattern (OR logic)
- is_included = any(fnmatch.fnmatch(filename, p) for p in includes)
- if not is_included:
- return False
- # 2. Check if it matches ANY exclude pattern (OR logic: any match means reject)
- is_excluded = any(fnmatch.fnmatch(filename, p) for p in excludes)
- return not is_excluded
|