Browse Source

Add causual sampling

Lengyue 2 years ago
parent
commit
e8e366e6aa
1 changed files with 35 additions and 11 deletions
  1. 35 11
      data_server/src/main.rs

+ 35 - 11
data_server/src/main.rs

@@ -1,9 +1,8 @@
 use clap::Parser;
 use log::info;
 use prost::Message;
-use rand::prelude::IteratorRandom;
 use rand::seq::SliceRandom;
-use rand::thread_rng;
+use rand::{thread_rng, Rng};
 use std::fs::File;
 use std::io::{self, BufReader, Read, Result as IoResult};
 use std::vec;
@@ -21,6 +20,7 @@ use text_data::{
 #[derive(Default)]
 pub struct MyDataService {
     groups: Vec<TextData>,
+    causual_sampling: bool,
     weights: Vec<f32>,
 }
 
@@ -56,7 +56,7 @@ fn read_pb_stream<R: Read>(mut reader: BufReader<R>) -> io::Result<Vec<TextData>
 }
 
 impl MyDataService {
-    pub fn new(files: Vec<String>) -> IoResult<Self> {
+    pub fn new(files: Vec<String>, causual_sampling: bool) -> IoResult<Self> {
         let mut groups = Vec::new();
         let mut weights = Vec::new();
 
@@ -73,7 +73,7 @@ impl MyDataService {
 
         info!("Loaded {} groups", groups.len());
 
-        Ok(MyDataService { groups, weights })
+        Ok(MyDataService { groups, weights, causual_sampling })
     }
 }
 
@@ -90,15 +90,36 @@ impl DataService for MyDataService {
             .groups
             .choose_weighted(&mut rng, |item| item.sentences.len() as f32);
 
-        if group.is_ok() {
-            let group = group.unwrap();
+        if group.is_err() {
+            return Err(Status::internal("Failed to select a group"));
+        }
+
+        let group = group.unwrap();
+
+        if self.causual_sampling {
             if num_samples > group.sentences.len() {
                 num_samples = group.sentences.len();
             }
 
+            // Random number between 0 and group.sentences.len() - num_samples
+            let max = group.sentences.len() - num_samples;
+            if max <= 0 {
+                return Ok(Response::new(SampledData {
+                    name: group.name.clone(), 
+                    source: group.source.clone(),
+                    samples: group.sentences.clone(),
+                }));
+            }
+
+            let start = rng.gen_range(0..max);
+            Ok(Response::new(SampledData {
+                name: group.name.clone(), 
+                source: group.source.clone(),
+                samples: group.sentences[start..start + num_samples].to_vec(),
+            }))
+        } else {
             let sentences_ref = group
                 .sentences
-                .iter()
                 .choose_multiple(&mut rng, num_samples);
 
             let sentences: Vec<Sentence> = sentences_ref
@@ -109,10 +130,8 @@ impl DataService for MyDataService {
             Ok(Response::new(SampledData {
                 name: group.name.clone(), 
                 source: group.source.clone(),
-                samples: sentences 
+                samples: sentences,
             }))
-        } else {
-            Err(Status::internal("Failed to select a group"))
         }
     }
 }
@@ -124,6 +143,10 @@ struct Args {
     /// Files to process
     #[clap(short, long, value_name = "FILE", required = true)]
     files: Vec<String>,
+
+    /// Causual sampling
+    #[clap(short, long, default_value = "false")]
+    causal: bool
 }
 
 #[tokio::main]
@@ -132,9 +155,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
 
     // Parse command-line arguments
     let args = Args::parse();
+    info!("Arguments: {:?}", args);
 
     let addr = "127.0.0.1:50051".parse()?;
-    let data_service = MyDataService::new(args.files)?;
+    let data_service = MyDataService::new(args.files, args.causal)?;
 
     info!("Starting server at {}", addr);