manage.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008
  1. from __future__ import annotations
  2. import html
  3. import json
  4. import os
  5. import platform
  6. import shutil
  7. import signal
  8. import subprocess
  9. import sys
  10. import webbrowser
  11. from pathlib import Path
  12. import gradio as gr
  13. import psutil
  14. import yaml
  15. from loguru import logger
  16. from tqdm import tqdm
  17. from fish_speech.i18n import i18n
  18. from fish_speech.webui.launch_utils import Seafoam, versions_html
  19. PYTHON = os.path.join(os.environ.get("PYTHON_FOLDERPATH", ""), "python")
  20. sys.path.insert(0, "")
  21. print(sys.path)
  22. cur_work_dir = Path(os.getcwd()).resolve()
  23. print("You are in ", str(cur_work_dir))
  24. config_path = cur_work_dir / "fish_speech" / "configs"
  25. vqgan_yml_path = config_path / "vqgan_finetune.yaml"
  26. llama_yml_path = config_path / "text2semantic_finetune.yaml"
  27. env = os.environ.copy()
  28. env["no_proxy"] = "127.0.0.1, localhost, 0.0.0.0"
  29. seafoam = Seafoam()
  30. def build_html_error_message(error):
  31. return f"""
  32. <div style="color: red; font-weight: bold;">
  33. {html.escape(error)}
  34. </div>
  35. """
  36. def build_html_ok_message(msg):
  37. return f"""
  38. <div style="color: green; font-weight: bold;">
  39. {html.escape(msg)}
  40. </div>
  41. """
  42. def load_data_in_raw(path):
  43. with open(path, "r", encoding="utf-8") as file:
  44. data = file.read()
  45. return str(data)
  46. def kill_proc_tree(pid, including_parent=True):
  47. try:
  48. parent = psutil.Process(pid)
  49. except psutil.NoSuchProcess:
  50. # Process already terminated
  51. return
  52. children = parent.children(recursive=True)
  53. for child in children:
  54. try:
  55. os.kill(child.pid, signal.SIGTERM) # or signal.SIGKILL
  56. except OSError:
  57. pass
  58. if including_parent:
  59. try:
  60. os.kill(parent.pid, signal.SIGTERM) # or signal.SIGKILL
  61. except OSError:
  62. pass
  63. system = platform.system()
  64. p_label = None
  65. p_infer = None
  66. p_tensorboard = None
  67. def kill_process(pid):
  68. if system == "Windows":
  69. cmd = "taskkill /t /f /pid %s" % pid
  70. # os.system(cmd)
  71. subprocess.run(cmd)
  72. else:
  73. kill_proc_tree(pid)
  74. def change_label(if_label):
  75. global p_label
  76. if if_label == True:
  77. # 设置要访问的URL
  78. url = "https://text-labeler.pages.dev/"
  79. webbrowser.open(url)
  80. yield i18n("Opened labeler in browser")
  81. elif if_label == False:
  82. p_label = None
  83. yield "Nothing"
  84. def change_infer(
  85. if_infer,
  86. host,
  87. port,
  88. infer_vqgan_model,
  89. infer_llama_model,
  90. infer_llama_config,
  91. infer_compile,
  92. ):
  93. global p_infer
  94. if if_infer == True and p_infer == None:
  95. env = os.environ.copy()
  96. env["GRADIO_SERVER_NAME"] = host
  97. env["GRADIO_SERVER_PORT"] = port
  98. # 启动第二个进程
  99. url = f"http://{host}:{port}"
  100. yield build_html_ok_message(
  101. i18n("Inferring interface is launched at {}").format(url)
  102. )
  103. p_infer = subprocess.Popen(
  104. [
  105. PYTHON,
  106. "tools/webui.py",
  107. "--vqgan-checkpoint-path",
  108. infer_vqgan_model,
  109. "--llama-checkpoint-path",
  110. infer_llama_model,
  111. "--llama-config-name",
  112. infer_llama_config,
  113. "--tokenizer",
  114. "checkpoints",
  115. ]
  116. + (["--compile"] if infer_compile == "Yes" else []),
  117. env=env,
  118. )
  119. elif if_infer == False and p_infer != None:
  120. kill_process(p_infer.pid)
  121. p_infer = None
  122. yield build_html_error_message(i18n("Infer interface is closed"))
  123. js = load_data_in_raw("fish_speech/webui/js/animate.js")
  124. css = load_data_in_raw("fish_speech/webui/css/style.css")
  125. data_pre_output = (cur_work_dir / "data").resolve()
  126. default_model_output = (cur_work_dir / "results").resolve()
  127. default_filelist = data_pre_output / "detect.list"
  128. data_pre_output.mkdir(parents=True, exist_ok=True)
  129. items = []
  130. dict_items = {}
  131. def load_yaml_data_in_fact(yml_path):
  132. with open(yml_path, "r", encoding="utf-8") as file:
  133. yml = yaml.safe_load(file)
  134. return yml
  135. def write_yaml_data_in_fact(yml, yml_path):
  136. with open(yml_path, "w", encoding="utf-8") as file:
  137. yaml.safe_dump(yml, file, allow_unicode=True)
  138. return yml
  139. def generate_tree(directory, depth=0, max_depth=None, prefix=""):
  140. if max_depth is not None and depth > max_depth:
  141. return ""
  142. tree_str = ""
  143. files = []
  144. directories = []
  145. for item in os.listdir(directory):
  146. if os.path.isdir(os.path.join(directory, item)):
  147. directories.append(item)
  148. else:
  149. files.append(item)
  150. entries = directories + files
  151. for i, entry in enumerate(entries):
  152. connector = "├── " if i < len(entries) - 1 else "└── "
  153. tree_str += f"{prefix}{connector}{entry}<br />"
  154. if i < len(directories):
  155. extension = "│ " if i < len(entries) - 1 else " "
  156. tree_str += generate_tree(
  157. os.path.join(directory, entry),
  158. depth + 1,
  159. max_depth,
  160. prefix=prefix + extension,
  161. )
  162. return tree_str
  163. def new_explorer(data_path, max_depth):
  164. return gr.Markdown(
  165. elem_classes=["scrollable-component"],
  166. value=generate_tree(data_path, max_depth=max_depth),
  167. )
  168. def add_item(folder: str, method: str, label_lang: str):
  169. folder = folder.strip(" ").strip('"')
  170. folder_path = Path(folder)
  171. if folder and folder not in items and data_pre_output not in folder_path.parents:
  172. if folder_path.is_dir():
  173. items.append(folder)
  174. dict_items[folder] = dict(
  175. type="folder", method=method, label_lang=label_lang
  176. )
  177. elif folder:
  178. err = folder
  179. return gr.Checkboxgroup(choices=items), build_html_error_message(
  180. f"添加文件夹路径无效: {err}"
  181. )
  182. formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
  183. logger.info(formatted_data)
  184. return gr.Checkboxgroup(choices=items), build_html_ok_message("添加文件(夹)路径成功!")
  185. def remove_items(selected_items):
  186. global items, dict_items
  187. to_remove = [item for item in items if item in selected_items]
  188. for item in to_remove:
  189. del dict_items[item]
  190. items = [item for item in items if item in dict_items.keys()]
  191. formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
  192. logger.info(formatted_data)
  193. return gr.Checkboxgroup(choices=items, value=[]), build_html_ok_message(
  194. "删除文件(夹)路径成功!"
  195. )
  196. def show_selected(options):
  197. selected_options = ", ".join(options)
  198. return f"你选中了: {selected_options}" if options else "你没有选中任何选项"
  199. def list_copy(list_file_path, method):
  200. wav_root = data_pre_output
  201. lst = []
  202. with list_file_path.open("r", encoding="utf-8") as file:
  203. for line in tqdm(file, desc="Processing audio/transcript"):
  204. wav_path, speaker_name, language, text = line.strip().split("|")
  205. original_wav_path = Path(wav_path)
  206. target_wav_path = (
  207. wav_root / original_wav_path.parent.name / original_wav_path.name
  208. )
  209. lst.append(f"{target_wav_path}|{speaker_name}|{language}|{text}")
  210. if target_wav_path.is_file():
  211. continue
  212. target_wav_path.parent.mkdir(parents=True, exist_ok=True)
  213. if method == "复制一份":
  214. shutil.copy(original_wav_path, target_wav_path)
  215. else:
  216. shutil.move(original_wav_path, target_wav_path.parent)
  217. original_lab_path = original_wav_path.with_suffix(".lab")
  218. target_lab_path = (
  219. wav_root
  220. / original_wav_path.parent.name
  221. / original_wav_path.with_suffix(".lab").name
  222. )
  223. if target_lab_path.is_file():
  224. continue
  225. if method == "复制一份":
  226. shutil.copy(original_lab_path, target_lab_path)
  227. else:
  228. shutil.move(original_lab_path, target_lab_path.parent)
  229. if method == "直接移动":
  230. with list_file_path.open("w", encoding="utf-8") as file:
  231. file.writelines("\n".join(lst))
  232. del lst
  233. return build_html_ok_message("使用filelist")
  234. def check_files(data_path: str, max_depth: int, label_model: str, label_device: str):
  235. dict_to_language = {"中文": "ZH", "英文": "EN", "日文": "JP", "不打标": "WTF"}
  236. global dict_items
  237. data_path = Path(data_path)
  238. for item, content in dict_items.items():
  239. item_path = Path(item)
  240. tar_path = data_path / item_path.name
  241. if content["type"] == "folder" and item_path.is_dir():
  242. cur_lang = dict_to_language[content["label_lang"]]
  243. if cur_lang != "WTF":
  244. try:
  245. subprocess.run(
  246. [
  247. PYTHON,
  248. "tools/whisper_asr.py",
  249. "--model-size",
  250. label_model,
  251. "--device",
  252. label_device,
  253. "--audio-dir",
  254. item_path,
  255. "--save-dir",
  256. item_path,
  257. "--language",
  258. cur_lang,
  259. ],
  260. env=env,
  261. )
  262. except Exception:
  263. print("Transcription error occurred")
  264. if content["method"] == "复制一份":
  265. os.makedirs(tar_path, exist_ok=True)
  266. shutil.copytree(
  267. src=str(item_path), dst=str(tar_path), dirs_exist_ok=True
  268. )
  269. elif not tar_path.is_dir():
  270. shutil.move(src=str(item_path), dst=str(tar_path))
  271. elif content["type"] == "file" and item_path.is_file():
  272. list_copy(item_path, content["method"])
  273. return build_html_ok_message("文件移动完毕"), new_explorer(data_path, max_depth=max_depth)
  274. def train_process(
  275. data_path: str,
  276. option: str,
  277. # vq-gan config
  278. vqgan_lr,
  279. vqgan_maxsteps,
  280. vqgan_data_num_workers,
  281. vqgan_data_batch_size,
  282. vqgan_data_val_batch_size,
  283. vqgan_precision,
  284. vqgan_check_interval,
  285. # llama config
  286. llama_base_config,
  287. llama_lr,
  288. llama_maxsteps,
  289. llama_data_num_workers,
  290. llama_data_batch_size,
  291. llama_data_max_length,
  292. llama_precision,
  293. llama_check_interval,
  294. llama_grad_batches,
  295. llama_use_speaker,
  296. llama_use_lora,
  297. ):
  298. import datetime
  299. def generate_folder_name():
  300. now = datetime.datetime.now()
  301. folder_name = now.strftime("%Y%m%d_%H%M%S")
  302. return folder_name
  303. backend = "nccl" if sys.platform == "linux" else "gloo"
  304. new_project = generate_folder_name()
  305. print("New Project Name: ", new_project)
  306. if option == "VQGAN" or option == "all":
  307. subprocess.run(
  308. [
  309. PYTHON,
  310. "tools/vqgan/create_train_split.py",
  311. str(data_pre_output.relative_to(cur_work_dir)),
  312. ]
  313. )
  314. train_cmd = [
  315. PYTHON,
  316. "fish_speech/train.py",
  317. "--config-name",
  318. "vqgan_finetune",
  319. f"project={'vqgan_' + new_project}",
  320. f"trainer.strategy.process_group_backend={backend}",
  321. f"model.optimizer.lr={vqgan_lr}",
  322. f"trainer.max_steps={vqgan_maxsteps}",
  323. f"data.num_workers={vqgan_data_num_workers}",
  324. f"data.batch_size={vqgan_data_batch_size}",
  325. f"data.val_batch_size={vqgan_data_val_batch_size}",
  326. f"trainer.precision={vqgan_precision}",
  327. f"trainer.val_check_interval={vqgan_check_interval}",
  328. f"train_dataset.filelist={str(data_pre_output / 'vq_train_filelist.txt')}",
  329. f"val_dataset.filelist={str(data_pre_output / 'vq_val_filelist.txt')}",
  330. ]
  331. logger.info(train_cmd)
  332. subprocess.run(train_cmd)
  333. if option == "LLAMA" or option == "all":
  334. subprocess.run(
  335. [
  336. PYTHON,
  337. "tools/vqgan/extract_vq.py",
  338. str(data_pre_output),
  339. "--num-workers",
  340. "1",
  341. "--batch-size",
  342. "16",
  343. "--config-name",
  344. "vqgan_pretrain",
  345. "--checkpoint-path",
  346. "checkpoints/vq-gan-group-fsq-2x1024.pth",
  347. ]
  348. )
  349. subprocess.run(
  350. [
  351. PYTHON,
  352. "tools/llama/build_dataset.py",
  353. "--input",
  354. str(data_pre_output),
  355. "--text-extension",
  356. ".lab",
  357. "--num-workers",
  358. "16",
  359. ]
  360. )
  361. ckpt_path = (
  362. "text2semantic-pretrain-medium-2k-v1.pth"
  363. if llama_base_config == "dual_ar_2_codebook_medium"
  364. else "text2semantic-sft-large-v1-4k.pth"
  365. )
  366. train_cmd = [
  367. PYTHON,
  368. "fish_speech/train.py",
  369. "--config-name",
  370. "text2semantic_finetune",
  371. f"project={'text2semantic_' + new_project}",
  372. f"ckpt_path=checkpoints/{ckpt_path}",
  373. f"trainer.strategy.process_group_backend={backend}",
  374. f"model@model.model={llama_base_config}",
  375. "tokenizer.pretrained_model_name_or_path=checkpoints",
  376. f"train_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
  377. f"val_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
  378. f"model.optimizer.lr={llama_lr}",
  379. f"trainer.max_steps={llama_maxsteps}",
  380. f"data.num_workers={llama_data_num_workers}",
  381. f"data.batch_size={llama_data_batch_size}",
  382. f"max_length={llama_data_max_length}",
  383. f"trainer.precision={llama_precision}",
  384. f"trainer.val_check_interval={llama_check_interval}",
  385. f"trainer.accumulate_grad_batches={llama_grad_batches}",
  386. f"train_dataset.use_speaker={llama_use_speaker}",
  387. ] + ([f"+lora@model.lora_config=r_8_alpha_16"] if llama_use_lora else [])
  388. logger.info(train_cmd)
  389. subprocess.run(train_cmd)
  390. return build_html_ok_message("训练终止")
  391. def tensorboard_process(
  392. if_tensorboard: bool,
  393. tensorboard_dir: str,
  394. host: str,
  395. port: str,
  396. ):
  397. global p_tensorboard
  398. if if_tensorboard == True and p_tensorboard == None:
  399. yield build_html_ok_message(f"Tensorboard界面已开启, 访问 http://{host}:{port}")
  400. p_tensorboard = subprocess.Popen(
  401. [
  402. "fishenv/python.exe",
  403. "fishenv/Scripts/tensorboard.exe"
  404. if Path("fishenv").exists()
  405. else "tensorboard",
  406. "--logdir",
  407. tensorboard_dir,
  408. "--host",
  409. host,
  410. "--port",
  411. port,
  412. "--reload_interval",
  413. "120",
  414. ]
  415. )
  416. elif if_tensorboard == False and p_tensorboard != None:
  417. kill_process(p_tensorboard.pid)
  418. p_tensorboard = None
  419. yield build_html_error_message("Tensorboard界面已关闭")
  420. def fresh_tb_dir():
  421. return gr.Dropdown(
  422. choices=[str(p) for p in Path("results").glob("**/tensorboard/version_*/")]
  423. )
  424. def fresh_vqgan_model():
  425. return gr.Dropdown(
  426. choices=[init_vqgan_yml["ckpt_path"]]
  427. + [str(p) for p in Path("results").glob("vqgan*/**/*.ckpt")]
  428. )
  429. def fresh_llama_model():
  430. return gr.Dropdown(
  431. choices=[init_llama_yml["ckpt_path"]]
  432. + [str(p) for p in Path("results").glob("text2sem*/**/*.ckpt")]
  433. )
  434. def llama_lora_merge(llama_weight, lora_weight, llama_lora_output):
  435. if (
  436. lora_weight is None
  437. or not Path(lora_weight).exists()
  438. or not Path(llama_weight).exists()
  439. ):
  440. return build_html_error_message("路径错误,请检查模型文件是否存在于对应路径")
  441. merge_cmd = [
  442. PYTHON,
  443. "tools/llama/merge_lora.py",
  444. "--llama-config",
  445. "dual_ar_2_codebook_large",
  446. "--lora-config",
  447. "r_8_alpha_16",
  448. "--llama-weight",
  449. llama_weight,
  450. "--lora-weight",
  451. lora_weight,
  452. "--output",
  453. llama_lora_output,
  454. ]
  455. logger.info(merge_cmd)
  456. subprocess.run(merge_cmd)
  457. return build_html_ok_message("融合终止")
  458. init_vqgan_yml = load_yaml_data_in_fact(vqgan_yml_path)
  459. init_llama_yml = load_yaml_data_in_fact(llama_yml_path)
  460. with gr.Blocks(
  461. head="<style>\n" + css + "\n</style>",
  462. js=js,
  463. theme=seafoam,
  464. analytics_enabled=False,
  465. title="Fish-Speech 鱼语",
  466. ) as demo:
  467. with gr.Row():
  468. with gr.Column():
  469. with gr.Tab("\U0001F4D6 数据集准备"):
  470. with gr.Row():
  471. textbox = gr.Textbox(
  472. label="\U0000270F 输入音频&转写源文件夹路径",
  473. info="音频装在一个以说话人命名的文件夹内作为区分",
  474. interactive=True,
  475. )
  476. with gr.Row(equal_height=False):
  477. with gr.Column():
  478. output_radio = gr.Radio(
  479. label="\U0001F4C1 选择源文件(夹)处理方式",
  480. choices=["复制一份", "直接移动"],
  481. value="复制一份",
  482. interactive=True,
  483. )
  484. with gr.Column():
  485. error = gr.HTML(label="错误信息")
  486. if_label = gr.Checkbox(
  487. label="是否开启打标WebUI", scale=0, show_label=True
  488. )
  489. with gr.Row():
  490. add_button = gr.Button("\U000027A1提交到处理区", variant="primary")
  491. remove_button = gr.Button("\U000026D4 取消所选内容")
  492. with gr.Row():
  493. label_device = gr.Dropdown(
  494. label="打标设备",
  495. info="建议使用cuda, 实在是低配置再用cpu",
  496. choices=["cpu", "cuda"],
  497. value="cuda",
  498. interactive=True,
  499. )
  500. label_model = gr.Dropdown(
  501. label="打标模型大小",
  502. info="显存10G以上用large, 5G用medium, 2G用small",
  503. choices=["large", "medium", "small"],
  504. value="small",
  505. interactive=True,
  506. )
  507. label_radio = gr.Dropdown(
  508. label="(可选)打标语言",
  509. info="如果没有音频对应的文本,则进行辅助打标, 支持.txt或.lab格式",
  510. choices=["中文", "日文", "英文", "不打标"],
  511. value="不打标",
  512. interactive=True,
  513. )
  514. with gr.Tab("\U0001F6E0 训练配置项"): # hammer
  515. with gr.Row():
  516. model_type_radio = gr.Radio(
  517. label="选择要训练的模型类型",
  518. interactive=True,
  519. choices=["VQGAN", "LLAMA", "all"],
  520. value="all",
  521. )
  522. with gr.Row():
  523. with gr.Tab(label="VQGAN配置项"):
  524. with gr.Row(equal_height=False):
  525. vqgan_lr_slider = gr.Slider(
  526. label="初始学习率",
  527. interactive=True,
  528. minimum=1e-5,
  529. maximum=1e-4,
  530. step=1e-5,
  531. value=init_vqgan_yml["model"]["optimizer"]["lr"],
  532. )
  533. vqgan_maxsteps_slider = gr.Slider(
  534. label="训练最大步数",
  535. interactive=True,
  536. minimum=1000,
  537. maximum=100000,
  538. step=1000,
  539. value=init_vqgan_yml["trainer"]["max_steps"],
  540. )
  541. with gr.Row(equal_height=False):
  542. vqgan_data_num_workers_slider = gr.Slider(
  543. label="num_workers",
  544. interactive=True,
  545. minimum=1,
  546. maximum=16,
  547. step=1,
  548. value=init_vqgan_yml["data"]["num_workers"],
  549. )
  550. vqgan_data_batch_size_slider = gr.Slider(
  551. label="batch_size",
  552. interactive=True,
  553. minimum=1,
  554. maximum=32,
  555. step=1,
  556. value=init_vqgan_yml["data"]["batch_size"],
  557. )
  558. with gr.Row(equal_height=False):
  559. vqgan_data_val_batch_size_slider = gr.Slider(
  560. label="val_batch_size",
  561. interactive=True,
  562. minimum=1,
  563. maximum=32,
  564. step=1,
  565. value=init_vqgan_yml["data"]["val_batch_size"],
  566. )
  567. vqgan_precision_dropdown = gr.Dropdown(
  568. label="训练精度",
  569. interactive=True,
  570. choices=["32", "bf16-true", "bf16-mixed"],
  571. value=str(init_vqgan_yml["trainer"]["precision"]),
  572. )
  573. with gr.Row(equal_height=False):
  574. vqgan_check_interval_slider = gr.Slider(
  575. label="每n步保存一个模型",
  576. interactive=True,
  577. minimum=500,
  578. maximum=10000,
  579. step=500,
  580. value=init_vqgan_yml["trainer"]["val_check_interval"],
  581. )
  582. with gr.Tab(label="LLAMA配置项"):
  583. with gr.Row(equal_height=False):
  584. llama_use_lora = gr.Checkbox(
  585. label="使用lora训练?",
  586. value=True,
  587. )
  588. with gr.Row(equal_height=False):
  589. llama_lr_slider = gr.Slider(
  590. label="初始学习率",
  591. interactive=True,
  592. minimum=1e-5,
  593. maximum=1e-4,
  594. step=1e-5,
  595. value=init_llama_yml["model"]["optimizer"]["lr"],
  596. )
  597. llama_maxsteps_slider = gr.Slider(
  598. label="训练最大步数",
  599. interactive=True,
  600. minimum=1000,
  601. maximum=100000,
  602. step=1000,
  603. value=init_llama_yml["trainer"]["max_steps"],
  604. )
  605. with gr.Row(equal_height=False):
  606. llama_base_config = gr.Dropdown(
  607. label="模型基础属性",
  608. choices=[
  609. "dual_ar_2_codebook_large",
  610. "dual_ar_2_codebook_medium",
  611. ],
  612. value="dual_ar_2_codebook_large",
  613. )
  614. llama_data_num_workers_slider = gr.Slider(
  615. label="num_workers",
  616. minimum=0,
  617. maximum=16,
  618. step=1,
  619. value=init_llama_yml["data"]["num_workers"]
  620. if sys.platform == "linux"
  621. else 0,
  622. )
  623. with gr.Row(equal_height=False):
  624. llama_data_batch_size_slider = gr.Slider(
  625. label="batch_size",
  626. interactive=True,
  627. minimum=1,
  628. maximum=32,
  629. step=1,
  630. value=init_llama_yml["data"]["batch_size"],
  631. )
  632. llama_data_max_length_slider = gr.Slider(
  633. label="max_length",
  634. interactive=True,
  635. minimum=1024,
  636. maximum=4096,
  637. step=128,
  638. value=init_llama_yml["max_length"],
  639. )
  640. with gr.Row(equal_height=False):
  641. llama_precision_dropdown = gr.Dropdown(
  642. label="训练精度",
  643. interactive=True,
  644. choices=["32", "bf16-true", "16-mixed"],
  645. value="bf16-true",
  646. )
  647. llama_check_interval_slider = gr.Slider(
  648. label="每n步保存一个模型",
  649. interactive=True,
  650. minimum=500,
  651. maximum=10000,
  652. step=500,
  653. value=init_llama_yml["trainer"]["val_check_interval"],
  654. )
  655. with gr.Row(equal_height=False):
  656. llama_grad_batches = gr.Slider(
  657. label="accumulate_grad_batches",
  658. interactive=True,
  659. minimum=1,
  660. maximum=20,
  661. step=1,
  662. value=init_llama_yml["trainer"][
  663. "accumulate_grad_batches"
  664. ],
  665. )
  666. llama_use_speaker = gr.Slider(
  667. label="use_speaker_ratio",
  668. interactive=True,
  669. minimum=0.1,
  670. maximum=1.0,
  671. step=0.05,
  672. value=init_llama_yml["train_dataset"]["use_speaker"],
  673. )
  674. with gr.Tab(label="LLAMA_lora融合"):
  675. with gr.Row(equal_height=False):
  676. llama_weight = gr.Dropdown(
  677. label="要融入的原模型",
  678. info="输入路径,或者下拉选择",
  679. choices=[init_llama_yml["ckpt_path"]],
  680. value=init_llama_yml["ckpt_path"],
  681. allow_custom_value=True,
  682. interactive=True,
  683. )
  684. with gr.Row(equal_height=False):
  685. lora_weight = gr.Dropdown(
  686. label="要融入的lora模型",
  687. info="输入路径,或者下拉选择",
  688. choices=[
  689. str(p)
  690. for p in Path("results").glob("text2*ar/**/*.ckpt")
  691. ],
  692. allow_custom_value=True,
  693. interactive=True,
  694. )
  695. with gr.Row(equal_height=False):
  696. llama_lora_output = gr.Dropdown(
  697. label="输出的lora模型",
  698. info="输出路径",
  699. value="checkpoints/merged.ckpt",
  700. choices=["checkpoints/merged.ckpt"],
  701. allow_custom_value=True,
  702. interactive=True,
  703. )
  704. with gr.Row(equal_height=False):
  705. llama_lora_merge_btn = gr.Button(
  706. value="开始融合", variant="primary"
  707. )
  708. with gr.Tab(label="Tensorboard"):
  709. with gr.Row(equal_height=False):
  710. tb_host = gr.Textbox(
  711. label="Tensorboard Host", value="127.0.0.1"
  712. )
  713. tb_port = gr.Textbox(
  714. label="Tensorboard Port", value="11451"
  715. )
  716. with gr.Row(equal_height=False):
  717. tb_dir = gr.Dropdown(
  718. label="Tensorboard 日志文件夹",
  719. allow_custom_value=True,
  720. choices=[
  721. str(p)
  722. for p in Path("results").glob(
  723. "**/tensorboard/version_*/"
  724. )
  725. ],
  726. )
  727. with gr.Row(equal_height=False):
  728. if_tb = gr.Checkbox(
  729. label="是否打开tensorboard?",
  730. )
  731. with gr.Tab("\U0001F9E0 进入推理界面"):
  732. with gr.Column():
  733. with gr.Row():
  734. with gr.Accordion(label="\U0001F5A5 推理服务器配置", open=False):
  735. with gr.Row():
  736. infer_host_textbox = gr.Textbox(
  737. label="Webui启动服务器地址", value="127.0.0.1"
  738. )
  739. infer_port_textbox = gr.Textbox(
  740. label="Webui启动服务器端口", value="7862"
  741. )
  742. with gr.Row():
  743. infer_vqgan_model = gr.Dropdown(
  744. label="VQGAN模型位置",
  745. info="填写pth/ckpt文件路径",
  746. value=init_vqgan_yml["ckpt_path"],
  747. choices=[init_vqgan_yml["ckpt_path"]]
  748. + [
  749. str(p)
  750. for p in Path("results").glob(
  751. "vqgan*/**/*.ckpt"
  752. )
  753. ],
  754. allow_custom_value=True,
  755. )
  756. with gr.Row():
  757. infer_llama_model = gr.Dropdown(
  758. label="LLAMA模型位置",
  759. info="填写pth/ckpt文件路径",
  760. value=init_llama_yml["ckpt_path"],
  761. choices=[init_llama_yml["ckpt_path"]]
  762. + [
  763. str(p)
  764. for p in Path("results").glob(
  765. "text2sem*/**/*.ckpt"
  766. )
  767. ],
  768. allow_custom_value=True,
  769. )
  770. with gr.Row():
  771. infer_compile = gr.Radio(
  772. label="是否编译模型?", choices=["Yes", "No"], value="Yes"
  773. )
  774. infer_llama_config = gr.Dropdown(
  775. label="LLAMA模型基础属性",
  776. choices=[
  777. "dual_ar_2_codebook_large",
  778. "dual_ar_2_codebook_medium",
  779. ],
  780. value="dual_ar_2_codebook_large",
  781. allow_custom_value=True,
  782. )
  783. with gr.Row():
  784. infer_checkbox = gr.Checkbox(label="是否打开推理界面")
  785. infer_error = gr.HTML(label="推理界面错误信息")
  786. with gr.Column():
  787. train_error = gr.HTML(label="训练时的报错信息")
  788. checkbox_group = gr.CheckboxGroup(
  789. label="\U0001F4CA 数据源列表",
  790. info="左侧输入文件夹所在路径或filelist。无论是否勾选,在此列表中都会被用以后续训练。",
  791. elem_classes=["data_src"],
  792. )
  793. train_box = gr.Textbox(
  794. label="数据预处理文件夹路径", value=str(data_pre_output), interactive=False
  795. )
  796. model_box = gr.Textbox(
  797. label="\U0001F4BE 模型输出路径",
  798. value=str(default_model_output),
  799. interactive=False,
  800. )
  801. with gr.Accordion(
  802. "查看预处理文件夹状态 (滑块为显示深度大小)",
  803. elem_classes=["scrollable-component"],
  804. elem_id="file_accordion",
  805. ):
  806. tree_slider = gr.Slider(
  807. minimum=0,
  808. maximum=3,
  809. value=0,
  810. step=1,
  811. show_label=False,
  812. container=False,
  813. )
  814. file_markdown = new_explorer(str(data_pre_output), 0)
  815. with gr.Row(equal_height=False):
  816. admit_btn = gr.Button(
  817. "\U00002705 文件预处理", scale=0, min_width=160, variant="primary"
  818. )
  819. fresh_btn = gr.Button("\U0001F503", scale=0, min_width=80)
  820. help_button = gr.Button("\U00002753", scale=0, min_width=80) # question
  821. train_btn = gr.Button("训练启动!", variant="primary")
  822. footer = load_data_in_raw("fish_speech/webui/html/footer.html")
  823. footer = footer.format(
  824. versions=versions_html(),
  825. api_docs="https://speech.fish.audio/inference/#http-api",
  826. )
  827. gr.HTML(footer, elem_id="footer")
  828. add_button.click(
  829. fn=add_item,
  830. inputs=[textbox, output_radio, label_radio],
  831. outputs=[checkbox_group, error],
  832. )
  833. remove_button.click(
  834. fn=remove_items, inputs=[checkbox_group], outputs=[checkbox_group, error]
  835. )
  836. checkbox_group.change(fn=show_selected, inputs=checkbox_group, outputs=[error])
  837. help_button.click(
  838. fn=None,
  839. js='() => { window.open("https://speech.fish.audio/", "newwindow", "height=100, width=400, '
  840. 'toolbar=no, menubar=no, scrollbars=no, resizable=no, location=no, status=no")}',
  841. )
  842. if_label.change(fn=change_label, inputs=[if_label], outputs=[error])
  843. train_btn.click(
  844. fn=train_process,
  845. inputs=[
  846. train_box,
  847. model_type_radio,
  848. # vq-gan config
  849. vqgan_lr_slider,
  850. vqgan_maxsteps_slider,
  851. vqgan_data_num_workers_slider,
  852. vqgan_data_batch_size_slider,
  853. vqgan_data_val_batch_size_slider,
  854. vqgan_precision_dropdown,
  855. vqgan_check_interval_slider,
  856. # llama config
  857. llama_base_config,
  858. llama_lr_slider,
  859. llama_maxsteps_slider,
  860. llama_data_num_workers_slider,
  861. llama_data_batch_size_slider,
  862. llama_data_max_length_slider,
  863. llama_precision_dropdown,
  864. llama_check_interval_slider,
  865. llama_grad_batches,
  866. llama_use_speaker,
  867. llama_use_lora,
  868. ],
  869. outputs=[train_error],
  870. )
  871. if_tb.change(
  872. fn=tensorboard_process,
  873. inputs=[if_tb, tb_dir, tb_host, tb_port],
  874. outputs=[train_error],
  875. )
  876. tb_dir.change(fn=fresh_tb_dir, inputs=[], outputs=[tb_dir])
  877. infer_vqgan_model.change(
  878. fn=fresh_vqgan_model, inputs=[], outputs=[infer_vqgan_model]
  879. )
  880. infer_llama_model.change(
  881. fn=fresh_llama_model, inputs=[], outputs=[infer_llama_model]
  882. )
  883. llama_weight.change(fn=fresh_llama_model, inputs=[], outputs=[llama_weight])
  884. admit_btn.click(
  885. fn=check_files,
  886. inputs=[train_box, tree_slider, label_model, label_device],
  887. outputs=[error, file_markdown],
  888. )
  889. fresh_btn.click(
  890. fn=new_explorer, inputs=[train_box, tree_slider], outputs=[file_markdown]
  891. )
  892. llama_lora_merge_btn.click(
  893. fn=llama_lora_merge,
  894. inputs=[llama_weight, lora_weight, llama_lora_output],
  895. outputs=[train_error],
  896. )
  897. infer_checkbox.change(
  898. fn=change_infer,
  899. inputs=[
  900. infer_checkbox,
  901. infer_host_textbox,
  902. infer_port_textbox,
  903. infer_vqgan_model,
  904. infer_llama_model,
  905. infer_llama_config,
  906. infer_compile,
  907. ],
  908. outputs=[infer_error],
  909. )
  910. demo.launch(inbrowser=True)