main.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. from fastapi import FastAPI, BackgroundTasks, Request, Depends, HTTPException, Header
  2. from sqlalchemy.orm import Session
  3. from typing import List, Optional
  4. from app.config import settings
  5. from app.database import engine, Base, get_db, SessionLocal
  6. from app.services.webhook_service import WebhookService
  7. from app.models import Project, DataVersion, DataFile
  8. from app import schemas
  9. import logging
  10. import os
  11. import hmac
  12. import hashlib
  13. # Create tables
  14. Base.metadata.create_all(bind=engine)
  15. logging.basicConfig(level=logging.INFO)
  16. logger = logging.getLogger(__name__)
  17. app = FastAPI(title="Data Nexus", version="0.1.0")
  18. async def process_webhook_task(payload: dict):
  19. """Background task that creates its own db session."""
  20. db = SessionLocal()
  21. try:
  22. service = WebhookService(db)
  23. await service.process_webhook(payload)
  24. except Exception as e:
  25. logger.error(f"Webhook processing failed: {e}", exc_info=True)
  26. finally:
  27. db.close()
  28. def build_file_tree(files: List[DataFile]) -> list:
  29. """Convert flat file list to tree structure."""
  30. tree = {}
  31. for f in files:
  32. parts = f.relative_path.split("/")
  33. current = tree
  34. for i, part in enumerate(parts):
  35. if i == len(parts) - 1:
  36. # It's a file
  37. if "_files" not in current:
  38. current["_files"] = []
  39. current["_files"].append({
  40. "name": part,
  41. "type": "file",
  42. "id": f.id,
  43. "size": f.file_size,
  44. "file_type": f.file_type,
  45. "sha": f.file_sha
  46. })
  47. else:
  48. # It's a folder
  49. if part not in current:
  50. current[part] = {}
  51. current = current[part]
  52. def convert_to_list(node: dict) -> list:
  53. result = []
  54. for key, value in node.items():
  55. if key == "_files":
  56. result.extend(value)
  57. else:
  58. result.append({
  59. "name": key,
  60. "type": "folder",
  61. "children": convert_to_list(value)
  62. })
  63. # Sort: folders first, then files
  64. result.sort(key=lambda x: (0 if x["type"] == "folder" else 1, x["name"]))
  65. return result
  66. return convert_to_list(tree)
  67. @app.get("/")
  68. def read_root():
  69. return {"message": "Welcome to Data Nexus API"}
  70. def verify_webhook_signature(payload_body: bytes, signature: str) -> bool:
  71. """Verify Gogs webhook signature."""
  72. if not settings.GOGS_SECRET:
  73. return True # No secret configured, skip verification
  74. if not signature:
  75. return False
  76. expected = hmac.new(
  77. settings.GOGS_SECRET.encode(),
  78. payload_body,
  79. hashlib.sha256
  80. ).hexdigest()
  81. return hmac.compare_digest(f"sha256={expected}", signature)
  82. @app.post("/webhook")
  83. async def webhook_handler(
  84. request: Request,
  85. background_tasks: BackgroundTasks,
  86. x_gogs_signature: Optional[str] = Header(None)
  87. ):
  88. body = await request.body()
  89. # Verify signature if secret is configured
  90. if settings.GOGS_SECRET and not verify_webhook_signature(body, x_gogs_signature):
  91. raise HTTPException(status_code=401, detail="Invalid signature")
  92. try:
  93. import json
  94. payload = json.loads(body)
  95. except Exception:
  96. raise HTTPException(status_code=400, detail="Invalid JSON")
  97. # Process in background with its own db session
  98. background_tasks.add_task(process_webhook_task, payload)
  99. return {"status": "ok", "message": "Webhook received"}
  100. # ==================== Project APIs ====================
  101. @app.get("/projects", response_model=List[schemas.ProjectOut])
  102. def list_projects(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
  103. """List all projects."""
  104. projects = db.query(Project).offset(skip).limit(limit).all()
  105. return projects
  106. @app.get("/projects/{project_id}", response_model=schemas.ProjectOut)
  107. def get_project(project_id: str, db: Session = Depends(get_db)):
  108. """Get a single project by ID."""
  109. project = db.query(Project).filter(Project.id == project_id).first()
  110. if not project:
  111. raise HTTPException(status_code=404, detail="Project not found")
  112. return project
  113. @app.get("/projects/name/{project_name}", response_model=schemas.ProjectOut)
  114. def get_project_by_name(project_name: str, db: Session = Depends(get_db)):
  115. """Get a project by name."""
  116. project = db.query(Project).filter(Project.project_name == project_name).first()
  117. if not project:
  118. raise HTTPException(status_code=404, detail="Project not found")
  119. return project
  120. # ==================== Version APIs ====================
  121. @app.get("/projects/{project_id}/versions", response_model=List[schemas.DataVersionOut])
  122. def list_versions(
  123. project_id: str,
  124. stage: Optional[str] = None,
  125. skip: int = 0,
  126. limit: int = 100,
  127. db: Session = Depends(get_db)
  128. ):
  129. """List versions for a project, optionally filtered by stage."""
  130. query = db.query(DataVersion).filter(DataVersion.project_id == project_id)
  131. if stage:
  132. query = query.filter(DataVersion.stage == stage)
  133. versions = query.order_by(DataVersion.created_at.desc()).offset(skip).limit(limit).all()
  134. return versions
  135. @app.get("/versions/{version_id}", response_model=schemas.DataVersionOut)
  136. def get_version(version_id: str, db: Session = Depends(get_db)):
  137. """Get a single version by ID."""
  138. version = db.query(DataVersion).filter(DataVersion.id == version_id).first()
  139. if not version:
  140. raise HTTPException(status_code=404, detail="Version not found")
  141. return version
  142. @app.get("/versions/{version_id}/files")
  143. def get_version_files(version_id: str, flat: bool = False, db: Session = Depends(get_db)):
  144. """
  145. Get files for a version.
  146. - flat=False (default): Returns tree structure
  147. - flat=True: Returns flat list
  148. """
  149. version = db.query(DataVersion).filter(DataVersion.id == version_id).first()
  150. if not version:
  151. raise HTTPException(status_code=404, detail="Version not found")
  152. files = db.query(DataFile).filter(DataFile.version_id == version_id).all()
  153. if flat:
  154. return [schemas.DataFileOut.model_validate(f) for f in files]
  155. return build_file_tree(files)
  156. # ==================== File APIs ====================
  157. from fastapi.responses import RedirectResponse
  158. from app.services.oss_client import oss_client
  159. @app.get("/files/{file_id}", response_model=schemas.DataFileOut)
  160. def get_file_info(file_id: int, db: Session = Depends(get_db)):
  161. """Get file metadata."""
  162. file_record = db.query(DataFile).filter(DataFile.id == file_id).first()
  163. if not file_record:
  164. raise HTTPException(status_code=404, detail="File not found")
  165. return file_record
  166. @app.get("/files/{file_id}/url")
  167. def get_file_url(file_id: int, db: Session = Depends(get_db)):
  168. """Get file CDN URL."""
  169. file_record = db.query(DataFile).filter(DataFile.id == file_id).first()
  170. if not file_record:
  171. raise HTTPException(status_code=404, detail="File not found")
  172. cdn_url = oss_client.get_cdn_url(file_record.storage_path)
  173. return {"url": cdn_url}
  174. @app.get("/files/{file_id}/content")
  175. def get_file_content(file_id: int, db: Session = Depends(get_db)):
  176. """Redirect to CDN URL for file download."""
  177. file_record = db.query(DataFile).filter(DataFile.id == file_id).first()
  178. if not file_record:
  179. raise HTTPException(status_code=404, detail="File not found")
  180. cdn_url = oss_client.get_cdn_url(file_record.storage_path)
  181. return RedirectResponse(url=cdn_url)
  182. if __name__ == "__main__":
  183. import uvicorn
  184. uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True)