Lengyue 1 rok temu
rodzic
commit
0cd7df0309
1 zmienionych plików z 2 dodań i 4 usunięć
  1. 2 4
      tools/llama/generate.py

+ 2 - 4
tools/llama/generate.py

@@ -94,7 +94,6 @@ def decode_one_token_ar(
     **sampling_kwargs,
 ) -> torch.Tensor:
     x = model.forward_generate(x, input_pos)
-    logits = x.token_logits
 
     sampling_kwargs_main = sampling_kwargs.copy()
     sampling_kwargs_main["temperature"] = 0.1
@@ -103,7 +102,7 @@ def decode_one_token_ar(
 
     codebooks = [
         sample(
-            logits,
+            x.logits,
             previous_tokens=None,  # Disable repetition penalty for the token codebook
             **sampling_kwargs_main,
         )[0]
@@ -142,7 +141,6 @@ def decode_one_token_naive(
     **sampling_kwargs,
 ) -> torch.Tensor:
     x = model.forward_generate(x, input_pos)
-    logits = x.token_logits
 
     sampling_kwargs_main = sampling_kwargs.copy()
     sampling_kwargs_main["temperature"] = 0.1
@@ -151,7 +149,7 @@ def decode_one_token_naive(
 
     codebooks = [
         sample(
-            logits,
+            x.logits,
             previous_tokens=None,  # Disable repetition penalty for the token codebook
             **sampling_kwargs_main,
         )[0]