|
|
@@ -73,7 +73,11 @@ impl MyDataService {
|
|
|
|
|
|
info!("Loaded {} groups", groups.len());
|
|
|
|
|
|
- Ok(MyDataService { groups, weights, causual_sampling })
|
|
|
+ Ok(MyDataService {
|
|
|
+ groups,
|
|
|
+ weights,
|
|
|
+ causual_sampling,
|
|
|
+ })
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -105,7 +109,7 @@ impl DataService for MyDataService {
|
|
|
let max = group.sentences.len() - num_samples;
|
|
|
if max <= 0 {
|
|
|
return Ok(Response::new(SampledData {
|
|
|
- name: group.name.clone(),
|
|
|
+ name: group.name.clone(),
|
|
|
source: group.source.clone(),
|
|
|
samples: group.sentences.clone(),
|
|
|
}));
|
|
|
@@ -113,14 +117,12 @@ impl DataService for MyDataService {
|
|
|
|
|
|
let start = rng.gen_range(0..max);
|
|
|
Ok(Response::new(SampledData {
|
|
|
- name: group.name.clone(),
|
|
|
+ name: group.name.clone(),
|
|
|
source: group.source.clone(),
|
|
|
samples: group.sentences[start..start + num_samples].to_vec(),
|
|
|
}))
|
|
|
} else {
|
|
|
- let sentences_ref = group
|
|
|
- .sentences
|
|
|
- .choose_multiple(&mut rng, num_samples);
|
|
|
+ let sentences_ref = group.sentences.choose_multiple(&mut rng, num_samples);
|
|
|
|
|
|
let sentences: Vec<Sentence> = sentences_ref
|
|
|
.into_iter()
|
|
|
@@ -128,7 +130,7 @@ impl DataService for MyDataService {
|
|
|
.collect();
|
|
|
|
|
|
Ok(Response::new(SampledData {
|
|
|
- name: group.name.clone(),
|
|
|
+ name: group.name.clone(),
|
|
|
source: group.source.clone(),
|
|
|
samples: sentences,
|
|
|
}))
|
|
|
@@ -146,7 +148,11 @@ struct Args {
|
|
|
|
|
|
/// Causual sampling
|
|
|
#[clap(short, long, default_value = "false")]
|
|
|
- causal: bool
|
|
|
+ causal: bool,
|
|
|
+
|
|
|
+ /// Address to bind to
|
|
|
+ #[clap(short, long, default_value = "127.0.0.1:50051")]
|
|
|
+ address: String,
|
|
|
}
|
|
|
|
|
|
#[tokio::main]
|
|
|
@@ -157,7 +163,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
|
let args = Args::parse();
|
|
|
info!("Arguments: {:?}", args);
|
|
|
|
|
|
- let addr = "127.0.0.1:50051".parse()?;
|
|
|
+ let addr = args.address.parse()?;
|
|
|
let data_service = MyDataService::new(args.files, args.causal)?;
|
|
|
|
|
|
info!("Starting server at {}", addr);
|