Explorar el Código

Fix dataset and docker

Lengyue hace 2 años
padre
commit
c7cf70e904
Se han modificado 2 ficheros con 56 adiciones y 24 borrados
  1. 5 5
      dockerfile
  2. 51 19
      fish_speech/datasets/text.py

+ 5 - 5
dockerfile

@@ -26,14 +26,14 @@ RUN pip3 install --upgrade pip && \
     pip3 install ninja packaging && \
     pip3 install git+https://github.com/Dao-AILab/flash-attention.git
 
-# Setup rust-data-server
-RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y && \
-    cd data_server && cargo build --release
-
 # Project Env
 WORKDIR /exp
-
 COPY . .
+
+# Setup rust-data-server
+RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y && \
+    cd data_server && $HOME/.cargo/bin/cargo build --release
+
 RUN pip3 install -e .
 
 CMD /bin/zsh

+ 51 - 19
fish_speech/datasets/text.py

@@ -1,9 +1,6 @@
-import json
 import random
-import re
 from dataclasses import dataclass
 from itertools import chain
-from pathlib import Path
 from random import Random
 from typing import Optional, Union
 
@@ -65,12 +62,16 @@ class StreamTextDataset(IterableDataset):
         seed: int = 42,
         parquet_batch_size: int = 10000,
         repo: str = "uonlp/CulturaX",
+        max_length: int = 1024,
+        tokenizer: AutoTokenizer = None,
     ):
         super().__init__()
 
         self.seed = seed
         self.parquet_batch_size = parquet_batch_size
         self.repo = repo
+        self.max_length = max_length
+        self.tokenizer = tokenizer
 
         if files is None and prefix is None:
             raise ValueError("Either files or prefix must be specified")
@@ -105,21 +106,48 @@ class StreamTextDataset(IterableDataset):
     def parse_data(self, filename: str):
         for data in self.parse_data_internal(filename):
             text = data["text"]
-            expression = re.compile(r"\[INST\] (.*) \[/INST\] (.*) </s>")
-            match = expression.match(text)
 
-            if match is None:
-                continue
+            # 30% modeling phones
+            if random.random() < 0.3:
+                text = " ".join(
+                    [
+                        (f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
+                        for i in text
+                    ]
+                )
+
+            # encode
+            tokens = self.tokenizer.encode(
+                text,
+                add_special_tokens=False,
+                truncation=False,
+                max_length=10**6,
+            )
+
+            # Random choice self.max_length
+            if len(tokens) > self.max_length:
+                start = random.randint(0, len(tokens) - self.max_length)
+                tokens = tokens[start : start + self.max_length - 1]
 
-            text = match.group(1)
-            semantic = match.group(2)
+            tokens = (
+                [self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
+            )
+            # Pad dims
+            placeholder_multi_codebook = torch.zeros((4, len(tokens)), dtype=torch.long)
 
-            # Convert semantic to ids
-            expression = re.compile(r"<semantic_(\d+)>")
-            # 0 and 1 are reserved for <s> and </s>
-            semantic = [0] + [int(i) + 2 for i in expression.findall(semantic)] + [1]
+            tokens = torch.concat(
+                [
+                    torch.tensor([tokens], dtype=torch.long),
+                    placeholder_multi_codebook,
+                ],
+                dim=0,
+            )
+            labels = tokens.clone()
+            tokens = tokens[:, :-1]
+            labels = labels[:, 1:]
+            labels[1:] = -100  # remove all placeholders
 
-            yield {"text": text, "semantic": [semantic]}
+            yield {"tokens": tokens, "labels": labels}
 
     def parse_data_internal(self, filename: str):
         url = f"https://huggingface.co/datasets/{self.repo}/resolve/main/{filename}"
@@ -190,8 +218,6 @@ class AutoAugTextDataset(IterableDataset):
                     for i in phones
                 ]
             )
-        else:
-            sentence = clean_text(sentence)
 
         tokens = self.tokenizer.encode(
             f"{sentence}",
@@ -415,7 +441,12 @@ if __name__ == "__main__":
 
     from tqdm import tqdm
 
-    ds = AutoAugTextDataset(
+    # ds = AutoAugTextDataset(
+    #     tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
+    # )
+
+    ds = StreamTextDataset(
+        prefix="en/",
         tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
     )
 
@@ -423,10 +454,11 @@ if __name__ == "__main__":
         train_dataset=ds,
         val_dataset=ds,
         tokenizer=ds.tokenizer,
-        batch_size=16,
+        batch_size=2,
         max_length=1024,
         num_workers=0,
     )
 
     for batch in tqdm(dm.train_dataloader()):
-        pass
+        print(batch)
+        break