|
@@ -456,6 +456,7 @@ def generate_long(
|
|
|
speaker: Optional[str] = None,
|
|
speaker: Optional[str] = None,
|
|
|
prompt_text: Optional[str] = None,
|
|
prompt_text: Optional[str] = None,
|
|
|
prompt_tokens: Optional[torch.Tensor] = None,
|
|
prompt_tokens: Optional[torch.Tensor] = None,
|
|
|
|
|
+ is_streaming: bool = False,
|
|
|
):
|
|
):
|
|
|
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
|
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
|
@@ -580,13 +581,22 @@ def generate_long(
|
|
|
|
|
|
|
|
# But for global encoding, we should keep the <im_end> token
|
|
# But for global encoding, we should keep the <im_end> token
|
|
|
global_encoded.append(decoded)
|
|
global_encoded.append(decoded)
|
|
|
- all_codes.append(codes)
|
|
|
|
|
- seg_idx += 1
|
|
|
|
|
|
|
|
|
|
- codes = torch.cat(all_codes, dim=1)
|
|
|
|
|
- assert (codes >= 0).all(), f"Negative code found: {codes}"
|
|
|
|
|
|
|
+ if is_streaming:
|
|
|
|
|
+ assert (codes >= 0).all(), f"Negative code found: {codes}"
|
|
|
|
|
+ yield codes
|
|
|
|
|
+ else:
|
|
|
|
|
+ all_codes.append(codes)
|
|
|
|
|
|
|
|
- yield codes
|
|
|
|
|
|
|
+ seg_idx += 1
|
|
|
|
|
+
|
|
|
|
|
+ if is_streaming:
|
|
|
|
|
+ # This indicates the end of the current sample
|
|
|
|
|
+ yield None
|
|
|
|
|
+ else:
|
|
|
|
|
+ all_codes = torch.cat(all_codes, dim=1)
|
|
|
|
|
+ assert (all_codes >= 0).all(), f"Negative code found: {codes}"
|
|
|
|
|
+ yield all_codes
|
|
|
|
|
|
|
|
|
|
|
|
|
@click.command()
|
|
@click.command()
|