Просмотр исходного кода

新增本地文本向量化及视频向量批量处理任务

wangyunpeng 2 дней назад
Родитель
Сommit
22785ace9e

+ 187 - 1
core/src/main/java/com/tzld/videoVector/job/VideoVectorJob.java

@@ -1,15 +1,46 @@
 package com.tzld.videoVector.job;
 
+import com.alibaba.fastjson.JSONObject;
+import com.aliyun.odps.data.Record;
+import com.tzld.videoVector.service.EmbeddingService;
+import com.tzld.videoVector.service.MilvusService;
+import com.tzld.videoVector.util.MilvusUtil;
+import com.tzld.videoVector.util.OdpsUtil;
 import com.xxl.job.core.biz.model.ReturnT;
 import com.xxl.job.core.handler.annotation.XxlJob;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.stereotype.Component;
 
+import javax.annotation.Resource;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Objects;
+import java.util.Set;
+import java.util.stream.Collectors;
+
 
 @Slf4j
 @Component
 public class VideoVectorJob {
 
+    @Resource
+    private MilvusService milvusService;
+
+    @Resource
+    private MilvusUtil milvusUtil;
+
+    @Resource
+    private EmbeddingService embeddingService;
+
+    /**
+     * 集合名称
+     */
+    private static final String COLLECTION_NAME = "video_vector";
+
+    /**
+     * 每页查询数量
+     */
+    private static final int PAGE_SIZE = 1000;
 
     /**
      * 视频向量化
@@ -18,8 +49,163 @@ public class VideoVectorJob {
      */
     @XxlJob("vectorVideoJob")
     public ReturnT<String> vectorVideoJob(String param) {
+        log.info("开始执行视频向量化任务, param: {}", param);
+        
+        int totalSuccessCount = 0;
+        int totalFailCount = 0;
+        int pageNum = 0;
+        
+        try {
+            while (true) {
+                // 1. 分页查询 videoId 列表
+                List<Long> videoIds = queryVideoIdsByPage(pageNum, PAGE_SIZE);
+                if (videoIds == null || videoIds.isEmpty()) {
+                    log.info("第 {} 页没有查询到数据,分页查询结束", pageNum);
+                    break;
+                }
+                log.info("第 {} 页查询到 {} 个 videoId", pageNum, videoIds.size());
 
-        return ReturnT.SUCCESS;
+                // 2. 查询哪些 videoId 在 Milvus 中已存在
+                Set<Long> existingIds = milvusService.existsByIds(COLLECTION_NAME, videoIds);
+                log.info("已存在 {} 个 videoId,将跳过", existingIds.size());
+
+                // 3. 过滤出不存在的 videoId
+                List<Long> newVideoIds = videoIds.stream()
+                        .filter(id -> !existingIds.contains(id))
+                        .collect(Collectors.toList());
+                
+                if (!newVideoIds.isEmpty()) {
+                    log.info("第 {} 页需要处理 {} 个新的 videoId", pageNum, newVideoIds.size());
+
+                    // 4. 逐个处理新的 videoId
+                    for (Long videoId : newVideoIds) {
+                        try {
+                            // 4.1 查询视频详情
+                            JSONObject videoDetail = queryVideoDetail(videoId);
+                            if (videoDetail == null) {
+                                log.warn("videoId={} 详情查询为空,跳过", videoId);
+                                totalFailCount++;
+                                continue;
+                            }
+
+                            // 4.2 提取字段并向量化
+                            List<Float> vector = extractAndVectorize(videoDetail);
+                            if (vector == null || vector.isEmpty()) {
+                                log.warn("videoId={} 向量化失败,跳过", videoId);
+                                totalFailCount++;
+                                continue;
+                            }
+
+                            // 4.3 存储到 Milvus
+                            insertToMilvus(videoId, vector, videoDetail);
+                            totalSuccessCount++;
+                            log.debug("videoId={} 处理成功", videoId);
+
+                        } catch (Exception e) {
+                            log.error("处理 videoId={} 时发生异常: {}", videoId, e.getMessage(), e);
+                            totalFailCount++;
+                        }
+                    }
+                }
+
+                // 如果查询到的数据少于 PAGE_SIZE,说明已经是最后一页
+                if (videoIds.size() < PAGE_SIZE) {
+                    log.info("第 {} 页数据量 {} 小于 PAGE_SIZE {},分页查询结束", pageNum, videoIds.size(), PAGE_SIZE);
+                    break;
+                }
+                
+                pageNum++;
+            }
+
+            log.info("视频向量化任务完成,总成功: {}, 总失败: {}, 总页数: {}", totalSuccessCount, totalFailCount, pageNum + 1);
+            return ReturnT.SUCCESS;
+
+        } catch (Exception e) {
+            log.error("视频向量化任务执行失败: {}", e.getMessage(), e);
+            return new ReturnT<>(ReturnT.FAIL_CODE, "任务执行失败: " + e.getMessage());
+        }
+    }
+
+    /**
+     * 分页查询 videoId 列表
+     * @param pageNum 页码(从0开始)
+     * @param pageSize 每页数量
+     * @return videoId 列表
+     */
+    private List<Long> queryVideoIdsByPage(int pageNum, int pageSize) {
+        int offset = pageNum * pageSize;
+        String sql = String.format(
+                "SELECT video_id FROM your_table WHERE status = 1 ORDER BY video_id LIMIT %d, %d",
+                offset, pageSize);
+        List<Record> records = OdpsUtil.getOdpsData(sql);
+        if (records == null || records.isEmpty()) {
+            return new ArrayList<>();
+        }
+        return records.stream()
+                .map(record -> record.getBigint("video_id"))
+                .filter(Objects::nonNull)
+                .collect(Collectors.toList());
+    }
+
+    /**
+     * 查询视频详情
+     */
+    private JSONObject queryVideoDetail(Long videoId) {
+        String sql = String.format(
+                "SELECT video_id, title, description, tags, category FROM your_detail_table WHERE video_id = %d",
+                videoId);
+        List<Record> records = OdpsUtil.getOdpsData(sql);
+        if (records == null || records.isEmpty()) {
+            return null;
+        }
+        Record record = records.get(0);
+        JSONObject result = new JSONObject();
+        result.put("video_id", record.getBigint("video_id"));
+        result.put("title", record.getString("title"));
+        result.put("description", record.getString("description"));
+        result.put("tags", record.getString("tags"));
+        result.put("category", record.getString("category"));
+        return result;
+    }
+
+    /**
+     * 提取字段并向量化
+     */
+    private List<Float> extractAndVectorize(JSONObject videoDetail) {
+        // 提取用于向量化的文本字段
+        String title = videoDetail.getString("title");
+        String description = videoDetail.getString("description");
+        String tags = videoDetail.getString("tags");
+
+        // 拼接文本
+        StringBuilder textBuilder = new StringBuilder();
+        if (title != null && !title.isEmpty()) {
+            textBuilder.append(title).append(" ");
+        }
+        if (description != null && !description.isEmpty()) {
+            textBuilder.append(description).append(" ");
+        }
+        if (tags != null && !tags.isEmpty()) {
+            textBuilder.append(tags);
+        }
+        String text = textBuilder.toString().trim();
+        if (text.isEmpty()) {
+            return null;
+        }
+
+        // 使用 EmbeddingService 进行向量化
+        return embeddingService.embed(text);
+    }
+
+    /**
+     * 将向量数据存储到 Milvus
+     */
+    private void insertToMilvus(Long videoId, List<Float> vector, JSONObject videoDetail) {
+        // 使用 MilvusUtil 进行插入操作
+        // 注意:需要确保 Milvus 集合已创建,且包含 video_id 和 vector 字段
+        List<List<Float>> vectors = new ArrayList<>();
+        vectors.add(vector);
+        milvusUtil.insertVectors(COLLECTION_NAME, vectors);
     }
 
 }

+ 29 - 0
core/src/main/java/com/tzld/videoVector/service/EmbeddingService.java

@@ -0,0 +1,29 @@
+package com.tzld.videoVector.service;
+
+import java.util.List;
+
+/**
+ * 文本向量化服务接口
+ */
+public interface EmbeddingService {
+
+    /**
+     * 将文本转换为向量
+     * @param text 输入文本
+     * @return 向量数据
+     */
+    List<Float> embed(String text);
+
+    /**
+     * 批量将文本转换为向量
+     * @param texts 输入文本列表
+     * @return 向量数据列表
+     */
+    List<List<Float>> batchEmbed(List<String> texts);
+
+    /**
+     * 获取向量维度
+     * @return 向量维度
+     */
+    int getDimension();
+}

+ 9 - 0
core/src/main/java/com/tzld/videoVector/service/MilvusService.java

@@ -1,6 +1,7 @@
 package com.tzld.videoVector.service;
 
 import java.util.List;
+import java.util.Set;
 
 /**
  * Milvus 向量服务接口
@@ -59,6 +60,14 @@ public interface MilvusService {
      */
     void deleteCollection(String collectionName);
 
+    /**
+     * 根据 videoId 列表查询在 Milvus 中已存在的 ID
+     * @param collectionName 集合名称
+     * @param videoIds 要查询的 videoId 列表
+     * @return 存在的 videoId 集合
+     */
+    Set<Long> existsByIds(String collectionName, List<Long> videoIds);
+
     /**
      * 搜索结果实体类
      */

+ 209 - 0
core/src/main/java/com/tzld/videoVector/service/impl/EmbeddingServiceImpl.java

@@ -0,0 +1,209 @@
+package com.tzld.videoVector.service.impl;
+
+import com.tzld.videoVector.service.EmbeddingService;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.beans.factory.annotation.Value;
+import org.springframework.stereotype.Service;
+
+import javax.annotation.PostConstruct;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+/**
+ * 基于本地哈希向量化的文本向量化服务实现
+ * 使用 MurmurHash + TF 特征提取,完全本地执行,无需下载模型
+ */
+@Slf4j
+@Service
+public class EmbeddingServiceImpl implements EmbeddingService {
+
+    /**
+     * 向量维度
+     */
+    @Value("${embedding.dimension:1024}")
+    private int dimension;
+
+    /**
+     * N-gram 大小
+     */
+    @Value("${embedding.ngram:2}")
+    private int ngram;
+
+    private int[] hashSeeds;
+
+    @PostConstruct
+    public void init() {
+        // 初始化哈希种子
+        hashSeeds = new int[]{31, 37, 41, 43, 47, 53, 59, 61};
+        log.info("本地向量化服务初始化完成,向量维度: {}, N-gram: {}", dimension, ngram);
+    }
+
+    @Override
+    public List<Float> embed(String text) {
+        if (text == null || text.trim().isEmpty()) {
+            log.warn("输入文本为空,返回空向量");
+            return Collections.emptyList();
+        }
+
+        try {
+            // 1. 文本预处理
+            String processedText = preprocess(text);
+            
+            // 2. 提取 N-gram 特征
+            List<String> features = extractNgrams(processedText, ngram);
+            
+            // 3. 生成哈希向量
+            float[] vector = new float[dimension];
+            
+            for (String feature : features) {
+                // 使用多个哈希函数
+                for (int seed : hashSeeds) {
+                    int hash = murmurHash(feature, seed);
+                    int index = Math.abs(hash % dimension);
+                    // 使用符号位决定加减
+                    float value = (hash & 0x80000000) == 0 ? 1.0f : -1.0f;
+                    vector[index] += value;
+                }
+            }
+            
+            // 4. L2 归一化
+            normalize(vector);
+            
+            return toFloatList(vector);
+        } catch (Exception e) {
+            log.error("文本向量化失败: {}", e.getMessage(), e);
+            return Collections.emptyList();
+        }
+    }
+
+    @Override
+    public List<List<Float>> batchEmbed(List<String> texts) {
+        if (texts == null || texts.isEmpty()) {
+            return Collections.emptyList();
+        }
+
+        List<List<Float>> results = new ArrayList<>(texts.size());
+        for (String text : texts) {
+            results.add(embed(text));
+        }
+        return results;
+    }
+
+    @Override
+    public int getDimension() {
+        return dimension;
+    }
+
+    /**
+     * 文本预处理:转小写、去除特殊字符
+     */
+    private String preprocess(String text) {
+        return text.toLowerCase()
+                .replaceAll("[^a-z0-9\\u4e00-\\u9fa5\\s]", " ")
+                .replaceAll("\\s+", " ")
+                .trim();
+    }
+
+    /**
+     * 提取 N-gram 特征
+     */
+    private List<String> extractNgrams(String text, int n) {
+        List<String> ngrams = new ArrayList<>();
+        
+        // 字符级 N-gram(适合中文)
+        for (int i = 0; i <= text.length() - n; i++) {
+            ngrams.add(text.substring(i, i + n));
+        }
+        
+        // 词级特征(适合英文和分词后的中文)
+        String[] words = text.split("\\s+");
+        for (String word : words) {
+            if (!word.isEmpty()) {
+                ngrams.add(word);
+            }
+        }
+        
+        // 词级 N-gram
+        for (int i = 0; i <= words.length - n; i++) {
+            StringBuilder sb = new StringBuilder();
+            for (int j = 0; j < n; j++) {
+                if (j > 0) sb.append(" ");
+                sb.append(words[i + j]);
+            }
+            ngrams.add(sb.toString());
+        }
+        
+        return ngrams;
+    }
+
+    /**
+     * MurmurHash 哈希函数
+     */
+    private int murmurHash(String text, int seed) {
+        byte[] data = text.getBytes(StandardCharsets.UTF_8);
+        int length = data.length;
+        int h = seed ^ length;
+        int k;
+        
+        for (int i = 0; i + 4 <= length; i += 4) {
+            k = (data[i] & 0xff) 
+                | ((data[i + 1] & 0xff) << 8)
+                | ((data[i + 2] & 0xff) << 16)
+                | ((data[i + 3] & 0xff) << 24);
+            
+            k *= 0x5bd1e995;
+            k ^= k >>> 24;
+            k *= 0x5bd1e995;
+            
+            h *= 0x5bd1e995;
+            h ^= k;
+        }
+        
+        // 处理剩余字节
+        int remaining = length % 4;
+        if (remaining > 0) {
+            int tail = 0;
+            for (int i = 0; i < remaining; i++) {
+                tail |= (data[length - remaining + i] & 0xff) << (i * 8);
+            }
+            h ^= tail;
+            h *= 0x5bd1e995;
+        }
+        
+        h ^= h >>> 13;
+        h *= 0x5bd1e995;
+        h ^= h >>> 15;
+        
+        return h;
+    }
+
+    /**
+     * L2 归一化
+     */
+    private void normalize(float[] vector) {
+        float norm = 0.0f;
+        for (float v : vector) {
+            norm += v * v;
+        }
+        norm = (float) Math.sqrt(norm);
+        
+        if (norm > 0) {
+            for (int i = 0; i < vector.length; i++) {
+                vector[i] /= norm;
+            }
+        }
+    }
+
+    /**
+     * 将 float[] 转换为 List<Float>
+     */
+    private List<Float> toFloatList(float[] array) {
+        List<Float> list = new ArrayList<>(array.length);
+        for (float f : array) {
+            list.add(f);
+        }
+        return list;
+    }
+}

+ 7 - 0
core/src/main/java/com/tzld/videoVector/service/impl/MilvusServiceImpl.java

@@ -10,6 +10,7 @@ import javax.annotation.Resource;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
+import java.util.Set;
 
 /**
  * Milvus 向量服务实现类
@@ -82,4 +83,10 @@ public class MilvusServiceImpl implements MilvusService {
         log.info("删除集合:{}", collectionName);
         milvusUtil.dropCollection(collectionName);
     }
+
+    @Override
+    public Set<Long> existsByIds(String collectionName, List<Long> videoIds) {
+        log.info("查询 {} 个 videoId 是否在集合 {} 中存在", videoIds.size(), collectionName);
+        return milvusUtil.existsByIds(collectionName, videoIds);
+    }
 }

+ 46 - 3
core/src/main/java/com/tzld/videoVector/util/MilvusUtil.java

@@ -1,23 +1,26 @@
 package com.tzld.videoVector.util;
 
+import com.tzld.videoVector.config.MilvusConfig;
 import io.milvus.client.MilvusClient;
 import io.milvus.client.MilvusServiceClient;
 import io.milvus.grpc.DataType;
 import io.milvus.grpc.MutationResult;
+import io.milvus.grpc.QueryResults;
 import io.milvus.grpc.SearchResults;
 import io.milvus.param.ConnectParam;
-import io.milvus.param.MetricType;
 import io.milvus.param.IndexType;
+import io.milvus.param.MetricType;
 import io.milvus.param.R;
 import io.milvus.param.collection.*;
 import io.milvus.param.dml.InsertParam;
+import io.milvus.param.dml.QueryParam;
 import io.milvus.param.dml.SearchParam;
 import io.milvus.param.index.CreateIndexParam;
-import io.milvus.response.SearchResultsWrapper;
 import io.milvus.response.MutationResultWrapper;
+import io.milvus.response.QueryResultsWrapper;
+import io.milvus.response.SearchResultsWrapper;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.stereotype.Component;
-import com.tzld.videoVector.config.MilvusConfig;
 
 import javax.annotation.PostConstruct;
 import javax.annotation.PreDestroy;
@@ -227,6 +230,46 @@ public class MilvusUtil {
         }
     }
 
+    /**
+     * 根据 videoId 列表查询在 Milvus 中存在的 ID
+     * @param collectionName 集合名称
+     * @param videoIds 要查询的 videoId 列表
+     * @return 存在的 videoId 集合
+     */
+    public java.util.Set<Long> existsByIds(String collectionName, List<Long> videoIds) {
+        java.util.Set<Long> existingIds = new java.util.HashSet<>();
+        if (videoIds == null || videoIds.isEmpty()) {
+            return existingIds;
+        }
+        try {
+            // 构建 IN 查询表达式
+            String expr = "video_id in [" + videoIds.stream()
+                    .map(String::valueOf)
+                    .collect(java.util.stream.Collectors.joining(",")) + "]";
+            
+            R<QueryResults> response = milvusClient.query(QueryParam.newBuilder()
+                    .withCollectionName(collectionName)
+                    .withExpr(expr)
+                    .withOutFields(Collections.singletonList("video_id"))
+                    .build());
+            
+            if (response.getData() != null) {
+                QueryResultsWrapper wrapper = new QueryResultsWrapper(response.getData());
+                List<?> fieldData = wrapper.getFieldWrapper("video_id").getFieldData();
+                for (Object id : fieldData) {
+                    if (id instanceof Long) {
+                        existingIds.add((Long) id);
+                    }
+                }
+            }
+            log.info("查询完成,找到 {} 条已存在的记录", existingIds.size());
+            return existingIds;
+        } catch (Exception e) {
+            log.error("查询 videoId 是否存在失败:{}", e.getMessage(), e);
+            return existingIds;
+        }
+    }
+
     /**
      * 获取客户端实例
      */

+ 1 - 1
server/src/main/java/com/tzld/videoVector/controller/XxlJobController.java

@@ -16,7 +16,7 @@ public class XxlJobController {
 
     @GetMapping("/spiderTaskJob")
     public CommonResponse<Void> spiderTaskJob() {
-        videoVectorJob.spiderTaskJob(null);
+        videoVectorJob.vectorVideoJob(null);
         return CommonResponse.success();
     }
 

+ 5 - 1
server/src/main/resources/application.yml

@@ -97,4 +97,8 @@ oss:
 
 cdn:
   upload:
-    domain: https://weappupload.piaoquantv.com/
+    domain: https://weappupload.piaoquantv.com/
+
+embedding:
+  dimension: 1024    # 向量维度
+  ngram: 2          # N-gram 大小