manage.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823
  1. from __future__ import annotations
  2. import html
  3. import json
  4. import os
  5. import platform
  6. import random
  7. import shutil
  8. import signal
  9. import subprocess
  10. import sys
  11. from pathlib import Path
  12. import gradio as gr
  13. import psutil
  14. import yaml
  15. from loguru import logger
  16. from tqdm import tqdm
  17. from fish_speech.webui.launch_utils import Seafoam, versions_html
  18. PYTHON = os.path.join(os.environ.get("PYTHON_FOLDERPATH", ""), "python")
  19. sys.path.insert(0, "")
  20. print(sys.path)
  21. cur_work_dir = Path(os.getcwd()).resolve()
  22. print("You are in ", str(cur_work_dir))
  23. config_path = cur_work_dir / "fish_speech" / "configs"
  24. vqgan_yml_path = config_path / "vqgan_finetune.yaml"
  25. llama_yml_path = config_path / "text2semantic_sft.yaml"
  26. env = os.environ.copy()
  27. env["no_proxy"] = "127.0.0.1, localhost, 0.0.0.0"
  28. seafoam = Seafoam()
  29. def build_html_error_message(error):
  30. return f"""
  31. <div style="color: red; font-weight: bold;">
  32. {html.escape(error)}
  33. </div>
  34. """
  35. def build_html_ok_message(msg):
  36. return f"""
  37. <div style="color: green; font-weight: bold;">
  38. {html.escape(msg)}
  39. </div>
  40. """
  41. def load_data_in_raw(path):
  42. with open(path, "r", encoding="utf-8") as file:
  43. data = file.read()
  44. return str(data)
  45. def kill_proc_tree(pid, including_parent=True):
  46. try:
  47. parent = psutil.Process(pid)
  48. except psutil.NoSuchProcess:
  49. # Process already terminated
  50. return
  51. children = parent.children(recursive=True)
  52. for child in children:
  53. try:
  54. os.kill(child.pid, signal.SIGTERM) # or signal.SIGKILL
  55. except OSError:
  56. pass
  57. if including_parent:
  58. try:
  59. os.kill(parent.pid, signal.SIGTERM) # or signal.SIGKILL
  60. except OSError:
  61. pass
  62. system = platform.system()
  63. p_label = None
  64. p_infer = None
  65. def kill_process(pid):
  66. if system == "Windows":
  67. cmd = "taskkill /t /f /pid %s" % pid
  68. # os.system(cmd)
  69. subprocess.run(cmd)
  70. else:
  71. kill_proc_tree(pid)
  72. def change_label(if_label):
  73. global p_label
  74. if if_label == True and p_label == None:
  75. cmd = ["asr-label-win-x64.exe"]
  76. yield f"打标工具WebUI已开启, 访问:http://localhost:{3000}"
  77. p_label = subprocess.Popen(cmd, shell=True, env=env)
  78. elif if_label == False and p_label != None:
  79. kill_process(p_label.pid)
  80. p_label = None
  81. yield "打标工具WebUI已关闭"
  82. def change_infer(
  83. if_infer, host, port, infer_vqgan_model, infer_llama_model, infer_compile
  84. ):
  85. global p_infer
  86. if if_infer == True and p_infer == None:
  87. env = os.environ.copy()
  88. env["GRADIO_SERVER_NAME"] = host
  89. env["GRADIO_SERVER_PORT"] = port
  90. # 启动第二个进程
  91. yield build_html_ok_message(f"推理界面已开启, 访问 http://{host}:{port}")
  92. p_infer = subprocess.Popen(
  93. [
  94. PYTHON,
  95. "tools/webui.py",
  96. "--vqgan-checkpoint-path",
  97. infer_vqgan_model,
  98. "--llama-checkpoint-path",
  99. infer_llama_model,
  100. "--tokenizer",
  101. "checkpoints",
  102. ]
  103. + (["--compile"] if infer_compile == "Yes" else []),
  104. env=env,
  105. )
  106. elif if_infer == False and p_infer != None:
  107. kill_process(p_infer.pid)
  108. p_infer = None
  109. yield build_html_error_message("推理界面已关闭")
  110. js = load_data_in_raw("fish_speech/webui/js/animate.js")
  111. css = load_data_in_raw("fish_speech/webui/css/style.css")
  112. data_pre_output = (cur_work_dir / "data").resolve()
  113. default_model_output = (cur_work_dir / "results").resolve()
  114. default_filelist = data_pre_output / "detect.list"
  115. data_pre_output.mkdir(parents=True, exist_ok=True)
  116. items = []
  117. dict_items = {}
  118. def load_yaml_data_in_fact(yml_path):
  119. with open(yml_path, "r", encoding="utf-8") as file:
  120. yml = yaml.safe_load(file)
  121. return yml
  122. def write_yaml_data_in_fact(yml, yml_path):
  123. with open(yml_path, "w", encoding="utf-8") as file:
  124. yaml.safe_dump(yml, file, allow_unicode=True)
  125. return yml
  126. def generate_tree(directory, depth=0, max_depth=None, prefix=""):
  127. if max_depth is not None and depth > max_depth:
  128. return ""
  129. tree_str = ""
  130. files = []
  131. directories = []
  132. for item in os.listdir(directory):
  133. if os.path.isdir(os.path.join(directory, item)):
  134. directories.append(item)
  135. else:
  136. files.append(item)
  137. entries = directories + files
  138. for i, entry in enumerate(entries):
  139. connector = "├── " if i < len(entries) - 1 else "└── "
  140. tree_str += f"{prefix}{connector}{entry}<br />"
  141. if i < len(directories):
  142. extension = "│ " if i < len(entries) - 1 else " "
  143. tree_str += generate_tree(
  144. os.path.join(directory, entry),
  145. depth + 1,
  146. max_depth,
  147. prefix=prefix + extension,
  148. )
  149. return tree_str
  150. def new_explorer(data_path, max_depth):
  151. return gr.Markdown(
  152. elem_classes=["scrollable-component"],
  153. value=generate_tree(data_path, max_depth=max_depth),
  154. )
  155. def add_item(folder: str, method: str, filelist: str, label_lang: str):
  156. folder = folder.strip(" ").strip('"')
  157. filelist = filelist.strip(" ").strip('"')
  158. folder_path = Path(folder)
  159. filelist_path = Path(filelist)
  160. if folder and folder not in items and data_pre_output not in folder_path.parents:
  161. if folder_path.is_dir():
  162. items.append(folder)
  163. dict_items[folder] = dict(
  164. type="folder", method=method, label_lang=label_lang
  165. )
  166. elif folder:
  167. err = folder
  168. return gr.Checkboxgroup(choices=items), build_html_error_message(
  169. f"添加文件夹路径无效: {err}"
  170. )
  171. if (
  172. filelist
  173. and filelist not in items
  174. and data_pre_output not in filelist_path.parents
  175. ):
  176. if filelist_path.is_file():
  177. items.append(filelist)
  178. dict_items[filelist] = dict(
  179. type="file", method=method, label_lang=label_lang
  180. )
  181. elif filelist:
  182. err = filelist
  183. return gr.Checkboxgroup(choices=items), build_html_error_message(
  184. f"添加文件路径无效: {err}"
  185. )
  186. formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
  187. logger.info(formatted_data)
  188. return gr.Checkboxgroup(choices=items), build_html_ok_message("添加文件(夹)路径成功!")
  189. def remove_items(selected_items):
  190. global items, dict_items
  191. to_remove = [item for item in items if item in selected_items]
  192. for item in to_remove:
  193. del dict_items[item]
  194. items = [item for item in items if item in dict_items.keys()]
  195. formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
  196. logger.info(formatted_data)
  197. return gr.Checkboxgroup(choices=items, value=[]), build_html_ok_message(
  198. "删除文件(夹)路径成功!"
  199. )
  200. def show_selected(options):
  201. selected_options = ", ".join(options)
  202. return f"你选中了: {selected_options}" if options else "你没有选中任何选项"
  203. def list_copy(list_file_path, method):
  204. wav_root = data_pre_output
  205. lst = []
  206. with list_file_path.open("r", encoding="utf-8") as file:
  207. for line in tqdm(file, desc="Processing audio/transcript"):
  208. wav_path, speaker_name, language, text = line.strip().split("|")
  209. original_wav_path = Path(wav_path)
  210. target_wav_path = (
  211. wav_root / original_wav_path.parent.name / original_wav_path.name
  212. )
  213. lst.append(f"{target_wav_path}|{speaker_name}|{language}|{text}")
  214. if target_wav_path.is_file():
  215. continue
  216. target_wav_path.parent.mkdir(parents=True, exist_ok=True)
  217. if method == "复制一份":
  218. shutil.copy(original_wav_path, target_wav_path)
  219. else:
  220. shutil.move(original_wav_path, target_wav_path.parent)
  221. original_lab_path = original_wav_path.with_suffix(".lab")
  222. target_lab_path = (
  223. wav_root
  224. / original_wav_path.parent.name
  225. / original_wav_path.with_suffix(".lab").name
  226. )
  227. if target_lab_path.is_file():
  228. continue
  229. if method == "复制一份":
  230. shutil.copy(original_lab_path, target_lab_path)
  231. else:
  232. shutil.move(original_lab_path, target_lab_path.parent)
  233. if method == "直接移动":
  234. with list_file_path.open("w", encoding="utf-8") as file:
  235. file.writelines("\n".join(lst))
  236. del lst
  237. return build_html_ok_message("使用filelist")
  238. def check_files(data_path: str, max_depth: int, label_model: str, label_device: str):
  239. dict_to_language = {"中文": "ZH", "英文": "EN", "日文": "JP", "不打标": "WTF"}
  240. global dict_items
  241. data_path = Path(data_path)
  242. for item, content in dict_items.items():
  243. item_path = Path(item)
  244. tar_path = data_path / item_path.name
  245. if content["type"] == "folder" and item_path.is_dir():
  246. cur_lang = dict_to_language[content["label_lang"]]
  247. if cur_lang != "WTF":
  248. try:
  249. subprocess.run(
  250. [
  251. PYTHON,
  252. "tools/whisper_asr.py",
  253. "--model-size",
  254. label_model,
  255. "--device",
  256. label_device,
  257. "--audio-dir",
  258. item_path,
  259. "--save-dir",
  260. item_path,
  261. "--language",
  262. cur_lang,
  263. ],
  264. env=env,
  265. )
  266. except Exception:
  267. print("Transcription error occurred")
  268. if content["method"] == "复制一份":
  269. os.makedirs(tar_path, exist_ok=True)
  270. shutil.copytree(
  271. src=str(item_path), dst=str(tar_path), dirs_exist_ok=True
  272. )
  273. elif not tar_path.is_dir():
  274. shutil.move(src=str(item_path), dst=str(tar_path))
  275. elif content["type"] == "file" and item_path.is_file():
  276. list_copy(item_path, content["method"])
  277. return build_html_ok_message("文件移动完毕"), new_explorer(data_path, max_depth=max_depth)
  278. def train_process(
  279. data_path: str,
  280. option: str,
  281. # vq-gan config
  282. vqgan_lr,
  283. vqgan_maxsteps,
  284. vqgan_data_num_workers,
  285. vqgan_data_batch_size,
  286. vqgan_data_val_batch_size,
  287. vqgan_precision,
  288. vqgan_check_interval,
  289. # llama config
  290. llama_lr,
  291. llama_maxsteps,
  292. llama_limit_val_batches,
  293. llama_data_num_workers,
  294. llama_data_batch_size,
  295. llama_data_max_length,
  296. llama_precision,
  297. llama_check_interval,
  298. llama_grad_batches,
  299. llama_use_speaker,
  300. ):
  301. backend = "nccl" if sys.platform == "linux" else "gloo"
  302. if option == "VQGAN" or option == "all":
  303. subprocess.run(
  304. [
  305. PYTHON,
  306. "tools/vqgan/create_train_split.py",
  307. str(data_pre_output.relative_to(cur_work_dir)),
  308. ]
  309. )
  310. train_cmd = [
  311. PYTHON,
  312. "fish_speech/train.py",
  313. "--config-name",
  314. "vqgan_finetune",
  315. f"trainer.strategy.process_group_backend={backend}",
  316. f"model.optimizer.lr={vqgan_lr}",
  317. f"trainer.max_steps={vqgan_maxsteps}",
  318. f"data.num_workers={vqgan_data_num_workers}",
  319. f"data.batch_size={vqgan_data_batch_size}",
  320. f"data.val_batch_size={vqgan_data_val_batch_size}",
  321. f"trainer.precision={vqgan_precision}",
  322. f"trainer.val_check_interval={vqgan_check_interval}",
  323. f"train_dataset.filelist={str(data_pre_output / 'vq_train_filelist.txt')}",
  324. f"val_dataset.filelist={str(data_pre_output / 'vq_val_filelist.txt')}",
  325. ]
  326. logger.info(train_cmd)
  327. subprocess.run(train_cmd)
  328. if option == "LLAMA" or option == "all":
  329. subprocess.run(
  330. [
  331. PYTHON,
  332. "tools/vqgan/extract_vq.py",
  333. str(data_pre_output),
  334. "--num-workers",
  335. "1",
  336. "--batch-size",
  337. "16",
  338. "--config-name",
  339. "vqgan_pretrain",
  340. "--checkpoint-path",
  341. "checkpoints/vq-gan-group-fsq-2x1024.pth",
  342. ]
  343. )
  344. subprocess.run(
  345. [
  346. PYTHON,
  347. "tools/llama/build_dataset.py",
  348. "--input",
  349. str(data_pre_output),
  350. "--text-extension",
  351. ".lab",
  352. "--num-workers",
  353. "16",
  354. ]
  355. )
  356. train_cmd = [
  357. PYTHON,
  358. "fish_speech/train.py",
  359. "--config-name",
  360. "text2semantic_sft",
  361. f"trainer.strategy.process_group_backend={backend}",
  362. "model@model.model=dual_ar_2_codebook_medium",
  363. "tokenizer.pretrained_model_name_or_path=checkpoints",
  364. f"train_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
  365. f"val_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
  366. f"model.optimizer.lr={llama_lr}",
  367. f"trainer.max_steps={llama_maxsteps}",
  368. f"trainer.limit_val_batches={llama_limit_val_batches}",
  369. f"data.num_workers={llama_data_num_workers}",
  370. f"data.batch_size={llama_data_batch_size}",
  371. f"max_length={llama_data_max_length}",
  372. f"trainer.precision={llama_precision}",
  373. f"trainer.val_check_interval={llama_check_interval}",
  374. f"trainer.accumulate_grad_batches={llama_grad_batches}",
  375. f"train_dataset.use_speaker={llama_use_speaker}",
  376. ]
  377. logger.info(train_cmd)
  378. subprocess.run(train_cmd)
  379. return build_html_ok_message("训练终止")
  380. init_vqgan_yml = load_yaml_data_in_fact(vqgan_yml_path)
  381. init_llama_yml = load_yaml_data_in_fact(llama_yml_path)
  382. with gr.Blocks(
  383. head="<style>\n" + css + "\n</style>",
  384. js=js,
  385. theme=seafoam,
  386. analytics_enabled=False,
  387. title="Fish-Speech 鱼语",
  388. ) as demo:
  389. with gr.Row():
  390. with gr.Column():
  391. with gr.Tab("\U0001F4D6 数据集准备"):
  392. with gr.Row():
  393. textbox = gr.Textbox(
  394. label="\U0000270F 输入音频&转写源文件夹路径",
  395. info="音频装在一个以说话人命名的文件夹内作为区分",
  396. interactive=True,
  397. )
  398. transcript_path = gr.Textbox(
  399. label="\U0001F4DD 转写文本filelist所在路径",
  400. info="支持 Bert-Vits2 / GPT-SoVITS 格式",
  401. interactive=True,
  402. )
  403. with gr.Row(equal_height=False):
  404. with gr.Column():
  405. output_radio = gr.Radio(
  406. label="\U0001F4C1 选择源文件(夹)处理方式",
  407. choices=["复制一份", "直接移动"],
  408. value="复制一份",
  409. interactive=True,
  410. )
  411. with gr.Column():
  412. error = gr.HTML(label="错误信息")
  413. if_label = gr.Checkbox(
  414. label="是否开启打标WebUI", scale=0, show_label=True
  415. )
  416. with gr.Row():
  417. add_button = gr.Button("\U000027A1提交到处理区", variant="primary")
  418. remove_button = gr.Button("\U000026D4 取消所选内容")
  419. with gr.Row():
  420. label_device = gr.Dropdown(
  421. label="打标设备",
  422. info="建议使用cuda, 实在是低配置再用cpu",
  423. choices=["cpu", "cuda"],
  424. value="cuda",
  425. interactive=True,
  426. )
  427. label_model = gr.Dropdown(
  428. label="打标模型大小",
  429. info="显存10G以上用large, 5G用medium, 2G用small",
  430. choices=["large", "medium", "small"],
  431. value="small",
  432. interactive=True,
  433. )
  434. label_radio = gr.Dropdown(
  435. label="(可选)打标语言",
  436. info="如果没有音频对应的文本,则进行辅助打标, 支持.txt或.lab格式",
  437. choices=["中文", "日文", "英文", "不打标"],
  438. value="不打标",
  439. interactive=True,
  440. )
  441. with gr.Tab("\U0001F6E0 训练配置项"): # hammer
  442. with gr.Column():
  443. with gr.Row():
  444. model_type_radio = gr.Radio(
  445. label="选择要训练的模型类型",
  446. interactive=True,
  447. choices=["VQGAN", "LLAMA", "all"],
  448. value="all",
  449. )
  450. with gr.Row():
  451. with gr.Accordion("VQGAN配置项", open=False):
  452. with gr.Row(equal_height=False):
  453. vqgan_lr_slider = gr.Slider(
  454. label="初始学习率",
  455. interactive=True,
  456. minimum=1e-5,
  457. maximum=1e-4,
  458. step=1e-5,
  459. value=init_vqgan_yml["model"]["optimizer"]["lr"],
  460. )
  461. vqgan_maxsteps_slider = gr.Slider(
  462. label="训练最大步数",
  463. interactive=True,
  464. minimum=1000,
  465. maximum=100000,
  466. step=1000,
  467. value=init_vqgan_yml["trainer"]["max_steps"],
  468. )
  469. with gr.Row(equal_height=False):
  470. vqgan_data_num_workers_slider = gr.Slider(
  471. label="num_workers",
  472. interactive=True,
  473. minimum=1,
  474. maximum=16,
  475. step=1,
  476. value=init_vqgan_yml["data"]["num_workers"],
  477. )
  478. vqgan_data_batch_size_slider = gr.Slider(
  479. label="batch_size",
  480. interactive=True,
  481. minimum=1,
  482. maximum=32,
  483. step=1,
  484. value=init_vqgan_yml["data"]["batch_size"],
  485. )
  486. with gr.Row(equal_height=False):
  487. vqgan_data_val_batch_size_slider = gr.Slider(
  488. label="val_batch_size",
  489. interactive=True,
  490. minimum=1,
  491. maximum=32,
  492. step=1,
  493. value=init_vqgan_yml["data"]["val_batch_size"],
  494. )
  495. vqgan_precision_dropdown = gr.Dropdown(
  496. label="训练精度",
  497. interactive=True,
  498. choices=["32", "bf16-true", "bf16-mixed"],
  499. value=str(init_vqgan_yml["trainer"]["precision"]),
  500. )
  501. with gr.Row(equal_height=False):
  502. vqgan_check_interval_slider = gr.Slider(
  503. label="每n步保存一个模型",
  504. interactive=True,
  505. minimum=500,
  506. maximum=10000,
  507. step=500,
  508. value=init_vqgan_yml["trainer"][
  509. "val_check_interval"
  510. ],
  511. )
  512. with gr.Row():
  513. with gr.Accordion("LLAMA配置项", open=False):
  514. with gr.Row(equal_height=False):
  515. llama_lr_slider = gr.Slider(
  516. label="初始学习率",
  517. interactive=True,
  518. minimum=1e-5,
  519. maximum=1e-4,
  520. step=1e-5,
  521. value=init_llama_yml["model"]["optimizer"]["lr"],
  522. )
  523. llama_maxsteps_slider = gr.Slider(
  524. label="训练最大步数",
  525. interactive=True,
  526. minimum=1000,
  527. maximum=100000,
  528. step=1000,
  529. value=init_llama_yml["trainer"]["max_steps"],
  530. )
  531. with gr.Row(equal_height=False):
  532. llama_limit_val_batches_slider = gr.Slider(
  533. label="limit_val_batches",
  534. interactive=True,
  535. minimum=1,
  536. maximum=20,
  537. step=1,
  538. value=init_llama_yml["trainer"][
  539. "limit_val_batches"
  540. ],
  541. )
  542. llama_data_num_workers_slider = gr.Slider(
  543. label="num_workers",
  544. minimum=0,
  545. maximum=16,
  546. step=1,
  547. value=init_llama_yml["data"]["num_workers"]
  548. if sys.platform == "linux"
  549. else 0,
  550. )
  551. with gr.Row(equal_height=False):
  552. llama_data_batch_size_slider = gr.Slider(
  553. label="batch_size",
  554. interactive=True,
  555. minimum=1,
  556. maximum=32,
  557. step=1,
  558. value=init_llama_yml["data"]["batch_size"],
  559. )
  560. llama_data_max_length_slider = gr.Slider(
  561. label="max_length",
  562. interactive=True,
  563. minimum=1024,
  564. maximum=4096,
  565. step=128,
  566. value=init_llama_yml["max_length"],
  567. )
  568. with gr.Row(equal_height=False):
  569. llama_precision_dropdown = gr.Dropdown(
  570. label="训练精度",
  571. interactive=True,
  572. choices=["32", "bf16-true", "16-mixed"],
  573. value="bf16-true",
  574. )
  575. llama_check_interval_slider = gr.Slider(
  576. label="每n步保存一个模型",
  577. interactive=True,
  578. minimum=500,
  579. maximum=10000,
  580. step=500,
  581. value=init_llama_yml["trainer"][
  582. "val_check_interval"
  583. ],
  584. )
  585. with gr.Row(equal_height=False):
  586. llama_grad_batches = gr.Slider(
  587. label="accumulate_grad_batches",
  588. interactive=True,
  589. minimum=1,
  590. maximum=20,
  591. step=1,
  592. value=init_llama_yml["trainer"][
  593. "accumulate_grad_batches"
  594. ],
  595. )
  596. llama_use_speaker = gr.Slider(
  597. label="use_speaker_ratio",
  598. interactive=True,
  599. minimum=0.1,
  600. maximum=1.0,
  601. step=0.05,
  602. value=init_llama_yml["train_dataset"][
  603. "use_speaker"
  604. ],
  605. )
  606. with gr.Tab("\U0001F9E0 进入推理界面"):
  607. with gr.Column():
  608. with gr.Row():
  609. with gr.Accordion(label="\U0001F5A5 推理服务器配置", open=False):
  610. with gr.Row():
  611. infer_host_textbox = gr.Textbox(
  612. label="Webui启动服务器地址", value="127.0.0.1"
  613. )
  614. infer_port_textbox = gr.Textbox(
  615. label="Webui启动服务器端口", value="7862"
  616. )
  617. with gr.Row():
  618. infer_vqgan_model = gr.Textbox(
  619. label="VQGAN模型位置",
  620. placeholder="填写pth/ckpt文件路径",
  621. value="checkpoints/vq-gan-group-fsq-2x1024.pth",
  622. )
  623. with gr.Row():
  624. infer_llama_model = gr.Textbox(
  625. label="LLAMA模型位置",
  626. placeholder="填写pth/ckpt文件路径",
  627. value="checkpoints/text2semantic-medium-v1-2k.pth",
  628. )
  629. with gr.Row():
  630. infer_compile = gr.Radio(
  631. label="是否编译模型?", choices=["Yes", "No"], value="Yes"
  632. )
  633. with gr.Row():
  634. infer_checkbox = gr.Checkbox(label="是否打开推理界面")
  635. infer_error = gr.HTML(label="推理界面错误信息")
  636. with gr.Column():
  637. train_error = gr.HTML(label="训练时的报错信息")
  638. checkbox_group = gr.CheckboxGroup(
  639. label="\U0001F4CA 数据源列表",
  640. info="左侧输入文件夹所在路径或filelist。无论是否勾选,在此列表中都会被用以后续训练。",
  641. elem_classes=["data_src"],
  642. )
  643. train_box = gr.Textbox(
  644. label="数据预处理文件夹路径", value=str(data_pre_output), interactive=False
  645. )
  646. model_box = gr.Textbox(
  647. label="\U0001F4BE 模型输出路径",
  648. value=str(default_model_output),
  649. interactive=False,
  650. )
  651. with gr.Accordion(
  652. "查看预处理文件夹状态 (滑块为显示深度大小)",
  653. elem_classes=["scrollable-component"],
  654. elem_id="file_accordion",
  655. ):
  656. tree_slider = gr.Slider(
  657. minimum=0,
  658. maximum=3,
  659. value=0,
  660. step=1,
  661. show_label=False,
  662. container=False,
  663. )
  664. file_markdown = new_explorer(str(data_pre_output), 0)
  665. with gr.Row(equal_height=False):
  666. admit_btn = gr.Button(
  667. "\U00002705 文件预处理", scale=0, min_width=160, variant="primary"
  668. )
  669. fresh_btn = gr.Button("\U0001F503", scale=0, min_width=80)
  670. help_button = gr.Button("\U00002753", scale=0, min_width=80) # question
  671. train_btn = gr.Button("训练启动!", variant="primary")
  672. footer = load_data_in_raw("fish_speech/webui/html/footer.html")
  673. footer = footer.format(
  674. versions=versions_html(),
  675. api_docs="https://speech.fish.audio/inference/#http-api",
  676. )
  677. gr.HTML(footer, elem_id="footer")
  678. add_button.click(
  679. fn=add_item,
  680. inputs=[textbox, output_radio, transcript_path, label_radio],
  681. outputs=[checkbox_group, error],
  682. )
  683. remove_button.click(
  684. fn=remove_items, inputs=[checkbox_group], outputs=[checkbox_group, error]
  685. )
  686. checkbox_group.change(fn=show_selected, inputs=checkbox_group, outputs=[error])
  687. help_button.click(
  688. fn=None,
  689. js='() => { window.open("https://speech.fish.audio/", "newwindow", "height=100, width=400, '
  690. 'toolbar=no, menubar=no, scrollbars=no, resizable=no, location=no, status=no")}',
  691. )
  692. if_label.change(fn=change_label, inputs=[if_label], outputs=[error])
  693. train_btn.click(
  694. fn=train_process,
  695. inputs=[
  696. train_box,
  697. model_type_radio,
  698. # vq-gan config
  699. vqgan_lr_slider,
  700. vqgan_maxsteps_slider,
  701. vqgan_data_num_workers_slider,
  702. vqgan_data_batch_size_slider,
  703. vqgan_data_val_batch_size_slider,
  704. vqgan_precision_dropdown,
  705. vqgan_check_interval_slider,
  706. # llama config
  707. llama_lr_slider,
  708. llama_maxsteps_slider,
  709. llama_limit_val_batches_slider,
  710. llama_data_num_workers_slider,
  711. llama_data_batch_size_slider,
  712. llama_data_max_length_slider,
  713. llama_precision_dropdown,
  714. llama_check_interval_slider,
  715. llama_grad_batches,
  716. llama_use_speaker,
  717. ],
  718. outputs=[train_error],
  719. )
  720. admit_btn.click(
  721. fn=check_files,
  722. inputs=[train_box, tree_slider, label_model, label_device],
  723. outputs=[error, file_markdown],
  724. )
  725. fresh_btn.click(
  726. fn=new_explorer, inputs=[train_box, tree_slider], outputs=[file_markdown]
  727. )
  728. infer_checkbox.change(
  729. fn=change_infer,
  730. inputs=[
  731. infer_checkbox,
  732. infer_host_textbox,
  733. infer_port_textbox,
  734. infer_vqgan_model,
  735. infer_llama_model,
  736. infer_compile,
  737. ],
  738. outputs=[infer_error],
  739. )
  740. demo.launch(inbrowser=True)