app.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411
  1. import html
  2. import io
  3. import traceback
  4. import gradio as gr
  5. import librosa
  6. import requests
  7. from fish_speech.text import parse_text_to_segments, segments_to_phones
  8. HEADER_MD = """
  9. # Fish Speech
  10. 基于 VQ-GAN 和 Llama 的多语种语音合成. 感谢 Rcell 的 GPT-VITS 提供的思路.
  11. """
  12. TEXTBOX_PLACEHOLDER = """在启用自动音素的情况下, 模型默认会全自动将输入文本转换为音素. 例如:
  13. 测试一下 Hugging face, BGM声音很大吗?那我改一下. 世界、こんにちは。
  14. 会被转换为:
  15. <Segment ZH: '测试一下' -> 'c e4 sh ir4 y i2 x ia4'>
  16. <Segment EN: ' Hugging face, BGM' -> 'HH AH1 G IH0 NG F EY1 S , B AE1 G M'>
  17. <Segment ZH: '声音很大吗?那我改一下.' -> 'sh eng1 y in1 h en3 d a4 m a5 ? n a4 w o2 g ai3 y i2 x ia4 .'>
  18. <Segment ZH: '世界,' -> 'sh ir4 j ie4 ,'>
  19. <Segment JP: 'こんにちは.' -> 'k o N n i ch i w a .'>
  20. 如你所见, 最后的句子被分割为了两个部分, 因为该日文包含了汉字, 你可以使用 <jp>...</jp> 标签来指定日文优先级. 例如:
  21. 测试一下 Hugging face, BGM声音很大吗?那我改一下. <jp>世界、こんにちは。</jp>
  22. 可以看到, 日文部分被正确地分割了出来:
  23. ...
  24. <Segment JP: '世界,こんにちは.' -> 's e k a i , k o N n i ch i w a .'>
  25. """
  26. def build_html_error_message(error):
  27. return f"""
  28. <div style="color: red; font-weight: bold;">
  29. {html.escape(error)}
  30. </div>
  31. """
  32. def prepare_text(
  33. text,
  34. input_mode,
  35. language0,
  36. language1,
  37. language2,
  38. enable_reference_audio,
  39. reference_text,
  40. ):
  41. lines = text.splitlines()
  42. languages = [language0, language1, language2]
  43. languages = [
  44. {
  45. "中文": "ZH",
  46. "日文": "JP",
  47. "英文": "EN",
  48. }[language]
  49. for language in languages
  50. ]
  51. if len(set(languages)) != len(languages):
  52. return [], build_html_error_message("语言优先级不能重复.")
  53. if enable_reference_audio:
  54. reference_text = reference_text.strip() + " "
  55. else:
  56. reference_text = ""
  57. if input_mode != "自动音素":
  58. return [
  59. [idx, reference_text + line, "-", "-"]
  60. for idx, line in enumerate(lines)
  61. if line.strip() != ""
  62. ], None
  63. rows = []
  64. for idx, line in enumerate(lines):
  65. if line.strip() == "":
  66. continue
  67. try:
  68. segments = parse_text_to_segments(reference_text + line, order=languages)
  69. except Exception:
  70. traceback.print_exc()
  71. err = traceback.format_exc()
  72. return [], build_html_error_message(f"解析 '{line}' 时发生错误. \n\n{err}")
  73. for segment in segments:
  74. rows.append([idx, segment.text, segment.language, " ".join(segment.phones)])
  75. return rows, None
  76. def load_model(
  77. server_url,
  78. llama_ckpt_path,
  79. llama_config_name,
  80. tokenizer,
  81. vqgan_ckpt_path,
  82. vqgan_config_name,
  83. device,
  84. precision,
  85. compile_model,
  86. ):
  87. payload = {
  88. "device": device,
  89. "llama": {
  90. "config_name": llama_config_name,
  91. "checkpoint_path": llama_ckpt_path,
  92. "precision": precision,
  93. "tokenizer": tokenizer,
  94. "compile": compile_model,
  95. },
  96. "vqgan": {
  97. "config_name": vqgan_config_name,
  98. "checkpoint_path": vqgan_ckpt_path,
  99. },
  100. }
  101. try:
  102. resp = requests.put(f"{server_url}/v1/models/default", json=payload)
  103. resp.raise_for_status()
  104. except Exception:
  105. traceback.print_exc()
  106. err = traceback.format_exc()
  107. return build_html_error_message(f"加载模型时发生错误. \n\n{err}")
  108. return "模型加载成功."
  109. def build_model_config_block():
  110. server_url = gr.Textbox(label="服务器地址", value="http://localhost:8000")
  111. with gr.Row():
  112. with gr.Column(scale=1):
  113. device = gr.Dropdown(
  114. label="设备",
  115. choices=["cpu", "cuda"],
  116. value="cuda",
  117. )
  118. with gr.Column(scale=1):
  119. precision = gr.Dropdown(
  120. label="精度",
  121. choices=["bfloat16", "float16"],
  122. value="float16",
  123. )
  124. with gr.Column(scale=1):
  125. compile_model = gr.Checkbox(
  126. label="编译模型",
  127. value=True,
  128. )
  129. llama_ckpt_path = gr.Textbox(
  130. label="Llama 模型路径", value="checkpoints/text2semantic-400m-v0.2-4k.pth"
  131. )
  132. llama_config_name = gr.Textbox(label="Llama 配置文件", value="text2semantic_finetune")
  133. tokenizer = gr.Textbox(label="Tokenizer", value="fishaudio/speech-lm-v1")
  134. vqgan_ckpt_path = gr.Textbox(label="VQGAN 模型路径", value="checkpoints/vqgan-v1.pth")
  135. vqgan_config_name = gr.Textbox(label="VQGAN 配置文件", value="vqgan_pretrain")
  136. load_model_btn = gr.Button(value="加载模型", variant="primary")
  137. error = gr.HTML(label="错误信息")
  138. load_model_btn.click(
  139. load_model,
  140. [
  141. server_url,
  142. llama_ckpt_path,
  143. llama_config_name,
  144. tokenizer,
  145. vqgan_ckpt_path,
  146. vqgan_config_name,
  147. device,
  148. precision,
  149. compile_model,
  150. ],
  151. [error],
  152. )
  153. return server_url
  154. def inference(
  155. server_url,
  156. text,
  157. input_mode,
  158. language0,
  159. language1,
  160. language2,
  161. enable_reference_audio,
  162. reference_audio,
  163. reference_text,
  164. max_new_tokens,
  165. top_k,
  166. top_p,
  167. repetition_penalty,
  168. temperature,
  169. speaker,
  170. ):
  171. languages = [language0, language1, language2]
  172. languages = [
  173. {
  174. "中文": "zh",
  175. "日文": "jp",
  176. "英文": "en",
  177. }[language]
  178. for language in languages
  179. ]
  180. if len(set(languages)) != len(languages):
  181. return [], build_html_error_message("语言优先级不能重复.")
  182. order = ",".join(languages)
  183. payload = {
  184. "text": text,
  185. "prompt_text": reference_text if enable_reference_audio else None,
  186. "prompt_tokens": reference_audio if enable_reference_audio else None,
  187. "max_new_tokens": int(max_new_tokens),
  188. "top_k": int(top_k) if top_k > 0 else None,
  189. "top_p": top_p,
  190. "repetition_penalty": repetition_penalty,
  191. "temperature": temperature,
  192. "order": order,
  193. "use_g2p": input_mode == "自动音素",
  194. "seed": None,
  195. "speaker": speaker if speaker.strip() != "" else None,
  196. }
  197. try:
  198. resp = requests.post(f"{server_url}/v1/models/default/invoke", json=payload)
  199. resp.raise_for_status()
  200. except Exception:
  201. traceback.print_exc()
  202. err = traceback.format_exc()
  203. return [], build_html_error_message(f"推理时发生错误. \n\n{err}")
  204. content = io.BytesIO(resp.content)
  205. content.seek(0)
  206. content, sr = librosa.load(content, sr=None, mono=True)
  207. return (sr, content), None
  208. with gr.Blocks(theme=gr.themes.Base()) as app:
  209. gr.Markdown(HEADER_MD)
  210. # Use light theme by default
  211. app.load(
  212. None,
  213. None,
  214. js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', 'light');window.location.search = params.toString();}}",
  215. )
  216. # Inference
  217. with gr.Row():
  218. with gr.Column(scale=3):
  219. with gr.Tab(label="模型配置"):
  220. server_url = build_model_config_block()
  221. with gr.Tab(label="推理配置"):
  222. text = gr.Textbox(
  223. label="输入文本", placeholder=TEXTBOX_PLACEHOLDER, lines=15
  224. )
  225. with gr.Row():
  226. with gr.Tab(label="合成参数"):
  227. gr.Markdown("配置常见合成参数. 自动音素会在推理时自动将文本转换为音素.")
  228. input_mode = gr.Dropdown(
  229. choices=["文本", "自动音素"],
  230. value="文本",
  231. label="输入模式",
  232. )
  233. max_new_tokens = gr.Slider(
  234. label="最大生成 Token 数",
  235. minimum=0,
  236. maximum=4096,
  237. value=0, # 0 means no limit
  238. step=8,
  239. )
  240. top_k = gr.Slider(
  241. label="Top-K", minimum=0, maximum=100, value=0, step=1
  242. )
  243. top_p = gr.Slider(
  244. label="Top-P", minimum=0, maximum=1, value=0.5, step=0.01
  245. )
  246. repetition_penalty = gr.Slider(
  247. label="重复惩罚", minimum=0, maximum=2, value=1.5, step=0.01
  248. )
  249. temperature = gr.Slider(
  250. label="温度", minimum=0, maximum=2, value=0.7, step=0.01
  251. )
  252. speaker = gr.Textbox(
  253. label="说话人",
  254. placeholder="说话人",
  255. lines=1,
  256. )
  257. with gr.Tab(label="语言优先级"):
  258. gr.Markdown("该参数只在自动音素转换时生效.")
  259. with gr.Column(scale=1):
  260. language0 = gr.Dropdown(
  261. choices=["中文", "日文", "英文"],
  262. label="语言 1",
  263. value="中文",
  264. )
  265. with gr.Column(scale=1):
  266. language1 = gr.Dropdown(
  267. choices=["中文", "日文", "英文"],
  268. label="语言 2",
  269. value="日文",
  270. )
  271. with gr.Column(scale=1):
  272. language2 = gr.Dropdown(
  273. choices=["中文", "日文", "英文"],
  274. label="语言 3",
  275. value="英文",
  276. )
  277. with gr.Tab(label="参考音频"):
  278. gr.Markdown("5-10 秒的参考音频, 适用于指定音色.")
  279. enable_reference_audio = gr.Checkbox(
  280. label="启用参考音频", value=False
  281. )
  282. reference_audio = gr.Audio(
  283. label="参考音频",
  284. value="docs/assets/audios/0_input.wav",
  285. type="filepath",
  286. )
  287. reference_text = gr.Textbox(
  288. label="参考文本",
  289. placeholder="参考文本",
  290. lines=1,
  291. value="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
  292. )
  293. with gr.Row():
  294. with gr.Column(scale=2):
  295. generate = gr.Button(value="合成", variant="primary")
  296. with gr.Column(scale=1):
  297. clear = gr.Button(value="清空")
  298. with gr.Column(scale=3):
  299. error = gr.HTML(label="错误信息")
  300. parsed_text = gr.Dataframe(
  301. label="解析结果 (仅参考)", headers=["ID", "文本", "语言", "音素"]
  302. )
  303. audio = gr.Audio(label="合成音频", type="numpy")
  304. # Language & Text Parsing
  305. kwargs = dict(
  306. inputs=[
  307. text,
  308. input_mode,
  309. language0,
  310. language1,
  311. language2,
  312. enable_reference_audio,
  313. reference_text,
  314. ],
  315. outputs=[parsed_text, error],
  316. trigger_mode="always_last",
  317. )
  318. text.change(prepare_text, **kwargs)
  319. input_mode.change(prepare_text, **kwargs)
  320. language0.change(prepare_text, **kwargs)
  321. language1.change(prepare_text, **kwargs)
  322. language2.change(prepare_text, **kwargs)
  323. enable_reference_audio.change(prepare_text, **kwargs)
  324. # Submit
  325. generate.click(
  326. inference,
  327. [
  328. server_url,
  329. text,
  330. input_mode,
  331. language0,
  332. language1,
  333. language2,
  334. enable_reference_audio,
  335. reference_audio,
  336. reference_text,
  337. max_new_tokens,
  338. top_k,
  339. top_p,
  340. repetition_penalty,
  341. temperature,
  342. speaker,
  343. ],
  344. [audio, error],
  345. )
  346. if __name__ == "__main__":
  347. app.launch(show_api=False)