Просмотр исходного кода

API support: ref audio (#241)

* New Package

* Update batch file: Ensure ASCII

* Download necessary tools

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix spelling

* API support: ref audio

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Revert flag

* Updated api ref files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
spicysama 1 год назад
Родитель
Сommit
1e83d31959
3 измененных файлов с 109 добавлено и 2 удалено
  1. 1 0
      .gitignore
  2. 72 2
      tools/api.py
  3. 36 0
      tools/gen_ref.py

+ 1 - 0
.gitignore

@@ -24,3 +24,4 @@ asr-label*
 /fishenv
 /.locale
 /demo-audios
+ref_data*

+ 72 - 2
tools/api.py

@@ -1,11 +1,14 @@
 import base64
 import io
+import json
 import queue
+import random
 import threading
 import traceback
 import wave
 from argparse import ArgumentParser
 from http import HTTPStatus
+from pathlib import Path
 from typing import Annotated, Literal, Optional
 
 import librosa
@@ -163,6 +166,41 @@ def decode_vq_tokens(
 routes = MultimethodRoutes(base_class=HttpView)
 
 
+def get_random_paths(base_path, data, speaker, emotion):
+    if base_path and data and speaker and emotion and (Path(base_path).exists()):
+        if speaker in data and emotion in data[speaker]:
+            files = data[speaker][emotion]
+            lab_files = [f for f in files if f.endswith(".lab")]
+            wav_files = [f for f in files if f.endswith(".wav")]
+
+            if lab_files and wav_files:
+                selected_lab = random.choice(lab_files)
+                selected_wav = random.choice(wav_files)
+
+                lab_path = Path(base_path) / speaker / emotion / selected_lab
+                wav_path = Path(base_path) / speaker / emotion / selected_wav
+                if lab_path.exists() and wav_path.exists():
+                    return lab_path, wav_path
+
+    return None, None
+
+
+def load_json(json_file):
+    if not json_file:
+        logger.info("Not using a json file")
+        return None
+    try:
+        with open(json_file, "r", encoding="utf-8") as file:
+            data = json.load(file)
+    except FileNotFoundError:
+        logger.warning(f"ref json not found: {json_file}")
+        data = None
+    except Exception as e:
+        logger.warning(f"Loading json failed: {e}")
+        data = None
+    return data
+
+
 class InvokeRequest(BaseModel):
     text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游."
     reference_text: Optional[str] = None
@@ -173,8 +211,22 @@ class InvokeRequest(BaseModel):
     repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.5
     temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
     speaker: Optional[str] = None
+    emotion: Optional[str] = None
     format: Literal["wav", "mp3", "flac"] = "wav"
     streaming: bool = False
+    ref_json: Optional[str] = "ref_data.json"
+    ref_base: Optional[str] = "ref_data"
+
+
+def get_content_type(audio_format):
+    if audio_format == "wav":
+        return "audio/wav"
+    elif audio_format == "flac":
+        return "audio/flac"
+    elif audio_format == "mp3":
+        return "audio/mpeg"
+    else:
+        return "application/octet-stream"
 
 
 @torch.inference_mode()
@@ -182,6 +234,21 @@ def inference(req: InvokeRequest):
     # Parse reference audio aka prompt
     prompt_tokens = None
 
+    ref_data = load_json(req.ref_json)
+    ref_base = req.ref_base
+
+    lab_path, wav_path = get_random_paths(ref_base, ref_data, req.speaker, req.emotion)
+
+    if lab_path and wav_path:
+        with open(wav_path, "rb") as wav_file:
+            audio_bytes = wav_file.read()
+        with open(lab_path, "r", encoding="utf-8") as lab_file:
+            ref_text = lab_file.read()
+        req.reference_audio = base64.b64encode(audio_bytes).decode("utf-8")
+        req.reference_text = ref_text
+        logger.info("ref_path: " + str(wav_path))
+        logger.info("ref_text: " + ref_text)
+
     # Parse reference audio aka prompt
     prompt_tokens, reference_embedding = encode_reference(
         decoder_model=decoder_model,
@@ -294,7 +361,7 @@ def api_invoke_model(
             headers={
                 "Content-Disposition": f"attachment; filename=audio.{req.format}",
             },
-            content_type="audio/wav",
+            content_type=get_content_type(req.format),
         )
     else:
         fake_audios = next(generator)
@@ -306,7 +373,7 @@ def api_invoke_model(
             headers={
                 "Content-Disposition": f"attachment; filename=audio.{req.format}",
             },
-            content_type="audio/wav",
+            content_type=get_content_type(req.format),
         )
 
 
@@ -404,7 +471,10 @@ if __name__ == "__main__":
                 repetition_penalty=1.5,
                 temperature=0.7,
                 speaker=None,
+                emotion=None,
                 format="wav",
+                ref_base=None,
+                ref_json=None,
             )
         )
     )

+ 36 - 0
tools/gen_ref.py

@@ -0,0 +1,36 @@
+import json
+from pathlib import Path
+
+
+def scan_folder(base_path):
+    wav_lab_pairs = {}
+
+    base = Path(base_path)
+    for suf in ["wav", "lab"]:
+        for f in base.rglob(f"*.{suf}"):
+            relative_path = f.relative_to(base)
+            parts = relative_path.parts
+            print(parts)
+            if len(parts) >= 3:
+                character = parts[0]
+                emotion = parts[1]
+
+                if character not in wav_lab_pairs:
+                    wav_lab_pairs[character] = {}
+                if emotion not in wav_lab_pairs[character]:
+                    wav_lab_pairs[character][emotion] = []
+                wav_lab_pairs[character][emotion].append(str(f.name))
+
+    return wav_lab_pairs
+
+
+def save_to_json(data, output_file):
+    with open(output_file, "w", encoding="utf-8") as file:
+        json.dump(data, file, ensure_ascii=False, indent=2)
+
+
+base_path = "ref_data"
+out_ref_file = "ref_data.json"
+
+wav_lab_pairs = scan_folder(base_path)
+save_to_json(wav_lab_pairs, out_ref_file)