main.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. from fastapi import FastAPI, BackgroundTasks, Request, Depends, HTTPException, Header, Query
  2. from fastapi.responses import FileResponse
  3. from sqlalchemy.orm import Session
  4. from sqlalchemy import func as sqlfunc
  5. from typing import List, Optional
  6. from app.config import settings
  7. from app.database import engine, Base, get_db, SessionLocal
  8. from app.services.webhook_service import WebhookService
  9. from app.models import Project, DataVersion, DataFile
  10. from app import schemas
  11. import logging
  12. import os
  13. import hmac
  14. import hashlib
  15. # Static files directory
  16. STATIC_DIR = os.path.join(os.path.dirname(__file__), "static")
  17. # Create tables
  18. Base.metadata.create_all(bind=engine)
  19. logging.basicConfig(level=logging.INFO)
  20. logger = logging.getLogger(__name__)
  21. app = FastAPI(title="Data Nexus", version="0.1.0")
  22. async def process_webhook_task(payload: dict):
  23. """Background task that creates its own db session."""
  24. db = SessionLocal()
  25. try:
  26. service = WebhookService(db)
  27. await service.process_webhook(payload)
  28. except Exception as e:
  29. logger.error(f"Webhook processing failed: {e}", exc_info=True)
  30. # 确保在异常情况下也能正确关闭数据库连接
  31. if hasattr(e, '__cause__') and 'Lost connection' in str(e) or 'MySQL server has gone away' in str(e):
  32. logger.warning("MySQL connection lost, the pool should auto-reconnect due to pool_pre_ping=True")
  33. finally:
  34. try:
  35. db.close()
  36. except Exception as close_error:
  37. logger.error(f"Error closing database session: {close_error}")
  38. def build_file_tree(files: List[DataFile]) -> list:
  39. """Convert flat file list to tree structure."""
  40. tree = {}
  41. for f in files:
  42. parts = f.relative_path.split("/")
  43. current = tree
  44. for i, part in enumerate(parts):
  45. if i == len(parts) - 1:
  46. # It's a file
  47. if "_files" not in current:
  48. current["_files"] = []
  49. current["_files"].append({
  50. "name": part,
  51. "type": "file",
  52. "id": f.id,
  53. "size": f.file_size,
  54. "file_type": f.file_type,
  55. "sha": f.file_sha,
  56. "direction": f.direction,
  57. "label": f.label,
  58. "extracted_value": f.extracted_value,
  59. "group_key": f.group_key
  60. })
  61. else:
  62. # It's a folder
  63. if part not in current:
  64. current[part] = {}
  65. current = current[part]
  66. def convert_to_list(node: dict) -> list:
  67. result = []
  68. for key, value in node.items():
  69. if key == "_files":
  70. result.extend(value)
  71. else:
  72. result.append({
  73. "name": key,
  74. "type": "folder",
  75. "children": convert_to_list(value)
  76. })
  77. # Sort: folders first, then files
  78. result.sort(key=lambda x: (0 if x["type"] == "folder" else 1, x["name"]))
  79. return result
  80. return convert_to_list(tree)
  81. @app.get("/")
  82. def read_root():
  83. """Serve the unified console UI."""
  84. return FileResponse(os.path.join(STATIC_DIR, "records.html"), media_type="text/html")
  85. @app.get("/fs")
  86. def filesystem_page():
  87. """Serve the legacy file system UI."""
  88. return FileResponse(os.path.join(STATIC_DIR, "index.html"), media_type="text/html")
  89. @app.get("/records")
  90. def records_page():
  91. """Serve the data records UI."""
  92. return FileResponse(os.path.join(STATIC_DIR, "records.html"), media_type="text/html")
  93. @app.get("/api/health")
  94. def health_check():
  95. """Health check endpoint."""
  96. return {"status": "ok"}
  97. def verify_webhook_signature(payload_body: bytes, signature: str) -> bool:
  98. """Verify Gogs webhook signature."""
  99. if not settings.GOGS_SECRET:
  100. return True # No secret configured, skip verification
  101. if not signature:
  102. return False
  103. expected = hmac.new(
  104. settings.GOGS_SECRET.encode(),
  105. payload_body,
  106. hashlib.sha256
  107. ).hexdigest()
  108. return hmac.compare_digest(f"sha256={expected}", signature)
  109. @app.post("/webhook")
  110. async def webhook_handler(
  111. request: Request,
  112. background_tasks: BackgroundTasks,
  113. x_gogs_signature: Optional[str] = Header(None)
  114. ):
  115. body = await request.body()
  116. # Verify signature if secret is configured
  117. if settings.GOGS_SECRET and not verify_webhook_signature(body, x_gogs_signature):
  118. raise HTTPException(status_code=401, detail="Invalid signature")
  119. try:
  120. import json
  121. payload = json.loads(body)
  122. except Exception:
  123. raise HTTPException(status_code=400, detail="Invalid JSON")
  124. # Process in background with its own db session
  125. background_tasks.add_task(process_webhook_task, payload)
  126. return {"status": "ok", "message": "Webhook received"}
  127. # ==================== Project APIs ====================
  128. @app.get("/projects", response_model=List[schemas.ProjectOut])
  129. def list_projects(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
  130. """List all projects."""
  131. projects = db.query(Project).offset(skip).limit(limit).all()
  132. return projects
  133. @app.get("/projects/{project_id}", response_model=schemas.ProjectOut)
  134. def get_project(project_id: str, db: Session = Depends(get_db)):
  135. """Get a single project by ID."""
  136. project = db.query(Project).filter(Project.id == project_id).first()
  137. if not project:
  138. raise HTTPException(status_code=404, detail="Project not found")
  139. return project
  140. @app.get("/projects/name/{project_name}", response_model=schemas.ProjectOut)
  141. def get_project_by_name(project_name: str, db: Session = Depends(get_db)):
  142. """Get a project by name."""
  143. project = db.query(Project).filter(Project.project_name == project_name).first()
  144. if not project:
  145. raise HTTPException(status_code=404, detail="Project not found")
  146. return project
  147. # ==================== Console APIs ====================
  148. @app.get("/stages/all")
  149. def get_all_stages(db: Session = Depends(get_db)):
  150. """Get all stages across all projects with version counts."""
  151. results = db.query(
  152. DataVersion.stage,
  153. DataVersion.project_id,
  154. Project.project_name,
  155. sqlfunc.count(DataVersion.id).label("count")
  156. ).join(Project).group_by(
  157. DataVersion.stage, DataVersion.project_id, Project.project_name
  158. ).all()
  159. return [{
  160. "name": r[0],
  161. "project_id": r[1],
  162. "project_name": r[2],
  163. "version_count": r[3]
  164. } for r in results]
  165. @app.get("/projects/{project_id}/stages")
  166. def get_project_stages(project_id: str, db: Session = Depends(get_db)):
  167. """Get all unique stages for a project with version counts."""
  168. results = db.query(
  169. DataVersion.stage,
  170. sqlfunc.count(DataVersion.id).label("count")
  171. ).filter(
  172. DataVersion.project_id == project_id
  173. ).group_by(DataVersion.stage).all()
  174. return [{"name": r[0], "version_count": r[1]} for r in results]
  175. @app.get("/projects/{project_id}/stage-files")
  176. def get_stage_files(
  177. project_id: str,
  178. stage: str = Query(...),
  179. skip: int = 0,
  180. limit: int = 20,
  181. db: Session = Depends(get_db)
  182. ):
  183. """Get versions with files for a specific stage, ordered by newest first."""
  184. versions = db.query(DataVersion).filter(
  185. DataVersion.project_id == project_id,
  186. DataVersion.stage == stage
  187. ).order_by(DataVersion.created_at.desc()).offset(skip).limit(limit).all()
  188. result = []
  189. for v in versions:
  190. files = db.query(DataFile).filter(DataFile.version_id == v.id).all()
  191. result.append({
  192. "version_id": v.id,
  193. "commit_id": v.commit_id,
  194. "author": v.author,
  195. "created_at": v.created_at.isoformat() if v.created_at else None,
  196. "files": [{
  197. "id": f.id,
  198. "name": f.relative_path.split("/")[-1] if f.relative_path else "",
  199. "relative_path": f.relative_path,
  200. "file_size": f.file_size,
  201. "file_type": f.file_type,
  202. "file_sha": f.file_sha,
  203. "direction": f.direction,
  204. "label": f.label,
  205. "extracted_value": f.extracted_value,
  206. "group_key": f.group_key,
  207. } for f in files]
  208. })
  209. return result
  210. @app.get("/projects/{project_id}/records", response_model=List[schemas.DataRecordOut])
  211. def list_data_records(
  212. project_id: str,
  213. stage: Optional[str] = None,
  214. skip: int = 0,
  215. limit: int = 100,
  216. db: Session = Depends(get_db)
  217. ):
  218. """List data records for a project, optionally filtered by stage."""
  219. from app.models import DataRecord
  220. query = db.query(DataRecord).filter(DataRecord.project_id == project_id)
  221. if stage:
  222. query = query.filter(DataRecord.stage == stage)
  223. records = query.order_by(DataRecord.created_at.desc()).offset(skip).limit(limit).all()
  224. return records
  225. # ==================== Version APIs ====================
  226. @app.get("/projects/{project_id}/versions", response_model=List[schemas.DataVersionOut])
  227. def list_versions(
  228. project_id: str,
  229. stage: Optional[str] = None,
  230. skip: int = 0,
  231. limit: int = 100,
  232. db: Session = Depends(get_db)
  233. ):
  234. """List versions for a project, optionally filtered by stage."""
  235. query = db.query(DataVersion).filter(DataVersion.project_id == project_id)
  236. if stage:
  237. query = query.filter(DataVersion.stage == stage)
  238. versions = query.order_by(DataVersion.created_at.desc()).offset(skip).limit(limit).all()
  239. return versions
  240. @app.get("/versions/{version_id}", response_model=schemas.DataVersionOut)
  241. def get_version(version_id: str, db: Session = Depends(get_db)):
  242. """Get a single version by ID."""
  243. version = db.query(DataVersion).filter(DataVersion.id == version_id).first()
  244. if not version:
  245. raise HTTPException(status_code=404, detail="Version not found")
  246. return version
  247. @app.get("/versions/{version_id}/files")
  248. def get_version_files(version_id: str, flat: bool = False, db: Session = Depends(get_db)):
  249. """
  250. Get files for a version.
  251. - flat=False (default): Returns tree structure
  252. - flat=True: Returns flat list
  253. """
  254. version = db.query(DataVersion).filter(DataVersion.id == version_id).first()
  255. if not version:
  256. raise HTTPException(status_code=404, detail="Version not found")
  257. files = db.query(DataFile).filter(DataFile.version_id == version_id).all()
  258. if flat:
  259. return [schemas.DataFileOut.model_validate(f) for f in files]
  260. return build_file_tree(files)
  261. # ==================== File APIs ====================
  262. import urllib.parse
  263. from fastapi.responses import RedirectResponse # noqa: E811
  264. from app.services.oss_client import oss_client
  265. @app.get("/files/{file_id}", response_model=schemas.DataFileOut)
  266. def get_file_info(file_id: int, db: Session = Depends(get_db)):
  267. """Get file metadata."""
  268. file_record = db.query(DataFile).filter(DataFile.id == file_id).first()
  269. if not file_record:
  270. raise HTTPException(status_code=404, detail="File not found")
  271. return file_record
  272. @app.get("/files/{file_id}/url")
  273. def get_file_url(file_id: int, db: Session = Depends(get_db)):
  274. """Get file CDN URL."""
  275. file_record = db.query(DataFile).filter(DataFile.id == file_id).first()
  276. if not file_record:
  277. raise HTTPException(status_code=404, detail="File not found")
  278. cdn_url = oss_client.get_cdn_url(file_record.storage_path)
  279. return {"url": cdn_url}
  280. @app.get("/files/{file_id}/content")
  281. def get_file_content(file_id: int, db: Session = Depends(get_db)):
  282. """Redirect to CDN URL for file download with forced attachment header."""
  283. file_record = db.query(DataFile).filter(DataFile.id == file_id).first()
  284. if not file_record:
  285. raise HTTPException(status_code=404, detail="File not found")
  286. cdn_url = oss_client.get_cdn_url(file_record.storage_path)
  287. # Try to force download by adding Aliyun OSS specific query parameter
  288. # This works for Aliyun OSS even on custom domains if not explicitly disabled
  289. filename = os.path.basename(file_record.relative_path)
  290. quoted_filename = urllib.parse.quote(filename)
  291. # Using both filename and filename* for maximum compatibility
  292. disposition = f"attachment; filename=\"{quoted_filename}\"; filename*=UTF-8''{quoted_filename}"
  293. separator = "&" if "?" in cdn_url else "?"
  294. download_url = f"{cdn_url}{separator}response-content-disposition={urllib.parse.quote(disposition)}"
  295. return RedirectResponse(url=download_url)
  296. if __name__ == "__main__":
  297. import uvicorn
  298. uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True)