|
|
@@ -1,10 +1,13 @@
|
|
|
use clap::Parser;
|
|
|
+use futures_lite::future::block_on;
|
|
|
use log::info;
|
|
|
use prost::Message;
|
|
|
use rand::seq::SliceRandom;
|
|
|
use rand::{thread_rng, Rng};
|
|
|
use std::fs::File;
|
|
|
use std::io::{self, BufReader, Read, Result as IoResult};
|
|
|
+use std::sync::Arc;
|
|
|
+use std::sync::Mutex;
|
|
|
use std::vec;
|
|
|
use tonic::{transport::Server, Request, Response, Status};
|
|
|
|
|
|
@@ -57,20 +60,35 @@ fn read_pb_stream<R: Read>(mut reader: BufReader<R>) -> io::Result<Vec<TextData>
|
|
|
|
|
|
impl MyDataService {
|
|
|
pub fn new(files: Vec<String>, causual_sampling: bool) -> IoResult<Self> {
|
|
|
- let mut groups = Vec::new();
|
|
|
- let mut weights = Vec::new();
|
|
|
-
|
|
|
- for filename in files.iter() {
|
|
|
- let file = File::open(filename)?;
|
|
|
- let reader = BufReader::new(file);
|
|
|
+ let groups = Vec::new();
|
|
|
+ let weights = Vec::new();
|
|
|
+
|
|
|
+ let guarded = Arc::new(Mutex::new((groups, weights)));
|
|
|
+
|
|
|
+ let mut joins = Vec::with_capacity(files.len());
|
|
|
+ for filename in files {
|
|
|
+ let g = guarded.clone();
|
|
|
+ joins.push(tokio::task::spawn_blocking(move || {
|
|
|
+ let file = File::open(filename)?;
|
|
|
+ let reader = BufReader::new(file);
|
|
|
+
|
|
|
+ // Assuming read_pb_stream is implemented and it returns an iterator over TextData
|
|
|
+ for text_data in read_pb_stream(reader)? {
|
|
|
+ let (groups, weights) = &mut *g.lock().unwrap();
|
|
|
+ groups.push(text_data.clone());
|
|
|
+ weights.push(text_data.sentences.len() as f32); // Assuming sentences is a repeated field in TextData
|
|
|
+ }
|
|
|
+
|
|
|
+ Ok::<_, io::Error>(())
|
|
|
+ }));
|
|
|
+ }
|
|
|
|
|
|
- // Assuming read_pb_stream is implemented and it returns an iterator over TextData
|
|
|
- for text_data in read_pb_stream(reader)? {
|
|
|
- groups.push(text_data.clone());
|
|
|
- weights.push(text_data.sentences.len() as f32); // Assuming sentences is a repeated field in TextData
|
|
|
- }
|
|
|
+ for join in joins {
|
|
|
+ block_on(join)??;
|
|
|
}
|
|
|
|
|
|
+ let (groups, weights) = Arc::into_inner(guarded).unwrap().into_inner().unwrap();
|
|
|
+
|
|
|
info!("Loaded {} groups", groups.len());
|
|
|
|
|
|
Ok(MyDataService {
|