main.rs 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. use clap::Parser;
  2. use log::info;
  3. use prost::Message;
  4. use rand::seq::SliceRandom;
  5. use rand::{thread_rng, Rng};
  6. use std::fs::{self, File};
  7. use std::io::{self, BufReader, Read, Result as IoResult};
  8. use std::path::{Path, PathBuf};
  9. use std::{env, vec};
  10. use tonic::{transport::Server, Request, Response, Status};
  11. pub mod text_data {
  12. tonic::include_proto!("text_data");
  13. }
  14. use text_data::{
  15. data_service_server::{DataService, DataServiceServer},
  16. SampleDataRequest, SampledData, Sentence, TextData,
  17. };
  18. #[derive(Default, Debug, Clone)]
  19. pub struct RSSentence {
  20. text: String,
  21. phones: Vec<String>,
  22. semantics: Vec<Vec<u8>>,
  23. }
  24. #[derive(Default, Debug, Clone)]
  25. pub struct RSTextData {
  26. source: String,
  27. name: String,
  28. languages: Vec<String>,
  29. sentences: Vec<RSSentence>,
  30. }
  31. #[derive(Default)]
  32. pub struct MyDataService {
  33. groups: Vec<RSTextData>,
  34. causual_sampling: bool,
  35. weights: Vec<f32>,
  36. }
  37. async fn read_pb_stream<R: Read>(mut reader: BufReader<R>) -> io::Result<Vec<RSTextData>> {
  38. let mut text_data_list = Vec::new();
  39. let mut index = 0;
  40. let mut total_vq_frames = 0;
  41. loop {
  42. let mut size_buf = [0u8; 4];
  43. match reader.read_exact(&mut size_buf) {
  44. Ok(()) => (),
  45. Err(ref e) if e.kind() == io::ErrorKind::UnexpectedEof => break, // End of file
  46. Err(e) => return Err(e),
  47. }
  48. let size = u32::from_le_bytes(size_buf) as usize;
  49. let mut message_buf = vec![0u8; size];
  50. reader.read_exact(&mut message_buf)?;
  51. let text_data = TextData::decode(&message_buf[..])
  52. .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
  53. text_data.sentences.iter().for_each(|sentence| {
  54. total_vq_frames += sentence.semantics[0].values.len();
  55. // Check that all values are in the range 0-255
  56. sentence
  57. .semantics
  58. .iter()
  59. .for_each(|semantics| semantics.values.iter().for_each(|v| assert!(*v <= 255)));
  60. });
  61. text_data_list.push(RSTextData {
  62. source: text_data.source.clone(),
  63. name: text_data.name.clone(),
  64. languages: text_data.languages.clone(),
  65. sentences: text_data
  66. .sentences
  67. .iter()
  68. .map(|sentence| RSSentence {
  69. text: sentence.text.clone(),
  70. phones: sentence.phones.clone(),
  71. semantics: sentence
  72. .semantics
  73. .iter()
  74. .map(|semantics| semantics.values.iter().map(|v| *v as u8).collect())
  75. .collect(),
  76. })
  77. .collect(),
  78. });
  79. index += 1;
  80. if index % 1000 == 0 {
  81. info!("Loaded {} groups, total vq frames: {}", index, total_vq_frames);
  82. }
  83. }
  84. info!("Worker loaded {} groups, total vq frames: {}", index, total_vq_frames);
  85. Ok(text_data_list)
  86. }
  87. fn list_files(path: PathBuf) -> Vec<String> {
  88. let mut files = Vec::new();
  89. for entry in fs::read_dir(path).unwrap() {
  90. let entry = entry.unwrap();
  91. let path = entry.path();
  92. if path.is_dir() {
  93. files.extend(list_files(path));
  94. } else {
  95. files.push(path.to_str().unwrap().to_string());
  96. }
  97. }
  98. files
  99. }
  100. impl MyDataService {
  101. pub async fn new(files: Vec<String>, causual_sampling: bool) -> IoResult<Self> {
  102. let mut groups = Vec::new();
  103. let mut weights = Vec::new();
  104. let mut handles = Vec::new();
  105. let start_time = std::time::Instant::now();
  106. // Expand files if some are directories
  107. let mut new_files = Vec::new();
  108. for filename in files.iter() {
  109. let path = Path::new(filename);
  110. if path.is_dir() {
  111. // run recursively on all files in the directory
  112. for entry in list_files(path.to_path_buf()) {
  113. if entry.ends_with(".protos") {
  114. new_files.push(entry);
  115. }
  116. }
  117. } else {
  118. new_files.push(filename.clone());
  119. }
  120. }
  121. log::info!("Loading files: {:?}", new_files.len());
  122. for filename in new_files.iter() {
  123. // Tokio launch multiple tasks to read the files in parallel
  124. let copied_filename = filename.clone();
  125. let handle = tokio::spawn(async move {
  126. let file = File::open(copied_filename)?;
  127. let reader = BufReader::new(file);
  128. read_pb_stream(reader).await
  129. });
  130. handles.push(handle);
  131. }
  132. for handle in handles {
  133. let text_data_list = handle.await??;
  134. for text_data in text_data_list {
  135. weights.push(text_data.sentences.len() as f32);
  136. groups.push(text_data);
  137. }
  138. }
  139. log::info!("All workers finished, total groups: {}, used time: {:?}", groups.len(), start_time.elapsed());
  140. Ok(MyDataService {
  141. groups,
  142. weights,
  143. causual_sampling,
  144. })
  145. }
  146. }
  147. fn cast_rs_sentence(sentence: &RSSentence) -> Sentence {
  148. Sentence {
  149. text: sentence.text.clone(),
  150. phones: sentence.phones.clone(),
  151. semantics: sentence
  152. .semantics
  153. .iter()
  154. .map(|semantics| text_data::Semantics {
  155. values: semantics.iter().map(|v| *v as u32).collect(),
  156. })
  157. .collect(),
  158. }
  159. }
  160. #[tonic::async_trait]
  161. impl DataService for MyDataService {
  162. async fn sample_data(
  163. &self,
  164. request: Request<SampleDataRequest>,
  165. ) -> Result<Response<SampledData>, Status> {
  166. let mut num_samples = request.into_inner().num_samples as usize;
  167. let mut rng = thread_rng();
  168. let group = self
  169. .groups
  170. .choose_weighted(&mut rng, |item| item.sentences.len() as f32);
  171. if group.is_err() {
  172. return Err(Status::internal("Failed to select a group"));
  173. }
  174. let group = group.unwrap();
  175. if self.causual_sampling {
  176. if num_samples > group.sentences.len() {
  177. num_samples = group.sentences.len();
  178. }
  179. // Random number between 0 and group.sentences.len() - num_samples
  180. let max = group.sentences.len() - num_samples;
  181. if max <= 0 {
  182. return Ok(Response::new(SampledData {
  183. name: group.name.clone(),
  184. source: group.source.clone(),
  185. samples: (&group.sentences).into_iter().map(cast_rs_sentence).collect(),
  186. }));
  187. }
  188. let start = rng.gen_range(0..max);
  189. Ok(Response::new(SampledData {
  190. name: group.name.clone(),
  191. source: group.source.clone(),
  192. samples: group.sentences[start..start + num_samples].iter().map(cast_rs_sentence).collect(),
  193. }))
  194. } else {
  195. let sentences_ref = group.sentences.choose_multiple(&mut rng, num_samples);
  196. let sentences: Vec<Sentence> = sentences_ref
  197. .into_iter()
  198. .map(cast_rs_sentence)
  199. .collect();
  200. Ok(Response::new(SampledData {
  201. name: group.name.clone(),
  202. source: group.source.clone(),
  203. samples: sentences,
  204. }))
  205. }
  206. }
  207. }
  208. /// My Data Service Application
  209. #[derive(Parser, Debug)]
  210. #[clap(author, version, about, long_about = None)]
  211. struct Args {
  212. /// Files to process
  213. #[clap(short, long, value_name = "FILE", required = true)]
  214. files: Vec<String>,
  215. /// Causual sampling
  216. #[clap(short, long, default_value = "false")]
  217. causal: bool,
  218. /// Address to bind to
  219. #[clap(short, long, default_value = "127.0.0.1:50051")]
  220. address: String,
  221. }
  222. #[tokio::main(flavor = "multi_thread", worker_threads = 8)]
  223. async fn main() -> Result<(), Box<dyn std::error::Error>> {
  224. if env::var("RUST_LOG").is_err() {
  225. env::set_var("RUST_LOG", "info")
  226. }
  227. env_logger::init();
  228. // Parse command-line arguments
  229. let args = Args::parse();
  230. info!("Arguments: {:?}", args);
  231. let addr = args.address.parse()?;
  232. let data_service = MyDataService::new(args.files, args.causal).await?;
  233. info!("Starting server at {}", addr);
  234. Server::builder()
  235. .add_service(DataServiceServer::new(data_service))
  236. .serve(addr)
  237. .await?;
  238. Ok(())
  239. }