Procházet zdrojové kódy

Avoid cuda-dependent code for CPU-only inference (#499)

* Avoid cuda-dependent code for CPU-only inference

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sergey Aleynikov před 1 rokem
rodič
revize
9e2f5e6b3a
1 změnil soubory, kde provedl 7 přidání a 2 odebrání
  1. 7 2
      tools/llama/generate.py

+ 7 - 2
tools/llama/generate.py

@@ -2,6 +2,7 @@ import os
 import queue
 import queue
 import threading
 import threading
 import time
 import time
+from contextlib import nullcontext
 from dataclasses import dataclass
 from dataclasses import dataclass
 from pathlib import Path
 from pathlib import Path
 from typing import Literal, Optional, Tuple, Union
 from typing import Literal, Optional, Tuple, Union
@@ -181,8 +182,12 @@ def decode_n_tokens(
         else:
         else:
             window = previous_tokens[:, i - win_size : i]
             window = previous_tokens[:, i - win_size : i]
 
 
-        with torch.backends.cuda.sdp_kernel(
-            enable_flash=False, enable_mem_efficient=False, enable_math=True
+        with (
+            torch.backends.cuda.sdp_kernel(
+                enable_flash=False, enable_mem_efficient=False, enable_math=True
+            )
+            if torch.cuda.is_available()
+            else nullcontext()
         ):  # Actually better for Inductor to codegen attention here
         ):  # Actually better for Inductor to codegen attention here
             next_token = decode_one_token(
             next_token = decode_one_token(
                 model=model,
                 model=model,