Procházet zdrojové kódy

Optimize compute graph for dynamic params

Lengyue před 1 rokem
rodič
revize
813868fc87
3 změnil soubory, kde provedl 27 přidání a 38 odebrání
  1. 0 3
      tools/api.py
  2. 27 27
      tools/llama/generate.py
  3. 0 8
      tools/webui.py

+ 0 - 3
tools/api.py

@@ -67,7 +67,6 @@ class InvokeRequest(BaseModel):
     reference_audio: Optional[str] = None
     max_new_tokens: int = 0
     chunk_length: int = 30
-    top_k: int = 0
     top_p: float = 0.7
     repetition_penalty: float = 1.5
     temperature: float = 0.7
@@ -104,7 +103,6 @@ def inference(req: InvokeRequest):
         device=vqgan_model.device,
         max_new_tokens=req.max_new_tokens,
         text=req.text,
-        top_k=int(req.top_k) if req.top_k > 0 else None,
         top_p=req.top_p,
         repetition_penalty=req.repetition_penalty,
         temperature=req.temperature,
@@ -281,7 +279,6 @@ if __name__ == "__main__":
             reference_audio=None,
             max_new_tokens=0,
             chunk_length=30,
-            top_k=0,
             top_p=0.7,
             repetition_penalty=1.5,
             temperature=0.7,

+ 27 - 27
tools/llama/generate.py

@@ -42,12 +42,12 @@ def multinomial_sample_one_no_sync(
 def logits_to_probs(
     logits,
     previous_tokens: Optional[torch.Tensor] = None,
-    temperature: float = 1.0,
-    top_k: Optional[int] = None,
-    top_p: Optional[int] = None,
-    repetition_penalty: float = 1.0,
-):
-    if previous_tokens is not None and repetition_penalty != 1.0:
+    temperature: torch.Tensor = 1.0,
+    top_p: torch.Tensor = 1.0,
+    repetition_penalty: torch.Tensor = 1.0,
+) -> torch.Tensor:
+    # Apply repetition penalty
+    if previous_tokens is not None:
         previous_tokens = previous_tokens.long()
         score = torch.gather(logits, dim=0, index=previous_tokens)
         score = torch.where(
@@ -55,25 +55,18 @@ def logits_to_probs(
         )
         logits.scatter_(dim=0, index=previous_tokens, src=score)
 
-    if top_p is not None and top_p < 1.0:
-        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
-        cum_probs = torch.cumsum(
-            torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
-        )
-        sorted_indices_to_remove = cum_probs > top_p
-        sorted_indices_to_remove[0] = False  # keep at least one option
-        indices_to_remove = sorted_indices_to_remove.scatter(
-            dim=0, index=sorted_indices, src=sorted_indices_to_remove
-        )
-        logits = logits.masked_fill(indices_to_remove, -float("Inf"))
+    # Apply top-p sampling
+    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+    cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
+    sorted_indices_to_remove = cum_probs > top_p
+    sorted_indices_to_remove[0] = False  # keep at least one option
+    indices_to_remove = sorted_indices_to_remove.scatter(
+        dim=0, index=sorted_indices, src=sorted_indices_to_remove
+    )
+    logits = logits.masked_fill(indices_to_remove, -float("Inf"))
 
     logits = logits / max(temperature, 1e-5)
 
-    if top_k is not None:
-        v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
-        pivot = v.select(-1, -1).unsqueeze(-1)
-        logits = torch.where(logits < pivot, -float("Inf"), logits)
-
     probs = torch.nn.functional.softmax(logits, dim=-1)
     return probs
 
@@ -449,7 +442,6 @@ def generate_long(
     text: str,
     num_samples: int = 1,
     max_new_tokens: int = 0,
-    top_k: int = None,
     top_p: int = 0.7,
     repetition_penalty: float = 1.5,
     temperature: float = 0.7,
@@ -462,6 +454,10 @@ def generate_long(
     prompt_tokens: Optional[torch.Tensor] = None,
     is_streaming: bool = False,
 ):
+    assert 0 < top_p <= 1, "top_p must be in (0, 1]"
+    assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
+    assert 0 < temperature < 2, "temperature must be in (0, 2)"
+
     model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
     im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
 
@@ -493,6 +489,14 @@ def generate_long(
         )
         logger.info(f"Encoded text: {text}")
 
+    # Move temperature, top_p, repetition_penalty to device
+    # This is important so that changing params doesn't trigger recompile
+    temperature = torch.tensor(temperature, device=device, dtype=torch.float)
+    top_p = torch.tensor(top_p, device=device, dtype=torch.float)
+    repetition_penalty = torch.tensor(
+        repetition_penalty, device=device, dtype=torch.float
+    )
+
     for sample_idx in range(num_samples):
         if torch.cuda.is_available():
             torch.cuda.synchronize()
@@ -542,7 +546,6 @@ def generate_long(
                 im_end_id=im_end_id,
                 decode_one_token=decode_one_token,
                 temperature=temperature,
-                top_k=top_k,
                 top_p=top_p,
                 repetition_penalty=repetition_penalty,
             )
@@ -660,7 +663,6 @@ def launch_thread_safe_queue(
 )
 @click.option("--num-samples", type=int, default=1)
 @click.option("--max-new-tokens", type=int, default=0)
-@click.option("--top-k", type=int, default=None)
 @click.option("--top-p", type=float, default=0.7)
 @click.option("--repetition-penalty", type=float, default=1.5)
 @click.option("--temperature", type=float, default=0.7)
@@ -684,7 +686,6 @@ def main(
     prompt_tokens: Optional[Path],
     num_samples: int,
     max_new_tokens: int,
-    top_k: int,
     top_p: int,
     repetition_penalty: float,
     temperature: float,
@@ -733,7 +734,6 @@ def main(
         text=text,
         num_samples=num_samples,
         max_new_tokens=max_new_tokens,
-        top_k=top_k,
         top_p=top_p,
         repetition_penalty=repetition_penalty,
         temperature=temperature,

+ 0 - 8
tools/webui.py

@@ -66,7 +66,6 @@ def inference(
     reference_text,
     max_new_tokens,
     chunk_length,
-    top_k,
     top_p,
     repetition_penalty,
     temperature,
@@ -107,7 +106,6 @@ def inference(
         device=vqgan_model.device,
         max_new_tokens=max_new_tokens,
         text=text,
-        top_k=int(top_k) if top_k > 0 else None,
         top_p=top_p,
         repetition_penalty=repetition_penalty,
         temperature=temperature,
@@ -193,10 +191,6 @@ def build_app():
                             step=8,
                         )
 
-                        top_k = gr.Slider(
-                            label="Top-K", minimum=0, maximum=100, value=0, step=1
-                        )
-
                         top_p = gr.Slider(
                             label="Top-P", minimum=0, maximum=1, value=0.7, step=0.01
                         )
@@ -266,7 +260,6 @@ def build_app():
                 reference_text,
                 max_new_tokens,
                 chunk_length,
-                top_k,
                 top_p,
                 repetition_penalty,
                 temperature,
@@ -337,7 +330,6 @@ if __name__ == "__main__":
         reference_text="",
         max_new_tokens=0,
         chunk_length=0,
-        top_k=0,  # 0 means no limit
         top_p=0.7,
         repetition_penalty=1.5,
         temperature=0.7,