webhook_service.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. import yaml
  2. import logging
  3. import fnmatch
  4. import os
  5. import asyncio
  6. import re
  7. from sqlalchemy.orm import Session
  8. from app.models import Project, DataVersion
  9. from app.services.gogs_client import GogsClient
  10. from app.services.storage_service import StorageService
  11. logger = logging.getLogger(__name__)
  12. def normalize_path(path: str) -> str:
  13. """Normalize path by removing ./ prefix."""
  14. path = path.strip()
  15. if path.startswith("./"):
  16. path = path[2:]
  17. return path
  18. def is_directory_pattern(path: str) -> bool:
  19. """Check if the path pattern represents a directory."""
  20. return path.endswith("/")
  21. class WebhookService:
  22. def __init__(self, db: Session):
  23. self.db = db
  24. self.gogs = GogsClient()
  25. self.storage = StorageService(db, self.gogs)
  26. async def process_webhook(self, payload: dict):
  27. # 1. Parse payload
  28. ref = payload.get("ref")
  29. if not ref:
  30. logger.warning("No ref in payload")
  31. return
  32. after_sha = payload.get("after")
  33. repo = payload.get("repository", {})
  34. repo_name = repo.get("name")
  35. owner = repo.get("owner", {}).get("username")
  36. pusher = payload.get("pusher", {})
  37. author_name = pusher.get("username", "unknown")
  38. if not after_sha or not repo_name or not owner:
  39. logger.error("Invalid payload: missing essential fields")
  40. return
  41. logger.info(f"Processing push for {owner}/{repo_name} commit {after_sha}")
  42. branch_name = self._extract_branch_name(ref)
  43. # 2. Get manifest
  44. manifest_content = await self.gogs.get_manifest(owner, repo_name, after_sha)
  45. if not manifest_content:
  46. logger.info("No manifest.yaml found. Skipping.")
  47. return
  48. try:
  49. manifest = yaml.safe_load(manifest_content)
  50. except yaml.YAMLError as e:
  51. logger.error(f"Failed to parse manifest: {e}")
  52. return
  53. # 3. Validation
  54. project_name = manifest.get("project_name")
  55. if not project_name:
  56. logger.error("Manifest missing project_name")
  57. return
  58. # 4. Get or create project
  59. project = self.storage.get_or_create_project(project_name)
  60. # 5. Get all changed files from payload for pre-filtering
  61. all_changed_files = self._get_all_changed_files(payload)
  62. manifest_changed = "manifest.yaml" in all_changed_files
  63. logger.info(f"Detected {len(all_changed_files)} changed files. Manifest changed: {manifest_changed}")
  64. # 6. Process stages
  65. stages = manifest.get("stages", [])
  66. # Backward compatibility: old single-stage format
  67. if not stages and manifest.get("stage"):
  68. stages = [{
  69. "name": manifest.get("stage"),
  70. "outputs": manifest.get("outputs", [])
  71. }]
  72. if not stages:
  73. logger.error("Manifest missing stages configuration")
  74. return
  75. for stage_config in stages:
  76. stage_name = stage_config.get("name")
  77. outputs = stage_config.get("outputs", [])
  78. if not stage_name:
  79. logger.warning("Stage missing name, skipping")
  80. continue
  81. # --- PRE-FILTERING LOGIC ---
  82. # Skip if manifest hasn't changed AND no files in this stage's scope have changed
  83. if not manifest_changed and not self._is_stage_affected(outputs, all_changed_files):
  84. logger.info(f"Stage '{stage_name}': No relevant files changed. Skipping processing.")
  85. continue
  86. # Check if this stage+commit already processed (idempotency)
  87. existing_version = self.db.query(DataVersion).filter(
  88. DataVersion.project_id == project.id,
  89. DataVersion.stage == stage_name,
  90. DataVersion.commit_id == after_sha
  91. ).first()
  92. if existing_version:
  93. logger.info(f"Stage '{stage_name}' already processed. Skipping.")
  94. continue
  95. # Get commit message from payload if available
  96. commit_msg = None
  97. commits = payload.get("commits", [])
  98. if commits:
  99. commit_msg = commits[0].get("message")
  100. # Create version for this stage
  101. version = self.storage.create_version(
  102. project.id, stage_name, after_sha, author_name, manifest=manifest_content, commit_message=commit_msg
  103. )
  104. if not version:
  105. logger.info(f"Stage '{stage_name}' (commit {after_sha[:8]}) is already being processed. Skipping.")
  106. continue
  107. logger.info(f"Processing stage '{stage_name}' with {len(outputs)} output rules")
  108. # Process ONLY changed files that match output rules (no directory tree fetching)
  109. has_new_uploads = await self._process_outputs(
  110. version,
  111. outputs,
  112. owner,
  113. repo_name,
  114. ref=after_sha,
  115. fallback_ref=branch_name,
  116. changed_files=all_changed_files,
  117. )
  118. # Check if this version represents a real change (content OR file set)
  119. if not self.storage.is_snapshot_changed(version, has_new_uploads):
  120. # No changes detected — this was a code-only push, discard the snapshot
  121. self.storage.rollback_version(version)
  122. logger.info(
  123. f"Stage '{stage_name}': no data changes detected (content and file set same). "
  124. f"Version discarded."
  125. )
  126. else:
  127. self.storage.aggregate_version_records(version)
  128. def _get_all_changed_files(self, payload: dict) -> set[str]:
  129. """Extract all added, modified, and removed files from all commits in payload."""
  130. files = set()
  131. commits = payload.get("commits", [])
  132. for commit in commits:
  133. # for key in ["added", "modified", "removed"]:
  134. for key in ["added", "modified"]:
  135. for f in (commit.get(key) or []):
  136. files.add(normalize_path(f))
  137. return files
  138. @staticmethod
  139. def _extract_branch_name(git_ref: str | None) -> str | None:
  140. """Extract branch name from webhook ref, e.g. refs/heads/main -> main."""
  141. if not git_ref:
  142. return None
  143. prefix = "refs/heads/"
  144. if git_ref.startswith(prefix):
  145. return git_ref[len(prefix):]
  146. return git_ref
  147. def _is_stage_affected(self, outputs: list, changed_files: set[str]) -> bool:
  148. """Check if any of the changed files fall under the scope of the stage's outputs."""
  149. if not changed_files:
  150. return True # Fallback: if we can't tell what changed, process it
  151. for output in outputs:
  152. path_pattern = normalize_path(output.get("path", ""))
  153. is_dir = is_directory_pattern(output.get("path", ""))
  154. for f in changed_files:
  155. if is_dir:
  156. # If it's a directory output, any change inside that directory counts
  157. dir_path = path_pattern.rstrip("/")
  158. if '*' in dir_path:
  159. import fnmatch
  160. if fnmatch.fnmatch(f, dir_path + "/*") or fnmatch.fnmatch(f, dir_path):
  161. return True
  162. else:
  163. if f == dir_path or f.startswith(dir_path + "/"):
  164. return True
  165. else:
  166. # Single file output: exact match
  167. if f == path_pattern:
  168. return True
  169. return False
  170. def _find_matching_output(self, file_path: str, outputs: list) -> dict | None:
  171. """Check if a file path matches any manifest output rule using LOCAL logic only.
  172. No Gogs API calls are made — this is pure string/glob matching.
  173. Returns the matching output config dict, or None.
  174. """
  175. for output in outputs:
  176. raw_path = output.get("path", "")
  177. path_pattern = normalize_path(raw_path)
  178. is_dir = is_directory_pattern(raw_path)
  179. patterns = output.get("pattern", "*")
  180. excludes = output.get("exclude")
  181. if is_dir:
  182. dir_path = path_pattern.rstrip("/")
  183. if '*' in dir_path:
  184. if not fnmatch.fnmatch(file_path, dir_path + "/*") and not fnmatch.fnmatch(file_path, dir_path):
  185. continue
  186. else:
  187. if not file_path.startswith(dir_path + "/"):
  188. continue
  189. filename = os.path.basename(file_path)
  190. if self._match_patterns(filename, patterns, excludes):
  191. return output
  192. else:
  193. if file_path == path_pattern:
  194. return output
  195. return None
  196. async def _fetch_and_process_file(
  197. self, version, file_path: str, output_config: dict,
  198. owner: str, repo_name: str, ref: str, fallback_ref: str | None,
  199. processed_keys: set
  200. ) -> bool:
  201. """Get file SHA from Gogs and process a single changed file, plus its paired input if configured."""
  202. # Pre-collect paired inputs to help determine grouping logic
  203. paired_configs = list(output_config.get("paired_inputs", []))
  204. if "paired_input" in output_config:
  205. paired_configs.append(output_config["paired_input"])
  206. # Calculate group_key here so both paired input and output can share it
  207. directory_depth = output_config.get("directory_depth")
  208. if directory_depth is not None and directory_depth > 0:
  209. parts = file_path.split("/")
  210. if len(parts) > 1:
  211. group_key = "/".join(parts[:-1][:directory_depth])
  212. else:
  213. group_key = ""
  214. elif paired_configs:
  215. # If we have paired inputs, use the full file path as a unique group key
  216. # to avoid "cross-talk" where multiple outputs in the same directory
  217. # share all paired inputs from that directory.
  218. group_key = file_path
  219. else:
  220. group_key = os.path.dirname(file_path)
  221. # Deduplicate API calls and DB entries across concurrently running tasks
  222. task_key = (file_path, group_key)
  223. if task_key in processed_keys:
  224. return False
  225. processed_keys.add(task_key)
  226. file_info = await self.gogs.get_file_info(owner, repo_name, file_path, ref=ref)
  227. if not file_info:
  228. logger.info(f"File {file_path} not found at ref {ref[:8]} (removed). Skipping.")
  229. return False
  230. has_change = await self.storage.process_file_with_sha(
  231. version, file_path, file_info.get("sha"), owner, repo_name,
  232. direction=output_config.get("direction"),
  233. label=output_config.get("label"),
  234. extract_json_key=output_config.get("extract_json_key"),
  235. directory_depth=directory_depth,
  236. group_key=group_key,
  237. content_ref=file_info.get("ref", ref),
  238. )
  239. for paired_config in paired_configs:
  240. extract_regex = paired_config.get("extract_regex")
  241. path_template = paired_config.get("path_template")
  242. if extract_regex and path_template:
  243. match = re.search(extract_regex, file_path)
  244. if match:
  245. # Construct paired file path using named capture groups
  246. try:
  247. paired_path = path_template.format(**match.groupdict())
  248. except KeyError as e:
  249. logger.error(f"Failed to format paired_path: missing {e} in regex match for {file_path}")
  250. paired_path = None
  251. if paired_path:
  252. # Deduplicate paired input fetches
  253. paired_task_key = (paired_path, group_key)
  254. if paired_task_key in processed_keys:
  255. continue
  256. processed_keys.add(paired_task_key)
  257. # Actively fetch paired file info from Gogs
  258. paired_info = await self.gogs.get_file_info(
  259. owner,
  260. repo_name,
  261. paired_path,
  262. ref=ref,
  263. fallback_ref=fallback_ref,
  264. )
  265. if paired_info:
  266. paired_changed = await self.storage.process_file_with_sha(
  267. version, paired_path, paired_info.get("sha"), owner, repo_name,
  268. direction=paired_config.get("direction", "input"),
  269. label=paired_config.get("label"),
  270. extract_json_key=paired_config.get("extract_json_key"),
  271. group_key=group_key, # Link them together!
  272. content_ref=paired_info.get("ref", ref),
  273. )
  274. has_change = has_change or paired_changed
  275. else:
  276. logger.warning(f"Paired input file not found at ref {ref[:8]}: {paired_path}")
  277. return has_change
  278. async def _process_outputs(
  279. self,
  280. version,
  281. outputs: list,
  282. owner: str,
  283. repo_name: str,
  284. ref: str,
  285. fallback_ref: str | None,
  286. changed_files: set[str],
  287. ) -> bool:
  288. """Process ONLY changed files that match manifest output rules.
  289. Instead of fetching entire directory trees from Gogs API (slow),
  290. we match the webhook payload's changed-file list against manifest
  291. rules using LOCAL string/glob logic — zero API calls for matching.
  292. Returns True if at least one file had actual content changes.
  293. """
  294. # Step 1: Local matching — zero API calls
  295. matched_files = []
  296. for file_path in changed_files:
  297. matched_output = self._find_matching_output(file_path, outputs)
  298. if matched_output is not None:
  299. matched_files.append((file_path, matched_output))
  300. if not matched_files:
  301. logger.info("No changed files matched any output rule.")
  302. return False
  303. logger.info(f"Matched {len(matched_files)} changed file(s) to output rules. Processing in parallel.")
  304. # Step 2: Fetch file info + download/upload in parallel
  305. processed_keys = set()
  306. tasks = [
  307. self._fetch_and_process_file(
  308. version, fp, oc, owner, repo_name, ref, fallback_ref, processed_keys
  309. )
  310. for fp, oc in matched_files
  311. ]
  312. has_changes = False
  313. results = await asyncio.gather(*tasks, return_exceptions=True)
  314. for i, res in enumerate(results):
  315. if isinstance(res, Exception):
  316. logger.error(f"Error processing {matched_files[i][0]}: {res}")
  317. elif res is True:
  318. has_changes = True
  319. return has_changes
  320. def _match_patterns(
  321. self,
  322. filename: str,
  323. include_patterns: str | list[str],
  324. exclude_patterns: str | list[str] | None = None,
  325. ) -> bool:
  326. """Helper to match filename against multiple include and exclude glob patterns."""
  327. # Normalize to lists
  328. includes = (
  329. [include_patterns] if isinstance(include_patterns, str) else include_patterns
  330. )
  331. excludes = []
  332. if exclude_patterns:
  333. excludes = (
  334. [exclude_patterns] if isinstance(exclude_patterns, str) else exclude_patterns
  335. )
  336. # 1. Check if it matches ANY include pattern (OR logic)
  337. is_included = any(fnmatch.fnmatch(filename, p) for p in includes)
  338. if not is_included:
  339. return False
  340. # 2. Check if it matches ANY exclude pattern (OR logic: any match means reject)
  341. is_excluded = any(fnmatch.fnmatch(filename, p) for p in excludes)
  342. return not is_excluded