Преглед изворни кода

feat:修改HALF的传值方式

zhaohaipeng пре 2 дана
родитељ
комит
d2316b268a
2 измењених фајлова са 18 додато и 3 уклоњено
  1. 11 3
      docker/Dockerfile
  2. 7 0
      tools/api_server.py

+ 11 - 3
docker/Dockerfile

@@ -267,6 +267,15 @@ RUN printf '%s\n' \
     '    echo "$@"' \
     '}' \
     '' \
+    '# Build half arguments' \
+    'build_half_args() {' \
+    '    if [ "${1:-}" = "half" ] || [ "${HALF:-}" = "1" ] || [ "${HALF:-}" = "true" ]; then' \
+    '        echo "--half"' \
+    '        shift' \
+    '    fi' \
+    '    echo "$@"' \
+    '}' \
+    '' \
     '# Health check function' \
     'health_check() {' \
     '    local port=${1:-7860}' \
@@ -337,7 +346,6 @@ ARG API_SERVER_PORT=8080
 ARG LLAMA_CHECKPOINT_PATH="checkpoints/s2-pro"
 ARG DECODER_CHECKPOINT_PATH="checkpoints/s2-pro/codec.pth"
 ARG DECODER_CONFIG_NAME="modded_dac_vq"
-ARG FISH_API_SERVER_ARGS="{}"
 # Expose port
 EXPOSE ${API_SERVER_PORT}
 
@@ -347,7 +355,6 @@ ENV API_SERVER_PORT=${API_SERVER_PORT}
 ENV LLAMA_CHECKPOINT_PATH=${LLAMA_CHECKPOINT_PATH}
 ENV DECODER_CHECKPOINT_PATH=${DECODER_CHECKPOINT_PATH}
 ENV DECODER_CONFIG_NAME=${DECODER_CONFIG_NAME}
-ENV FISH_API_SERVER_ARGS=${FISH_API_SERVER_ARGS}
 
 # Create server entrypoint
 RUN printf '%s\n' \
@@ -359,6 +366,7 @@ RUN printf '%s\n' \
     '' \
     'DEVICE_ARGS=$(build_device_args)' \
     'COMPILE_ARGS=$(build_compile_args "$@")' \
+    'HALF_ARGS=$(build_half_args "$@")' \
     '' \
     'log "Device args: ${DEVICE_ARGS:-none}"' \
     'log "Compile args: ${COMPILE_ARGS}"' \
@@ -369,7 +377,7 @@ RUN printf '%s\n' \
     '  --llama-checkpoint-path "${LLAMA_CHECKPOINT_PATH}" \' \
     '  --decoder-checkpoint-path "${DECODER_CHECKPOINT_PATH}" \' \
     '  --decoder-config-name "${DECODER_CONFIG_NAME}" \' \
-    '  ${DEVICE_ARGS} ${COMPILE_ARGS}' \
+    '  ${DEVICE_ARGS} ${COMPILE_ARGS} ${HALF_ARGS}' \
     > /app/start_server.sh && chmod +x /app/start_server.sh
 
 # Health check

+ 7 - 0
tools/api_server.py

@@ -96,6 +96,13 @@ class API(ExceptionHandler):
             decoder_config_name=self.args.decoder_config_name,
         )
 
+        logger.info(f"self.args.mode={self.args.mode}")
+        logger.info(f"self.args.device={self.args.device}")
+        logger.info(f"self.args.half={self.args.half}")
+        logger.info(f"self.args.compile={self.args.compile}")
+        logger.info(f"self.args.llama_checkpoint_path={self.args.llama_checkpoint_path}")
+        logger.info(f"self.args.decoder_checkpoint_path={self.args.decoder_checkpoint_path}")
+        logger.info(f"self.args.decoder_config_name={self.args.decoder_config_name}")
         logger.info(f"Startup done, listening server at http://{self.args.listen}")