ソースを参照

Improve loading speed & reduce memory to 1/4

Lengyue 2 年 前
コミット
f6a35e09df
1 ファイル変更59 行追加7 行削除
  1. 59 7
      data_server/src/main.rs

+ 59 - 7
data_server/src/main.rs

@@ -17,14 +17,29 @@ use text_data::{
     SampleDataRequest, SampledData, Sentence, TextData,
 };
 
+#[derive(Default, Debug, Clone)]
+pub struct RSSentence {
+    text: String,
+    phones: Vec<String>,
+    semantics: Vec<Vec<u8>>,
+}
+
+#[derive(Default, Debug, Clone)]
+pub struct RSTextData {
+    source: String,
+    name: String,
+    languages: Vec<String>,
+    sentences: Vec<RSSentence>,
+}
+
 #[derive(Default)]
 pub struct MyDataService {
-    groups: Vec<TextData>,
+    groups: Vec<RSTextData>,
     causual_sampling: bool,
     weights: Vec<f32>,
 }
 
-fn read_pb_stream<R: Read>(mut reader: BufReader<R>) -> io::Result<Vec<TextData>> {
+fn read_pb_stream<R: Read>(mut reader: BufReader<R>) -> io::Result<Vec<RSTextData>> {
     let mut text_data_list = Vec::new();
     let mut index = 0;
     let mut total_vq_frames = 0;
@@ -47,9 +62,32 @@ fn read_pb_stream<R: Read>(mut reader: BufReader<R>) -> io::Result<Vec<TextData>
 
         text_data.sentences.iter().for_each(|sentence| {
             total_vq_frames += sentence.semantics[0].values.len();
+
+            // Check that all values are in the range 0-255
+            sentence
+                .semantics
+                .iter()
+                .for_each(|semantics| semantics.values.iter().for_each(|v| assert!(*v <= 255)));
         });
         
-        text_data_list.push(text_data);
+        text_data_list.push(RSTextData {
+            source: text_data.source.clone(),
+            name: text_data.name.clone(),
+            languages: text_data.languages.clone(),
+            sentences: text_data
+                .sentences
+                .iter()
+                .map(|sentence| RSSentence {
+                    text: sentence.text.clone(),
+                    phones: sentence.phones.clone(),
+                    semantics: sentence
+                        .semantics
+                        .iter()
+                        .map(|semantics| semantics.values.iter().map(|v| *v as u8).collect())
+                        .collect(),
+                })
+                .collect(),
+        });
 
         index += 1;
 
@@ -74,8 +112,8 @@ impl MyDataService {
 
             // Assuming read_pb_stream is implemented and it returns an iterator over TextData
             for text_data in read_pb_stream(reader)? {
-                groups.push(text_data.clone());
                 weights.push(text_data.sentences.len() as f32); // Assuming sentences is a repeated field in TextData
+                groups.push(text_data);
             }
         }
 
@@ -87,6 +125,20 @@ impl MyDataService {
     }
 }
 
+fn cast_rs_sentence(sentence: &RSSentence) -> Sentence {
+    Sentence {
+        text: sentence.text.clone(),
+        phones: sentence.phones.clone(),
+        semantics: sentence
+            .semantics
+            .iter()
+            .map(|semantics| text_data::Semantics {
+                values: semantics.iter().map(|v| *v as u32).collect(),
+            })
+            .collect(),
+    }
+}
+
 #[tonic::async_trait]
 impl DataService for MyDataService {
     async fn sample_data(
@@ -117,7 +169,7 @@ impl DataService for MyDataService {
                 return Ok(Response::new(SampledData {
                     name: group.name.clone(),
                     source: group.source.clone(),
-                    samples: group.sentences.clone(),
+                    samples: (&group.sentences).into_iter().map(cast_rs_sentence).collect(),
                 }));
             }
 
@@ -125,14 +177,14 @@ impl DataService for MyDataService {
             Ok(Response::new(SampledData {
                 name: group.name.clone(),
                 source: group.source.clone(),
-                samples: group.sentences[start..start + num_samples].to_vec(),
+                samples: group.sentences[start..start + num_samples].iter().map(cast_rs_sentence).collect(),
             }))
         } else {
             let sentences_ref = group.sentences.choose_multiple(&mut rng, num_samples);
 
             let sentences: Vec<Sentence> = sentences_ref
                 .into_iter()
-                .cloned() // Clone each &Sentence to get Sentence
+                .map(cast_rs_sentence)
                 .collect();
 
             Ok(Response::new(SampledData {