main.rs 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. use clap::Parser;
  2. use log::info;
  3. use prost::Message;
  4. use rand::seq::SliceRandom;
  5. use rand::{thread_rng, Rng};
  6. use std::fs::File;
  7. use std::io::{self, BufReader, Read, Result as IoResult};
  8. use std::vec;
  9. use tonic::{transport::Server, Request, Response, Status};
  10. pub mod text_data {
  11. tonic::include_proto!("text_data");
  12. }
  13. use text_data::{
  14. data_service_server::{DataService, DataServiceServer},
  15. SampleDataRequest, SampledData, Sentence, TextData,
  16. };
  17. #[derive(Default)]
  18. pub struct MyDataService {
  19. groups: Vec<TextData>,
  20. causual_sampling: bool,
  21. weights: Vec<f32>,
  22. }
  23. fn read_pb_stream<R: Read>(mut reader: BufReader<R>) -> io::Result<Vec<TextData>> {
  24. let mut text_data_list = Vec::new();
  25. let mut index = 0;
  26. loop {
  27. let mut size_buf = [0u8; 4];
  28. match reader.read_exact(&mut size_buf) {
  29. Ok(()) => (),
  30. Err(ref e) if e.kind() == io::ErrorKind::UnexpectedEof => break, // End of file
  31. Err(e) => return Err(e),
  32. }
  33. let size = u32::from_le_bytes(size_buf) as usize;
  34. let mut message_buf = vec![0u8; size];
  35. reader.read_exact(&mut message_buf)?;
  36. let text_data = TextData::decode(&message_buf[..])
  37. .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
  38. text_data_list.push(text_data);
  39. index += 1;
  40. if index % 10000 == 0 {
  41. info!("Loaded {} groups", index);
  42. }
  43. }
  44. Ok(text_data_list)
  45. }
  46. impl MyDataService {
  47. pub fn new(files: Vec<String>, causual_sampling: bool) -> IoResult<Self> {
  48. let mut groups = Vec::new();
  49. let mut weights = Vec::new();
  50. for filename in files.iter() {
  51. let file = File::open(filename)?;
  52. let reader = BufReader::new(file);
  53. // Assuming read_pb_stream is implemented and it returns an iterator over TextData
  54. for text_data in read_pb_stream(reader)? {
  55. groups.push(text_data.clone());
  56. weights.push(text_data.sentences.len() as f32); // Assuming sentences is a repeated field in TextData
  57. }
  58. }
  59. info!("Loaded {} groups", groups.len());
  60. Ok(MyDataService {
  61. groups,
  62. weights,
  63. causual_sampling,
  64. })
  65. }
  66. }
  67. #[tonic::async_trait]
  68. impl DataService for MyDataService {
  69. async fn sample_data(
  70. &self,
  71. request: Request<SampleDataRequest>,
  72. ) -> Result<Response<SampledData>, Status> {
  73. let mut num_samples = request.into_inner().num_samples as usize;
  74. let mut rng = thread_rng();
  75. let group = self
  76. .groups
  77. .choose_weighted(&mut rng, |item| item.sentences.len() as f32);
  78. if group.is_err() {
  79. return Err(Status::internal("Failed to select a group"));
  80. }
  81. let group = group.unwrap();
  82. if self.causual_sampling {
  83. if num_samples > group.sentences.len() {
  84. num_samples = group.sentences.len();
  85. }
  86. // Random number between 0 and group.sentences.len() - num_samples
  87. let max = group.sentences.len() - num_samples;
  88. if max <= 0 {
  89. return Ok(Response::new(SampledData {
  90. name: group.name.clone(),
  91. source: group.source.clone(),
  92. samples: group.sentences.clone(),
  93. }));
  94. }
  95. let start = rng.gen_range(0..max);
  96. Ok(Response::new(SampledData {
  97. name: group.name.clone(),
  98. source: group.source.clone(),
  99. samples: group.sentences[start..start + num_samples].to_vec(),
  100. }))
  101. } else {
  102. let sentences_ref = group.sentences.choose_multiple(&mut rng, num_samples);
  103. let sentences: Vec<Sentence> = sentences_ref
  104. .into_iter()
  105. .cloned() // Clone each &Sentence to get Sentence
  106. .collect();
  107. Ok(Response::new(SampledData {
  108. name: group.name.clone(),
  109. source: group.source.clone(),
  110. samples: sentences,
  111. }))
  112. }
  113. }
  114. }
  115. /// My Data Service Application
  116. #[derive(Parser, Debug)]
  117. #[clap(author, version, about, long_about = None)]
  118. struct Args {
  119. /// Files to process
  120. #[clap(short, long, value_name = "FILE", required = true)]
  121. files: Vec<String>,
  122. /// Causual sampling
  123. #[clap(short, long, default_value = "false")]
  124. causal: bool,
  125. /// Address to bind to
  126. #[clap(short, long, default_value = "127.0.0.1:50051")]
  127. address: String,
  128. }
  129. #[tokio::main]
  130. async fn main() -> Result<(), Box<dyn std::error::Error>> {
  131. env_logger::init();
  132. // Parse command-line arguments
  133. let args = Args::parse();
  134. info!("Arguments: {:?}", args);
  135. let addr = args.address.parse()?;
  136. let data_service = MyDataService::new(args.files, args.causal)?;
  137. info!("Starting server at {}", addr);
  138. Server::builder()
  139. .add_service(DataServiceServer::new(data_service))
  140. .serve(addr)
  141. .await?;
  142. Ok(())
  143. }