main.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360
  1. from fastapi import FastAPI, BackgroundTasks, Request, Depends, HTTPException, Header, Query
  2. from fastapi.responses import FileResponse
  3. from sqlalchemy.orm import Session
  4. from sqlalchemy import func as sqlfunc
  5. from typing import List, Optional
  6. from app.config import settings
  7. from app.database import engine, Base, get_db, SessionLocal
  8. from app.services.webhook_service import WebhookService
  9. from app.models import Project, DataVersion, DataFile
  10. from app import schemas
  11. import logging
  12. import os
  13. import hmac
  14. import hashlib
  15. # Static files directory
  16. STATIC_DIR = os.path.join(os.path.dirname(__file__), "static")
  17. # Create tables
  18. Base.metadata.create_all(bind=engine)
  19. logging.basicConfig(level=logging.INFO)
  20. logger = logging.getLogger(__name__)
  21. app = FastAPI(title="Data Nexus", version="0.1.0")
  22. async def process_webhook_task(payload: dict):
  23. """Background task that creates its own db session."""
  24. db = SessionLocal()
  25. try:
  26. service = WebhookService(db)
  27. await service.process_webhook(payload)
  28. except Exception as e:
  29. logger.error(f"Webhook processing failed: {e}", exc_info=True)
  30. finally:
  31. db.close()
  32. def build_file_tree(files: List[DataFile]) -> list:
  33. """Convert flat file list to tree structure."""
  34. tree = {}
  35. for f in files:
  36. parts = f.relative_path.split("/")
  37. current = tree
  38. for i, part in enumerate(parts):
  39. if i == len(parts) - 1:
  40. # It's a file
  41. if "_files" not in current:
  42. current["_files"] = []
  43. current["_files"].append({
  44. "name": part,
  45. "type": "file",
  46. "id": f.id,
  47. "size": f.file_size,
  48. "file_type": f.file_type,
  49. "sha": f.file_sha,
  50. "direction": f.direction,
  51. "label": f.label,
  52. "extracted_value": f.extracted_value,
  53. "group_key": f.group_key
  54. })
  55. else:
  56. # It's a folder
  57. if part not in current:
  58. current[part] = {}
  59. current = current[part]
  60. def convert_to_list(node: dict) -> list:
  61. result = []
  62. for key, value in node.items():
  63. if key == "_files":
  64. result.extend(value)
  65. else:
  66. result.append({
  67. "name": key,
  68. "type": "folder",
  69. "children": convert_to_list(value)
  70. })
  71. # Sort: folders first, then files
  72. result.sort(key=lambda x: (0 if x["type"] == "folder" else 1, x["name"]))
  73. return result
  74. return convert_to_list(tree)
  75. @app.get("/")
  76. def read_root():
  77. """Serve the unified console UI."""
  78. return FileResponse(os.path.join(STATIC_DIR, "console.html"), media_type="text/html")
  79. @app.get("/fs")
  80. def filesystem_page():
  81. """Serve the legacy file system UI."""
  82. return FileResponse(os.path.join(STATIC_DIR, "index.html"), media_type="text/html")
  83. @app.get("/records")
  84. def records_page():
  85. """Serve the data records UI."""
  86. return FileResponse(os.path.join(STATIC_DIR, "records.html"), media_type="text/html")
  87. @app.get("/api/health")
  88. def health_check():
  89. """Health check endpoint."""
  90. return {"status": "ok"}
  91. def verify_webhook_signature(payload_body: bytes, signature: str) -> bool:
  92. """Verify Gogs webhook signature."""
  93. if not settings.GOGS_SECRET:
  94. return True # No secret configured, skip verification
  95. if not signature:
  96. return False
  97. expected = hmac.new(
  98. settings.GOGS_SECRET.encode(),
  99. payload_body,
  100. hashlib.sha256
  101. ).hexdigest()
  102. return hmac.compare_digest(f"sha256={expected}", signature)
  103. @app.post("/webhook")
  104. async def webhook_handler(
  105. request: Request,
  106. background_tasks: BackgroundTasks,
  107. x_gogs_signature: Optional[str] = Header(None)
  108. ):
  109. body = await request.body()
  110. # Verify signature if secret is configured
  111. if settings.GOGS_SECRET and not verify_webhook_signature(body, x_gogs_signature):
  112. raise HTTPException(status_code=401, detail="Invalid signature")
  113. try:
  114. import json
  115. payload = json.loads(body)
  116. except Exception:
  117. raise HTTPException(status_code=400, detail="Invalid JSON")
  118. # Process in background with its own db session
  119. background_tasks.add_task(process_webhook_task, payload)
  120. return {"status": "ok", "message": "Webhook received"}
  121. # ==================== Project APIs ====================
  122. @app.get("/projects", response_model=List[schemas.ProjectOut])
  123. def list_projects(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
  124. """List all projects."""
  125. projects = db.query(Project).offset(skip).limit(limit).all()
  126. return projects
  127. @app.get("/projects/{project_id}", response_model=schemas.ProjectOut)
  128. def get_project(project_id: str, db: Session = Depends(get_db)):
  129. """Get a single project by ID."""
  130. project = db.query(Project).filter(Project.id == project_id).first()
  131. if not project:
  132. raise HTTPException(status_code=404, detail="Project not found")
  133. return project
  134. @app.get("/projects/name/{project_name}", response_model=schemas.ProjectOut)
  135. def get_project_by_name(project_name: str, db: Session = Depends(get_db)):
  136. """Get a project by name."""
  137. project = db.query(Project).filter(Project.project_name == project_name).first()
  138. if not project:
  139. raise HTTPException(status_code=404, detail="Project not found")
  140. return project
  141. # ==================== Console APIs ====================
  142. @app.get("/stages/all")
  143. def get_all_stages(db: Session = Depends(get_db)):
  144. """Get all stages across all projects with version counts."""
  145. results = db.query(
  146. DataVersion.stage,
  147. DataVersion.project_id,
  148. Project.project_name,
  149. sqlfunc.count(DataVersion.id).label("count")
  150. ).join(Project).group_by(
  151. DataVersion.stage, DataVersion.project_id, Project.project_name
  152. ).all()
  153. return [{
  154. "name": r[0],
  155. "project_id": r[1],
  156. "project_name": r[2],
  157. "version_count": r[3]
  158. } for r in results]
  159. @app.get("/projects/{project_id}/stages")
  160. def get_project_stages(project_id: str, db: Session = Depends(get_db)):
  161. """Get all unique stages for a project with version counts."""
  162. results = db.query(
  163. DataVersion.stage,
  164. sqlfunc.count(DataVersion.id).label("count")
  165. ).filter(
  166. DataVersion.project_id == project_id
  167. ).group_by(DataVersion.stage).all()
  168. return [{"name": r[0], "version_count": r[1]} for r in results]
  169. @app.get("/projects/{project_id}/stage-files")
  170. def get_stage_files(
  171. project_id: str,
  172. stage: str = Query(...),
  173. skip: int = 0,
  174. limit: int = 20,
  175. db: Session = Depends(get_db)
  176. ):
  177. """Get versions with files for a specific stage, ordered by newest first."""
  178. versions = db.query(DataVersion).filter(
  179. DataVersion.project_id == project_id,
  180. DataVersion.stage == stage
  181. ).order_by(DataVersion.created_at.desc()).offset(skip).limit(limit).all()
  182. result = []
  183. for v in versions:
  184. files = db.query(DataFile).filter(DataFile.version_id == v.id).all()
  185. result.append({
  186. "version_id": v.id,
  187. "commit_id": v.commit_id,
  188. "author": v.author,
  189. "created_at": v.created_at.isoformat() if v.created_at else None,
  190. "files": [{
  191. "id": f.id,
  192. "name": f.relative_path.split("/")[-1] if f.relative_path else "",
  193. "relative_path": f.relative_path,
  194. "file_size": f.file_size,
  195. "file_type": f.file_type,
  196. "file_sha": f.file_sha,
  197. "direction": f.direction,
  198. "label": f.label,
  199. "extracted_value": f.extracted_value,
  200. "group_key": f.group_key,
  201. } for f in files]
  202. })
  203. return result
  204. @app.get("/projects/{project_id}/records", response_model=List[schemas.DataRecordOut])
  205. def list_data_records(
  206. project_id: str,
  207. stage: Optional[str] = None,
  208. skip: int = 0,
  209. limit: int = 100,
  210. db: Session = Depends(get_db)
  211. ):
  212. """List data records for a project, optionally filtered by stage."""
  213. from app.models import DataRecord
  214. query = db.query(DataRecord).filter(DataRecord.project_id == project_id)
  215. if stage:
  216. query = query.filter(DataRecord.stage == stage)
  217. records = query.order_by(DataRecord.created_at.desc()).offset(skip).limit(limit).all()
  218. return records
  219. # ==================== Version APIs ====================
  220. @app.get("/projects/{project_id}/versions", response_model=List[schemas.DataVersionOut])
  221. def list_versions(
  222. project_id: str,
  223. stage: Optional[str] = None,
  224. skip: int = 0,
  225. limit: int = 100,
  226. db: Session = Depends(get_db)
  227. ):
  228. """List versions for a project, optionally filtered by stage."""
  229. query = db.query(DataVersion).filter(DataVersion.project_id == project_id)
  230. if stage:
  231. query = query.filter(DataVersion.stage == stage)
  232. versions = query.order_by(DataVersion.created_at.desc()).offset(skip).limit(limit).all()
  233. return versions
  234. @app.get("/versions/{version_id}", response_model=schemas.DataVersionOut)
  235. def get_version(version_id: str, db: Session = Depends(get_db)):
  236. """Get a single version by ID."""
  237. version = db.query(DataVersion).filter(DataVersion.id == version_id).first()
  238. if not version:
  239. raise HTTPException(status_code=404, detail="Version not found")
  240. return version
  241. @app.get("/versions/{version_id}/files")
  242. def get_version_files(version_id: str, flat: bool = False, db: Session = Depends(get_db)):
  243. """
  244. Get files for a version.
  245. - flat=False (default): Returns tree structure
  246. - flat=True: Returns flat list
  247. """
  248. version = db.query(DataVersion).filter(DataVersion.id == version_id).first()
  249. if not version:
  250. raise HTTPException(status_code=404, detail="Version not found")
  251. files = db.query(DataFile).filter(DataFile.version_id == version_id).all()
  252. if flat:
  253. return [schemas.DataFileOut.model_validate(f) for f in files]
  254. return build_file_tree(files)
  255. # ==================== File APIs ====================
  256. import urllib.parse
  257. from fastapi.responses import RedirectResponse # noqa: E811
  258. from app.services.oss_client import oss_client
  259. @app.get("/files/{file_id}", response_model=schemas.DataFileOut)
  260. def get_file_info(file_id: int, db: Session = Depends(get_db)):
  261. """Get file metadata."""
  262. file_record = db.query(DataFile).filter(DataFile.id == file_id).first()
  263. if not file_record:
  264. raise HTTPException(status_code=404, detail="File not found")
  265. return file_record
  266. @app.get("/files/{file_id}/url")
  267. def get_file_url(file_id: int, db: Session = Depends(get_db)):
  268. """Get file CDN URL."""
  269. file_record = db.query(DataFile).filter(DataFile.id == file_id).first()
  270. if not file_record:
  271. raise HTTPException(status_code=404, detail="File not found")
  272. cdn_url = oss_client.get_cdn_url(file_record.storage_path)
  273. return {"url": cdn_url}
  274. @app.get("/files/{file_id}/content")
  275. def get_file_content(file_id: int, db: Session = Depends(get_db)):
  276. """Redirect to CDN URL for file download with forced attachment header."""
  277. file_record = db.query(DataFile).filter(DataFile.id == file_id).first()
  278. if not file_record:
  279. raise HTTPException(status_code=404, detail="File not found")
  280. cdn_url = oss_client.get_cdn_url(file_record.storage_path)
  281. # Try to force download by adding Aliyun OSS specific query parameter
  282. # This works for Aliyun OSS even on custom domains if not explicitly disabled
  283. filename = os.path.basename(file_record.relative_path)
  284. quoted_filename = urllib.parse.quote(filename)
  285. # Using both filename and filename* for maximum compatibility
  286. disposition = f"attachment; filename=\"{quoted_filename}\"; filename*=UTF-8''{quoted_filename}"
  287. separator = "&" if "?" in cdn_url else "?"
  288. download_url = f"{cdn_url}{separator}response-content-disposition={urllib.parse.quote(disposition)}"
  289. return RedirectResponse(url=download_url)
  290. if __name__ == "__main__":
  291. import uvicorn
  292. uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True)