Parcourir la source

feat:控制prompt增长

zhaohaipeng il y a 2 semaines
Parent
commit
11ee914699
2 fichiers modifiés avec 14 ajouts et 3 suppressions
  1. 1 1
      .env
  2. 13 2
      fish_speech/models/text2semantic/inference.py

+ 1 - 1
.env

@@ -1,4 +1,4 @@
-API_PORT=443
+API_PORT=8080
 COMPILE=1
 COMPILE=1
 HALF=1
 HALF=1
 PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
 PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

+ 13 - 2
fish_speech/models/text2semantic/inference.py

@@ -34,9 +34,7 @@ if hasattr(torch._inductor.config, "fx_graph_cache"):
 from torch.nn.attention import SDPBackend, sdpa_kernel
 from torch.nn.attention import SDPBackend, sdpa_kernel
 
 
 from fish_speech.models.text2semantic.llama import (
 from fish_speech.models.text2semantic.llama import (
-    BaseTransformer,
     DualARTransformer,
     DualARTransformer,
-    NaiveTransformer,
 )
 )
 
 
 
 
@@ -724,9 +722,22 @@ def generate_long(
 
 
             yield GenerateResponse(action="sample", codes=codes, text=batch_text)
             yield GenerateResponse(action="sample", codes=codes, text=batch_text)
 
 
+            MAX_HISTORY_TURNS = 2  # 只保留最近 2 轮 user/assistant
+            assistant_indices = [i for i, m in enumerate(conversation) if m.role == "assistant"]
+            if len(assistant_indices) > MAX_HISTORY_TURNS:
+                drop = assistant_indices[0]
+                # 移除最早的 user+assistant 对,保留 system 消息
+                conversation = [m for i, m in enumerate(conversation)
+                                if i not in (drop - 1, drop)]
+
             # Cleanup
             # Cleanup
             del y, encoded
             del y, encoded
 
 
+            if torch.cuda.is_available():
+                torch.cuda.empty_cache()
+            import gc
+            gc.collect()
+
         if torch.cuda.is_available():
         if torch.cuda.is_available():
             logger.info(
             logger.info(
                 f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
                 f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"