Browse Source

修改embedding位数问题

xueyiming 1 tháng trước cách đây
mục cha
commit
9a0da29306

+ 10 - 10
src/main/java/com/tzld/piaoquan/featurestools/service/impl/TextEmbeddingServiceImpl.java

@@ -49,9 +49,9 @@ public class TextEmbeddingServiceImpl implements TextEmbeddingService {
         JSONArray embeddings = result.getJSONArray("embeddings");
         List<String> words = result.getJSONArray("tokens").toJavaList(String.class);
 
-        List<List<Double>> filteredEmbeddings = basicFilterEmbeddings(words, embeddings);
+        List<List<Float>> filteredEmbeddings = basicFilterEmbeddings(words, embeddings);
         List<String> filterWords = basicFilterWords(words);
-        List<Double> doubleList = averageEmbeddings(filteredEmbeddings);
+        List<Float> doubleList = averageEmbeddings(filteredEmbeddings);
         String embeddingStr = doubleList.stream()
                 .map(String::valueOf)
                 .collect(Collectors.joining("|"));
@@ -69,10 +69,10 @@ public class TextEmbeddingServiceImpl implements TextEmbeddingService {
 
 
 
-    public List<List<Double>> basicFilterEmbeddings(List<String> words, JSONArray embeddingsArray) {
-        List<List<Double>> saveEmbeddings = new ArrayList<>();
+    public List<List<Float>> basicFilterEmbeddings(List<String> words, JSONArray embeddingsArray) {
+        List<List<Float>> saveEmbeddings = new ArrayList<>();
         for (int i = 0; i < embeddingsArray.size(); i++) {
-            List<Double> embeddings = embeddingsArray.getJSONArray(i).toJavaList(Double.class);
+            List<Float> embeddings = embeddingsArray.getJSONArray(i).toJavaList(Float.class);
             String word = words.get(i);
             if (!stopwords.contains(word)) {
                 saveEmbeddings.add(embeddings);
@@ -91,20 +91,20 @@ public class TextEmbeddingServiceImpl implements TextEmbeddingService {
         return saveWord;
     }
 
-    public List<Double> averageEmbeddings(List<List<Double>> embeddings) {
+    public List<Float> averageEmbeddings(List<List<Float>> embeddings) {
         if (embeddings.isEmpty()) {
             return new ArrayList<>();
         }
         int length = embeddings.get(0).size();
-        double[] sums = new double[length];
-        for (List<Double> embedding : embeddings) {
+        float[] sums = new float[length];
+        for (List<Float> embedding : embeddings) {
             for (int i = 0; i < length; i++) {
                 sums[i] += embedding.get(i);
             }
         }
         int numArrays = embeddings.size();
-        List<Double> averages = new ArrayList<>();
-        for (double sum : sums) {
+        List<Float> averages = new ArrayList<>();
+        for (float sum : sums) {
             averages.add(sum / numArrays);
         }
         return averages;