Просмотр исходного кода

Parallelized loading in MyDataService::new (#128)

2 лет назад
Родитель
Сommit
053169e55f
3 измененных файлов с 56 добавлено и 11 удалено
  1. 26 0
      data_server/Cargo.lock
  2. 1 0
      data_server/Cargo.toml
  3. 29 11
      data_server/src/main.rs

+ 26 - 0
data_server/Cargo.lock

@@ -271,6 +271,7 @@ dependencies = [
  "bytes",
  "clap",
  "env_logger",
+ "futures-lite",
  "log",
  "prost",
  "rand",
@@ -347,6 +348,25 @@ version = "0.3.29"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "eb1d22c66e66d9d72e1758f0bd7d4fd0bee04cad842ee34587d68c07e45d088c"
 
+[[package]]
+name = "futures-io"
+version = "0.3.30"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1"
+
+[[package]]
+name = "futures-lite"
+version = "2.3.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "52527eb5074e35e9339c6b4e8d12600c7128b68fb25dcb9fa9dec18f7c25f3a5"
+dependencies = [
+ "fastrand",
+ "futures-core",
+ "futures-io",
+ "parking",
+ "pin-project-lite",
+]
+
 [[package]]
 name = "futures-sink"
 version = "0.3.29"
@@ -659,6 +679,12 @@ version = "1.19.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
 
+[[package]]
+name = "parking"
+version = "2.2.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae"
+
 [[package]]
 name = "parking_lot"
 version = "0.12.1"

+ 1 - 0
data_server/Cargo.toml

@@ -9,6 +9,7 @@ edition = "2021"
 bytes = "1.5.0"
 clap = { version = "4.4.11", features = ["derive"] }
 env_logger = "0.10.1"
+futures-lite = "2.3.0"
 log = "0.4.20"
 prost = "0.12.3"
 rand = "0.8.5"

+ 29 - 11
data_server/src/main.rs

@@ -1,10 +1,13 @@
 use clap::Parser;
+use futures_lite::future::block_on;
 use log::info;
 use prost::Message;
 use rand::seq::SliceRandom;
 use rand::{thread_rng, Rng};
 use std::fs::File;
 use std::io::{self, BufReader, Read, Result as IoResult};
+use std::sync::Arc;
+use std::sync::Mutex;
 use std::vec;
 use tonic::{transport::Server, Request, Response, Status};
 
@@ -57,20 +60,35 @@ fn read_pb_stream<R: Read>(mut reader: BufReader<R>) -> io::Result<Vec<TextData>
 
 impl MyDataService {
     pub fn new(files: Vec<String>, causual_sampling: bool) -> IoResult<Self> {
-        let mut groups = Vec::new();
-        let mut weights = Vec::new();
-
-        for filename in files.iter() {
-            let file = File::open(filename)?;
-            let reader = BufReader::new(file);
+        let groups = Vec::new();
+        let weights = Vec::new();
+
+        let guarded = Arc::new(Mutex::new((groups, weights)));
+
+        let mut joins = Vec::with_capacity(files.len());
+        for filename in files {
+            let g = guarded.clone();
+            joins.push(tokio::task::spawn_blocking(move || {
+                let file = File::open(filename)?;
+                let reader = BufReader::new(file);
+
+                // Assuming read_pb_stream is implemented and it returns an iterator over TextData
+                for text_data in read_pb_stream(reader)? {
+                    let (groups, weights) = &mut *g.lock().unwrap();
+                    groups.push(text_data.clone());
+                    weights.push(text_data.sentences.len() as f32); // Assuming sentences is a repeated field in TextData
+                }
+
+                Ok::<_, io::Error>(())
+            }));
+        }
 
-            // Assuming read_pb_stream is implemented and it returns an iterator over TextData
-            for text_data in read_pb_stream(reader)? {
-                groups.push(text_data.clone());
-                weights.push(text_data.sentences.len() as f32); // Assuming sentences is a repeated field in TextData
-            }
+        for join in joins {
+            block_on(join)??;
         }
 
+        let (groups, weights) = Arc::into_inner(guarded).unwrap().into_inner().unwrap();
+
         info!("Loaded {} groups", groups.len());
 
         Ok(MyDataService {