Lengyue 1 年間 前
コミット
762562b7e6
1 ファイル変更1 行追加0 行削除
  1. 1 0
      tools/llama/generate.py

+ 1 - 0
tools/llama/generate.py

@@ -94,6 +94,7 @@ def decode_one_token_ar(
     **sampling_kwargs,
 ) -> torch.Tensor:
     x = model.forward_generate(x, input_pos)
+    logits = x.token_logits
 
     sampling_kwargs_main = sampling_kwargs.copy()
     sampling_kwargs_main["temperature"] = 0.1