tanjingyu 1 неделя назад
Родитель
Сommit
5bd2bf9ce3
3 измененных файлов с 189 добавлено и 145 удалено
  1. 44 54
      app/services/storage_service.py
  2. 123 91
      app/services/webhook_service.py
  3. 22 0
      manifest.yaml.example

+ 44 - 54
app/services/storage_service.py

@@ -54,47 +54,12 @@ class StorageService:
 
     def is_snapshot_changed(self, version: DataVersion, has_new_uploads: bool) -> bool:
         """
-        Determine if this version represents a meaningful change compared to the previous one.
-        A version is meaningful if:
-        1. Any file content changed (has_new_uploads is True)
-        2. The set of files (paths) is different from the previous version record for this stage.
-        """
-        # 1. If content changed, it's definitely a new version
-        if has_new_uploads:
-            return True
-
-        # 2. Get current file paths
-        current_files = self.db.query(DataFile).filter(DataFile.version_id == version.id).all()
-        if not current_files:
-            return False  # No data files at all, don't keep
-
-        # 3. Find the most recent previous version for this project and stage
-        prev_version = (
-            self.db.query(DataVersion)
-            .filter(
-                DataVersion.project_id == version.project_id,
-                DataVersion.stage == version.stage,
-                DataVersion.id != version.id
-            )
-            .order_by(DataVersion.created_at.desc())
-            .first()
-        )
-
-        # 4. If there's no previous version, we keep this one as the baseline
-        if not prev_version:
-            return True
+        Determine if this version represents a meaningful change.
 
-        # 5. Compare the set of relative paths
-        prev_files = self.db.query(DataFile).filter(DataFile.version_id == prev_version.id).all()
-        
-        prev_paths = {f.relative_path for f in prev_files}
-        curr_paths = {f.relative_path for f in current_files}
-
-        if prev_paths != curr_paths:
-            logger.info(f"Snapshot file set changed for stage '{version.stage}': {len(prev_paths)} -> {len(curr_paths)} files")
-            return True
-
-        return False
+        With differential processing (only webhook-changed files are processed),
+        a version is meaningful if any file had new content uploaded to OSS.
+        """
+        return has_new_uploads
 
     def aggregate_version_records(self, version: DataVersion):
         """Aggregate files in a version into DataRecord groups based on parent directory."""
@@ -128,22 +93,49 @@ class StorageService:
             else:
                 # Treat 'output' or None as output by default for rendering purposes
                 groups[group_key]["outputs"].append(file_data)
-        # 3. Insert aggregated records (One record per output file)
+        # 3. Insert aggregated records (One record per output file, with differential logic)
         for group_key, data in groups.items():
             inputs = data["inputs"]
             outputs = data["outputs"]
             
-            # If there are no outputs, still create one record for the inputs
-            # If there are outputs, create one record for EACH output
+            # 预先获取该 group_key 下所有输出路径的最新状态
+            # 用于判定当前这次 Commit 是否真的产生了变化
+            latest_hashes = {}
+            past_records = (
+                self.db.query(DataRecord)
+                .filter(
+                    DataRecord.project_id == version.project_id,
+                    DataRecord.stage == version.stage,
+                    DataRecord.group_key == group_key,
+                    DataRecord.version_id != version.id
+                )
+                .order_by(DataRecord.created_at.desc())
+                .limit(200) # 覆盖常见的目录文件数
+                .all()
+            )
+            for pr in past_records:
+                # 记录每一个输出路径对应的最新 hash
+                p_out = pr.outputs[0]['relative_path'] if pr.outputs else None
+                if p_out not in latest_hashes:
+                    latest_hashes[p_out] = pr.content_hash
+
+            # 如果没有输出,也创建一个逻辑组
             output_groups = [[o] for o in outputs] if outputs else [[]]
             
             for out_list in output_groups:
-                # Calculate a deterministic content_hash for this combination
+                # 1. 计算当前指纹:输入集合 SHA + 对应输出 SHA
                 all_shas = [f["file_sha"] for f in inputs] + [f["file_sha"] for f in out_list]
                 all_shas.sort()
                 combined_string = "|".join(all_shas)
                 content_hash = hashlib.sha256(combined_string.encode('utf-8')).hexdigest()
 
+                # 2. 差异化判定
+                out_path = out_list[0]['relative_path'] if out_list else None
+                if out_path in latest_hashes and latest_hashes[out_path] == content_hash:
+                    logger.info(f"Skipping unchanged record: {group_key} -> {out_path}")
+                    continue
+
+                # 3. 只有变化了才记录
                 record = DataRecord(
                     project_id=version.project_id,
                     version_id=version.id,
@@ -155,12 +147,11 @@ class StorageService:
                     outputs=out_list,
                     content_hash=content_hash,
                     author=version.author,
-                    # letting server_default handle created_at
                 )
                 self.db.add(record)
             
         self.db.commit()
-        logger.info(f"Aggregated version {version.id} into DataRecords (one per output).")
+        logger.info(f"Aggregated version {version.id} with refined differential logic.")
 
     async def process_file_with_sha(
         self,
@@ -173,14 +164,10 @@ class StorageService:
         label: str = None,
         extract_json_key: str = None,
         directory_depth: int = None,
+        group_key: str = None,
     ) -> bool:
         """Process a file and create a snapshot record.
 
-
-        **Snapshot semantics**: a record is ALWAYS created regardless of
-        whether the file changed.  This ensures every version is a
-        self-contained snapshot of all declared output files.
-
         Returns
         -------
         bool
@@ -203,9 +190,10 @@ class StorageService:
         should_extract = bool(extract_json_key and relative_path.lower().endswith(".json"))
         extracted_val = None
 
-        # Calculate group_key based on directory_depth
-        calc_group_key = os.path.dirname(relative_path)  # Default fallback
-        if directory_depth is not None and directory_depth > 0:
+        # Calculate group_key: explicit override > directory_depth > dirname fallback
+        if group_key is not None:
+            calc_group_key = group_key
+        elif directory_depth is not None and directory_depth > 0:
             parts = relative_path.split("/")
             # Remove filename
             if len(parts) > 1:
@@ -214,6 +202,8 @@ class StorageService:
                 calc_group_key = "/".join(parts[:directory_depth])
             else:
                 calc_group_key = "" # File is in root directory
+        else:
+            calc_group_key = os.path.dirname(relative_path)  # Default fallback
 
         async def _extract_val() -> str | None:
             try:

+ 123 - 91
app/services/webhook_service.py

@@ -1,6 +1,9 @@
 import yaml
 import logging
 import fnmatch
+import os
+import asyncio
+import re
 from sqlalchemy.orm import Session
 from app.models import Project, DataVersion
 from app.services.gogs_client import GogsClient
@@ -131,9 +134,9 @@ class WebhookService:
 
             logger.info(f"Processing stage '{stage_name}' with {len(outputs)} output rules")
 
-            # Process outputs and check if any file actually changed
+            # Process ONLY changed files that match output rules (no directory tree fetching)
             has_new_uploads = await self._process_outputs(
-                version, outputs, owner, repo_name, after_sha
+                version, outputs, owner, repo_name, after_sha, all_changed_files
             )
 
             # Check if this version represents a real change (content OR file set)
@@ -183,107 +186,136 @@ class WebhookService:
                         return True
         return False
 
-    async def _process_outputs(
-        self, version, outputs: list, owner: str, repo_name: str, commit_id: str
-    ) -> bool:
-        """Process output rules, create snapshot records for ALL matching files.
+    def _find_matching_output(self, file_path: str, outputs: list) -> dict | None:
+        """Check if a file path matches any manifest output rule using LOCAL logic only.
 
-        Returns
-        -------
-        bool
-            ``True`` if at least one file had actual content changes,
-            ``False`` if every file was unchanged.
+        No Gogs API calls are made — this is pure string/glob matching.
+        Returns the matching output config dict, or None.
         """
-        has_changes = False
-
         for output in outputs:
-            raw_path_pattern = output.get("path", "")
-            # Support both string and list for pattern and exclude
+            raw_path = output.get("path", "")
+            path_pattern = normalize_path(raw_path)
+            is_dir = is_directory_pattern(raw_path)
             patterns = output.get("pattern", "*")
             excludes = output.get("exclude")
 
-            direction = output.get("direction")
-            label = output.get("label")
-            extract_json_key = output.get("extract_json_key")
-            directory_depth = output.get("directory_depth")
-
-            path_pattern = normalize_path(raw_path_pattern)
-            is_dir = is_directory_pattern(raw_path_pattern)
-            dir_path = path_pattern.rstrip("/")
-
             if is_dir:
-                # Directory pattern: fetch files from the closest static parent directory
-                # For `data/*/test/`, that is `data/`
-                import re
-                
-                # Split by first wildcard chunk path
-                wildcard_idx = dir_path.find('*')
-                if wildcard_idx != -1:
-                    static_base = dir_path[:wildcard_idx]
-                    # Trim back to the nearest directory separator
-                    last_sep = static_base.rfind('/')
-                    if last_sep != -1:
-                        static_base = static_base[:last_sep]
-                    else:
-                        static_base = "" # ROOT
+                dir_path = path_pattern.rstrip("/")
+                if '*' in dir_path:
+                    if not fnmatch.fnmatch(file_path, dir_path + "/*") and not fnmatch.fnmatch(file_path, dir_path):
+                        continue
                 else:
-                    static_base = dir_path
-                    
-                static_base = static_base.strip('/')
-                
-                logger.info(f"Fetching directory: {static_base} (to match wildcard path: {dir_path}) with patterns: {patterns}, excludes: {excludes}")
+                    if not file_path.startswith(dir_path + "/"):
+                        continue
+                filename = os.path.basename(file_path)
+                if self._match_patterns(filename, patterns, excludes):
+                    return output
+            else:
+                if file_path == path_pattern:
+                    return output
 
-                files = await self.gogs.get_directory_tree(owner, repo_name, commit_id, static_base)
+        return None
 
-                for file_info in files:
-                    file_path = file_info.get("path")
-                    
-                    # 1. First verify if the full path matches the wildcard directory path provided
-                    if '*' in dir_path:
-                        # e.g dir_path: data/*/test/ -> match: data/*/test/*
-                        if not fnmatch.fnmatch(file_path, dir_path + "/*") and not fnmatch.fnmatch(file_path, dir_path):
-                            continue
-                    else:
-                        if not file_path.startswith(dir_path + "/"):
-                            continue
-                        
-                    # Calculate name relative to the matched base path segment for pattern matching
-                    import os
-                    rel_name = os.path.basename(file_path)
-
-                    if self._match_patterns(rel_name, patterns, excludes):
-                        try:
-                            changed = await self.storage.process_file_with_sha(
-                                version, file_path, file_info.get("sha"), owner, repo_name,
-                                direction=direction, label=label, extract_json_key=extract_json_key,
-                                directory_depth=directory_depth
-                            )
-                            if changed:
-                                has_changes = True
-                        except Exception as e:
-                            logger.error(f"Failed to process file {file_path}: {e}")
+    async def _fetch_and_process_file(
+        self, version, file_path: str, output_config: dict,
+        owner: str, repo_name: str, commit_id: str
+    ) -> bool:
+        """Get file SHA from Gogs and process a single changed file, plus its paired input if configured."""
+        file_info = await self.gogs.get_file_info(owner, repo_name, commit_id, file_path)
+        if not file_info:
+            logger.info(f"File {file_path} not found at commit {commit_id[:8]} (removed). Skipping.")
+            return False
+
+        # Calculate group_key here so both paired input and output can share it
+        directory_depth = output_config.get("directory_depth")
+        if directory_depth is not None and directory_depth > 0:
+            parts = file_path.split("/")
+            if len(parts) > 1:
+                group_key = "/".join(parts[:-1][:directory_depth])
             else:
-                # Single file: fetch only this file's info
-                logger.info(f"Fetching single file: {path_pattern}")
-
-                file_info = await self.gogs.get_file_info(owner, repo_name, commit_id, path_pattern)
-                if file_info:
-                    # Apply pattern matching to the filename for consistency
-                    import os
-                    filename = os.path.basename(path_pattern)
-                    if self._match_patterns(filename, patterns, excludes):
-                        try:
-                            changed = await self.storage.process_file_with_sha(
-                                version, path_pattern, file_info.get("sha"), owner, repo_name,
-                                direction=direction, label=label, extract_json_key=extract_json_key,
-                                directory_depth=directory_depth
+                group_key = ""
+        else:
+            group_key = os.path.dirname(file_path)
+
+        has_change = await self.storage.process_file_with_sha(
+            version, file_path, file_info.get("sha"), owner, repo_name,
+            direction=output_config.get("direction"),
+            label=output_config.get("label"),
+            extract_json_key=output_config.get("extract_json_key"),
+            directory_depth=directory_depth,
+            group_key=group_key,
+        )
+
+        # Handle paired_input active fetching
+        paired_input = output_config.get("paired_input")
+        if paired_input:
+            extract_regex = paired_input.get("extract_regex")
+            path_template = paired_input.get("path_template")
+            if extract_regex and path_template:
+                match = re.search(extract_regex, file_path)
+                if match:
+                    # Construct paired file path using named capture groups
+                    try:
+                        paired_path = path_template.format(**match.groupdict())
+                    except KeyError as e:
+                        logger.error(f"Failed to format paired_path: missing {e} in regex match for {file_path}")
+                        paired_path = None
+                    
+                    if paired_path:
+                        # Actively fetch paired file info from Gogs
+                        paired_info = await self.gogs.get_file_info(owner, repo_name, commit_id, paired_path)
+                        if paired_info:
+                            paired_changed = await self.storage.process_file_with_sha(
+                                version, paired_path, paired_info.get("sha"), owner, repo_name,
+                                direction=paired_input.get("direction", "input"),
+                                label=paired_input.get("label"),
+                                extract_json_key=paired_input.get("extract_json_key"),
+                                group_key=group_key,  # Link them together!
                             )
-                            if changed:
-                                has_changes = True
-                        except Exception as e:
-                            logger.error(f"Failed to process file {path_pattern}: {e}")
-                else:
-                    logger.warning(f"File not found: {path_pattern}")
+                            has_change = has_change or paired_changed
+                        else:
+                            logger.warning(f"Paired input file not found at commit {commit_id[:8]}: {paired_path}")
+
+        return has_change
+
+    async def _process_outputs(
+        self, version, outputs: list, owner: str, repo_name: str, commit_id: str,
+        changed_files: set[str]
+    ) -> bool:
+        """Process ONLY changed files that match manifest output rules.
+
+        Instead of fetching entire directory trees from Gogs API (slow),
+        we match the webhook payload's changed-file list against manifest
+        rules using LOCAL string/glob logic — zero API calls for matching.
+
+        Returns True if at least one file had actual content changes.
+        """
+        # Step 1: Local matching — zero API calls
+        matched_files = []
+        for file_path in changed_files:
+            matched_output = self._find_matching_output(file_path, outputs)
+            if matched_output is not None:
+                matched_files.append((file_path, matched_output))
+
+        if not matched_files:
+            logger.info("No changed files matched any output rule.")
+            return False
+
+        logger.info(f"Matched {len(matched_files)} changed file(s) to output rules. Processing in parallel.")
+
+        # Step 2: Fetch file info + download/upload in parallel
+        tasks = [
+            self._fetch_and_process_file(version, fp, oc, owner, repo_name, commit_id)
+            for fp, oc in matched_files
+        ]
+
+        has_changes = False
+        results = await asyncio.gather(*tasks, return_exceptions=True)
+        for i, res in enumerate(results):
+            if isinstance(res, Exception):
+                logger.error(f"Error processing {matched_files[i][0]}: {res}")
+            elif res is True:
+                has_changes = True
 
         return has_changes
 

+ 22 - 0
manifest.yaml.example

@@ -57,6 +57,23 @@ stages:
         label: 灵感点
         extract_json_key: "data.idea_content"  # 会解析 JSON 并提取对应 key 的值保存
         directory_depth: 2
+        
+  # ---------- 阶段 5:动态配对输入文件(自动提取) ----------
+  - name: auto_paired_data
+    outputs:
+      # 示例 G:当产生一个结果文件时,根据正则表达式从输出路径中提取变量,
+      #          主动去 Gogs 拉取与其配对的原始输入文件,并将它们同归到一组中。
+      - path: aiddit/decode/topic/result/
+        pattern: "*.json"
+        direction: output
+        label: 标准化结果
+        paired_input:
+          # 正则表达式,使用命名捕获组如 (?P<name>...) 提取变量
+          extract_regex: "aiddit/decode/topic/result/(?P<name>[^/]+)/final_normalization/(?P<filename>[^/]+)"
+          # 使用提取的变量构造对应的 input 路径
+          path_template: "aigc_data/{name}/{filename}"
+          direction: input
+          label: 原始数据
 
 # ============================================================
 # 字段说明
@@ -80,6 +97,11 @@ stages:
 #     - label     (可选) 该文件的业务称呼/标签(如 '帖子输入', '灵感点' 等)
 #     - extract_json_key (可选) 针对 JSON 文件,配置要提取解析的 json key 路径(支持由于嵌套的 . 分隔,例如 'data.content')。提取的值会被记录在数据库中。
 #     - directory_depth  (可选) 定义这组规则生成的文件关联用的父目录深度(如 1 或 2,用来将不同子目录的关联文件合并到一行展示)。
+#     - paired_input     (可选) 动态输入映射规则,用于输出生成后主动拉取关联的输入。包含:
+#       - extract_regex:   提取路径变量的正则表达式 (必需使用命名捕获组,如 (?P<var>...))
+#       - path_template:   组装对应输入文件路径的模板 (如 "aigc_data/{var}/file.json")
+#       - direction:       配对文件的方向 (通常为 "input")
+#       - label:           配对文件的业务名称
 #
 # ============================================================
 # 工作流程