Przeglądaj źródła

fix:asyncio.gather分支

tanjingyu 1 miesiąc temu
rodzic
commit
5ec7232cba
2 zmienionych plików z 22 dodań i 24 usunięć
  1. 17 22
      app/services/gogs_client.py
  2. 5 2
      app/services/storage_service.py

+ 17 - 22
app/services/gogs_client.py

@@ -126,30 +126,24 @@ class GogsClient:
             return None
 
     async def get_directory_tree(self, owner: str, repo: str, commit_id: str, dir_path: str) -> list:
-        """Get all files under a specific directory (recursive).
-
-        Args:
-            dir_path: Directory path without trailing slash (e.g., "data/output")
-
-        Returns:
-            List of file info dicts with 'path', 'sha', 'size', 'type'
-        """
+        """Get all files under a specific directory (recursive) using concurrency."""
+        import asyncio
         all_files = []
 
-        async def fetch_contents(path: str):
-            """Recursively fetch directory contents using contents API."""
-            url = f"{self.base_url}/api/v1/repos/{owner}/{repo}/contents/{path}?ref={commit_id}"
-            try:
-                async with httpx.AsyncClient(timeout=_DEFAULT_TIMEOUT) as client:
-                    resp = await client.get(url, headers=self.headers)
+        async with httpx.AsyncClient(timeout=_DEFAULT_TIMEOUT, headers=self.headers) as client:
+            async def fetch_contents(path: str):
+                """Recursively fetch directory contents using contents API in parallel."""
+                url = f"{self.base_url}/api/v1/repos/{owner}/{repo}/contents/{path}?ref={commit_id}"
+                try:
+                    resp = await client.get(url)
                     if resp.status_code == 404:
                         logger.warning(f"Directory not found: {path}")
                         return
                     resp.raise_for_status()
                     data = resp.json()
 
-                    # contents API returns list for directories
                     if isinstance(data, list):
+                        tasks = []
                         for item in data:
                             if item.get("type") == "file":
                                 all_files.append({
@@ -159,13 +153,14 @@ class GogsClient:
                                     "type": "blob"
                                 })
                             elif item.get("type") == "dir":
-                                # Recursively fetch subdirectory
-                                await fetch_contents(item.get("path"))
-
-            except httpx.HTTPStatusError as e:
-                logger.error(f"Failed to get contents for {path}: {e}")
-
-        await fetch_contents(dir_path)
+                                tasks.append(fetch_contents(item.get("path")))
+                        
+                        if tasks:
+                            await asyncio.gather(*tasks)
+                except Exception as e:
+                    logger.error(f"Failed to get contents for {path}: {e}")
+
+            await fetch_contents(dir_path)
         return all_files
 
     async def get_file_content(self, owner: str, repo: str, commit_id: str, file_path: str) -> bytes:

+ 5 - 2
app/services/storage_service.py

@@ -230,9 +230,12 @@ class StorageService:
 
         if last_file and last_file.file_sha == file_sha:
             # ── Unchanged: reuse previous OSS key, still record a snapshot entry ──
-            # Re-extract if needed, or reuse previous extracted_val
+            # Optimization: Try to reuse previously extracted value if the SHA hasn't changed
             if should_extract:
-                extracted_val = await _extract_val()
+                if last_file.extracted_value is not None:
+                    extracted_val = last_file.extracted_value
+                else:
+                    extracted_val = await _extract_val()
             
             new_file = DataFile(
                 version_id=version.id,