|
|
@@ -1,11 +1,14 @@
|
|
|
import base64
|
|
|
import io
|
|
|
+import json
|
|
|
import queue
|
|
|
+import random
|
|
|
import threading
|
|
|
import traceback
|
|
|
import wave
|
|
|
from argparse import ArgumentParser
|
|
|
from http import HTTPStatus
|
|
|
+from pathlib import Path
|
|
|
from typing import Annotated, Literal, Optional
|
|
|
|
|
|
import librosa
|
|
|
@@ -163,6 +166,41 @@ def decode_vq_tokens(
|
|
|
routes = MultimethodRoutes(base_class=HttpView)
|
|
|
|
|
|
|
|
|
+def get_random_paths(base_path, data, speaker, emotion):
|
|
|
+ if base_path and data and speaker and emotion and (Path(base_path).exists()):
|
|
|
+ if speaker in data and emotion in data[speaker]:
|
|
|
+ files = data[speaker][emotion]
|
|
|
+ lab_files = [f for f in files if f.endswith(".lab")]
|
|
|
+ wav_files = [f for f in files if f.endswith(".wav")]
|
|
|
+
|
|
|
+ if lab_files and wav_files:
|
|
|
+ selected_lab = random.choice(lab_files)
|
|
|
+ selected_wav = random.choice(wav_files)
|
|
|
+
|
|
|
+ lab_path = Path(base_path) / speaker / emotion / selected_lab
|
|
|
+ wav_path = Path(base_path) / speaker / emotion / selected_wav
|
|
|
+ if lab_path.exists() and wav_path.exists():
|
|
|
+ return lab_path, wav_path
|
|
|
+
|
|
|
+ return None, None
|
|
|
+
|
|
|
+
|
|
|
+def load_json(json_file):
|
|
|
+ if not json_file:
|
|
|
+ logger.info("Not using a json file")
|
|
|
+ return None
|
|
|
+ try:
|
|
|
+ with open(json_file, "r", encoding="utf-8") as file:
|
|
|
+ data = json.load(file)
|
|
|
+ except FileNotFoundError:
|
|
|
+ logger.warning(f"ref json not found: {json_file}")
|
|
|
+ data = None
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"Loading json failed: {e}")
|
|
|
+ data = None
|
|
|
+ return data
|
|
|
+
|
|
|
+
|
|
|
class InvokeRequest(BaseModel):
|
|
|
text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游."
|
|
|
reference_text: Optional[str] = None
|
|
|
@@ -173,8 +211,22 @@ class InvokeRequest(BaseModel):
|
|
|
repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.5
|
|
|
temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
|
|
speaker: Optional[str] = None
|
|
|
+ emotion: Optional[str] = None
|
|
|
format: Literal["wav", "mp3", "flac"] = "wav"
|
|
|
streaming: bool = False
|
|
|
+ ref_json: Optional[str] = "ref_data.json"
|
|
|
+ ref_base: Optional[str] = "ref_data"
|
|
|
+
|
|
|
+
|
|
|
+def get_content_type(audio_format):
|
|
|
+ if audio_format == "wav":
|
|
|
+ return "audio/wav"
|
|
|
+ elif audio_format == "flac":
|
|
|
+ return "audio/flac"
|
|
|
+ elif audio_format == "mp3":
|
|
|
+ return "audio/mpeg"
|
|
|
+ else:
|
|
|
+ return "application/octet-stream"
|
|
|
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
@@ -182,6 +234,21 @@ def inference(req: InvokeRequest):
|
|
|
# Parse reference audio aka prompt
|
|
|
prompt_tokens = None
|
|
|
|
|
|
+ ref_data = load_json(req.ref_json)
|
|
|
+ ref_base = req.ref_base
|
|
|
+
|
|
|
+ lab_path, wav_path = get_random_paths(ref_base, ref_data, req.speaker, req.emotion)
|
|
|
+
|
|
|
+ if lab_path and wav_path:
|
|
|
+ with open(wav_path, "rb") as wav_file:
|
|
|
+ audio_bytes = wav_file.read()
|
|
|
+ with open(lab_path, "r", encoding="utf-8") as lab_file:
|
|
|
+ ref_text = lab_file.read()
|
|
|
+ req.reference_audio = base64.b64encode(audio_bytes).decode("utf-8")
|
|
|
+ req.reference_text = ref_text
|
|
|
+ logger.info("ref_path: " + str(wav_path))
|
|
|
+ logger.info("ref_text: " + ref_text)
|
|
|
+
|
|
|
# Parse reference audio aka prompt
|
|
|
prompt_tokens, reference_embedding = encode_reference(
|
|
|
decoder_model=decoder_model,
|
|
|
@@ -294,7 +361,7 @@ def api_invoke_model(
|
|
|
headers={
|
|
|
"Content-Disposition": f"attachment; filename=audio.{req.format}",
|
|
|
},
|
|
|
- content_type="audio/wav",
|
|
|
+ content_type=get_content_type(req.format),
|
|
|
)
|
|
|
else:
|
|
|
fake_audios = next(generator)
|
|
|
@@ -306,7 +373,7 @@ def api_invoke_model(
|
|
|
headers={
|
|
|
"Content-Disposition": f"attachment; filename=audio.{req.format}",
|
|
|
},
|
|
|
- content_type="audio/wav",
|
|
|
+ content_type=get_content_type(req.format),
|
|
|
)
|
|
|
|
|
|
|
|
|
@@ -404,7 +471,10 @@ if __name__ == "__main__":
|
|
|
repetition_penalty=1.5,
|
|
|
temperature=0.7,
|
|
|
speaker=None,
|
|
|
+ emotion=None,
|
|
|
format="wav",
|
|
|
+ ref_base=None,
|
|
|
+ ref_json=None,
|
|
|
)
|
|
|
)
|
|
|
)
|