|
@@ -3,7 +3,7 @@ from typing import Optional
|
|
|
import torch
|
|
import torch
|
|
|
import torch.nn as nn
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
import torch.nn.functional as F
|
|
|
-from vector_quantize_pytorch import VectorQuantize
|
|
|
|
|
|
|
+from vector_quantize_pytorch import LFQ, VectorQuantize
|
|
|
|
|
|
|
|
from fish_speech.models.vqgan.modules.modules import WN
|
|
from fish_speech.models.vqgan.modules.modules import WN
|
|
|
from fish_speech.models.vqgan.modules.transformer import (
|
|
from fish_speech.models.vqgan.modules.transformer import (
|
|
@@ -234,12 +234,12 @@ class VQEncoder(nn.Module):
|
|
|
):
|
|
):
|
|
|
super().__init__()
|
|
super().__init__()
|
|
|
|
|
|
|
|
- self.vq = VectorQuantize(
|
|
|
|
|
|
|
+ self.vq = LFQ(
|
|
|
dim=vq_channels,
|
|
dim=vq_channels,
|
|
|
codebook_size=codebook_size,
|
|
codebook_size=codebook_size,
|
|
|
- threshold_ema_dead_code=2,
|
|
|
|
|
- kmeans_init=False,
|
|
|
|
|
- channel_last=False,
|
|
|
|
|
|
|
+ # threshold_ema_dead_code=2,
|
|
|
|
|
+ # kmeans_init=False,
|
|
|
|
|
+ # channel_last=False,
|
|
|
)
|
|
)
|
|
|
self.downsample = downsample
|
|
self.downsample = downsample
|
|
|
self.conv_in = nn.Conv1d(
|
|
self.conv_in = nn.Conv1d(
|
|
@@ -286,8 +286,8 @@ class VQEncoder(nn.Module):
|
|
|
x_mask = F.pad(x_mask, (0, self.downsample - x_len % self.downsample))
|
|
x_mask = F.pad(x_mask, (0, self.downsample - x_len % self.downsample))
|
|
|
|
|
|
|
|
x = self.conv_in(x)
|
|
x = self.conv_in(x)
|
|
|
- q, _, loss = self.vq(x)
|
|
|
|
|
- x = self.conv_out(q) * x_mask
|
|
|
|
|
|
|
+ q, _, loss = self.vq(x.mT)
|
|
|
|
|
+ x = self.conv_out(q.mT) * x_mask
|
|
|
x = x[:, :, :x_len]
|
|
x = x[:, :, :x_len]
|
|
|
|
|
|
|
|
return x, loss
|
|
return x, loss
|