zhaohaipeng 1 месяц назад
Родитель
Сommit
1b593b6f78
1 измененных файлов с 2 добавлено и 1 удалено
  1. 2 1
      fish_speech/models/text2semantic/inference.py

+ 2 - 1
fish_speech/models/text2semantic/inference.py

@@ -720,6 +720,8 @@ def generate_long(
     # Build base conversation with system message
     # Build base conversation with system message
     base_conversation = Conversation()
     base_conversation = Conversation()
 
 
+    all_codes = torch.cat([c for c in prompt_tokens], dim=1)
+
     if use_prompt:
     if use_prompt:
         # Auto-add speaker tags to prompt texts that don't have them
         # Auto-add speaker tags to prompt texts that don't have them
         tagged_prompt_text = []
         tagged_prompt_text = []
@@ -738,7 +740,6 @@ def generate_long(
         reference_text = "\n".join(tagged_prompt_text)
         reference_text = "\n".join(tagged_prompt_text)
         system_parts.append(TextPart(text=reference_text, cal_loss=False))
         system_parts.append(TextPart(text=reference_text, cal_loss=False))
         system_parts.append(TextPart(text="\n\nSpeech:\n", cal_loss=False))
         system_parts.append(TextPart(text="\n\nSpeech:\n", cal_loss=False))
-        all_codes = torch.cat([c for c in prompt_tokens], dim=1)
         system_parts.append(VQPart(codes=all_codes, cal_loss=False))
         system_parts.append(VQPart(codes=all_codes, cal_loss=False))
         # torch.save(all_codes, "debug_vq_codes.pt")
         # torch.save(all_codes, "debug_vq_codes.pt")
     else:
     else: