|
@@ -2,14 +2,12 @@ import html
|
|
|
import os
|
|
import os
|
|
|
import threading
|
|
import threading
|
|
|
from argparse import ArgumentParser
|
|
from argparse import ArgumentParser
|
|
|
-from io import BytesIO
|
|
|
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
import gradio as gr
|
|
import gradio as gr
|
|
|
import librosa
|
|
import librosa
|
|
|
import torch
|
|
import torch
|
|
|
from loguru import logger
|
|
from loguru import logger
|
|
|
-from torchaudio import functional as AF
|
|
|
|
|
from transformers import AutoTokenizer
|
|
from transformers import AutoTokenizer
|
|
|
|
|
|
|
|
from tools.llama.generate import launch_thread_safe_queue
|
|
from tools.llama.generate import launch_thread_safe_queue
|
|
@@ -74,7 +72,10 @@ def inference(
|
|
|
speaker,
|
|
speaker,
|
|
|
):
|
|
):
|
|
|
if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
|
|
if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
|
|
|
- return None, "Text is too long, please keep it under 1000 characters."
|
|
|
|
|
|
|
+ return (
|
|
|
|
|
+ None,
|
|
|
|
|
+ f"Text is too long, please keep it under {args.max_gradio_length} characters.",
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
# Parse reference audio aka prompt
|
|
# Parse reference audio aka prompt
|
|
|
prompt_tokens = None
|
|
prompt_tokens = None
|
|
@@ -266,10 +267,10 @@ def parse_args():
|
|
|
parser.add_argument(
|
|
parser.add_argument(
|
|
|
"--llama-checkpoint-path",
|
|
"--llama-checkpoint-path",
|
|
|
type=Path,
|
|
type=Path,
|
|
|
- default="checkpoints/text2semantic-medium-v1-2k.pth",
|
|
|
|
|
|
|
+ default="checkpoints/text2semantic-sft-large-v1-4k.pth",
|
|
|
)
|
|
)
|
|
|
parser.add_argument(
|
|
parser.add_argument(
|
|
|
- "--llama-config-name", type=str, default="dual_ar_2_codebook_medium"
|
|
|
|
|
|
|
+ "--llama-config-name", type=str, default="dual_ar_2_codebook_large"
|
|
|
)
|
|
)
|
|
|
parser.add_argument(
|
|
parser.add_argument(
|
|
|
"--vqgan-checkpoint-path",
|
|
"--vqgan-checkpoint-path",
|