Przeglądaj źródła

Clean unused code & reorganize

Lengyue 2 lat temu
rodzic
commit
afc854ad85

+ 1 - 1
fish_speech/models/vqgan/modules/encoders.py

@@ -330,7 +330,7 @@ class VQEncoder(nn.Module):
     def decode(self, indices):
         q = self.vq.get_output_from_indices(indices)
 
-        if q.shape[1] != indices.shape[1]:
+        if q.shape[1] != indices.shape[1] and indices.ndim != 4:
             q = q.view(q.shape[0], indices.shape[1], -1)
         q = q.mT
 

+ 0 - 101
tools/build_vq_dataset.py

@@ -1,101 +0,0 @@
-from functools import lru_cache
-from pathlib import Path
-
-import numpy as np
-from datasets import Dataset, DatasetDict
-
-
-@lru_cache(maxsize=1)
-def get_phonemes():
-    phones = {}
-    phones.update(np.load("dump/phoneme_dev.npy", allow_pickle=True).item())
-    phones.update(np.load("dump/phoneme_train.npy", allow_pickle=True).item())
-    phones.update(
-        np.load(
-            "/home/fish/hubert-vq-vits/dump/phoneme_dev.npy", allow_pickle=True
-        ).item()
-    )
-    phones.update(
-        np.load(
-            "/home/fish/hubert-vq-vits/dump/phoneme_train.npy", allow_pickle=True
-        ).item()
-    )
-    print("Loaded phonemes")
-
-    return phones
-
-
-def parse_data(items):
-    results = []
-    phones = get_phonemes()
-
-    for item_name, semantic_audio in zip(items["item_name"], items["semantic_audio"]):
-        file_name = item_name
-        if item_name.startswith("/wenet-speech-vocals"):
-            file_name = "/home/fish/wenetspeech/dsall" + item_name
-
-        wav_file = Path(file_name)
-        text_file = wav_file.with_suffix(".txt")
-
-        if not text_file.exists():
-            text_file = wav_file.with_suffix(".lab")
-
-        if not text_file.exists():
-            print(f"Missing {text_file}")
-            return None
-
-        text = text_file.read_text().strip()
-        semantic = [f"<semantic_{x}>" for x in semantic_audio.split(" ")]
-        semantic = " ".join(semantic)
-        results.append(f"[INST] {text} [/INST] {semantic} </s>")
-        results.append(f"[INST] {phones[item_name]} [/INST] {semantic} </s>")
-
-    return {
-        "text": results,
-    }
-
-
-if __name__ == "__main__":
-    test_dataset = Dataset.from_csv(
-        ["dump/semantic_dev.tsv", "/home/fish/hubert-vq-vits/dump/semantic_dev.tsv"],
-        delimiter="\t",
-        split="test",
-    )
-    test_dataset = test_dataset.map(
-        parse_data,
-        num_proc=32,
-        remove_columns=test_dataset.column_names,
-        batched=True,
-        batch_size=10000,
-    )
-
-    train_dataset = Dataset.from_csv(
-        [
-            "dump/semantic_train.tsv",
-            "/home/fish/hubert-vq-vits/dump/semantic_train.tsv",
-        ],
-        delimiter="\t",
-        split="train",
-    )
-    train_dataset = train_dataset.map(
-        parse_data,
-        num_proc=32,
-        remove_columns=train_dataset.column_names,
-        batched=True,
-        batch_size=10000,
-    )
-
-    dataset = DatasetDict(
-        {
-            "train": train_dataset,
-            "test": test_dataset,
-        }
-    )
-
-    print(
-        f"There are {len(dataset['train'])} training examples and {len(dataset['test'])} test examples"
-    )
-    print(dataset["train"][0])
-    print(dataset["test"][1])
-
-    dataset.push_to_hub("fishaudio/cn-hubert-25hz-vq", private=True)

+ 0 - 26
tools/extract_whisper_vq_weights.py

@@ -1,26 +0,0 @@
-from pathlib import Path
-
-import click
-import torch
-from loguru import logger
-
-
-@click.command()
-@click.argument(
-    "input-file",
-    type=click.Path(exists=True, dir_okay=False, file_okay=True, path_type=Path),
-)
-@click.argument(
-    "output-file",
-    type=click.Path(exists=False, dir_okay=False, file_okay=True, path_type=Path),
-)
-def extract(input_file: Path, output_file: Path):
-    model = torch.load(input_file, map_location="cpu")["model"]
-    state_dict = {k: v for k, v in model.items() if k.startswith("whisper") is False}
-
-    torch.save(state_dict, output_file)
-    logger.info(f"Saved {len(state_dict)} keys to {output_file}")
-
-
-if __name__ == "__main__":
-    extract()

+ 2 - 2
tools/llama/extract_model.py

@@ -1,7 +1,7 @@
 import torch
 
 state_dict = torch.load(
-    "results/text2semantic_400m/checkpoints/step_000035000.ckpt", map_location="cpu"
+    "results/text2semantic_400m/checkpoints/step_000095000.ckpt", map_location="cpu"
 )["state_dict"]
 state_dict = {
     state_dict.replace("model.", ""): value
@@ -9,4 +9,4 @@ state_dict = {
     if state_dict.startswith("model.")
 }
 
-torch.save(state_dict, "results/text2semantic_400m/step_000035000_weights.ckpt")
+torch.save(state_dict, "results/text2semantic_400m/step_000095000_weights.ckpt")

+ 11 - 10
fish_speech/models/text2semantic/generate.py → tools/llama/generate.py

@@ -21,10 +21,10 @@ torch._inductor.config.fx_graph_cache = True  # Experimental feature to reduce c
 
 
 from fish_speech.models.text2semantic.llama import ModelArgs, Transformer
-from fish_speech.models.text2semantic.tp import maybe_init_dist
 from fish_speech.text import g2p
 from fish_speech.text.symbols import pad as pad_symbol
 from fish_speech.text.symbols import pu_symbols
+from tools.llama.tp import maybe_init_dist
 
 
 def multinomial_sample_one_no_sync(
@@ -175,9 +175,9 @@ def generate(
 
     # create an empty tensor of the expected final shape and fill in the current tokens
     T = prompt.size(1)
-    # if T + max_new_tokens > 1024:
-    #     max_new_tokens = 1024 - T
-    #     print(f"Truncating max_new_tokens to {max_new_tokens}")
+    if T + max_new_tokens > model.config.max_seq_len:
+        max_new_tokens = model.config.max_seq_len - T
+        print(f"Truncating max_new_tokens to {max_new_tokens}")
 
     T_new = T + max_new_tokens
     if interactive:
@@ -221,7 +221,7 @@ def generate(
 def encode_tokens(tokenizer, string, bos=True, device="cuda"):
     # data/Genshin/Chinese/神里绫华/vo_ayaka_character_idle_04.npy
 
-    prompt = g2p("算啦,虽然他罪无可恕,但也有可怜的地方嘛。" + string)
+    prompt = g2p("<zh>算啦,虽然他罪无可恕,但也有可怜的地方嘛。</zh> {string}")
     prompt = [
         (f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
         for _, i in prompt
@@ -425,7 +425,8 @@ def main(
         t = time.perf_counter() - t0
 
         if not interactive:
-            print(tokenizer.decode(y[0].tolist()))
+            print(tokenizer.decode(y[0, :prompt_length:].tolist()))
+            print(f"Generated {y.size(1) - prompt_length} tokens")
             # Find all <s:2769>
             codes = y[0, prompt_length:-1]
             codes = codes - 32311
@@ -459,7 +460,7 @@ if __name__ == "__main__":
     parser.add_argument(
         "--prompt",
         type=str,
-        default="在情感分析功能中,我们让大语言模型分析了一段经典散文。可以看到虽然分析的角度比较浅显,但没有逻辑错误,还是可以自洽的。这也不能怪AI,如果我们提前告知它作者所处的时代背景,相信它一定可以回答得更好。而在这个中文翻译功能中,英特尔大语言模型的表现就更加令我意外了,",
+        default="感情分析関数では、大規模言語モデルに古典的な散文を分析させます。 分析の観点は比較的単純ですが、論理的な誤りはなく、依然として自己一貫性があることがわかります。",
         help="Input prompt.",
     )
     parser.add_argument(
@@ -469,18 +470,18 @@ if __name__ == "__main__":
     )
     parser.add_argument("--num_samples", type=int, default=1, help="Number of samples.")
     parser.add_argument(
-        "--max_new_tokens", type=int, default=1024, help="Maximum number of new tokens."
+        "--max_new_tokens", type=int, default=4096, help="Maximum number of new tokens."
     )
     parser.add_argument("--top_k", type=int, default=50, help="Top-k for sampling.")
     parser.add_argument("--top_p", type=int, default=0.95, help="Top-k for sampling.")
-    parser.add_argument("--repetition_penalty", type=float, default=1.1)
+    parser.add_argument("--repetition_penalty", type=float, default=1.0)
     parser.add_argument(
         "--temperature", type=float, default=0.8, help="Temperature for sampling."
     )
     parser.add_argument(
         "--checkpoint_path",
         type=Path,
-        default=Path("results/text2semantic_400m/step_000090000_weights.ckpt"),
+        default=Path("results/text2semantic_400m/step_000095000_weights.ckpt"),
         help="Model checkpoint path.",
     )
     parser.add_argument(

+ 0 - 56
tools/llama/init_model.py

@@ -1,56 +0,0 @@
-from transformers import AutoTokenizer, LlamaConfig, LlamaModel
-
-# reuse the tokenizer from the llama
-model_type = "meta-llama/Llama-2-7b-hf"
-tokenizer = AutoTokenizer.from_pretrained(model_type)
-
-# new tokens
-new_tokens = [f"<semantic_{i}>" for i in range(4096)]
-tokenizer.add_tokens(new_tokens + ["<pad>"])
-
-# pad token
-tokenizer.pad_token = "<pad>"
-tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
-
-print(f"Vocab size: {len(tokenizer)}")
-
-hidden_size = 1024
-intermediate_size = hidden_size * (11 / 3)
-# then round to the nearest multiple of 8
-intermediate_size = round(intermediate_size / 8) * 8
-print(f"Hidden size: {hidden_size}")
-print(f"Intermediate size: {intermediate_size}")
-
-model = LlamaModel(
-    LlamaConfig(
-        vocab_size=tokenizer.vocab_size,
-        hidden_size=hidden_size,
-        intermediate_size=intermediate_size,
-        num_hidden_layers=20,
-        num_attention_heads=16,
-        max_position_embeddings=4096,
-    )
-)
-
-model = model.bfloat16()
-
-# Resize the token embeddings to include the new tokens
-# Make sure it's a multiple of 8 for faster training
-model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
-
-total_params = sum(p.numel() for p in model.parameters())
-print(f"Total parameters: {total_params / 1e6:.2f}M")
-
-# Try tokenizing a new sequence
-sequence = "Test <semantic_0> <semantic_1023> <pad>"
-encoded = tokenizer.encode(sequence)
-print("Test encoding....")
-print(f"\tSentence: {sequence}")
-print(f"\tEncoded: {encoded}")
-print(f"\tDecoded: {tokenizer.batch_decode(encoded)}")
-
-# model.save_pretrained("./checkpoints/speech-lm-300m-init")
-# tokenizer.save_pretrained("./checkpoints/speech-lm-300m-init")
-
-model.push_to_hub("fishaudio/speech-lm-300m", private=True, revision="init")
-tokenizer.push_to_hub("fishaudio/speech-lm-300m", private=True, revision="init")

+ 0 - 0
fish_speech/models/text2semantic/quantize.py → tools/llama/quantize.py


+ 0 - 0
fish_speech/models/text2semantic/tp.py → tools/llama/tp.py


+ 0 - 88
tools/prepare_dataset.py

@@ -1,88 +0,0 @@
-import json
-import os
-from pathlib import Path
-
-import librosa
-import torch
-from datasets import Dataset
-from multiprocess import set_start_method
-from transformers import AutoProcessor, EncodecModel
-
-set_start_method("spawn", force=True)
-
-encodec_name = "facebook/encodec_24khz"
-encodec_processor = AutoProcessor.from_pretrained(encodec_name)
-encodec_model = EncodecModel.from_pretrained(encodec_name)
-encodec_model.eval()
-
-
-def tokenize(text, audio, sr=None, speaker=None):
-    assert sr is None or sr == encodec_processor.sampling_rate
-
-    if isinstance(audio, (str, Path)):
-        audio, sr = librosa.load(audio, sr=sr, mono=True)
-
-    prompt = "[INST] "
-    if speaker:
-        prompt += f"[SPK] {speaker} [/SPK] "
-    prompt += f"{text} [/INST] "
-
-    inputs = encodec_processor(
-        raw_audio=audio, sampling_rate=sr, return_tensors="pt"
-    ).to(encodec_model.device)
-    outputs = encodec_model.encode(
-        inputs["input_values"], inputs["padding_mask"], bandwidth=1.5, return_dict=True
-    )
-
-    assert outputs.audio_codes.dim() == 4  # [batch, channel, codebook, code]
-    assert outputs.audio_codes.shape[0] == outputs.audio_codes.shape[1] == 1
-
-    codes = outputs.audio_codes[0, 0, 0, :].long()
-    codes_str = " ".join([f"<encodec_{int(c)}>" for c in codes.tolist()])
-    prompt += codes_str
-
-    return {
-        "prompt": prompt,
-        "codes": codes,
-    }
-
-
-def wrap_tokenize(x):
-    device = torch.device("cuda", 0)
-
-    if encodec_model.device != device:
-        encodec_model.to(device)
-
-    return tokenize(
-        text=x["text"],
-        audio=x["raw_path"],
-        sr=encodec_processor.sampling_rate,
-        speaker=x["speaker"],
-    )
-
-
-def generator_libritts_r():
-    base = Path("dataset/tts/LibriTTS_R")
-
-    for i in base.rglob("*.wav"):
-        text_file = i.with_suffix(".normalized.txt")
-        if not text_file.exists():
-            continue
-
-        text = text_file.read_text().strip()
-
-        yield {
-            "text": text,
-            "speaker": f"libritts_{i.parent.parent.name}",
-            "raw_path": str(i),
-            "path": str(i.relative_to(base)),
-        }
-
-
-if __name__ == "__main__":
-    dataset = Dataset.from_generator(generator_libritts_r)
-    dataset = dataset.map(wrap_tokenize, num_proc=12)
-    dataset = dataset.remove_columns(["raw_path"])
-
-    dataset.save_to_disk("dataset/tts/libritts-r-encodec")
-    dataset.push_to_hub("fishaudio/libritts-r-encodec", private=True)

+ 0 - 33
tools/split_filelist.py

@@ -1,33 +0,0 @@
-import random
-from pathlib import Path
-
-import click
-from loguru import logger
-
-
-@click.command()
-@click.argument(
-    "list-file",
-    type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path),
-)
-@click.option("--train-proportion", "-p", type=float, default=0.95)
-def main(list_file, train_proportion):
-    lines = list_file.read_text().splitlines()
-    logger.info(f"Found {len(lines)} lines in {list_file}")
-
-    random.shuffle(lines)
-
-    train_size = int(len(lines) * train_proportion)
-
-    train_file = list_file.with_suffix(f".train{list_file.suffix}")
-    train_file.write_text("\n".join(lines[:train_size]))
-
-    test_file = list_file.with_suffix(f".test{list_file.suffix}")
-    test_file.write_text("\n".join(lines[train_size:]))
-
-    logger.info(f"Wrote {len(lines[:train_size])} lines to {train_file}")
-    logger.info(f"Wrote {len(lines[train_size:])} lines to {test_file}")
-
-
-if __name__ == "__main__":
-    main()