|
|
@@ -440,15 +440,13 @@ class FishSpeechTransformer(nn.Module):
|
|
|
|
|
|
return codes
|
|
|
|
|
|
- def decode_one_token(
|
|
|
+ def sample_decoder(
|
|
|
self,
|
|
|
x: torch.Tensor,
|
|
|
context: torch.Tensor,
|
|
|
input_pos: torch.Tensor,
|
|
|
**sampling_kwargs,
|
|
|
- ) -> tuple[torch.Tensor, torch.Tensor]:
|
|
|
- # input_pos: [B, 1]
|
|
|
- assert input_pos.shape[-1] == 1
|
|
|
+ ):
|
|
|
attn_bias = self.alibi.alibi[:, input_pos, : self.max_seq_length]
|
|
|
causual_mask = self.causual_mask[input_pos, : self.max_seq_length]
|
|
|
|
|
|
@@ -480,7 +478,7 @@ class FishSpeechTransformer(nn.Module):
|
|
|
next_token.append(next_token_i)
|
|
|
probs.append(probs_i)
|
|
|
|
|
|
- return torch.stack(next_token, dim=1), torch.stack(probs, dim=1)
|
|
|
+ return torch.stack(next_token, dim=0), torch.stack(probs, dim=0)
|
|
|
|
|
|
@staticmethod
|
|
|
def multinomial_sample_one_no_sync(
|
|
|
@@ -490,18 +488,42 @@ class FishSpeechTransformer(nn.Module):
|
|
|
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
|
|
|
|
|
@staticmethod
|
|
|
- def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
|
|
|
+ def logits_to_probs(
|
|
|
+ logits,
|
|
|
+ temperature: float = 1.0,
|
|
|
+ top_p: Optional[int] = None,
|
|
|
+ top_k: Optional[int] = None,
|
|
|
+ ):
|
|
|
+ if top_p is not None:
|
|
|
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
|
+ cum_probs = torch.cumsum(
|
|
|
+ torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
|
|
|
+ )
|
|
|
+ sorted_indices_to_remove = cum_probs > top_p
|
|
|
+ sorted_indices_to_remove[0] = False # keep at least one option
|
|
|
+ indices_to_remove = sorted_indices_to_remove.scatter(
|
|
|
+ dim=0, index=sorted_indices, src=sorted_indices_to_remove
|
|
|
+ )
|
|
|
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
|
|
+
|
|
|
logits = logits / max(temperature, 1e-5)
|
|
|
|
|
|
if top_k is not None:
|
|
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
|
|
pivot = v.select(-1, -1).unsqueeze(-1)
|
|
|
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
|
|
+
|
|
|
probs = torch.nn.functional.softmax(logits, dim=-1)
|
|
|
return probs
|
|
|
|
|
|
- def sample(self, logits, temperature: float = 1.0, top_k: Optional[int] = None):
|
|
|
- probs = self.logits_to_probs(logits[0, -1], temperature, top_k)
|
|
|
+ def sample(
|
|
|
+ self,
|
|
|
+ logits,
|
|
|
+ temperature: float = 1.0,
|
|
|
+ top_p: Optional[int] = None,
|
|
|
+ top_k: Optional[int] = None,
|
|
|
+ ):
|
|
|
+ probs = self.logits_to_probs(logits[0, -1], temperature, top_p, top_k)
|
|
|
idx_next = self.multinomial_sample_one_no_sync(probs)
|
|
|
return idx_next, probs
|
|
|
|
|
|
@@ -517,24 +539,37 @@ class FishSpeechTransformer(nn.Module):
|
|
|
new_tokens, new_probs = [], []
|
|
|
|
|
|
for i in range(num_new_tokens):
|
|
|
- next_token, next_prob = self.decode_one_token(
|
|
|
+ next_token, next_prob = self.sample_decoder(
|
|
|
cur_token, context, input_pos, **sampling_kwargs
|
|
|
)
|
|
|
+
|
|
|
input_pos += 1
|
|
|
new_tokens.append(next_token.clone())
|
|
|
callback(new_tokens[-1])
|
|
|
new_probs.append(next_prob.clone())
|
|
|
+
|
|
|
+ if next_token[0, 0] == 1:
|
|
|
+ break
|
|
|
+
|
|
|
cur_token = next_token.view(1, self.num_codebooks, -1)
|
|
|
|
|
|
return new_tokens, new_probs
|
|
|
|
|
|
@torch.no_grad()
|
|
|
- def inference(self, inputs, max_new_tokens=1024, top_k=5, temperature=1.0):
|
|
|
- # x: (B, T)
|
|
|
- # y: (B, C, T)
|
|
|
+ def inference(self, inputs, prompt=None, max_new_tokens=1024, **sampling_kwargs):
|
|
|
+ # inputs: (B, T)
|
|
|
+ # prompt: (B, C, T)
|
|
|
|
|
|
assert inputs.size(0) == 1, "Only support batch size 1 for now"
|
|
|
|
|
|
+ if prompt is None:
|
|
|
+ prompt = torch.tensor(
|
|
|
+ [[[0]] * self.num_codebooks], device=inputs.device, dtype=torch.long
|
|
|
+ )
|
|
|
+
|
|
|
+ T = prompt.size(2)
|
|
|
+ T_new = T + max_new_tokens
|
|
|
+
|
|
|
# Encode Features
|
|
|
inputs = self.encoder_embedding(inputs)
|
|
|
attn_bias = self.alibi(inputs)
|
|
|
@@ -545,24 +580,37 @@ class FishSpeechTransformer(nn.Module):
|
|
|
|
|
|
# Decode
|
|
|
with torch.device(inputs.device):
|
|
|
- self.setup_kv_caches(max_batch_size=1, max_seq_length=max_new_tokens)
|
|
|
+ self.setup_kv_caches(max_batch_size=1, max_seq_length=T_new)
|
|
|
|
|
|
# create an empty tensor of the expected final shape and fill in the current tokens
|
|
|
- input_pos = torch.tensor([0], device=device, dtype=torch.long)
|
|
|
- next_token = torch.tensor(
|
|
|
- [[0] * self.num_codebooks], device=device, dtype=torch.long
|
|
|
- ) # BOS of decoder
|
|
|
+ empty = torch.empty(
|
|
|
+ (1, self.num_codebooks, T_new), dtype=torch.long, device=device
|
|
|
+ )
|
|
|
+ empty[:, :, :T] = prompt
|
|
|
+ seq = empty
|
|
|
+ input_pos = torch.arange(0, T, device=device)
|
|
|
+
|
|
|
+ # prefill
|
|
|
+ next_token, _ = self.sample_decoder(
|
|
|
+ prompt.view(1, self.num_codebooks, -1), inputs, input_pos, **sampling_kwargs
|
|
|
+ )
|
|
|
+ seq[:, :, T] = next_token
|
|
|
|
|
|
+ # create an empty tensor of the expected final shape and fill in the current tokens
|
|
|
+ input_pos = torch.tensor([T], device=device, dtype=torch.long)
|
|
|
generated_tokens, _ = self.decode_n_tokens(
|
|
|
next_token.view(1, self.num_codebooks, -1),
|
|
|
context=inputs,
|
|
|
input_pos=input_pos,
|
|
|
num_new_tokens=max_new_tokens - 1,
|
|
|
- top_k=top_k,
|
|
|
- temperature=temperature,
|
|
|
+ **sampling_kwargs,
|
|
|
)
|
|
|
|
|
|
- return [i[0, 0].item() for i in generated_tokens]
|
|
|
+ generated_tokens = torch.stack(generated_tokens, dim=-1)
|
|
|
+ seq = seq[:, :, : T + 1 + generated_tokens.size(-1)]
|
|
|
+ seq[:, :, T + 1 :] = generated_tokens
|
|
|
+
|
|
|
+ return seq
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|