소스 검색

Merge branch 'dev-xym-update' of Server/features-tools into master

xueyiming 1 개월 전
부모
커밋
c127b88389

+ 26 - 0
src/main/java/com/tzld/piaoquan/featurestools/job/SummarizeUnderstandingJob.java

@@ -5,6 +5,7 @@ import com.tzld.piaoquan.featurestools.dao.mapper.CreativeVideoSummarizeMapper;
 import com.tzld.piaoquan.featurestools.dao.mapper.CreativeVideoUnderstanderMapper;
 import com.tzld.piaoquan.featurestools.model.bo.EmbeddingResult;
 import com.tzld.piaoquan.featurestools.model.po.CreativeVideoSummarize;
+import com.tzld.piaoquan.featurestools.model.po.CreativeVideoSummarizeExample;
 import com.tzld.piaoquan.featurestools.model.po.CreativeVideoUnderstander;
 import com.tzld.piaoquan.featurestools.model.po.CreativeVideoUnderstanderExample;
 import com.tzld.piaoquan.featurestools.service.CreativeVideoSummarizeService;
@@ -41,6 +42,9 @@ public class SummarizeUnderstandingJob {
     @Autowired
     private CreativeVideoSummarizeService creativeVideoSummarizeService;
 
+    @Autowired
+    private CreativeVideoSummarizeMapper creativeVideoSummarizeMapper;
+
 
     @XxlJob("summarizeUnderstandingJob")
     public ReturnT<String> summarizeUnderstanding(String param) throws InterruptedException {
@@ -149,5 +153,27 @@ public class SummarizeUnderstandingJob {
                 valueCreativeVideoSummarize, urgencyCreativeVideoSummarize);
     }
 
+    @XxlJob("refreshEmbeddingJob")
+    public ReturnT<String> refreshEmbedding(String param) throws InterruptedException {
+        long l = creativeVideoSummarizeMapper.countByExample(new CreativeVideoSummarizeExample());
+        int pageSize = 1000;
+        long pageNum = l / pageSize + 1;
+        for (int i = 0; i < pageNum; i++) {
+            CreativeVideoSummarizeExample example = new CreativeVideoSummarizeExample();
+            example.setPage(new Page<>(i + 1, pageSize));
+            List<CreativeVideoSummarize> creativeVideoSummarizes = creativeVideoSummarizeMapper.selectByExample(example);
+            if (CollectionUtils.isEmpty(creativeVideoSummarizes)) {
+                continue;
+            }
+            for (CreativeVideoSummarize creativeVideoSummarize : creativeVideoSummarizes) {
+                EmbeddingResult result = textEmbeddingService.getEmbedding(creativeVideoSummarize.getAiWordSplit());
+                creativeVideoSummarize.setEmbedding(result.getEmbeddingRes());
+                creativeVideoSummarize.setNlpWordSplit(result.getWords());
+                creativeVideoSummarizeMapper.updateByPrimaryKeyWithBLOBs(creativeVideoSummarize);
+            }
+        }
+        return ReturnT.SUCCESS;
+    }
+
 
 }

+ 10 - 10
src/main/java/com/tzld/piaoquan/featurestools/service/impl/TextEmbeddingServiceImpl.java

@@ -49,9 +49,9 @@ public class TextEmbeddingServiceImpl implements TextEmbeddingService {
         JSONArray embeddings = result.getJSONArray("embeddings");
         List<String> words = result.getJSONArray("tokens").toJavaList(String.class);
 
-        List<List<Double>> filteredEmbeddings = basicFilterEmbeddings(words, embeddings);
+        List<List<Float>> filteredEmbeddings = basicFilterEmbeddings(words, embeddings);
         List<String> filterWords = basicFilterWords(words);
-        List<Double> doubleList = averageEmbeddings(filteredEmbeddings);
+        List<Float> doubleList = averageEmbeddings(filteredEmbeddings);
         String embeddingStr = doubleList.stream()
                 .map(String::valueOf)
                 .collect(Collectors.joining("|"));
@@ -69,10 +69,10 @@ public class TextEmbeddingServiceImpl implements TextEmbeddingService {
 
 
 
-    public List<List<Double>> basicFilterEmbeddings(List<String> words, JSONArray embeddingsArray) {
-        List<List<Double>> saveEmbeddings = new ArrayList<>();
+    public List<List<Float>> basicFilterEmbeddings(List<String> words, JSONArray embeddingsArray) {
+        List<List<Float>> saveEmbeddings = new ArrayList<>();
         for (int i = 0; i < embeddingsArray.size(); i++) {
-            List<Double> embeddings = embeddingsArray.getJSONArray(i).toJavaList(Double.class);
+            List<Float> embeddings = embeddingsArray.getJSONArray(i).toJavaList(Float.class);
             String word = words.get(i);
             if (!stopwords.contains(word)) {
                 saveEmbeddings.add(embeddings);
@@ -91,20 +91,20 @@ public class TextEmbeddingServiceImpl implements TextEmbeddingService {
         return saveWord;
     }
 
-    public List<Double> averageEmbeddings(List<List<Double>> embeddings) {
+    public List<Float> averageEmbeddings(List<List<Float>> embeddings) {
         if (embeddings.isEmpty()) {
             return new ArrayList<>();
         }
         int length = embeddings.get(0).size();
-        double[] sums = new double[length];
-        for (List<Double> embedding : embeddings) {
+        float[] sums = new float[length];
+        for (List<Float> embedding : embeddings) {
             for (int i = 0; i < length; i++) {
                 sums[i] += embedding.get(i);
             }
         }
         int numArrays = embeddings.size();
-        List<Double> averages = new ArrayList<>();
-        for (double sum : sums) {
+        List<Float> averages = new ArrayList<>();
+        for (float sum : sums) {
             averages.add(sum / numArrays);
         }
         return averages;