|
|
@@ -1,125 +1,60 @@
|
|
|
-import logging
|
|
|
-from io import BytesIO
|
|
|
+from typing import List, Dict, Any
|
|
|
|
|
|
-import pandas as pd
|
|
|
-import requests
|
|
|
+from client.FishClient import FishClient
|
|
|
+from helper.MySQLHelper import MySQLHelper
|
|
|
|
|
|
-logger = logging.getLogger(__name__)
|
|
|
+official_fish_client = FishClient("https://api.fish.audio")
|
|
|
|
|
|
-from typing import Dict, Any
|
|
|
+mysql_helper = MySQLHelper(
|
|
|
+ host="rm-t4na9qj85v7790tf84o.mysql.singapore.rds.aliyuncs.com",
|
|
|
+ username="readonly",
|
|
|
+ password="HdkZ4TDmeK6SQ3BRtJBk",
|
|
|
+ database="aigc-admin-prod"
|
|
|
+)
|
|
|
|
|
|
|
|
|
-def build_common_header_map():
|
|
|
- return {
|
|
|
- "Authorization": "Bearer 0891f3f93a2640428f9988e267aa57e1",
|
|
|
- "Content-Type": "application/json"
|
|
|
- }
|
|
|
+def get_fish_pq_ip() -> List[str]:
|
|
|
+ sql = "select * from base_config where config_key = 'fish_pq_ip_list';"
|
|
|
+ result = mysql_helper.execute_query(sql)
|
|
|
+ if not result:
|
|
|
+ return []
|
|
|
+ value = result[0]['config_value']
|
|
|
+ return value.split(',')
|
|
|
|
|
|
|
|
|
-def get_model_info(reference_id: str) -> Dict[str, Any]:
|
|
|
- api_url = f"https://api.fish.audio/model/{reference_id}"
|
|
|
- headers = build_common_header_map()
|
|
|
- response = requests.get(
|
|
|
- api_url,
|
|
|
- headers=headers,
|
|
|
- timeout=(10, 1800) # connect timeout, read timeout
|
|
|
- )
|
|
|
- return response.json()
|
|
|
-
|
|
|
-
|
|
|
-def add_reference(
|
|
|
- base_url: str,
|
|
|
- reference_id: str,
|
|
|
- audio_url: str,
|
|
|
- text: str,
|
|
|
- timeout: int = 30,
|
|
|
- download_timeout: int = 60,
|
|
|
-) -> Dict[str, Any]:
|
|
|
- """
|
|
|
- 从 URL 下载音频并调用添加参考音频的接口
|
|
|
-
|
|
|
- Args:
|
|
|
- base_url: 服务端基础地址,例如 "http://localhost:8000"
|
|
|
- reference_id: 参考音频唯一标识
|
|
|
- audio_url: 音频文件的 URL(支持 http/https)
|
|
|
- text: 音频对应的文本内容
|
|
|
- timeout: 上传请求的超时时间(秒)
|
|
|
- download_timeout: 下载音频文件的超时时间(秒)
|
|
|
-
|
|
|
- Returns:
|
|
|
- 服务端返回的 JSON 响应(字典)
|
|
|
-
|
|
|
- Raises:
|
|
|
- requests.RequestException: 下载或上传请求失败
|
|
|
- ValueError: 服务端返回错误响应或下载内容为空
|
|
|
- """
|
|
|
- # 1. 从 URL 下载音频内容
|
|
|
- try:
|
|
|
- resp = requests.get(audio_url, timeout=download_timeout)
|
|
|
- resp.raise_for_status() # 检查 HTTP 错误
|
|
|
- audio_content = resp.content
|
|
|
- if not audio_content:
|
|
|
- raise ValueError("从 URL 下载的音频文件为空")
|
|
|
- except requests.exceptions.RequestException as e:
|
|
|
- raise requests.RequestException(f"下载音频失败: {e}") from e
|
|
|
-
|
|
|
- # 2. 构造请求 URL
|
|
|
- url = f"{base_url.rstrip('/')}/v1/references/add"
|
|
|
-
|
|
|
- # 3. 准备表单数据和文件(从内存中的字节构造文件)
|
|
|
- data = {
|
|
|
- "id": reference_id,
|
|
|
- "text": text,
|
|
|
- }
|
|
|
- # 从 URL 中提取文件名(如果 URL 没有明确文件名,可以自定义)
|
|
|
- file_name = audio_url.split('/')[-1] or "audio.wav"
|
|
|
- # 使用 BytesIO 包装音频内容
|
|
|
- files = {
|
|
|
- "audio": (file_name, BytesIO(audio_content), "audio/wav"),
|
|
|
- }
|
|
|
-
|
|
|
- headers = {
|
|
|
- "Accept": "application/json",
|
|
|
- }
|
|
|
- # 4. 发送上传请求
|
|
|
- try:
|
|
|
- response = requests.post(url, data=data, files=files, headers=headers, timeout=timeout)
|
|
|
- resp_json = response.json()
|
|
|
- except requests.exceptions.RequestException as e:
|
|
|
- raise requests.RequestException(f"上传请求失败: {e}") from e
|
|
|
- finally:
|
|
|
- # 关闭 BytesIO(可选,因为内存对象会自动回收)
|
|
|
- files["audio"][1].close()
|
|
|
-
|
|
|
- # 5. 检查响应
|
|
|
- if response.status_code != 200:
|
|
|
- raise ValueError(f"服务端返回错误 (HTTP {response.status_code}): {resp_json.get('message', '未知错误')}")
|
|
|
-
|
|
|
- if not resp_json.get("success", False):
|
|
|
- raise ValueError(f"业务失败: {resp_json.get('message', '未知错误')}")
|
|
|
-
|
|
|
- return resp_json
|
|
|
-
|
|
|
+def get_all_reference_by_db() -> List[Dict[str, Any]]:
|
|
|
+ sql = "select * from ai_model_tts where model = 33;"
|
|
|
+ return mysql_helper.execute_query(sql)
|
|
|
|
|
|
|
|
|
def _main():
|
|
|
- df = pd.read_csv("/Users/zhao/Desktop/aigc_admin_prod_ai_model_tts.csv")
|
|
|
- base_url = "http://192.168.245.146:8080/"
|
|
|
- for row in df.itertuples():
|
|
|
- reference_id = row.speaker_id
|
|
|
- audio_url = row.audio_url
|
|
|
- if reference_id in ['6e2d9e58b26c424db6d564ea56983f4d']:
|
|
|
- continue
|
|
|
-
|
|
|
- model_info = get_model_info(reference_id)
|
|
|
- text = model_info['samples'][0]['text']
|
|
|
- add_reference(
|
|
|
- base_url=base_url,
|
|
|
- reference_id=reference_id,
|
|
|
- audio_url=audio_url,
|
|
|
- text=text,
|
|
|
- timeout=30,
|
|
|
- )
|
|
|
+ db_all_reference = get_all_reference_by_db()
|
|
|
+ reference_id_and_text_map = {}
|
|
|
+ all_ip = get_fish_pq_ip()
|
|
|
+ print(f"当前配置的Fish服务器IP列表为: {all_ip}")
|
|
|
+ for ip in all_ip:
|
|
|
+ print(f"开始将音频同步到实例【{ip}】")
|
|
|
+ fish_client = FishClient(f"http://{ip}:8080")
|
|
|
+ exist_references_ids = fish_client.get_all_references_id()
|
|
|
+ for reference_info in db_all_reference:
|
|
|
+ reference_id = reference_info['speaker_id']
|
|
|
+ try:
|
|
|
+ if reference_id in exist_references_ids:
|
|
|
+ print(f"音频ID【{reference_id}】在实例【{ip}】上已经存在,跳过")
|
|
|
+ continue
|
|
|
+
|
|
|
+ if reference_id not in reference_id_and_text_map:
|
|
|
+ model_info = official_fish_client.get_model_info_by_id(reference_id)
|
|
|
+ text = model_info['samples'][0]['text']
|
|
|
+ reference_id_and_text_map[reference_id] = text
|
|
|
+
|
|
|
+ audio_url = reference_info['audio_url']
|
|
|
+ reference_text = reference_id_and_text_map[reference_id]
|
|
|
+ fish_client.add_reference_id_by_url(reference_id=reference_id, reference_text=reference_text, audio_url=audio_url)
|
|
|
+ print(f"音频ID【{reference_id}】同步到实例【{ip}】上完成")
|
|
|
+ except Exception as e:
|
|
|
+ print(f"音频ID【{reference_id}】同步到实例【{ip}】上异常 {str(e)}")
|
|
|
+ print(f"将音频同步到实例【{ip}】完成")
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|