zhaohaipeng 2 недель назад
Родитель
Сommit
8b8af56bd7
4 измененных файлов с 13 добавлено и 10 удалено
  1. 1 1
      .env
  2. 3 3
      fish_speech/models/text2semantic/inference.py
  3. 8 6
      tools/api_server.py
  4. 1 0
      tools/server/views.py

+ 1 - 1
.env

@@ -1,4 +1,4 @@
 API_PORT=8080
 API_PORT=8080
 COMPILE=1
 COMPILE=1
-HALF=1
+HALF=0
 PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
 PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

+ 3 - 3
fish_speech/models/text2semantic/inference.py

@@ -723,12 +723,12 @@ def generate_long(
             yield GenerateResponse(action="sample", codes=codes, text=batch_text)
             yield GenerateResponse(action="sample", codes=codes, text=batch_text)
 
 
             MAX_HISTORY_TURNS = 2  # 只保留最近 2 轮 user/assistant
             MAX_HISTORY_TURNS = 2  # 只保留最近 2 轮 user/assistant
-            assistant_indices = [i for i, m in enumerate(conversation) if m.role == "assistant"]
+            assistant_indices = [i for i, m in enumerate(conversation.messages) if m.role == "assistant"]
             if len(assistant_indices) > MAX_HISTORY_TURNS:
             if len(assistant_indices) > MAX_HISTORY_TURNS:
                 drop = assistant_indices[0]
                 drop = assistant_indices[0]
                 # 移除最早的 user+assistant 对,保留 system 消息
                 # 移除最早的 user+assistant 对,保留 system 消息
-                conversation = [m for i, m in enumerate(conversation)
-                                if i not in (drop - 1, drop)]
+                conversation = Conversation([m for i, m in enumerate(conversation.messages)
+                                             if i not in (drop - 1, drop)])
 
 
             # Cleanup
             # Cleanup
             del y, encoded
             del y, encoded

+ 8 - 6
tools/api_server.py

@@ -99,12 +99,14 @@ class API(ExceptionHandler):
         logger.info(f"Startup done, listening server at http://{self.args.listen}")
         logger.info(f"Startup done, listening server at http://{self.args.listen}")
 
 
     def args_print(self):
     def args_print(self):
-        if not self.args:
-            return
-        print("Loaded arguments:")
-        for key, value in vars(self.args).items():
-            print(f"  self.args.{key}: {value}")
-
+        if self.args:
+            logger.info("Loaded arguments:")
+            for key, value in vars(self.args).items():
+                logger.info(f"  self.args.{key}: {value}")
+
+        logger.info("environment:")
+        for key in os.environ.keys():
+            logger.info(f"    env.{key}: {os.getenv(key)}")
 
 
 
 
 def create_app():
 def create_app():

+ 1 - 0
tools/server/views.py

@@ -146,6 +146,7 @@ async def tts(req: Annotated[ServeTTSRequest, Body(exclusive=True)]):
     """
     """
     Generate speech from text using TTS model.
     Generate speech from text using TTS model.
     """
     """
+    logger.info(f"/v1/tts param: {req}")
     try:
     try:
         # Get the model from the app
         # Get the model from the app
         app_state = request.app.state
         app_state = request.app.state