| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309 |
- import os
- from sqlalchemy.orm import Session
- from sqlalchemy.exc import IntegrityError
- from app.models import Project, DataVersion, DataFile, DataRecord
- from app.config import settings
- from app.services.gogs_client import GogsClient
- from app.services.oss_client import oss_client
- import logging
- import hashlib
- logger = logging.getLogger(__name__)
- class StorageService:
- def __init__(self, db: Session, gogs_client: GogsClient):
- self.db = db
- self.gogs = gogs_client
- def get_or_create_project(self, project_name: str, description: str = None) -> Project:
- project = self.db.query(Project).filter(Project.project_name == project_name).first()
- if not project:
- project = Project(project_name=project_name, description=description)
- self.db.add(project)
- self.db.commit()
- self.db.refresh(project)
- return project
- def create_version(self, project_id: str, stage: str, commit_id: str, author: str, manifest: str, commit_message: str = None) -> DataVersion | None:
- """Create a new data version. Returns None if a duplicate exists (IntegrityError)."""
- version = DataVersion(
- project_id=project_id,
- stage=stage,
- commit_id=commit_id,
- author=author,
- commit_message=commit_message,
- manifest_snapshot=manifest
- )
- try:
- self.db.add(version)
- self.db.commit()
- self.db.refresh(version)
- return version
- except IntegrityError:
- self.db.rollback()
- logger.info(f"Version already exists for project {project_id}, stage {stage}, commit {commit_id[:8]}.")
- return None
- def rollback_version(self, version: DataVersion):
- """Remove a version and all its associated file records."""
- self.db.query(DataFile).filter(DataFile.version_id == version.id).delete()
- self.db.delete(version)
- self.db.commit()
- logger.info(f"Rolled back unchanged version {version.id}")
- def is_snapshot_changed(self, version: DataVersion, has_new_uploads: bool) -> bool:
- """
- Determine if this version represents a meaningful change.
- With differential processing (only webhook-changed files are processed),
- a version is meaningful if any file had new content uploaded to OSS.
- """
- return has_new_uploads
- def aggregate_version_records(self, version: DataVersion):
- """Aggregate files in a version into DataRecord groups based on parent directory."""
- from collections import defaultdict
-
- # 1. Clean existing records for this version (idempotency)
- self.db.query(DataRecord).filter(DataRecord.version_id == version.id).delete()
-
- files = self.db.query(DataFile).filter(DataFile.version_id == version.id).all()
-
- # 2. Group by dirname
- groups = defaultdict(lambda: {"inputs": [], "outputs": []})
-
- for f in files:
- # Group key falls back to immediate parent directory if not explicitly saved in f.group_key
- group_key = f.group_key if f.group_key is not None else os.path.dirname(f.relative_path)
-
- file_data = {
- "id": f.id,
- "relative_path": f.relative_path,
- "file_type": f.file_type,
- "file_size": f.file_size,
- "file_sha": f.file_sha,
- "direction": f.direction,
- "label": f.label,
- "extracted_value": f.extracted_value,
- "storage_path": f.storage_path
- }
- if f.direction == "input":
- groups[group_key]["inputs"].append(file_data)
- else:
- # Treat 'output' or None as output by default for rendering purposes
- groups[group_key]["outputs"].append(file_data)
- # 3. Insert aggregated records (One record per output file, with differential logic)
- for group_key, data in groups.items():
- inputs = data["inputs"]
- outputs = data["outputs"]
-
- # 预先获取该 group_key 下所有输出路径的最新状态
- # 用于判定当前这次 Commit 是否真的产生了变化
- latest_hashes = {}
- past_records = (
- self.db.query(DataRecord)
- .filter(
- DataRecord.project_id == version.project_id,
- DataRecord.stage == version.stage,
- DataRecord.group_key == group_key,
- DataRecord.version_id != version.id
- )
- .order_by(DataRecord.created_at.desc())
- .limit(200) # 覆盖常见的目录文件数
- .all()
- )
- for pr in past_records:
- # 记录每一个输出路径对应的最新 hash
- p_out = pr.outputs[0]['relative_path'] if pr.outputs else None
- if p_out not in latest_hashes:
- latest_hashes[p_out] = pr.content_hash
- # 如果没有输出,也创建一个逻辑组
- output_groups = [[o] for o in outputs] if outputs else [[]]
-
- for out_list in output_groups:
- # 1. 计算当前指纹:输入集合 SHA + 对应输出 SHA
- all_shas = [f["file_sha"] for f in inputs] + [f["file_sha"] for f in out_list]
- all_shas.sort()
- combined_string = "|".join(all_shas)
- content_hash = hashlib.sha256(combined_string.encode('utf-8')).hexdigest()
- # 2. 差异化判定
- out_path = out_list[0]['relative_path'] if out_list else None
- if out_path in latest_hashes and latest_hashes[out_path] == content_hash:
- logger.info(f"Skipping unchanged record: {group_key} -> {out_path}")
- continue
- # 3. 只有变化了才记录
- record = DataRecord(
- project_id=version.project_id,
- version_id=version.id,
- stage=version.stage,
- commit_id=version.commit_id,
- commit_message=version.commit_message,
- group_key=group_key,
- inputs=inputs,
- outputs=out_list,
- content_hash=content_hash,
- author=version.author,
- )
- self.db.add(record)
-
- self.db.commit()
- logger.info(f"Aggregated version {version.id} with refined differential logic.")
- async def process_file_with_sha(
- self,
- version: DataVersion,
- relative_path: str,
- file_sha: str,
- owner: str,
- repo: str,
- direction: str = None,
- label: str = None,
- extract_json_key: str = None,
- directory_depth: int = None,
- group_key: str = None,
- content_ref: str | None = None,
- ) -> bool:
- """Process a file and create a snapshot record.
- Returns
- -------
- bool
- ``True`` if the file content actually changed (new upload),
- ``False`` if unchanged (record reuses previous OSS key).
- """
- # Find the latest record for this file in the same project + stage
- last_file = (
- self.db.query(DataFile)
- .join(DataVersion)
- .filter(
- DataVersion.project_id == version.project_id,
- DataVersion.stage == version.stage,
- DataFile.relative_path == relative_path,
- )
- .order_by(DataVersion.created_at.desc())
- .first()
- )
- should_extract = bool(extract_json_key and relative_path.lower().endswith(".json"))
- extracted_val = None
- # Calculate group_key: explicit override > directory_depth > dirname fallback
- if group_key is not None:
- calc_group_key = group_key
- elif directory_depth is not None and directory_depth > 0:
- parts = relative_path.split("/")
- # Remove filename
- if len(parts) > 1:
- parts = parts[:-1]
- # Combine up to directory_depth
- calc_group_key = "/".join(parts[:directory_depth])
- else:
- calc_group_key = "" # File is in root directory
- else:
- calc_group_key = os.path.dirname(relative_path) # Default fallback
- download_ref = content_ref or version.commit_id
- async def _extract_val() -> str | None:
- try:
- content_bytes = await self.gogs.get_file_content(
- owner, repo, relative_path, ref=download_ref
- )
- if not content_bytes:
- return None
- import json
- parsed = json.loads(content_bytes.decode('utf-8'))
- val = parsed
- for key_part in extract_json_key.split("."):
- if isinstance(val, dict):
- val = val.get(key_part)
- else:
- val = None
- break
- if val is not None:
- if isinstance(val, (dict, list)):
- return json.dumps(val, ensure_ascii=False)
- return str(val)
- except Exception as e:
- logger.warning(f"Failed to extract json key {extract_json_key} from {relative_path}: {e}")
- return None
- if last_file and last_file.file_sha == file_sha:
- # ── Unchanged: reuse previous OSS key, still record a snapshot entry ──
- # Optimization: Try to reuse previously extracted value if the SHA hasn't changed
- if should_extract:
- if last_file.extracted_value is not None:
- extracted_val = last_file.extracted_value
- else:
- extracted_val = await _extract_val()
-
- new_file = DataFile(
- version_id=version.id,
- relative_path=relative_path,
- storage_path=last_file.storage_path,
- file_size=last_file.file_size,
- file_type=last_file.file_type,
- file_sha=file_sha,
- direction=direction,
- label=label,
- extracted_value=extracted_val,
- group_key=calc_group_key,
- )
- self.db.add(new_file)
- self.db.commit()
- logger.info(
- f"File {relative_path} (SHA: {file_sha[:8]}…) "
- f"unchanged — snapshot recorded, reusing OSS key"
- )
- return False
- # ── Changed or new: download → upload → record ──
- logger.info(f"File {relative_path} (SHA: {file_sha[:8]}…) changed — downloading")
- content = await self.gogs.get_file_content(owner, repo, relative_path, ref=download_ref)
- file_size = len(content)
- project_name = version.project.project_name
- oss_key = oss_client._build_key(
- project_name, version.stage, version.commit_id, relative_path
- )
- oss_client.upload(oss_key, content)
- if should_extract:
- try:
- import json
- parsed = json.loads(content.decode('utf-8'))
- val = parsed
- for key_part in extract_json_key.split("."):
- if isinstance(val, dict):
- val = val.get(key_part)
- else:
- val = None
- break
- if val is not None:
- if isinstance(val, (dict, list)):
- extracted_val = json.dumps(val, ensure_ascii=False)
- else:
- extracted_val = str(val)
- except Exception as e:
- logger.warning(f"Failed to extract json key {extract_json_key} from {relative_path}: {e}")
- new_file = DataFile(
- version_id=version.id,
- relative_path=relative_path,
- storage_path=oss_key,
- file_size=file_size,
- file_type=os.path.splitext(relative_path)[1],
- file_sha=file_sha,
- direction=direction,
- label=label,
- extracted_value=extracted_val,
- group_key=calc_group_key,
- )
- self.db.add(new_file)
- self.db.commit()
- return True
|