|
|
@@ -1,9 +1,8 @@
|
|
|
use clap::Parser;
|
|
|
use log::info;
|
|
|
use prost::Message;
|
|
|
-use rand::prelude::IteratorRandom;
|
|
|
use rand::seq::SliceRandom;
|
|
|
-use rand::thread_rng;
|
|
|
+use rand::{thread_rng, Rng};
|
|
|
use std::fs::File;
|
|
|
use std::io::{self, BufReader, Read, Result as IoResult};
|
|
|
use std::vec;
|
|
|
@@ -21,6 +20,7 @@ use text_data::{
|
|
|
#[derive(Default)]
|
|
|
pub struct MyDataService {
|
|
|
groups: Vec<TextData>,
|
|
|
+ causual_sampling: bool,
|
|
|
weights: Vec<f32>,
|
|
|
}
|
|
|
|
|
|
@@ -56,7 +56,7 @@ fn read_pb_stream<R: Read>(mut reader: BufReader<R>) -> io::Result<Vec<TextData>
|
|
|
}
|
|
|
|
|
|
impl MyDataService {
|
|
|
- pub fn new(files: Vec<String>) -> IoResult<Self> {
|
|
|
+ pub fn new(files: Vec<String>, causual_sampling: bool) -> IoResult<Self> {
|
|
|
let mut groups = Vec::new();
|
|
|
let mut weights = Vec::new();
|
|
|
|
|
|
@@ -73,7 +73,7 @@ impl MyDataService {
|
|
|
|
|
|
info!("Loaded {} groups", groups.len());
|
|
|
|
|
|
- Ok(MyDataService { groups, weights })
|
|
|
+ Ok(MyDataService { groups, weights, causual_sampling })
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -90,15 +90,36 @@ impl DataService for MyDataService {
|
|
|
.groups
|
|
|
.choose_weighted(&mut rng, |item| item.sentences.len() as f32);
|
|
|
|
|
|
- if group.is_ok() {
|
|
|
- let group = group.unwrap();
|
|
|
+ if group.is_err() {
|
|
|
+ return Err(Status::internal("Failed to select a group"));
|
|
|
+ }
|
|
|
+
|
|
|
+ let group = group.unwrap();
|
|
|
+
|
|
|
+ if self.causual_sampling {
|
|
|
if num_samples > group.sentences.len() {
|
|
|
num_samples = group.sentences.len();
|
|
|
}
|
|
|
|
|
|
+ // Random number between 0 and group.sentences.len() - num_samples
|
|
|
+ let max = group.sentences.len() - num_samples;
|
|
|
+ if max <= 0 {
|
|
|
+ return Ok(Response::new(SampledData {
|
|
|
+ name: group.name.clone(),
|
|
|
+ source: group.source.clone(),
|
|
|
+ samples: group.sentences.clone(),
|
|
|
+ }));
|
|
|
+ }
|
|
|
+
|
|
|
+ let start = rng.gen_range(0..max);
|
|
|
+ Ok(Response::new(SampledData {
|
|
|
+ name: group.name.clone(),
|
|
|
+ source: group.source.clone(),
|
|
|
+ samples: group.sentences[start..start + num_samples].to_vec(),
|
|
|
+ }))
|
|
|
+ } else {
|
|
|
let sentences_ref = group
|
|
|
.sentences
|
|
|
- .iter()
|
|
|
.choose_multiple(&mut rng, num_samples);
|
|
|
|
|
|
let sentences: Vec<Sentence> = sentences_ref
|
|
|
@@ -109,10 +130,8 @@ impl DataService for MyDataService {
|
|
|
Ok(Response::new(SampledData {
|
|
|
name: group.name.clone(),
|
|
|
source: group.source.clone(),
|
|
|
- samples: sentences
|
|
|
+ samples: sentences,
|
|
|
}))
|
|
|
- } else {
|
|
|
- Err(Status::internal("Failed to select a group"))
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -124,6 +143,10 @@ struct Args {
|
|
|
/// Files to process
|
|
|
#[clap(short, long, value_name = "FILE", required = true)]
|
|
|
files: Vec<String>,
|
|
|
+
|
|
|
+ /// Causual sampling
|
|
|
+ #[clap(short, long, default_value = "false")]
|
|
|
+ causal: bool
|
|
|
}
|
|
|
|
|
|
#[tokio::main]
|
|
|
@@ -132,9 +155,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
|
|
|
|
// Parse command-line arguments
|
|
|
let args = Args::parse();
|
|
|
+ info!("Arguments: {:?}", args);
|
|
|
|
|
|
let addr = "127.0.0.1:50051".parse()?;
|
|
|
- let data_service = MyDataService::new(args.files)?;
|
|
|
+ let data_service = MyDataService::new(args.files, args.causal)?;
|
|
|
|
|
|
info!("Starting server at {}", addr);
|
|
|
|