main.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. from fastapi import FastAPI, BackgroundTasks, Request, Depends, HTTPException, Header, Query
  2. from fastapi.responses import FileResponse
  3. from sqlalchemy.orm import Session
  4. from sqlalchemy import func as sqlfunc
  5. from typing import List, Optional
  6. from app.config import settings
  7. from app.database import engine, Base, get_db, SessionLocal
  8. from app.services.webhook_service import WebhookService
  9. from app.models import Project, DataVersion, DataFile
  10. from app import schemas
  11. import logging
  12. import os
  13. import hmac
  14. import hashlib
  15. # Static files directory
  16. STATIC_DIR = os.path.join(os.path.dirname(__file__), "static")
  17. # Create tables
  18. Base.metadata.create_all(bind=engine)
  19. logging.basicConfig(level=logging.INFO)
  20. logger = logging.getLogger(__name__)
  21. app = FastAPI(title="Data Nexus", version="0.1.0")
  22. async def process_webhook_task(payload: dict):
  23. """Background task that creates its own db session."""
  24. db = SessionLocal()
  25. try:
  26. service = WebhookService(db)
  27. await service.process_webhook(payload)
  28. except Exception as e:
  29. logger.error(f"Webhook processing failed: {e}", exc_info=True)
  30. # 确保在异常情况下也能正确关闭数据库连接
  31. if hasattr(e, '__cause__') and 'Lost connection' in str(e) or 'MySQL server has gone away' in str(e):
  32. logger.warning("MySQL connection lost, the pool should auto-reconnect due to pool_pre_ping=True")
  33. finally:
  34. try:
  35. db.close()
  36. except Exception as close_error:
  37. logger.error(f"Error closing database session: {close_error}")
  38. def build_file_tree(files: List[DataFile]) -> list:
  39. """Convert flat file list to tree structure."""
  40. tree = {}
  41. for f in files:
  42. parts = f.relative_path.split("/")
  43. current = tree
  44. for i, part in enumerate(parts):
  45. if i == len(parts) - 1:
  46. # It's a file
  47. if "_files" not in current:
  48. current["_files"] = []
  49. current["_files"].append({
  50. "name": part,
  51. "type": "file",
  52. "id": f.id,
  53. "size": f.file_size,
  54. "file_type": f.file_type,
  55. "sha": f.file_sha,
  56. "direction": f.direction,
  57. "label": f.label,
  58. "extracted_value": f.extracted_value,
  59. "group_key": f.group_key
  60. })
  61. else:
  62. # It's a folder
  63. if part not in current:
  64. current[part] = {}
  65. current = current[part]
  66. def convert_to_list(node: dict) -> list:
  67. result = []
  68. for key, value in node.items():
  69. if key == "_files":
  70. result.extend(value)
  71. else:
  72. result.append({
  73. "name": key,
  74. "type": "folder",
  75. "children": convert_to_list(value)
  76. })
  77. # Sort: folders first, then files
  78. result.sort(key=lambda x: (0 if x["type"] == "folder" else 1, x["name"]))
  79. return result
  80. return convert_to_list(tree)
  81. @app.get("/")
  82. def read_root():
  83. """Serve the unified console UI."""
  84. return FileResponse(os.path.join(STATIC_DIR, "records.html"), media_type="text/html")
  85. @app.get("/records")
  86. def records_page():
  87. """Serve the data records UI."""
  88. return FileResponse(os.path.join(STATIC_DIR, "records.html"), media_type="text/html")
  89. @app.get("/api/health")
  90. def health_check():
  91. """Health check endpoint."""
  92. return {"status": "ok"}
  93. def verify_webhook_signature(payload_body: bytes, signature: str) -> bool:
  94. """Verify Gogs webhook signature."""
  95. if not settings.GOGS_SECRET:
  96. return True # No secret configured, skip verification
  97. if not signature:
  98. return False
  99. expected = hmac.new(
  100. settings.GOGS_SECRET.encode(),
  101. payload_body,
  102. hashlib.sha256
  103. ).hexdigest()
  104. return hmac.compare_digest(f"sha256={expected}", signature)
  105. @app.post("/webhook")
  106. async def webhook_handler(
  107. request: Request,
  108. background_tasks: BackgroundTasks,
  109. x_gogs_signature: Optional[str] = Header(None)
  110. ):
  111. body = await request.body()
  112. # Verify signature if secret is configured
  113. if settings.GOGS_SECRET and not verify_webhook_signature(body, x_gogs_signature):
  114. raise HTTPException(status_code=401, detail="Invalid signature")
  115. try:
  116. import json
  117. payload = json.loads(body)
  118. except Exception:
  119. raise HTTPException(status_code=400, detail="Invalid JSON")
  120. # Process in background with its own db session
  121. background_tasks.add_task(process_webhook_task, payload)
  122. return {"status": "ok", "message": "Webhook received"}
  123. # ==================== Project APIs ====================
  124. @app.get("/projects", response_model=List[schemas.ProjectOut])
  125. def list_projects(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
  126. """List all projects."""
  127. projects = db.query(Project).offset(skip).limit(limit).all()
  128. return projects
  129. @app.get("/projects/{project_id}", response_model=schemas.ProjectOut)
  130. def get_project(project_id: str, db: Session = Depends(get_db)):
  131. """Get a single project by ID."""
  132. project = db.query(Project).filter(Project.id == project_id).first()
  133. if not project:
  134. raise HTTPException(status_code=404, detail="Project not found")
  135. return project
  136. @app.get("/projects/name/{project_name}", response_model=schemas.ProjectOut)
  137. def get_project_by_name(project_name: str, db: Session = Depends(get_db)):
  138. """Get a project by name."""
  139. project = db.query(Project).filter(Project.project_name == project_name).first()
  140. if not project:
  141. raise HTTPException(status_code=404, detail="Project not found")
  142. return project
  143. # ==================== Console APIs ====================
  144. @app.get("/stages/all")
  145. def get_all_stages(db: Session = Depends(get_db)):
  146. """Get all stages across all projects with version counts."""
  147. results = db.query(
  148. DataVersion.stage,
  149. DataVersion.project_id,
  150. Project.project_name,
  151. sqlfunc.count(DataVersion.id).label("count")
  152. ).join(Project).group_by(
  153. DataVersion.stage, DataVersion.project_id, Project.project_name
  154. ).all()
  155. return [{
  156. "name": r[0],
  157. "project_id": r[1],
  158. "project_name": r[2],
  159. "version_count": r[3]
  160. } for r in results]
  161. @app.get("/projects/{project_id}/stages")
  162. def get_project_stages(project_id: str, db: Session = Depends(get_db)):
  163. """Get all unique stages for a project with version counts."""
  164. results = db.query(
  165. DataVersion.stage,
  166. sqlfunc.count(DataVersion.id).label("count")
  167. ).filter(
  168. DataVersion.project_id == project_id
  169. ).group_by(DataVersion.stage).all()
  170. return [{"name": r[0], "version_count": r[1]} for r in results]
  171. @app.get("/projects/{project_id}/stage-files")
  172. def get_stage_files(
  173. project_id: str,
  174. stage: str = Query(...),
  175. skip: int = 0,
  176. limit: int = 20,
  177. db: Session = Depends(get_db)
  178. ):
  179. """Get versions with files for a specific stage, ordered by newest first."""
  180. versions = db.query(DataVersion).filter(
  181. DataVersion.project_id == project_id,
  182. DataVersion.stage == stage
  183. ).order_by(DataVersion.created_at.desc()).offset(skip).limit(limit).all()
  184. result = []
  185. for v in versions:
  186. files = db.query(DataFile).filter(DataFile.version_id == v.id).all()
  187. result.append({
  188. "version_id": v.id,
  189. "commit_id": v.commit_id,
  190. "author": v.author,
  191. "created_at": v.created_at.isoformat() if v.created_at else None,
  192. "files": [{
  193. "id": f.id,
  194. "name": f.relative_path.split("/")[-1] if f.relative_path else "",
  195. "relative_path": f.relative_path,
  196. "file_size": f.file_size,
  197. "file_type": f.file_type,
  198. "file_sha": f.file_sha,
  199. "direction": f.direction,
  200. "label": f.label,
  201. "extracted_value": f.extracted_value,
  202. "group_key": f.group_key,
  203. } for f in files]
  204. })
  205. return result
  206. @app.get("/projects/{project_id}/records", response_model=List[schemas.DataRecordOut])
  207. def list_data_records(
  208. project_id: str,
  209. stage: Optional[str] = None,
  210. skip: int = 0,
  211. limit: int = 100,
  212. db: Session = Depends(get_db)
  213. ):
  214. """List data records for a project, optionally filtered by stage."""
  215. from app.models import DataRecord
  216. query = db.query(DataRecord).filter(DataRecord.project_id == project_id)
  217. if stage:
  218. query = query.filter(DataRecord.stage == stage)
  219. records = query.order_by(DataRecord.created_at.desc()).offset(skip).limit(limit).all()
  220. return records
  221. # ==================== Version APIs ====================
  222. @app.get("/projects/{project_id}/versions", response_model=List[schemas.DataVersionOut])
  223. def list_versions(
  224. project_id: str,
  225. stage: Optional[str] = None,
  226. skip: int = 0,
  227. limit: int = 100,
  228. db: Session = Depends(get_db)
  229. ):
  230. """List versions for a project, optionally filtered by stage."""
  231. query = db.query(DataVersion).filter(DataVersion.project_id == project_id)
  232. if stage:
  233. query = query.filter(DataVersion.stage == stage)
  234. versions = query.order_by(DataVersion.created_at.desc()).offset(skip).limit(limit).all()
  235. return versions
  236. @app.get("/versions/{version_id}", response_model=schemas.DataVersionOut)
  237. def get_version(version_id: str, db: Session = Depends(get_db)):
  238. """Get a single version by ID."""
  239. version = db.query(DataVersion).filter(DataVersion.id == version_id).first()
  240. if not version:
  241. raise HTTPException(status_code=404, detail="Version not found")
  242. return version
  243. @app.get("/versions/{version_id}/files")
  244. def get_version_files(version_id: str, flat: bool = False, db: Session = Depends(get_db)):
  245. """
  246. Get files for a version.
  247. - flat=False (default): Returns tree structure
  248. - flat=True: Returns flat list
  249. """
  250. version = db.query(DataVersion).filter(DataVersion.id == version_id).first()
  251. if not version:
  252. raise HTTPException(status_code=404, detail="Version not found")
  253. files = db.query(DataFile).filter(DataFile.version_id == version_id).all()
  254. if flat:
  255. return [schemas.DataFileOut.model_validate(f) for f in files]
  256. return build_file_tree(files)
  257. # ==================== File APIs ====================
  258. import urllib.parse
  259. from fastapi.responses import RedirectResponse # noqa: E811
  260. from app.services.oss_client import oss_client
  261. @app.get("/files/{file_id}", response_model=schemas.DataFileOut)
  262. def get_file_info(file_id: int, db: Session = Depends(get_db)):
  263. """Get file metadata."""
  264. file_record = db.query(DataFile).filter(DataFile.id == file_id).first()
  265. if not file_record:
  266. raise HTTPException(status_code=404, detail="File not found")
  267. return file_record
  268. @app.get("/files/{file_id}/url")
  269. def get_file_url(file_id: int, db: Session = Depends(get_db)):
  270. """Get file CDN URL."""
  271. file_record = db.query(DataFile).filter(DataFile.id == file_id).first()
  272. if not file_record:
  273. raise HTTPException(status_code=404, detail="File not found")
  274. cdn_url = oss_client.get_cdn_url(file_record.storage_path)
  275. return {"url": cdn_url}
  276. @app.get("/files/{file_id}/content")
  277. def get_file_content(file_id: int, db: Session = Depends(get_db)):
  278. """Redirect to CDN URL for file download with forced attachment header."""
  279. file_record = db.query(DataFile).filter(DataFile.id == file_id).first()
  280. if not file_record:
  281. raise HTTPException(status_code=404, detail="File not found")
  282. cdn_url = oss_client.get_cdn_url(file_record.storage_path)
  283. # Try to force download by adding Aliyun OSS specific query parameter
  284. # This works for Aliyun OSS even on custom domains if not explicitly disabled
  285. filename = os.path.basename(file_record.relative_path)
  286. quoted_filename = urllib.parse.quote(filename)
  287. # Using both filename and filename* for maximum compatibility
  288. disposition = f"attachment; filename=\"{quoted_filename}\"; filename*=UTF-8''{quoted_filename}"
  289. separator = "&" if "?" in cdn_url else "?"
  290. download_url = f"{cdn_url}{separator}response-content-disposition={urllib.parse.quote(disposition)}"
  291. return RedirectResponse(url=download_url)
  292. if __name__ == "__main__":
  293. import uvicorn
  294. uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True)