ソースを参照

Avoid cuda-dependent code for CPU-only inference (#499)

* Avoid cuda-dependent code for CPU-only inference

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sergey Aleynikov 1 年間 前
コミット
9e2f5e6b3a
1 ファイル変更7 行追加2 行削除
  1. 7 2
      tools/llama/generate.py

+ 7 - 2
tools/llama/generate.py

@@ -2,6 +2,7 @@ import os
 import queue
 import threading
 import time
+from contextlib import nullcontext
 from dataclasses import dataclass
 from pathlib import Path
 from typing import Literal, Optional, Tuple, Union
@@ -181,8 +182,12 @@ def decode_n_tokens(
         else:
             window = previous_tokens[:, i - win_size : i]
 
-        with torch.backends.cuda.sdp_kernel(
-            enable_flash=False, enable_mem_efficient=False, enable_math=True
+        with (
+            torch.backends.cuda.sdp_kernel(
+                enable_flash=False, enable_mem_efficient=False, enable_math=True
+            )
+            if torch.cuda.is_available()
+            else nullcontext()
         ):  # Actually better for Inductor to codegen attention here
             next_token = decode_one_token(
                 model=model,