|
|
@@ -2,8 +2,18 @@ from threading import Lock
|
|
|
|
|
|
import pyrootutils
|
|
|
import uvicorn
|
|
|
-from kui.asgi import FactoryClass, HTTPException, HttpRoute, Kui, OpenAPI, Routes
|
|
|
+from kui.asgi import (
|
|
|
+ Depends,
|
|
|
+ FactoryClass,
|
|
|
+ HTTPException,
|
|
|
+ HttpRoute,
|
|
|
+ Kui,
|
|
|
+ OpenAPI,
|
|
|
+ Routes,
|
|
|
+)
|
|
|
+from kui.security import bearer_auth
|
|
|
from loguru import logger
|
|
|
+from typing_extensions import Annotated
|
|
|
|
|
|
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
|
|
|
|
@@ -31,7 +41,25 @@ class API(ExceptionHandler):
|
|
|
("/v1/tts", TTSView),
|
|
|
("/v1/chat", ChatView),
|
|
|
]
|
|
|
- self.routes = Routes([HttpRoute(path, view) for path, view in self.routes])
|
|
|
+
|
|
|
+ def api_auth(endpoint):
|
|
|
+ async def verify(token: Annotated[str, Depends(bearer_auth)]):
|
|
|
+ if token != self.args.api_key:
|
|
|
+ raise HTTPException(401, None, "Invalid token")
|
|
|
+ return await endpoint()
|
|
|
+
|
|
|
+ async def passthrough():
|
|
|
+ return await endpoint()
|
|
|
+
|
|
|
+ if self.args.api_key is not None:
|
|
|
+ return verify
|
|
|
+ else:
|
|
|
+ return passthrough
|
|
|
+
|
|
|
+ self.routes = Routes(
|
|
|
+ [HttpRoute(path, view) for path, view in self.routes],
|
|
|
+ http_middlewares=[api_auth],
|
|
|
+ )
|
|
|
|
|
|
self.openapi = OpenAPI(
|
|
|
{
|