|
|
@@ -40,21 +40,6 @@ HEADER_MD = f"""# Fish Speech
|
|
|
TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
|
|
|
SPACE_IMPORTED = False
|
|
|
|
|
|
-try:
|
|
|
- import spaces
|
|
|
-
|
|
|
- GPU_DECORATOR = spaces.GPU
|
|
|
- SPACE_IMPORTED = True
|
|
|
-except ImportError:
|
|
|
-
|
|
|
- def GPU_DECORATOR(func):
|
|
|
- @wraps(func)
|
|
|
- def wrapper(*args, **kwargs):
|
|
|
- return func(*args, **kwargs)
|
|
|
-
|
|
|
- wrapper.original = func # ref
|
|
|
- return wrapper
|
|
|
-
|
|
|
|
|
|
def build_html_error_message(error):
|
|
|
return f"""
|
|
|
@@ -65,7 +50,6 @@ def build_html_error_message(error):
|
|
|
"""
|
|
|
|
|
|
|
|
|
-@GPU_DECORATOR
|
|
|
@torch.inference_mode()
|
|
|
def inference(
|
|
|
text,
|
|
|
@@ -173,11 +157,6 @@ def inference(
|
|
|
|
|
|
inference_stream = partial(inference, streaming=True)
|
|
|
|
|
|
-if not SPACE_IMPORTED:
|
|
|
- logger.info("‘spaces’ not imported, use original")
|
|
|
- inference = inference.original
|
|
|
- inference_stream = partial(inference, streaming=True)
|
|
|
-
|
|
|
|
|
|
def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
|
|
|
buffer = io.BytesIO()
|
|
|
@@ -343,7 +322,7 @@ def parse_args():
|
|
|
parser.add_argument(
|
|
|
"--llama-checkpoint-path",
|
|
|
type=Path,
|
|
|
- default="checkpoints/text2semantic-sft-large-v1-4k.pth",
|
|
|
+ default="checkpoints/text2semantic-sft-medium-v1-4k.pth",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--llama-config-name", type=str, default="dual_ar_2_codebook_large"
|