|
@@ -34,9 +34,7 @@ if hasattr(torch._inductor.config, "fx_graph_cache"):
|
|
|
from torch.nn.attention import SDPBackend, sdpa_kernel
|
|
from torch.nn.attention import SDPBackend, sdpa_kernel
|
|
|
|
|
|
|
|
from fish_speech.models.text2semantic.llama import (
|
|
from fish_speech.models.text2semantic.llama import (
|
|
|
- BaseTransformer,
|
|
|
|
|
DualARTransformer,
|
|
DualARTransformer,
|
|
|
- NaiveTransformer,
|
|
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@@ -724,9 +722,22 @@ def generate_long(
|
|
|
|
|
|
|
|
yield GenerateResponse(action="sample", codes=codes, text=batch_text)
|
|
yield GenerateResponse(action="sample", codes=codes, text=batch_text)
|
|
|
|
|
|
|
|
|
|
+ MAX_HISTORY_TURNS = 2 # 只保留最近 2 轮 user/assistant
|
|
|
|
|
+ assistant_indices = [i for i, m in enumerate(conversation) if m.role == "assistant"]
|
|
|
|
|
+ if len(assistant_indices) > MAX_HISTORY_TURNS:
|
|
|
|
|
+ drop = assistant_indices[0]
|
|
|
|
|
+ # 移除最早的 user+assistant 对,保留 system 消息
|
|
|
|
|
+ conversation = [m for i, m in enumerate(conversation)
|
|
|
|
|
+ if i not in (drop - 1, drop)]
|
|
|
|
|
+
|
|
|
# Cleanup
|
|
# Cleanup
|
|
|
del y, encoded
|
|
del y, encoded
|
|
|
|
|
|
|
|
|
|
+ if torch.cuda.is_available():
|
|
|
|
|
+ torch.cuda.empty_cache()
|
|
|
|
|
+ import gc
|
|
|
|
|
+ gc.collect()
|
|
|
|
|
+
|
|
|
if torch.cuda.is_available():
|
|
if torch.cuda.is_available():
|
|
|
logger.info(
|
|
logger.info(
|
|
|
f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
|
|
f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
|