upload_api.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. """
  2. Trace Upload API
  3. 提供 Trace 压缩包上传和导入功能
  4. """
  5. import os
  6. import shutil
  7. import tempfile
  8. import zipfile
  9. from typing import List, Dict, Any
  10. from fastapi import APIRouter, UploadFile, File, HTTPException
  11. from pydantic import BaseModel
  12. from .protocols import TraceStore
  13. router = APIRouter(prefix="/api/traces", tags=["traces"])
  14. # ===== Response 模型 =====
  15. class UploadResponse(BaseModel):
  16. """上传响应"""
  17. success: bool
  18. message: str
  19. imported_traces: List[str]
  20. failed_traces: List[Dict[str, str]]
  21. # ===== 全局 TraceStore =====
  22. _trace_store: TraceStore | None = None
  23. def set_trace_store(store: TraceStore):
  24. """设置 TraceStore 实例"""
  25. global _trace_store
  26. _trace_store = store
  27. def get_trace_store() -> TraceStore:
  28. """获取 TraceStore 实例"""
  29. if _trace_store is None:
  30. raise RuntimeError("TraceStore not initialized")
  31. return _trace_store
  32. # ===== 辅助函数 =====
  33. def is_valid_trace_folder(folder_path: str) -> bool:
  34. """
  35. 验证是否是有效的 trace 文件夹
  36. 有效的 trace 文件夹应该包含:
  37. - meta.json 文件
  38. """
  39. return os.path.isfile(os.path.join(folder_path, "meta.json"))
  40. def extract_and_import_traces(zip_path: str, base_trace_path: str) -> tuple[List[str], List[Dict[str, str]]]:
  41. """
  42. 解压并导入 traces
  43. Returns:
  44. (imported_traces, failed_traces)
  45. """
  46. import logging
  47. logger = logging.getLogger(__name__)
  48. imported = []
  49. failed = []
  50. # 创建临时目录
  51. with tempfile.TemporaryDirectory() as temp_dir:
  52. try:
  53. # 解压文件
  54. with zipfile.ZipFile(zip_path, 'r') as zip_ref:
  55. zip_ref.extractall(temp_dir)
  56. logger.info(f"Extracted to temp dir: {temp_dir}")
  57. # 收集所有有效的 trace 文件夹
  58. valid_traces = []
  59. # 遍历解压后的内容
  60. for root, dirs, files in os.walk(temp_dir):
  61. # 检查当前目录是否是 trace 文件夹
  62. if is_valid_trace_folder(root):
  63. valid_traces.append(root)
  64. logger.info(f"Found valid trace folder: {root}")
  65. if not valid_traces:
  66. logger.warning(f"No valid traces found in {temp_dir}")
  67. # 列出临时目录的内容用于调试
  68. for root, dirs, files in os.walk(temp_dir):
  69. logger.info(f"Dir: {root}, Files: {files[:5]}") # 只显示前5个文件
  70. # 导入找到的 trace 文件夹
  71. for trace_folder in valid_traces:
  72. trace_folder_name = os.path.basename(trace_folder)
  73. target_path = os.path.join(base_trace_path, trace_folder_name)
  74. try:
  75. # 如果目标已存在,跳过
  76. if os.path.exists(target_path):
  77. failed.append({
  78. "trace_id": trace_folder_name,
  79. "reason": "Trace already exists"
  80. })
  81. logger.warning(f"Trace already exists: {trace_folder_name}")
  82. continue
  83. # 复制到目标目录
  84. shutil.copytree(trace_folder, target_path)
  85. imported.append(trace_folder_name)
  86. logger.info(f"Imported trace: {trace_folder_name}")
  87. except Exception as e:
  88. failed.append({
  89. "trace_id": trace_folder_name,
  90. "reason": str(e)
  91. })
  92. logger.error(f"Failed to import {trace_folder_name}: {e}")
  93. except zipfile.BadZipFile:
  94. raise HTTPException(status_code=400, detail="Invalid zip file")
  95. except Exception as e:
  96. logger.error(f"Failed to extract zip: {e}")
  97. raise HTTPException(status_code=500, detail=f"Failed to extract zip: {str(e)}")
  98. return imported, failed
  99. # ===== 路由 =====
  100. @router.post("/upload", response_model=UploadResponse)
  101. async def upload_traces(file: UploadFile = File(...)):
  102. """
  103. 上传 trace 压缩包并导入
  104. 支持的格式:.zip
  105. 压缩包可以包含:
  106. - 单个 trace 文件夹
  107. - 多个 trace 文件夹
  108. - 嵌套的 trace 文件夹
  109. Args:
  110. file: 上传的压缩包文件
  111. """
  112. # 验证文件类型
  113. if not file.filename or not file.filename.endswith('.zip'):
  114. raise HTTPException(status_code=400, detail="Only .zip files are supported")
  115. # 获取 trace 存储路径
  116. store = get_trace_store()
  117. # 假设 FileSystemTraceStore 有 base_path 属性
  118. if not hasattr(store, 'base_path'):
  119. raise HTTPException(status_code=500, detail="TraceStore does not support file system operations")
  120. base_trace_path = store.base_path
  121. # 保存上传的文件到临时位置
  122. with tempfile.NamedTemporaryFile(delete=False, suffix='.zip') as temp_file:
  123. temp_file_path = temp_file.name
  124. content = await file.read()
  125. temp_file.write(content)
  126. try:
  127. # 解压并导入
  128. imported, failed = extract_and_import_traces(temp_file_path, base_trace_path)
  129. # 构建响应消息
  130. if imported and not failed:
  131. message = f"Successfully imported {len(imported)} trace(s)"
  132. elif imported and failed:
  133. message = f"Imported {len(imported)} trace(s), {len(failed)} failed"
  134. elif not imported and failed:
  135. message = f"Failed to import all traces"
  136. else:
  137. message = "No valid traces found in the zip file"
  138. return UploadResponse(
  139. success=len(imported) > 0,
  140. message=message,
  141. imported_traces=imported,
  142. failed_traces=failed
  143. )
  144. finally:
  145. # 清理临时文件
  146. if os.path.exists(temp_file_path):
  147. os.unlink(temp_file_path)