Просмотр исходного кода

feat:添加【PYTORCH_CUDA_ALLOC_CONF】环境变量

zhaohaipeng 2 недель назад
Родитель
Сommit
93f5aaa610
2 измененных файлов с 12 добавлено и 10 удалено
  1. 1 0
      compose.base.yml
  2. 11 10
      tools/api_server.py

+ 1 - 0
compose.base.yml

@@ -13,6 +13,7 @@ services:
       - /root/fish-references/s2-pro:/app/references
     environment:
       COMPILE: ${COMPILE:-0}
+      PYTORCH_CUDA_ALLOC_CONF: ${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}
     # GPU (remove this block if CPU-only):
     deploy:
       resources:

+ 11 - 10
tools/api_server.py

@@ -11,13 +11,10 @@ from kui.asgi import (
     Depends,
     FactoryClass,
     HTTPException,
-    HttpRoute,
     Kui,
-    OpenAPI,
     Routes,
 )
 from kui.cors import CORSConfig
-from kui.openapi.specification import Info
 from kui.security import bearer_auth
 from loguru import logger
 from typing_extensions import Annotated
@@ -84,6 +81,9 @@ class API(ExceptionHandler):
         # Associate the app with the model manager
         self.app.on_startup(self.initialize_app)
 
+        # Print args
+        self.args_print()
+
     async def initialize_app(self, app: Kui):
         # Make the ModelManager available to the views
         app.state.model_manager = ModelManager(
@@ -96,15 +96,16 @@ class API(ExceptionHandler):
             decoder_config_name=self.args.decoder_config_name,
         )
 
-        logger.info(f"self.args.mode={self.args.mode}")
-        logger.info(f"self.args.device={self.args.device}")
-        logger.info(f"self.args.half={self.args.half}")
-        logger.info(f"self.args.compile={self.args.compile}")
-        logger.info(f"self.args.llama_checkpoint_path={self.args.llama_checkpoint_path}")
-        logger.info(f"self.args.decoder_checkpoint_path={self.args.decoder_checkpoint_path}")
-        logger.info(f"self.args.decoder_config_name={self.args.decoder_config_name}")
         logger.info(f"Startup done, listening server at http://{self.args.listen}")
 
+    def args_print(self):
+        if not self.args:
+            return
+        print("Loaded arguments:")
+        for key, value in vars(self.args).items():
+            print(f"  self.args.{key}: {value}")
+
+
 
 def create_app():
     args_env = os.environ.get(ENV_ARGS_KEY)