|
|
@@ -3,9 +3,10 @@ use log::info;
|
|
|
use prost::Message;
|
|
|
use rand::seq::SliceRandom;
|
|
|
use rand::{thread_rng, Rng};
|
|
|
-use std::fs::File;
|
|
|
+use std::fs::{self, File};
|
|
|
use std::io::{self, BufReader, Read, Result as IoResult};
|
|
|
-use std::vec;
|
|
|
+use std::path::{Path, PathBuf};
|
|
|
+use std::{env, vec};
|
|
|
use tonic::{transport::Server, Request, Response, Status};
|
|
|
|
|
|
pub mod text_data {
|
|
|
@@ -39,7 +40,7 @@ pub struct MyDataService {
|
|
|
weights: Vec<f32>,
|
|
|
}
|
|
|
|
|
|
-fn read_pb_stream<R: Read>(mut reader: BufReader<R>) -> io::Result<Vec<RSTextData>> {
|
|
|
+async fn read_pb_stream<R: Read>(mut reader: BufReader<R>) -> io::Result<Vec<RSTextData>> {
|
|
|
let mut text_data_list = Vec::new();
|
|
|
let mut index = 0;
|
|
|
let mut total_vq_frames = 0;
|
|
|
@@ -91,32 +92,78 @@ fn read_pb_stream<R: Read>(mut reader: BufReader<R>) -> io::Result<Vec<RSTextDat
|
|
|
|
|
|
index += 1;
|
|
|
|
|
|
- if index % 10000 == 0 {
|
|
|
+ if index % 1000 == 0 {
|
|
|
info!("Loaded {} groups, total vq frames: {}", index, total_vq_frames);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- info!("Loaded {} groups, total vq frames: {}", index, total_vq_frames);
|
|
|
+ info!("Worker loaded {} groups, total vq frames: {}", index, total_vq_frames);
|
|
|
|
|
|
Ok(text_data_list)
|
|
|
}
|
|
|
|
|
|
+fn list_files(path: PathBuf) -> Vec<String> {
|
|
|
+ let mut files = Vec::new();
|
|
|
+
|
|
|
+ for entry in fs::read_dir(path).unwrap() {
|
|
|
+ let entry = entry.unwrap();
|
|
|
+ let path = entry.path();
|
|
|
+ if path.is_dir() {
|
|
|
+ files.extend(list_files(path));
|
|
|
+ } else {
|
|
|
+ files.push(path.to_str().unwrap().to_string());
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ files
|
|
|
+}
|
|
|
+
|
|
|
impl MyDataService {
|
|
|
- pub fn new(files: Vec<String>, causual_sampling: bool) -> IoResult<Self> {
|
|
|
+ pub async fn new(files: Vec<String>, causual_sampling: bool) -> IoResult<Self> {
|
|
|
let mut groups = Vec::new();
|
|
|
let mut weights = Vec::new();
|
|
|
+ let mut handles = Vec::new();
|
|
|
+ let start_time = std::time::Instant::now();
|
|
|
|
|
|
+ // Expand files if some are directories
|
|
|
+ let mut new_files = Vec::new();
|
|
|
for filename in files.iter() {
|
|
|
- let file = File::open(filename)?;
|
|
|
- let reader = BufReader::new(file);
|
|
|
+ let path = Path::new(filename);
|
|
|
+ if path.is_dir() {
|
|
|
+ // run recursively on all files in the directory
|
|
|
+ for entry in list_files(path.to_path_buf()) {
|
|
|
+ if entry.ends_with(".protos") {
|
|
|
+ new_files.push(entry);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ new_files.push(filename.clone());
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ log::info!("Loading files: {:?}", new_files.len());
|
|
|
+
|
|
|
+ for filename in new_files.iter() {
|
|
|
+ // Tokio launch multiple tasks to read the files in parallel
|
|
|
+ let copied_filename = filename.clone();
|
|
|
+ let handle = tokio::spawn(async move {
|
|
|
+ let file = File::open(copied_filename)?;
|
|
|
+ let reader = BufReader::new(file);
|
|
|
+ read_pb_stream(reader).await
|
|
|
+ });
|
|
|
+ handles.push(handle);
|
|
|
+ }
|
|
|
|
|
|
- // Assuming read_pb_stream is implemented and it returns an iterator over TextData
|
|
|
- for text_data in read_pb_stream(reader)? {
|
|
|
- weights.push(text_data.sentences.len() as f32); // Assuming sentences is a repeated field in TextData
|
|
|
+ for handle in handles {
|
|
|
+ let text_data_list = handle.await??;
|
|
|
+ for text_data in text_data_list {
|
|
|
+ weights.push(text_data.sentences.len() as f32);
|
|
|
groups.push(text_data);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ log::info!("All workers finished, total groups: {}, used time: {:?}", groups.len(), start_time.elapsed());
|
|
|
+
|
|
|
Ok(MyDataService {
|
|
|
groups,
|
|
|
weights,
|
|
|
@@ -213,8 +260,12 @@ struct Args {
|
|
|
address: String,
|
|
|
}
|
|
|
|
|
|
-#[tokio::main]
|
|
|
+#[tokio::main(flavor = "multi_thread", worker_threads = 8)]
|
|
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
|
+ if env::var("RUST_LOG").is_err() {
|
|
|
+ env::set_var("RUST_LOG", "info")
|
|
|
+ }
|
|
|
+
|
|
|
env_logger::init();
|
|
|
|
|
|
// Parse command-line arguments
|
|
|
@@ -222,7 +273,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
|
info!("Arguments: {:?}", args);
|
|
|
|
|
|
let addr = args.address.parse()?;
|
|
|
- let data_service = MyDataService::new(args.files, args.causal)?;
|
|
|
+ let data_service = MyDataService::new(args.files, args.causal).await?;
|
|
|
|
|
|
info!("Starting server at {}", addr);
|
|
|
|