Lengyue 2 лет назад
Родитель
Сommit
df811976d5
3 измененных файлов с 4 добавлено и 11 удалено
  1. 1 1
      data_server/src/main.rs
  2. 3 3
      tools/vqgan/extract_vq.py
  3. 0 7
      tools/vqgan/inference.py

+ 1 - 1
data_server/src/main.rs

@@ -252,7 +252,7 @@ struct Args {
     files: Vec<String>,
     files: Vec<String>,
 
 
     /// Causual sampling
     /// Causual sampling
-    #[clap(short, long, default_value = "false")]
+    #[clap(short, long, default_value = "true")]
     causal: bool,
     causal: bool,
 
 
     /// Address to bind to
     /// Address to bind to

+ 3 - 3
tools/vqgan/extract_vq.py

@@ -63,6 +63,7 @@ def get_model(
     return model
     return model
 
 
 
 
+@torch.inference_mode()
 def process_batch(files: list[Path], model) -> float:
 def process_batch(files: list[Path], model) -> float:
     wavs = []
     wavs = []
     audio_lengths = []
     audio_lengths = []
@@ -87,8 +88,7 @@ def process_batch(files: list[Path], model) -> float:
     audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long)
     audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long)
 
 
     # Calculate lengths
     # Calculate lengths
-    with torch.no_grad():
-        indices, feature_lengths = model.encode(audios, audio_lengths)
+    indices, feature_lengths = model.encode(audios, audio_lengths)
 
 
     # Save to disk
     # Save to disk
     outputs = indices.cpu().numpy()
     outputs = indices.cpu().numpy()
@@ -111,7 +111,7 @@ def process_batch(files: list[Path], model) -> float:
 @click.option("--config-name", default="vqgan_pretrain")
 @click.option("--config-name", default="vqgan_pretrain")
 @click.option(
 @click.option(
     "--checkpoint-path",
     "--checkpoint-path",
-    default="checkpoints/vqgan-v1.pth",
+    default="checkpoints/vq-gan-group-fsq-8x1024-wn-20x768-30kh.pth",
 )
 )
 @click.option("--batch-size", default=64)
 @click.option("--batch-size", default=64)
 @click.option("--filelist", default=None, type=Path)
 @click.option("--filelist", default=None, type=Path)

+ 0 - 7
tools/vqgan/inference.py

@@ -5,15 +5,12 @@ import librosa
 import numpy as np
 import numpy as np
 import soundfile as sf
 import soundfile as sf
 import torch
 import torch
-import torch.nn.functional as F
-from einops import rearrange
 from hydra import compose, initialize
 from hydra import compose, initialize
 from hydra.utils import instantiate
 from hydra.utils import instantiate
 from lightning import LightningModule
 from lightning import LightningModule
 from loguru import logger
 from loguru import logger
 from omegaconf import OmegaConf
 from omegaconf import OmegaConf
 
 
-from fish_speech.models.vqgan.utils import sequence_mask
 from fish_speech.utils.file import AUDIO_EXTENSIONS
 from fish_speech.utils.file import AUDIO_EXTENSIONS
 
 
 # register eval resolver
 # register eval resolver
@@ -85,10 +82,6 @@ def main(input_path, output_path, config_name, checkpoint_path):
     else:
     else:
         raise ValueError(f"Unknown input type: {input_path}")
         raise ValueError(f"Unknown input type: {input_path}")
 
 
-    # random destroy 10% of indices
-    # mask = torch.rand_like(indices, dtype=torch.float) > 0.9
-    # indices[mask] = torch.randint(0, 1000, mask.shape, device=indices.device, dtype=indices.dtype)[mask]
-
     # Restore
     # Restore
     feature_lengths = torch.tensor([indices.shape[1]], device=model.device)
     feature_lengths = torch.tensor([indices.shape[1]], device=model.device)
     fake_audios = model.decode(
     fake_audios = model.decode(