Преглед изворни кода

feat:添加fish参考音频同步任务

zhaohaipeng пре 2 месеци
родитељ
комит
16f8394119
4 измењених фајлова са 145 додато и 115 уклоњено
  1. 96 0
      client/FishClient.py
  2. 2 2
      helper/MySQLHelper.py
  3. 0 1
      script/feature_spark_analyse.py
  4. 47 112
      script/fish_reference_audio_sync.py

+ 96 - 0
client/FishClient.py

@@ -0,0 +1,96 @@
+from io import BytesIO
+from typing import Dict, Any, List
+
+import requests
+
+
+class FishClient(object):
+    def __init__(self, base_url: str):
+        self.base_url = base_url
+
+    def get_all_references_id(self) -> List[str]:
+        """
+        获取已存在的参考音频
+
+        Returns:
+            参考音频ID列表
+
+        Raises:
+            requests.RequestException: 下载或上传请求失败
+            ValueError: 服务端返回错误响应或下载内容为空
+        """
+        url = f"{self.base_url}/v1/references/list"
+        headers = self.build_common_header()
+        response = requests.get(url, headers=headers)
+        return response.json().get("reference_ids", [])
+
+    def add_reference_id_by_url(self, reference_id: str, reference_text: str, audio_url: str):
+        """
+        从 URL 下载音频并调用添加参考音频的接口
+
+        Args:
+            reference_id: 参考音频唯一标识
+            audio_url: 音频文件的 URL(支持 http/https)
+            reference_text: 音频对应的文本内容
+
+        Raises:
+            requests.RequestException: 下载或上传请求失败
+            ValueError: 服务端返回错误响应或下载内容为空
+        """
+        try:
+            resp = requests.get(audio_url, timeout=30)
+            resp.raise_for_status()  # 检查 HTTP 错误
+            audio_content = resp.content
+            if not audio_content:
+                raise ValueError("从 URL 下载的音频文件为空")
+        except requests.exceptions.RequestException as e:
+            raise requests.RequestException(f"下载音频失败: {e}") from e
+
+        url = f"{self.base_url}/v1/references/add"
+        payload = {
+            "id": reference_id,
+            "text": reference_text,
+        }
+
+        # 定义文件名
+        file_name = audio_url.split('/')[-1] or "audio.wav"
+        # 使用 BytesIO 包装音频内容
+        files = {
+            "audio": (file_name, BytesIO(audio_content), "audio/wav"),
+        }
+
+        headers = self.build_common_header()
+
+        try:
+            response = requests.post(url, data=payload, files=files, headers=headers, timeout=30)
+            resp_json = response.json()
+        except requests.exceptions.RequestException as e:
+            raise requests.RequestException(f"上传请求失败: {e}") from e
+        finally:
+            files["audio"][1].close()
+
+        if response.status_code != 200:
+            raise ValueError(f"服务端返回错误 (HTTP {response.status_code}): {resp_json.get('message', '未知错误')}")
+
+        if not resp_json.get("success", False):
+            raise ValueError(f"业务失败: {resp_json.get('message', '未知错误')}")
+
+        return resp_json
+
+    def get_model_info_by_id(self, reference_id: str) -> Dict[str, Any]:
+        url = f"{self.base_url}/model/{reference_id}"
+        headers = {
+            "Authorization": "Bearer 0891f3f93a2640428f9988e267aa57e1",
+            "Content-Type": "application/json"
+        }
+        response = requests.get(url, headers=headers, timeout=(10, 1800))
+        return response.json()
+
+    def is_official_api(self):
+        return "api.fish.audio" in self.base_url
+
+    @staticmethod
+    def build_common_header() -> Dict[str, Any]:
+        return {
+            "Accept": "application/json",
+        }

+ 2 - 2
helper/MySQLHelper.py

@@ -1,4 +1,5 @@
 import logging
+from typing import Dict, List, Any
 
 import pymysql
 import pymysql.cursors
@@ -38,7 +39,7 @@ class MySQLHelper:
             logging.error(f"数据库连接失败: {e}")
             raise
 
-    def execute_query(self, sql: str, params: tuple = None) -> list[dict]:
+    def execute_query(self, sql: str, params: tuple = None) -> List[Dict[str, Any]]:
         """
         执行 SELECT 查询。
 
@@ -106,4 +107,3 @@ class MySQLHelper:
     def __del__(self):
         """析构时确保连接关闭。"""
         self.close()
-

+ 0 - 1
script/feature_spark_analyse.py

@@ -1,6 +1,5 @@
 from collections import defaultdict
 from datetime import datetime, timedelta
-from typing import List, Dict
 
 import pandas as pd
 

+ 47 - 112
script/fish_reference_audio_sync.py

@@ -1,125 +1,60 @@
-import logging
-from io import BytesIO
+from typing import List, Dict, Any
 
-import pandas as pd
-import requests
+from client.FishClient import FishClient
+from helper.MySQLHelper import MySQLHelper
 
-logger = logging.getLogger(__name__)
+official_fish_client = FishClient("https://api.fish.audio")
 
-from typing import Dict, Any
+mysql_helper = MySQLHelper(
+    host="rm-t4na9qj85v7790tf84o.mysql.singapore.rds.aliyuncs.com",
+    username="readonly",
+    password="HdkZ4TDmeK6SQ3BRtJBk",
+    database="aigc-admin-prod"
+)
 
 
-def build_common_header_map():
-    return {
-        "Authorization": "Bearer 0891f3f93a2640428f9988e267aa57e1",
-        "Content-Type": "application/json"
-    }
+def get_fish_pq_ip() -> List[str]:
+    sql = "select * from base_config where config_key = 'fish_pq_ip_list';"
+    result = mysql_helper.execute_query(sql)
+    if not result:
+        return []
+    value = result[0]['config_value']
+    return value.split(',')
 
 
-def get_model_info(reference_id: str) -> Dict[str, Any]:
-    api_url = f"https://api.fish.audio/model/{reference_id}"
-    headers = build_common_header_map()
-    response = requests.get(
-        api_url,
-        headers=headers,
-        timeout=(10, 1800)  # connect timeout, read timeout
-    )
-    return response.json()
-
-
-def add_reference(
-        base_url: str,
-        reference_id: str,
-        audio_url: str,
-        text: str,
-        timeout: int = 30,
-        download_timeout: int = 60,
-) -> Dict[str, Any]:
-    """
-    从 URL 下载音频并调用添加参考音频的接口
-
-    Args:
-        base_url: 服务端基础地址,例如 "http://localhost:8000"
-        reference_id: 参考音频唯一标识
-        audio_url: 音频文件的 URL(支持 http/https)
-        text: 音频对应的文本内容
-        timeout: 上传请求的超时时间(秒)
-        download_timeout: 下载音频文件的超时时间(秒)
-
-    Returns:
-        服务端返回的 JSON 响应(字典)
-
-    Raises:
-        requests.RequestException: 下载或上传请求失败
-        ValueError: 服务端返回错误响应或下载内容为空
-    """
-    # 1. 从 URL 下载音频内容
-    try:
-        resp = requests.get(audio_url, timeout=download_timeout)
-        resp.raise_for_status()  # 检查 HTTP 错误
-        audio_content = resp.content
-        if not audio_content:
-            raise ValueError("从 URL 下载的音频文件为空")
-    except requests.exceptions.RequestException as e:
-        raise requests.RequestException(f"下载音频失败: {e}") from e
-
-    # 2. 构造请求 URL
-    url = f"{base_url.rstrip('/')}/v1/references/add"
-
-    # 3. 准备表单数据和文件(从内存中的字节构造文件)
-    data = {
-        "id": reference_id,
-        "text": text,
-    }
-    # 从 URL 中提取文件名(如果 URL 没有明确文件名,可以自定义)
-    file_name = audio_url.split('/')[-1] or "audio.wav"
-    # 使用 BytesIO 包装音频内容
-    files = {
-        "audio": (file_name, BytesIO(audio_content), "audio/wav"),
-    }
-
-    headers = {
-        "Accept": "application/json",
-    }
-    # 4. 发送上传请求
-    try:
-        response = requests.post(url, data=data, files=files, headers=headers, timeout=timeout)
-        resp_json = response.json()
-    except requests.exceptions.RequestException as e:
-        raise requests.RequestException(f"上传请求失败: {e}") from e
-    finally:
-        # 关闭 BytesIO(可选,因为内存对象会自动回收)
-        files["audio"][1].close()
-
-    # 5. 检查响应
-    if response.status_code != 200:
-        raise ValueError(f"服务端返回错误 (HTTP {response.status_code}): {resp_json.get('message', '未知错误')}")
-
-    if not resp_json.get("success", False):
-        raise ValueError(f"业务失败: {resp_json.get('message', '未知错误')}")
-
-    return resp_json
-
+def get_all_reference_by_db() -> List[Dict[str, Any]]:
+    sql = "select * from ai_model_tts where model = 33;"
+    return mysql_helper.execute_query(sql)
 
 
 def _main():
-    df = pd.read_csv("/Users/zhao/Desktop/aigc_admin_prod_ai_model_tts.csv")
-    base_url = "http://192.168.245.146:8080/"
-    for row in df.itertuples():
-        reference_id = row.speaker_id
-        audio_url = row.audio_url
-        if reference_id in ['6e2d9e58b26c424db6d564ea56983f4d']:
-            continue
-
-        model_info = get_model_info(reference_id)
-        text = model_info['samples'][0]['text']
-        add_reference(
-            base_url=base_url,
-            reference_id=reference_id,
-            audio_url=audio_url,
-            text=text,
-            timeout=30,
-        )
+    db_all_reference = get_all_reference_by_db()
+    reference_id_and_text_map = {}
+    all_ip = get_fish_pq_ip()
+    print(f"当前配置的Fish服务器IP列表为: {all_ip}")
+    for ip in all_ip:
+        print(f"开始将音频同步到实例【{ip}】")
+        fish_client = FishClient(f"http://{ip}:8080")
+        exist_references_ids = fish_client.get_all_references_id()
+        for reference_info in db_all_reference:
+            reference_id = reference_info['speaker_id']
+            try:
+                if reference_id in exist_references_ids:
+                    print(f"音频ID【{reference_id}】在实例【{ip}】上已经存在,跳过")
+                    continue
+
+                if reference_id not in reference_id_and_text_map:
+                    model_info = official_fish_client.get_model_info_by_id(reference_id)
+                    text = model_info['samples'][0]['text']
+                    reference_id_and_text_map[reference_id] = text
+
+                audio_url = reference_info['audio_url']
+                reference_text = reference_id_and_text_map[reference_id]
+                fish_client.add_reference_id_by_url(reference_id=reference_id, reference_text=reference_text, audio_url=audio_url)
+                print(f"音频ID【{reference_id}】同步到实例【{ip}】上完成")
+            except Exception as e:
+                print(f"音频ID【{reference_id}】同步到实例【{ip}】上异常 {str(e)}")
+        print(f"将音频同步到实例【{ip}】完成")
 
 
 if __name__ == '__main__':