|
|
@@ -42,12 +42,12 @@ def multinomial_sample_one_no_sync(
|
|
|
def logits_to_probs(
|
|
|
logits,
|
|
|
previous_tokens: Optional[torch.Tensor] = None,
|
|
|
- temperature: float = 1.0,
|
|
|
- top_k: Optional[int] = None,
|
|
|
- top_p: Optional[int] = None,
|
|
|
- repetition_penalty: float = 1.0,
|
|
|
-):
|
|
|
- if previous_tokens is not None and repetition_penalty != 1.0:
|
|
|
+ temperature: torch.Tensor = 1.0,
|
|
|
+ top_p: torch.Tensor = 1.0,
|
|
|
+ repetition_penalty: torch.Tensor = 1.0,
|
|
|
+) -> torch.Tensor:
|
|
|
+ # Apply repetition penalty
|
|
|
+ if previous_tokens is not None:
|
|
|
previous_tokens = previous_tokens.long()
|
|
|
score = torch.gather(logits, dim=0, index=previous_tokens)
|
|
|
score = torch.where(
|
|
|
@@ -55,25 +55,18 @@ def logits_to_probs(
|
|
|
)
|
|
|
logits.scatter_(dim=0, index=previous_tokens, src=score)
|
|
|
|
|
|
- if top_p is not None and top_p < 1.0:
|
|
|
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
|
- cum_probs = torch.cumsum(
|
|
|
- torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
|
|
|
- )
|
|
|
- sorted_indices_to_remove = cum_probs > top_p
|
|
|
- sorted_indices_to_remove[0] = False # keep at least one option
|
|
|
- indices_to_remove = sorted_indices_to_remove.scatter(
|
|
|
- dim=0, index=sorted_indices, src=sorted_indices_to_remove
|
|
|
- )
|
|
|
- logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
|
|
+ # Apply top-p sampling
|
|
|
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
|
+ cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
|
|
+ sorted_indices_to_remove = cum_probs > top_p
|
|
|
+ sorted_indices_to_remove[0] = False # keep at least one option
|
|
|
+ indices_to_remove = sorted_indices_to_remove.scatter(
|
|
|
+ dim=0, index=sorted_indices, src=sorted_indices_to_remove
|
|
|
+ )
|
|
|
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
|
|
|
|
|
logits = logits / max(temperature, 1e-5)
|
|
|
|
|
|
- if top_k is not None:
|
|
|
- v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
|
|
- pivot = v.select(-1, -1).unsqueeze(-1)
|
|
|
- logits = torch.where(logits < pivot, -float("Inf"), logits)
|
|
|
-
|
|
|
probs = torch.nn.functional.softmax(logits, dim=-1)
|
|
|
return probs
|
|
|
|
|
|
@@ -449,7 +442,6 @@ def generate_long(
|
|
|
text: str,
|
|
|
num_samples: int = 1,
|
|
|
max_new_tokens: int = 0,
|
|
|
- top_k: int = None,
|
|
|
top_p: int = 0.7,
|
|
|
repetition_penalty: float = 1.5,
|
|
|
temperature: float = 0.7,
|
|
|
@@ -462,6 +454,10 @@ def generate_long(
|
|
|
prompt_tokens: Optional[torch.Tensor] = None,
|
|
|
is_streaming: bool = False,
|
|
|
):
|
|
|
+ assert 0 < top_p <= 1, "top_p must be in (0, 1]"
|
|
|
+ assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
|
|
|
+ assert 0 < temperature < 2, "temperature must be in (0, 2)"
|
|
|
+
|
|
|
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
|
|
|
|
|
@@ -493,6 +489,14 @@ def generate_long(
|
|
|
)
|
|
|
logger.info(f"Encoded text: {text}")
|
|
|
|
|
|
+ # Move temperature, top_p, repetition_penalty to device
|
|
|
+ # This is important so that changing params doesn't trigger recompile
|
|
|
+ temperature = torch.tensor(temperature, device=device, dtype=torch.float)
|
|
|
+ top_p = torch.tensor(top_p, device=device, dtype=torch.float)
|
|
|
+ repetition_penalty = torch.tensor(
|
|
|
+ repetition_penalty, device=device, dtype=torch.float
|
|
|
+ )
|
|
|
+
|
|
|
for sample_idx in range(num_samples):
|
|
|
if torch.cuda.is_available():
|
|
|
torch.cuda.synchronize()
|
|
|
@@ -542,7 +546,6 @@ def generate_long(
|
|
|
im_end_id=im_end_id,
|
|
|
decode_one_token=decode_one_token,
|
|
|
temperature=temperature,
|
|
|
- top_k=top_k,
|
|
|
top_p=top_p,
|
|
|
repetition_penalty=repetition_penalty,
|
|
|
)
|
|
|
@@ -660,7 +663,6 @@ def launch_thread_safe_queue(
|
|
|
)
|
|
|
@click.option("--num-samples", type=int, default=1)
|
|
|
@click.option("--max-new-tokens", type=int, default=0)
|
|
|
-@click.option("--top-k", type=int, default=None)
|
|
|
@click.option("--top-p", type=float, default=0.7)
|
|
|
@click.option("--repetition-penalty", type=float, default=1.5)
|
|
|
@click.option("--temperature", type=float, default=0.7)
|
|
|
@@ -684,7 +686,6 @@ def main(
|
|
|
prompt_tokens: Optional[Path],
|
|
|
num_samples: int,
|
|
|
max_new_tokens: int,
|
|
|
- top_k: int,
|
|
|
top_p: int,
|
|
|
repetition_penalty: float,
|
|
|
temperature: float,
|
|
|
@@ -733,7 +734,6 @@ def main(
|
|
|
text=text,
|
|
|
num_samples=num_samples,
|
|
|
max_new_tokens=max_new_tokens,
|
|
|
- top_k=top_k,
|
|
|
top_p=top_p,
|
|
|
repetition_penalty=repetition_penalty,
|
|
|
temperature=temperature,
|