zhaohaipeng před 1 měsícem
rodič
revize
f8be1c3822

+ 2 - 0
fish_speech/models/dac/modded_dac.py

@@ -249,6 +249,8 @@ class Attention(nn.Module):
     ) -> Tensor:
         bsz, seqlen, _ = x.shape
 
+        print(f"Attention forward self.n_local_heads {self.n_local_heads}, self.head_dim {self.head_dim}")
+
         kv_size = self.n_local_heads * self.head_dim
         q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
         context_seqlen = seqlen

+ 1 - 0
fish_speech/models/text2semantic/llama.py

@@ -198,6 +198,7 @@ class KVCache(nn.Module):
         self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
     ):
         super().__init__()
+        logger.info(f"Initializing KVCache max_batch_size={max_batch_size}, max_seq_len={max_seq_len}, n_heads={n_heads}, head_dim={head_dim}")
         cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
         self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
         self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))