|
@@ -219,7 +219,6 @@ def generate(
|
|
|
eos_token_id: int = 2,
|
|
eos_token_id: int = 2,
|
|
|
im_end_id: int = 4,
|
|
im_end_id: int = 4,
|
|
|
decode_one_token=decode_one_token_naive,
|
|
decode_one_token=decode_one_token_naive,
|
|
|
- precision: torch.dtype = torch.bfloat16,
|
|
|
|
|
**sampling_kwargs,
|
|
**sampling_kwargs,
|
|
|
) -> torch.Tensor:
|
|
) -> torch.Tensor:
|
|
|
"""
|
|
"""
|
|
@@ -241,7 +240,9 @@ def generate(
|
|
|
|
|
|
|
|
device, dtype = prompt.device, prompt.dtype
|
|
device, dtype = prompt.device, prompt.dtype
|
|
|
with torch.device(device):
|
|
with torch.device(device):
|
|
|
- model.setup_caches(max_batch_size=1, max_seq_len=T_new, dtype=precision)
|
|
|
|
|
|
|
+ model.setup_caches(
|
|
|
|
|
+ max_batch_size=1, max_seq_len=T_new, dtype=next(model.parameters()).dtype
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
codebook_dim = 1 + model.config.num_codebooks
|
|
codebook_dim = 1 + model.config.num_codebooks
|
|
|
# create an empty tensor of the expected final shape and fill in the current tokens
|
|
# create an empty tensor of the expected final shape and fill in the current tokens
|
|
@@ -250,7 +251,13 @@ def generate(
|
|
|
seq = empty
|
|
seq = empty
|
|
|
input_pos = torch.arange(0, T, device=device)
|
|
input_pos = torch.arange(0, T, device=device)
|
|
|
|
|
|
|
|
- next_token = decode_one_token(
|
|
|
|
|
|
|
+ # Use non-accelerated version for now, to avoid compilation overhead
|
|
|
|
|
+ prefill_decode = (
|
|
|
|
|
+ decode_one_token_naive
|
|
|
|
|
+ if isinstance(model, NaiveTransformer)
|
|
|
|
|
+ else decode_one_token_ar
|
|
|
|
|
+ )
|
|
|
|
|
+ next_token = prefill_decode(
|
|
|
model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
|
|
model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
|
|
|
)
|
|
)
|
|
|
seq[:, T : T + 1] = next_token
|
|
seq[:, T : T + 1] = next_token
|
|
@@ -338,7 +345,9 @@ def encode_tokens(
|
|
|
return prompt
|
|
return prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
-def load_model(config_name, checkpoint_path, device, precision, max_length):
|
|
|
|
|
|
|
+def load_model(
|
|
|
|
|
+ config_name, checkpoint_path, device, precision, max_length, compile=False
|
|
|
|
|
+):
|
|
|
with initialize(version_base="1.3", config_path="../../fish_speech/configs/model"):
|
|
with initialize(version_base="1.3", config_path="../../fish_speech/configs/model"):
|
|
|
cfg = compose(
|
|
cfg = compose(
|
|
|
config_name=config_name, overrides=[f"config.max_seq_len={max_length}"]
|
|
config_name=config_name, overrides=[f"config.max_seq_len={max_length}"]
|
|
@@ -379,7 +388,20 @@ def load_model(config_name, checkpoint_path, device, precision, max_length):
|
|
|
model = model.to(device=device, dtype=precision)
|
|
model = model.to(device=device, dtype=precision)
|
|
|
logger.info("Restored model from checkpoint")
|
|
logger.info("Restored model from checkpoint")
|
|
|
|
|
|
|
|
- return model.eval(), cfg
|
|
|
|
|
|
|
+ if isinstance(model, DualARTransformer):
|
|
|
|
|
+ decode_one_token = decode_one_token_ar
|
|
|
|
|
+ logger.info("Using DualARTransformer")
|
|
|
|
|
+ else:
|
|
|
|
|
+ decode_one_token = decode_one_token_naive
|
|
|
|
|
+ logger.info("Using NaiveTransformer")
|
|
|
|
|
+
|
|
|
|
|
+ if compile:
|
|
|
|
|
+ logger.info("Compiling function...")
|
|
|
|
|
+ decode_one_token = torch.compile(
|
|
|
|
|
+ decode_one_token, mode="reduce-overhead", fullgraph=True
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ return model.eval(), decode_one_token
|
|
|
|
|
|
|
|
|
|
|
|
|
def split_text(text, min_length):
|
|
def split_text(text, min_length):
|
|
@@ -401,76 +423,28 @@ def split_text(text, min_length):
|
|
|
return segments
|
|
return segments
|
|
|
|
|
|
|
|
|
|
|
|
|
-@click.command()
|
|
|
|
|
-@click.option(
|
|
|
|
|
- "--text",
|
|
|
|
|
- type=str,
|
|
|
|
|
- default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
|
|
|
|
|
-)
|
|
|
|
|
-@click.option("--prompt-text", type=str, default=None)
|
|
|
|
|
-@click.option(
|
|
|
|
|
- "--prompt-tokens", type=click.Path(path_type=Path, exists=True), default=None
|
|
|
|
|
-)
|
|
|
|
|
-@click.option("--num-samples", type=int, default=1)
|
|
|
|
|
-@click.option("--max-new-tokens", type=int, default=0)
|
|
|
|
|
-@click.option("--top-k", type=int, default=None)
|
|
|
|
|
-@click.option("--top-p", type=float, default=0.7)
|
|
|
|
|
-@click.option("--repetition-penalty", type=float, default=1.5)
|
|
|
|
|
-@click.option("--temperature", type=float, default=0.7)
|
|
|
|
|
-@click.option(
|
|
|
|
|
- "--checkpoint-path",
|
|
|
|
|
- type=click.Path(path_type=Path, exists=True),
|
|
|
|
|
- default="results/text2semantic_400m_finetune/step_000002000.pth",
|
|
|
|
|
-)
|
|
|
|
|
-@click.option("--config-name", type=str, default="dual_ar_8_codebook_small")
|
|
|
|
|
-@click.option("--tokenizer", type=str, default="fishaudio/fish-speech-1")
|
|
|
|
|
-@click.option("--compile/--no-compile", default=False)
|
|
|
|
|
-@click.option("--seed", type=int, default=42)
|
|
|
|
|
-@click.option("--speaker", type=str, default=None)
|
|
|
|
|
-@click.option("--half/--no-half", default=False)
|
|
|
|
|
-@click.option("--iterative-prompt/--no-iterative-prompt", default=False)
|
|
|
|
|
-@click.option("--max-length", type=int, default=2048)
|
|
|
|
|
-@click.option("--chunk-length", type=int, default=30)
|
|
|
|
|
-def main(
|
|
|
|
|
|
|
+def generate_long(
|
|
|
|
|
+ *,
|
|
|
|
|
+ model,
|
|
|
|
|
+ tokenizer: callable,
|
|
|
|
|
+ device: str | torch.device,
|
|
|
|
|
+ decode_one_token: callable,
|
|
|
text: str,
|
|
text: str,
|
|
|
- prompt_text: Optional[str],
|
|
|
|
|
- prompt_tokens: Optional[Path],
|
|
|
|
|
- num_samples: int,
|
|
|
|
|
- max_new_tokens: int,
|
|
|
|
|
- top_k: int,
|
|
|
|
|
- top_p: int,
|
|
|
|
|
- repetition_penalty: float,
|
|
|
|
|
- temperature: float,
|
|
|
|
|
- checkpoint_path: Path,
|
|
|
|
|
- config_name: str,
|
|
|
|
|
- tokenizer: str,
|
|
|
|
|
- compile: bool,
|
|
|
|
|
- seed: int,
|
|
|
|
|
- speaker: Optional[str],
|
|
|
|
|
- half: bool,
|
|
|
|
|
- iterative_prompt: bool,
|
|
|
|
|
- max_length: int,
|
|
|
|
|
- chunk_length: int,
|
|
|
|
|
-) -> None:
|
|
|
|
|
- device = "cuda"
|
|
|
|
|
-
|
|
|
|
|
- precision = torch.half if half else torch.bfloat16
|
|
|
|
|
-
|
|
|
|
|
- logger.info("Loading model ...")
|
|
|
|
|
- t0 = time.time()
|
|
|
|
|
- model, cfg = load_model(config_name, checkpoint_path, device, precision, max_length)
|
|
|
|
|
|
|
+ num_samples: int = 1,
|
|
|
|
|
+ max_new_tokens: int = 0,
|
|
|
|
|
+ top_k: int = None,
|
|
|
|
|
+ top_p: int = 0.7,
|
|
|
|
|
+ repetition_penalty: float = 1.5,
|
|
|
|
|
+ temperature: float = 0.7,
|
|
|
|
|
+ compile: bool = False,
|
|
|
|
|
+ iterative_prompt: bool = True,
|
|
|
|
|
+ max_length: int = 2048,
|
|
|
|
|
+ chunk_length: int = 30,
|
|
|
|
|
+ speaker: Optional[str] = None,
|
|
|
|
|
+ prompt_text: Optional[str] = None,
|
|
|
|
|
+ prompt_tokens: Optional[torch.Tensor] = None,
|
|
|
|
|
+):
|
|
|
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
-
|
|
|
|
|
- torch.cuda.synchronize()
|
|
|
|
|
- logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
|
|
|
|
|
-
|
|
|
|
|
- tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
|
|
|
|
- prompt_tokens = (
|
|
|
|
|
- torch.from_numpy(np.load(prompt_tokens)).to(device)
|
|
|
|
|
- if prompt_tokens is not None
|
|
|
|
|
- else None
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
|
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
|
|
|
|
|
|
|
use_prompt = prompt_text is not None and prompt_tokens is not None
|
|
use_prompt = prompt_text is not None and prompt_tokens is not None
|
|
@@ -502,29 +476,17 @@ def main(
|
|
|
|
|
|
|
|
encoded[0] = torch.cat((encoded_prompt, encoded[0]), dim=1)
|
|
encoded[0] = torch.cat((encoded_prompt, encoded[0]), dim=1)
|
|
|
|
|
|
|
|
- torch.manual_seed(seed)
|
|
|
|
|
- torch.cuda.manual_seed(seed)
|
|
|
|
|
-
|
|
|
|
|
- if isinstance(model, DualARTransformer):
|
|
|
|
|
- decode_one_token = decode_one_token_ar
|
|
|
|
|
- logger.info("Using DualARTransformer")
|
|
|
|
|
- else:
|
|
|
|
|
- decode_one_token = decode_one_token_naive
|
|
|
|
|
- logger.info("Using NaiveTransformer")
|
|
|
|
|
-
|
|
|
|
|
- if compile:
|
|
|
|
|
- logger.info("Compiling function...")
|
|
|
|
|
- decode_one_token = torch.compile(
|
|
|
|
|
- decode_one_token, mode="reduce-overhead", fullgraph=True
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- for idx in range(num_samples):
|
|
|
|
|
|
|
+ for sample_idx in range(num_samples):
|
|
|
torch.cuda.synchronize()
|
|
torch.cuda.synchronize()
|
|
|
global_encoded = []
|
|
global_encoded = []
|
|
|
all_codes = []
|
|
all_codes = []
|
|
|
seg_idx = 0
|
|
seg_idx = 0
|
|
|
|
|
|
|
|
while seg_idx < len(encoded):
|
|
while seg_idx < len(encoded):
|
|
|
|
|
+ logger.info(
|
|
|
|
|
+ f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
seg = encoded[seg_idx]
|
|
seg = encoded[seg_idx]
|
|
|
global_encoded.append(seg)
|
|
global_encoded.append(seg)
|
|
|
|
|
|
|
@@ -557,14 +519,13 @@ def main(
|
|
|
eos_token_id=tokenizer.eos_token_id,
|
|
eos_token_id=tokenizer.eos_token_id,
|
|
|
im_end_id=im_end_id,
|
|
im_end_id=im_end_id,
|
|
|
decode_one_token=decode_one_token,
|
|
decode_one_token=decode_one_token,
|
|
|
- precision=precision,
|
|
|
|
|
temperature=temperature,
|
|
temperature=temperature,
|
|
|
top_k=top_k,
|
|
top_k=top_k,
|
|
|
top_p=top_p,
|
|
top_p=top_p,
|
|
|
repetition_penalty=repetition_penalty,
|
|
repetition_penalty=repetition_penalty,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- if idx == 0 and seg_idx == 0 and compile:
|
|
|
|
|
|
|
+ if sample_idx == 0 and seg_idx == 0 and compile:
|
|
|
logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
|
|
logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
|
|
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
torch.cuda.synchronize()
|
|
@@ -607,6 +568,104 @@ def main(
|
|
|
codes = torch.cat(all_codes, dim=1)
|
|
codes = torch.cat(all_codes, dim=1)
|
|
|
assert (codes >= 0).all(), f"Negative code found: {codes}"
|
|
assert (codes >= 0).all(), f"Negative code found: {codes}"
|
|
|
|
|
|
|
|
|
|
+ yield codes
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@click.command()
|
|
|
|
|
+@click.option(
|
|
|
|
|
+ "--text",
|
|
|
|
|
+ type=str,
|
|
|
|
|
+ default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
|
|
|
|
|
+)
|
|
|
|
|
+@click.option("--prompt-text", type=str, default=None)
|
|
|
|
|
+@click.option(
|
|
|
|
|
+ "--prompt-tokens", type=click.Path(path_type=Path, exists=True), default=None
|
|
|
|
|
+)
|
|
|
|
|
+@click.option("--num-samples", type=int, default=1)
|
|
|
|
|
+@click.option("--max-new-tokens", type=int, default=0)
|
|
|
|
|
+@click.option("--top-k", type=int, default=None)
|
|
|
|
|
+@click.option("--top-p", type=float, default=0.7)
|
|
|
|
|
+@click.option("--repetition-penalty", type=float, default=1.5)
|
|
|
|
|
+@click.option("--temperature", type=float, default=0.7)
|
|
|
|
|
+@click.option(
|
|
|
|
|
+ "--checkpoint-path",
|
|
|
|
|
+ type=click.Path(path_type=Path, exists=True),
|
|
|
|
|
+ default="results/text2semantic_400m_finetune/step_000002000.pth",
|
|
|
|
|
+)
|
|
|
|
|
+@click.option("--config-name", type=str, default="dual_ar_8_codebook_small")
|
|
|
|
|
+@click.option("--tokenizer", type=str, default="fishaudio/fish-speech-1")
|
|
|
|
|
+@click.option("--compile/--no-compile", default=False)
|
|
|
|
|
+@click.option("--seed", type=int, default=42)
|
|
|
|
|
+@click.option("--speaker", type=str, default=None)
|
|
|
|
|
+@click.option("--half/--no-half", default=False)
|
|
|
|
|
+@click.option("--iterative-prompt/--no-iterative-prompt", default=True)
|
|
|
|
|
+@click.option("--max-length", type=int, default=2048)
|
|
|
|
|
+@click.option("--chunk-length", type=int, default=30)
|
|
|
|
|
+def main(
|
|
|
|
|
+ text: str,
|
|
|
|
|
+ prompt_text: Optional[str],
|
|
|
|
|
+ prompt_tokens: Optional[Path],
|
|
|
|
|
+ num_samples: int,
|
|
|
|
|
+ max_new_tokens: int,
|
|
|
|
|
+ top_k: int,
|
|
|
|
|
+ top_p: int,
|
|
|
|
|
+ repetition_penalty: float,
|
|
|
|
|
+ temperature: float,
|
|
|
|
|
+ checkpoint_path: Path,
|
|
|
|
|
+ config_name: str,
|
|
|
|
|
+ tokenizer: str,
|
|
|
|
|
+ compile: bool,
|
|
|
|
|
+ seed: int,
|
|
|
|
|
+ speaker: Optional[str],
|
|
|
|
|
+ half: bool,
|
|
|
|
|
+ iterative_prompt: bool,
|
|
|
|
|
+ max_length: int,
|
|
|
|
|
+ chunk_length: int,
|
|
|
|
|
+) -> None:
|
|
|
|
|
+ device = "cuda"
|
|
|
|
|
+
|
|
|
|
|
+ precision = torch.half if half else torch.bfloat16
|
|
|
|
|
+
|
|
|
|
|
+ logger.info("Loading model ...")
|
|
|
|
|
+ t0 = time.time()
|
|
|
|
|
+ model, decode_one_token = load_model(
|
|
|
|
|
+ config_name, checkpoint_path, device, precision, max_length, compile=compile
|
|
|
|
|
+ )
|
|
|
|
|
+ torch.cuda.synchronize()
|
|
|
|
|
+ logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
|
|
|
|
|
+
|
|
|
|
|
+ prompt_tokens = (
|
|
|
|
|
+ torch.from_numpy(np.load(prompt_tokens)).to(device)
|
|
|
|
|
+ if prompt_tokens is not None
|
|
|
|
|
+ else None
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
|
|
|
|
+ torch.manual_seed(seed)
|
|
|
|
|
+ torch.cuda.manual_seed(seed)
|
|
|
|
|
+
|
|
|
|
|
+ generator = generate_long(
|
|
|
|
|
+ model=model,
|
|
|
|
|
+ device=device,
|
|
|
|
|
+ decode_one_token=decode_one_token,
|
|
|
|
|
+ text=text,
|
|
|
|
|
+ num_samples=num_samples,
|
|
|
|
|
+ max_new_tokens=max_new_tokens,
|
|
|
|
|
+ top_k=top_k,
|
|
|
|
|
+ top_p=top_p,
|
|
|
|
|
+ repetition_penalty=repetition_penalty,
|
|
|
|
|
+ temperature=temperature,
|
|
|
|
|
+ tokenizer=tokenizer,
|
|
|
|
|
+ compile=compile,
|
|
|
|
|
+ speaker=speaker,
|
|
|
|
|
+ iterative_prompt=iterative_prompt,
|
|
|
|
|
+ max_length=max_length,
|
|
|
|
|
+ chunk_length=chunk_length,
|
|
|
|
|
+ prompt_text=prompt_text,
|
|
|
|
|
+ prompt_tokens=prompt_tokens,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ for idx, codes in enumerate(generator):
|
|
|
np.save(f"codes_{idx}.npy", codes.cpu().numpy())
|
|
np.save(f"codes_{idx}.npy", codes.cpu().numpy())
|
|
|
logger.info(f"Saved codes to codes_{idx}.npy")
|
|
logger.info(f"Saved codes to codes_{idx}.npy")
|
|
|
|
|
|