from fastapi import FastAPI, BackgroundTasks, Request, Depends, HTTPException, Header, Query from fastapi.responses import FileResponse from sqlalchemy.orm import Session from sqlalchemy import func as sqlfunc from typing import List, Optional from app.config import settings from app.database import engine, Base, get_db, SessionLocal from app.services.webhook_service import WebhookService from app.models import Project, DataVersion, DataFile from app import schemas import logging import os import hmac import hashlib # Static files directory STATIC_DIR = os.path.join(os.path.dirname(__file__), "static") # Create tables Base.metadata.create_all(bind=engine) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title="Data Nexus", version="0.1.0") async def process_webhook_task(payload: dict): """Background task that creates its own db session.""" db = SessionLocal() try: service = WebhookService(db) await service.process_webhook(payload) except Exception as e: logger.error(f"Webhook processing failed: {e}", exc_info=True) # 确保在异常情况下也能正确关闭数据库连接 if hasattr(e, '__cause__') and 'Lost connection' in str(e) or 'MySQL server has gone away' in str(e): logger.warning("MySQL connection lost, the pool should auto-reconnect due to pool_pre_ping=True") finally: try: db.close() except Exception as close_error: logger.error(f"Error closing database session: {close_error}") def build_file_tree(files: List[DataFile]) -> list: """Convert flat file list to tree structure.""" tree = {} for f in files: parts = f.relative_path.split("/") current = tree for i, part in enumerate(parts): if i == len(parts) - 1: # It's a file if "_files" not in current: current["_files"] = [] current["_files"].append({ "name": part, "type": "file", "id": f.id, "size": f.file_size, "file_type": f.file_type, "sha": f.file_sha, "direction": f.direction, "label": f.label, "extracted_value": f.extracted_value, "group_key": f.group_key }) else: # It's a folder if part not in current: current[part] = {} current = current[part] def convert_to_list(node: dict) -> list: result = [] for key, value in node.items(): if key == "_files": result.extend(value) else: result.append({ "name": key, "type": "folder", "children": convert_to_list(value) }) # Sort: folders first, then files result.sort(key=lambda x: (0 if x["type"] == "folder" else 1, x["name"])) return result return convert_to_list(tree) @app.get("/") def read_root(): """Serve the unified console UI.""" return FileResponse(os.path.join(STATIC_DIR, "records.html"), media_type="text/html") @app.get("/fs") def filesystem_page(): """Serve the legacy file system UI.""" return FileResponse(os.path.join(STATIC_DIR, "index.html"), media_type="text/html") @app.get("/records") def records_page(): """Serve the data records UI.""" return FileResponse(os.path.join(STATIC_DIR, "records.html"), media_type="text/html") @app.get("/api/health") def health_check(): """Health check endpoint.""" return {"status": "ok"} def verify_webhook_signature(payload_body: bytes, signature: str) -> bool: """Verify Gogs webhook signature.""" if not settings.GOGS_SECRET: return True # No secret configured, skip verification if not signature: return False expected = hmac.new( settings.GOGS_SECRET.encode(), payload_body, hashlib.sha256 ).hexdigest() return hmac.compare_digest(f"sha256={expected}", signature) @app.post("/webhook") async def webhook_handler( request: Request, background_tasks: BackgroundTasks, x_gogs_signature: Optional[str] = Header(None) ): body = await request.body() # Verify signature if secret is configured if settings.GOGS_SECRET and not verify_webhook_signature(body, x_gogs_signature): raise HTTPException(status_code=401, detail="Invalid signature") try: import json payload = json.loads(body) except Exception: raise HTTPException(status_code=400, detail="Invalid JSON") # Process in background with its own db session background_tasks.add_task(process_webhook_task, payload) return {"status": "ok", "message": "Webhook received"} # ==================== Project APIs ==================== @app.get("/projects", response_model=List[schemas.ProjectOut]) def list_projects(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)): """List all projects.""" projects = db.query(Project).offset(skip).limit(limit).all() return projects @app.get("/projects/{project_id}", response_model=schemas.ProjectOut) def get_project(project_id: str, db: Session = Depends(get_db)): """Get a single project by ID.""" project = db.query(Project).filter(Project.id == project_id).first() if not project: raise HTTPException(status_code=404, detail="Project not found") return project @app.get("/projects/name/{project_name}", response_model=schemas.ProjectOut) def get_project_by_name(project_name: str, db: Session = Depends(get_db)): """Get a project by name.""" project = db.query(Project).filter(Project.project_name == project_name).first() if not project: raise HTTPException(status_code=404, detail="Project not found") return project # ==================== Console APIs ==================== @app.get("/stages/all") def get_all_stages(db: Session = Depends(get_db)): """Get all stages across all projects with version counts.""" results = db.query( DataVersion.stage, DataVersion.project_id, Project.project_name, sqlfunc.count(DataVersion.id).label("count") ).join(Project).group_by( DataVersion.stage, DataVersion.project_id, Project.project_name ).all() return [{ "name": r[0], "project_id": r[1], "project_name": r[2], "version_count": r[3] } for r in results] @app.get("/projects/{project_id}/stages") def get_project_stages(project_id: str, db: Session = Depends(get_db)): """Get all unique stages for a project with version counts.""" results = db.query( DataVersion.stage, sqlfunc.count(DataVersion.id).label("count") ).filter( DataVersion.project_id == project_id ).group_by(DataVersion.stage).all() return [{"name": r[0], "version_count": r[1]} for r in results] @app.get("/projects/{project_id}/stage-files") def get_stage_files( project_id: str, stage: str = Query(...), skip: int = 0, limit: int = 20, db: Session = Depends(get_db) ): """Get versions with files for a specific stage, ordered by newest first.""" versions = db.query(DataVersion).filter( DataVersion.project_id == project_id, DataVersion.stage == stage ).order_by(DataVersion.created_at.desc()).offset(skip).limit(limit).all() result = [] for v in versions: files = db.query(DataFile).filter(DataFile.version_id == v.id).all() result.append({ "version_id": v.id, "commit_id": v.commit_id, "author": v.author, "created_at": v.created_at.isoformat() if v.created_at else None, "files": [{ "id": f.id, "name": f.relative_path.split("/")[-1] if f.relative_path else "", "relative_path": f.relative_path, "file_size": f.file_size, "file_type": f.file_type, "file_sha": f.file_sha, "direction": f.direction, "label": f.label, "extracted_value": f.extracted_value, "group_key": f.group_key, } for f in files] }) return result @app.get("/projects/{project_id}/records", response_model=List[schemas.DataRecordOut]) def list_data_records( project_id: str, stage: Optional[str] = None, skip: int = 0, limit: int = 100, db: Session = Depends(get_db) ): """List data records for a project, optionally filtered by stage.""" from app.models import DataRecord query = db.query(DataRecord).filter(DataRecord.project_id == project_id) if stage: query = query.filter(DataRecord.stage == stage) records = query.order_by(DataRecord.created_at.desc()).offset(skip).limit(limit).all() return records # ==================== Version APIs ==================== @app.get("/projects/{project_id}/versions", response_model=List[schemas.DataVersionOut]) def list_versions( project_id: str, stage: Optional[str] = None, skip: int = 0, limit: int = 100, db: Session = Depends(get_db) ): """List versions for a project, optionally filtered by stage.""" query = db.query(DataVersion).filter(DataVersion.project_id == project_id) if stage: query = query.filter(DataVersion.stage == stage) versions = query.order_by(DataVersion.created_at.desc()).offset(skip).limit(limit).all() return versions @app.get("/versions/{version_id}", response_model=schemas.DataVersionOut) def get_version(version_id: str, db: Session = Depends(get_db)): """Get a single version by ID.""" version = db.query(DataVersion).filter(DataVersion.id == version_id).first() if not version: raise HTTPException(status_code=404, detail="Version not found") return version @app.get("/versions/{version_id}/files") def get_version_files(version_id: str, flat: bool = False, db: Session = Depends(get_db)): """ Get files for a version. - flat=False (default): Returns tree structure - flat=True: Returns flat list """ version = db.query(DataVersion).filter(DataVersion.id == version_id).first() if not version: raise HTTPException(status_code=404, detail="Version not found") files = db.query(DataFile).filter(DataFile.version_id == version_id).all() if flat: return [schemas.DataFileOut.model_validate(f) for f in files] return build_file_tree(files) # ==================== File APIs ==================== import urllib.parse from fastapi.responses import RedirectResponse # noqa: E811 from app.services.oss_client import oss_client @app.get("/files/{file_id}", response_model=schemas.DataFileOut) def get_file_info(file_id: int, db: Session = Depends(get_db)): """Get file metadata.""" file_record = db.query(DataFile).filter(DataFile.id == file_id).first() if not file_record: raise HTTPException(status_code=404, detail="File not found") return file_record @app.get("/files/{file_id}/url") def get_file_url(file_id: int, db: Session = Depends(get_db)): """Get file CDN URL.""" file_record = db.query(DataFile).filter(DataFile.id == file_id).first() if not file_record: raise HTTPException(status_code=404, detail="File not found") cdn_url = oss_client.get_cdn_url(file_record.storage_path) return {"url": cdn_url} @app.get("/files/{file_id}/content") def get_file_content(file_id: int, db: Session = Depends(get_db)): """Redirect to CDN URL for file download with forced attachment header.""" file_record = db.query(DataFile).filter(DataFile.id == file_id).first() if not file_record: raise HTTPException(status_code=404, detail="File not found") cdn_url = oss_client.get_cdn_url(file_record.storage_path) # Try to force download by adding Aliyun OSS specific query parameter # This works for Aliyun OSS even on custom domains if not explicitly disabled filename = os.path.basename(file_record.relative_path) quoted_filename = urllib.parse.quote(filename) # Using both filename and filename* for maximum compatibility disposition = f"attachment; filename=\"{quoted_filename}\"; filename*=UTF-8''{quoted_filename}" separator = "&" if "?" in cdn_url else "?" download_url = f"{cdn_url}{separator}response-content-disposition={urllib.parse.quote(disposition)}" return RedirectResponse(url=download_url) if __name__ == "__main__": import uvicorn uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True)