Просмотр исходного кода

Add an option to specify data server address.

Lengyue 2 лет назад
Родитель
Сommit
61a8ab15c7
1 измененных файлов с 15 добавлено и 9 удалено
  1. 15 9
      data_server/src/main.rs

+ 15 - 9
data_server/src/main.rs

@@ -73,7 +73,11 @@ impl MyDataService {
 
         info!("Loaded {} groups", groups.len());
 
-        Ok(MyDataService { groups, weights, causual_sampling })
+        Ok(MyDataService {
+            groups,
+            weights,
+            causual_sampling,
+        })
     }
 }
 
@@ -105,7 +109,7 @@ impl DataService for MyDataService {
             let max = group.sentences.len() - num_samples;
             if max <= 0 {
                 return Ok(Response::new(SampledData {
-                    name: group.name.clone(), 
+                    name: group.name.clone(),
                     source: group.source.clone(),
                     samples: group.sentences.clone(),
                 }));
@@ -113,14 +117,12 @@ impl DataService for MyDataService {
 
             let start = rng.gen_range(0..max);
             Ok(Response::new(SampledData {
-                name: group.name.clone(), 
+                name: group.name.clone(),
                 source: group.source.clone(),
                 samples: group.sentences[start..start + num_samples].to_vec(),
             }))
         } else {
-            let sentences_ref = group
-                .sentences
-                .choose_multiple(&mut rng, num_samples);
+            let sentences_ref = group.sentences.choose_multiple(&mut rng, num_samples);
 
             let sentences: Vec<Sentence> = sentences_ref
                 .into_iter()
@@ -128,7 +130,7 @@ impl DataService for MyDataService {
                 .collect();
 
             Ok(Response::new(SampledData {
-                name: group.name.clone(), 
+                name: group.name.clone(),
                 source: group.source.clone(),
                 samples: sentences,
             }))
@@ -146,7 +148,11 @@ struct Args {
 
     /// Causual sampling
     #[clap(short, long, default_value = "false")]
-    causal: bool
+    causal: bool,
+
+    /// Address to bind to
+    #[clap(short, long, default_value = "127.0.0.1:50051")]
+    address: String,
 }
 
 #[tokio::main]
@@ -157,7 +163,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
     let args = Args::parse();
     info!("Arguments: {:?}", args);
 
-    let addr = "127.0.0.1:50051".parse()?;
+    let addr = args.address.parse()?;
     let data_service = MyDataService::new(args.files, args.causal)?;
 
     info!("Starting server at {}", addr);