app.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
  1. import html
  2. import io
  3. import os
  4. import traceback
  5. import wave
  6. from pathlib import Path
  7. import gradio as gr
  8. import librosa
  9. import numpy as np
  10. import requests
  11. from fish_speech.text import parse_text_to_segments
  12. HEADER_MD = """
  13. # Fish Speech
  14. 基于 VQ-GAN 和 Llama 的多语种语音合成. 感谢 Rcell 的 GPT-VITS 提供的思路.
  15. """
  16. TEXTBOX_PLACEHOLDER = """在启用自动音素的情况下, 模型默认会全自动将输入文本转换为音素. 例如:
  17. 测试一下 Hugging face, BGM声音很大吗?那我改一下. 世界、こんにちは。
  18. 会被转换为:
  19. <Segment ZH: '测试一下' -> 'c e4 sh ir4 y i2 x ia4'>
  20. <Segment EN: ' Hugging face, BGM' -> 'HH AH1 G IH0 NG F EY1 S , B AE1 G M'>
  21. <Segment ZH: '声音很大吗?那我改一下.' -> 'sh eng1 y in1 h en3 d a4 m a5 ? n a4 w o2 g ai3 y i2 x ia4 .'>
  22. <Segment ZH: '世界,' -> 'sh ir4 j ie4 ,'>
  23. <Segment JP: 'こんにちは.' -> 'k o N n i ch i w a .'>
  24. 如你所见, 最后的句子被分割为了两个部分, 因为该日文包含了汉字, 你可以使用 <jp>...</jp> 标签来指定日文优先级. 例如:
  25. 测试一下 Hugging face, BGM声音很大吗?那我改一下. <jp>世界、こんにちは。</jp>
  26. 可以看到, 日文部分被正确地分割了出来:
  27. ...
  28. <Segment JP: '世界,こんにちは.' -> 's e k a i , k o N n i ch i w a .'>
  29. """
  30. def build_html_error_message(error):
  31. return f"""
  32. <div style="color: red; font-weight: bold;">
  33. {html.escape(error)}
  34. </div>
  35. """
  36. def prepare_text(
  37. text,
  38. input_mode,
  39. language0,
  40. language1,
  41. language2,
  42. enable_reference_audio,
  43. reference_text,
  44. ):
  45. lines = text.splitlines()
  46. languages = [language0, language1, language2]
  47. languages = [
  48. {
  49. "中文": "ZH",
  50. "日文": "JP",
  51. "英文": "EN",
  52. }[language]
  53. for language in languages
  54. ]
  55. if len(set(languages)) != len(languages):
  56. return [], build_html_error_message("语言优先级不能重复.")
  57. if enable_reference_audio:
  58. reference_text = reference_text.strip() + " "
  59. else:
  60. reference_text = ""
  61. if input_mode != "自动音素":
  62. return [
  63. [idx, reference_text + line, "-", "-"]
  64. for idx, line in enumerate(lines)
  65. if line.strip() != ""
  66. ], None
  67. rows = []
  68. for idx, line in enumerate(lines):
  69. if line.strip() == "":
  70. continue
  71. try:
  72. segments = parse_text_to_segments(reference_text + line, order=languages)
  73. except Exception:
  74. traceback.print_exc()
  75. err = traceback.format_exc()
  76. return [], build_html_error_message(f"解析 '{line}' 时发生错误. \n\n{err}")
  77. for segment in segments:
  78. rows.append([idx, segment.text, segment.language, " ".join(segment.phones)])
  79. return rows, None
  80. def load_model(
  81. server_url,
  82. llama_ckpt_path,
  83. llama_config_name,
  84. tokenizer,
  85. vqgan_ckpt_path,
  86. vqgan_config_name,
  87. device,
  88. precision,
  89. compile_model,
  90. ):
  91. payload = {
  92. "device": device,
  93. "llama": {
  94. "config_name": llama_config_name,
  95. "checkpoint_path": llama_ckpt_path,
  96. "precision": precision,
  97. "tokenizer": tokenizer,
  98. "compile": compile_model,
  99. },
  100. "vqgan": {
  101. "config_name": vqgan_config_name,
  102. "checkpoint_path": vqgan_ckpt_path,
  103. },
  104. }
  105. try:
  106. resp = requests.put(f"{server_url}/v1/models/default", json=payload)
  107. resp.raise_for_status()
  108. except Exception:
  109. traceback.print_exc()
  110. err = traceback.format_exc()
  111. return build_html_error_message(f"加载模型时发生错误. \n\n{err}")
  112. return "模型加载成功."
  113. def build_model_config_block():
  114. server_url = gr.Textbox(label="服务器地址", value="http://localhost:8000")
  115. with gr.Row():
  116. with gr.Column(scale=1):
  117. device = gr.Dropdown(
  118. label="设备",
  119. choices=["cpu", "cuda"],
  120. value="cuda",
  121. )
  122. with gr.Column(scale=1):
  123. precision = gr.Dropdown(
  124. label="精度",
  125. choices=["bfloat16", "float16"],
  126. value="float16",
  127. )
  128. with gr.Column(scale=1):
  129. compile_model = gr.Checkbox(
  130. label="编译模型",
  131. value=True,
  132. )
  133. llama_ckpt_path = gr.Dropdown(
  134. label="Llama 模型路径",
  135. value=str(Path("checkpoints/text2semantic-400m-v0.3-4k.pth")),
  136. choices=[
  137. str(pth_file) for pth_file in Path("results").rglob("**/text*/**/*.ckpt")
  138. ]
  139. + [str(pth_file) for pth_file in Path("checkpoints").rglob("**/*text*.pth")],
  140. allow_custom_value=True,
  141. )
  142. llama_config_name = gr.Textbox(label="Llama 配置文件", value="text2semantic_finetune")
  143. tokenizer = gr.Dropdown(
  144. label="Tokenizer",
  145. value="fishaudio/speech-lm-v1",
  146. choices=["fishaudio/speech-lm-v1", "checkpoints"],
  147. )
  148. vqgan_ckpt_path = gr.Dropdown(
  149. label="VQGAN 模型路径",
  150. value=str(Path("checkpoints/vqgan-v1.pth")),
  151. choices=[
  152. str(pth_file) for pth_file in Path("results").rglob("**/vqgan*/**/*.ckpt")
  153. ]
  154. + [str(pth_file) for pth_file in Path("checkpoints").rglob("**/*vqgan*.pth")],
  155. allow_custom_value=True,
  156. )
  157. vqgan_config_name = gr.Dropdown(
  158. label="VQGAN 配置文件",
  159. value="vqgan_pretrain",
  160. choices=["vqgan_pretrain", "vqgan_finetune"],
  161. )
  162. load_model_btn = gr.Button(value="加载模型", variant="primary")
  163. error = gr.HTML(label="错误信息")
  164. load_model_btn.click(
  165. load_model,
  166. [
  167. server_url,
  168. llama_ckpt_path,
  169. llama_config_name,
  170. tokenizer,
  171. vqgan_ckpt_path,
  172. vqgan_config_name,
  173. device,
  174. precision,
  175. compile_model,
  176. ],
  177. [error],
  178. )
  179. return server_url
  180. def inference(
  181. server_url,
  182. text,
  183. input_mode,
  184. language0,
  185. language1,
  186. language2,
  187. enable_reference_audio,
  188. reference_audio,
  189. reference_text,
  190. max_new_tokens,
  191. top_k,
  192. top_p,
  193. repetition_penalty,
  194. temperature,
  195. speaker,
  196. ):
  197. languages = [language0, language1, language2]
  198. languages = [
  199. {
  200. "中文": "zh",
  201. "日文": "jp",
  202. "英文": "en",
  203. }[language]
  204. for language in languages
  205. ]
  206. if len(set(languages)) != len(languages):
  207. return [], build_html_error_message("语言优先级不能重复.")
  208. order = ",".join(languages)
  209. payload = {
  210. "text": text,
  211. "prompt_text": reference_text if enable_reference_audio else None,
  212. "prompt_tokens": reference_audio if enable_reference_audio else None,
  213. "max_new_tokens": int(max_new_tokens),
  214. "top_k": int(top_k) if top_k > 0 else None,
  215. "top_p": top_p,
  216. "repetition_penalty": repetition_penalty,
  217. "temperature": temperature,
  218. "order": order,
  219. "use_g2p": input_mode == "自动音素",
  220. "seed": None,
  221. "speaker": speaker if speaker.strip() != "" else None,
  222. }
  223. try:
  224. resp = requests.post(f"{server_url}/v1/models/default/invoke", json=payload)
  225. resp.raise_for_status()
  226. except Exception:
  227. traceback.print_exc()
  228. err = traceback.format_exc()
  229. return [], build_html_error_message(f"推理时发生错误. \n\n{err}")
  230. content = io.BytesIO(resp.content)
  231. content.seek(0)
  232. content, sr = librosa.load(content, sr=None, mono=True)
  233. return (sr, content), None
  234. def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=22050):
  235. # copy and paste
  236. wav_buf = io.BytesIO()
  237. with wave.open(wav_buf, "wb") as vfout:
  238. vfout.setnchannels(channels)
  239. vfout.setsampwidth(sample_width)
  240. vfout.setframerate(sample_rate)
  241. vfout.writeframes(frame_input)
  242. wav_buf.seek(0)
  243. return wav_buf.read()
  244. def inference_stream(
  245. server_url,
  246. text,
  247. input_mode,
  248. language0,
  249. language1,
  250. language2,
  251. enable_reference_audio,
  252. reference_audio,
  253. reference_text,
  254. max_new_tokens,
  255. top_k,
  256. top_p,
  257. repetition_penalty,
  258. temperature,
  259. speaker,
  260. ):
  261. languages = [language0, language1, language2]
  262. languages = [
  263. {
  264. "中文": "zh",
  265. "日文": "jp",
  266. "英文": "en",
  267. }[language]
  268. for language in languages
  269. ]
  270. if len(set(languages)) != len(languages):
  271. return []
  272. order = ",".join(languages)
  273. payload = {
  274. "text": text,
  275. "prompt_text": reference_text if enable_reference_audio else None,
  276. "prompt_tokens": reference_audio if enable_reference_audio else None,
  277. "max_new_tokens": int(max_new_tokens),
  278. "top_k": int(top_k) if top_k > 0 else None,
  279. "top_p": top_p,
  280. "repetition_penalty": repetition_penalty,
  281. "temperature": temperature,
  282. "order": order,
  283. "use_g2p": input_mode == "自动音素",
  284. "seed": None,
  285. "speaker": speaker if speaker.strip() != "" else None,
  286. }
  287. resp = requests.post(
  288. f"{server_url}/v1/models/default/invoke_stream", json=payload, stream=True
  289. )
  290. resp.raise_for_status()
  291. yield wave_header_chunk(), None
  292. for chunk in resp.iter_content(chunk_size=None):
  293. if chunk:
  294. content = io.BytesIO(chunk)
  295. content.seek(0)
  296. audio, sr = librosa.load(content, sr=None, mono=True)
  297. print(audio.shape, sr)
  298. yield (np.concatenate([audio], 0) * 32768).astype(np.int16).tobytes(), None
  299. with gr.Blocks(theme=gr.themes.Base()) as app:
  300. gr.Markdown(HEADER_MD)
  301. # Use light theme by default
  302. app.load(
  303. None,
  304. None,
  305. js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', 'light');window.location.search = params.toString();}}",
  306. )
  307. # Inference
  308. with gr.Row():
  309. with gr.Column(scale=3):
  310. with gr.Tab(label="模型配置"):
  311. server_url = build_model_config_block()
  312. with gr.Tab(label="推理配置"):
  313. text = gr.Textbox(
  314. label="输入文本", placeholder=TEXTBOX_PLACEHOLDER, lines=15
  315. )
  316. with gr.Row():
  317. with gr.Tab(label="合成参数"):
  318. gr.Markdown("配置常见合成参数. 自动音素会在推理时自动将文本转换为音素.")
  319. input_mode = gr.Dropdown(
  320. choices=["文本", "自动音素"],
  321. value="文本",
  322. label="输入模式",
  323. )
  324. max_new_tokens = gr.Slider(
  325. label="最大生成 Token 数",
  326. minimum=0,
  327. maximum=4096,
  328. value=0, # 0 means no limit
  329. step=8,
  330. )
  331. top_k = gr.Slider(
  332. label="Top-K", minimum=0, maximum=100, value=0, step=1
  333. )
  334. top_p = gr.Slider(
  335. label="Top-P", minimum=0, maximum=1, value=0.5, step=0.01
  336. )
  337. repetition_penalty = gr.Slider(
  338. label="重复惩罚", minimum=0, maximum=2, value=1.5, step=0.01
  339. )
  340. temperature = gr.Slider(
  341. label="温度", minimum=0, maximum=2, value=0.7, step=0.01
  342. )
  343. speaker = gr.Textbox(
  344. label="说话人",
  345. placeholder="说话人",
  346. lines=1,
  347. )
  348. with gr.Tab(label="语言优先级"):
  349. gr.Markdown("该参数只在自动音素转换时生效.")
  350. with gr.Column(scale=1):
  351. language0 = gr.Dropdown(
  352. choices=["中文", "日文", "英文"],
  353. label="语言 1",
  354. value="中文",
  355. )
  356. with gr.Column(scale=1):
  357. language1 = gr.Dropdown(
  358. choices=["中文", "日文", "英文"],
  359. label="语言 2",
  360. value="日文",
  361. )
  362. with gr.Column(scale=1):
  363. language2 = gr.Dropdown(
  364. choices=["中文", "日文", "英文"],
  365. label="语言 3",
  366. value="英文",
  367. )
  368. with gr.Tab(label="参考音频"):
  369. gr.Markdown("5-10 秒的参考音频, 适用于指定音色.")
  370. enable_reference_audio = gr.Checkbox(
  371. label="启用参考音频", value=False
  372. )
  373. reference_audio = gr.Audio(
  374. label="参考音频",
  375. value="docs/assets/audios/0_input.wav",
  376. type="filepath",
  377. )
  378. reference_text = gr.Textbox(
  379. label="参考文本",
  380. placeholder="参考文本",
  381. lines=1,
  382. value="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
  383. )
  384. with gr.Column(scale=3):
  385. with gr.Row():
  386. error = gr.HTML(label="错误信息")
  387. with gr.Row():
  388. parsed_text = gr.Dataframe(
  389. label="解析结果 (仅参考)", headers=["ID", "文本", "语言", "音素"]
  390. )
  391. with gr.Row():
  392. audio_once = gr.Audio(label="一次合成音频", type="numpy")
  393. with gr.Row():
  394. audio_stream = gr.Audio(
  395. label="流式合成音频",
  396. autoplay=True,
  397. streaming=True,
  398. show_label=True,
  399. interactive=False,
  400. )
  401. with gr.Row():
  402. with gr.Column(scale=3):
  403. generate = gr.Button(value="\U0001F3A7 合成", variant="primary")
  404. stream_generate = gr.Button(
  405. value="\U0001F4A7 流式合成", variant="primary"
  406. )
  407. with gr.Column(scale=1):
  408. audio_download = gr.Button(
  409. value="\U0001F449 下载流式音频", elem_id="audio_download"
  410. )
  411. clear = gr.Button(value="清空")
  412. # Language & Text Parsing
  413. kwargs = dict(
  414. inputs=[
  415. text,
  416. input_mode,
  417. language0,
  418. language1,
  419. language2,
  420. enable_reference_audio,
  421. reference_text,
  422. ],
  423. outputs=[parsed_text, error],
  424. trigger_mode="always_last",
  425. )
  426. text.change(prepare_text, **kwargs)
  427. input_mode.change(prepare_text, **kwargs)
  428. language0.change(prepare_text, **kwargs)
  429. language1.change(prepare_text, **kwargs)
  430. language2.change(prepare_text, **kwargs)
  431. enable_reference_audio.change(prepare_text, **kwargs)
  432. # Submit
  433. generate.click(
  434. inference,
  435. [
  436. server_url,
  437. text,
  438. input_mode,
  439. language0,
  440. language1,
  441. language2,
  442. enable_reference_audio,
  443. reference_audio,
  444. reference_text,
  445. max_new_tokens,
  446. top_k,
  447. top_p,
  448. repetition_penalty,
  449. temperature,
  450. speaker,
  451. ],
  452. [audio_once, error],
  453. )
  454. stream_generate.click(
  455. inference_stream,
  456. [
  457. server_url,
  458. text,
  459. input_mode,
  460. language0,
  461. language1,
  462. language2,
  463. enable_reference_audio,
  464. reference_audio,
  465. reference_text,
  466. max_new_tokens,
  467. top_k,
  468. top_p,
  469. repetition_penalty,
  470. temperature,
  471. speaker,
  472. ],
  473. [audio_stream, error],
  474. ).then(lambda: gr.update(interactive=True), None, [text], queue=False)
  475. audio_download.click(
  476. None,
  477. js="() => { "
  478. 'var btn = document.getElementById("audio_download"); '
  479. "btn.disabled = true; "
  480. "setTimeout(() => { btn.disabled = false; }, 1000); "
  481. 'var win = window.open("http://localhost:8000/v1/models/default/download", '
  482. '"newwindow", "height=100, width=400, toolbar=no, menubar=no, scrollbars=no, '
  483. 'resizable=no, location=no, status=no"); '
  484. "setTimeout(function() { win.close(); }, 1000);"
  485. "}",
  486. )
  487. if __name__ == "__main__":
  488. app.launch(show_api=False)