소스 검색

fix generate

Lengyue 1 년 전
부모
커밋
762562b7e6
1개의 변경된 파일1개의 추가작업 그리고 0개의 파일을 삭제
  1. 1 0
      tools/llama/generate.py

+ 1 - 0
tools/llama/generate.py

@@ -94,6 +94,7 @@ def decode_one_token_ar(
     **sampling_kwargs,
 ) -> torch.Tensor:
     x = model.forward_generate(x, input_pos)
+    logits = x.token_logits
 
     sampling_kwargs_main = sampling_kwargs.copy()
     sampling_kwargs_main["temperature"] = 0.1