Просмотр исходного кода

Add dropout options to optimize overfitting

Lengyue 2 лет назад
Родитель
Сommit
39f6902119

+ 1 - 0
fish_speech/configs/text2semantic_finetune.yaml

@@ -59,6 +59,7 @@ model:
       norm_eps: 1e-5
       norm_eps: 1e-5
       num_codebooks: 4  # single codebook
       num_codebooks: 4  # single codebook
       codebook_size: 168 # codebook size 160 + 2 special tokens
       codebook_size: 168 # codebook size 160 + 2 special tokens
+      dropout: 0.1 # For small dataset, dropout helps to prevent overfitting
 
 
   optimizer:
   optimizer:
     _target_: torch.optim.AdamW
     _target_: torch.optim.AdamW

+ 1 - 0
fish_speech/configs/text2semantic_finetune_lora.yaml

@@ -59,6 +59,7 @@ model:
       norm_eps: 1e-5
       norm_eps: 1e-5
       num_codebooks: 4  # single codebook
       num_codebooks: 4  # single codebook
       codebook_size: 168 # codebook size 160 + 2 special tokens
       codebook_size: 168 # codebook size 160 + 2 special tokens
+      dropout: 0.1 # For small dataset, dropout helps to prevent overfitting
 
 
   lora_config:
   lora_config:
     _target_: fish_speech.models.text2semantic.lit_module.LoraConfig
     _target_: fish_speech.models.text2semantic.lit_module.LoraConfig

+ 10 - 2
fish_speech/models/text2semantic/llama.py

@@ -31,6 +31,7 @@ class ModelArgs:
     rope_base: float = 10000
     rope_base: float = 10000
     norm_eps: float = 1e-5
     norm_eps: float = 1e-5
     max_seq_len: int = 2048
     max_seq_len: int = 2048
+    dropout: float = 0.0
 
 
     # Additional decoding heads
     # Additional decoding heads
     codebook_size: int = 160
     codebook_size: int = 160
@@ -260,6 +261,7 @@ class Attention(nn.Module):
         self.wo = nn.Linear(config.dim, config.dim, bias=False)
         self.wo = nn.Linear(config.dim, config.dim, bias=False)
         self.kv_cache = None
         self.kv_cache = None
 
 
+        self.dropout = config.dropout
         self.n_head = config.n_head
         self.n_head = config.n_head
         self.head_dim = config.head_dim
         self.head_dim = config.head_dim
         self.n_local_heads = config.n_local_heads
         self.n_local_heads = config.n_local_heads
@@ -301,7 +303,13 @@ class Attention(nn.Module):
 
 
             k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
             k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
             v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
             v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
-            y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
+            y = F.scaled_dot_product_attention(
+                q,
+                k,
+                v,
+                attn_mask=mask,
+                dropout_p=self.dropout if self.training else 0.0,
+            )
 
 
             y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
             y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
         else:
         else:
@@ -311,7 +319,7 @@ class Attention(nn.Module):
 
 
             # We don't need to transpose q, k, v here because flash_attn_varlen_func
             # We don't need to transpose q, k, v here because flash_attn_varlen_func
             attn_output = self._flash_attention_forward(
             attn_output = self._flash_attention_forward(
-                q, k, v, mask, seqlen, dropout=0.0
+                q, k, v, mask, seqlen, dropout=self.dropout if self.training else 0.0
             )
             )
 
 
             y = attn_output.reshape(bsz, seqlen, self.dim).contiguous()
             y = attn_output.reshape(bsz, seqlen, self.dim).contiguous()