main.rs 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. use clap::Parser;
  2. use futures_lite::future::block_on;
  3. use log::info;
  4. use prost::Message;
  5. use rand::seq::SliceRandom;
  6. use rand::{thread_rng, Rng};
  7. use std::fs::File;
  8. use std::io::{self, BufReader, Read, Result as IoResult};
  9. use std::sync::Arc;
  10. use std::sync::Mutex;
  11. use std::vec;
  12. use tonic::{transport::Server, Request, Response, Status};
  13. pub mod text_data {
  14. tonic::include_proto!("text_data");
  15. }
  16. use text_data::{
  17. data_service_server::{DataService, DataServiceServer},
  18. SampleDataRequest, SampledData, Sentence, TextData,
  19. };
  20. #[derive(Default)]
  21. pub struct MyDataService {
  22. groups: Vec<TextData>,
  23. causual_sampling: bool,
  24. weights: Vec<f32>,
  25. }
  26. fn read_pb_stream<R: Read>(mut reader: BufReader<R>) -> io::Result<Vec<TextData>> {
  27. let mut text_data_list = Vec::new();
  28. let mut index = 0;
  29. loop {
  30. let mut size_buf = [0u8; 4];
  31. match reader.read_exact(&mut size_buf) {
  32. Ok(()) => (),
  33. Err(ref e) if e.kind() == io::ErrorKind::UnexpectedEof => break, // End of file
  34. Err(e) => return Err(e),
  35. }
  36. let size = u32::from_le_bytes(size_buf) as usize;
  37. let mut message_buf = vec![0u8; size];
  38. reader.read_exact(&mut message_buf)?;
  39. let text_data = TextData::decode(&message_buf[..])
  40. .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
  41. text_data_list.push(text_data);
  42. index += 1;
  43. if index % 10000 == 0 {
  44. info!("Loaded {} groups", index);
  45. }
  46. }
  47. Ok(text_data_list)
  48. }
  49. impl MyDataService {
  50. pub fn new(files: Vec<String>, causual_sampling: bool) -> IoResult<Self> {
  51. let groups = Vec::new();
  52. let weights = Vec::new();
  53. let guarded = Arc::new(Mutex::new((groups, weights)));
  54. let mut joins = Vec::with_capacity(files.len());
  55. for filename in files {
  56. let g = guarded.clone();
  57. joins.push(tokio::task::spawn_blocking(move || {
  58. let file = File::open(filename)?;
  59. let reader = BufReader::new(file);
  60. // Assuming read_pb_stream is implemented and it returns an iterator over TextData
  61. for text_data in read_pb_stream(reader)? {
  62. let (groups, weights) = &mut *g.lock().unwrap();
  63. groups.push(text_data.clone());
  64. weights.push(text_data.sentences.len() as f32); // Assuming sentences is a repeated field in TextData
  65. }
  66. Ok::<_, io::Error>(())
  67. }));
  68. }
  69. for join in joins {
  70. block_on(join)??;
  71. }
  72. let (groups, weights) = Arc::into_inner(guarded).unwrap().into_inner().unwrap();
  73. info!("Loaded {} groups", groups.len());
  74. Ok(MyDataService {
  75. groups,
  76. weights,
  77. causual_sampling,
  78. })
  79. }
  80. }
  81. #[tonic::async_trait]
  82. impl DataService for MyDataService {
  83. async fn sample_data(
  84. &self,
  85. request: Request<SampleDataRequest>,
  86. ) -> Result<Response<SampledData>, Status> {
  87. let mut num_samples = request.into_inner().num_samples as usize;
  88. let mut rng = thread_rng();
  89. let group = self
  90. .groups
  91. .choose_weighted(&mut rng, |item| item.sentences.len() as f32);
  92. if group.is_err() {
  93. return Err(Status::internal("Failed to select a group"));
  94. }
  95. let group = group.unwrap();
  96. if self.causual_sampling {
  97. if num_samples > group.sentences.len() {
  98. num_samples = group.sentences.len();
  99. }
  100. // Random number between 0 and group.sentences.len() - num_samples
  101. let max = group.sentences.len() - num_samples;
  102. if max <= 0 {
  103. return Ok(Response::new(SampledData {
  104. name: group.name.clone(),
  105. source: group.source.clone(),
  106. samples: group.sentences.clone(),
  107. }));
  108. }
  109. let start = rng.gen_range(0..max);
  110. Ok(Response::new(SampledData {
  111. name: group.name.clone(),
  112. source: group.source.clone(),
  113. samples: group.sentences[start..start + num_samples].to_vec(),
  114. }))
  115. } else {
  116. let sentences_ref = group.sentences.choose_multiple(&mut rng, num_samples);
  117. let sentences: Vec<Sentence> = sentences_ref
  118. .into_iter()
  119. .cloned() // Clone each &Sentence to get Sentence
  120. .collect();
  121. Ok(Response::new(SampledData {
  122. name: group.name.clone(),
  123. source: group.source.clone(),
  124. samples: sentences,
  125. }))
  126. }
  127. }
  128. }
  129. /// My Data Service Application
  130. #[derive(Parser, Debug)]
  131. #[clap(author, version, about, long_about = None)]
  132. struct Args {
  133. /// Files to process
  134. #[clap(short, long, value_name = "FILE", required = true)]
  135. files: Vec<String>,
  136. /// Causual sampling
  137. #[clap(short, long, default_value = "false")]
  138. causal: bool,
  139. /// Address to bind to
  140. #[clap(short, long, default_value = "127.0.0.1:50051")]
  141. address: String,
  142. }
  143. #[tokio::main]
  144. async fn main() -> Result<(), Box<dyn std::error::Error>> {
  145. env_logger::init();
  146. // Parse command-line arguments
  147. let args = Args::parse();
  148. info!("Arguments: {:?}", args);
  149. let addr = args.address.parse()?;
  150. let data_service = MyDataService::new(args.files, args.causal)?;
  151. info!("Starting server at {}", addr);
  152. Server::builder()
  153. .add_service(DataServiceServer::new(data_service))
  154. .serve(addr)
  155. .await?;
  156. Ok(())
  157. }