main.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. from fastapi import FastAPI, BackgroundTasks, Request, Depends, HTTPException, Header
  2. from fastapi.responses import FileResponse
  3. from sqlalchemy.orm import Session
  4. from typing import List, Optional
  5. from app.config import settings
  6. from app.database import engine, Base, get_db, SessionLocal
  7. from app.services.webhook_service import WebhookService
  8. from app.models import Project, DataVersion, DataFile
  9. from app import schemas
  10. import logging
  11. import os
  12. import hmac
  13. import hashlib
  14. # Static files directory
  15. STATIC_DIR = os.path.join(os.path.dirname(__file__), "static")
  16. # Create tables
  17. Base.metadata.create_all(bind=engine)
  18. logging.basicConfig(level=logging.INFO)
  19. logger = logging.getLogger(__name__)
  20. app = FastAPI(title="Data Nexus", version="0.1.0")
  21. async def process_webhook_task(payload: dict):
  22. """Background task that creates its own db session."""
  23. db = SessionLocal()
  24. try:
  25. service = WebhookService(db)
  26. await service.process_webhook(payload)
  27. except Exception as e:
  28. logger.error(f"Webhook processing failed: {e}", exc_info=True)
  29. finally:
  30. db.close()
  31. def build_file_tree(files: List[DataFile]) -> list:
  32. """Convert flat file list to tree structure."""
  33. tree = {}
  34. for f in files:
  35. parts = f.relative_path.split("/")
  36. current = tree
  37. for i, part in enumerate(parts):
  38. if i == len(parts) - 1:
  39. # It's a file
  40. if "_files" not in current:
  41. current["_files"] = []
  42. current["_files"].append({
  43. "name": part,
  44. "type": "file",
  45. "id": f.id,
  46. "size": f.file_size,
  47. "file_type": f.file_type,
  48. "sha": f.file_sha
  49. })
  50. else:
  51. # It's a folder
  52. if part not in current:
  53. current[part] = {}
  54. current = current[part]
  55. def convert_to_list(node: dict) -> list:
  56. result = []
  57. for key, value in node.items():
  58. if key == "_files":
  59. result.extend(value)
  60. else:
  61. result.append({
  62. "name": key,
  63. "type": "folder",
  64. "children": convert_to_list(value)
  65. })
  66. # Sort: folders first, then files
  67. result.sort(key=lambda x: (0 if x["type"] == "folder" else 1, x["name"]))
  68. return result
  69. return convert_to_list(tree)
  70. @app.get("/")
  71. def read_root():
  72. """Serve the frontend UI."""
  73. return FileResponse(os.path.join(STATIC_DIR, "index.html"), media_type="text/html")
  74. @app.get("/api/health")
  75. def health_check():
  76. """Health check endpoint."""
  77. return {"status": "ok"}
  78. def verify_webhook_signature(payload_body: bytes, signature: str) -> bool:
  79. """Verify Gogs webhook signature."""
  80. if not settings.GOGS_SECRET:
  81. return True # No secret configured, skip verification
  82. if not signature:
  83. return False
  84. expected = hmac.new(
  85. settings.GOGS_SECRET.encode(),
  86. payload_body,
  87. hashlib.sha256
  88. ).hexdigest()
  89. return hmac.compare_digest(f"sha256={expected}", signature)
  90. @app.post("/webhook")
  91. async def webhook_handler(
  92. request: Request,
  93. background_tasks: BackgroundTasks,
  94. x_gogs_signature: Optional[str] = Header(None)
  95. ):
  96. body = await request.body()
  97. # Verify signature if secret is configured
  98. if settings.GOGS_SECRET and not verify_webhook_signature(body, x_gogs_signature):
  99. raise HTTPException(status_code=401, detail="Invalid signature")
  100. try:
  101. import json
  102. payload = json.loads(body)
  103. except Exception:
  104. raise HTTPException(status_code=400, detail="Invalid JSON")
  105. # Process in background with its own db session
  106. background_tasks.add_task(process_webhook_task, payload)
  107. return {"status": "ok", "message": "Webhook received"}
  108. # ==================== Project APIs ====================
  109. @app.get("/projects", response_model=List[schemas.ProjectOut])
  110. def list_projects(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
  111. """List all projects."""
  112. projects = db.query(Project).offset(skip).limit(limit).all()
  113. return projects
  114. @app.get("/projects/{project_id}", response_model=schemas.ProjectOut)
  115. def get_project(project_id: str, db: Session = Depends(get_db)):
  116. """Get a single project by ID."""
  117. project = db.query(Project).filter(Project.id == project_id).first()
  118. if not project:
  119. raise HTTPException(status_code=404, detail="Project not found")
  120. return project
  121. @app.get("/projects/name/{project_name}", response_model=schemas.ProjectOut)
  122. def get_project_by_name(project_name: str, db: Session = Depends(get_db)):
  123. """Get a project by name."""
  124. project = db.query(Project).filter(Project.project_name == project_name).first()
  125. if not project:
  126. raise HTTPException(status_code=404, detail="Project not found")
  127. return project
  128. # ==================== Version APIs ====================
  129. @app.get("/projects/{project_id}/versions", response_model=List[schemas.DataVersionOut])
  130. def list_versions(
  131. project_id: str,
  132. stage: Optional[str] = None,
  133. skip: int = 0,
  134. limit: int = 100,
  135. db: Session = Depends(get_db)
  136. ):
  137. """List versions for a project, optionally filtered by stage."""
  138. query = db.query(DataVersion).filter(DataVersion.project_id == project_id)
  139. if stage:
  140. query = query.filter(DataVersion.stage == stage)
  141. versions = query.order_by(DataVersion.created_at.desc()).offset(skip).limit(limit).all()
  142. return versions
  143. @app.get("/versions/{version_id}", response_model=schemas.DataVersionOut)
  144. def get_version(version_id: str, db: Session = Depends(get_db)):
  145. """Get a single version by ID."""
  146. version = db.query(DataVersion).filter(DataVersion.id == version_id).first()
  147. if not version:
  148. raise HTTPException(status_code=404, detail="Version not found")
  149. return version
  150. @app.get("/versions/{version_id}/files")
  151. def get_version_files(version_id: str, flat: bool = False, db: Session = Depends(get_db)):
  152. """
  153. Get files for a version.
  154. - flat=False (default): Returns tree structure
  155. - flat=True: Returns flat list
  156. """
  157. version = db.query(DataVersion).filter(DataVersion.id == version_id).first()
  158. if not version:
  159. raise HTTPException(status_code=404, detail="Version not found")
  160. files = db.query(DataFile).filter(DataFile.version_id == version_id).all()
  161. if flat:
  162. return [schemas.DataFileOut.model_validate(f) for f in files]
  163. return build_file_tree(files)
  164. # ==================== File APIs ====================
  165. import urllib.parse
  166. from fastapi.responses import RedirectResponse # noqa: E811
  167. from app.services.oss_client import oss_client
  168. @app.get("/files/{file_id}", response_model=schemas.DataFileOut)
  169. def get_file_info(file_id: int, db: Session = Depends(get_db)):
  170. """Get file metadata."""
  171. file_record = db.query(DataFile).filter(DataFile.id == file_id).first()
  172. if not file_record:
  173. raise HTTPException(status_code=404, detail="File not found")
  174. return file_record
  175. @app.get("/files/{file_id}/url")
  176. def get_file_url(file_id: int, db: Session = Depends(get_db)):
  177. """Get file CDN URL."""
  178. file_record = db.query(DataFile).filter(DataFile.id == file_id).first()
  179. if not file_record:
  180. raise HTTPException(status_code=404, detail="File not found")
  181. cdn_url = oss_client.get_cdn_url(file_record.storage_path)
  182. return {"url": cdn_url}
  183. @app.get("/files/{file_id}/content")
  184. def get_file_content(file_id: int, db: Session = Depends(get_db)):
  185. """Redirect to CDN URL for file download with forced attachment header."""
  186. file_record = db.query(DataFile).filter(DataFile.id == file_id).first()
  187. if not file_record:
  188. raise HTTPException(status_code=404, detail="File not found")
  189. cdn_url = oss_client.get_cdn_url(file_record.storage_path)
  190. # Try to force download by adding Aliyun OSS specific query parameter
  191. # This works for Aliyun OSS even on custom domains if not explicitly disabled
  192. filename = os.path.basename(file_record.relative_path)
  193. quoted_filename = urllib.parse.quote(filename)
  194. # Using both filename and filename* for maximum compatibility
  195. disposition = f"attachment; filename=\"{quoted_filename}\"; filename*=UTF-8''{quoted_filename}"
  196. separator = "&" if "?" in cdn_url else "?"
  197. download_url = f"{cdn_url}{separator}response-content-disposition={urllib.parse.quote(disposition)}"
  198. return RedirectResponse(url=download_url)
  199. if __name__ == "__main__":
  200. import uvicorn
  201. uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True)