tts_test.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. from __future__ import print_function
  2. from typing import List, Dict, Any
  3. import pandas as pd
  4. import requests
  5. import volcenginesdkcore
  6. import volcenginesdkspeechsaasprod
  7. import volcenginesdkspeechsaasprod20250521
  8. from volcenginesdkcore.rest import ApiException
  9. from helper.MySQLHelper import MySQLHelper
  10. mysql_helper = MySQLHelper(
  11. host="rm-t4na9qj85v7790tf84o.mysql.singapore.rds.aliyuncs.com",
  12. username="readonly",
  13. password="HdkZ4TDmeK6SQ3BRtJBk",
  14. database="aigc-admin-prod"
  15. )
  16. ak = "AKLTZWIxNWRkMzUyYjBmNGU2Yjk5MTFiYWVmNmNiY2Q1Njg"
  17. sk = "WW1NM1l6TTJNRFZrT0dFMk5HSXhZamt5TnpFd1kyWTNPR0V6TURZd056Yw=="
  18. configuration = volcenginesdkcore.Configuration()
  19. configuration.ak = ak
  20. configuration.sk = sk
  21. configuration.region = "cn-beijing"
  22. # set default configuration
  23. volcenginesdkcore.Configuration.set_default(configuration)
  24. def add_speaker(audio_url: str, speaker: str):
  25. url = "https://aigc-api.aiddit.com/aigc/resources/aiAccount/saveTts"
  26. payload = {
  27. "params": {
  28. "model": 34,
  29. "ttsName": speaker,
  30. "trainAudioUrl": audio_url
  31. },
  32. "baseInfo": {
  33. "token": "80ce2034892c4428ab5b6e39ec0a9e2d",
  34. }
  35. }
  36. response = requests.post(url, headers={}, json=payload)
  37. response.raise_for_status() # 如果状态码不是 2xx,抛出异常
  38. return response.json()
  39. def speaker_insert_db(speaker_id: str):
  40. sql = f'INSERT INTO `aigc-admin-prod`.volcengine_tts_speaker (speaker_id, type, status, create_time, update_time) VALUES ("{speaker_id}", 2, 0, "2026-05-13 21:08:00", "2026-05-13 21:08:00");'
  41. print(sql)
  42. def get_all_speakers() -> List[Dict[str, Any]]:
  43. api_instance = volcenginesdkspeechsaasprod20250521.SPEECHSAASPROD20250521Api()
  44. batch_list_mega_tts_train_status_request = volcenginesdkspeechsaasprod20250521.BatchListMegaTTSTrainStatusRequest(
  45. state='Success',
  46. page_size=100,
  47. page_number=1,
  48. project_name="aiddit",
  49. )
  50. try:
  51. # 复制代码运行示例,请自行打印API返回值。
  52. response = api_instance.batch_list_mega_tts_train_status(batch_list_mega_tts_train_status_request)
  53. result = []
  54. for item in response.statuses:
  55. result.append({
  56. "speaker_id": item.speaker_id,
  57. "alias": item.alias,
  58. "instance_no": item.instance_no,
  59. })
  60. return result
  61. except ApiException as e:
  62. print("Exception when calling api: %s\n" % e)
  63. return []
  64. def update_tts_alias(tts_id: str, alias: str):
  65. # use global default configuration
  66. api_instance = volcenginesdkspeechsaasprod.SPEECHSAASPRODApi()
  67. alias_resource_pack_request = volcenginesdkspeechsaasprod.AliasResourcePackRequest(
  68. alias=alias,
  69. instance_number="",
  70. project_name="aiddit",
  71. train_id=tts_id,
  72. )
  73. try:
  74. response = api_instance.alias_resource_pack(alias_resource_pack_request)
  75. # response.raise_for_status() # 如果状态码不是 2xx,抛出异常
  76. return {}
  77. except ApiException as e:
  78. print("Exception when calling api: %s\n" % e)
  79. return {}
  80. def get_volc_engine_tts_info(tts_name: str, mode: int) -> Dict[str, Any]:
  81. sql = f"select * from ai_model_tts where speaker = '{tts_name}' and model = {mode};"
  82. results = mysql_helper.execute_query(sql)
  83. if results:
  84. return results[0]
  85. else:
  86. return {}
  87. def get_tts_info(tts_id: str) -> Dict[str, Any]:
  88. sql = f"select * from ai_model_tts where id = '{tts_id}'"
  89. results = mysql_helper.execute_query(sql)
  90. return results[0]
  91. def read_tts_id() -> List[str]:
  92. df = pd.read_csv("/Users/zhao/Desktop/fish_tts.csv")
  93. return df['tts_id'].tolist()
  94. def main():
  95. speaker_id_alias_map = {}
  96. for item in get_all_speakers():
  97. speaker_id = item['speaker_id']
  98. alias = item['alias']
  99. speaker_id_alias_map[speaker_id] = alias
  100. df = pd.read_csv("/Users/zhao/Desktop/aigc_admin_prod_ai_model_tts.csv")
  101. dict_list = df.to_dict(orient='records')
  102. for item in dict_list:
  103. speaker_id = item['speaker_id']
  104. alias = item['speaker']
  105. if speaker_id not in speaker_id_alias_map:
  106. print(f"{speaker_id} not in speaker_id_alias_map")
  107. continue
  108. volc_alias = speaker_id_alias_map[speaker_id]
  109. if volc_alias == alias:
  110. print(f"{volc_alias} == {alias}")
  111. continue
  112. print(f'更新 {speaker_id} 的别名为 {alias}')
  113. response = update_tts_alias(speaker_id, alias)
  114. print(f'{speaker_id} -> {alias} -> {response}')
  115. if __name__ == '__main__':
  116. main()