|
@@ -723,12 +723,12 @@ def generate_long(
|
|
|
yield GenerateResponse(action="sample", codes=codes, text=batch_text)
|
|
yield GenerateResponse(action="sample", codes=codes, text=batch_text)
|
|
|
|
|
|
|
|
MAX_HISTORY_TURNS = 2 # 只保留最近 2 轮 user/assistant
|
|
MAX_HISTORY_TURNS = 2 # 只保留最近 2 轮 user/assistant
|
|
|
- assistant_indices = [i for i, m in enumerate(conversation) if m.role == "assistant"]
|
|
|
|
|
|
|
+ assistant_indices = [i for i, m in enumerate(conversation.messages) if m.role == "assistant"]
|
|
|
if len(assistant_indices) > MAX_HISTORY_TURNS:
|
|
if len(assistant_indices) > MAX_HISTORY_TURNS:
|
|
|
drop = assistant_indices[0]
|
|
drop = assistant_indices[0]
|
|
|
# 移除最早的 user+assistant 对,保留 system 消息
|
|
# 移除最早的 user+assistant 对,保留 system 消息
|
|
|
- conversation = [m for i, m in enumerate(conversation)
|
|
|
|
|
- if i not in (drop - 1, drop)]
|
|
|
|
|
|
|
+ conversation = Conversation([m for i, m in enumerate(conversation.messages)
|
|
|
|
|
+ if i not in (drop - 1, drop)])
|
|
|
|
|
|
|
|
# Cleanup
|
|
# Cleanup
|
|
|
del y, encoded
|
|
del y, encoded
|