|
|
@@ -90,22 +90,21 @@ def sample(
|
|
|
|
|
|
|
|
|
def decode_one_token_ar(
|
|
|
- model: NaiveTransformer,
|
|
|
+ model: DualARTransformer,
|
|
|
x: torch.Tensor,
|
|
|
input_pos: torch.Tensor,
|
|
|
previous_tokens: torch.Tensor = None,
|
|
|
**sampling_kwargs,
|
|
|
) -> torch.Tensor:
|
|
|
- assert input_pos.shape[-1] == 1
|
|
|
-
|
|
|
- x, logits = model.forward_generate_slow(x, input_pos)
|
|
|
+ x = model.forward_generate(x, input_pos)
|
|
|
codebooks = [
|
|
|
sample(
|
|
|
- logits,
|
|
|
+ x.logits,
|
|
|
previous_tokens=None, # Disable repetition penalty for the token codebook
|
|
|
**sampling_kwargs,
|
|
|
)[0]
|
|
|
]
|
|
|
+ x = x.hidden_states
|
|
|
|
|
|
# Cleanup the cache
|
|
|
for layer in model.fast_layers:
|
|
|
@@ -137,12 +136,11 @@ def decode_one_token_naive(
|
|
|
previous_tokens: torch.Tensor = None,
|
|
|
**sampling_kwargs,
|
|
|
) -> torch.Tensor:
|
|
|
- assert input_pos.shape[-1] == 1
|
|
|
+ x = model.forward_generate(x, input_pos)
|
|
|
|
|
|
- x, logits = model.forward_generate_slow(x, input_pos)
|
|
|
codebooks = [
|
|
|
sample(
|
|
|
- logits,
|
|
|
+ x.token_logits,
|
|
|
previous_tokens=None, # Disable repetition penalty for the token codebook
|
|
|
**sampling_kwargs,
|
|
|
)[0]
|
|
|
@@ -151,7 +149,7 @@ def decode_one_token_naive(
|
|
|
for i in range(model.config.num_codebooks):
|
|
|
codebooks.append(
|
|
|
sample(
|
|
|
- logits.codebook_logits[:, :, i],
|
|
|
+ x.codebook_logits[:, :, i],
|
|
|
previous_tokens=previous_tokens[i + 1]
|
|
|
if previous_tokens is not None
|
|
|
else None,
|
|
|
@@ -343,11 +341,13 @@ def encode_tokens(
|
|
|
return prompt
|
|
|
|
|
|
|
|
|
-def load_model(config_name, checkpoint_path, device, precision):
|
|
|
- with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
|
|
|
- cfg = compose(config_name=config_name)
|
|
|
+def load_model(config_name, checkpoint_path, device, precision, max_length):
|
|
|
+ with initialize(version_base="1.3", config_path="../../fish_speech/configs/model"):
|
|
|
+ cfg = compose(
|
|
|
+ config_name=config_name, overrides=[f"config.max_seq_len={max_length}"]
|
|
|
+ )
|
|
|
|
|
|
- model: Union[NaiveTransformer, DualARTransformer] = instantiate(cfg.model).model
|
|
|
+ model: Union[NaiveTransformer, DualARTransformer] = instantiate(cfg)
|
|
|
|
|
|
if "int8" in str(checkpoint_path):
|
|
|
logger.info("Using int8 weight-only quantization!")
|
|
|
@@ -421,7 +421,7 @@ def split_text(text, min_length):
|
|
|
type=click.Path(path_type=Path, exists=True),
|
|
|
default="results/text2semantic_400m_finetune/step_000002000.pth",
|
|
|
)
|
|
|
-@click.option("--config-name", type=str, default="text2semantic_finetune")
|
|
|
+@click.option("--config-name", type=str, default="dual_ar_8_codebook_small")
|
|
|
@click.option("--tokenizer", type=str, default="fishaudio/speech-lm-v1")
|
|
|
@click.option("--compile/--no-compile", default=False)
|
|
|
@click.option("--use-g2p/--no-g2p", default=True)
|
|
|
@@ -430,6 +430,8 @@ def split_text(text, min_length):
|
|
|
@click.option("--order", type=str, default="zh,jp,en")
|
|
|
@click.option("--half/--no-half", default=False)
|
|
|
@click.option("--iterative-prompt/--no-iterative-prompt", default=False)
|
|
|
+@click.option("--max-length", type=int, default=2048)
|
|
|
+@click.option("--chunk-length", type=int, default=30)
|
|
|
def main(
|
|
|
text: str,
|
|
|
prompt_text: Optional[str],
|
|
|
@@ -450,6 +452,8 @@ def main(
|
|
|
order: str,
|
|
|
half: bool,
|
|
|
iterative_prompt: bool,
|
|
|
+ max_length: int,
|
|
|
+ chunk_length: int,
|
|
|
) -> None:
|
|
|
device = "cuda"
|
|
|
|
|
|
@@ -457,7 +461,7 @@ def main(
|
|
|
|
|
|
logger.info("Loading model ...")
|
|
|
t0 = time.time()
|
|
|
- model, cfg = load_model(config_name, checkpoint_path, device, precision)
|
|
|
+ model, cfg = load_model(config_name, checkpoint_path, device, precision, max_length)
|
|
|
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
@@ -472,7 +476,7 @@ def main(
|
|
|
|
|
|
use_prompt = prompt_text is not None and prompt_tokens is not None
|
|
|
encoded = []
|
|
|
- texts = split_text(text, 30) if iterative_prompt else [text]
|
|
|
+ texts = split_text(text, chunk_length) if iterative_prompt else [text]
|
|
|
for idx, text in enumerate(texts):
|
|
|
encoded.append(
|
|
|
encode_tokens(
|
|
|
@@ -506,13 +510,15 @@ def main(
|
|
|
torch.manual_seed(seed)
|
|
|
torch.cuda.manual_seed(seed)
|
|
|
|
|
|
- decode_one_token = (
|
|
|
- decode_one_token_ar
|
|
|
- if isinstance(model, DualARTransformer)
|
|
|
- else decode_one_token_naive
|
|
|
- )
|
|
|
+ if isinstance(model, DualARTransformer):
|
|
|
+ decode_one_token = decode_one_token_ar
|
|
|
+ logger.info("Using DualARTransformer")
|
|
|
+ else:
|
|
|
+ decode_one_token = decode_one_token_naive
|
|
|
+ logger.info("Using NaiveTransformer")
|
|
|
|
|
|
if compile:
|
|
|
+ logger.info("Compiling function...")
|
|
|
decode_one_token = torch.compile(
|
|
|
decode_one_token, mode="reduce-overhead", fullgraph=True
|
|
|
)
|
|
|
@@ -528,11 +534,12 @@ def main(
|
|
|
global_encoded.append(seg)
|
|
|
|
|
|
lengths = reversed([seg.size(1) for seg in global_encoded])
|
|
|
+
|
|
|
# Pick last 2000 tokens
|
|
|
count = 0
|
|
|
for i, length in enumerate(lengths):
|
|
|
count += length
|
|
|
- if count >= 2000:
|
|
|
+ if count + length > max_length - 1024:
|
|
|
break
|
|
|
|
|
|
if i != 0 and i % 2 == 0:
|
|
|
@@ -561,7 +568,7 @@ def main(
|
|
|
repetition_penalty=repetition_penalty,
|
|
|
)
|
|
|
|
|
|
- if idx == 0 and compile:
|
|
|
+ if idx == 0 and seg_idx == 0 and compile:
|
|
|
logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
|
|
|
|
|
|
torch.cuda.synchronize()
|