ソースを参照

Add g2p config and semantic convert to api server

Lengyue 2 年 前
コミット
3a08434bfb
2 ファイル変更77 行追加10 行削除
  1. 71 9
      tools/api_server.py
  2. 6 1
      tools/llama/generate.py

+ 71 - 9
tools/api_server.py

@@ -6,6 +6,7 @@ from http import HTTPStatus
 from threading import Lock
 from typing import Annotated, Literal, Optional
 
+import librosa
 import numpy as np
 import soundfile as sf
 import torch
@@ -29,6 +30,7 @@ from pydantic import BaseModel
 from transformers import AutoTokenizer
 
 import tools.llama.generate
+from fish_speech.models.vqgan.utils import sequence_mask
 from tools.llama.generate import encode_tokens, generate, load_model
 
 
@@ -133,7 +135,6 @@ class VQGANModel:
         logger.info("The vqgan model is removed from memory.")
 
     @torch.no_grad()
-    @torch.autocast(device_type="cuda", enabled=True)
     def sematic_to_wav(self, indices):
         model = self.model
         indices = indices.to(model.device).long()
@@ -170,6 +171,57 @@ class VQGANModel:
 
         return fake_audio, model.sampling_rate
 
+    @torch.no_grad()
+    def wav_to_semantic(self, audio):
+        model = self.model
+        # Load audio
+        audio, _ = librosa.load(
+            audio,
+            sr=model.sampling_rate,
+            mono=True,
+        )
+        audios = torch.from_numpy(audio).to(model.device)[None, None, :]
+        logger.info(
+            f"Loaded audio with {audios.shape[2] / model.sampling_rate:.2f} seconds"
+        )
+
+        # VQ Encoder
+        audio_lengths = torch.tensor(
+            [audios.shape[2]], device=model.device, dtype=torch.long
+        )
+
+        features = gt_mels = model.mel_transform(
+            audios, sample_rate=model.sampling_rate
+        )
+
+        if model.downsample is not None:
+            features = model.downsample(features)
+
+        mel_lengths = audio_lengths // model.hop_length
+        feature_lengths = (
+            audio_lengths
+            / model.hop_length
+            / (model.downsample.total_strides if model.downsample is not None else 1)
+        ).long()
+
+        feature_masks = torch.unsqueeze(
+            sequence_mask(feature_lengths, features.shape[2]), 1
+        ).to(gt_mels.dtype)
+
+        # vq_features is 50 hz, need to convert to true mel size
+        text_features = model.mel_encoder(features, feature_masks)
+        _, indices, _ = model.vq_encoder(text_features, feature_masks)
+
+        if indices.ndim == 4 and indices.shape[1] == 1 and indices.shape[3] == 1:
+            indices = indices[:, 0, :, 0]
+        else:
+            logger.error(f"Unknown indices shape: {indices.shape}")
+            return
+
+        logger.info(f"Generated indices of shape {indices.shape}")
+
+        return indices
+
 
 class LoadLlamaModelRequest(BaseModel):
     config_name: str = "text2semantic_finetune"
@@ -275,6 +327,7 @@ class InvokeRequest(BaseModel):
     top_p: float = 0.5
     repetition_penalty: float = 1.5
     temperature: float = 0.7
+    order: str = "zh,jp,en"
     use_g2p: bool = True
     seed: Optional[int] = None
     speaker: Optional[str] = None
@@ -299,19 +352,27 @@ def api_invoke_model(
     llama_model_manager = model["llama"]
     vqgan_model_manager = model["vqgan"]
 
-    # Lock
-    model["lock"].acquire()
-
     device = llama_model_manager.device
     seed = req.seed
     prompt_tokens = req.prompt_tokens
     logger.info(f"Device: {device}")
 
-    prompt_tokens = (
-        torch.from_numpy(np.load(prompt_tokens)).to(device)
-        if prompt_tokens is not None
-        else None
-    )
+    if prompt_tokens is not None and prompt_tokens.endswith(".npy"):
+        prompt_tokens = torch.from_numpy(np.load(prompt_tokens)).to(device)
+    elif prompt_tokens is not None and prompt_tokens.endswith(".wav"):
+        prompt_tokens = vqgan_model_manager.wav_to_semantic(prompt_tokens)
+    elif prompt_tokens is not None:
+        logger.error(f"Unknown prompt tokens: {prompt_tokens}")
+        raise HTTPException(
+            status_code=HTTPStatus.BAD_REQUEST,
+            content="Unknown prompt tokens, it should be either .npy or .wav file.",
+        )
+    else:
+        prompt_tokens = None
+
+    # Lock
+    model["lock"].acquire()
+
     encoded = encode_tokens(
         llama_model_manager.tokenizer,
         req.text,
@@ -321,6 +382,7 @@ def api_invoke_model(
         device=device,
         use_g2p=req.use_g2p,
         speaker=req.speaker,
+        order=req.order,
     )
     prompt_length = encoded.size(1)
     logger.info(f"Encoded prompt shape: {encoded.shape}")

+ 6 - 1
tools/llama/generate.py

@@ -268,12 +268,14 @@ def encode_tokens(
     prompt_tokens=None,
     use_g2p=False,
     speaker=None,
+    order="zh,jp,en",
 ):
     if prompt_text is not None:
         string = prompt_text + " " + string
 
     if use_g2p:
-        prompt = g2p(string)
+        order = order.split(",")
+        prompt = g2p(string, order=order)
         prompt = [
             (f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
             for _, i in prompt
@@ -382,6 +384,7 @@ def load_model(config_name, checkpoint_path, device, precision):
 @click.option("--use-g2p/--no-g2p", default=True)
 @click.option("--seed", type=int, default=42)
 @click.option("--speaker", type=str, default=None)
+@click.option("--order", type=str, default="zh,jp,en")
 @click.option("--half/--no-half", default=False)
 def main(
     text: str,
@@ -400,6 +403,7 @@ def main(
     use_g2p: bool,
     seed: int,
     speaker: Optional[str],
+    order: str,
     half: bool,
 ) -> None:
     device = "cuda"
@@ -429,6 +433,7 @@ def main(
         device=device,
         use_g2p=use_g2p,
         speaker=speaker,
+        order=order,
     )
     prompt_length = encoded.size(1)
     logger.info(f"Encoded prompt shape: {encoded.shape}")