|
|
@@ -1,4 +1,5 @@
|
|
|
import io
|
|
|
+import os
|
|
|
import queue
|
|
|
import sys
|
|
|
import traceback
|
|
|
@@ -88,7 +89,8 @@ def load_audio(reference_audio, sr):
|
|
|
reference_audio = io.BytesIO(audio_data)
|
|
|
|
|
|
waveform, original_sr = torchaudio.load(
|
|
|
- reference_audio, backend="ffmpeg" if sys.platform == "linux" else "soundfile"
|
|
|
+ reference_audio,
|
|
|
+ backend="soundfile", # not every linux release supports 'sox' or 'ffmpeg'
|
|
|
)
|
|
|
|
|
|
if waveform.shape[0] > 1:
|
|
|
@@ -166,6 +168,8 @@ def get_content_type(audio_format):
|
|
|
@torch.inference_mode()
|
|
|
def inference(req: ServeTTSRequest):
|
|
|
|
|
|
+ global prompt_tokens, prompt_texts
|
|
|
+
|
|
|
idstr: str | None = req.reference_id
|
|
|
if idstr is not None:
|
|
|
ref_folder = Path("references") / idstr
|
|
|
@@ -173,33 +177,43 @@ def inference(req: ServeTTSRequest):
|
|
|
ref_audios = list_files(
|
|
|
ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
|
|
|
)
|
|
|
- prompt_tokens = [
|
|
|
- encode_reference(
|
|
|
- decoder_model=decoder_model,
|
|
|
- reference_audio=audio_to_bytes(str(ref_audio)),
|
|
|
- enable_reference_audio=True,
|
|
|
- )
|
|
|
- for ref_audio in ref_audios
|
|
|
- ]
|
|
|
- prompt_texts = [
|
|
|
- read_ref_text(str(ref_audio.with_suffix(".lab")))
|
|
|
- for ref_audio in ref_audios
|
|
|
- ]
|
|
|
+
|
|
|
+ if req.use_memory_cache == "never" or (
|
|
|
+ req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
|
|
|
+ ):
|
|
|
+ prompt_tokens = [
|
|
|
+ encode_reference(
|
|
|
+ decoder_model=decoder_model,
|
|
|
+ reference_audio=audio_to_bytes(str(ref_audio)),
|
|
|
+ enable_reference_audio=True,
|
|
|
+ )
|
|
|
+ for ref_audio in ref_audios
|
|
|
+ ]
|
|
|
+ prompt_texts = [
|
|
|
+ read_ref_text(str(ref_audio.with_suffix(".lab")))
|
|
|
+ for ref_audio in ref_audios
|
|
|
+ ]
|
|
|
+ else:
|
|
|
+ logger.info("Use same references")
|
|
|
|
|
|
else:
|
|
|
# Parse reference audio aka prompt
|
|
|
refs = req.references
|
|
|
- if refs is None:
|
|
|
- refs = []
|
|
|
- prompt_tokens = [
|
|
|
- encode_reference(
|
|
|
- decoder_model=decoder_model,
|
|
|
- reference_audio=ref.audio,
|
|
|
- enable_reference_audio=True,
|
|
|
- )
|
|
|
- for ref in refs
|
|
|
- ]
|
|
|
- prompt_texts = [ref.text for ref in refs]
|
|
|
+
|
|
|
+ if req.use_memory_cache == "never" or (
|
|
|
+ req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
|
|
|
+ ):
|
|
|
+ prompt_tokens = [
|
|
|
+ encode_reference(
|
|
|
+ decoder_model=decoder_model,
|
|
|
+ reference_audio=ref.audio,
|
|
|
+ enable_reference_audio=True,
|
|
|
+ )
|
|
|
+ for ref in refs
|
|
|
+ ]
|
|
|
+ prompt_texts = [ref.text for ref in refs]
|
|
|
+ else:
|
|
|
+ logger.info("Use same references")
|
|
|
|
|
|
# LLAMA Inference
|
|
|
request = dict(
|
|
|
@@ -397,11 +411,23 @@ app = Kui(
|
|
|
)
|
|
|
|
|
|
|
|
|
-if __name__ == "__main__":
|
|
|
+# Each worker process created by Uvicorn has its own memory space,
|
|
|
+# meaning that models and variables are not shared between processes.
|
|
|
+# Therefore, any global variables (like `llama_queue` or `decoder_model`)
|
|
|
+# will not be shared across workers.
|
|
|
|
|
|
- import uvicorn
|
|
|
|
|
|
- args = parse_args()
|
|
|
+# Multi-threading for deep learning can cause issues, such as inconsistent
|
|
|
+# outputs if multiple threads access the same buffers simultaneously.
|
|
|
+# Instead, it's better to use multiprocessing or independent models per thread.
|
|
|
+@app.on_startup
|
|
|
+def initialize_app(app: Kui):
|
|
|
+
|
|
|
+ global args, llama_queue, decoder_model, prompt_tokens, prompt_texts
|
|
|
+
|
|
|
+ prompt_tokens, prompt_texts = [], []
|
|
|
+
|
|
|
+ args = parse_args() # args same as ones in other processes
|
|
|
args.precision = torch.half if args.half else torch.bfloat16
|
|
|
|
|
|
logger.info("Loading Llama model...")
|
|
|
@@ -411,6 +437,7 @@ if __name__ == "__main__":
|
|
|
precision=args.precision,
|
|
|
compile=args.compile,
|
|
|
)
|
|
|
+
|
|
|
logger.info("Llama model loaded, loading VQ-GAN model...")
|
|
|
|
|
|
decoder_model = load_decoder_model(
|
|
|
@@ -421,7 +448,7 @@ if __name__ == "__main__":
|
|
|
|
|
|
logger.info("VQ-GAN model loaded, warming up...")
|
|
|
|
|
|
- # Dry run to check if the model is loaded correctly and avoid the first-time latency
|
|
|
+ # Dry run to ensure models work and avoid first-time latency
|
|
|
list(
|
|
|
inference(
|
|
|
ServeTTSRequest(
|
|
|
@@ -440,5 +467,18 @@ if __name__ == "__main__":
|
|
|
)
|
|
|
|
|
|
logger.info(f"Warming up done, starting server at http://{args.listen}")
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+
|
|
|
+ import uvicorn
|
|
|
+
|
|
|
+ args = parse_args()
|
|
|
host, port = args.listen.split(":")
|
|
|
- uvicorn.run(app, host=host, port=int(port), workers=args.workers, log_level="info")
|
|
|
+ uvicorn.run(
|
|
|
+ "tools.api:app",
|
|
|
+ host=host,
|
|
|
+ port=int(port),
|
|
|
+ workers=args.workers,
|
|
|
+ log_level="info",
|
|
|
+ )
|