|
|
@@ -46,6 +46,7 @@ class HubertVQ(nn.Module):
|
|
|
param.requires_grad = True
|
|
|
|
|
|
# Quantization
|
|
|
+ self.quantizer_ln = nn.LayerNorm(self.hubert.config.hidden_size)
|
|
|
self.quantizer = VectorQuantization(
|
|
|
codebook_size=codebook_size,
|
|
|
dim=self.hubert.config.hidden_size,
|
|
|
@@ -164,6 +165,7 @@ class HubertVQ(nn.Module):
|
|
|
)
|
|
|
|
|
|
# Quantize
|
|
|
+ hidden_states = self.quantizer_ln(hidden_states)
|
|
|
quantize, _, vq_loss = self.quantizer(hidden_states.transpose(1, 2))
|
|
|
quantize = quantize.transpose(1, 2)
|
|
|
|