Przeglądaj źródła

优化Redis向量存储服务,新增本地缓存与并行搜索

wangyunpeng 13 godzin temu
rodzic
commit
658c42c0e7

+ 287 - 35
core/src/main/java/com/tzld/videoVector/service/impl/RedisVectorStoreServiceImpl.java

@@ -1,19 +1,32 @@
 package com.tzld.videoVector.service.impl;
 
 import com.alibaba.fastjson.JSONArray;
+import com.google.common.cache.Cache;
+import com.google.common.cache.CacheBuilder;
 import com.tzld.videoVector.model.entity.VideoMatch;
 import com.tzld.videoVector.service.VectorStoreService;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.data.redis.core.RedisCallback;
 import org.springframework.data.redis.core.RedisTemplate;
+import org.springframework.scheduling.annotation.Scheduled;
 import org.springframework.stereotype.Service;
 
 import java.util.*;
 import java.util.stream.Collectors;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
 
 /**
- * 基于 Redis 的向量存储服务实现
+ * 基于 Redis 的向量存储服务实现(优化版)
+ *
+ * <p>优化点:
+ * <ul>
+ *   <li>本地缓存:使用 Guava Cache 缓存向量数据,减少 Redis 查询</li>
+ *   <li>预归一化:存储时即归一化,查询时直接使用</li>
+ *   <li>并行计算:使用并行流加速相似度计算</li>
+ *   <li>堆排序优化:使用优先队列快速获取 TopN</li>
+ * </ul>
  *
  * <p>存储结构:
  * <ul>
@@ -28,6 +41,15 @@ public class RedisVectorStoreServiceImpl implements VectorStoreService {
     /** 向量 Key 前缀 */
     private static final String VECTOR_KEY_PREFIX = "video:vector:";
 
+    /** 本地缓存:configCode -> (videoId -> 归一化向量) */
+    private final Cache<String, Map<Long, float[]>> vectorCache = CacheBuilder.newBuilder()
+            .maximumSize(10)
+            .expireAfterAccess(30, TimeUnit.MINUTES)
+            .build();
+
+    /** 缓存读写锁 */
+    private final ReentrantReadWriteLock cacheLock = new ReentrantReadWriteLock();
+
     @Autowired
     private RedisTemplate<String, String> redisTemplate;
 
@@ -47,10 +69,17 @@ public class RedisVectorStoreServiceImpl implements VectorStoreService {
         if (configCode == null || configCode.isEmpty()) {
             configCode = DEFAULT_CONFIG_CODE;
         }
+
+        // 预归一化后存储
+        float[] normalized = l2Normalize(vector);
         String key = buildKey(configCode, videoId);
-        String value = JSONArray.toJSONString(vector);
+        String value = serializeVector(normalized);
         redisTemplate.opsForValue().set(key, value);
         redisTemplate.opsForSet().add(buildIdsKey(configCode), videoId.toString());
+
+        // 更新本地缓存
+        updateLocalCache(configCode, videoId, normalized);
+
         log.debug("保存向量成功,configCode={}, videoId={}, 维度={}", configCode, videoId, vector.size());
     }
 
@@ -118,8 +147,20 @@ public class RedisVectorStoreServiceImpl implements VectorStoreService {
         if (configCode == null || configCode.isEmpty()) {
             configCode = DEFAULT_CONFIG_CODE;
         }
+
+        // 先查本地缓存
+        float[] cached = getFromLocalCache(configCode, videoId);
+        if (cached != null) {
+            return toList(cached);
+        }
+
+        // 缓存未命中,从 Redis 获取
         String value = redisTemplate.opsForValue().get(buildKey(configCode, videoId));
-        return parseVector(value);
+        float[] vector = parseVector(value);
+        if (vector != null) {
+            return toList(vector);
+        }
+        return null;
     }
 
     @Override
@@ -133,19 +174,60 @@ public class RedisVectorStoreServiceImpl implements VectorStoreService {
         if (configCode == null || configCode.isEmpty()) {
             configCode = DEFAULT_CONFIG_CODE;
         }
-        List<Long> idList = new ArrayList<>(videoIds);
-        final String finalConfigCode = configCode;
-        List<String> keys = idList.stream().map(id -> buildKey(finalConfigCode, id)).collect(Collectors.toList());
-        List<String> values = redisTemplate.opsForValue().multiGet(keys);
 
-        Map<Long, List<Float>> result = new HashMap<>();
-        if (values == null) return result;
-        for (int i = 0; i < idList.size(); i++) {
-            List<Float> vector = parseVector(values.get(i));
-            if (vector != null) {
-                result.put(idList.get(i), vector);
+        Map<Long, float[]> floatResult = getVectorsAsFloatArray(configCode, videoIds);
+        Map<Long, List<Float>> result = new HashMap<>(floatResult.size());
+        for (Map.Entry<Long, float[]> entry : floatResult.entrySet()) {
+            result.put(entry.getKey(), toList(entry.getValue()));
+        }
+        return result;
+    }
+
+    /**
+     * 批量获取向量(返回 float[],避免装箱开销)
+     */
+    private Map<Long, float[]> getVectorsAsFloatArray(String configCode, Collection<Long> videoIds) {
+        if (videoIds == null || videoIds.isEmpty()) return Collections.emptyMap();
+
+        Map<Long, float[]> result = new HashMap<>();
+        List<Long> missedIds = new ArrayList<>();
+
+        // 先从本地缓存获取
+        cacheLock.readLock().lock();
+        try {
+            Map<Long, float[]> cachedMap = vectorCache.getIfPresent(configCode);
+            if (cachedMap != null) {
+                for (Long id : videoIds) {
+                    float[] vector = cachedMap.get(id);
+                    if (vector != null) {
+                        result.put(id, vector);
+                    } else {
+                        missedIds.add(id);
+                    }
+                }
+            } else {
+                missedIds.addAll(videoIds);
+            }
+        } finally {
+            cacheLock.readLock().unlock();
+        }
+
+        // 缓存未命中的从 Redis 获取
+        if (!missedIds.isEmpty()) {
+            List<Long> idList = new ArrayList<>(missedIds);
+            List<String> keys = idList.stream().map(id -> buildKey(configCode, id)).collect(Collectors.toList());
+            List<String> values = redisTemplate.opsForValue().multiGet(keys);
+
+            if (values != null) {
+                for (int i = 0; i < idList.size(); i++) {
+                    float[] vector = parseVector(values.get(i));
+                    if (vector != null) {
+                        result.put(idList.get(i), vector);
+                    }
+                }
             }
         }
+
         return result;
     }
 
@@ -179,10 +261,14 @@ public class RedisVectorStoreServiceImpl implements VectorStoreService {
         }
         redisTemplate.delete(buildKey(configCode, videoId));
         redisTemplate.opsForSet().remove(buildIdsKey(configCode), videoId.toString());
+
+        // 从本地缓存移除
+        removeFromLocalCache(configCode, videoId);
+
         log.debug("删除向量成功,configCode={}, videoId={}", configCode, videoId);
     }
 
-    // ---------------------------------------------------------------- 搜索
+    // ---------------------------------------------------------------- 搜索(优化版)
 
     @Override
     public List<VideoMatch> searchTopN(List<Float> queryVector, int topN) {
@@ -198,34 +284,174 @@ public class RedisVectorStoreServiceImpl implements VectorStoreService {
             configCode = DEFAULT_CONFIG_CODE;
         }
 
-        Set<Long> allIds = getAllVideoIds(configCode);
-        if (allIds.isEmpty()) {
+        // 获取或加载全部向量到本地缓存
+        Map<Long, float[]> allVectors = loadAllVectors(configCode);
+        if (allVectors.isEmpty()) {
             log.info("向量库为空,configCode={},无法搜索", configCode);
             return Collections.emptyList();
         }
 
-        log.info("开始向量搜索,configCode={},库中共 {} 条记录,topN={}", configCode, allIds.size(), topN);
+        log.info("开始向量搜索,configCode={},库中共 {} 条记录,topN={}", configCode, allVectors.size(), topN);
+
+        // 查询向量归一化
+        float[] queryNorm = l2Normalize(queryVector);
+
+        // 并行计算相似度 + 堆排序获取 TopN
+        List<VideoMatch> topMatches = parallelSearchTopN(queryNorm, allVectors, topN);
 
-        // 批量获取所有向量
-        Map<Long, List<Float>> allVectors = getVectors(configCode, allIds);
+        log.info("向量搜索完成,configCode={},返回 {} 条结果", configCode, topMatches.size());
+        return topMatches;
+    }
 
-        // 对查询向量做 L2 归一化,加速余弦相似度计算
-        float[] qNorm = l2Normalize(queryVector);
+    /**
+     * 并行搜索 TopN(使用堆排序优化)
+     */
+    private List<VideoMatch> parallelSearchTopN(float[] queryNorm, Map<Long, float[]> vectors, int topN) {
+        // 使用优先队列(小顶堆)维护 TopN
+        PriorityQueue<VideoMatch> heap = new PriorityQueue<>(topN, Comparator.comparingDouble(VideoMatch::getScore));
+
+        // 并行计算相似度
+        vectors.entrySet().parallelStream().forEach(entry -> {
+            double score = dotProduct(queryNorm, entry.getValue());
+
+            synchronized (heap) {
+                if (heap.size() < topN) {
+                    heap.offer(new VideoMatch(entry.getKey(), score));
+                } else if (score > heap.peek().getScore()) {
+                    heap.poll();
+                    heap.offer(new VideoMatch(entry.getKey(), score));
+                }
+            }
+        });
 
-        // 计算每条记录的余弦相似度
-        List<VideoMatch> matches = new ArrayList<>(allVectors.size());
-        for (Map.Entry<Long, List<Float>> entry : allVectors.entrySet()) {
-            float[] storedNorm = l2Normalize(entry.getValue());
-            double score = dotProduct(qNorm, storedNorm);
-            matches.add(new VideoMatch(entry.getKey(), score));
+        // 按相似度降序输出
+        List<VideoMatch> result = new ArrayList<>(heap.size());
+        while (!heap.isEmpty()) {
+            result.add(heap.poll());
         }
+        Collections.reverse(result);
+        return result;
+    }
 
-        // 按相似度降序,取前 topN
-        matches.sort((a, b) -> Double.compare(b.getScore(), a.getScore()));
-        List<VideoMatch> topMatches = matches.subList(0, Math.min(topN, matches.size()));
+    // ---------------------------------------------------------------- 本地缓存管理
 
-        log.info("向量搜索完成,configCode={},返回 {} 条结果", configCode, topMatches.size());
-        return topMatches;
+    /**
+     * 加载全部向量到本地缓存
+     */
+    private Map<Long, float[]> loadAllVectors(String configCode) {
+        // 先尝试从缓存获取
+        cacheLock.readLock().lock();
+        try {
+            Map<Long, float[]> cached = vectorCache.getIfPresent(configCode);
+            if (cached != null && !cached.isEmpty()) {
+                return cached;
+            }
+        } finally {
+            cacheLock.readLock().unlock();
+        }
+
+        // 缓存为空,从 Redis 加载
+        Set<Long> allIds = getAllVideoIds(configCode);
+        if (allIds.isEmpty()) {
+            return Collections.emptyMap();
+        }
+
+        log.info("从 Redis 加载向量数据,configCode={},数量={}", configCode, allIds.size());
+
+        List<Long> idList = new ArrayList<>(allIds);
+        List<String> keys = idList.stream().map(id -> buildKey(configCode, id)).collect(Collectors.toList());
+
+        // 分批加载,避免一次性加载过多数据
+        int batchSize = 1000;
+        Map<Long, float[]> allVectors = new HashMap<>(idList.size());
+
+        for (int i = 0; i < keys.size(); i += batchSize) {
+            int end = Math.min(i + batchSize, keys.size());
+            List<String> batchKeys = keys.subList(i, end);
+            List<String> values = redisTemplate.opsForValue().multiGet(batchKeys);
+
+            if (values != null) {
+                for (int j = 0; j < batchKeys.size(); j++) {
+                    float[] vector = parseVector(values.get(j));
+                    if (vector != null) {
+                        allVectors.put(idList.get(i + j), vector);
+                    }
+                }
+            }
+        }
+
+        // 写入缓存
+        if (!allVectors.isEmpty()) {
+            cacheLock.writeLock().lock();
+            try {
+                vectorCache.put(configCode, allVectors);
+            } finally {
+                cacheLock.writeLock().unlock();
+            }
+        }
+
+        return allVectors;
+    }
+
+    /**
+     * 定时刷新缓存(每5分钟)
+     */
+    @Scheduled(fixedRate = 10 * 60 * 1000)
+    public void refreshCache() {
+        log.info("开始定时刷新向量缓存");
+        cacheLock.writeLock().lock();
+        try {
+            vectorCache.invalidateAll();
+            log.info("向量缓存已清空,下次查询时重新加载");
+        } finally {
+            cacheLock.writeLock().unlock();
+        }
+    }
+
+    /**
+     * 更新本地缓存中的单个向量
+     */
+    private void updateLocalCache(String configCode, Long videoId, float[] vector) {
+        cacheLock.writeLock().lock();
+        try {
+            Map<Long, float[]> cached = vectorCache.getIfPresent(configCode);
+            if (cached != null) {
+                cached.put(videoId, vector);
+            }
+        } finally {
+            cacheLock.writeLock().unlock();
+        }
+    }
+
+    /**
+     * 从本地缓存获取单个向量
+     */
+    private float[] getFromLocalCache(String configCode, Long videoId) {
+        cacheLock.readLock().lock();
+        try {
+            Map<Long, float[]> cached = vectorCache.getIfPresent(configCode);
+            if (cached != null) {
+                return cached.get(videoId);
+            }
+        } finally {
+            cacheLock.readLock().unlock();
+        }
+        return null;
+    }
+
+    /**
+     * 从本地缓存移除单个向量
+     */
+    private void removeFromLocalCache(String configCode, Long videoId) {
+        cacheLock.writeLock().lock();
+        try {
+            Map<Long, float[]> cached = vectorCache.getIfPresent(configCode);
+            if (cached != null) {
+                cached.remove(videoId);
+            }
+        } finally {
+            cacheLock.writeLock().unlock();
+        }
     }
 
     // ---------------------------------------------------------------- 工具方法
@@ -246,13 +472,27 @@ public class RedisVectorStoreServiceImpl implements VectorStoreService {
         return VECTOR_KEY_PREFIX + configCode + ":ids";
     }
 
-    private List<Float> parseVector(String value) {
+    /**
+     * 序列化向量为 JSON 字符串
+     */
+    private String serializeVector(float[] vector) {
+        JSONArray array = new JSONArray(vector.length);
+        for (float v : vector) {
+            array.add(v);
+        }
+        return array.toJSONString();
+    }
+
+    /**
+     * 解析向量为 float[]
+     */
+    private float[] parseVector(String value) {
         if (value == null || value.isEmpty()) return null;
         try {
             JSONArray array = JSONArray.parseArray(value);
-            List<Float> vector = new ArrayList<>(array.size());
+            float[] vector = new float[array.size()];
             for (int i = 0; i < array.size(); i++) {
-                vector.add(array.getFloat(i));
+                vector[i] = array.getFloatValue(i);
             }
             return vector;
         } catch (Exception e) {
@@ -261,6 +501,17 @@ public class RedisVectorStoreServiceImpl implements VectorStoreService {
         }
     }
 
+    /**
+     * float[] 转 List<Float>
+     */
+    private List<Float> toList(float[] arr) {
+        List<Float> list = new ArrayList<>(arr.length);
+        for (float v : arr) {
+            list.add(v);
+        }
+        return list;
+    }
+
     /**
      * L2 归一化:将向量转换为单位向量
      */
@@ -282,6 +533,7 @@ public class RedisVectorStoreServiceImpl implements VectorStoreService {
 
     /**
      * 两个已归一化向量的点积 = 余弦相似度
+     * 使用 double 累加提高精度
      */
     private double dotProduct(float[] a, float[] b) {
         int len = Math.min(a.length, b.length);