Bladeren bron

Init preprocessing tools

Lengyue 2 jaren geleden
commit
db0aa1a99a
8 gewijzigde bestanden met toevoegingen van 444 en 0 verwijderingen
  1. 1 0
      .gitignore
  2. 22 0
      ds_config.json
  3. 112 0
      fine-tune.py
  4. 31 0
      finetune.sh
  5. 88 0
      prepare_dataset.py
  6. 46 0
      preparing_data/to_flac.py
  7. 117 0
      preparing_data/wenet_clean/clean_wenet_speech.py
  8. 27 0
      preparing_data/wenet_clean/launch.py

+ 1 - 0
.gitignore

@@ -0,0 +1 @@
+.pgx.*

+ 22 - 0
ds_config.json

@@ -0,0 +1,22 @@
+{
+    "train_batch_size": "auto",
+    "train_micro_batch_size_per_gpu" :"auto",
+    "gradient_accumulation_steps": "auto",
+    "gradient_clipping": 1.0,
+    "bf16": {
+        "enabled": "auto"
+    },
+    "zero_optimization": {
+        "stage": 3,
+        "overlap_comm": true,
+        "stage3_gather_16bit_weights_on_model_save": true
+    },
+    "flops_profiler": {
+        "enabled": false,
+        "profile_step": 1,
+        "module_depth": -1,
+        "top_modules": 1,
+        "detailed": true,
+        "output_file": null
+    }
+}

+ 112 - 0
fine-tune.py

@@ -0,0 +1,112 @@
+from dataclasses import dataclass, field
+from functools import partial
+from typing import Optional
+
+from datasets import load_dataset, load_from_disk
+from transformers import (
+    AutoModelForCausalLM,
+    AutoTokenizer,
+    DataCollatorWithPadding,
+    HfArgumentParser,
+    Trainer,
+)
+from transformers import TrainingArguments as _TrainingArguments
+
+
+@dataclass
+class ModelArguments:
+    model_name_or_path: Optional[str] = field(default="baichuan-inc/Baichuan2-7B-Base")
+
+
+@dataclass
+class DataArguments:
+    data_path: str = field(
+        default=None, metadata={"help": "Path to the training data."}
+    )
+
+
+@dataclass
+class TrainingArguments(_TrainingArguments):
+    cache_dir: Optional[str] = field(default=None)
+    optim: str = field(default="adamw_torch")
+    model_max_length: int = field(
+        default=512,
+        metadata={
+            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
+        },
+    )
+    use_lora: bool = field(default=False)
+
+
+def dataset_transform(batch, tokenizer: AutoTokenizer=None):
+    outputs = tokenizer(batch["prompt"], padding="longest", truncation=True, max_length=512, return_tensors="pt")
+    labels = outputs.input_ids.clone()
+
+    # Set the labels to -100 so that the logits are not affected by loss
+    labels[outputs.attention_mask == 0] = -100
+
+    return {
+        "input_ids": outputs.input_ids,
+        "attention_mask": outputs.attention_mask,
+        "labels": labels,
+    }
+
+def train():
+    parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
+    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+
+    model = AutoModelForCausalLM.from_pretrained(
+        model_args.model_name_or_path,
+        trust_remote_code=True,
+        cache_dir=training_args.cache_dir,
+    )
+
+    tokenizer = AutoTokenizer.from_pretrained(
+        model_args.model_name_or_path,
+        use_fast=False,
+        trust_remote_code=True,
+        model_max_length=training_args.model_max_length,
+        cache_dir=training_args.cache_dir,
+    )
+    tokenizer.pad_token_id = tokenizer.eos_token_id
+
+    if training_args.use_lora:
+        from peft import LoraConfig, TaskType, get_peft_model
+
+        peft_config = LoraConfig(
+            task_type=TaskType.CAUSAL_LM,
+            target_modules=["W_pack"],
+            inference_mode=False,
+            r=16,
+            lora_alpha=64,
+            lora_dropout=0.1,
+        )
+        model.enable_input_require_grads()
+        model = get_peft_model(model, peft_config)
+        model.print_trainable_parameters()
+
+    try:
+        dataset = load_from_disk(data_args.data_path)
+        if 'train' in dataset:
+            dataset = dataset['train']
+    except:
+        dataset = load_dataset(data_args.data_path, split="train")
+    
+    dataset.set_transform(partial(dataset_transform, tokenizer=tokenizer))
+    dataset = dataset.train_test_split(test_size=1000, seed=42)
+
+    trainer = Trainer(
+        model=model,
+        args=training_args,
+        train_dataset=dataset["train"],
+        eval_dataset=dataset["test"],
+        tokenizer=tokenizer,
+        data_collator=DataCollatorWithPadding(tokenizer),
+    )
+    trainer.train()
+    trainer.save_state()
+    trainer.save_model(output_dir=training_args.output_dir)
+
+
+if __name__ == "__main__":
+    train()

+ 31 - 0
finetune.sh

@@ -0,0 +1,31 @@
+export NCCL_P2P_DISABLE=1
+
+hostfile=""
+deepspeed --hostfile=$hostfile tools/tts/fine-tune.py \
+    --deepspeed tools/tts/ds_config.json \
+    --report_to "tensorboard" \
+    --data_path "fishaudio/libritts-r-encodec" \
+    --model_name_or_path "checkpoints/llama2-tiny-init" \
+    --output_dir "results" \
+    --model_max_length 4096 \
+    --max_steps 500000 \
+    --per_device_train_batch_size 32 \
+    --gradient_accumulation_steps 1 \
+    --save_strategy steps \
+    --save_steps 10000 \
+    --evaluation_strategy steps \
+    --eval_steps 10000 \
+    --learning_rate 1e-3 \
+    --lr_scheduler_type cosine \
+    --adam_beta1 0.9 \
+    --adam_beta2 0.98 \
+    --adam_epsilon 1e-8 \
+    --max_grad_norm 1.0 \
+    --weight_decay 1e-4 \
+    --warmup_steps 10000 \
+    --logging_steps 1 \
+    --gradient_checkpointing True \
+    --remove_unused_columns False \
+    --use_lora False \
+    --bf16 True \
+    --tf32 True

+ 88 - 0
prepare_dataset.py

@@ -0,0 +1,88 @@
+import json
+import os
+from pathlib import Path
+
+import librosa
+import torch
+from datasets import Dataset
+from multiprocess import set_start_method
+from transformers import AutoProcessor, EncodecModel
+
+set_start_method("spawn", force=True)
+
+encodec_name = "facebook/encodec_24khz"
+encodec_processor = AutoProcessor.from_pretrained(encodec_name)
+encodec_model = EncodecModel.from_pretrained(encodec_name)
+encodec_model.eval()
+
+
+def tokenize(text, audio, sr=None, speaker=None):
+    assert sr is None or sr == encodec_processor.sampling_rate
+
+    if isinstance(audio, (str, Path)):
+        audio, sr = librosa.load(audio, sr=sr, mono=True)
+
+    prompt = "[INST] "
+    if speaker:
+        prompt += f"[SPK] {speaker} [/SPK] "
+    prompt += f"{text} [/INST] "
+
+    inputs = encodec_processor(
+        raw_audio=audio, sampling_rate=sr, return_tensors="pt"
+    ).to(encodec_model.device)
+    outputs = encodec_model.encode(
+        inputs["input_values"], inputs["padding_mask"], bandwidth=1.5, return_dict=True
+    )
+
+    assert outputs.audio_codes.dim() == 4  # [batch, channel, codebook, code]
+    assert outputs.audio_codes.shape[0] == outputs.audio_codes.shape[1] == 1
+
+    codes = outputs.audio_codes[0, 0, 0, :].long()
+    codes_str = " ".join([f"<encodec_{int(c)}>" for c in codes.tolist()])
+    prompt += codes_str
+
+    return {
+        "prompt": prompt,
+        "codes": codes,
+    }
+
+
+def wrap_tokenize(x):
+    device = torch.device("cuda", 0)
+
+    if encodec_model.device != device:
+        encodec_model.to(device)
+
+    return tokenize(
+        text=x["text"],
+        audio=x["raw_path"],
+        sr=encodec_processor.sampling_rate,
+        speaker=x["speaker"],
+    )
+
+
+def generator_libritts_r():
+    base = Path("dataset/tts/LibriTTS_R")
+
+    for i in base.rglob("*.wav"):
+        text_file = i.with_suffix(".normalized.txt")
+        if not text_file.exists():
+            continue
+
+        text = text_file.read_text().strip()
+
+        yield {
+            "text": text,
+            "speaker": f"libritts_{i.parent.parent.name}",
+            "raw_path": str(i),
+            "path": str(i.relative_to(base)),
+        }
+
+
+if __name__ == "__main__":
+    dataset = Dataset.from_generator(generator_libritts_r)
+    dataset = dataset.map(wrap_tokenize, num_proc=12)
+    dataset = dataset.remove_columns(["raw_path"])
+
+    dataset.save_to_disk("dataset/tts/libritts-r-encodec")
+    dataset.push_to_hub("fishaudio/libritts-r-encodec", private=True)

+ 46 - 0
preparing_data/to_flac.py

@@ -0,0 +1,46 @@
+from pathlib import Path
+import subprocess
+from multiprocessing import Pool, cpu_count
+from tqdm import tqdm
+import random
+
+def convert_to_flac(src_file_path):
+    dst_file_path = src_file_path.with_suffix(".flac")
+    dst_file_path.parent.mkdir(parents=True, exist_ok=True)
+
+    try:
+        subprocess.check_call(
+            ["ffmpeg", "-y", "-i", str(src_file_path), "-acodec", "flac", "-threads", "0", str(dst_file_path)],
+            stdout=subprocess.DEVNULL,
+            stderr=subprocess.DEVNULL,
+        )
+
+        # remove the input file
+        src_file_path.unlink()
+        return True
+    except subprocess.CalledProcessError:
+        return False
+
+
+if __name__ == "__main__":
+    src_dir = Path("dataset/tts/WenetSpeech/cleaned")
+
+    wav_files = list(src_dir.rglob("*.wav"))
+    random.shuffle(wav_files)
+    print(f"Found {len(wav_files)} wav files")
+
+    success_counter = 0
+    fail_counter = 0
+
+    with Pool(processes=cpu_count(), maxtasksperchild=100) as pool:
+        with tqdm(pool.imap_unordered(convert_to_flac, wav_files), total=len(wav_files)) as pbar:
+            for success in pbar:
+                if success:
+                    success_counter += 1
+                else:
+                    fail_counter += 1
+            
+            pbar.set_description(f"Success: {success_counter}, Fail: {fail_counter}")
+
+    print(f"Successfully converted: {success_counter}")
+    print(f"Failed conversions: {fail_counter}")

+ 117 - 0
preparing_data/wenet_clean/clean_wenet_speech.py

@@ -0,0 +1,117 @@
+import json
+from pathlib import Path
+import subprocess
+
+import librosa
+import soundfile as sf
+import torch
+import torchaudio
+from fish_audio_preprocess.utils.separate_audio import (
+    separate_audio,
+    merge_tracks,
+    init_model,
+)
+from tqdm import tqdm
+import time
+import os
+import tempfile
+
+rank = int(os.environ.get("SLURM_PROCID", 0))
+world_size = int(os.environ.get("SLURM_NTASKS", 1))
+device = torch.device("cuda:0")
+print(f"Rank {rank}/{world_size} on {device}")
+
+
+def main():
+    meta_path = Path("dataset/tts/WenetSpeech/WenetSpeech.json")
+    dataset_path = Path("dataset/tts/WenetSpeech")
+    cleaned_path = Path("dataset/tts/WenetSpeech/cleaned")
+    if not cleaned_path.exists():
+        cleaned_path.mkdir(parents=True)
+
+    demucs = init_model("htdemucs", device)
+    print("Model loaded")
+
+    with open(meta_path) as f:
+        dataset = json.load(f)["audios"]
+
+    print(f"Dataset loaded, {len(dataset)} samples")
+    dataset = dataset[rank::world_size]
+    print(f"Dataset split, {len(dataset)} samples")
+
+    for data_idx, data in enumerate(dataset):
+        done_path = cleaned_path / data["aid"] / "done"
+        done_path.parent.mkdir(parents=True, exist_ok=True)
+
+        if done_path.exists():
+            continue
+
+        print(f"Processing {data_idx}/{len(dataset)} at rank {rank}")
+
+        try:
+            with tempfile.NamedTemporaryFile(suffix=".wav") as f:
+                subprocess.check_call(
+                    [
+                        "ffmpeg",
+                        "-y",
+                        "-i",
+                        str(dataset_path / data["path"]),
+                        "-c:a",
+                        "pcm_s16le",
+                        "-threads",
+                        "0",
+                        "-ar",
+                        "24000",
+                        str(f.name),
+                    ],
+                    stdout=subprocess.DEVNULL,
+                    stderr=subprocess.DEVNULL,
+                )
+                raw_audio, sr = librosa.load(f.name, sr=None, mono=True)
+
+            raw_audio = torch.from_numpy(raw_audio[None]).to(device)
+            audio = torchaudio.functional.resample(
+                raw_audio, orig_freq=sr, new_freq=demucs.samplerate
+            )
+            # Make it 2 channels
+            audio = torch.cat([audio, audio], dim=0)
+            tracks = separate_audio(demucs, audio, shifts=1, num_workers=0, progress=False)
+            audio = merge_tracks(tracks, filter=["vocals"])[0]
+            vocals, sr = (
+                torchaudio.functional.resample(
+                    audio, orig_freq=demucs.samplerate, new_freq=24000
+                ),
+                24000,
+            )
+            vocals = vocals.cpu().numpy()
+
+            for idx, segment in enumerate(data["segments"]):
+                if segment["confidence"] <= 0.95:
+                    continue
+
+                # Load audio
+                begin = int(segment["begin_time"] * sr)
+                end = int(segment["end_time"] * sr)
+                segment_audio = vocals[begin:end]
+
+                # Write audio
+                temp_path = cleaned_path / data["aid"] / f"S{idx:05d}.wav"
+                temp_path.parent.mkdir(parents=True, exist_ok=True)
+                sf.write(temp_path, segment_audio, samplerate=sr)
+
+                # Write text
+                temp_path = temp_path.with_suffix(".txt")
+                temp_path.write_text(segment["text"])
+
+            # Write done file
+            done_path.write_text("")
+        except Exception as e:
+            print(f"Error {e} on {data_idx}/{len(dataset)} at rank {rank}")
+            time.sleep(10)
+            continue
+
+    print("Done")
+
+
+if __name__ == "__main__":
+    main()

+ 27 - 0
preparing_data/wenet_clean/launch.py

@@ -0,0 +1,27 @@
+import os
+import subprocess as sp
+import sys
+
+SLURM_NTASKS = 6
+
+processes = []
+for i in range(SLURM_NTASKS):
+    env = os.environ.copy()
+    env["SLURM_PROCID"] = str(i)
+    env["SLURM_NTASKS"] = str(SLURM_NTASKS)
+    env["CUDA_VISIBLE_DEVICES"] = str(i % 8)
+
+    processes.append(
+        sp.Popen(
+            f"python preparing_data/wenet_clean/clean_wenet_speech.py",
+            shell=True,
+            env=env,
+            stdout=sys.stdout,
+            stderr=sys.stderr,
+        )
+    )
+
+
+for p in processes:
+    p.wait()
+    print(p.communicate())