Lengyue 2 лет назад
Родитель
Сommit
ca11d4b869

+ 1 - 1
README.md

@@ -34,7 +34,7 @@ Download required `vqgan` and `text2semantic` model from our huggingface repo.
 
 ```bash
 wget https://huggingface.co/fishaudio/speech-lm-v1/raw/main/vqgan-v1.pth -O checkpoints/vqgan-v1.pth
-wget https://huggingface.co/fishaudio/speech-lm-v1/blob/main/text2semantic-400m-v0.1-4k.pth -O checkpoints/text2semantic-400m-v0.1-4k.pth
+wget https://huggingface.co/fishaudio/speech-lm-v1/blob/main/text2semantic-400m-v0.2-4k.pth -O checkpoints/text2semantic-400m-v0.2-4k.pth
 ```
 
 Generate semantic tokens from text:

+ 2 - 1
README.zh.md

@@ -35,7 +35,7 @@ pip3 install -e .
     
 ```bash
 wget https://huggingface.co/fishaudio/speech-lm-v1/raw/main/vqgan-v1.pth -O checkpoints/vqgan-v1.pth
-wget https://huggingface.co/fishaudio/speech-lm-v1/blob/main/text2semantic-400m-v0.1-4k.pth -O checkpoints/text2semantic-400m-v0.1-4k.pth
+wget https://huggingface.co/fishaudio/speech-lm-v1/blob/main/text2semantic-400m-v0.2-4k.pth -O checkpoints/text2semantic-400m-v0.2-4k.pth
 ```
 
 ### 1. [可选] 从语音生成 prompt: 
@@ -74,6 +74,7 @@ cargo build --release
 
 ## 更新日志
 
+- 2023/12/17: 更新了 `text2semantic` 模型, 支持无音素模式.
 - 2023/12/13: 测试版发布, 包含 VQGAN 模型和一个基于 LLAMA 的语言模型 (只支持音素).
 
 ## 致谢

+ 4 - 1
fish_speech/models/vqgan/lit_module.py

@@ -72,6 +72,7 @@ class VQGAN(L.LightningModule):
         self.hop_length = hop_length
         self.sampling_rate = sample_rate
         self.freeze_hifigan = freeze_hifigan
+        self.freeze_vq = freeze_vq
 
         # Disable automatic optimization
         self.automatic_optimization = False
@@ -164,7 +165,9 @@ class VQGAN(L.LightningModule):
 
         # vq_features is 50 hz, need to convert to true mel size
         text_features = self.mel_encoder(features, feature_masks)
-        text_features, _, loss_vq = self.vq_encoder(text_features, feature_masks)
+        text_features, _, loss_vq = self.vq_encoder(
+            text_features, feature_masks, freeze_codebook=self.freeze_vq
+        )
         text_features = F.interpolate(
             text_features, size=gt_mels.shape[2], mode="nearest"
         )

+ 2 - 2
fish_speech/models/vqgan/modules/encoders.py

@@ -308,7 +308,7 @@ class VQEncoder(nn.Module):
             nn.Conv1d(vq_channels, in_channels, kernel_size=1, stride=1),
         )
 
-    def forward(self, x, x_mask):
+    def forward(self, x, x_mask, freeze_codebook=False):
         # x: [B, C, T], x_mask: [B, 1, T]
         x_len = x.shape[2]
 
@@ -317,7 +317,7 @@ class VQEncoder(nn.Module):
             x_mask = F.pad(x_mask, (0, self.downsample - x_len % self.downsample))
 
         x = self.conv_in(x)
-        q, indices, loss = self.vq(x.mT)
+        q, indices, loss = self.vq(x.mT, freeze_codebook=freeze_codebook)
         q = q.mT
 
         if self.codebook_groups > 1: