|
|
@@ -6,6 +6,7 @@ from http import HTTPStatus
|
|
|
from threading import Lock
|
|
|
from typing import Annotated, Literal, Optional
|
|
|
|
|
|
+import librosa
|
|
|
import numpy as np
|
|
|
import soundfile as sf
|
|
|
import torch
|
|
|
@@ -29,6 +30,7 @@ from pydantic import BaseModel
|
|
|
from transformers import AutoTokenizer
|
|
|
|
|
|
import tools.llama.generate
|
|
|
+from fish_speech.models.vqgan.utils import sequence_mask
|
|
|
from tools.llama.generate import encode_tokens, generate, load_model
|
|
|
|
|
|
|
|
|
@@ -133,7 +135,6 @@ class VQGANModel:
|
|
|
logger.info("The vqgan model is removed from memory.")
|
|
|
|
|
|
@torch.no_grad()
|
|
|
- @torch.autocast(device_type="cuda", enabled=True)
|
|
|
def sematic_to_wav(self, indices):
|
|
|
model = self.model
|
|
|
indices = indices.to(model.device).long()
|
|
|
@@ -170,6 +171,57 @@ class VQGANModel:
|
|
|
|
|
|
return fake_audio, model.sampling_rate
|
|
|
|
|
|
+ @torch.no_grad()
|
|
|
+ def wav_to_semantic(self, audio):
|
|
|
+ model = self.model
|
|
|
+ # Load audio
|
|
|
+ audio, _ = librosa.load(
|
|
|
+ audio,
|
|
|
+ sr=model.sampling_rate,
|
|
|
+ mono=True,
|
|
|
+ )
|
|
|
+ audios = torch.from_numpy(audio).to(model.device)[None, None, :]
|
|
|
+ logger.info(
|
|
|
+ f"Loaded audio with {audios.shape[2] / model.sampling_rate:.2f} seconds"
|
|
|
+ )
|
|
|
+
|
|
|
+ # VQ Encoder
|
|
|
+ audio_lengths = torch.tensor(
|
|
|
+ [audios.shape[2]], device=model.device, dtype=torch.long
|
|
|
+ )
|
|
|
+
|
|
|
+ features = gt_mels = model.mel_transform(
|
|
|
+ audios, sample_rate=model.sampling_rate
|
|
|
+ )
|
|
|
+
|
|
|
+ if model.downsample is not None:
|
|
|
+ features = model.downsample(features)
|
|
|
+
|
|
|
+ mel_lengths = audio_lengths // model.hop_length
|
|
|
+ feature_lengths = (
|
|
|
+ audio_lengths
|
|
|
+ / model.hop_length
|
|
|
+ / (model.downsample.total_strides if model.downsample is not None else 1)
|
|
|
+ ).long()
|
|
|
+
|
|
|
+ feature_masks = torch.unsqueeze(
|
|
|
+ sequence_mask(feature_lengths, features.shape[2]), 1
|
|
|
+ ).to(gt_mels.dtype)
|
|
|
+
|
|
|
+ # vq_features is 50 hz, need to convert to true mel size
|
|
|
+ text_features = model.mel_encoder(features, feature_masks)
|
|
|
+ _, indices, _ = model.vq_encoder(text_features, feature_masks)
|
|
|
+
|
|
|
+ if indices.ndim == 4 and indices.shape[1] == 1 and indices.shape[3] == 1:
|
|
|
+ indices = indices[:, 0, :, 0]
|
|
|
+ else:
|
|
|
+ logger.error(f"Unknown indices shape: {indices.shape}")
|
|
|
+ return
|
|
|
+
|
|
|
+ logger.info(f"Generated indices of shape {indices.shape}")
|
|
|
+
|
|
|
+ return indices
|
|
|
+
|
|
|
|
|
|
class LoadLlamaModelRequest(BaseModel):
|
|
|
config_name: str = "text2semantic_finetune"
|
|
|
@@ -275,6 +327,7 @@ class InvokeRequest(BaseModel):
|
|
|
top_p: float = 0.5
|
|
|
repetition_penalty: float = 1.5
|
|
|
temperature: float = 0.7
|
|
|
+ order: str = "zh,jp,en"
|
|
|
use_g2p: bool = True
|
|
|
seed: Optional[int] = None
|
|
|
speaker: Optional[str] = None
|
|
|
@@ -299,19 +352,27 @@ def api_invoke_model(
|
|
|
llama_model_manager = model["llama"]
|
|
|
vqgan_model_manager = model["vqgan"]
|
|
|
|
|
|
- # Lock
|
|
|
- model["lock"].acquire()
|
|
|
-
|
|
|
device = llama_model_manager.device
|
|
|
seed = req.seed
|
|
|
prompt_tokens = req.prompt_tokens
|
|
|
logger.info(f"Device: {device}")
|
|
|
|
|
|
- prompt_tokens = (
|
|
|
- torch.from_numpy(np.load(prompt_tokens)).to(device)
|
|
|
- if prompt_tokens is not None
|
|
|
- else None
|
|
|
- )
|
|
|
+ if prompt_tokens is not None and prompt_tokens.endswith(".npy"):
|
|
|
+ prompt_tokens = torch.from_numpy(np.load(prompt_tokens)).to(device)
|
|
|
+ elif prompt_tokens is not None and prompt_tokens.endswith(".wav"):
|
|
|
+ prompt_tokens = vqgan_model_manager.wav_to_semantic(prompt_tokens)
|
|
|
+ elif prompt_tokens is not None:
|
|
|
+ logger.error(f"Unknown prompt tokens: {prompt_tokens}")
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=HTTPStatus.BAD_REQUEST,
|
|
|
+ content="Unknown prompt tokens, it should be either .npy or .wav file.",
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ prompt_tokens = None
|
|
|
+
|
|
|
+ # Lock
|
|
|
+ model["lock"].acquire()
|
|
|
+
|
|
|
encoded = encode_tokens(
|
|
|
llama_model_manager.tokenizer,
|
|
|
req.text,
|
|
|
@@ -321,6 +382,7 @@ def api_invoke_model(
|
|
|
device=device,
|
|
|
use_g2p=req.use_g2p,
|
|
|
speaker=req.speaker,
|
|
|
+ order=req.order,
|
|
|
)
|
|
|
prompt_length = encoded.size(1)
|
|
|
logger.info(f"Encoded prompt shape: {encoded.shape}")
|