Explorar el Código

Implement delay pattern generate

Lengyue hace 2 años
padre
commit
37f4b517c2
Se han modificado 1 ficheros con 15 adiciones y 18 borrados
  1. 15 18
      tools/llama/generate.py

+ 15 - 18
tools/llama/generate.py

@@ -109,8 +109,6 @@ def decode_one_token(
 
     # Disable <s> and </s> tokens for codebooks
     if model.config.num_codebooks != 0:
-        logits.codebook_logits[:, :, :, :1] = -float("Inf")
-
         for i in range(model.config.num_codebooks):
             codebooks.append(
                 sample(
@@ -122,12 +120,7 @@ def decode_one_token(
                 )[0]
             )
 
-    codebooks = torch.stack(codebooks, dim=0)
-    codebooks[1] = torch.where(
-        codebooks[0] <= 32311, codebooks[1], codebooks[0] - 32311 + 2
-    )
-
-    return codebooks
+    return torch.stack(codebooks, dim=0)
 
 
 def prefill(
@@ -143,10 +136,7 @@ def prefill(
         )[0]
     ]
 
-    # Disable <s> and </s> tokens for codebooks
     if model.config.num_codebooks != 0:
-        logits.codebook_logits[:, :, :, :2] = -float("Inf")
-
         for i in range(model.config.num_codebooks):
             codebooks.append(
                 sample(
@@ -330,9 +320,9 @@ def encode_tokens(
         data = data[:num_codebooks]
 
     # Since 1.0, we use <s:xxx> to replace <semantic>
-    main_tokens = [f"<s:{i}>" for i in data[0]]
-    main_token_ids = tokenizer.convert_tokens_to_ids(main_tokens)
-    main_token_ids = torch.tensor([main_token_ids], dtype=torch.int, device=device)
+    main_token_ids = torch.tensor(
+        [[tokenizer.pad_token_id] * data.size(1)], dtype=torch.int, device=device
+    )
 
     data = torch.cat((main_token_ids, data), dim=0)
     prompt = torch.cat((prompt, data), dim=1)
@@ -502,7 +492,7 @@ def main(
             decode_one_token, mode="reduce-overhead", fullgraph=True
         )
 
-    for i in range(num_samples):
+    for idx in range(num_samples):
         torch.cuda.synchronize()
 
         t0 = time.perf_counter()
@@ -518,7 +508,7 @@ def main(
             repetition_penalty=repetition_penalty,
         )
 
-        if i == 0 and compile:
+        if idx == 0 and compile:
             logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
 
         torch.cuda.synchronize()
@@ -535,11 +525,18 @@ def main(
         )
 
         codes = y[1:, prompt_length:-1]
+        new_codes = []
+        for j, code in enumerate(codes):
+            new_codes.append(
+                code[j : codes.shape[1] - (model.config.num_codebooks - j - 1)]
+            )
+
+        codes = torch.stack(new_codes, dim=0)
         codes = codes - 2
         assert (codes >= 0).all(), "Codes should be >= 0"
 
-        np.save(f"codes_{i}.npy", codes.cpu().numpy())
-        logger.info(f"Saved codes to codes_{i}.npy")
+        np.save(f"codes_{idx}.npy", codes.cpu().numpy())
+        logger.info(f"Saved codes to codes_{idx}.npy")
 
 
 if __name__ == "__main__":