|
@@ -1,4 +1,5 @@
|
|
|
import io
|
|
import io
|
|
|
|
|
+import re
|
|
|
from hashlib import sha256
|
|
from hashlib import sha256
|
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
|
from typing import Callable, Literal, Tuple
|
|
from typing import Callable, Literal, Tuple
|
|
@@ -16,6 +17,8 @@ from fish_speech.utils.file import (
|
|
|
)
|
|
)
|
|
|
from fish_speech.utils.schema import ServeReferenceAudio
|
|
from fish_speech.utils.schema import ServeReferenceAudio
|
|
|
|
|
|
|
|
|
|
+_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_ ]+$")
|
|
|
|
|
+
|
|
|
|
|
|
|
|
class ReferenceLoader:
|
|
class ReferenceLoader:
|
|
|
def __init__(self) -> None:
|
|
def __init__(self) -> None:
|
|
@@ -48,11 +51,21 @@ class ReferenceLoader:
|
|
|
except (ImportError, ModuleNotFoundError):
|
|
except (ImportError, ModuleNotFoundError):
|
|
|
self.backend = "soundfile"
|
|
self.backend = "soundfile"
|
|
|
|
|
|
|
|
|
|
+ @staticmethod
|
|
|
|
|
+ def _validate_id(id: str) -> None:
|
|
|
|
|
+ if not _ID_PATTERN.match(id) or len(id) > 255:
|
|
|
|
|
+ raise ValueError(
|
|
|
|
|
+ "Reference ID contains invalid characters or is too long. "
|
|
|
|
|
+ "Only alphanumeric, hyphens, underscores, and spaces are allowed."
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
def load_by_id(
|
|
def load_by_id(
|
|
|
self,
|
|
self,
|
|
|
id: str,
|
|
id: str,
|
|
|
use_cache: Literal["on", "off"],
|
|
use_cache: Literal["on", "off"],
|
|
|
) -> Tuple:
|
|
) -> Tuple:
|
|
|
|
|
+ self._validate_id(id)
|
|
|
|
|
+
|
|
|
# Load the references audio and text by id
|
|
# Load the references audio and text by id
|
|
|
ref_folder = Path("references") / id
|
|
ref_folder = Path("references") / id
|
|
|
ref_folder.mkdir(parents=True, exist_ok=True)
|
|
ref_folder.mkdir(parents=True, exist_ok=True)
|
|
@@ -189,18 +202,7 @@ class ReferenceLoader:
|
|
|
FileNotFoundError: If the audio file doesn't exist
|
|
FileNotFoundError: If the audio file doesn't exist
|
|
|
OSError: If file operations fail
|
|
OSError: If file operations fail
|
|
|
"""
|
|
"""
|
|
|
- # Validate ID format
|
|
|
|
|
- import re
|
|
|
|
|
-
|
|
|
|
|
- if not re.match(r"^[a-zA-Z0-9\-_ ]+$", id):
|
|
|
|
|
- raise ValueError(
|
|
|
|
|
- "Reference ID contains invalid characters. Only alphanumeric, hyphens, underscores, and spaces are allowed."
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- if len(id) > 255:
|
|
|
|
|
- raise ValueError(
|
|
|
|
|
- "Reference ID is too long. Maximum length is 255 characters."
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ self._validate_id(id)
|
|
|
|
|
|
|
|
# Check if reference already exists
|
|
# Check if reference already exists
|
|
|
ref_dir = Path("references") / id
|
|
ref_dir = Path("references") / id
|
|
@@ -260,7 +262,8 @@ class ReferenceLoader:
|
|
|
FileNotFoundError: If the reference ID doesn't exist
|
|
FileNotFoundError: If the reference ID doesn't exist
|
|
|
OSError: If file operations fail
|
|
OSError: If file operations fail
|
|
|
"""
|
|
"""
|
|
|
- # Check if reference exists
|
|
|
|
|
|
|
+ self._validate_id(id)
|
|
|
|
|
+
|
|
|
ref_dir = Path("references") / id
|
|
ref_dir = Path("references") / id
|
|
|
if not ref_dir.exists():
|
|
if not ref_dir.exists():
|
|
|
raise FileNotFoundError(f"Reference ID '{id}' does not exist")
|
|
raise FileNotFoundError(f"Reference ID '{id}' does not exist")
|