소스 검색

feat: Bearer auth for HTTP API (#746)

* feat: Bearer auth for HTTP API

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Hao Guan 1 년 전
부모
커밋
d8d71b2cbb
3개의 변경된 파일38개의 추가작업 그리고 3개의 파일을 삭제
  1. 7 1
      tools/api_client.py
  2. 30 2
      tools/api_server.py
  3. 1 0
      tools/server/api_utils.py

+ 7 - 1
tools/api_client.py

@@ -119,6 +119,12 @@ def parse_args():
         help="`None` means randomized inference, otherwise deterministic.\n"
         "It can't be used for fixing a timbre.",
     )
+    parser.add_argument(
+        "--api_key",
+        type=str,
+        default="YOUR_API_KEY",
+        help="API key for authentication",
+    )
 
     return parser.parse_args()
 
@@ -173,7 +179,7 @@ if __name__ == "__main__":
         data=ormsgpack.packb(pydantic_data, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
         stream=args.streaming,
         headers={
-            "authorization": "Bearer YOUR_API_KEY",
+            "authorization": f"Bearer {args.api_key}",
             "content-type": "application/msgpack",
         },
     )

+ 30 - 2
tools/api_server.py

@@ -2,8 +2,18 @@ from threading import Lock
 
 import pyrootutils
 import uvicorn
-from kui.asgi import FactoryClass, HTTPException, HttpRoute, Kui, OpenAPI, Routes
+from kui.asgi import (
+    Depends,
+    FactoryClass,
+    HTTPException,
+    HttpRoute,
+    Kui,
+    OpenAPI,
+    Routes,
+)
+from kui.security import bearer_auth
 from loguru import logger
+from typing_extensions import Annotated
 
 pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
 
@@ -31,7 +41,25 @@ class API(ExceptionHandler):
             ("/v1/tts", TTSView),
             ("/v1/chat", ChatView),
         ]
-        self.routes = Routes([HttpRoute(path, view) for path, view in self.routes])
+
+        def api_auth(endpoint):
+            async def verify(token: Annotated[str, Depends(bearer_auth)]):
+                if token != self.args.api_key:
+                    raise HTTPException(401, None, "Invalid token")
+                return await endpoint()
+
+            async def passthrough():
+                return await endpoint()
+
+            if self.args.api_key is not None:
+                return verify
+            else:
+                return passthrough
+
+        self.routes = Routes(
+            [HttpRoute(path, view) for path, view in self.routes],
+            http_middlewares=[api_auth],
+        )
 
         self.openapi = OpenAPI(
             {

+ 1 - 0
tools/server/api_utils.py

@@ -32,6 +32,7 @@ def parse_args():
     parser.add_argument("--max-text-length", type=int, default=0)
     parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
     parser.add_argument("--workers", type=int, default=1)
+    parser.add_argument("--api-key", type=str, default=None)
 
     return parser.parse_args()