""" Trace Upload API 提供 Trace 压缩包上传和导入功能 """ import os import shutil import tempfile import zipfile from typing import List, Dict, Any from fastapi import APIRouter, UploadFile, File, HTTPException from pydantic import BaseModel from .protocols import TraceStore router = APIRouter(prefix="/api/traces", tags=["traces"]) # ===== Response 模型 ===== class UploadResponse(BaseModel): """上传响应""" success: bool message: str imported_traces: List[str] failed_traces: List[Dict[str, str]] # ===== 全局 TraceStore ===== _trace_store: TraceStore | None = None def set_trace_store(store: TraceStore): """设置 TraceStore 实例""" global _trace_store _trace_store = store def get_trace_store() -> TraceStore: """获取 TraceStore 实例""" if _trace_store is None: raise RuntimeError("TraceStore not initialized") return _trace_store # ===== 辅助函数 ===== def is_valid_trace_folder(folder_path: str) -> bool: """ 验证是否是有效的 trace 文件夹 有效的 trace 文件夹应该包含: - meta.json 文件 """ return os.path.isfile(os.path.join(folder_path, "meta.json")) def extract_and_import_traces(zip_path: str, base_trace_path: str) -> tuple[List[str], List[Dict[str, str]]]: """ 解压并导入 traces Returns: (imported_traces, failed_traces) """ import logging logger = logging.getLogger(__name__) imported = [] failed = [] # 创建临时目录 with tempfile.TemporaryDirectory() as temp_dir: try: # 解压文件 with zipfile.ZipFile(zip_path, 'r') as zip_ref: zip_ref.extractall(temp_dir) logger.info(f"Extracted to temp dir: {temp_dir}") # 收集所有有效的 trace 文件夹 valid_traces = [] # 遍历解压后的内容 for root, dirs, files in os.walk(temp_dir): # 检查当前目录是否是 trace 文件夹 if is_valid_trace_folder(root): valid_traces.append(root) logger.info(f"Found valid trace folder: {root}") if not valid_traces: logger.warning(f"No valid traces found in {temp_dir}") # 列出临时目录的内容用于调试 for root, dirs, files in os.walk(temp_dir): logger.info(f"Dir: {root}, Files: {files[:5]}") # 只显示前5个文件 # 导入找到的 trace 文件夹 for trace_folder in valid_traces: trace_folder_name = os.path.basename(trace_folder) target_path = os.path.join(base_trace_path, trace_folder_name) try: # 如果目标已存在,跳过 if os.path.exists(target_path): failed.append({ "trace_id": trace_folder_name, "reason": "Trace already exists" }) logger.warning(f"Trace already exists: {trace_folder_name}") continue # 复制到目标目录 shutil.copytree(trace_folder, target_path) imported.append(trace_folder_name) logger.info(f"Imported trace: {trace_folder_name}") except Exception as e: failed.append({ "trace_id": trace_folder_name, "reason": str(e) }) logger.error(f"Failed to import {trace_folder_name}: {e}") except zipfile.BadZipFile: raise HTTPException(status_code=400, detail="Invalid zip file") except Exception as e: logger.error(f"Failed to extract zip: {e}") raise HTTPException(status_code=500, detail=f"Failed to extract zip: {str(e)}") return imported, failed # ===== 路由 ===== @router.post("/upload", response_model=UploadResponse) async def upload_traces(file: UploadFile = File(...)): """ 上传 trace 压缩包并导入 支持的格式:.zip 压缩包可以包含: - 单个 trace 文件夹 - 多个 trace 文件夹 - 嵌套的 trace 文件夹 Args: file: 上传的压缩包文件 """ # 验证文件类型 if not file.filename or not file.filename.endswith('.zip'): raise HTTPException(status_code=400, detail="Only .zip files are supported") # 获取 trace 存储路径 store = get_trace_store() # 假设 FileSystemTraceStore 有 base_path 属性 if not hasattr(store, 'base_path'): raise HTTPException(status_code=500, detail="TraceStore does not support file system operations") base_trace_path = store.base_path # 保存上传的文件到临时位置 with tempfile.NamedTemporaryFile(delete=False, suffix='.zip') as temp_file: temp_file_path = temp_file.name content = await file.read() temp_file.write(content) try: # 解压并导入 imported, failed = extract_and_import_traces(temp_file_path, base_trace_path) # 构建响应消息 if imported and not failed: message = f"Successfully imported {len(imported)} trace(s)" elif imported and failed: message = f"Imported {len(imported)} trace(s), {len(failed)} failed" elif not imported and failed: message = f"Failed to import all traces" else: message = "No valid traces found in the zip file" return UploadResponse( success=len(imported) > 0, message=message, imported_traces=imported, failed_traces=failed ) finally: # 清理临时文件 if os.path.exists(temp_file_path): os.unlink(temp_file_path)