Explorar o código

default causual & clean code

Lengyue %!s(int64=2) %!d(string=hai) anos
pai
achega
df811976d5
Modificáronse 3 ficheiros con 4 adicións e 11 borrados
  1. 1 1
      data_server/src/main.rs
  2. 3 3
      tools/vqgan/extract_vq.py
  3. 0 7
      tools/vqgan/inference.py

+ 1 - 1
data_server/src/main.rs

@@ -252,7 +252,7 @@ struct Args {
     files: Vec<String>,
 
     /// Causual sampling
-    #[clap(short, long, default_value = "false")]
+    #[clap(short, long, default_value = "true")]
     causal: bool,
 
     /// Address to bind to

+ 3 - 3
tools/vqgan/extract_vq.py

@@ -63,6 +63,7 @@ def get_model(
     return model
 
 
+@torch.inference_mode()
 def process_batch(files: list[Path], model) -> float:
     wavs = []
     audio_lengths = []
@@ -87,8 +88,7 @@ def process_batch(files: list[Path], model) -> float:
     audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long)
 
     # Calculate lengths
-    with torch.no_grad():
-        indices, feature_lengths = model.encode(audios, audio_lengths)
+    indices, feature_lengths = model.encode(audios, audio_lengths)
 
     # Save to disk
     outputs = indices.cpu().numpy()
@@ -111,7 +111,7 @@ def process_batch(files: list[Path], model) -> float:
 @click.option("--config-name", default="vqgan_pretrain")
 @click.option(
     "--checkpoint-path",
-    default="checkpoints/vqgan-v1.pth",
+    default="checkpoints/vq-gan-group-fsq-8x1024-wn-20x768-30kh.pth",
 )
 @click.option("--batch-size", default=64)
 @click.option("--filelist", default=None, type=Path)

+ 0 - 7
tools/vqgan/inference.py

@@ -5,15 +5,12 @@ import librosa
 import numpy as np
 import soundfile as sf
 import torch
-import torch.nn.functional as F
-from einops import rearrange
 from hydra import compose, initialize
 from hydra.utils import instantiate
 from lightning import LightningModule
 from loguru import logger
 from omegaconf import OmegaConf
 
-from fish_speech.models.vqgan.utils import sequence_mask
 from fish_speech.utils.file import AUDIO_EXTENSIONS
 
 # register eval resolver
@@ -85,10 +82,6 @@ def main(input_path, output_path, config_name, checkpoint_path):
     else:
         raise ValueError(f"Unknown input type: {input_path}")
 
-    # random destroy 10% of indices
-    # mask = torch.rand_like(indices, dtype=torch.float) > 0.9
-    # indices[mask] = torch.randint(0, 1000, mask.shape, device=indices.device, dtype=indices.dtype)[mask]
-
     # Restore
     feature_lengths = torch.tensor([indices.shape[1]], device=model.device)
     fake_audios = model.decode(