|
|
@@ -440,30 +440,42 @@ def generate_long(
|
|
|
max_length: int = 2048,
|
|
|
chunk_length: int = 150,
|
|
|
speaker: Optional[str] = None,
|
|
|
- prompt_text: Optional[str] = None,
|
|
|
- prompt_tokens: Optional[torch.Tensor] = None,
|
|
|
+ prompt_text: Optional[str | list[str]] = None,
|
|
|
+ prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
|
|
|
):
|
|
|
assert 0 < top_p <= 1, "top_p must be in (0, 1]"
|
|
|
assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
|
|
|
assert 0 < temperature < 2, "temperature must be in (0, 2)"
|
|
|
|
|
|
+ use_prompt = prompt_text is not None and prompt_tokens is not None
|
|
|
+ if use_prompt and isinstance(prompt_text, str):
|
|
|
+ prompt_text = [prompt_text]
|
|
|
+ prompt_tokens = [prompt_tokens]
|
|
|
+
|
|
|
+ assert use_prompt is False or len(prompt_text) == len(
|
|
|
+ prompt_tokens
|
|
|
+ ), "Prompt text and tokens must have the same length"
|
|
|
+
|
|
|
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
|
|
|
|
|
- use_prompt = prompt_text is not None and prompt_tokens is not None
|
|
|
encoded = []
|
|
|
texts = split_text(text, chunk_length) if iterative_prompt else [text]
|
|
|
+ encoded_prompts = []
|
|
|
|
|
|
if use_prompt:
|
|
|
- encoded_prompts = encode_tokens(
|
|
|
- tokenizer,
|
|
|
- prompt_text,
|
|
|
- prompt_tokens=prompt_tokens,
|
|
|
- bos=True,
|
|
|
- device=device,
|
|
|
- speaker=speaker,
|
|
|
- num_codebooks=model.config.num_codebooks,
|
|
|
- )
|
|
|
+ for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
|
|
|
+ encoded_prompts.append(
|
|
|
+ encode_tokens(
|
|
|
+ tokenizer,
|
|
|
+ string=t,
|
|
|
+ bos=idx == 0,
|
|
|
+ device=device,
|
|
|
+ prompt_tokens=c,
|
|
|
+ speaker=speaker,
|
|
|
+ num_codebooks=model.config.num_codebooks,
|
|
|
+ )
|
|
|
+ )
|
|
|
|
|
|
for idx, text in enumerate(texts):
|
|
|
encoded.append(
|
|
|
@@ -507,7 +519,9 @@ def generate_long(
|
|
|
count = 0
|
|
|
for i, length in enumerate(lengths):
|
|
|
count += length
|
|
|
- if count + length > max_length - 1024:
|
|
|
+ if count + length > max_length - 1024 - sum(
|
|
|
+ t.shape[1] for t in encoded_prompts
|
|
|
+ ):
|
|
|
break
|
|
|
|
|
|
if i != 0 and i % 2 == 0:
|
|
|
@@ -520,7 +534,7 @@ def generate_long(
|
|
|
partial_encoded = global_encoded
|
|
|
|
|
|
if use_prompt:
|
|
|
- partial_encoded = [encoded_prompts] + partial_encoded
|
|
|
+ partial_encoded = encoded_prompts + partial_encoded
|
|
|
|
|
|
cat_encoded = torch.cat(partial_encoded, dim=1)
|
|
|
prompt_length = cat_encoded.size(1)
|
|
|
@@ -643,9 +657,12 @@ def launch_thread_safe_queue(
|
|
|
type=str,
|
|
|
default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
|
|
|
)
|
|
|
-@click.option("--prompt-text", type=str, default=None)
|
|
|
+@click.option("--prompt-text", type=str, default=None, multiple=True)
|
|
|
@click.option(
|
|
|
- "--prompt-tokens", type=click.Path(path_type=Path, exists=True), default=None
|
|
|
+ "--prompt-tokens",
|
|
|
+ type=click.Path(path_type=Path, exists=True),
|
|
|
+ default=None,
|
|
|
+ multiple=True,
|
|
|
)
|
|
|
@click.option("--num-samples", type=int, default=1)
|
|
|
@click.option("--max-new-tokens", type=int, default=0)
|
|
|
@@ -665,11 +682,11 @@ def launch_thread_safe_queue(
|
|
|
@click.option("--half/--no-half", default=False)
|
|
|
@click.option("--iterative-prompt/--no-iterative-prompt", default=True)
|
|
|
@click.option("--max-length", type=int, default=2048)
|
|
|
-@click.option("--chunk-length", type=int, default=30)
|
|
|
+@click.option("--chunk-length", type=int, default=150)
|
|
|
def main(
|
|
|
text: str,
|
|
|
- prompt_text: Optional[str],
|
|
|
- prompt_tokens: Optional[Path],
|
|
|
+ prompt_text: Optional[list[str]],
|
|
|
+ prompt_tokens: Optional[list[Path]],
|
|
|
num_samples: int,
|
|
|
max_new_tokens: int,
|
|
|
top_p: int,
|
|
|
@@ -690,6 +707,11 @@ def main(
|
|
|
|
|
|
precision = torch.half if half else torch.bfloat16
|
|
|
|
|
|
+ if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
|
|
|
+ raise ValueError(
|
|
|
+ f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
|
|
|
+ )
|
|
|
+
|
|
|
logger.info("Loading model ...")
|
|
|
t0 = time.time()
|
|
|
model, decode_one_token = load_model(
|
|
|
@@ -701,11 +723,8 @@ def main(
|
|
|
|
|
|
logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
|
|
|
|
|
|
- prompt_tokens = (
|
|
|
- torch.from_numpy(np.load(prompt_tokens)).to(device)
|
|
|
- if prompt_tokens is not None
|
|
|
- else None
|
|
|
- )
|
|
|
+ if prompt_tokens is not None:
|
|
|
+ prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
|
|
torch.manual_seed(seed)
|