Lengyue 1 год назад
Родитель
Сommit
90f22aabce
1 измененных файлов с 18 добавлено и 7 удалено
  1. 18 7
      tools/llama/generate.py

+ 18 - 7
tools/llama/generate.py

@@ -94,15 +94,20 @@ def decode_one_token_ar(
     **sampling_kwargs,
     **sampling_kwargs,
 ) -> torch.Tensor:
 ) -> torch.Tensor:
     x = model.forward_generate(x, input_pos)
     x = model.forward_generate(x, input_pos)
+
+    sampling_kwargs_main = sampling_kwargs.copy()
+    sampling_kwargs_main["temperature"] = 0.1
+    sampling_kwargs_main["top_p"] = 0.1
+    sampling_kwargs_main["repetition_penalty"] = 1.0
+
     codebooks = [
     codebooks = [
         sample(
         sample(
-            x.logits,
-            previous_tokens=(
-                previous_tokens[0] if previous_tokens is not None else None
-            ),  # Disable repetition penalty for the token codebook
-            **sampling_kwargs,
+            logits,
+            previous_tokens=None,  # Disable repetition penalty for the token codebook
+            **sampling_kwargs_main,
         )[0]
         )[0]
     ]
     ]
+
     x = x.hidden_states
     x = x.hidden_states
 
 
     # Cleanup the cache
     # Cleanup the cache
@@ -136,12 +141,18 @@ def decode_one_token_naive(
     **sampling_kwargs,
     **sampling_kwargs,
 ) -> torch.Tensor:
 ) -> torch.Tensor:
     x = model.forward_generate(x, input_pos)
     x = model.forward_generate(x, input_pos)
+    logits = x.token_logits
+
+    sampling_kwargs_main = sampling_kwargs.copy()
+    sampling_kwargs_main["temperature"] = 0.1
+    sampling_kwargs_main["top_p"] = 0.1
+    sampling_kwargs_main["repetition_penalty"] = 1.0
 
 
     codebooks = [
     codebooks = [
         sample(
         sample(
-            x.token_logits,
+            logits,
             previous_tokens=None,  # Disable repetition penalty for the token codebook
             previous_tokens=None,  # Disable repetition penalty for the token codebook
-            **sampling_kwargs,
+            **sampling_kwargs_main,
         )[0]
         )[0]
     ]
     ]