Lengyue 1 год назад
Родитель
Сommit
03449e5526
2 измененных файлов с 26 добавлено и 11 удалено
  1. 18 7
      tools/llama/generate.py
  2. 8 4
      tools/vqgan/extract_vq.py

+ 18 - 7
tools/llama/generate.py

@@ -494,7 +494,9 @@ def generate_long(
         logger.info(f"Encoded text: {text}")
 
     for sample_idx in range(num_samples):
-        torch.cuda.synchronize()
+        if torch.cuda.is_available():
+            torch.cuda.synchronize()
+
         global_encoded = []
         all_codes = []
         seg_idx = 0
@@ -548,7 +550,9 @@ def generate_long(
             if sample_idx == 0 and seg_idx == 0 and compile:
                 logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
 
-            torch.cuda.synchronize()
+            if torch.cuda.is_available():
+                torch.cuda.synchronize()
+
             t = time.perf_counter() - t0
 
             tokens_generated = y.size(1) - prompt_length
@@ -559,9 +563,11 @@ def generate_long(
             logger.info(
                 f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
             )
-            logger.info(
-                f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
-            )
+
+            if torch.cuda.is_available():
+                logger.info(
+                    f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
+                )
 
             # Put the generated tokens
             # since there is <im_end> and <eos> tokens, we remove last 2 tokens
@@ -702,7 +708,10 @@ def main(
     model, decode_one_token = load_model(
         config_name, checkpoint_path, device, precision, max_length, compile=compile
     )
-    torch.cuda.synchronize()
+
+    if torch.cuda.is_available():
+        torch.cuda.synchronize()
+
     logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
 
     prompt_tokens = (
@@ -713,7 +722,9 @@ def main(
 
     tokenizer = AutoTokenizer.from_pretrained(tokenizer)
     torch.manual_seed(seed)
-    torch.cuda.manual_seed(seed)
+
+    if torch.cuda.is_available():
+        torch.cuda.manual_seed(seed)
 
     generator = generate_long(
         model=model,

+ 8 - 4
tools/vqgan/extract_vq.py

@@ -140,11 +140,15 @@ def main(
 
         logger.info(f"Spawning {num_workers} workers")
 
-        visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
-        if visible_devices is None:
-            visible_devices = list(range(torch.cuda.device_count()))
+        if torch.cuda.is_available():
+            visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
+            if visible_devices is None:
+                visible_devices = list(range(torch.cuda.device_count()))
+            else:
+                visible_devices = visible_devices.split(",")
         else:
-            visible_devices = visible_devices.split(",")
+            # Set to empty string to avoid using GPU
+            visible_devices = [""]
 
         processes = []
         for i in range(num_workers):