webhook_service.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. import yaml
  2. import logging
  3. import fnmatch
  4. from sqlalchemy.orm import Session
  5. from app.models import Project, DataVersion
  6. from app.services.gogs_client import GogsClient
  7. from app.services.storage_service import StorageService
  8. logger = logging.getLogger(__name__)
  9. def normalize_path(path: str) -> str:
  10. """Normalize path by removing ./ prefix."""
  11. path = path.strip()
  12. if path.startswith("./"):
  13. path = path[2:]
  14. return path
  15. def is_directory_pattern(path: str) -> bool:
  16. """Check if the path pattern represents a directory."""
  17. return path.endswith("/")
  18. class WebhookService:
  19. def __init__(self, db: Session):
  20. self.db = db
  21. self.gogs = GogsClient()
  22. self.storage = StorageService(db, self.gogs)
  23. async def process_webhook(self, payload: dict):
  24. # 1. Parse payload
  25. ref = payload.get("ref")
  26. if not ref:
  27. logger.warning("No ref in payload")
  28. return
  29. after_sha = payload.get("after")
  30. repo = payload.get("repository", {})
  31. repo_name = repo.get("name")
  32. owner = repo.get("owner", {}).get("username")
  33. pusher = payload.get("pusher", {})
  34. author_name = pusher.get("username", "unknown")
  35. if not after_sha or not repo_name or not owner:
  36. logger.error("Invalid payload: missing essential fields")
  37. return
  38. logger.info(f"Processing push for {owner}/{repo_name} commit {after_sha}")
  39. # 2. Get manifest
  40. manifest_content = await self.gogs.get_manifest(owner, repo_name, after_sha)
  41. if not manifest_content:
  42. logger.info("No manifest.yaml found. Skipping.")
  43. return
  44. try:
  45. manifest = yaml.safe_load(manifest_content)
  46. except yaml.YAMLError as e:
  47. logger.error(f"Failed to parse manifest: {e}")
  48. return
  49. # 3. Validation
  50. project_name = manifest.get("project_name")
  51. if not project_name:
  52. logger.error("Manifest missing project_name")
  53. return
  54. # 4. Get or create project
  55. project = self.storage.get_or_create_project(project_name)
  56. # 5. Get all changed files from payload for pre-filtering
  57. all_changed_files = self._get_all_changed_files(payload)
  58. manifest_changed = "manifest.yaml" in all_changed_files
  59. logger.info(f"Detected {len(all_changed_files)} changed files. Manifest changed: {manifest_changed}")
  60. # 6. Process stages
  61. stages = manifest.get("stages", [])
  62. # Backward compatibility: old single-stage format
  63. if not stages and manifest.get("stage"):
  64. stages = [{
  65. "name": manifest.get("stage"),
  66. "outputs": manifest.get("outputs", [])
  67. }]
  68. if not stages:
  69. logger.error("Manifest missing stages configuration")
  70. return
  71. for stage_config in stages:
  72. stage_name = stage_config.get("name")
  73. outputs = stage_config.get("outputs", [])
  74. if not stage_name:
  75. logger.warning("Stage missing name, skipping")
  76. continue
  77. # --- PRE-FILTERING LOGIC ---
  78. # Skip if manifest hasn't changed AND no files in this stage's scope have changed
  79. if not manifest_changed and not self._is_stage_affected(outputs, all_changed_files):
  80. logger.info(f"Stage '{stage_name}': No relevant files changed. Skipping processing.")
  81. continue
  82. # Check if this stage+commit already processed (idempotency)
  83. existing_version = self.db.query(DataVersion).filter(
  84. DataVersion.project_id == project.id,
  85. DataVersion.stage == stage_name,
  86. DataVersion.commit_id == after_sha
  87. ).first()
  88. if existing_version:
  89. logger.info(f"Stage '{stage_name}' already processed. Skipping.")
  90. continue
  91. # Get commit message from payload if available
  92. commit_msg = None
  93. commits = payload.get("commits", [])
  94. if commits:
  95. commit_msg = commits[0].get("message")
  96. # Create version for this stage
  97. version = self.storage.create_version(
  98. project.id, stage_name, after_sha, author_name, manifest=manifest_content, commit_message=commit_msg
  99. )
  100. if not version:
  101. logger.info(f"Stage '{stage_name}' (commit {after_sha[:8]}) is already being processed. Skipping.")
  102. continue
  103. logger.info(f"Processing stage '{stage_name}' with {len(outputs)} output rules")
  104. # Process outputs and check if any file actually changed
  105. has_new_uploads = await self._process_outputs(
  106. version, outputs, owner, repo_name, after_sha
  107. )
  108. # Check if this version represents a real change (content OR file set)
  109. if not self.storage.is_snapshot_changed(version, has_new_uploads):
  110. # No changes detected — this was a code-only push, discard the snapshot
  111. self.storage.rollback_version(version)
  112. logger.info(
  113. f"Stage '{stage_name}': no data changes detected (content and file set same). "
  114. f"Version discarded."
  115. )
  116. else:
  117. self.storage.aggregate_version_records(version)
  118. def _get_all_changed_files(self, payload: dict) -> set[str]:
  119. """Extract all added, modified, and removed files from all commits in payload."""
  120. files = set()
  121. commits = payload.get("commits", [])
  122. for commit in commits:
  123. for key in ["added", "modified", "removed"]:
  124. for f in (commit.get(key) or []):
  125. files.add(normalize_path(f))
  126. return files
  127. def _is_stage_affected(self, outputs: list, changed_files: set[str]) -> bool:
  128. """Check if any of the changed files fall under the scope of the stage's outputs."""
  129. if not changed_files:
  130. return True # Fallback: if we can't tell what changed, process it
  131. for output in outputs:
  132. path_pattern = normalize_path(output.get("path", ""))
  133. is_dir = is_directory_pattern(output.get("path", ""))
  134. for f in changed_files:
  135. if is_dir:
  136. # If it's a directory output, any change inside that directory counts
  137. dir_path = path_pattern.rstrip("/")
  138. if '*' in dir_path:
  139. import fnmatch
  140. if fnmatch.fnmatch(f, dir_path + "/*") or fnmatch.fnmatch(f, dir_path):
  141. return True
  142. else:
  143. if f == dir_path or f.startswith(dir_path + "/"):
  144. return True
  145. else:
  146. # Single file output: exact match
  147. if f == path_pattern:
  148. return True
  149. return False
  150. async def _process_outputs(
  151. self, version, outputs: list, owner: str, repo_name: str, commit_id: str
  152. ) -> bool:
  153. """Process output rules, create snapshot records for ALL matching files.
  154. Returns
  155. -------
  156. bool
  157. ``True`` if at least one file had actual content changes,
  158. ``False`` if every file was unchanged.
  159. """
  160. has_changes = False
  161. for output in outputs:
  162. raw_path_pattern = output.get("path", "")
  163. # Support both string and list for pattern and exclude
  164. patterns = output.get("pattern", "*")
  165. excludes = output.get("exclude")
  166. direction = output.get("direction")
  167. label = output.get("label")
  168. extract_json_key = output.get("extract_json_key")
  169. directory_depth = output.get("directory_depth")
  170. path_pattern = normalize_path(raw_path_pattern)
  171. is_dir = is_directory_pattern(raw_path_pattern)
  172. dir_path = path_pattern.rstrip("/")
  173. if is_dir:
  174. # Directory pattern: fetch files from the closest static parent directory
  175. # For `data/*/test/`, that is `data/`
  176. import re
  177. # Split by first wildcard chunk path
  178. wildcard_idx = dir_path.find('*')
  179. if wildcard_idx != -1:
  180. static_base = dir_path[:wildcard_idx]
  181. # Trim back to the nearest directory separator
  182. last_sep = static_base.rfind('/')
  183. if last_sep != -1:
  184. static_base = static_base[:last_sep]
  185. else:
  186. static_base = "" # ROOT
  187. else:
  188. static_base = dir_path
  189. static_base = static_base.strip('/')
  190. logger.info(f"Fetching directory: {static_base} (to match wildcard path: {dir_path}) with patterns: {patterns}, excludes: {excludes}")
  191. files = await self.gogs.get_directory_tree(owner, repo_name, commit_id, static_base)
  192. for file_info in files:
  193. file_path = file_info.get("path")
  194. # 1. First verify if the full path matches the wildcard directory path provided
  195. if '*' in dir_path:
  196. # e.g dir_path: data/*/test/ -> match: data/*/test/*
  197. if not fnmatch.fnmatch(file_path, dir_path + "/*") and not fnmatch.fnmatch(file_path, dir_path):
  198. continue
  199. else:
  200. if not file_path.startswith(dir_path + "/"):
  201. continue
  202. # Calculate name relative to the matched base path segment for pattern matching
  203. import os
  204. rel_name = os.path.basename(file_path)
  205. if self._match_patterns(rel_name, patterns, excludes):
  206. try:
  207. changed = await self.storage.process_file_with_sha(
  208. version, file_path, file_info.get("sha"), owner, repo_name,
  209. direction=direction, label=label, extract_json_key=extract_json_key,
  210. directory_depth=directory_depth
  211. )
  212. if changed:
  213. has_changes = True
  214. except Exception as e:
  215. logger.error(f"Failed to process file {file_path}: {e}")
  216. else:
  217. # Single file: fetch only this file's info
  218. logger.info(f"Fetching single file: {path_pattern}")
  219. file_info = await self.gogs.get_file_info(owner, repo_name, commit_id, path_pattern)
  220. if file_info:
  221. # Apply pattern matching to the filename for consistency
  222. import os
  223. filename = os.path.basename(path_pattern)
  224. if self._match_patterns(filename, patterns, excludes):
  225. try:
  226. changed = await self.storage.process_file_with_sha(
  227. version, path_pattern, file_info.get("sha"), owner, repo_name,
  228. direction=direction, label=label, extract_json_key=extract_json_key,
  229. directory_depth=directory_depth
  230. )
  231. if changed:
  232. has_changes = True
  233. except Exception as e:
  234. logger.error(f"Failed to process file {path_pattern}: {e}")
  235. else:
  236. logger.warning(f"File not found: {path_pattern}")
  237. return has_changes
  238. def _match_patterns(
  239. self,
  240. filename: str,
  241. include_patterns: str | list[str],
  242. exclude_patterns: str | list[str] | None = None,
  243. ) -> bool:
  244. """Helper to match filename against multiple include and exclude glob patterns."""
  245. # Normalize to lists
  246. includes = (
  247. [include_patterns] if isinstance(include_patterns, str) else include_patterns
  248. )
  249. excludes = []
  250. if exclude_patterns:
  251. excludes = (
  252. [exclude_patterns] if isinstance(exclude_patterns, str) else exclude_patterns
  253. )
  254. # 1. Check if it matches ANY include pattern (OR logic)
  255. is_included = any(fnmatch.fnmatch(filename, p) for p in includes)
  256. if not is_included:
  257. return False
  258. # 2. Check if it matches ANY exclude pattern (OR logic: any match means reject)
  259. is_excluded = any(fnmatch.fnmatch(filename, p) for p in excludes)
  260. return not is_excluded