|
|
@@ -17,14 +17,29 @@ use text_data::{
|
|
|
SampleDataRequest, SampledData, Sentence, TextData,
|
|
|
};
|
|
|
|
|
|
+#[derive(Default, Debug, Clone)]
|
|
|
+pub struct RSSentence {
|
|
|
+ text: String,
|
|
|
+ phones: Vec<String>,
|
|
|
+ semantics: Vec<Vec<u8>>,
|
|
|
+}
|
|
|
+
|
|
|
+#[derive(Default, Debug, Clone)]
|
|
|
+pub struct RSTextData {
|
|
|
+ source: String,
|
|
|
+ name: String,
|
|
|
+ languages: Vec<String>,
|
|
|
+ sentences: Vec<RSSentence>,
|
|
|
+}
|
|
|
+
|
|
|
#[derive(Default)]
|
|
|
pub struct MyDataService {
|
|
|
- groups: Vec<TextData>,
|
|
|
+ groups: Vec<RSTextData>,
|
|
|
causual_sampling: bool,
|
|
|
weights: Vec<f32>,
|
|
|
}
|
|
|
|
|
|
-fn read_pb_stream<R: Read>(mut reader: BufReader<R>) -> io::Result<Vec<TextData>> {
|
|
|
+fn read_pb_stream<R: Read>(mut reader: BufReader<R>) -> io::Result<Vec<RSTextData>> {
|
|
|
let mut text_data_list = Vec::new();
|
|
|
let mut index = 0;
|
|
|
let mut total_vq_frames = 0;
|
|
|
@@ -47,9 +62,32 @@ fn read_pb_stream<R: Read>(mut reader: BufReader<R>) -> io::Result<Vec<TextData>
|
|
|
|
|
|
text_data.sentences.iter().for_each(|sentence| {
|
|
|
total_vq_frames += sentence.semantics[0].values.len();
|
|
|
+
|
|
|
+ // Check that all values are in the range 0-255
|
|
|
+ sentence
|
|
|
+ .semantics
|
|
|
+ .iter()
|
|
|
+ .for_each(|semantics| semantics.values.iter().for_each(|v| assert!(*v <= 255)));
|
|
|
});
|
|
|
|
|
|
- text_data_list.push(text_data);
|
|
|
+ text_data_list.push(RSTextData {
|
|
|
+ source: text_data.source.clone(),
|
|
|
+ name: text_data.name.clone(),
|
|
|
+ languages: text_data.languages.clone(),
|
|
|
+ sentences: text_data
|
|
|
+ .sentences
|
|
|
+ .iter()
|
|
|
+ .map(|sentence| RSSentence {
|
|
|
+ text: sentence.text.clone(),
|
|
|
+ phones: sentence.phones.clone(),
|
|
|
+ semantics: sentence
|
|
|
+ .semantics
|
|
|
+ .iter()
|
|
|
+ .map(|semantics| semantics.values.iter().map(|v| *v as u8).collect())
|
|
|
+ .collect(),
|
|
|
+ })
|
|
|
+ .collect(),
|
|
|
+ });
|
|
|
|
|
|
index += 1;
|
|
|
|
|
|
@@ -74,8 +112,8 @@ impl MyDataService {
|
|
|
|
|
|
// Assuming read_pb_stream is implemented and it returns an iterator over TextData
|
|
|
for text_data in read_pb_stream(reader)? {
|
|
|
- groups.push(text_data.clone());
|
|
|
weights.push(text_data.sentences.len() as f32); // Assuming sentences is a repeated field in TextData
|
|
|
+ groups.push(text_data);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -87,6 +125,20 @@ impl MyDataService {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+fn cast_rs_sentence(sentence: &RSSentence) -> Sentence {
|
|
|
+ Sentence {
|
|
|
+ text: sentence.text.clone(),
|
|
|
+ phones: sentence.phones.clone(),
|
|
|
+ semantics: sentence
|
|
|
+ .semantics
|
|
|
+ .iter()
|
|
|
+ .map(|semantics| text_data::Semantics {
|
|
|
+ values: semantics.iter().map(|v| *v as u32).collect(),
|
|
|
+ })
|
|
|
+ .collect(),
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
#[tonic::async_trait]
|
|
|
impl DataService for MyDataService {
|
|
|
async fn sample_data(
|
|
|
@@ -117,7 +169,7 @@ impl DataService for MyDataService {
|
|
|
return Ok(Response::new(SampledData {
|
|
|
name: group.name.clone(),
|
|
|
source: group.source.clone(),
|
|
|
- samples: group.sentences.clone(),
|
|
|
+ samples: (&group.sentences).into_iter().map(cast_rs_sentence).collect(),
|
|
|
}));
|
|
|
}
|
|
|
|
|
|
@@ -125,14 +177,14 @@ impl DataService for MyDataService {
|
|
|
Ok(Response::new(SampledData {
|
|
|
name: group.name.clone(),
|
|
|
source: group.source.clone(),
|
|
|
- samples: group.sentences[start..start + num_samples].to_vec(),
|
|
|
+ samples: group.sentences[start..start + num_samples].iter().map(cast_rs_sentence).collect(),
|
|
|
}))
|
|
|
} else {
|
|
|
let sentences_ref = group.sentences.choose_multiple(&mut rng, num_samples);
|
|
|
|
|
|
let sentences: Vec<Sentence> = sentences_ref
|
|
|
.into_iter()
|
|
|
- .cloned() // Clone each &Sentence to get Sentence
|
|
|
+ .map(cast_rs_sentence)
|
|
|
.collect();
|
|
|
|
|
|
Ok(Response::new(SampledData {
|