Parcourir la source

Add half precision inference and document

Lengyue il y a 2 ans
Parent
commit
c583555995

+ 3 - 0
docs/en/inference.md

@@ -48,6 +48,9 @@ This command will create a `codes_N` file in the working directory, where N is a
     You may want to use `--compile` to fuse CUDA kernels for faster inference (~30 tokens/second -> ~500 tokens/second).
     You may want to use `--compile` to fuse CUDA kernels for faster inference (~30 tokens/second -> ~500 tokens/second).
     Correspondingly, if you do not plan to use acceleration, you can comment out the `--compile` parameter.
     Correspondingly, if you do not plan to use acceleration, you can comment out the `--compile` parameter.
 
 
+!!! info
+    For GPUs that do not support bf16, you may need to use the `--half` parameter.
+
 ### 3. Generate vocals from semantic tokens:
 ### 3. Generate vocals from semantic tokens:
 ```bash
 ```bash
 python tools/vqgan/inference.py \
 python tools/vqgan/inference.py \

+ 3 - 0
docs/zh/inference.md

@@ -53,6 +53,9 @@ python tools/llama/generate.py \
     您可能希望使用 `--compile` 来融合 cuda 内核以实现更快的推理 (~30 个 token/秒 -> ~500 个 token/秒).  
     您可能希望使用 `--compile` 来融合 cuda 内核以实现更快的推理 (~30 个 token/秒 -> ~500 个 token/秒).  
     对应的, 如果你不打算使用加速, 你可以注释掉 `--compile` 参数.
     对应的, 如果你不打算使用加速, 你可以注释掉 `--compile` 参数.
 
 
+!!! info
+    对于不支持 bf16 的 GPU, 你可能需要使用 `--half` 参数.
+
 ### 3. 从语义 token 生成人声: 
 ### 3. 从语义 token 生成人声: 
 ```bash
 ```bash
 python tools/vqgan/inference.py \
 python tools/vqgan/inference.py \

+ 6 - 2
fish_speech/models/text2semantic/llama.py

@@ -110,7 +110,7 @@ class Transformer(nn.Module):
         self.max_batch_size = -1
         self.max_batch_size = -1
         self.max_seq_len = -1
         self.max_seq_len = -1
 
 
-    def setup_caches(self, max_batch_size, max_seq_len):
+    def setup_caches(self, max_batch_size, max_seq_len, dtype=torch.bfloat16):
         if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
         if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
             return
             return
 
 
@@ -121,7 +121,11 @@ class Transformer(nn.Module):
 
 
         for b in self.layers:
         for b in self.layers:
             b.attention.kv_cache = KVCache(
             b.attention.kv_cache = KVCache(
-                max_batch_size, max_seq_len, self.config.n_local_heads, head_dim
+                max_batch_size,
+                max_seq_len,
+                self.config.n_local_heads,
+                head_dim,
+                dtype=dtype,
             )
             )
 
 
     def embed(self, x: Tensor) -> Tensor:
     def embed(self, x: Tensor) -> Tensor:

+ 7 - 2
tools/llama/generate.py

@@ -207,6 +207,7 @@ def generate(
     prompt: torch.Tensor,
     prompt: torch.Tensor,
     max_new_tokens: int,
     max_new_tokens: int,
     eos_token_id: int = 2,
     eos_token_id: int = 2,
+    precision: torch.dtype = torch.bfloat16,
     **sampling_kwargs,
     **sampling_kwargs,
 ) -> torch.Tensor:
 ) -> torch.Tensor:
     """
     """
@@ -228,7 +229,7 @@ def generate(
 
 
     device, dtype = prompt.device, prompt.dtype
     device, dtype = prompt.device, prompt.dtype
     with torch.device(device):
     with torch.device(device):
-        model.setup_caches(max_batch_size=1, max_seq_len=T_new)
+        model.setup_caches(max_batch_size=1, max_seq_len=T_new, dtype=precision)
 
 
     codebook_dim = 1 + model.config.num_codebooks
     codebook_dim = 1 + model.config.num_codebooks
     # create an empty tensor of the expected final shape and fill in the current tokens
     # create an empty tensor of the expected final shape and fill in the current tokens
@@ -381,6 +382,7 @@ def load_model(config_name, checkpoint_path, device, precision):
 @click.option("--use-g2p/--no-g2p", default=True)
 @click.option("--use-g2p/--no-g2p", default=True)
 @click.option("--seed", type=int, default=42)
 @click.option("--seed", type=int, default=42)
 @click.option("--speaker", type=str, default=None)
 @click.option("--speaker", type=str, default=None)
+@click.option("--half/--no-half", default=False)
 def main(
 def main(
     text: str,
     text: str,
     prompt_text: Optional[str],
     prompt_text: Optional[str],
@@ -398,9 +400,11 @@ def main(
     use_g2p: bool,
     use_g2p: bool,
     seed: int,
     seed: int,
     speaker: Optional[str],
     speaker: Optional[str],
+    half: bool,
 ) -> None:
 ) -> None:
     device = "cuda"
     device = "cuda"
-    precision = torch.bfloat16
+
+    precision = torch.half if half else torch.bfloat16
 
 
     logger.info("Loading model ...")
     logger.info("Loading model ...")
     t0 = time.time()
     t0 = time.time()
@@ -445,6 +449,7 @@ def main(
             prompt=encoded,
             prompt=encoded,
             max_new_tokens=max_new_tokens,
             max_new_tokens=max_new_tokens,
             eos_token_id=tokenizer.eos_token_id,
             eos_token_id=tokenizer.eos_token_id,
+            precision=precision,
             temperature=temperature,
             temperature=temperature,
             top_k=top_k,
             top_k=top_k,
             top_p=top_p,
             top_p=top_p,