webhook_service.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535
  1. import yaml
  2. import logging
  3. import fnmatch
  4. import os
  5. import asyncio
  6. import re
  7. from sqlalchemy.orm import Session
  8. from app.models import Project, DataVersion, DataFile, DataRecord
  9. from app.services.gogs_client import GogsClient
  10. from app.services.storage_service import StorageService
  11. logger = logging.getLogger(__name__)
  12. def normalize_path(path: str) -> str:
  13. """Normalize path by removing ./ prefix."""
  14. path = path.strip()
  15. if path.startswith("./"):
  16. path = path[2:]
  17. return path
  18. def is_directory_pattern(path: str) -> bool:
  19. """Check if the path pattern represents a directory."""
  20. return path.endswith("/")
  21. class WebhookService:
  22. def __init__(self, db: Session):
  23. self.db = db
  24. self.gogs = GogsClient()
  25. self.storage = StorageService(db, self.gogs)
  26. async def process_webhook(self, payload: dict):
  27. # 1. Parse payload
  28. ref = payload.get("ref")
  29. if not ref:
  30. logger.warning("No ref in payload")
  31. return
  32. after_sha = payload.get("after")
  33. repo = payload.get("repository", {})
  34. repo_name = repo.get("name")
  35. owner = repo.get("owner", {}).get("username")
  36. pusher = payload.get("pusher", {})
  37. author_name = pusher.get("username", "unknown")
  38. if not after_sha or not repo_name or not owner:
  39. logger.error("Invalid payload: missing essential fields")
  40. return
  41. logger.info(f"Processing push for {owner}/{repo_name} commit {after_sha}")
  42. branch_name = self._extract_branch_name(ref)
  43. # 2. Get manifest
  44. manifest_content = await self.gogs.get_manifest(owner, repo_name, after_sha)
  45. if not manifest_content:
  46. logger.info("No manifest.yaml found. Skipping.")
  47. return
  48. try:
  49. manifest = yaml.safe_load(manifest_content)
  50. except yaml.YAMLError as e:
  51. logger.error(f"Failed to parse manifest: {e}")
  52. return
  53. # 3. Validation
  54. project_name = manifest.get("project_name")
  55. if not project_name:
  56. logger.error("Manifest missing project_name")
  57. return
  58. # 4. Get or create project
  59. project = self.storage.get_or_create_project(project_name)
  60. # 5. Get all changed files from payload for pre-filtering
  61. all_changed_files = self._get_all_changed_files(payload)
  62. manifest_changed = "manifest.yaml" in all_changed_files
  63. logger.info(f"Detected {len(all_changed_files)} changed files. Manifest changed: {manifest_changed}")
  64. # 6. Process stages
  65. stages = manifest.get("stages", [])
  66. if not stages:
  67. logger.error("Manifest missing stages configuration")
  68. return
  69. for stage_config in stages:
  70. stage_name = stage_config.get("name")
  71. outputs = stage_config.get("outputs", [])
  72. if not stage_name:
  73. logger.warning("Stage missing name, skipping")
  74. continue
  75. # --- PRE-FILTERING LOGIC ---
  76. # Skip if manifest hasn't changed AND no files in this stage's scope have changed
  77. if not manifest_changed and not self._is_stage_affected(outputs, all_changed_files):
  78. logger.info(f"Stage '{stage_name}': No relevant files changed. Skipping processing.")
  79. continue
  80. # Check if this stage+commit already processed (idempotency)
  81. existing_version = self.db.query(DataVersion).filter(
  82. DataVersion.project_id == project.id,
  83. DataVersion.stage == stage_name,
  84. DataVersion.commit_id == after_sha
  85. ).first()
  86. if existing_version:
  87. logger.info(f"Stage '{stage_name}' already processed. Skipping.")
  88. continue
  89. # Get commit message from payload if available
  90. commit_msg = None
  91. commits = payload.get("commits", [])
  92. if commits:
  93. commit_msg = commits[0].get("message")
  94. # Create version for this stage
  95. version = self.storage.create_version(
  96. project.id, stage_name, after_sha, author_name, manifest=manifest_content, commit_message=commit_msg
  97. )
  98. if not version:
  99. logger.info(f"Stage '{stage_name}' (commit {after_sha[:8]}) is already being processed. Skipping.")
  100. continue
  101. logger.info(f"Processing stage '{stage_name}' with {len(outputs)} output rules")
  102. # Process ONLY changed files that match output rules (no directory tree fetching)
  103. has_new_uploads = await self._process_outputs(
  104. version,
  105. outputs,
  106. owner,
  107. repo_name,
  108. ref=after_sha,
  109. fallback_ref=branch_name,
  110. changed_files=all_changed_files,
  111. )
  112. # Check if this version represents a real change (content OR file set)
  113. if not self.storage.is_snapshot_changed(version, has_new_uploads):
  114. # No changes detected — this was a code-only push, discard the snapshot
  115. self.storage.rollback_version(version)
  116. logger.info(
  117. f"Stage '{stage_name}': no data changes detected (content and file set same). "
  118. f"Version discarded."
  119. )
  120. else:
  121. self.storage.aggregate_version_records(version)
  122. # ── Backfill: supplement missing paired inputs in recent versions ──
  123. # Handles the case where output files were committed BEFORE their
  124. # paired input files. When the input file arrives in a later push,
  125. # we retroactively attach it to the older records that were missing it.
  126. for stage_config in stages:
  127. stage_name = stage_config.get("name")
  128. outputs = stage_config.get("outputs", [])
  129. if not stage_name or not outputs:
  130. continue
  131. has_paired = any(
  132. oc.get("paired_input") or oc.get("paired_inputs")
  133. for oc in outputs
  134. )
  135. if not has_paired:
  136. continue
  137. if not manifest_changed and not self._is_stage_affected(outputs, all_changed_files):
  138. continue
  139. await self._backfill_incomplete_records(
  140. project.id, stage_name, outputs,
  141. owner, repo_name, after_sha,
  142. )
  143. # Close the shared HTTP client (connection pool)
  144. await self.gogs.aclose()
  145. def _get_all_changed_files(self, payload: dict) -> set[str]:
  146. """Extract all added, modified, and removed files from all commits in payload."""
  147. files = set()
  148. commits = payload.get("commits", [])
  149. for commit in commits:
  150. # for key in ["added", "modified", "removed"]:
  151. for key in ["added", "modified"]:
  152. for f in (commit.get(key) or []):
  153. files.add(normalize_path(f))
  154. return files
  155. @staticmethod
  156. def _extract_branch_name(git_ref: str | None) -> str | None:
  157. """Extract branch name from webhook ref, e.g. refs/heads/main -> main."""
  158. if not git_ref:
  159. return None
  160. prefix = "refs/heads/"
  161. if git_ref.startswith(prefix):
  162. return git_ref[len(prefix):]
  163. return git_ref
  164. def _is_stage_affected(self, outputs: list, changed_files: set[str]) -> bool:
  165. """Check if any of the changed files fall under the scope of the stage's outputs."""
  166. if not changed_files:
  167. return True # Fallback: if we can't tell what changed, process it
  168. for output in outputs:
  169. path_pattern = normalize_path(output.get("path", ""))
  170. is_dir = is_directory_pattern(output.get("path", ""))
  171. for f in changed_files:
  172. if is_dir:
  173. # If it's a directory output, any change inside that directory counts
  174. dir_path = path_pattern.rstrip("/")
  175. if '*' in dir_path:
  176. import fnmatch
  177. if fnmatch.fnmatch(f, dir_path + "/*") or fnmatch.fnmatch(f, dir_path):
  178. return True
  179. else:
  180. if f == dir_path or f.startswith(dir_path + "/"):
  181. return True
  182. else:
  183. # Single file output: exact match
  184. if f == path_pattern:
  185. return True
  186. return False
  187. def _find_matching_output(self, file_path: str, outputs: list) -> dict | None:
  188. """Check if a file path matches any manifest output rule using LOCAL logic only.
  189. No Gogs API calls are made — this is pure string/glob matching.
  190. Returns the matching output config dict, or None.
  191. """
  192. for output in outputs:
  193. raw_path = output.get("path", "")
  194. path_pattern = normalize_path(raw_path)
  195. is_dir = is_directory_pattern(raw_path)
  196. patterns = output.get("pattern", "*")
  197. excludes = output.get("exclude")
  198. if is_dir:
  199. dir_path = path_pattern.rstrip("/")
  200. if '*' in dir_path:
  201. if not fnmatch.fnmatch(file_path, dir_path + "/*") and not fnmatch.fnmatch(file_path, dir_path):
  202. continue
  203. else:
  204. if not file_path.startswith(dir_path + "/"):
  205. continue
  206. filename = os.path.basename(file_path)
  207. if self._match_patterns(filename, patterns, excludes):
  208. return output
  209. else:
  210. if file_path == path_pattern:
  211. return output
  212. return None
  213. async def _fetch_and_process_file(
  214. self, version, file_path: str, output_config: dict,
  215. owner: str, repo_name: str, ref: str, fallback_ref: str | None,
  216. processed_keys: set
  217. ) -> bool:
  218. """Get file SHA from Gogs and process a single changed file, plus its paired input if configured."""
  219. # Pre-collect paired inputs to help determine grouping logic
  220. paired_configs = list(output_config.get("paired_inputs", []))
  221. if "paired_input" in output_config:
  222. paired_configs.append(output_config["paired_input"])
  223. # The group_key is ALWAYS the output file's exact relative_path.
  224. # This guarantees 1 Output : N Inputs mapping strictly.
  225. group_key = file_path
  226. # Deduplicate API calls and DB entries across concurrently running tasks
  227. task_key = (file_path, group_key)
  228. if task_key in processed_keys:
  229. return False
  230. processed_keys.add(task_key)
  231. file_info = await self.gogs.get_file_info(owner, repo_name, file_path, ref=ref)
  232. if not file_info:
  233. logger.info(f"File {file_path} not found at ref {ref[:8]} (removed). Skipping.")
  234. return False
  235. has_change = await self.storage.process_file_with_sha(
  236. version, file_path, file_info.get("sha"), owner, repo_name,
  237. direction=output_config.get("direction"),
  238. label=output_config.get("label"),
  239. extract_json_key=output_config.get("extract_json_key"),
  240. group_key=group_key,
  241. content_ref=file_info.get("ref", ref),
  242. )
  243. for paired_config in paired_configs:
  244. extract_regex = paired_config.get("extract_regex")
  245. path_template = paired_config.get("path_template")
  246. if extract_regex and path_template:
  247. match = re.search(extract_regex, file_path)
  248. if match:
  249. # Construct paired file path using named capture groups
  250. try:
  251. paired_path = path_template.format(**match.groupdict())
  252. except KeyError as e:
  253. logger.error(f"Failed to format paired_path: missing {e} in regex match for {file_path}")
  254. paired_path = None
  255. if paired_path:
  256. # Deduplicate paired input fetches
  257. paired_task_key = (paired_path, group_key)
  258. if paired_task_key in processed_keys:
  259. continue
  260. processed_keys.add(paired_task_key)
  261. # Actively fetch paired file info from Gogs
  262. paired_info = await self.gogs.get_file_info(
  263. owner,
  264. repo_name,
  265. paired_path,
  266. ref=ref,
  267. fallback_ref=fallback_ref,
  268. )
  269. if paired_info:
  270. paired_changed = await self.storage.process_file_with_sha(
  271. version, paired_path, paired_info.get("sha"), owner, repo_name,
  272. direction=paired_config.get("direction", "input"),
  273. label=paired_config.get("label"),
  274. extract_json_key=paired_config.get("extract_json_key"),
  275. group_key=group_key, # Link them together!
  276. content_ref=paired_info.get("ref", ref),
  277. )
  278. has_change = has_change or paired_changed
  279. else:
  280. logger.warning(f"Paired input file not found at ref {ref[:8]}: {paired_path}")
  281. return has_change
  282. async def _backfill_incomplete_records(
  283. self, project_id: str, stage_name: str, outputs: list,
  284. owner: str, repo_name: str, current_commit: str,
  285. ):
  286. """Backfill paired inputs that were missing when records were first created.
  287. When output files are committed before their paired input files, the
  288. initial records will have empty inputs. This method finds those
  289. incomplete records and tries to fetch the now-available paired inputs
  290. using the *current* commit (which may contain files added later).
  291. """
  292. recent_versions = (
  293. self.db.query(DataVersion)
  294. .filter(
  295. DataVersion.project_id == project_id,
  296. DataVersion.stage == stage_name,
  297. )
  298. .order_by(DataVersion.created_at.desc())
  299. .limit(20)
  300. .all()
  301. )
  302. # Cache Gogs lookups so we don't fetch the same path twice
  303. file_info_cache: dict[str, dict | None] = {}
  304. for version in recent_versions:
  305. records = (
  306. self.db.query(DataRecord)
  307. .filter(DataRecord.version_id == version.id)
  308. .all()
  309. )
  310. needs_reaggregate = False
  311. for record in records:
  312. out_path = (
  313. record.outputs[0]["relative_path"]
  314. if record.outputs
  315. else None
  316. )
  317. if not out_path:
  318. continue
  319. output_config = self._find_matching_output(out_path, outputs)
  320. if not output_config:
  321. continue
  322. paired_configs = list(output_config.get("paired_inputs", []))
  323. if "paired_input" in output_config:
  324. paired_configs.append(output_config["paired_input"])
  325. if not paired_configs:
  326. continue
  327. existing_input_paths = {
  328. inp["relative_path"] for inp in (record.inputs or [])
  329. }
  330. for pc in paired_configs:
  331. extract_regex = pc.get("extract_regex")
  332. path_template = pc.get("path_template")
  333. if not extract_regex or not path_template:
  334. continue
  335. match = re.search(extract_regex, out_path)
  336. if not match:
  337. continue
  338. try:
  339. paired_path = path_template.format(**match.groupdict())
  340. except KeyError:
  341. continue
  342. # Already present in this record's inputs
  343. if paired_path in existing_input_paths:
  344. continue
  345. # DataFile already exists but not yet reflected in record
  346. existing_df = (
  347. self.db.query(DataFile)
  348. .filter(
  349. DataFile.version_id == version.id,
  350. DataFile.relative_path == paired_path,
  351. DataFile.group_key == record.group_key,
  352. )
  353. .first()
  354. )
  355. if existing_df:
  356. needs_reaggregate = True
  357. continue
  358. # Fetch from Gogs (with cache)
  359. if paired_path not in file_info_cache:
  360. file_info_cache[paired_path] = (
  361. await self.gogs.get_file_info(
  362. owner, repo_name, paired_path,
  363. ref=current_commit,
  364. )
  365. )
  366. paired_info = file_info_cache[paired_path]
  367. if not paired_info:
  368. continue # still not available
  369. await self.storage.process_file_with_sha(
  370. version,
  371. paired_path,
  372. paired_info.get("sha"),
  373. owner,
  374. repo_name,
  375. direction=pc.get("direction", "input"),
  376. label=pc.get("label"),
  377. extract_json_key=pc.get("extract_json_key"),
  378. group_key=record.group_key,
  379. content_ref=paired_info.get("ref", current_commit),
  380. )
  381. needs_reaggregate = True
  382. logger.info(
  383. f"Backfilled paired input {paired_path} "
  384. f"for version {version.id} (commit {version.commit_id[:8]})"
  385. )
  386. if needs_reaggregate:
  387. self.storage.aggregate_version_records(version)
  388. logger.info(
  389. f"Re-aggregated version {version.id} after backfilling paired inputs"
  390. )
  391. async def _process_outputs(
  392. self,
  393. version,
  394. outputs: list,
  395. owner: str,
  396. repo_name: str,
  397. ref: str,
  398. fallback_ref: str | None,
  399. changed_files: set[str],
  400. ) -> bool:
  401. """Process ONLY changed files that match manifest output rules.
  402. Instead of fetching entire directory trees from Gogs API (slow),
  403. we match the webhook payload's changed-file list against manifest
  404. rules using LOCAL string/glob logic — zero API calls for matching.
  405. Returns True if at least one file had actual content changes.
  406. """
  407. # Step 1: Local matching — zero API calls
  408. matched_files = []
  409. for file_path in changed_files:
  410. matched_output = self._find_matching_output(file_path, outputs)
  411. if matched_output is not None:
  412. matched_files.append((file_path, matched_output))
  413. if not matched_files:
  414. logger.info("No changed files matched any output rule.")
  415. return False
  416. logger.info(f"Matched {len(matched_files)} changed file(s) to output rules. Processing in parallel.")
  417. # Step 2: Fetch file info + download/upload in parallel
  418. processed_keys = set()
  419. tasks = [
  420. self._fetch_and_process_file(
  421. version, fp, oc, owner, repo_name, ref, fallback_ref, processed_keys
  422. )
  423. for fp, oc in matched_files
  424. ]
  425. has_changes = False
  426. results = await asyncio.gather(*tasks, return_exceptions=True)
  427. for i, res in enumerate(results):
  428. if isinstance(res, Exception):
  429. logger.error(f"Error processing {matched_files[i][0]}: {res}")
  430. elif res is True:
  431. has_changes = True
  432. return has_changes
  433. def _match_patterns(
  434. self,
  435. filename: str,
  436. include_patterns: str | list[str],
  437. exclude_patterns: str | list[str] | None = None,
  438. ) -> bool:
  439. """Helper to match filename against multiple include and exclude glob patterns."""
  440. # Normalize to lists
  441. includes = (
  442. [include_patterns] if isinstance(include_patterns, str) else include_patterns
  443. )
  444. excludes = []
  445. if exclude_patterns:
  446. excludes = (
  447. [exclude_patterns] if isinstance(exclude_patterns, str) else exclude_patterns
  448. )
  449. # 1. Check if it matches ANY include pattern (OR logic)
  450. is_included = any(fnmatch.fnmatch(filename, p) for p in includes)
  451. if not is_included:
  452. return False
  453. # 2. Check if it matches ANY exclude pattern (OR logic: any match means reject)
  454. is_excluded = any(fnmatch.fnmatch(filename, p) for p in excludes)
  455. return not is_excluded