storage_service.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. import os
  2. from sqlalchemy.orm import Session
  3. from sqlalchemy.exc import IntegrityError
  4. from app.models import Project, DataVersion, DataFile, DataRecord
  5. from app.config import settings
  6. from app.services.gogs_client import GogsClient
  7. from app.services.oss_client import oss_client
  8. import logging
  9. import hashlib
  10. logger = logging.getLogger(__name__)
  11. class StorageService:
  12. def __init__(self, db: Session, gogs_client: GogsClient):
  13. self.db = db
  14. self.gogs = gogs_client
  15. def get_or_create_project(self, project_name: str, description: str = None) -> Project:
  16. project = self.db.query(Project).filter(Project.project_name == project_name).first()
  17. if not project:
  18. project = Project(project_name=project_name, description=description)
  19. self.db.add(project)
  20. self.db.commit()
  21. self.db.refresh(project)
  22. return project
  23. def create_version(self, project_id: str, stage: str, commit_id: str, author: str, manifest: str, commit_message: str = None) -> DataVersion | None:
  24. """Create a new data version. Returns None if a duplicate exists (IntegrityError)."""
  25. version = DataVersion(
  26. project_id=project_id,
  27. stage=stage,
  28. commit_id=commit_id,
  29. author=author,
  30. commit_message=commit_message,
  31. manifest_snapshot=manifest
  32. )
  33. try:
  34. self.db.add(version)
  35. self.db.commit()
  36. self.db.refresh(version)
  37. return version
  38. except IntegrityError:
  39. self.db.rollback()
  40. logger.info(f"Version already exists for project {project_id}, stage {stage}, commit {commit_id[:8]}.")
  41. return None
  42. def rollback_version(self, version: DataVersion):
  43. """Remove a version and all its associated file records."""
  44. self.db.query(DataFile).filter(DataFile.version_id == version.id).delete()
  45. self.db.delete(version)
  46. self.db.commit()
  47. logger.info(f"Rolled back unchanged version {version.id}")
  48. def is_snapshot_changed(self, version: DataVersion, has_new_uploads: bool) -> bool:
  49. """
  50. Determine if this version represents a meaningful change.
  51. With differential processing (only webhook-changed files are processed),
  52. a version is meaningful if any file had new content uploaded to OSS.
  53. """
  54. return has_new_uploads
  55. def aggregate_version_records(self, version: DataVersion):
  56. """Aggregate files in a version into DataRecord groups based on parent directory."""
  57. from collections import defaultdict
  58. # 1. Clean existing records for this version (idempotency)
  59. self.db.query(DataRecord).filter(DataRecord.version_id == version.id).delete()
  60. files = self.db.query(DataFile).filter(DataFile.version_id == version.id).all()
  61. # 2. Group by dirname
  62. groups = defaultdict(lambda: {"inputs": [], "outputs": []})
  63. for f in files:
  64. # Group key falls back to immediate parent directory if not explicitly saved in f.group_key
  65. group_key = f.group_key if f.group_key is not None else os.path.dirname(f.relative_path)
  66. file_data = {
  67. "id": f.id,
  68. "relative_path": f.relative_path,
  69. "file_type": f.file_type,
  70. "file_size": f.file_size,
  71. "file_sha": f.file_sha,
  72. "direction": f.direction,
  73. "label": f.label,
  74. "extracted_value": f.extracted_value,
  75. "storage_path": f.storage_path
  76. }
  77. if f.direction == "input":
  78. groups[group_key]["inputs"].append(file_data)
  79. else:
  80. # Treat 'output' or None as output by default for rendering purposes
  81. groups[group_key]["outputs"].append(file_data)
  82. # 3. Insert aggregated records (One record per output file, with differential logic)
  83. for group_key, data in groups.items():
  84. inputs = data["inputs"]
  85. outputs = data["outputs"]
  86. # 预先获取该 group_key 下所有输出路径的最新状态
  87. # 用于判定当前这次 Commit 是否真的产生了变化
  88. latest_hashes = {}
  89. past_records = (
  90. self.db.query(DataRecord)
  91. .filter(
  92. DataRecord.project_id == version.project_id,
  93. DataRecord.stage == version.stage,
  94. DataRecord.group_key == group_key,
  95. DataRecord.version_id != version.id
  96. )
  97. .order_by(DataRecord.created_at.desc())
  98. .limit(200) # 覆盖常见的目录文件数
  99. .all()
  100. )
  101. for pr in past_records:
  102. # 记录每一个输出路径对应的最新 hash
  103. p_out = pr.outputs[0]['relative_path'] if pr.outputs else None
  104. if p_out not in latest_hashes:
  105. latest_hashes[p_out] = pr.content_hash
  106. # 如果没有输出,也创建一个逻辑组
  107. output_groups = [[o] for o in outputs] if outputs else [[]]
  108. for out_list in output_groups:
  109. # 1. 计算当前指纹:输入集合 SHA + 对应输出 SHA
  110. all_shas = [f["file_sha"] for f in inputs] + [f["file_sha"] for f in out_list]
  111. all_shas.sort()
  112. combined_string = "|".join(all_shas)
  113. content_hash = hashlib.sha256(combined_string.encode('utf-8')).hexdigest()
  114. # 2. 差异化判定
  115. out_path = out_list[0]['relative_path'] if out_list else None
  116. if out_path in latest_hashes and latest_hashes[out_path] == content_hash:
  117. logger.info(f"Skipping unchanged record: {group_key} -> {out_path}")
  118. continue
  119. # 3. 只有变化了才记录
  120. record = DataRecord(
  121. project_id=version.project_id,
  122. version_id=version.id,
  123. stage=version.stage,
  124. commit_id=version.commit_id,
  125. commit_message=version.commit_message,
  126. group_key=group_key,
  127. inputs=inputs,
  128. outputs=out_list,
  129. content_hash=content_hash,
  130. author=version.author,
  131. )
  132. self.db.add(record)
  133. self.db.commit()
  134. logger.info(f"Aggregated version {version.id} with refined differential logic.")
  135. async def process_file_with_sha(
  136. self,
  137. version: DataVersion,
  138. relative_path: str,
  139. file_sha: str,
  140. owner: str,
  141. repo: str,
  142. direction: str = None,
  143. label: str = None,
  144. extract_json_key: str = None,
  145. directory_depth: int = None,
  146. group_key: str = None,
  147. content_ref: str | None = None,
  148. ) -> bool:
  149. """Process a file and create a snapshot record.
  150. Returns
  151. -------
  152. bool
  153. ``True`` if the file content actually changed (new upload),
  154. ``False`` if unchanged (record reuses previous OSS key).
  155. """
  156. # Find the latest record for this file in the same project + stage
  157. last_file = (
  158. self.db.query(DataFile)
  159. .join(DataVersion)
  160. .filter(
  161. DataVersion.project_id == version.project_id,
  162. DataVersion.stage == version.stage,
  163. DataFile.relative_path == relative_path,
  164. )
  165. .order_by(DataVersion.created_at.desc())
  166. .first()
  167. )
  168. should_extract = bool(extract_json_key and relative_path.lower().endswith(".json"))
  169. extracted_val = None
  170. # Calculate group_key: explicit override > directory_depth > dirname fallback
  171. if group_key is not None:
  172. calc_group_key = group_key
  173. elif directory_depth is not None and directory_depth > 0:
  174. parts = relative_path.split("/")
  175. # Remove filename
  176. if len(parts) > 1:
  177. parts = parts[:-1]
  178. # Combine up to directory_depth
  179. calc_group_key = "/".join(parts[:directory_depth])
  180. else:
  181. calc_group_key = "" # File is in root directory
  182. else:
  183. calc_group_key = os.path.dirname(relative_path) # Default fallback
  184. download_ref = content_ref or version.commit_id
  185. async def _extract_val() -> str | None:
  186. try:
  187. content_bytes = await self.gogs.get_file_content(
  188. owner, repo, relative_path, ref=download_ref
  189. )
  190. if not content_bytes:
  191. return None
  192. import json
  193. parsed = json.loads(content_bytes.decode('utf-8'))
  194. val = parsed
  195. for key_part in extract_json_key.split("."):
  196. if isinstance(val, dict):
  197. val = val.get(key_part)
  198. else:
  199. val = None
  200. break
  201. if val is not None:
  202. if isinstance(val, (dict, list)):
  203. return json.dumps(val, ensure_ascii=False)
  204. return str(val)
  205. except Exception as e:
  206. logger.warning(f"Failed to extract json key {extract_json_key} from {relative_path}: {e}")
  207. return None
  208. if last_file and last_file.file_sha == file_sha:
  209. # ── Unchanged: reuse previous OSS key, still record a snapshot entry ──
  210. # Optimization: Try to reuse previously extracted value if the SHA hasn't changed
  211. if should_extract:
  212. if last_file.extracted_value is not None:
  213. extracted_val = last_file.extracted_value
  214. else:
  215. extracted_val = await _extract_val()
  216. new_file = DataFile(
  217. version_id=version.id,
  218. relative_path=relative_path,
  219. storage_path=last_file.storage_path,
  220. file_size=last_file.file_size,
  221. file_type=last_file.file_type,
  222. file_sha=file_sha,
  223. direction=direction,
  224. label=label,
  225. extracted_value=extracted_val,
  226. group_key=calc_group_key,
  227. )
  228. self.db.add(new_file)
  229. self.db.commit()
  230. logger.info(
  231. f"File {relative_path} (SHA: {file_sha[:8]}…) "
  232. f"unchanged — snapshot recorded, reusing OSS key"
  233. )
  234. return False
  235. # ── Changed or new: download → upload → record ──
  236. logger.info(f"File {relative_path} (SHA: {file_sha[:8]}…) changed — downloading")
  237. content = await self.gogs.get_file_content(owner, repo, relative_path, ref=download_ref)
  238. file_size = len(content)
  239. project_name = version.project.project_name
  240. oss_key = oss_client._build_key(
  241. project_name, version.stage, version.commit_id, relative_path
  242. )
  243. oss_client.upload(oss_key, content)
  244. if should_extract:
  245. try:
  246. import json
  247. parsed = json.loads(content.decode('utf-8'))
  248. val = parsed
  249. for key_part in extract_json_key.split("."):
  250. if isinstance(val, dict):
  251. val = val.get(key_part)
  252. else:
  253. val = None
  254. break
  255. if val is not None:
  256. if isinstance(val, (dict, list)):
  257. extracted_val = json.dumps(val, ensure_ascii=False)
  258. else:
  259. extracted_val = str(val)
  260. except Exception as e:
  261. logger.warning(f"Failed to extract json key {extract_json_key} from {relative_path}: {e}")
  262. new_file = DataFile(
  263. version_id=version.id,
  264. relative_path=relative_path,
  265. storage_path=oss_key,
  266. file_size=file_size,
  267. file_type=os.path.splitext(relative_path)[1],
  268. file_sha=file_sha,
  269. direction=direction,
  270. label=label,
  271. extracted_value=extracted_val,
  272. group_key=calc_group_key,
  273. )
  274. self.db.add(new_file)
  275. self.db.commit()
  276. return True