Ver Fonte

Format code

Lengyue há 2 anos atrás
pai
commit
4255b1d217
3 ficheiros alterados com 40 adições e 17 exclusões
  1. 12 5
      fine-tune.py
  2. 19 5
      preparing_data/to_flac.py
  3. 9 7
      preparing_data/wenet_clean/clean_wenet_speech.py

+ 12 - 5
fine-tune.py

@@ -38,8 +38,14 @@ class TrainingArguments(_TrainingArguments):
     use_lora: bool = field(default=False)
 
 
-def dataset_transform(batch, tokenizer: AutoTokenizer=None):
-    outputs = tokenizer(batch["prompt"], padding="longest", truncation=True, max_length=512, return_tensors="pt")
+def dataset_transform(batch, tokenizer: AutoTokenizer = None):
+    outputs = tokenizer(
+        batch["prompt"],
+        padding="longest",
+        truncation=True,
+        max_length=512,
+        return_tensors="pt",
+    )
     labels = outputs.input_ids.clone()
 
     # Set the labels to -100 so that the logits are not affected by loss
@@ -51,6 +57,7 @@ def dataset_transform(batch, tokenizer: AutoTokenizer=None):
         "labels": labels,
     }
 
+
 def train():
     parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
     model_args, data_args, training_args = parser.parse_args_into_dataclasses()
@@ -87,11 +94,11 @@ def train():
 
     try:
         dataset = load_from_disk(data_args.data_path)
-        if 'train' in dataset:
-            dataset = dataset['train']
+        if "train" in dataset:
+            dataset = dataset["train"]
     except:
         dataset = load_dataset(data_args.data_path, split="train")
-    
+
     dataset.set_transform(partial(dataset_transform, tokenizer=tokenizer))
     dataset = dataset.train_test_split(test_size=1000, seed=42)
 

+ 19 - 5
preparing_data/to_flac.py

@@ -1,8 +1,10 @@
-from pathlib import Path
+import random
 import subprocess
 from multiprocessing import Pool, cpu_count
+from pathlib import Path
+
 from tqdm import tqdm
-import random
+
 
 def convert_to_flac(src_file_path):
     dst_file_path = src_file_path.with_suffix(".flac")
@@ -10,7 +12,17 @@ def convert_to_flac(src_file_path):
 
     try:
         subprocess.check_call(
-            ["ffmpeg", "-y", "-i", str(src_file_path), "-acodec", "flac", "-threads", "0", str(dst_file_path)],
+            [
+                "ffmpeg",
+                "-y",
+                "-i",
+                str(src_file_path),
+                "-acodec",
+                "flac",
+                "-threads",
+                "0",
+                str(dst_file_path),
+            ],
             stdout=subprocess.DEVNULL,
             stderr=subprocess.DEVNULL,
         )
@@ -33,13 +45,15 @@ if __name__ == "__main__":
     fail_counter = 0
 
     with Pool(processes=cpu_count(), maxtasksperchild=100) as pool:
-        with tqdm(pool.imap_unordered(convert_to_flac, wav_files), total=len(wav_files)) as pbar:
+        with tqdm(
+            pool.imap_unordered(convert_to_flac, wav_files), total=len(wav_files)
+        ) as pbar:
             for success in pbar:
                 if success:
                     success_counter += 1
                 else:
                     fail_counter += 1
-            
+
             pbar.set_description(f"Success: {success_counter}, Fail: {fail_counter}")
 
     print(f"Successfully converted: {success_counter}")

+ 9 - 7
preparing_data/wenet_clean/clean_wenet_speech.py

@@ -1,20 +1,20 @@
 import json
-from pathlib import Path
+import os
 import subprocess
+import tempfile
+import time
+from pathlib import Path
 
 import librosa
 import soundfile as sf
 import torch
 import torchaudio
 from fish_audio_preprocess.utils.separate_audio import (
-    separate_audio,
-    merge_tracks,
     init_model,
+    merge_tracks,
+    separate_audio,
 )
 from tqdm import tqdm
-import time
-import os
-import tempfile
 
 rank = int(os.environ.get("SLURM_PROCID", 0))
 world_size = int(os.environ.get("SLURM_NTASKS", 1))
@@ -75,7 +75,9 @@ def main():
             )
             # Make it 2 channels
             audio = torch.cat([audio, audio], dim=0)
-            tracks = separate_audio(demucs, audio, shifts=1, num_workers=0, progress=False)
+            tracks = separate_audio(
+                demucs, audio, shifts=1, num_workers=0, progress=False
+            )
             audio = merge_tracks(tracks, filter=["vocals"])[0]
             vocals, sr = (
                 torchaudio.functional.resample(