소스 검색

Fix api (#321)

spicysama 1 년 전
부모
커밋
61b6609981
2개의 변경된 파일4개의 추가작업 그리고 21개의 파일을 삭제
  1. 3 9
      tools/api.py
  2. 1 12
      tools/post_api.py

+ 3 - 9
tools/api.py

@@ -200,11 +200,9 @@ def inference(req: InvokeRequest):
     lab_path, wav_path = get_random_paths(ref_base, ref_data, req.speaker, req.emotion)
 
     if lab_path and wav_path:
-        with open(wav_path, "rb") as wav_file:
-            audio_bytes = wav_file.read()
         with open(lab_path, "r", encoding="utf-8") as lab_file:
             ref_text = lab_file.read()
-        req.reference_audio = base64.b64encode(audio_bytes).decode("utf-8")
+        req.reference_audio = wav_path
         req.reference_text = ref_text
         logger.info("ref_path: " + str(wav_path))
         logger.info("ref_text: " + ref_text)
@@ -212,11 +210,7 @@ def inference(req: InvokeRequest):
     # Parse reference audio aka prompt
     prompt_tokens = encode_reference(
         decoder_model=decoder_model,
-        reference_audio=(
-            io.BytesIO(base64.b64decode(req.reference_audio))
-            if req.reference_audio is not None
-            else None
-        ),
+        reference_audio=req.reference_audio,
         enable_reference_audio=req.reference_audio is not None,
     )
 
@@ -423,7 +417,7 @@ if __name__ == "__main__":
                 text="Hello world.",
                 reference_text=None,
                 reference_audio=None,
-                max_new_tokens=1024,
+                max_new_tokens=0,
                 top_p=0.7,
                 repetition_penalty=1.2,
                 temperature=0.7,

+ 1 - 12
tools/post_api.py

@@ -6,15 +6,6 @@ import pyaudio
 import requests
 
 
-def wav_to_base64(file_path):
-    if not file_path:
-        return None
-    with open(file_path, "rb") as wav_file:
-        wav_content = wav_file.read()
-        base64_encoded = base64.b64encode(wav_content)
-        return base64_encoded.decode("utf-8")
-
-
 def play_audio(audio_content, format, channels, rate):
     p = pyaudio.PyAudio()
     stream = p.open(format=format, channels=channels, rate=rate, output=True)
@@ -88,12 +79,10 @@ if __name__ == "__main__":
 
     args = parser.parse_args()
 
-    base64_audio = wav_to_base64(args.reference_audio)
-
     data = {
         "text": args.text,
         "reference_text": args.reference_text,
-        "reference_audio": base64_audio,
+        "reference_audio": args.reference_audio,
         "max_new_tokens": args.max_new_tokens,
         "chunk_length": args.chunk_length,
         "top_p": args.top_p,