Lengyue 2 лет назад
Родитель
Сommit
490176ece5
3 измененных файлов с 118 добавлено и 13 удалено
  1. 64 13
      data_server/src/main.rs
  2. 10 0
      fish_speech/datasets/protos/text_data_stream.py
  3. 44 0
      tools/split_protos.py

+ 64 - 13
data_server/src/main.rs

@@ -3,9 +3,10 @@ use log::info;
 use prost::Message;
 use rand::seq::SliceRandom;
 use rand::{thread_rng, Rng};
-use std::fs::File;
+use std::fs::{self, File};
 use std::io::{self, BufReader, Read, Result as IoResult};
-use std::vec;
+use std::path::{Path, PathBuf};
+use std::{env, vec};
 use tonic::{transport::Server, Request, Response, Status};
 
 pub mod text_data {
@@ -39,7 +40,7 @@ pub struct MyDataService {
     weights: Vec<f32>,
 }
 
-fn read_pb_stream<R: Read>(mut reader: BufReader<R>) -> io::Result<Vec<RSTextData>> {
+async fn read_pb_stream<R: Read>(mut reader: BufReader<R>) -> io::Result<Vec<RSTextData>> {
     let mut text_data_list = Vec::new();
     let mut index = 0;
     let mut total_vq_frames = 0;
@@ -91,32 +92,78 @@ fn read_pb_stream<R: Read>(mut reader: BufReader<R>) -> io::Result<Vec<RSTextDat
 
         index += 1;
 
-        if index % 10000 == 0 {
+        if index % 1000 == 0 {
             info!("Loaded {} groups, total vq frames: {}", index, total_vq_frames);
         }
     }
 
-    info!("Loaded {} groups, total vq frames: {}", index, total_vq_frames);
+    info!("Worker loaded {} groups, total vq frames: {}", index, total_vq_frames);
 
     Ok(text_data_list)
 }
 
+fn list_files(path: PathBuf) -> Vec<String> {
+    let mut files = Vec::new();
+
+    for entry in fs::read_dir(path).unwrap() {
+        let entry = entry.unwrap();
+        let path = entry.path();
+        if path.is_dir() {
+            files.extend(list_files(path));
+        } else {
+            files.push(path.to_str().unwrap().to_string());
+        }
+    }
+
+    files
+}
+
 impl MyDataService {
-    pub fn new(files: Vec<String>, causual_sampling: bool) -> IoResult<Self> {
+    pub async fn new(files: Vec<String>, causual_sampling: bool) -> IoResult<Self> {
         let mut groups = Vec::new();
         let mut weights = Vec::new();
+        let mut handles = Vec::new();
+        let start_time = std::time::Instant::now();
 
+        // Expand files if some are directories
+        let mut new_files = Vec::new();
         for filename in files.iter() {
-            let file = File::open(filename)?;
-            let reader = BufReader::new(file);
+            let path = Path::new(filename);
+            if path.is_dir() {
+                // run recursively on all files in the directory
+                for entry in list_files(path.to_path_buf()) {
+                    if entry.ends_with(".protos") {
+                        new_files.push(entry);
+                    }
+                }
+            } else {
+                new_files.push(filename.clone());
+            }
+        }
+
+        log::info!("Loading files: {:?}", new_files.len());
+
+        for filename in new_files.iter() {
+            // Tokio launch multiple tasks to read the files in parallel
+            let copied_filename = filename.clone();
+            let handle = tokio::spawn(async move {
+                let file = File::open(copied_filename)?;
+                let reader = BufReader::new(file);
+                read_pb_stream(reader).await
+            });
+            handles.push(handle);
+        }
 
-            // Assuming read_pb_stream is implemented and it returns an iterator over TextData
-            for text_data in read_pb_stream(reader)? {
-                weights.push(text_data.sentences.len() as f32); // Assuming sentences is a repeated field in TextData
+        for handle in handles {
+            let text_data_list = handle.await??;
+            for text_data in text_data_list {
+                weights.push(text_data.sentences.len() as f32);
                 groups.push(text_data);
             }
         }
 
+        log::info!("All workers finished, total groups: {}, used time: {:?}", groups.len(), start_time.elapsed());
+
         Ok(MyDataService {
             groups,
             weights,
@@ -213,8 +260,12 @@ struct Args {
     address: String,
 }
 
-#[tokio::main]
+#[tokio::main(flavor = "multi_thread", worker_threads = 8)]
 async fn main() -> Result<(), Box<dyn std::error::Error>> {
+    if env::var("RUST_LOG").is_err() {
+        env::set_var("RUST_LOG", "info")
+    }
+
     env_logger::init();
 
     // Parse command-line arguments
@@ -222,7 +273,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
     info!("Arguments: {:?}", args);
 
     let addr = args.address.parse()?;
-    let data_service = MyDataService::new(args.files, args.causal)?;
+    let data_service = MyDataService::new(args.files, args.causal).await?;
 
     info!("Starting server at {}", addr);
 

+ 10 - 0
fish_speech/datasets/protos/text_data_stream.py

@@ -24,3 +24,13 @@ def write_pb_stream(f, text_data):
 def pack_pb_stream(text_data):
     buf = text_data.SerializeToString()
     return struct.pack("I", len(buf)) + buf
+
+
+def split_pb_stream(f):
+    while True:
+        head = f.read(4)
+        if len(head) == 0:
+            break
+        size = struct.unpack("I", head)[0]
+        buf = f.read(size)
+        yield head + buf

+ 44 - 0
tools/split_protos.py

@@ -0,0 +1,44 @@
+from pathlib import Path
+
+import click
+from loguru import logger
+
+from fish_speech.datasets.protos.text_data_stream import split_pb_stream
+
+
+@click.command()
+@click.argument("input", type=click.Path(exists=True, path_type=Path))
+@click.argument("output", type=click.Path(path_type=Path))
+@click.option("--chunk-size", type=int, default=1024**3)  # 1GB
+def main(input, output, chunk_size):
+    chunk_idx = 0
+    current_size = 0
+    current_file = None
+
+    if output.exists() is False:
+        output.mkdir(parents=True)
+
+    with open(input, "rb") as f:
+        for chunk in split_pb_stream(f):
+            if current_file is None or current_size + len(chunk) > chunk_size:
+                if current_file is not None:
+                    current_file.close()
+
+                current_file = open(
+                    output / f"{input.stem}.{chunk_idx:04d}.protos", "wb"
+                )
+                chunk_idx += 1
+                current_size = 0
+                logger.info(f"Writing to {current_file.name}")
+
+            current_file.write(chunk)
+            current_size += len(chunk)
+
+    if current_file is not None:
+        current_file.close()
+
+    logger.info(f"Split {input} into {chunk_idx} files")
+
+
+if __name__ == "__main__":
+    main()