main.rs 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. use clap::Parser;
  2. use log::info;
  3. use prost::Message;
  4. use rand::prelude::IteratorRandom;
  5. use rand::seq::SliceRandom;
  6. use rand::thread_rng;
  7. use std::fs::File;
  8. use std::io::{self, BufReader, Read, Result as IoResult};
  9. use std::vec;
  10. use tonic::{transport::Server, Request, Response, Status};
  11. pub mod text_data {
  12. tonic::include_proto!("text_data");
  13. }
  14. use text_data::{
  15. data_service_server::{DataService, DataServiceServer},
  16. SampleDataRequest, SampledData, Sentence, TextData,
  17. };
  18. #[derive(Default)]
  19. pub struct MyDataService {
  20. groups: Vec<TextData>,
  21. weights: Vec<f32>,
  22. }
  23. fn read_pb_stream<R: Read>(mut reader: BufReader<R>) -> io::Result<Vec<TextData>> {
  24. let mut text_data_list = Vec::new();
  25. let mut index = 0;
  26. loop {
  27. let mut size_buf = [0u8; 4];
  28. match reader.read_exact(&mut size_buf) {
  29. Ok(()) => (),
  30. Err(ref e) if e.kind() == io::ErrorKind::UnexpectedEof => break, // End of file
  31. Err(e) => return Err(e),
  32. }
  33. let size = u32::from_le_bytes(size_buf) as usize;
  34. let mut message_buf = vec![0u8; size];
  35. reader.read_exact(&mut message_buf)?;
  36. let text_data = TextData::decode(&message_buf[..])
  37. .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
  38. text_data_list.push(text_data);
  39. index += 1;
  40. if index % 10000 == 0 {
  41. info!("Loaded {} groups", index);
  42. }
  43. }
  44. Ok(text_data_list)
  45. }
  46. impl MyDataService {
  47. pub fn new(files: Vec<String>) -> IoResult<Self> {
  48. let mut groups = Vec::new();
  49. let mut weights = Vec::new();
  50. for filename in files.iter() {
  51. let file = File::open(filename)?;
  52. let reader = BufReader::new(file);
  53. // Assuming read_pb_stream is implemented and it returns an iterator over TextData
  54. for text_data in read_pb_stream(reader)? {
  55. groups.push(text_data.clone());
  56. weights.push(text_data.sentences.len() as f32); // Assuming sentences is a repeated field in TextData
  57. }
  58. }
  59. info!("Loaded {} groups", groups.len());
  60. Ok(MyDataService { groups, weights })
  61. }
  62. }
  63. #[tonic::async_trait]
  64. impl DataService for MyDataService {
  65. async fn sample_data(
  66. &self,
  67. request: Request<SampleDataRequest>,
  68. ) -> Result<Response<SampledData>, Status> {
  69. let mut num_samples = request.into_inner().num_samples as usize;
  70. let mut rng = thread_rng();
  71. let group = self
  72. .groups
  73. .choose_weighted(&mut rng, |item| item.sentences.len() as f32);
  74. if group.is_ok() {
  75. let group = group.unwrap();
  76. if num_samples > group.sentences.len() {
  77. num_samples = group.sentences.len();
  78. }
  79. let sentences_ref = group
  80. .sentences
  81. .iter()
  82. .choose_multiple(&mut rng, num_samples);
  83. let sentences: Vec<Sentence> = sentences_ref
  84. .into_iter()
  85. .cloned() // Clone each &Sentence to get Sentence
  86. .collect();
  87. Ok(Response::new(SampledData {
  88. name: group.name.clone(),
  89. source: group.source.clone(),
  90. samples: sentences
  91. }))
  92. } else {
  93. Err(Status::internal("Failed to select a group"))
  94. }
  95. }
  96. }
  97. /// My Data Service Application
  98. #[derive(Parser, Debug)]
  99. #[clap(author, version, about, long_about = None)]
  100. struct Args {
  101. /// Files to process
  102. #[clap(short, long, value_name = "FILE", required = true)]
  103. files: Vec<String>,
  104. }
  105. #[tokio::main]
  106. async fn main() -> Result<(), Box<dyn std::error::Error>> {
  107. env_logger::init();
  108. // Parse command-line arguments
  109. let args = Args::parse();
  110. let addr = "127.0.0.1:50051".parse()?;
  111. let data_service = MyDataService::new(args.files)?;
  112. info!("Starting server at {}", addr);
  113. Server::builder()
  114. .add_service(DataServiceServer::new(data_service))
  115. .serve(addr)
  116. .await?;
  117. Ok(())
  118. }