Просмотр исходного кода

feat:去除控制prompt的长度

zhaohaipeng 2 недель назад
Родитель
Сommit
006d51a5e5
2 измененных файлов с 20 добавлено и 15 удалено
  1. 7 7
      fish_speech/models/text2semantic/inference.py
  2. 13 8
      tools/api_server.py

+ 7 - 7
fish_speech/models/text2semantic/inference.py

@@ -722,13 +722,13 @@ def generate_long(
 
             yield GenerateResponse(action="sample", codes=codes, text=batch_text)
 
-            MAX_HISTORY_TURNS = 2  # 只保留最近 2 轮 user/assistant
-            assistant_indices = [i for i, m in enumerate(conversation.messages) if m.role == "assistant"]
-            if len(assistant_indices) > MAX_HISTORY_TURNS:
-                drop = assistant_indices[0]
-                # 移除最早的 user+assistant 对,保留 system 消息
-                conversation = Conversation([m for i, m in enumerate(conversation.messages)
-                                             if i not in (drop - 1, drop)])
+            # MAX_HISTORY_TURNS = 2  # 只保留最近 2 轮 user/assistant
+            # assistant_indices = [i for i, m in enumerate(conversation.messages) if m.role == "assistant"]
+            # if len(assistant_indices) > MAX_HISTORY_TURNS:
+            #     drop = assistant_indices[0]
+            #     # 移除最早的 user+assistant 对,保留 system 消息
+            #     conversation = Conversation([m for i, m in enumerate(conversation.messages)
+            #                                  if i not in (drop - 1, drop)])
 
             # Cleanup
             del y, encoded

+ 13 - 8
tools/api_server.py

@@ -78,12 +78,14 @@ class API(ExceptionHandler):
         self.app.state.device = self.args.device
         self.app.state.max_text_length = self.args.max_text_length
 
+        # Args print
+        self.args_print()
+        # Environment print
+        self.environment_print()
+
         # Associate the app with the model manager
         self.app.on_startup(self.initialize_app)
 
-        # Print args
-        self.args_print()
-
     async def initialize_app(self, app: Kui):
         # Make the ModelManager available to the views
         app.state.model_manager = ModelManager(
@@ -99,11 +101,14 @@ class API(ExceptionHandler):
         logger.info(f"Startup done, listening server at http://{self.args.listen}")
 
     def args_print(self):
-        if self.args:
-            logger.info("Loaded arguments:")
-            for key, value in vars(self.args).items():
-                logger.info(f"  self.args.{key}: {value}")
-
+        if not self.args:
+            return
+        logger.info("Loaded arguments:")
+        for key, value in vars(self.args).items():
+            logger.info(f"  self.args.{key}: {value}")
+
+    @staticmethod
+    def environment_print():
         logger.info("environment:")
         for key in os.environ.keys():
             logger.info(f"    env.{key}: {os.getenv(key)}")