main.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. from fastapi import FastAPI, BackgroundTasks, Request, Depends, HTTPException, Header, Query
  2. from fastapi.responses import FileResponse
  3. from sqlalchemy.orm import Session
  4. from sqlalchemy import func as sqlfunc
  5. from typing import List, Optional
  6. from app.config import settings
  7. from app.database import engine, Base, get_db, SessionLocal
  8. from app.services.webhook_service import WebhookService
  9. from app.models import Project, DataVersion, DataFile
  10. from app import schemas
  11. import logging
  12. import os
  13. import hmac
  14. import hashlib
  15. # Static files directory
  16. STATIC_DIR = os.path.join(os.path.dirname(__file__), "static")
  17. # Create tables
  18. Base.metadata.create_all(bind=engine)
  19. logging.basicConfig(level=logging.INFO)
  20. logger = logging.getLogger(__name__)
  21. app = FastAPI(title="Data Nexus", version="0.1.0")
  22. async def process_webhook_task(payload: dict):
  23. """Background task that creates its own db session."""
  24. db = SessionLocal()
  25. try:
  26. service = WebhookService(db)
  27. await service.process_webhook(payload)
  28. except Exception as e:
  29. logger.error(f"Webhook processing failed: {e}", exc_info=True)
  30. finally:
  31. db.close()
  32. def build_file_tree(files: List[DataFile]) -> list:
  33. """Convert flat file list to tree structure."""
  34. tree = {}
  35. for f in files:
  36. parts = f.relative_path.split("/")
  37. current = tree
  38. for i, part in enumerate(parts):
  39. if i == len(parts) - 1:
  40. # It's a file
  41. if "_files" not in current:
  42. current["_files"] = []
  43. current["_files"].append({
  44. "name": part,
  45. "type": "file",
  46. "id": f.id,
  47. "size": f.file_size,
  48. "file_type": f.file_type,
  49. "sha": f.file_sha
  50. })
  51. else:
  52. # It's a folder
  53. if part not in current:
  54. current[part] = {}
  55. current = current[part]
  56. def convert_to_list(node: dict) -> list:
  57. result = []
  58. for key, value in node.items():
  59. if key == "_files":
  60. result.extend(value)
  61. else:
  62. result.append({
  63. "name": key,
  64. "type": "folder",
  65. "children": convert_to_list(value)
  66. })
  67. # Sort: folders first, then files
  68. result.sort(key=lambda x: (0 if x["type"] == "folder" else 1, x["name"]))
  69. return result
  70. return convert_to_list(tree)
  71. @app.get("/")
  72. def read_root():
  73. """Serve the unified console UI."""
  74. return FileResponse(os.path.join(STATIC_DIR, "console.html"), media_type="text/html")
  75. @app.get("/fs")
  76. def filesystem_page():
  77. """Serve the legacy file system UI."""
  78. return FileResponse(os.path.join(STATIC_DIR, "index.html"), media_type="text/html")
  79. @app.get("/api/health")
  80. def health_check():
  81. """Health check endpoint."""
  82. return {"status": "ok"}
  83. def verify_webhook_signature(payload_body: bytes, signature: str) -> bool:
  84. """Verify Gogs webhook signature."""
  85. if not settings.GOGS_SECRET:
  86. return True # No secret configured, skip verification
  87. if not signature:
  88. return False
  89. expected = hmac.new(
  90. settings.GOGS_SECRET.encode(),
  91. payload_body,
  92. hashlib.sha256
  93. ).hexdigest()
  94. return hmac.compare_digest(f"sha256={expected}", signature)
  95. @app.post("/webhook")
  96. async def webhook_handler(
  97. request: Request,
  98. background_tasks: BackgroundTasks,
  99. x_gogs_signature: Optional[str] = Header(None)
  100. ):
  101. body = await request.body()
  102. # Verify signature if secret is configured
  103. if settings.GOGS_SECRET and not verify_webhook_signature(body, x_gogs_signature):
  104. raise HTTPException(status_code=401, detail="Invalid signature")
  105. try:
  106. import json
  107. payload = json.loads(body)
  108. except Exception:
  109. raise HTTPException(status_code=400, detail="Invalid JSON")
  110. # Process in background with its own db session
  111. background_tasks.add_task(process_webhook_task, payload)
  112. return {"status": "ok", "message": "Webhook received"}
  113. # ==================== Project APIs ====================
  114. @app.get("/projects", response_model=List[schemas.ProjectOut])
  115. def list_projects(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
  116. """List all projects."""
  117. projects = db.query(Project).offset(skip).limit(limit).all()
  118. return projects
  119. @app.get("/projects/{project_id}", response_model=schemas.ProjectOut)
  120. def get_project(project_id: str, db: Session = Depends(get_db)):
  121. """Get a single project by ID."""
  122. project = db.query(Project).filter(Project.id == project_id).first()
  123. if not project:
  124. raise HTTPException(status_code=404, detail="Project not found")
  125. return project
  126. @app.get("/projects/name/{project_name}", response_model=schemas.ProjectOut)
  127. def get_project_by_name(project_name: str, db: Session = Depends(get_db)):
  128. """Get a project by name."""
  129. project = db.query(Project).filter(Project.project_name == project_name).first()
  130. if not project:
  131. raise HTTPException(status_code=404, detail="Project not found")
  132. return project
  133. # ==================== Console APIs ====================
  134. @app.get("/stages/all")
  135. def get_all_stages(db: Session = Depends(get_db)):
  136. """Get all stages across all projects with version counts."""
  137. results = db.query(
  138. DataVersion.stage,
  139. DataVersion.project_id,
  140. Project.project_name,
  141. sqlfunc.count(DataVersion.id).label("count")
  142. ).join(Project).group_by(
  143. DataVersion.stage, DataVersion.project_id, Project.project_name
  144. ).all()
  145. return [{
  146. "name": r[0],
  147. "project_id": r[1],
  148. "project_name": r[2],
  149. "version_count": r[3]
  150. } for r in results]
  151. @app.get("/projects/{project_id}/stages")
  152. def get_project_stages(project_id: str, db: Session = Depends(get_db)):
  153. """Get all unique stages for a project with version counts."""
  154. results = db.query(
  155. DataVersion.stage,
  156. sqlfunc.count(DataVersion.id).label("count")
  157. ).filter(
  158. DataVersion.project_id == project_id
  159. ).group_by(DataVersion.stage).all()
  160. return [{"name": r[0], "version_count": r[1]} for r in results]
  161. @app.get("/projects/{project_id}/stage-files")
  162. def get_stage_files(
  163. project_id: str,
  164. stage: str = Query(...),
  165. skip: int = 0,
  166. limit: int = 20,
  167. db: Session = Depends(get_db)
  168. ):
  169. """Get versions with files for a specific stage, ordered by newest first."""
  170. versions = db.query(DataVersion).filter(
  171. DataVersion.project_id == project_id,
  172. DataVersion.stage == stage
  173. ).order_by(DataVersion.created_at.desc()).offset(skip).limit(limit).all()
  174. result = []
  175. for v in versions:
  176. files = db.query(DataFile).filter(DataFile.version_id == v.id).all()
  177. result.append({
  178. "version_id": v.id,
  179. "commit_id": v.commit_id,
  180. "author": v.author,
  181. "created_at": v.created_at.isoformat() if v.created_at else None,
  182. "files": [{
  183. "id": f.id,
  184. "name": f.relative_path.split("/")[-1] if f.relative_path else "",
  185. "relative_path": f.relative_path,
  186. "file_size": f.file_size,
  187. "file_type": f.file_type,
  188. "file_sha": f.file_sha,
  189. } for f in files]
  190. })
  191. return result
  192. # ==================== Version APIs ====================
  193. @app.get("/projects/{project_id}/versions", response_model=List[schemas.DataVersionOut])
  194. def list_versions(
  195. project_id: str,
  196. stage: Optional[str] = None,
  197. skip: int = 0,
  198. limit: int = 100,
  199. db: Session = Depends(get_db)
  200. ):
  201. """List versions for a project, optionally filtered by stage."""
  202. query = db.query(DataVersion).filter(DataVersion.project_id == project_id)
  203. if stage:
  204. query = query.filter(DataVersion.stage == stage)
  205. versions = query.order_by(DataVersion.created_at.desc()).offset(skip).limit(limit).all()
  206. return versions
  207. @app.get("/versions/{version_id}", response_model=schemas.DataVersionOut)
  208. def get_version(version_id: str, db: Session = Depends(get_db)):
  209. """Get a single version by ID."""
  210. version = db.query(DataVersion).filter(DataVersion.id == version_id).first()
  211. if not version:
  212. raise HTTPException(status_code=404, detail="Version not found")
  213. return version
  214. @app.get("/versions/{version_id}/files")
  215. def get_version_files(version_id: str, flat: bool = False, db: Session = Depends(get_db)):
  216. """
  217. Get files for a version.
  218. - flat=False (default): Returns tree structure
  219. - flat=True: Returns flat list
  220. """
  221. version = db.query(DataVersion).filter(DataVersion.id == version_id).first()
  222. if not version:
  223. raise HTTPException(status_code=404, detail="Version not found")
  224. files = db.query(DataFile).filter(DataFile.version_id == version_id).all()
  225. if flat:
  226. return [schemas.DataFileOut.model_validate(f) for f in files]
  227. return build_file_tree(files)
  228. # ==================== File APIs ====================
  229. import urllib.parse
  230. from fastapi.responses import RedirectResponse # noqa: E811
  231. from app.services.oss_client import oss_client
  232. @app.get("/files/{file_id}", response_model=schemas.DataFileOut)
  233. def get_file_info(file_id: int, db: Session = Depends(get_db)):
  234. """Get file metadata."""
  235. file_record = db.query(DataFile).filter(DataFile.id == file_id).first()
  236. if not file_record:
  237. raise HTTPException(status_code=404, detail="File not found")
  238. return file_record
  239. @app.get("/files/{file_id}/url")
  240. def get_file_url(file_id: int, db: Session = Depends(get_db)):
  241. """Get file CDN URL."""
  242. file_record = db.query(DataFile).filter(DataFile.id == file_id).first()
  243. if not file_record:
  244. raise HTTPException(status_code=404, detail="File not found")
  245. cdn_url = oss_client.get_cdn_url(file_record.storage_path)
  246. return {"url": cdn_url}
  247. @app.get("/files/{file_id}/content")
  248. def get_file_content(file_id: int, db: Session = Depends(get_db)):
  249. """Redirect to CDN URL for file download with forced attachment header."""
  250. file_record = db.query(DataFile).filter(DataFile.id == file_id).first()
  251. if not file_record:
  252. raise HTTPException(status_code=404, detail="File not found")
  253. cdn_url = oss_client.get_cdn_url(file_record.storage_path)
  254. # Try to force download by adding Aliyun OSS specific query parameter
  255. # This works for Aliyun OSS even on custom domains if not explicitly disabled
  256. filename = os.path.basename(file_record.relative_path)
  257. quoted_filename = urllib.parse.quote(filename)
  258. # Using both filename and filename* for maximum compatibility
  259. disposition = f"attachment; filename=\"{quoted_filename}\"; filename*=UTF-8''{quoted_filename}"
  260. separator = "&" if "?" in cdn_url else "?"
  261. download_url = f"{cdn_url}{separator}response-content-disposition={urllib.parse.quote(disposition)}"
  262. return RedirectResponse(url=download_url)
  263. if __name__ == "__main__":
  264. import uvicorn
  265. uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True)