Просмотр исходного кода

fix(inference): GPU memory leakage bug (#1040) (#1073)

* fix: GPU memory leak caused by repeated torch.compile graph builds (#1040)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
胖虎遛二狗 8 месяцев назад
Родитель
Сommit
54458f7cd5
1 измененных файлов с 102 добавлено и 59 удалено
  1. 102 59
      fish_speech/models/text2semantic/inference.py

+ 102 - 59
fish_speech/models/text2semantic/inference.py

@@ -3,10 +3,9 @@ import queue
 import threading
 import time
 import traceback
-from contextlib import nullcontext
 from dataclasses import dataclass
 from pathlib import Path
-from typing import Literal, Optional, Tuple, Union
+from typing import Callable, Literal, Optional, Tuple, Union
 
 import click
 import numpy as np
@@ -106,17 +105,17 @@ def decode_one_token_ar(
     repetition_penalty: torch.Tensor,
     audio_masks: torch.Tensor,
     audio_parts: torch.Tensor,
-    previous_tokens: torch.Tensor = None,
+    previous_tokens: Optional[torch.Tensor] = None,
 ) -> torch.Tensor:
     # print(x, torch.count_nonzero(vq_masks))
-    x = model.forward_generate(
+    forward_result = model.forward_generate(
         x,
         input_pos,
         audio_masks=audio_masks,
         audio_parts=audio_parts,
     )
-    logits = x.logits  # [:, -1:]
-    hidden_states = x.hidden_states  # [:, -1:]
+    logits = forward_result.logits  # [:, -1:]
+    hidden_states = forward_result.hidden_states  # [:, -1:]
 
     codebooks = [
         sample(
@@ -130,10 +129,11 @@ def decode_one_token_ar(
         )[0]
     ]
 
-    # Cleanup the cache
+    # Only clear cache for fast_layers, avoid clearing main model cache
     for layer in model.fast_layers:
-        layer.attention.kv_cache.k_cache.fill_(0)
-        layer.attention.kv_cache.v_cache.fill_(0)
+        if hasattr(layer, "attention") and hasattr(layer.attention, "kv_cache"):
+            layer.attention.kv_cache.k_cache.fill_(0)
+            layer.attention.kv_cache.v_cache.fill_(0)
 
     input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
     model.forward_generate_fast(hidden_states, input_pos)
@@ -167,11 +167,15 @@ def decode_one_token_ar(
         codebooks.append(a)
 
     codebooks = torch.stack(codebooks, dim=1)
+
+    # Only delete references, let Python GC handle cleanup
+    del logits, hidden_states, forward_result
+
     return codebooks.T
 
 
 def decode_n_tokens(
-    model: NaiveTransformer,
+    model: DualARTransformer,
     cur_token: torch.Tensor,
     input_pos: torch.Tensor,
     num_new_tokens: int,
@@ -220,6 +224,9 @@ def decode_n_tokens(
         if cur_token[0, 0, -1] == model.tokenizer.get_token_id(IM_END_TOKEN):
             break
 
+    # Only clean up the large tensor
+    del cur_token
+
     return previous_tokens[:, : i + 1]
 
 
@@ -227,7 +234,7 @@ def decode_n_tokens(
 @torch.inference_mode()
 def generate(
     *,
-    model: BaseTransformer,
+    model: DualARTransformer,
     prompt: torch.Tensor,
     max_new_tokens: int,
     audio_masks: torch.Tensor,
@@ -259,28 +266,51 @@ def generate(
         max_new_tokens = T_new - T
 
     device, dtype = prompt.device, prompt.dtype
-    with torch.device(device):
-        model.setup_caches(
-            max_batch_size=num_samples,
-            max_seq_len=model.config.max_seq_len,
-            dtype=next(model.parameters()).dtype,
-        )
+
+    # Critical fix: Only set up cache on first run or when necessary
+    if not hasattr(model, "_cache_setup_done") or not model._cache_setup_done:
+        with torch.device(device):
+            model.setup_caches(
+                max_batch_size=1,  # Fixed to 1, avoid dynamic changes
+                max_seq_len=model.config.max_seq_len,
+                dtype=next(model.parameters()).dtype,
+            )
+        model._cache_setup_done = True
 
     codebook_dim = 1 + model.config.num_codebooks
-    input_pos = torch.arange(0, T, device=device)
+
+    # Create new tensor each time, but try to reuse memory
+    input_pos = torch.arange(0, T, device=device, dtype=torch.long)
     empty = torch.empty(
         (codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
     )
     empty[:, :T] = prompt
     seq = empty
 
-    temperature = torch.tensor(
-        sampling_kwargs["temperature"], device=device, dtype=torch.bfloat16
+    # Use pre-created fixed parameter tensors
+    temperature = getattr(
+        model, "fixed_temperature", torch.tensor(0.8, device=device, dtype=torch.float)
     )
-    top_p = torch.tensor(sampling_kwargs["top_p"], device=device, dtype=torch.bfloat16)
-    repetition_penalty = torch.tensor(
-        sampling_kwargs["repetition_penalty"], device=device, dtype=torch.bfloat16
+    top_p = getattr(
+        model, "fixed_top_p", torch.tensor(0.8, device=device, dtype=torch.float)
     )
+    repetition_penalty = getattr(
+        model,
+        "fixed_repetition_penalty",
+        torch.tensor(1.1, device=device, dtype=torch.float),
+    )
+
+    # If different parameter values are needed, directly modify existing tensors
+    temp_val = sampling_kwargs.get("temperature", 0.7)
+    top_p_val = sampling_kwargs.get("top_p", 0.7)
+    rep_val = sampling_kwargs.get("repetition_penalty", 1.5)
+
+    if abs(temperature.item() - temp_val) > 1e-6:
+        temperature.fill_(temp_val)
+    if abs(top_p.item() - top_p_val) > 1e-6:
+        top_p.fill_(top_p_val)
+    if abs(repetition_penalty.item() - rep_val) > 1e-6:
+        repetition_penalty.fill_(rep_val)
 
     prefill_decode = decode_one_token_ar
 
@@ -296,7 +326,9 @@ def generate(
     )
     seq[:, T : T + 1] = first_token
 
+    # Recreate input_pos
     input_pos = torch.tensor([T], device=device, dtype=torch.int)
+
     x = decode_n_tokens(
         model,
         first_token.view(1, codebook_dim, -1),
@@ -311,6 +343,10 @@ def generate(
     )
     seq = seq[:, : T + 1 + x.size(1)]
     seq[:, T + 1 :] = x
+
+    # Clean up temporary variables
+    del first_token, x, prompt, empty, input_pos
+
     return seq
 
 
@@ -327,19 +363,18 @@ def init_model(checkpoint_path, device, precision, compile=False):
     else:
         raise ValueError("Unsupported model type")
 
-    # Initialize cache
-    with torch.device(device):
-        model.setup_caches(
-            max_batch_size=1,
-            max_seq_len=model.config.max_seq_len,
-            dtype=next(model.parameters()).dtype,
-        )
+    # Pre-create fixed parameter tensors to avoid runtime creation
+    model.fixed_temperature = torch.tensor(0.7, device=device, dtype=torch.float)
+    model.fixed_top_p = torch.tensor(0.7, device=device, dtype=torch.float)
+    model.fixed_repetition_penalty = torch.tensor(1.5, device=device, dtype=torch.float)
+
+    # Mark whether cache has been initialized
+    model._cache_setup_done = False
 
     if compile:
         logger.info("Compiling function...")
         decode_one_token = torch.compile(
             decode_one_token,
-            # mode="max-autotune-no-cudagraphs",
             backend="inductor" if torch.cuda.is_available() else "aot_eager",
             mode="reduce-overhead" if torch.cuda.is_available() else None,
             fullgraph=True,
@@ -358,19 +393,19 @@ class GenerateResponse:
 def generate_long(
     *,
     model,
-    device: str | torch.device,
-    decode_one_token: callable,
+    device: Union[str, torch.device],
+    decode_one_token: Callable,
     text: str,
     num_samples: int = 1,
     max_new_tokens: int = 0,
-    top_p: int = 0.8,
+    top_p: float = 0.8,
     repetition_penalty: float = 1.1,
     temperature: float = 0.8,
     compile: bool = False,
     iterative_prompt: bool = True,
     chunk_length: int = 512,
-    prompt_text: Optional[str | list[str]] = None,
-    prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
+    prompt_text: Optional[Union[str, list[str]]] = None,
+    prompt_tokens: Optional[Union[torch.Tensor, list[torch.Tensor]]] = None,
 ):
     assert 0 < top_p <= 1, "top_p must be in (0, 1]"
     assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
@@ -381,11 +416,13 @@ def generate_long(
         prompt_text = [prompt_text]
         prompt_tokens = [prompt_tokens]
 
-    assert use_prompt is False or len(prompt_text) == len(
-        prompt_tokens
-    ), "Prompt text and tokens must have the same length"
+    if use_prompt:
+        assert len(prompt_text) == len(
+            prompt_tokens
+        ), "Prompt text and tokens must have the same length"
 
-    prompt_tokens = [i.cpu() for i in prompt_tokens]
+    if prompt_tokens:
+        prompt_tokens = [i.cpu() for i in prompt_tokens]
 
     model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
     tokenizer = model.tokenizer
@@ -419,14 +456,6 @@ def generate_long(
     encoded = encoded.to(device=device)
     logger.info(f"Encoded text: {text}")
 
-    # Move temperature, top_p, repetition_penalty to device
-    # This is important so that changing params doesn't trigger recompile
-    temperature = torch.tensor(temperature, device=device, dtype=torch.float)
-    top_p = torch.tensor(top_p, device=device, dtype=torch.float)
-    repetition_penalty = torch.tensor(
-        repetition_penalty, device=device, dtype=torch.float
-    )
-
     for sample_idx in range(num_samples):
         if torch.cuda.is_available():
             torch.cuda.synchronize()
@@ -436,6 +465,7 @@ def generate_long(
         prompt_length = encoded.size(1)
 
         t0 = time.perf_counter()
+
         y = generate(
             model=model,
             prompt=encoded,
@@ -469,26 +499,26 @@ def generate_long(
             )
 
         # Put the generated tokens
-        # since there is <im_end>, we remove last token
         codes = y[1:, prompt_length:-1].clone()
         assert (codes >= 0).all(), f"Negative code found"
 
         decoded = y[:, prompt_length:].clone()
-        # But for global encoding, we should keep the <im_end> token
-
         global_encoded.append(decoded.cpu())
         assert (codes >= 0).all(), f"Negative code found: {codes}"
+
         yield GenerateResponse(action="sample", codes=codes, text=text)
         seg_idx += 1
 
-        # This indicates the end of the current sample
+        # Force GPU memory cleanup
+        del y, decoded, codes
+
         yield GenerateResponse(action="next")
 
 
 @dataclass
 class WrappedGenerateResponse:
     status: Literal["success", "error"]
-    response: Optional[GenerateResponse | Exception] = None
+    response: Optional[Union[GenerateResponse, Exception]] = None
 
 
 @dataclass
@@ -533,9 +563,17 @@ def launch_thread_safe_queue(
                     response_queue.put(
                         WrappedGenerateResponse(status="success", response=chunk)
                     )
+
+                # Only clear cache after complete request batch
+                if torch.cuda.is_available():
+                    torch.cuda.empty_cache()
+
             except Exception as e:
                 logger.error(traceback.format_exc())
                 response_queue.put(WrappedGenerateResponse(status="error", response=e))
+                # Clear cache on error
+                if torch.cuda.is_available():
+                    torch.cuda.empty_cache()
 
     threading.Thread(target=worker, daemon=True).start()
     init_event.wait()
@@ -575,8 +613,8 @@ def launch_thread_safe_queue(
 @click.option("--output-dir", type=Path, default="temp")
 def main(
     text: str,
-    prompt_text: Optional[list[str]],
-    prompt_tokens: Optional[list[Path]],
+    prompt_text: Optional[tuple[str, ...]],
+    prompt_tokens: Optional[tuple[Path, ...]],
     num_samples: int,
     max_new_tokens: int,
     top_p: int,
@@ -594,7 +632,11 @@ def main(
     os.makedirs(output_dir, exist_ok=True)
     precision = torch.half if half else torch.bfloat16
 
-    if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
+    if (
+        prompt_text is not None
+        and prompt_tokens is not None
+        and len(prompt_text) != len(prompt_tokens)
+    ):
         raise ValueError(
             f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
         )
@@ -615,8 +657,9 @@ def main(
 
     logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
 
+    prompt_tokens_list = None
     if prompt_tokens is not None:
-        prompt_tokens = [torch.from_numpy(np.load(p)) for p in prompt_tokens]
+        prompt_tokens_list = [torch.from_numpy(np.load(p)) for p in prompt_tokens]
 
     torch.manual_seed(seed)
 
@@ -636,8 +679,8 @@ def main(
         compile=compile,
         iterative_prompt=iterative_prompt,
         chunk_length=chunk_length,
-        prompt_text=prompt_text,
-        prompt_tokens=prompt_tokens,
+        prompt_text=list(prompt_text) if prompt_text else None,
+        prompt_tokens=prompt_tokens_list,
     )
 
     idx = 0