| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366 |
- from fastapi import FastAPI, BackgroundTasks, Request, Depends, HTTPException, Header, Query
- from fastapi.responses import FileResponse
- from sqlalchemy.orm import Session
- from sqlalchemy import func as sqlfunc
- from typing import List, Optional
- from app.config import settings
- from app.database import engine, Base, get_db, SessionLocal
- from app.services.webhook_service import WebhookService
- from app.models import Project, DataVersion, DataFile
- from app import schemas
- import logging
- import os
- import hmac
- import hashlib
- # Static files directory
- STATIC_DIR = os.path.join(os.path.dirname(__file__), "static")
- # Create tables
- Base.metadata.create_all(bind=engine)
- logging.basicConfig(level=logging.INFO)
- logger = logging.getLogger(__name__)
- app = FastAPI(title="Data Nexus", version="0.1.0")
- async def process_webhook_task(payload: dict):
- """Background task that creates its own db session."""
- db = SessionLocal()
- try:
- service = WebhookService(db)
- await service.process_webhook(payload)
- except Exception as e:
- logger.error(f"Webhook processing failed: {e}", exc_info=True)
- # 确保在异常情况下也能正确关闭数据库连接
- if hasattr(e, '__cause__') and 'Lost connection' in str(e) or 'MySQL server has gone away' in str(e):
- logger.warning("MySQL connection lost, the pool should auto-reconnect due to pool_pre_ping=True")
- finally:
- try:
- db.close()
- except Exception as close_error:
- logger.error(f"Error closing database session: {close_error}")
- def build_file_tree(files: List[DataFile]) -> list:
- """Convert flat file list to tree structure."""
- tree = {}
- for f in files:
- parts = f.relative_path.split("/")
- current = tree
- for i, part in enumerate(parts):
- if i == len(parts) - 1:
- # It's a file
- if "_files" not in current:
- current["_files"] = []
- current["_files"].append({
- "name": part,
- "type": "file",
- "id": f.id,
- "size": f.file_size,
- "file_type": f.file_type,
- "sha": f.file_sha,
- "direction": f.direction,
- "label": f.label,
- "extracted_value": f.extracted_value,
- "group_key": f.group_key
- })
- else:
- # It's a folder
- if part not in current:
- current[part] = {}
- current = current[part]
- def convert_to_list(node: dict) -> list:
- result = []
- for key, value in node.items():
- if key == "_files":
- result.extend(value)
- else:
- result.append({
- "name": key,
- "type": "folder",
- "children": convert_to_list(value)
- })
- # Sort: folders first, then files
- result.sort(key=lambda x: (0 if x["type"] == "folder" else 1, x["name"]))
- return result
- return convert_to_list(tree)
- @app.get("/")
- def read_root():
- """Serve the unified console UI."""
- return FileResponse(os.path.join(STATIC_DIR, "records.html"), media_type="text/html")
- @app.get("/fs")
- def filesystem_page():
- """Serve the legacy file system UI."""
- return FileResponse(os.path.join(STATIC_DIR, "index.html"), media_type="text/html")
- @app.get("/records")
- def records_page():
- """Serve the data records UI."""
- return FileResponse(os.path.join(STATIC_DIR, "records.html"), media_type="text/html")
- @app.get("/api/health")
- def health_check():
- """Health check endpoint."""
- return {"status": "ok"}
- def verify_webhook_signature(payload_body: bytes, signature: str) -> bool:
- """Verify Gogs webhook signature."""
- if not settings.GOGS_SECRET:
- return True # No secret configured, skip verification
- if not signature:
- return False
- expected = hmac.new(
- settings.GOGS_SECRET.encode(),
- payload_body,
- hashlib.sha256
- ).hexdigest()
- return hmac.compare_digest(f"sha256={expected}", signature)
- @app.post("/webhook")
- async def webhook_handler(
- request: Request,
- background_tasks: BackgroundTasks,
- x_gogs_signature: Optional[str] = Header(None)
- ):
- body = await request.body()
- # Verify signature if secret is configured
- if settings.GOGS_SECRET and not verify_webhook_signature(body, x_gogs_signature):
- raise HTTPException(status_code=401, detail="Invalid signature")
- try:
- import json
- payload = json.loads(body)
- except Exception:
- raise HTTPException(status_code=400, detail="Invalid JSON")
- # Process in background with its own db session
- background_tasks.add_task(process_webhook_task, payload)
- return {"status": "ok", "message": "Webhook received"}
- # ==================== Project APIs ====================
- @app.get("/projects", response_model=List[schemas.ProjectOut])
- def list_projects(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
- """List all projects."""
- projects = db.query(Project).offset(skip).limit(limit).all()
- return projects
- @app.get("/projects/{project_id}", response_model=schemas.ProjectOut)
- def get_project(project_id: str, db: Session = Depends(get_db)):
- """Get a single project by ID."""
- project = db.query(Project).filter(Project.id == project_id).first()
- if not project:
- raise HTTPException(status_code=404, detail="Project not found")
- return project
- @app.get("/projects/name/{project_name}", response_model=schemas.ProjectOut)
- def get_project_by_name(project_name: str, db: Session = Depends(get_db)):
- """Get a project by name."""
- project = db.query(Project).filter(Project.project_name == project_name).first()
- if not project:
- raise HTTPException(status_code=404, detail="Project not found")
- return project
- # ==================== Console APIs ====================
- @app.get("/stages/all")
- def get_all_stages(db: Session = Depends(get_db)):
- """Get all stages across all projects with version counts."""
- results = db.query(
- DataVersion.stage,
- DataVersion.project_id,
- Project.project_name,
- sqlfunc.count(DataVersion.id).label("count")
- ).join(Project).group_by(
- DataVersion.stage, DataVersion.project_id, Project.project_name
- ).all()
- return [{
- "name": r[0],
- "project_id": r[1],
- "project_name": r[2],
- "version_count": r[3]
- } for r in results]
- @app.get("/projects/{project_id}/stages")
- def get_project_stages(project_id: str, db: Session = Depends(get_db)):
- """Get all unique stages for a project with version counts."""
- results = db.query(
- DataVersion.stage,
- sqlfunc.count(DataVersion.id).label("count")
- ).filter(
- DataVersion.project_id == project_id
- ).group_by(DataVersion.stage).all()
- return [{"name": r[0], "version_count": r[1]} for r in results]
- @app.get("/projects/{project_id}/stage-files")
- def get_stage_files(
- project_id: str,
- stage: str = Query(...),
- skip: int = 0,
- limit: int = 20,
- db: Session = Depends(get_db)
- ):
- """Get versions with files for a specific stage, ordered by newest first."""
- versions = db.query(DataVersion).filter(
- DataVersion.project_id == project_id,
- DataVersion.stage == stage
- ).order_by(DataVersion.created_at.desc()).offset(skip).limit(limit).all()
- result = []
- for v in versions:
- files = db.query(DataFile).filter(DataFile.version_id == v.id).all()
- result.append({
- "version_id": v.id,
- "commit_id": v.commit_id,
- "author": v.author,
- "created_at": v.created_at.isoformat() if v.created_at else None,
- "files": [{
- "id": f.id,
- "name": f.relative_path.split("/")[-1] if f.relative_path else "",
- "relative_path": f.relative_path,
- "file_size": f.file_size,
- "file_type": f.file_type,
- "file_sha": f.file_sha,
- "direction": f.direction,
- "label": f.label,
- "extracted_value": f.extracted_value,
- "group_key": f.group_key,
- } for f in files]
- })
- return result
- @app.get("/projects/{project_id}/records", response_model=List[schemas.DataRecordOut])
- def list_data_records(
- project_id: str,
- stage: Optional[str] = None,
- skip: int = 0,
- limit: int = 100,
- db: Session = Depends(get_db)
- ):
- """List data records for a project, optionally filtered by stage."""
- from app.models import DataRecord
- query = db.query(DataRecord).filter(DataRecord.project_id == project_id)
- if stage:
- query = query.filter(DataRecord.stage == stage)
- records = query.order_by(DataRecord.created_at.desc()).offset(skip).limit(limit).all()
- return records
- # ==================== Version APIs ====================
- @app.get("/projects/{project_id}/versions", response_model=List[schemas.DataVersionOut])
- def list_versions(
- project_id: str,
- stage: Optional[str] = None,
- skip: int = 0,
- limit: int = 100,
- db: Session = Depends(get_db)
- ):
- """List versions for a project, optionally filtered by stage."""
- query = db.query(DataVersion).filter(DataVersion.project_id == project_id)
- if stage:
- query = query.filter(DataVersion.stage == stage)
- versions = query.order_by(DataVersion.created_at.desc()).offset(skip).limit(limit).all()
- return versions
- @app.get("/versions/{version_id}", response_model=schemas.DataVersionOut)
- def get_version(version_id: str, db: Session = Depends(get_db)):
- """Get a single version by ID."""
- version = db.query(DataVersion).filter(DataVersion.id == version_id).first()
- if not version:
- raise HTTPException(status_code=404, detail="Version not found")
- return version
- @app.get("/versions/{version_id}/files")
- def get_version_files(version_id: str, flat: bool = False, db: Session = Depends(get_db)):
- """
- Get files for a version.
- - flat=False (default): Returns tree structure
- - flat=True: Returns flat list
- """
- version = db.query(DataVersion).filter(DataVersion.id == version_id).first()
- if not version:
- raise HTTPException(status_code=404, detail="Version not found")
- files = db.query(DataFile).filter(DataFile.version_id == version_id).all()
- if flat:
- return [schemas.DataFileOut.model_validate(f) for f in files]
- return build_file_tree(files)
- # ==================== File APIs ====================
- import urllib.parse
- from fastapi.responses import RedirectResponse # noqa: E811
- from app.services.oss_client import oss_client
- @app.get("/files/{file_id}", response_model=schemas.DataFileOut)
- def get_file_info(file_id: int, db: Session = Depends(get_db)):
- """Get file metadata."""
- file_record = db.query(DataFile).filter(DataFile.id == file_id).first()
- if not file_record:
- raise HTTPException(status_code=404, detail="File not found")
- return file_record
- @app.get("/files/{file_id}/url")
- def get_file_url(file_id: int, db: Session = Depends(get_db)):
- """Get file CDN URL."""
- file_record = db.query(DataFile).filter(DataFile.id == file_id).first()
- if not file_record:
- raise HTTPException(status_code=404, detail="File not found")
- cdn_url = oss_client.get_cdn_url(file_record.storage_path)
- return {"url": cdn_url}
- @app.get("/files/{file_id}/content")
- def get_file_content(file_id: int, db: Session = Depends(get_db)):
- """Redirect to CDN URL for file download with forced attachment header."""
- file_record = db.query(DataFile).filter(DataFile.id == file_id).first()
- if not file_record:
- raise HTTPException(status_code=404, detail="File not found")
- cdn_url = oss_client.get_cdn_url(file_record.storage_path)
-
- # Try to force download by adding Aliyun OSS specific query parameter
- # This works for Aliyun OSS even on custom domains if not explicitly disabled
- filename = os.path.basename(file_record.relative_path)
- quoted_filename = urllib.parse.quote(filename)
- # Using both filename and filename* for maximum compatibility
- disposition = f"attachment; filename=\"{quoted_filename}\"; filename*=UTF-8''{quoted_filename}"
-
- separator = "&" if "?" in cdn_url else "?"
- download_url = f"{cdn_url}{separator}response-content-disposition={urllib.parse.quote(disposition)}"
-
- return RedirectResponse(url=download_url)
- if __name__ == "__main__":
- import uvicorn
- uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True)
|