Lengyue пре 2 година
родитељ
комит
bd8c904825
1 измењених фајлова са 68 додато и 20 уклоњено
  1. 68 20
      fish_speech/models/text2semantic/modules.py

+ 68 - 20
fish_speech/models/text2semantic/modules.py

@@ -440,15 +440,13 @@ class FishSpeechTransformer(nn.Module):
 
         return codes
 
-    def decode_one_token(
+    def sample_decoder(
         self,
         x: torch.Tensor,
         context: torch.Tensor,
         input_pos: torch.Tensor,
         **sampling_kwargs,
-    ) -> tuple[torch.Tensor, torch.Tensor]:
-        # input_pos: [B, 1]
-        assert input_pos.shape[-1] == 1
+    ):
         attn_bias = self.alibi.alibi[:, input_pos, : self.max_seq_length]
         causual_mask = self.causual_mask[input_pos, : self.max_seq_length]
 
@@ -480,7 +478,7 @@ class FishSpeechTransformer(nn.Module):
             next_token.append(next_token_i)
             probs.append(probs_i)
 
-        return torch.stack(next_token, dim=1), torch.stack(probs, dim=1)
+        return torch.stack(next_token, dim=0), torch.stack(probs, dim=0)
 
     @staticmethod
     def multinomial_sample_one_no_sync(
@@ -490,18 +488,42 @@ class FishSpeechTransformer(nn.Module):
         return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
 
     @staticmethod
-    def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
+    def logits_to_probs(
+        logits,
+        temperature: float = 1.0,
+        top_p: Optional[int] = None,
+        top_k: Optional[int] = None,
+    ):
+        if top_p is not None:
+            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+            cum_probs = torch.cumsum(
+                torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
+            )
+            sorted_indices_to_remove = cum_probs > top_p
+            sorted_indices_to_remove[0] = False  # keep at least one option
+            indices_to_remove = sorted_indices_to_remove.scatter(
+                dim=0, index=sorted_indices, src=sorted_indices_to_remove
+            )
+            logits = logits.masked_fill(indices_to_remove, -float("Inf"))
+
         logits = logits / max(temperature, 1e-5)
 
         if top_k is not None:
             v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
             pivot = v.select(-1, -1).unsqueeze(-1)
             logits = torch.where(logits < pivot, -float("Inf"), logits)
+
         probs = torch.nn.functional.softmax(logits, dim=-1)
         return probs
 
-    def sample(self, logits, temperature: float = 1.0, top_k: Optional[int] = None):
-        probs = self.logits_to_probs(logits[0, -1], temperature, top_k)
+    def sample(
+        self,
+        logits,
+        temperature: float = 1.0,
+        top_p: Optional[int] = None,
+        top_k: Optional[int] = None,
+    ):
+        probs = self.logits_to_probs(logits[0, -1], temperature, top_p, top_k)
         idx_next = self.multinomial_sample_one_no_sync(probs)
         return idx_next, probs
 
@@ -517,24 +539,37 @@ class FishSpeechTransformer(nn.Module):
         new_tokens, new_probs = [], []
 
         for i in range(num_new_tokens):
-            next_token, next_prob = self.decode_one_token(
+            next_token, next_prob = self.sample_decoder(
                 cur_token, context, input_pos, **sampling_kwargs
             )
+
             input_pos += 1
             new_tokens.append(next_token.clone())
             callback(new_tokens[-1])
             new_probs.append(next_prob.clone())
+
+            if next_token[0, 0] == 1:
+                break
+
             cur_token = next_token.view(1, self.num_codebooks, -1)
 
         return new_tokens, new_probs
 
     @torch.no_grad()
-    def inference(self, inputs, max_new_tokens=1024, top_k=5, temperature=1.0):
-        # x: (B, T)
-        # y: (B, C, T)
+    def inference(self, inputs, prompt=None, max_new_tokens=1024, **sampling_kwargs):
+        # inputs: (B, T)
+        # prompt: (B, C, T)
 
         assert inputs.size(0) == 1, "Only support batch size 1 for now"
 
+        if prompt is None:
+            prompt = torch.tensor(
+                [[[0]] * self.num_codebooks], device=inputs.device, dtype=torch.long
+            )
+
+        T = prompt.size(2)
+        T_new = T + max_new_tokens
+
         # Encode Features
         inputs = self.encoder_embedding(inputs)
         attn_bias = self.alibi(inputs)
@@ -545,24 +580,37 @@ class FishSpeechTransformer(nn.Module):
 
         # Decode
         with torch.device(inputs.device):
-            self.setup_kv_caches(max_batch_size=1, max_seq_length=max_new_tokens)
+            self.setup_kv_caches(max_batch_size=1, max_seq_length=T_new)
 
         # create an empty tensor of the expected final shape and fill in the current tokens
-        input_pos = torch.tensor([0], device=device, dtype=torch.long)
-        next_token = torch.tensor(
-            [[0] * self.num_codebooks], device=device, dtype=torch.long
-        )  # BOS of decoder
+        empty = torch.empty(
+            (1, self.num_codebooks, T_new), dtype=torch.long, device=device
+        )
+        empty[:, :, :T] = prompt
+        seq = empty
+        input_pos = torch.arange(0, T, device=device)
+
+        # prefill
+        next_token, _ = self.sample_decoder(
+            prompt.view(1, self.num_codebooks, -1), inputs, input_pos, **sampling_kwargs
+        )
+        seq[:, :, T] = next_token
 
+        # create an empty tensor of the expected final shape and fill in the current tokens
+        input_pos = torch.tensor([T], device=device, dtype=torch.long)
         generated_tokens, _ = self.decode_n_tokens(
             next_token.view(1, self.num_codebooks, -1),
             context=inputs,
             input_pos=input_pos,
             num_new_tokens=max_new_tokens - 1,
-            top_k=top_k,
-            temperature=temperature,
+            **sampling_kwargs,
         )
 
-        return [i[0, 0].item() for i in generated_tokens]
+        generated_tokens = torch.stack(generated_tokens, dim=-1)
+        seq = seq[:, :, : T + 1 + generated_tokens.size(-1)]
+        seq[:, :, T + 1 :] = generated_tokens
+
+        return seq
 
 
 if __name__ == "__main__":