manage.py 40 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073
  1. from __future__ import annotations
  2. import html
  3. import json
  4. import os
  5. import platform
  6. import shutil
  7. import signal
  8. import subprocess
  9. import sys
  10. import webbrowser
  11. from pathlib import Path
  12. import gradio as gr
  13. import psutil
  14. import yaml
  15. from loguru import logger
  16. from tqdm import tqdm
  17. from fish_speech.i18n import i18n
  18. from fish_speech.webui.launch_utils import Seafoam, versions_html
  19. PYTHON = os.path.join(os.environ.get("PYTHON_FOLDERPATH", ""), "python")
  20. sys.path.insert(0, "")
  21. print(sys.path)
  22. cur_work_dir = Path(os.getcwd()).resolve()
  23. print("You are in ", str(cur_work_dir))
  24. config_path = cur_work_dir / "fish_speech" / "configs"
  25. vqgan_yml_path = config_path / "vqgan_finetune.yaml"
  26. llama_yml_path = config_path / "text2semantic_finetune.yaml"
  27. env = os.environ.copy()
  28. env["no_proxy"] = "127.0.0.1, localhost, 0.0.0.0"
  29. seafoam = Seafoam()
  30. def build_html_error_message(error):
  31. return f"""
  32. <div style="color: red; font-weight: bold;">
  33. {html.escape(error)}
  34. </div>
  35. """
  36. def build_html_ok_message(msg):
  37. return f"""
  38. <div style="color: green; font-weight: bold;">
  39. {html.escape(msg)}
  40. </div>
  41. """
  42. def load_data_in_raw(path):
  43. with open(path, "r", encoding="utf-8") as file:
  44. data = file.read()
  45. return str(data)
  46. def kill_proc_tree(pid, including_parent=True):
  47. try:
  48. parent = psutil.Process(pid)
  49. except psutil.NoSuchProcess:
  50. # Process already terminated
  51. return
  52. children = parent.children(recursive=True)
  53. for child in children:
  54. try:
  55. os.kill(child.pid, signal.SIGTERM) # or signal.SIGKILL
  56. except OSError:
  57. pass
  58. if including_parent:
  59. try:
  60. os.kill(parent.pid, signal.SIGTERM) # or signal.SIGKILL
  61. except OSError:
  62. pass
  63. system = platform.system()
  64. p_label = None
  65. p_infer = None
  66. p_tensorboard = None
  67. def kill_process(pid):
  68. if system == "Windows":
  69. cmd = "taskkill /t /f /pid %s" % pid
  70. # os.system(cmd)
  71. subprocess.run(cmd)
  72. else:
  73. kill_proc_tree(pid)
  74. def change_label(if_label):
  75. global p_label
  76. if if_label == True:
  77. # 设置要访问的URL
  78. url = "https://text-labeler.pages.dev/"
  79. webbrowser.open(url)
  80. yield i18n("Opened labeler in browser")
  81. elif if_label == False:
  82. p_label = None
  83. yield "Nothing"
  84. def change_infer(
  85. if_infer,
  86. host,
  87. port,
  88. infer_vqgan_model,
  89. infer_llama_model,
  90. infer_llama_config,
  91. infer_compile,
  92. ):
  93. global p_infer
  94. if if_infer == True and p_infer == None:
  95. env = os.environ.copy()
  96. env["GRADIO_SERVER_NAME"] = host
  97. env["GRADIO_SERVER_PORT"] = port
  98. # 启动第二个进程
  99. url = f"http://{host}:{port}"
  100. yield build_html_ok_message(
  101. i18n("Inferring interface is launched at {}").format(url)
  102. )
  103. p_infer = subprocess.Popen(
  104. [
  105. PYTHON,
  106. "tools/webui.py",
  107. "--vqgan-checkpoint-path",
  108. infer_vqgan_model,
  109. "--llama-checkpoint-path",
  110. infer_llama_model,
  111. "--llama-config-name",
  112. infer_llama_config,
  113. "--tokenizer",
  114. "checkpoints",
  115. ]
  116. + (["--compile"] if infer_compile == "Yes" else []),
  117. env=env,
  118. )
  119. elif if_infer == False and p_infer != None:
  120. kill_process(p_infer.pid)
  121. p_infer = None
  122. yield build_html_error_message(i18n("Infer interface is closed"))
  123. js = load_data_in_raw("fish_speech/webui/js/animate.js")
  124. css = load_data_in_raw("fish_speech/webui/css/style.css")
  125. data_pre_output = (cur_work_dir / "data").resolve()
  126. default_model_output = (cur_work_dir / "results").resolve()
  127. default_filelist = data_pre_output / "detect.list"
  128. data_pre_output.mkdir(parents=True, exist_ok=True)
  129. items = []
  130. dict_items = {}
  131. def load_yaml_data_in_fact(yml_path):
  132. with open(yml_path, "r", encoding="utf-8") as file:
  133. yml = yaml.safe_load(file)
  134. return yml
  135. def write_yaml_data_in_fact(yml, yml_path):
  136. with open(yml_path, "w", encoding="utf-8") as file:
  137. yaml.safe_dump(yml, file, allow_unicode=True)
  138. return yml
  139. def generate_tree(directory, depth=0, max_depth=None, prefix=""):
  140. if max_depth is not None and depth > max_depth:
  141. return ""
  142. tree_str = ""
  143. files = []
  144. directories = []
  145. for item in os.listdir(directory):
  146. if os.path.isdir(os.path.join(directory, item)):
  147. directories.append(item)
  148. else:
  149. files.append(item)
  150. entries = directories + files
  151. for i, entry in enumerate(entries):
  152. connector = "├── " if i < len(entries) - 1 else "└── "
  153. tree_str += f"{prefix}{connector}{entry}<br />"
  154. if i < len(directories):
  155. extension = "│ " if i < len(entries) - 1 else " "
  156. tree_str += generate_tree(
  157. os.path.join(directory, entry),
  158. depth + 1,
  159. max_depth,
  160. prefix=prefix + extension,
  161. )
  162. return tree_str
  163. def new_explorer(data_path, max_depth):
  164. return gr.Markdown(
  165. elem_classes=["scrollable-component"],
  166. value=generate_tree(data_path, max_depth=max_depth),
  167. )
  168. def add_item(folder: str, method: str, label_lang: str):
  169. folder = folder.strip(" ").strip('"')
  170. folder_path = Path(folder)
  171. if folder and folder not in items and data_pre_output not in folder_path.parents:
  172. if folder_path.is_dir():
  173. items.append(folder)
  174. dict_items[folder] = dict(
  175. type="folder", method=method, label_lang=label_lang
  176. )
  177. elif folder:
  178. err = folder
  179. return gr.Checkboxgroup(choices=items), build_html_error_message(
  180. i18n("Invalid path: {}").format(err)
  181. )
  182. formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
  183. logger.info(formatted_data)
  184. return gr.Checkboxgroup(choices=items), build_html_ok_message(
  185. i18n("Added path successfully!")
  186. )
  187. def remove_items(selected_items):
  188. global items, dict_items
  189. to_remove = [item for item in items if item in selected_items]
  190. for item in to_remove:
  191. del dict_items[item]
  192. items = [item for item in items if item in dict_items.keys()]
  193. formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
  194. logger.info(formatted_data)
  195. return gr.Checkboxgroup(choices=items, value=[]), build_html_ok_message(
  196. i18n("Removed path successfully!")
  197. )
  198. def show_selected(options):
  199. selected_options = ", ".join(options)
  200. if options:
  201. return i18n("Selected: {}").format(selected_options)
  202. else:
  203. return i18n("No selected options")
  204. def list_copy(list_file_path, method):
  205. wav_root = data_pre_output
  206. lst = []
  207. with list_file_path.open("r", encoding="utf-8") as file:
  208. for line in tqdm(file, desc="Processing audio/transcript"):
  209. wav_path, speaker_name, language, text = line.strip().split("|")
  210. original_wav_path = Path(wav_path)
  211. target_wav_path = (
  212. wav_root / original_wav_path.parent.name / original_wav_path.name
  213. )
  214. lst.append(f"{target_wav_path}|{speaker_name}|{language}|{text}")
  215. if target_wav_path.is_file():
  216. continue
  217. target_wav_path.parent.mkdir(parents=True, exist_ok=True)
  218. if method == i18n("Copy"):
  219. shutil.copy(original_wav_path, target_wav_path)
  220. else:
  221. shutil.move(original_wav_path, target_wav_path.parent)
  222. original_lab_path = original_wav_path.with_suffix(".lab")
  223. target_lab_path = (
  224. wav_root
  225. / original_wav_path.parent.name
  226. / original_wav_path.with_suffix(".lab").name
  227. )
  228. if target_lab_path.is_file():
  229. continue
  230. if method == i18n("Copy"):
  231. shutil.copy(original_lab_path, target_lab_path)
  232. else:
  233. shutil.move(original_lab_path, target_lab_path.parent)
  234. if method == i18n("Move"):
  235. with list_file_path.open("w", encoding="utf-8") as file:
  236. file.writelines("\n".join(lst))
  237. del lst
  238. return build_html_ok_message(i18n("Use filelist"))
  239. def check_files(data_path: str, max_depth: int, label_model: str, label_device: str):
  240. global dict_items
  241. data_path = Path(data_path)
  242. for item, content in dict_items.items():
  243. item_path = Path(item)
  244. tar_path = data_path / item_path.name
  245. if content["type"] == "folder" and item_path.is_dir():
  246. cur_lang = content["label_lang"]
  247. if cur_lang != "IGNORE":
  248. try:
  249. subprocess.run(
  250. [
  251. PYTHON,
  252. "tools/whisper_asr.py",
  253. "--model-size",
  254. label_model,
  255. "--device",
  256. label_device,
  257. "--audio-dir",
  258. item_path,
  259. "--save-dir",
  260. item_path,
  261. "--language",
  262. cur_lang,
  263. ],
  264. env=env,
  265. )
  266. except Exception:
  267. print("Transcription error occurred")
  268. if content["method"] == i18n("Copy"):
  269. os.makedirs(tar_path, exist_ok=True)
  270. shutil.copytree(
  271. src=str(item_path), dst=str(tar_path), dirs_exist_ok=True
  272. )
  273. elif not tar_path.is_dir():
  274. shutil.move(src=str(item_path), dst=str(tar_path))
  275. elif content["type"] == "file" and item_path.is_file():
  276. list_copy(item_path, content["method"])
  277. return build_html_ok_message(i18n("Move files successfully")), new_explorer(
  278. data_path, max_depth=max_depth
  279. )
  280. def train_process(
  281. data_path: str,
  282. option: str,
  283. # vq-gan config
  284. vqgan_lr,
  285. vqgan_maxsteps,
  286. vqgan_data_num_workers,
  287. vqgan_data_batch_size,
  288. vqgan_data_val_batch_size,
  289. vqgan_precision,
  290. vqgan_check_interval,
  291. # llama config
  292. llama_base_config,
  293. llama_lr,
  294. llama_maxsteps,
  295. llama_data_num_workers,
  296. llama_data_batch_size,
  297. llama_data_max_length,
  298. llama_precision,
  299. llama_check_interval,
  300. llama_grad_batches,
  301. llama_use_speaker,
  302. llama_use_lora,
  303. ):
  304. import datetime
  305. def generate_folder_name():
  306. now = datetime.datetime.now()
  307. folder_name = now.strftime("%Y%m%d_%H%M%S")
  308. return folder_name
  309. backend = "nccl" if sys.platform == "linux" else "gloo"
  310. new_project = generate_folder_name()
  311. print("New Project Name: ", new_project)
  312. if option == "VQGAN" or option == "all":
  313. subprocess.run(
  314. [
  315. PYTHON,
  316. "tools/vqgan/create_train_split.py",
  317. str(data_pre_output.relative_to(cur_work_dir)),
  318. ]
  319. )
  320. train_cmd = [
  321. PYTHON,
  322. "fish_speech/train.py",
  323. "--config-name",
  324. "vqgan_finetune",
  325. f"project={'vqgan_' + new_project}",
  326. f"trainer.strategy.process_group_backend={backend}",
  327. f"model.optimizer.lr={vqgan_lr}",
  328. f"trainer.max_steps={vqgan_maxsteps}",
  329. f"data.num_workers={vqgan_data_num_workers}",
  330. f"data.batch_size={vqgan_data_batch_size}",
  331. f"data.val_batch_size={vqgan_data_val_batch_size}",
  332. f"trainer.precision={vqgan_precision}",
  333. f"trainer.val_check_interval={vqgan_check_interval}",
  334. f"train_dataset.filelist={str(data_pre_output / 'vq_train_filelist.txt')}",
  335. f"val_dataset.filelist={str(data_pre_output / 'vq_val_filelist.txt')}",
  336. ]
  337. logger.info(train_cmd)
  338. subprocess.run(train_cmd)
  339. if option == "LLAMA" or option == "all":
  340. subprocess.run(
  341. [
  342. PYTHON,
  343. "tools/vqgan/extract_vq.py",
  344. str(data_pre_output),
  345. "--num-workers",
  346. "1",
  347. "--batch-size",
  348. "16",
  349. "--config-name",
  350. "vqgan_pretrain",
  351. "--checkpoint-path",
  352. "checkpoints/vq-gan-group-fsq-2x1024.pth",
  353. ]
  354. )
  355. subprocess.run(
  356. [
  357. PYTHON,
  358. "tools/llama/build_dataset.py",
  359. "--input",
  360. str(data_pre_output),
  361. "--text-extension",
  362. ".lab",
  363. "--num-workers",
  364. "16",
  365. ]
  366. )
  367. ckpt_path = (
  368. "text2semantic-pretrain-medium-2k-v1.pth"
  369. if llama_base_config == "dual_ar_2_codebook_medium"
  370. else "text2semantic-sft-large-v1-4k.pth"
  371. )
  372. train_cmd = [
  373. PYTHON,
  374. "fish_speech/train.py",
  375. "--config-name",
  376. "text2semantic_finetune",
  377. f"project={'text2semantic_' + new_project}",
  378. f"ckpt_path=checkpoints/{ckpt_path}",
  379. f"trainer.strategy.process_group_backend={backend}",
  380. f"model@model.model={llama_base_config}",
  381. "tokenizer.pretrained_model_name_or_path=checkpoints",
  382. f"train_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
  383. f"val_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
  384. f"model.optimizer.lr={llama_lr}",
  385. f"trainer.max_steps={llama_maxsteps}",
  386. f"data.num_workers={llama_data_num_workers}",
  387. f"data.batch_size={llama_data_batch_size}",
  388. f"max_length={llama_data_max_length}",
  389. f"trainer.precision={llama_precision}",
  390. f"trainer.val_check_interval={llama_check_interval}",
  391. f"trainer.accumulate_grad_batches={llama_grad_batches}",
  392. f"train_dataset.use_speaker={llama_use_speaker}",
  393. ] + ([f"+lora@model.lora_config=r_8_alpha_16"] if llama_use_lora else [])
  394. logger.info(train_cmd)
  395. subprocess.run(train_cmd)
  396. return build_html_ok_message(i18n("Training stopped"))
  397. def tensorboard_process(
  398. if_tensorboard: bool,
  399. tensorboard_dir: str,
  400. host: str,
  401. port: str,
  402. ):
  403. global p_tensorboard
  404. if if_tensorboard == True and p_tensorboard == None:
  405. url = f"http://{host}:{port}"
  406. yield build_html_ok_message(
  407. i18n("Tensorboard interface is launched at {}").format(url)
  408. )
  409. prefix = ["tensorboard"]
  410. if Path("fishenv").exists():
  411. prefix = ["fishenv/python.exe", "fishenv/Scripts/tensorboard.exe"]
  412. p_tensorboard = subprocess.Popen(
  413. prefix
  414. + [
  415. "--logdir",
  416. tensorboard_dir,
  417. "--host",
  418. host,
  419. "--port",
  420. port,
  421. "--reload_interval",
  422. "120",
  423. ]
  424. )
  425. elif if_tensorboard == False and p_tensorboard != None:
  426. kill_process(p_tensorboard.pid)
  427. p_tensorboard = None
  428. yield build_html_error_message(i18n("Tensorboard interface is closed"))
  429. def fresh_tb_dir():
  430. return gr.Dropdown(
  431. choices=[str(p) for p in Path("results").glob("**/tensorboard/version_*/")]
  432. )
  433. def fresh_vqgan_model():
  434. return gr.Dropdown(
  435. choices=[init_vqgan_yml["ckpt_path"]]
  436. + [str(p) for p in Path("results").glob("vqgan*/**/*.ckpt")]
  437. )
  438. def fresh_llama_model():
  439. return gr.Dropdown(
  440. choices=[init_llama_yml["ckpt_path"]]
  441. + [str(p) for p in Path("results").glob("text2sem*/**/*.ckpt")]
  442. )
  443. def llama_lora_merge(llama_weight, lora_weight, llama_lora_output):
  444. if (
  445. lora_weight is None
  446. or not Path(lora_weight).exists()
  447. or not Path(llama_weight).exists()
  448. ):
  449. return build_html_error_message(
  450. i18n(
  451. "Path error, please check the model file exists in the corresponding path"
  452. )
  453. )
  454. merge_cmd = [
  455. PYTHON,
  456. "tools/llama/merge_lora.py",
  457. "--llama-config",
  458. "dual_ar_2_codebook_large",
  459. "--lora-config",
  460. "r_8_alpha_16",
  461. "--llama-weight",
  462. llama_weight,
  463. "--lora-weight",
  464. lora_weight,
  465. "--output",
  466. llama_lora_output,
  467. ]
  468. logger.info(merge_cmd)
  469. subprocess.run(merge_cmd)
  470. return build_html_ok_message(i18n("Merge successfully"))
  471. init_vqgan_yml = load_yaml_data_in_fact(vqgan_yml_path)
  472. init_llama_yml = load_yaml_data_in_fact(llama_yml_path)
  473. with gr.Blocks(
  474. head="<style>\n" + css + "\n</style>",
  475. js=js,
  476. theme=seafoam,
  477. analytics_enabled=False,
  478. title="Fish Speech",
  479. ) as demo:
  480. with gr.Row():
  481. with gr.Column():
  482. with gr.Tab("\U0001F4D6 " + i18n("Data Preprocessing")):
  483. with gr.Row():
  484. textbox = gr.Textbox(
  485. label="\U0000270F "
  486. + i18n("Input Audio & Source Path for Transcription"),
  487. info=i18n("Speaker is identified by the folder name"),
  488. interactive=True,
  489. )
  490. with gr.Row(equal_height=False):
  491. with gr.Column():
  492. output_radio = gr.Radio(
  493. label="\U0001F4C1 "
  494. + i18n("Select source file processing method"),
  495. choices=[i18n("Copy"), i18n("Move")],
  496. value=i18n("Copy"),
  497. interactive=True,
  498. )
  499. with gr.Column():
  500. error = gr.HTML(label=i18n("Error Message"))
  501. if_label = gr.Checkbox(
  502. label=i18n("Open Labeler WebUI"), scale=0, show_label=True
  503. )
  504. with gr.Row():
  505. add_button = gr.Button(
  506. "\U000027A1 " + i18n("Add to Processing Area"),
  507. variant="primary",
  508. )
  509. remove_button = gr.Button(
  510. "\U000026D4 " + i18n("Remove Selected Data")
  511. )
  512. with gr.Row():
  513. label_device = gr.Dropdown(
  514. label=i18n("Labeling Device"),
  515. info=i18n(
  516. "It is recommended to use CUDA, if you have low configuration, use CPU"
  517. ),
  518. choices=["cpu", "cuda"],
  519. value="cuda",
  520. interactive=True,
  521. )
  522. label_model = gr.Dropdown(
  523. label=i18n("Whisper Model"),
  524. info=i18n(
  525. "Use large for 10G+ GPU, medium for 5G, small for 2G"
  526. ),
  527. choices=["large", "medium", "small"],
  528. value="small",
  529. interactive=True,
  530. )
  531. label_radio = gr.Dropdown(
  532. label=i18n("Optional Label Language"),
  533. info=i18n(
  534. "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format"
  535. ),
  536. choices=[
  537. (i18n("Chinese"), "ZH"),
  538. (i18n("English"), "EN"),
  539. (i18n("Japanese"), "JA"),
  540. (i18n("Disabled"), "IGNORE"),
  541. ],
  542. value="IGNORE",
  543. interactive=True,
  544. )
  545. with gr.Tab("\U0001F6E0 " + i18n("Training Configuration")):
  546. with gr.Row():
  547. model_type_radio = gr.Radio(
  548. label=i18n("Select the model to be trained"),
  549. interactive=True,
  550. choices=["VQGAN", "LLAMA", "all"],
  551. value="all",
  552. )
  553. with gr.Row():
  554. with gr.Tab(label=i18n("VQGAN Configuration")):
  555. with gr.Row(equal_height=False):
  556. vqgan_lr_slider = gr.Slider(
  557. label=i18n("Initial Learning Rate"),
  558. interactive=True,
  559. minimum=1e-5,
  560. maximum=1e-4,
  561. step=1e-5,
  562. value=init_vqgan_yml["model"]["optimizer"]["lr"],
  563. )
  564. vqgan_maxsteps_slider = gr.Slider(
  565. label=i18n("Maximum Training Steps"),
  566. interactive=True,
  567. minimum=1000,
  568. maximum=100000,
  569. step=1000,
  570. value=init_vqgan_yml["trainer"]["max_steps"],
  571. )
  572. with gr.Row(equal_height=False):
  573. vqgan_data_num_workers_slider = gr.Slider(
  574. label=i18n("Number of Workers"),
  575. interactive=True,
  576. minimum=1,
  577. maximum=16,
  578. step=1,
  579. value=init_vqgan_yml["data"]["num_workers"],
  580. )
  581. vqgan_data_batch_size_slider = gr.Slider(
  582. label=i18n("Batch Size"),
  583. interactive=True,
  584. minimum=1,
  585. maximum=32,
  586. step=1,
  587. value=init_vqgan_yml["data"]["batch_size"],
  588. )
  589. with gr.Row(equal_height=False):
  590. vqgan_data_val_batch_size_slider = gr.Slider(
  591. label=i18n("Validation Batch Size"),
  592. interactive=True,
  593. minimum=1,
  594. maximum=32,
  595. step=1,
  596. value=init_vqgan_yml["data"]["val_batch_size"],
  597. )
  598. vqgan_precision_dropdown = gr.Dropdown(
  599. label=i18n("Precision"),
  600. interactive=True,
  601. choices=["32", "bf16-true", "bf16-mixed"],
  602. info=i18n(
  603. "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU"
  604. ),
  605. value=str(init_vqgan_yml["trainer"]["precision"]),
  606. )
  607. with gr.Row(equal_height=False):
  608. vqgan_check_interval_slider = gr.Slider(
  609. label=i18n("Save model every n steps"),
  610. interactive=True,
  611. minimum=500,
  612. maximum=10000,
  613. step=500,
  614. value=init_vqgan_yml["trainer"]["val_check_interval"],
  615. )
  616. with gr.Tab(label=i18n("LLAMA Configuration")):
  617. with gr.Row(equal_height=False):
  618. llama_use_lora = gr.Checkbox(
  619. label=i18n("Use LoRA"),
  620. info=i18n(
  621. "Use LoRA can save GPU memory, but may reduce the quality of the model"
  622. ),
  623. value=True,
  624. )
  625. with gr.Row(equal_height=False):
  626. llama_lr_slider = gr.Slider(
  627. label=i18n("Initial Learning Rate"),
  628. interactive=True,
  629. minimum=1e-5,
  630. maximum=1e-4,
  631. step=1e-5,
  632. value=init_llama_yml["model"]["optimizer"]["lr"],
  633. )
  634. llama_maxsteps_slider = gr.Slider(
  635. label=i18n("Maximum Training Steps"),
  636. interactive=True,
  637. minimum=1000,
  638. maximum=100000,
  639. step=1000,
  640. value=init_llama_yml["trainer"]["max_steps"],
  641. )
  642. with gr.Row(equal_height=False):
  643. llama_base_config = gr.Dropdown(
  644. label=i18n("Model Size"),
  645. choices=[
  646. "dual_ar_2_codebook_large",
  647. "dual_ar_2_codebook_medium",
  648. ],
  649. value="dual_ar_2_codebook_large",
  650. )
  651. llama_data_num_workers_slider = gr.Slider(
  652. label=i18n("Number of Workers"),
  653. minimum=0,
  654. maximum=16,
  655. step=1,
  656. value=init_llama_yml["data"]["num_workers"]
  657. if sys.platform == "linux"
  658. else 0,
  659. )
  660. with gr.Row(equal_height=False):
  661. llama_data_batch_size_slider = gr.Slider(
  662. label=i18n("Batch Size"),
  663. interactive=True,
  664. minimum=1,
  665. maximum=32,
  666. step=1,
  667. value=init_llama_yml["data"]["batch_size"],
  668. )
  669. llama_data_max_length_slider = gr.Slider(
  670. label=i18n("Maximum Length per Sample"),
  671. interactive=True,
  672. minimum=1024,
  673. maximum=4096,
  674. step=128,
  675. value=init_llama_yml["max_length"],
  676. )
  677. with gr.Row(equal_height=False):
  678. llama_precision_dropdown = gr.Dropdown(
  679. label=i18n("Precision"),
  680. info=i18n(
  681. "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU"
  682. ),
  683. interactive=True,
  684. choices=["32", "bf16-true", "16-mixed"],
  685. value="bf16-true",
  686. )
  687. llama_check_interval_slider = gr.Slider(
  688. label=i18n("Save model every n steps"),
  689. interactive=True,
  690. minimum=500,
  691. maximum=10000,
  692. step=500,
  693. value=init_llama_yml["trainer"]["val_check_interval"],
  694. )
  695. with gr.Row(equal_height=False):
  696. llama_grad_batches = gr.Slider(
  697. label=i18n("Accumulate Gradient Batches"),
  698. interactive=True,
  699. minimum=1,
  700. maximum=20,
  701. step=1,
  702. value=init_llama_yml["trainer"][
  703. "accumulate_grad_batches"
  704. ],
  705. )
  706. llama_use_speaker = gr.Slider(
  707. label=i18n("Probability of applying Speaker Condition"),
  708. interactive=True,
  709. minimum=0.1,
  710. maximum=1.0,
  711. step=0.05,
  712. value=init_llama_yml["train_dataset"]["use_speaker"],
  713. )
  714. with gr.Tab(label=i18n("Merge LoRA")):
  715. with gr.Row(equal_height=False):
  716. llama_weight = gr.Dropdown(
  717. label=i18n("Base LLAMA Model"),
  718. info=i18n("Type the path or select from the dropdown"),
  719. choices=[init_llama_yml["ckpt_path"]],
  720. value=init_llama_yml["ckpt_path"],
  721. allow_custom_value=True,
  722. interactive=True,
  723. )
  724. with gr.Row(equal_height=False):
  725. lora_weight = gr.Dropdown(
  726. label=i18n("LoRA Model to be merged"),
  727. info=i18n("Type the path or select from the dropdown"),
  728. choices=[
  729. str(p)
  730. for p in Path("results").glob("text2*ar/**/*.ckpt")
  731. ],
  732. allow_custom_value=True,
  733. interactive=True,
  734. )
  735. with gr.Row(equal_height=False):
  736. llama_lora_output = gr.Dropdown(
  737. label=i18n("Output Path"),
  738. info=i18n("Type the path or select from the dropdown"),
  739. value="checkpoints/merged.ckpt",
  740. choices=["checkpoints/merged.ckpt"],
  741. allow_custom_value=True,
  742. interactive=True,
  743. )
  744. with gr.Row(equal_height=False):
  745. llama_lora_merge_btn = gr.Button(
  746. value=i18n("Merge"), variant="primary"
  747. )
  748. with gr.Tab(label="Tensorboard"):
  749. with gr.Row(equal_height=False):
  750. tb_host = gr.Textbox(
  751. label=i18n("Tensorboard Host"), value="127.0.0.1"
  752. )
  753. tb_port = gr.Textbox(
  754. label=i18n("Tensorboard Port"), value="11451"
  755. )
  756. with gr.Row(equal_height=False):
  757. tb_dir = gr.Dropdown(
  758. label=i18n("Tensorboard Log Path"),
  759. allow_custom_value=True,
  760. choices=[
  761. str(p)
  762. for p in Path("results").glob(
  763. "**/tensorboard/version_*/"
  764. )
  765. ],
  766. )
  767. with gr.Row(equal_height=False):
  768. if_tb = gr.Checkbox(
  769. label=i18n("Open Tensorboard"),
  770. )
  771. with gr.Tab("\U0001F9E0 " + i18n("Inference Configuration")):
  772. with gr.Column():
  773. with gr.Row():
  774. with gr.Accordion(
  775. label="\U0001F5A5 "
  776. + i18n("Inference Server Configuration"),
  777. open=False,
  778. ):
  779. with gr.Row():
  780. infer_host_textbox = gr.Textbox(
  781. label=i18n("WebUI Host"), value="127.0.0.1"
  782. )
  783. infer_port_textbox = gr.Textbox(
  784. label=i18n("WebUI Port"), value="7862"
  785. )
  786. with gr.Row():
  787. infer_vqgan_model = gr.Dropdown(
  788. label=i18n("VQGAN Model Path"),
  789. info=i18n(
  790. "Type the path or select from the dropdown"
  791. ),
  792. value=init_vqgan_yml["ckpt_path"],
  793. choices=[init_vqgan_yml["ckpt_path"]]
  794. + [
  795. str(p)
  796. for p in Path("results").glob(
  797. "vqgan*/**/*.ckpt"
  798. )
  799. ],
  800. allow_custom_value=True,
  801. )
  802. with gr.Row():
  803. infer_llama_model = gr.Dropdown(
  804. label=i18n("LLAMA Model Path"),
  805. info=i18n(
  806. "Type the path or select from the dropdown"
  807. ),
  808. value=init_llama_yml["ckpt_path"],
  809. choices=[init_llama_yml["ckpt_path"]]
  810. + [
  811. str(p)
  812. for p in Path("results").glob(
  813. "text2sem*/**/*.ckpt"
  814. )
  815. ],
  816. allow_custom_value=True,
  817. )
  818. with gr.Row():
  819. infer_compile = gr.Radio(
  820. label=i18n("Compile Model"),
  821. info=i18n(
  822. "Compile the model can significantly reduce the inference time, but will increase cold start time"
  823. ),
  824. choices=["Yes", "No"],
  825. value="Yes",
  826. )
  827. infer_llama_config = gr.Dropdown(
  828. label=i18n("LLAMA Model Config"),
  829. choices=[
  830. "dual_ar_2_codebook_large",
  831. "dual_ar_2_codebook_medium",
  832. ],
  833. value="dual_ar_2_codebook_large",
  834. allow_custom_value=True,
  835. )
  836. with gr.Row():
  837. infer_checkbox = gr.Checkbox(
  838. label=i18n("Open Inference Server")
  839. )
  840. infer_error = gr.HTML(label=i18n("Inference Server Error"))
  841. with gr.Column():
  842. train_error = gr.HTML(label=i18n("Training Error"))
  843. checkbox_group = gr.CheckboxGroup(
  844. label="\U0001F4CA " + i18n("Data Source"),
  845. info=i18n(
  846. "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list."
  847. ),
  848. elem_classes=["data_src"],
  849. )
  850. train_box = gr.Textbox(
  851. label=i18n("Data Preprocessing Path"),
  852. value=str(data_pre_output),
  853. interactive=False,
  854. )
  855. model_box = gr.Textbox(
  856. label="\U0001F4BE " + i18n("Model Output Path"),
  857. value=str(default_model_output),
  858. interactive=False,
  859. )
  860. with gr.Accordion(
  861. i18n(
  862. "View the status of the preprocessing folder (use the slider to control the depth of the tree)"
  863. ),
  864. elem_classes=["scrollable-component"],
  865. elem_id="file_accordion",
  866. ):
  867. tree_slider = gr.Slider(
  868. minimum=0,
  869. maximum=3,
  870. value=0,
  871. step=1,
  872. show_label=False,
  873. container=False,
  874. )
  875. file_markdown = new_explorer(str(data_pre_output), 0)
  876. with gr.Row(equal_height=False):
  877. admit_btn = gr.Button(
  878. "\U00002705 " + i18n("File Preprocessing"),
  879. scale=0,
  880. min_width=160,
  881. variant="primary",
  882. )
  883. fresh_btn = gr.Button("\U0001F503", scale=0, min_width=80)
  884. help_button = gr.Button("\U00002753", scale=0, min_width=80) # question
  885. train_btn = gr.Button(i18n("Start Training"), variant="primary")
  886. footer = load_data_in_raw("fish_speech/webui/html/footer.html")
  887. footer = footer.format(
  888. versions=versions_html(),
  889. api_docs="https://speech.fish.audio/inference/#http-api",
  890. )
  891. gr.HTML(footer, elem_id="footer")
  892. add_button.click(
  893. fn=add_item,
  894. inputs=[textbox, output_radio, label_radio],
  895. outputs=[checkbox_group, error],
  896. )
  897. remove_button.click(
  898. fn=remove_items, inputs=[checkbox_group], outputs=[checkbox_group, error]
  899. )
  900. checkbox_group.change(fn=show_selected, inputs=checkbox_group, outputs=[error])
  901. help_button.click(
  902. fn=None,
  903. js='() => { window.open("https://speech.fish.audio/", "newwindow", "height=100, width=400, '
  904. 'toolbar=no, menubar=no, scrollbars=no, resizable=no, location=no, status=no")}',
  905. )
  906. if_label.change(fn=change_label, inputs=[if_label], outputs=[error])
  907. train_btn.click(
  908. fn=train_process,
  909. inputs=[
  910. train_box,
  911. model_type_radio,
  912. # vq-gan config
  913. vqgan_lr_slider,
  914. vqgan_maxsteps_slider,
  915. vqgan_data_num_workers_slider,
  916. vqgan_data_batch_size_slider,
  917. vqgan_data_val_batch_size_slider,
  918. vqgan_precision_dropdown,
  919. vqgan_check_interval_slider,
  920. # llama config
  921. llama_base_config,
  922. llama_lr_slider,
  923. llama_maxsteps_slider,
  924. llama_data_num_workers_slider,
  925. llama_data_batch_size_slider,
  926. llama_data_max_length_slider,
  927. llama_precision_dropdown,
  928. llama_check_interval_slider,
  929. llama_grad_batches,
  930. llama_use_speaker,
  931. llama_use_lora,
  932. ],
  933. outputs=[train_error],
  934. )
  935. if_tb.change(
  936. fn=tensorboard_process,
  937. inputs=[if_tb, tb_dir, tb_host, tb_port],
  938. outputs=[train_error],
  939. )
  940. tb_dir.change(fn=fresh_tb_dir, inputs=[], outputs=[tb_dir])
  941. infer_vqgan_model.change(
  942. fn=fresh_vqgan_model, inputs=[], outputs=[infer_vqgan_model]
  943. )
  944. infer_llama_model.change(
  945. fn=fresh_llama_model, inputs=[], outputs=[infer_llama_model]
  946. )
  947. llama_weight.change(fn=fresh_llama_model, inputs=[], outputs=[llama_weight])
  948. admit_btn.click(
  949. fn=check_files,
  950. inputs=[train_box, tree_slider, label_model, label_device],
  951. outputs=[error, file_markdown],
  952. )
  953. fresh_btn.click(
  954. fn=new_explorer, inputs=[train_box, tree_slider], outputs=[file_markdown]
  955. )
  956. llama_lora_merge_btn.click(
  957. fn=llama_lora_merge,
  958. inputs=[llama_weight, lora_weight, llama_lora_output],
  959. outputs=[train_error],
  960. )
  961. infer_checkbox.change(
  962. fn=change_infer,
  963. inputs=[
  964. infer_checkbox,
  965. infer_host_textbox,
  966. infer_port_textbox,
  967. infer_vqgan_model,
  968. infer_llama_model,
  969. infer_llama_config,
  970. infer_compile,
  971. ],
  972. outputs=[infer_error],
  973. )
  974. demo.launch(inbrowser=True)