supeng 3 недель назад
Родитель
Сommit
ea4f9343ae

+ 34 - 0
deploy_multi_worker.sh

@@ -0,0 +1,34 @@
+#!/bin/bash
+# 多worker部署脚本 - 在单台机器上启动API服务
+# 使用方法: ./deploy_multi_worker.sh [num_workers] [port]
+
+set -e
+
+# 配置参数
+NUM_WORKERS=${1:-2}  # 默认2个worker
+PORT=${2:-8080}      # 默认端口8080
+GPU_ID=${3:-0}       # 默认GPU 0
+
+LLAMA_CHECKPOINT="checkpoints/s2-pro"
+DECODER_CHECKPOINT="checkpoints/s2-pro/codec.pth"
+
+# 设置环境变量
+export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
+export CUDA_VISIBLE_DEVICES=${GPU_ID}
+
+echo "========================================="
+echo "Fish-Speech Multi-Worker Deployment"
+echo "========================================="
+echo "Workers: ${NUM_WORKERS}"
+echo "Port: ${PORT}"
+echo "GPU: ${GPU_ID}"
+echo "========================================="
+
+# 启动API服务
+python tools/api_server.py \
+  --listen 0.0.0.0:${PORT} \
+  --llama-checkpoint-path ${LLAMA_CHECKPOINT} \
+  --decoder-checkpoint-path ${DECODER_CHECKPOINT} \
+  --half \
+  --workers 1 \
+  --num-workers ${NUM_WORKERS}

+ 11 - 4
fish_speech/models/text2semantic/inference.py

@@ -774,11 +774,13 @@ def launch_thread_safe_queue(
     device,
     device,
     precision,
     precision,
     compile: bool = False,
     compile: bool = False,
+    num_workers: int = 1,
 ):
 ):
     input_queue = queue.Queue()
     input_queue = queue.Queue()
-    init_event = threading.Event()
+    init_events = [threading.Event() for _ in range(num_workers)]
 
 
-    def worker():
+    def worker(worker_id, init_event):
+        logger.info(f"Worker {worker_id} starting, loading model...")
         model, decode_one_token = init_model(
         model, decode_one_token = init_model(
             checkpoint_path, device, precision, compile=compile
             checkpoint_path, device, precision, compile=compile
         )
         )
@@ -788,6 +790,7 @@ def launch_thread_safe_queue(
                 max_seq_len=model.config.max_seq_len,
                 max_seq_len=model.config.max_seq_len,
                 dtype=next(model.parameters()).dtype,
                 dtype=next(model.parameters()).dtype,
             )
             )
+        logger.info(f"Worker {worker_id} initialized")
         init_event.set()
         init_event.set()
 
 
         while True:
         while True:
@@ -817,9 +820,13 @@ def launch_thread_safe_queue(
                 if torch.cuda.is_available():
                 if torch.cuda.is_available():
                     torch.cuda.empty_cache()
                     torch.cuda.empty_cache()
 
 
-    threading.Thread(target=worker, daemon=True).start()
-    init_event.wait()
+    for i in range(num_workers):
+        threading.Thread(target=worker, args=(i, init_events[i]), daemon=True).start()
 
 
+    for event in init_events:
+        event.wait()
+
+    logger.info(f"All {num_workers} workers initialized successfully")
     return input_queue
     return input_queue
 
 
 
 

+ 1 - 0
tools/api_server.py

@@ -96,6 +96,7 @@ class API(ExceptionHandler):
             llama_checkpoint_path=self.args.llama_checkpoint_path,
             llama_checkpoint_path=self.args.llama_checkpoint_path,
             decoder_checkpoint_path=self.args.decoder_checkpoint_path,
             decoder_checkpoint_path=self.args.decoder_checkpoint_path,
             decoder_config_name=self.args.decoder_config_name,
             decoder_config_name=self.args.decoder_config_name,
+            num_workers=self.args.num_workers,
         )
         )
 
 
         logger.info(f"Startup done, listening server at http://{self.args.listen}")
         logger.info(f"Startup done, listening server at http://{self.args.listen}")

+ 1 - 0
tools/server/api_utils.py

@@ -38,6 +38,7 @@ def parse_args():
     parser.add_argument("--max-text-length", type=int, default=0)
     parser.add_argument("--max-text-length", type=int, default=0)
     parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
     parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
     parser.add_argument("--workers", type=int, default=1)
     parser.add_argument("--workers", type=int, default=1)
+    parser.add_argument("--num-workers", type=int, default=1, help="Number of model worker threads for parallel inference")
     parser.add_argument("--api-key", type=str, default=None)
     parser.add_argument("--api-key", type=str, default=None)
 
 
     return parser.parse_args()
     return parser.parse_args()

+ 3 - 0
tools/server/model_manager.py

@@ -18,12 +18,14 @@ class ModelManager:
         llama_checkpoint_path: str,
         llama_checkpoint_path: str,
         decoder_checkpoint_path: str,
         decoder_checkpoint_path: str,
         decoder_config_name: str,
         decoder_config_name: str,
+        num_workers: int = 1,
     ) -> None:
     ) -> None:
 
 
         self.mode = mode
         self.mode = mode
         self.device = device
         self.device = device
         self.half = half
         self.half = half
         self.compile = compile
         self.compile = compile
+        self.num_workers = num_workers
 
 
         self.precision = torch.half if half else torch.bfloat16
         self.precision = torch.half if half else torch.bfloat16
 
 
@@ -63,6 +65,7 @@ class ModelManager:
                 device=device,
                 device=device,
                 precision=precision,
                 precision=precision,
                 compile=compile,
                 compile=compile,
+                num_workers=self.num_workers,
             )
             )
         else:
         else:
             raise ValueError(f"Invalid mode: {mode}")
             raise ValueError(f"Invalid mode: {mode}")