app.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. import html
  2. import io
  3. import traceback
  4. from pathlib import Path
  5. import gradio as gr
  6. import librosa
  7. import requests
  8. from fish_speech.text import parse_text_to_segments
  9. HEADER_MD = """
  10. # Fish Speech
  11. 基于 VQ-GAN 和 Llama 的多语种语音合成. 感谢 Rcell 的 GPT-VITS 提供的思路.
  12. """
  13. TEXTBOX_PLACEHOLDER = """在启用自动音素的情况下, 模型默认会全自动将输入文本转换为音素. 例如:
  14. 测试一下 Hugging face, BGM声音很大吗?那我改一下. 世界、こんにちは。
  15. 会被转换为:
  16. <Segment ZH: '测试一下' -> 'c e4 sh ir4 y i2 x ia4'>
  17. <Segment EN: ' Hugging face, BGM' -> 'HH AH1 G IH0 NG F EY1 S , B AE1 G M'>
  18. <Segment ZH: '声音很大吗?那我改一下.' -> 'sh eng1 y in1 h en3 d a4 m a5 ? n a4 w o2 g ai3 y i2 x ia4 .'>
  19. <Segment ZH: '世界,' -> 'sh ir4 j ie4 ,'>
  20. <Segment JP: 'こんにちは.' -> 'k o N n i ch i w a .'>
  21. 如你所见, 最后的句子被分割为了两个部分, 因为该日文包含了汉字, 你可以使用 <jp>...</jp> 标签来指定日文优先级. 例如:
  22. 测试一下 Hugging face, BGM声音很大吗?那我改一下. <jp>世界、こんにちは。</jp>
  23. 可以看到, 日文部分被正确地分割了出来:
  24. ...
  25. <Segment JP: '世界,こんにちは.' -> 's e k a i , k o N n i ch i w a .'>
  26. """
  27. def build_html_error_message(error):
  28. return f"""
  29. <div style="color: red; font-weight: bold;">
  30. {html.escape(error)}
  31. </div>
  32. """
  33. def prepare_text(
  34. text,
  35. input_mode,
  36. language0,
  37. language1,
  38. language2,
  39. enable_reference_audio,
  40. reference_text,
  41. ):
  42. lines = text.splitlines()
  43. languages = [language0, language1, language2]
  44. languages = [
  45. {
  46. "中文": "ZH",
  47. "日文": "JP",
  48. "英文": "EN",
  49. }[language]
  50. for language in languages
  51. ]
  52. if len(set(languages)) != len(languages):
  53. return [], build_html_error_message("语言优先级不能重复.")
  54. if enable_reference_audio:
  55. reference_text = reference_text.strip() + " "
  56. else:
  57. reference_text = ""
  58. if input_mode != "自动音素":
  59. return [
  60. [idx, reference_text + line, "-", "-"]
  61. for idx, line in enumerate(lines)
  62. if line.strip() != ""
  63. ], None
  64. rows = []
  65. for idx, line in enumerate(lines):
  66. if line.strip() == "":
  67. continue
  68. try:
  69. segments = parse_text_to_segments(reference_text + line, order=languages)
  70. except Exception:
  71. traceback.print_exc()
  72. err = traceback.format_exc()
  73. return [], build_html_error_message(f"解析 '{line}' 时发生错误. \n\n{err}")
  74. for segment in segments:
  75. rows.append([idx, segment.text, segment.language, " ".join(segment.phones)])
  76. return rows, None
  77. def load_model(
  78. server_url,
  79. llama_ckpt_path,
  80. llama_config_name,
  81. tokenizer,
  82. vqgan_ckpt_path,
  83. vqgan_config_name,
  84. device,
  85. precision,
  86. compile_model,
  87. ):
  88. payload = {
  89. "device": device,
  90. "llama": {
  91. "config_name": llama_config_name,
  92. "checkpoint_path": llama_ckpt_path,
  93. "precision": precision,
  94. "tokenizer": tokenizer,
  95. "compile": compile_model,
  96. },
  97. "vqgan": {
  98. "config_name": vqgan_config_name,
  99. "checkpoint_path": vqgan_ckpt_path,
  100. },
  101. }
  102. try:
  103. resp = requests.put(f"{server_url}/v1/models/default", json=payload)
  104. resp.raise_for_status()
  105. except Exception:
  106. traceback.print_exc()
  107. err = traceback.format_exc()
  108. return build_html_error_message(f"加载模型时发生错误. \n\n{err}")
  109. return "模型加载成功."
  110. def build_model_config_block():
  111. server_url = gr.Textbox(label="服务器地址", value="http://localhost:8000")
  112. with gr.Row():
  113. with gr.Column(scale=1):
  114. device = gr.Dropdown(
  115. label="设备",
  116. choices=["cpu", "cuda"],
  117. value="cuda",
  118. )
  119. with gr.Column(scale=1):
  120. precision = gr.Dropdown(
  121. label="精度",
  122. choices=["bfloat16", "float16"],
  123. value="float16",
  124. )
  125. with gr.Column(scale=1):
  126. compile_model = gr.Checkbox(
  127. label="编译模型",
  128. value=True,
  129. )
  130. llama_ckpt_path = gr.Dropdown(
  131. label="Llama 模型路径",
  132. value=str(Path("checkpoints/text2semantic-400m-v0.3-4k.pth")),
  133. choices=[str(pth_file) for pth_file in Path("results").rglob("*text*/*.ckpt")]
  134. + [str(pth_file) for pth_file in Path("checkpoints").rglob("*text*.pth")],
  135. allow_custom_value=True,
  136. )
  137. llama_config_name = gr.Textbox(label="Llama 配置文件", value="text2semantic_finetune")
  138. tokenizer = gr.Dropdown(
  139. label="Tokenizer",
  140. value="fishaudio/speech-lm-v1",
  141. choices=["fishaudio/speech-lm-v1", "checkpoints"],
  142. )
  143. vqgan_ckpt_path = gr.Dropdown(
  144. label="VQGAN 模型路径",
  145. value=str(Path("checkpoints/vqgan-v1.pth")),
  146. choices=[str(pth_file) for pth_file in Path("results").rglob("*vqgan*/*.ckpt")]
  147. + [str(pth_file) for pth_file in Path("checkpoints").rglob("*vqgan*.pth")],
  148. allow_custom_value=True,
  149. )
  150. vqgan_config_name = gr.Dropdown(
  151. label="VQGAN 配置文件",
  152. value="vqgan_pretrain",
  153. choices=["vqgan_pretrain", "vqgan_finetune"],
  154. )
  155. load_model_btn = gr.Button(value="加载模型", variant="primary")
  156. error = gr.HTML(label="错误信息")
  157. load_model_btn.click(
  158. load_model,
  159. [
  160. server_url,
  161. llama_ckpt_path,
  162. llama_config_name,
  163. tokenizer,
  164. vqgan_ckpt_path,
  165. vqgan_config_name,
  166. device,
  167. precision,
  168. compile_model,
  169. ],
  170. [error],
  171. )
  172. return server_url
  173. def inference(
  174. server_url,
  175. text,
  176. input_mode,
  177. language0,
  178. language1,
  179. language2,
  180. enable_reference_audio,
  181. reference_audio,
  182. reference_text,
  183. max_new_tokens,
  184. top_k,
  185. top_p,
  186. repetition_penalty,
  187. temperature,
  188. speaker,
  189. ):
  190. languages = [language0, language1, language2]
  191. languages = [
  192. {
  193. "中文": "zh",
  194. "日文": "jp",
  195. "英文": "en",
  196. }[language]
  197. for language in languages
  198. ]
  199. if len(set(languages)) != len(languages):
  200. return [], build_html_error_message("语言优先级不能重复.")
  201. order = ",".join(languages)
  202. payload = {
  203. "text": text,
  204. "prompt_text": reference_text if enable_reference_audio else None,
  205. "prompt_tokens": reference_audio if enable_reference_audio else None,
  206. "max_new_tokens": int(max_new_tokens),
  207. "top_k": int(top_k) if top_k > 0 else None,
  208. "top_p": top_p,
  209. "repetition_penalty": repetition_penalty,
  210. "temperature": temperature,
  211. "order": order,
  212. "use_g2p": input_mode == "自动音素",
  213. "seed": None,
  214. "speaker": speaker if speaker.strip() != "" else None,
  215. }
  216. try:
  217. resp = requests.post(f"{server_url}/v1/models/default/invoke", json=payload)
  218. resp.raise_for_status()
  219. except Exception:
  220. traceback.print_exc()
  221. err = traceback.format_exc()
  222. return [], build_html_error_message(f"推理时发生错误. \n\n{err}")
  223. content = io.BytesIO(resp.content)
  224. content.seek(0)
  225. content, sr = librosa.load(content, sr=None, mono=True)
  226. return (sr, content), None
  227. with gr.Blocks(theme=gr.themes.Base()) as app:
  228. gr.Markdown(HEADER_MD)
  229. # Use light theme by default
  230. app.load(
  231. None,
  232. None,
  233. js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', 'light');window.location.search = params.toString();}}",
  234. )
  235. # Inference
  236. with gr.Row():
  237. with gr.Column(scale=3):
  238. with gr.Tab(label="模型配置"):
  239. server_url = build_model_config_block()
  240. with gr.Tab(label="推理配置"):
  241. text = gr.Textbox(
  242. label="输入文本", placeholder=TEXTBOX_PLACEHOLDER, lines=15
  243. )
  244. with gr.Row():
  245. with gr.Tab(label="合成参数"):
  246. gr.Markdown("配置常见合成参数. 自动音素会在推理时自动将文本转换为音素.")
  247. input_mode = gr.Dropdown(
  248. choices=["文本", "自动音素"],
  249. value="文本",
  250. label="输入模式",
  251. )
  252. max_new_tokens = gr.Slider(
  253. label="最大生成 Token 数",
  254. minimum=0,
  255. maximum=4096,
  256. value=0, # 0 means no limit
  257. step=8,
  258. )
  259. top_k = gr.Slider(
  260. label="Top-K", minimum=0, maximum=100, value=0, step=1
  261. )
  262. top_p = gr.Slider(
  263. label="Top-P", minimum=0, maximum=1, value=0.5, step=0.01
  264. )
  265. repetition_penalty = gr.Slider(
  266. label="重复惩罚", minimum=0, maximum=2, value=1.5, step=0.01
  267. )
  268. temperature = gr.Slider(
  269. label="温度", minimum=0, maximum=2, value=0.7, step=0.01
  270. )
  271. speaker = gr.Textbox(
  272. label="说话人",
  273. placeholder="说话人",
  274. lines=1,
  275. )
  276. with gr.Tab(label="语言优先级"):
  277. gr.Markdown("该参数只在自动音素转换时生效.")
  278. with gr.Column(scale=1):
  279. language0 = gr.Dropdown(
  280. choices=["中文", "日文", "英文"],
  281. label="语言 1",
  282. value="中文",
  283. )
  284. with gr.Column(scale=1):
  285. language1 = gr.Dropdown(
  286. choices=["中文", "日文", "英文"],
  287. label="语言 2",
  288. value="日文",
  289. )
  290. with gr.Column(scale=1):
  291. language2 = gr.Dropdown(
  292. choices=["中文", "日文", "英文"],
  293. label="语言 3",
  294. value="英文",
  295. )
  296. with gr.Tab(label="参考音频"):
  297. gr.Markdown("5-10 秒的参考音频, 适用于指定音色.")
  298. enable_reference_audio = gr.Checkbox(
  299. label="启用参考音频", value=False
  300. )
  301. reference_audio = gr.Audio(
  302. label="参考音频",
  303. value="docs/assets/audios/0_input.wav",
  304. type="filepath",
  305. )
  306. reference_text = gr.Textbox(
  307. label="参考文本",
  308. placeholder="参考文本",
  309. lines=1,
  310. value="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
  311. )
  312. with gr.Row():
  313. with gr.Column(scale=2):
  314. generate = gr.Button(value="合成", variant="primary")
  315. with gr.Column(scale=1):
  316. clear = gr.Button(value="清空")
  317. with gr.Column(scale=3):
  318. error = gr.HTML(label="错误信息")
  319. parsed_text = gr.Dataframe(
  320. label="解析结果 (仅参考)", headers=["ID", "文本", "语言", "音素"]
  321. )
  322. audio = gr.Audio(label="合成音频", type="numpy")
  323. # Language & Text Parsing
  324. kwargs = dict(
  325. inputs=[
  326. text,
  327. input_mode,
  328. language0,
  329. language1,
  330. language2,
  331. enable_reference_audio,
  332. reference_text,
  333. ],
  334. outputs=[parsed_text, error],
  335. trigger_mode="always_last",
  336. )
  337. text.change(prepare_text, **kwargs)
  338. input_mode.change(prepare_text, **kwargs)
  339. language0.change(prepare_text, **kwargs)
  340. language1.change(prepare_text, **kwargs)
  341. language2.change(prepare_text, **kwargs)
  342. enable_reference_audio.change(prepare_text, **kwargs)
  343. # Submit
  344. generate.click(
  345. inference,
  346. [
  347. server_url,
  348. text,
  349. input_mode,
  350. language0,
  351. language1,
  352. language2,
  353. enable_reference_audio,
  354. reference_audio,
  355. reference_text,
  356. max_new_tokens,
  357. top_k,
  358. top_p,
  359. repetition_penalty,
  360. temperature,
  361. speaker,
  362. ],
  363. [audio, error],
  364. )
  365. if __name__ == "__main__":
  366. app.launch(show_api=False)