import os from sqlalchemy.orm import Session from sqlalchemy.exc import IntegrityError from app.models import Project, DataVersion, DataFile, DataRecord from app.config import settings from app.services.gogs_client import GogsClient from app.services.oss_client import oss_client import logging import hashlib logger = logging.getLogger(__name__) class StorageService: def __init__(self, db: Session, gogs_client: GogsClient): self.db = db self.gogs = gogs_client def get_or_create_project(self, project_name: str, description: str = None) -> Project: project = self.db.query(Project).filter(Project.project_name == project_name).first() if not project: project = Project(project_name=project_name, description=description) self.db.add(project) self.db.commit() self.db.refresh(project) return project def create_version(self, project_id: str, stage: str, commit_id: str, author: str, manifest: str, commit_message: str = None) -> DataVersion | None: """Create a new data version. Returns None if a duplicate exists (IntegrityError).""" version = DataVersion( project_id=project_id, stage=stage, commit_id=commit_id, author=author, commit_message=commit_message, manifest_snapshot=manifest ) try: self.db.add(version) self.db.commit() self.db.refresh(version) return version except IntegrityError: self.db.rollback() logger.info(f"Version already exists for project {project_id}, stage {stage}, commit {commit_id[:8]}.") return None def rollback_version(self, version: DataVersion): """Remove a version and all its associated file records.""" self.db.query(DataFile).filter(DataFile.version_id == version.id).delete() self.db.delete(version) self.db.commit() logger.info(f"Rolled back unchanged version {version.id}") def is_snapshot_changed(self, version: DataVersion, has_new_uploads: bool) -> bool: """ Determine if this version represents a meaningful change. With differential processing (only webhook-changed files are processed), a version is meaningful if any file had new content uploaded to OSS. """ return has_new_uploads def aggregate_version_records(self, version: DataVersion): """Aggregate files in a version into DataRecord groups based on parent directory.""" from collections import defaultdict # 1. Clean existing records for this version (idempotency) self.db.query(DataRecord).filter(DataRecord.version_id == version.id).delete() files = self.db.query(DataFile).filter(DataFile.version_id == version.id).all() # 2. Group by dirname groups = defaultdict(lambda: {"inputs": [], "outputs": []}) for f in files: # Group key falls back to immediate parent directory if not explicitly saved in f.group_key group_key = f.group_key if f.group_key is not None else os.path.dirname(f.relative_path) file_data = { "id": f.id, "relative_path": f.relative_path, "file_type": f.file_type, "file_size": f.file_size, "file_sha": f.file_sha, "direction": f.direction, "label": f.label, "extracted_value": f.extracted_value, "storage_path": f.storage_path } if f.direction == "input": groups[group_key]["inputs"].append(file_data) else: # Treat 'output' or None as output by default for rendering purposes groups[group_key]["outputs"].append(file_data) # 3. Insert aggregated records (One record per output file, with differential logic) for group_key, data in groups.items(): inputs = data["inputs"] outputs = data["outputs"] # 预先获取该 group_key 下所有输出路径的最新状态 # 用于判定当前这次 Commit 是否真的产生了变化 latest_hashes = {} past_records = ( self.db.query(DataRecord) .filter( DataRecord.project_id == version.project_id, DataRecord.stage == version.stage, DataRecord.group_key == group_key, DataRecord.version_id != version.id ) .order_by(DataRecord.created_at.desc()) .limit(200) # 覆盖常见的目录文件数 .all() ) for pr in past_records: # 记录每一个输出路径对应的最新 hash p_out = pr.outputs[0]['relative_path'] if pr.outputs else None if p_out not in latest_hashes: latest_hashes[p_out] = pr.content_hash # 如果没有输出,也创建一个逻辑组 output_groups = [[o] for o in outputs] if outputs else [[]] for out_list in output_groups: # 1. 计算当前指纹:输入集合 SHA + 对应输出 SHA all_shas = [f["file_sha"] for f in inputs] + [f["file_sha"] for f in out_list] all_shas.sort() combined_string = "|".join(all_shas) content_hash = hashlib.sha256(combined_string.encode('utf-8')).hexdigest() # 2. 差异化判定 out_path = out_list[0]['relative_path'] if out_list else None if out_path in latest_hashes and latest_hashes[out_path] == content_hash: logger.info(f"Skipping unchanged record: {group_key} -> {out_path}") continue # 3. 只有变化了才记录 record = DataRecord( project_id=version.project_id, version_id=version.id, stage=version.stage, commit_id=version.commit_id, commit_message=version.commit_message, group_key=group_key, inputs=inputs, outputs=out_list, content_hash=content_hash, author=version.author, ) self.db.add(record) self.db.commit() logger.info(f"Aggregated version {version.id} with refined differential logic.") async def process_file_with_sha( self, version: DataVersion, relative_path: str, file_sha: str, owner: str, repo: str, direction: str = None, label: str = None, extract_json_key: str = None, directory_depth: int = None, group_key: str = None, content_ref: str | None = None, ) -> bool: """Process a file and create a snapshot record. Returns ------- bool ``True`` if the file content actually changed (new upload), ``False`` if unchanged (record reuses previous OSS key). """ # Find the latest record for this file in the same project + stage last_file = ( self.db.query(DataFile) .join(DataVersion) .filter( DataVersion.project_id == version.project_id, DataVersion.stage == version.stage, DataFile.relative_path == relative_path, ) .order_by(DataVersion.created_at.desc()) .first() ) should_extract = bool(extract_json_key and relative_path.lower().endswith(".json")) extracted_val = None # Calculate group_key: explicit override > directory_depth > dirname fallback if group_key is not None: calc_group_key = group_key elif directory_depth is not None and directory_depth > 0: parts = relative_path.split("/") # Remove filename if len(parts) > 1: parts = parts[:-1] # Combine up to directory_depth calc_group_key = "/".join(parts[:directory_depth]) else: calc_group_key = "" # File is in root directory else: calc_group_key = os.path.dirname(relative_path) # Default fallback download_ref = content_ref or version.commit_id async def _extract_val() -> str | None: try: content_bytes = await self.gogs.get_file_content( owner, repo, relative_path, ref=download_ref ) if not content_bytes: return None import json parsed = json.loads(content_bytes.decode('utf-8')) val = parsed for key_part in extract_json_key.split("."): if isinstance(val, dict): val = val.get(key_part) else: val = None break if val is not None: if isinstance(val, (dict, list)): return json.dumps(val, ensure_ascii=False) return str(val) except Exception as e: logger.warning(f"Failed to extract json key {extract_json_key} from {relative_path}: {e}") return None if last_file and last_file.file_sha == file_sha: # ── Unchanged: reuse previous OSS key, still record a snapshot entry ── # Optimization: Try to reuse previously extracted value if the SHA hasn't changed if should_extract: if last_file.extracted_value is not None: extracted_val = last_file.extracted_value else: extracted_val = await _extract_val() new_file = DataFile( version_id=version.id, relative_path=relative_path, storage_path=last_file.storage_path, file_size=last_file.file_size, file_type=last_file.file_type, file_sha=file_sha, direction=direction, label=label, extracted_value=extracted_val, group_key=calc_group_key, ) self.db.add(new_file) self.db.commit() logger.info( f"File {relative_path} (SHA: {file_sha[:8]}…) " f"unchanged — snapshot recorded, reusing OSS key" ) return False # ── Changed or new: download → upload → record ── logger.info(f"File {relative_path} (SHA: {file_sha[:8]}…) changed — downloading") content = await self.gogs.get_file_content(owner, repo, relative_path, ref=download_ref) file_size = len(content) project_name = version.project.project_name oss_key = oss_client._build_key( project_name, version.stage, version.commit_id, relative_path ) oss_client.upload(oss_key, content) if should_extract: try: import json parsed = json.loads(content.decode('utf-8')) val = parsed for key_part in extract_json_key.split("."): if isinstance(val, dict): val = val.get(key_part) else: val = None break if val is not None: if isinstance(val, (dict, list)): extracted_val = json.dumps(val, ensure_ascii=False) else: extracted_val = str(val) except Exception as e: logger.warning(f"Failed to extract json key {extract_json_key} from {relative_path}: {e}") new_file = DataFile( version_id=version.id, relative_path=relative_path, storage_path=oss_key, file_size=file_size, file_type=os.path.splitext(relative_path)[1], file_sha=file_sha, direction=direction, label=label, extracted_value=extracted_val, group_key=calc_group_key, ) self.db.add(new_file) self.db.commit() return True