|
|
@@ -720,7 +720,7 @@ def generate_long(
|
|
|
# Build base conversation with system message
|
|
|
base_conversation = Conversation()
|
|
|
|
|
|
- all_codes = torch.cat([c for c in prompt_tokens], dim=1)
|
|
|
+ all_codes = None
|
|
|
|
|
|
if use_prompt:
|
|
|
# Auto-add speaker tags to prompt texts that don't have them
|
|
|
@@ -740,6 +740,7 @@ def generate_long(
|
|
|
reference_text = "\n".join(tagged_prompt_text)
|
|
|
system_parts.append(TextPart(text=reference_text, cal_loss=False))
|
|
|
system_parts.append(TextPart(text="\n\nSpeech:\n", cal_loss=False))
|
|
|
+ all_codes = torch.cat([c for c in prompt_tokens], dim=1)
|
|
|
system_parts.append(VQPart(codes=all_codes, cal_loss=False))
|
|
|
# torch.save(all_codes, "debug_vq_codes.pt")
|
|
|
else:
|