|
@@ -49,9 +49,9 @@ public class TextEmbeddingServiceImpl implements TextEmbeddingService {
|
|
|
JSONArray embeddings = result.getJSONArray("embeddings");
|
|
|
List<String> words = result.getJSONArray("tokens").toJavaList(String.class);
|
|
|
|
|
|
- List<List<Double>> filteredEmbeddings = basicFilterEmbeddings(words, embeddings);
|
|
|
+ List<List<Float>> filteredEmbeddings = basicFilterEmbeddings(words, embeddings);
|
|
|
List<String> filterWords = basicFilterWords(words);
|
|
|
- List<Double> doubleList = averageEmbeddings(filteredEmbeddings);
|
|
|
+ List<Float> doubleList = averageEmbeddings(filteredEmbeddings);
|
|
|
String embeddingStr = doubleList.stream()
|
|
|
.map(String::valueOf)
|
|
|
.collect(Collectors.joining("|"));
|
|
@@ -69,10 +69,10 @@ public class TextEmbeddingServiceImpl implements TextEmbeddingService {
|
|
|
|
|
|
|
|
|
|
|
|
- public List<List<Double>> basicFilterEmbeddings(List<String> words, JSONArray embeddingsArray) {
|
|
|
- List<List<Double>> saveEmbeddings = new ArrayList<>();
|
|
|
+ public List<List<Float>> basicFilterEmbeddings(List<String> words, JSONArray embeddingsArray) {
|
|
|
+ List<List<Float>> saveEmbeddings = new ArrayList<>();
|
|
|
for (int i = 0; i < embeddingsArray.size(); i++) {
|
|
|
- List<Double> embeddings = embeddingsArray.getJSONArray(i).toJavaList(Double.class);
|
|
|
+ List<Float> embeddings = embeddingsArray.getJSONArray(i).toJavaList(Float.class);
|
|
|
String word = words.get(i);
|
|
|
if (!stopwords.contains(word)) {
|
|
|
saveEmbeddings.add(embeddings);
|
|
@@ -91,20 +91,20 @@ public class TextEmbeddingServiceImpl implements TextEmbeddingService {
|
|
|
return saveWord;
|
|
|
}
|
|
|
|
|
|
- public List<Double> averageEmbeddings(List<List<Double>> embeddings) {
|
|
|
+ public List<Float> averageEmbeddings(List<List<Float>> embeddings) {
|
|
|
if (embeddings.isEmpty()) {
|
|
|
return new ArrayList<>();
|
|
|
}
|
|
|
int length = embeddings.get(0).size();
|
|
|
- double[] sums = new double[length];
|
|
|
- for (List<Double> embedding : embeddings) {
|
|
|
+ float[] sums = new float[length];
|
|
|
+ for (List<Float> embedding : embeddings) {
|
|
|
for (int i = 0; i < length; i++) {
|
|
|
sums[i] += embedding.get(i);
|
|
|
}
|
|
|
}
|
|
|
int numArrays = embeddings.size();
|
|
|
- List<Double> averages = new ArrayList<>();
|
|
|
- for (double sum : sums) {
|
|
|
+ List<Float> averages = new ArrayList<>();
|
|
|
+ for (float sum : sums) {
|
|
|
averages.add(sum / numArrays);
|
|
|
}
|
|
|
return averages;
|