Lengyue 2 лет назад
Родитель
Сommit
3d002fd18b
2 измененных файлов с 28 добавлено и 3 удалено
  1. 7 1
      fish_speech/models/text2semantic/llama.py
  2. 21 2
      tools/llama/generate.py

+ 7 - 1
fish_speech/models/text2semantic/llama.py

@@ -248,7 +248,13 @@ class Transformer(nn.Module):
         codebook_logits = self.fast_output(fast_out)
 
         # Re-pad the codebook_logits
-        buffer = torch.zeros(x_bs, x_len, codebook_logits.size(-1), device=x.device)
+        buffer = torch.zeros(
+            x_bs,
+            x_len,
+            codebook_logits.size(-1),
+            device=codebook_logits.device,
+            dtype=codebook_logits.dtype,
+        )
         buffer[~codebook_mask] = codebook_logits
         codebook_logits = buffer
 

+ 21 - 2
tools/llama/generate.py

@@ -509,7 +509,7 @@ def main(
 
     use_prompt = prompt_text is not None and prompt_tokens is not None
     encoded = []
-    texts = split_text(text, 20) if iterative_prompt else [text]
+    texts = split_text(text, 30) if iterative_prompt else [text]
     for idx, text in enumerate(texts):
         encoded.append(
             encode_tokens(
@@ -561,7 +561,26 @@ def main(
         while seg_idx < len(encoded):
             seg = encoded[seg_idx]
             global_encoded.append(seg)
-            cat_encoded = torch.cat(global_encoded, dim=1)
+
+            lengths = reversed([seg.size(1) for seg in global_encoded])
+            # Pick last 2000 tokens
+            count = 0
+            for i, length in enumerate(lengths):
+                count += length
+                if count >= 2000:
+                    break
+
+            if i != 0 and i % 2 == 0:
+                i -= 1
+
+            if i < len(global_encoded) - 2:
+                partial_encoded = global_encoded[-i:]
+                print(f"Loaded partial encoded")
+            else:
+                partial_encoded = global_encoded
+                print(f"Using full encoded")
+
+            cat_encoded = torch.cat(partial_encoded, dim=1)
             prompt_length = cat_encoded.size(1)
 
             t0 = time.perf_counter()