Parcourir la source

feat:添加耗时统计

zhaohaipeng il y a 1 mois
Parent
commit
e772ffdcef
1 fichiers modifiés avec 2 ajouts et 1 suppressions
  1. 2 1
      fish_speech/models/text2semantic/inference.py

+ 2 - 1
fish_speech/models/text2semantic/inference.py

@@ -720,7 +720,7 @@ def generate_long(
     # Build base conversation with system message
     base_conversation = Conversation()
 
-    all_codes = torch.cat([c for c in prompt_tokens], dim=1)
+    all_codes = None
 
     if use_prompt:
         # Auto-add speaker tags to prompt texts that don't have them
@@ -740,6 +740,7 @@ def generate_long(
         reference_text = "\n".join(tagged_prompt_text)
         system_parts.append(TextPart(text=reference_text, cal_loss=False))
         system_parts.append(TextPart(text="\n\nSpeech:\n", cal_loss=False))
+        all_codes = torch.cat([c for c in prompt_tokens], dim=1)
         system_parts.append(VQPart(codes=all_codes, cal_loss=False))
         # torch.save(all_codes, "debug_vq_codes.pt")
     else: