|
|
@@ -9,8 +9,9 @@ NUM_WORKERS=${1:-2} # 默认2个worker
|
|
|
PORT=${2:-8080} # 默认端口8080
|
|
|
GPU_ID=${3:-0} # 默认GPU 0
|
|
|
|
|
|
-LLAMA_CHECKPOINT="checkpoints/s2-pro"
|
|
|
-DECODER_CHECKPOINT="checkpoints/s2-pro/codec.pth"
|
|
|
+DECODER_CONFIG_NAME="modded_dac_vq"
|
|
|
+LLAMA_CHECKPOINT="/root/fish-checkpoints/s2-pro"
|
|
|
+DECODER_CHECKPOINT="/root/fish-checkpoints/s2-pro/codec.pth"
|
|
|
|
|
|
# 设置环境变量
|
|
|
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
|
|
@@ -29,6 +30,7 @@ python tools/api_server.py \
|
|
|
--listen 0.0.0.0:${PORT} \
|
|
|
--llama-checkpoint-path ${LLAMA_CHECKPOINT} \
|
|
|
--decoder-checkpoint-path ${DECODER_CHECKPOINT} \
|
|
|
+ --decoder-config-name "${DECODER_CONFIG_NAME} \
|
|
|
--half \
|
|
|
--workers 1 \
|
|
|
--num-workers ${NUM_WORKERS}
|