|
@@ -1,19 +1,32 @@
|
|
|
package com.tzld.videoVector.service.impl;
|
|
package com.tzld.videoVector.service.impl;
|
|
|
|
|
|
|
|
import com.alibaba.fastjson.JSONArray;
|
|
import com.alibaba.fastjson.JSONArray;
|
|
|
|
|
+import com.google.common.cache.Cache;
|
|
|
|
|
+import com.google.common.cache.CacheBuilder;
|
|
|
import com.tzld.videoVector.model.entity.VideoMatch;
|
|
import com.tzld.videoVector.model.entity.VideoMatch;
|
|
|
import com.tzld.videoVector.service.VectorStoreService;
|
|
import com.tzld.videoVector.service.VectorStoreService;
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
import org.springframework.beans.factory.annotation.Autowired;
|
|
import org.springframework.beans.factory.annotation.Autowired;
|
|
|
import org.springframework.data.redis.core.RedisCallback;
|
|
import org.springframework.data.redis.core.RedisCallback;
|
|
|
import org.springframework.data.redis.core.RedisTemplate;
|
|
import org.springframework.data.redis.core.RedisTemplate;
|
|
|
|
|
+import org.springframework.scheduling.annotation.Scheduled;
|
|
|
import org.springframework.stereotype.Service;
|
|
import org.springframework.stereotype.Service;
|
|
|
|
|
|
|
|
import java.util.*;
|
|
import java.util.*;
|
|
|
import java.util.stream.Collectors;
|
|
import java.util.stream.Collectors;
|
|
|
|
|
+import java.util.concurrent.TimeUnit;
|
|
|
|
|
+import java.util.concurrent.locks.ReentrantReadWriteLock;
|
|
|
|
|
|
|
|
/**
|
|
/**
|
|
|
- * 基于 Redis 的向量存储服务实现
|
|
|
|
|
|
|
+ * 基于 Redis 的向量存储服务实现(优化版)
|
|
|
|
|
+ *
|
|
|
|
|
+ * <p>优化点:
|
|
|
|
|
+ * <ul>
|
|
|
|
|
+ * <li>本地缓存:使用 Guava Cache 缓存向量数据,减少 Redis 查询</li>
|
|
|
|
|
+ * <li>预归一化:存储时即归一化,查询时直接使用</li>
|
|
|
|
|
+ * <li>并行计算:使用并行流加速相似度计算</li>
|
|
|
|
|
+ * <li>堆排序优化:使用优先队列快速获取 TopN</li>
|
|
|
|
|
+ * </ul>
|
|
|
*
|
|
*
|
|
|
* <p>存储结构:
|
|
* <p>存储结构:
|
|
|
* <ul>
|
|
* <ul>
|
|
@@ -28,6 +41,15 @@ public class RedisVectorStoreServiceImpl implements VectorStoreService {
|
|
|
/** 向量 Key 前缀 */
|
|
/** 向量 Key 前缀 */
|
|
|
private static final String VECTOR_KEY_PREFIX = "video:vector:";
|
|
private static final String VECTOR_KEY_PREFIX = "video:vector:";
|
|
|
|
|
|
|
|
|
|
+ /** 本地缓存:configCode -> (videoId -> 归一化向量) */
|
|
|
|
|
+ private final Cache<String, Map<Long, float[]>> vectorCache = CacheBuilder.newBuilder()
|
|
|
|
|
+ .maximumSize(10)
|
|
|
|
|
+ .expireAfterAccess(30, TimeUnit.MINUTES)
|
|
|
|
|
+ .build();
|
|
|
|
|
+
|
|
|
|
|
+ /** 缓存读写锁 */
|
|
|
|
|
+ private final ReentrantReadWriteLock cacheLock = new ReentrantReadWriteLock();
|
|
|
|
|
+
|
|
|
@Autowired
|
|
@Autowired
|
|
|
private RedisTemplate<String, String> redisTemplate;
|
|
private RedisTemplate<String, String> redisTemplate;
|
|
|
|
|
|
|
@@ -47,10 +69,17 @@ public class RedisVectorStoreServiceImpl implements VectorStoreService {
|
|
|
if (configCode == null || configCode.isEmpty()) {
|
|
if (configCode == null || configCode.isEmpty()) {
|
|
|
configCode = DEFAULT_CONFIG_CODE;
|
|
configCode = DEFAULT_CONFIG_CODE;
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+ // 预归一化后存储
|
|
|
|
|
+ float[] normalized = l2Normalize(vector);
|
|
|
String key = buildKey(configCode, videoId);
|
|
String key = buildKey(configCode, videoId);
|
|
|
- String value = JSONArray.toJSONString(vector);
|
|
|
|
|
|
|
+ String value = serializeVector(normalized);
|
|
|
redisTemplate.opsForValue().set(key, value);
|
|
redisTemplate.opsForValue().set(key, value);
|
|
|
redisTemplate.opsForSet().add(buildIdsKey(configCode), videoId.toString());
|
|
redisTemplate.opsForSet().add(buildIdsKey(configCode), videoId.toString());
|
|
|
|
|
+
|
|
|
|
|
+ // 更新本地缓存
|
|
|
|
|
+ updateLocalCache(configCode, videoId, normalized);
|
|
|
|
|
+
|
|
|
log.debug("保存向量成功,configCode={}, videoId={}, 维度={}", configCode, videoId, vector.size());
|
|
log.debug("保存向量成功,configCode={}, videoId={}, 维度={}", configCode, videoId, vector.size());
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -118,8 +147,20 @@ public class RedisVectorStoreServiceImpl implements VectorStoreService {
|
|
|
if (configCode == null || configCode.isEmpty()) {
|
|
if (configCode == null || configCode.isEmpty()) {
|
|
|
configCode = DEFAULT_CONFIG_CODE;
|
|
configCode = DEFAULT_CONFIG_CODE;
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+ // 先查本地缓存
|
|
|
|
|
+ float[] cached = getFromLocalCache(configCode, videoId);
|
|
|
|
|
+ if (cached != null) {
|
|
|
|
|
+ return toList(cached);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // 缓存未命中,从 Redis 获取
|
|
|
String value = redisTemplate.opsForValue().get(buildKey(configCode, videoId));
|
|
String value = redisTemplate.opsForValue().get(buildKey(configCode, videoId));
|
|
|
- return parseVector(value);
|
|
|
|
|
|
|
+ float[] vector = parseVector(value);
|
|
|
|
|
+ if (vector != null) {
|
|
|
|
|
+ return toList(vector);
|
|
|
|
|
+ }
|
|
|
|
|
+ return null;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
@Override
|
|
@Override
|
|
@@ -133,19 +174,60 @@ public class RedisVectorStoreServiceImpl implements VectorStoreService {
|
|
|
if (configCode == null || configCode.isEmpty()) {
|
|
if (configCode == null || configCode.isEmpty()) {
|
|
|
configCode = DEFAULT_CONFIG_CODE;
|
|
configCode = DEFAULT_CONFIG_CODE;
|
|
|
}
|
|
}
|
|
|
- List<Long> idList = new ArrayList<>(videoIds);
|
|
|
|
|
- final String finalConfigCode = configCode;
|
|
|
|
|
- List<String> keys = idList.stream().map(id -> buildKey(finalConfigCode, id)).collect(Collectors.toList());
|
|
|
|
|
- List<String> values = redisTemplate.opsForValue().multiGet(keys);
|
|
|
|
|
|
|
|
|
|
- Map<Long, List<Float>> result = new HashMap<>();
|
|
|
|
|
- if (values == null) return result;
|
|
|
|
|
- for (int i = 0; i < idList.size(); i++) {
|
|
|
|
|
- List<Float> vector = parseVector(values.get(i));
|
|
|
|
|
- if (vector != null) {
|
|
|
|
|
- result.put(idList.get(i), vector);
|
|
|
|
|
|
|
+ Map<Long, float[]> floatResult = getVectorsAsFloatArray(configCode, videoIds);
|
|
|
|
|
+ Map<Long, List<Float>> result = new HashMap<>(floatResult.size());
|
|
|
|
|
+ for (Map.Entry<Long, float[]> entry : floatResult.entrySet()) {
|
|
|
|
|
+ result.put(entry.getKey(), toList(entry.getValue()));
|
|
|
|
|
+ }
|
|
|
|
|
+ return result;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ /**
|
|
|
|
|
+ * 批量获取向量(返回 float[],避免装箱开销)
|
|
|
|
|
+ */
|
|
|
|
|
+ private Map<Long, float[]> getVectorsAsFloatArray(String configCode, Collection<Long> videoIds) {
|
|
|
|
|
+ if (videoIds == null || videoIds.isEmpty()) return Collections.emptyMap();
|
|
|
|
|
+
|
|
|
|
|
+ Map<Long, float[]> result = new HashMap<>();
|
|
|
|
|
+ List<Long> missedIds = new ArrayList<>();
|
|
|
|
|
+
|
|
|
|
|
+ // 先从本地缓存获取
|
|
|
|
|
+ cacheLock.readLock().lock();
|
|
|
|
|
+ try {
|
|
|
|
|
+ Map<Long, float[]> cachedMap = vectorCache.getIfPresent(configCode);
|
|
|
|
|
+ if (cachedMap != null) {
|
|
|
|
|
+ for (Long id : videoIds) {
|
|
|
|
|
+ float[] vector = cachedMap.get(id);
|
|
|
|
|
+ if (vector != null) {
|
|
|
|
|
+ result.put(id, vector);
|
|
|
|
|
+ } else {
|
|
|
|
|
+ missedIds.add(id);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ } else {
|
|
|
|
|
+ missedIds.addAll(videoIds);
|
|
|
|
|
+ }
|
|
|
|
|
+ } finally {
|
|
|
|
|
+ cacheLock.readLock().unlock();
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // 缓存未命中的从 Redis 获取
|
|
|
|
|
+ if (!missedIds.isEmpty()) {
|
|
|
|
|
+ List<Long> idList = new ArrayList<>(missedIds);
|
|
|
|
|
+ List<String> keys = idList.stream().map(id -> buildKey(configCode, id)).collect(Collectors.toList());
|
|
|
|
|
+ List<String> values = redisTemplate.opsForValue().multiGet(keys);
|
|
|
|
|
+
|
|
|
|
|
+ if (values != null) {
|
|
|
|
|
+ for (int i = 0; i < idList.size(); i++) {
|
|
|
|
|
+ float[] vector = parseVector(values.get(i));
|
|
|
|
|
+ if (vector != null) {
|
|
|
|
|
+ result.put(idList.get(i), vector);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
return result;
|
|
return result;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -179,10 +261,14 @@ public class RedisVectorStoreServiceImpl implements VectorStoreService {
|
|
|
}
|
|
}
|
|
|
redisTemplate.delete(buildKey(configCode, videoId));
|
|
redisTemplate.delete(buildKey(configCode, videoId));
|
|
|
redisTemplate.opsForSet().remove(buildIdsKey(configCode), videoId.toString());
|
|
redisTemplate.opsForSet().remove(buildIdsKey(configCode), videoId.toString());
|
|
|
|
|
+
|
|
|
|
|
+ // 从本地缓存移除
|
|
|
|
|
+ removeFromLocalCache(configCode, videoId);
|
|
|
|
|
+
|
|
|
log.debug("删除向量成功,configCode={}, videoId={}", configCode, videoId);
|
|
log.debug("删除向量成功,configCode={}, videoId={}", configCode, videoId);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- // ---------------------------------------------------------------- 搜索
|
|
|
|
|
|
|
+ // ---------------------------------------------------------------- 搜索(优化版)
|
|
|
|
|
|
|
|
@Override
|
|
@Override
|
|
|
public List<VideoMatch> searchTopN(List<Float> queryVector, int topN) {
|
|
public List<VideoMatch> searchTopN(List<Float> queryVector, int topN) {
|
|
@@ -198,34 +284,174 @@ public class RedisVectorStoreServiceImpl implements VectorStoreService {
|
|
|
configCode = DEFAULT_CONFIG_CODE;
|
|
configCode = DEFAULT_CONFIG_CODE;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- Set<Long> allIds = getAllVideoIds(configCode);
|
|
|
|
|
- if (allIds.isEmpty()) {
|
|
|
|
|
|
|
+ // 获取或加载全部向量到本地缓存
|
|
|
|
|
+ Map<Long, float[]> allVectors = loadAllVectors(configCode);
|
|
|
|
|
+ if (allVectors.isEmpty()) {
|
|
|
log.info("向量库为空,configCode={},无法搜索", configCode);
|
|
log.info("向量库为空,configCode={},无法搜索", configCode);
|
|
|
return Collections.emptyList();
|
|
return Collections.emptyList();
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- log.info("开始向量搜索,configCode={},库中共 {} 条记录,topN={}", configCode, allIds.size(), topN);
|
|
|
|
|
|
|
+ log.info("开始向量搜索,configCode={},库中共 {} 条记录,topN={}", configCode, allVectors.size(), topN);
|
|
|
|
|
+
|
|
|
|
|
+ // 查询向量归一化
|
|
|
|
|
+ float[] queryNorm = l2Normalize(queryVector);
|
|
|
|
|
+
|
|
|
|
|
+ // 并行计算相似度 + 堆排序获取 TopN
|
|
|
|
|
+ List<VideoMatch> topMatches = parallelSearchTopN(queryNorm, allVectors, topN);
|
|
|
|
|
|
|
|
- // 批量获取所有向量
|
|
|
|
|
- Map<Long, List<Float>> allVectors = getVectors(configCode, allIds);
|
|
|
|
|
|
|
+ log.info("向量搜索完成,configCode={},返回 {} 条结果", configCode, topMatches.size());
|
|
|
|
|
+ return topMatches;
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- // 对查询向量做 L2 归一化,加速余弦相似度计算
|
|
|
|
|
- float[] qNorm = l2Normalize(queryVector);
|
|
|
|
|
|
|
+ /**
|
|
|
|
|
+ * 并行搜索 TopN(使用堆排序优化)
|
|
|
|
|
+ */
|
|
|
|
|
+ private List<VideoMatch> parallelSearchTopN(float[] queryNorm, Map<Long, float[]> vectors, int topN) {
|
|
|
|
|
+ // 使用优先队列(小顶堆)维护 TopN
|
|
|
|
|
+ PriorityQueue<VideoMatch> heap = new PriorityQueue<>(topN, Comparator.comparingDouble(VideoMatch::getScore));
|
|
|
|
|
+
|
|
|
|
|
+ // 并行计算相似度
|
|
|
|
|
+ vectors.entrySet().parallelStream().forEach(entry -> {
|
|
|
|
|
+ double score = dotProduct(queryNorm, entry.getValue());
|
|
|
|
|
+
|
|
|
|
|
+ synchronized (heap) {
|
|
|
|
|
+ if (heap.size() < topN) {
|
|
|
|
|
+ heap.offer(new VideoMatch(entry.getKey(), score));
|
|
|
|
|
+ } else if (score > heap.peek().getScore()) {
|
|
|
|
|
+ heap.poll();
|
|
|
|
|
+ heap.offer(new VideoMatch(entry.getKey(), score));
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
|
|
|
- // 计算每条记录的余弦相似度
|
|
|
|
|
- List<VideoMatch> matches = new ArrayList<>(allVectors.size());
|
|
|
|
|
- for (Map.Entry<Long, List<Float>> entry : allVectors.entrySet()) {
|
|
|
|
|
- float[] storedNorm = l2Normalize(entry.getValue());
|
|
|
|
|
- double score = dotProduct(qNorm, storedNorm);
|
|
|
|
|
- matches.add(new VideoMatch(entry.getKey(), score));
|
|
|
|
|
|
|
+ // 按相似度降序输出
|
|
|
|
|
+ List<VideoMatch> result = new ArrayList<>(heap.size());
|
|
|
|
|
+ while (!heap.isEmpty()) {
|
|
|
|
|
+ result.add(heap.poll());
|
|
|
}
|
|
}
|
|
|
|
|
+ Collections.reverse(result);
|
|
|
|
|
+ return result;
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- // 按相似度降序,取前 topN
|
|
|
|
|
- matches.sort((a, b) -> Double.compare(b.getScore(), a.getScore()));
|
|
|
|
|
- List<VideoMatch> topMatches = matches.subList(0, Math.min(topN, matches.size()));
|
|
|
|
|
|
|
+ // ---------------------------------------------------------------- 本地缓存管理
|
|
|
|
|
|
|
|
- log.info("向量搜索完成,configCode={},返回 {} 条结果", configCode, topMatches.size());
|
|
|
|
|
- return topMatches;
|
|
|
|
|
|
|
+ /**
|
|
|
|
|
+ * 加载全部向量到本地缓存
|
|
|
|
|
+ */
|
|
|
|
|
+ private Map<Long, float[]> loadAllVectors(String configCode) {
|
|
|
|
|
+ // 先尝试从缓存获取
|
|
|
|
|
+ cacheLock.readLock().lock();
|
|
|
|
|
+ try {
|
|
|
|
|
+ Map<Long, float[]> cached = vectorCache.getIfPresent(configCode);
|
|
|
|
|
+ if (cached != null && !cached.isEmpty()) {
|
|
|
|
|
+ return cached;
|
|
|
|
|
+ }
|
|
|
|
|
+ } finally {
|
|
|
|
|
+ cacheLock.readLock().unlock();
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // 缓存为空,从 Redis 加载
|
|
|
|
|
+ Set<Long> allIds = getAllVideoIds(configCode);
|
|
|
|
|
+ if (allIds.isEmpty()) {
|
|
|
|
|
+ return Collections.emptyMap();
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ log.info("从 Redis 加载向量数据,configCode={},数量={}", configCode, allIds.size());
|
|
|
|
|
+
|
|
|
|
|
+ List<Long> idList = new ArrayList<>(allIds);
|
|
|
|
|
+ List<String> keys = idList.stream().map(id -> buildKey(configCode, id)).collect(Collectors.toList());
|
|
|
|
|
+
|
|
|
|
|
+ // 分批加载,避免一次性加载过多数据
|
|
|
|
|
+ int batchSize = 1000;
|
|
|
|
|
+ Map<Long, float[]> allVectors = new HashMap<>(idList.size());
|
|
|
|
|
+
|
|
|
|
|
+ for (int i = 0; i < keys.size(); i += batchSize) {
|
|
|
|
|
+ int end = Math.min(i + batchSize, keys.size());
|
|
|
|
|
+ List<String> batchKeys = keys.subList(i, end);
|
|
|
|
|
+ List<String> values = redisTemplate.opsForValue().multiGet(batchKeys);
|
|
|
|
|
+
|
|
|
|
|
+ if (values != null) {
|
|
|
|
|
+ for (int j = 0; j < batchKeys.size(); j++) {
|
|
|
|
|
+ float[] vector = parseVector(values.get(j));
|
|
|
|
|
+ if (vector != null) {
|
|
|
|
|
+ allVectors.put(idList.get(i + j), vector);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // 写入缓存
|
|
|
|
|
+ if (!allVectors.isEmpty()) {
|
|
|
|
|
+ cacheLock.writeLock().lock();
|
|
|
|
|
+ try {
|
|
|
|
|
+ vectorCache.put(configCode, allVectors);
|
|
|
|
|
+ } finally {
|
|
|
|
|
+ cacheLock.writeLock().unlock();
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return allVectors;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ /**
|
|
|
|
|
+ * 定时刷新缓存(每5分钟)
|
|
|
|
|
+ */
|
|
|
|
|
+ @Scheduled(fixedRate = 10 * 60 * 1000)
|
|
|
|
|
+ public void refreshCache() {
|
|
|
|
|
+ log.info("开始定时刷新向量缓存");
|
|
|
|
|
+ cacheLock.writeLock().lock();
|
|
|
|
|
+ try {
|
|
|
|
|
+ vectorCache.invalidateAll();
|
|
|
|
|
+ log.info("向量缓存已清空,下次查询时重新加载");
|
|
|
|
|
+ } finally {
|
|
|
|
|
+ cacheLock.writeLock().unlock();
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ /**
|
|
|
|
|
+ * 更新本地缓存中的单个向量
|
|
|
|
|
+ */
|
|
|
|
|
+ private void updateLocalCache(String configCode, Long videoId, float[] vector) {
|
|
|
|
|
+ cacheLock.writeLock().lock();
|
|
|
|
|
+ try {
|
|
|
|
|
+ Map<Long, float[]> cached = vectorCache.getIfPresent(configCode);
|
|
|
|
|
+ if (cached != null) {
|
|
|
|
|
+ cached.put(videoId, vector);
|
|
|
|
|
+ }
|
|
|
|
|
+ } finally {
|
|
|
|
|
+ cacheLock.writeLock().unlock();
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ /**
|
|
|
|
|
+ * 从本地缓存获取单个向量
|
|
|
|
|
+ */
|
|
|
|
|
+ private float[] getFromLocalCache(String configCode, Long videoId) {
|
|
|
|
|
+ cacheLock.readLock().lock();
|
|
|
|
|
+ try {
|
|
|
|
|
+ Map<Long, float[]> cached = vectorCache.getIfPresent(configCode);
|
|
|
|
|
+ if (cached != null) {
|
|
|
|
|
+ return cached.get(videoId);
|
|
|
|
|
+ }
|
|
|
|
|
+ } finally {
|
|
|
|
|
+ cacheLock.readLock().unlock();
|
|
|
|
|
+ }
|
|
|
|
|
+ return null;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ /**
|
|
|
|
|
+ * 从本地缓存移除单个向量
|
|
|
|
|
+ */
|
|
|
|
|
+ private void removeFromLocalCache(String configCode, Long videoId) {
|
|
|
|
|
+ cacheLock.writeLock().lock();
|
|
|
|
|
+ try {
|
|
|
|
|
+ Map<Long, float[]> cached = vectorCache.getIfPresent(configCode);
|
|
|
|
|
+ if (cached != null) {
|
|
|
|
|
+ cached.remove(videoId);
|
|
|
|
|
+ }
|
|
|
|
|
+ } finally {
|
|
|
|
|
+ cacheLock.writeLock().unlock();
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// ---------------------------------------------------------------- 工具方法
|
|
// ---------------------------------------------------------------- 工具方法
|
|
@@ -246,13 +472,27 @@ public class RedisVectorStoreServiceImpl implements VectorStoreService {
|
|
|
return VECTOR_KEY_PREFIX + configCode + ":ids";
|
|
return VECTOR_KEY_PREFIX + configCode + ":ids";
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- private List<Float> parseVector(String value) {
|
|
|
|
|
|
|
+ /**
|
|
|
|
|
+ * 序列化向量为 JSON 字符串
|
|
|
|
|
+ */
|
|
|
|
|
+ private String serializeVector(float[] vector) {
|
|
|
|
|
+ JSONArray array = new JSONArray(vector.length);
|
|
|
|
|
+ for (float v : vector) {
|
|
|
|
|
+ array.add(v);
|
|
|
|
|
+ }
|
|
|
|
|
+ return array.toJSONString();
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ /**
|
|
|
|
|
+ * 解析向量为 float[]
|
|
|
|
|
+ */
|
|
|
|
|
+ private float[] parseVector(String value) {
|
|
|
if (value == null || value.isEmpty()) return null;
|
|
if (value == null || value.isEmpty()) return null;
|
|
|
try {
|
|
try {
|
|
|
JSONArray array = JSONArray.parseArray(value);
|
|
JSONArray array = JSONArray.parseArray(value);
|
|
|
- List<Float> vector = new ArrayList<>(array.size());
|
|
|
|
|
|
|
+ float[] vector = new float[array.size()];
|
|
|
for (int i = 0; i < array.size(); i++) {
|
|
for (int i = 0; i < array.size(); i++) {
|
|
|
- vector.add(array.getFloat(i));
|
|
|
|
|
|
|
+ vector[i] = array.getFloatValue(i);
|
|
|
}
|
|
}
|
|
|
return vector;
|
|
return vector;
|
|
|
} catch (Exception e) {
|
|
} catch (Exception e) {
|
|
@@ -261,6 +501,17 @@ public class RedisVectorStoreServiceImpl implements VectorStoreService {
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ /**
|
|
|
|
|
+ * float[] 转 List<Float>
|
|
|
|
|
+ */
|
|
|
|
|
+ private List<Float> toList(float[] arr) {
|
|
|
|
|
+ List<Float> list = new ArrayList<>(arr.length);
|
|
|
|
|
+ for (float v : arr) {
|
|
|
|
|
+ list.add(v);
|
|
|
|
|
+ }
|
|
|
|
|
+ return list;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
/**
|
|
/**
|
|
|
* L2 归一化:将向量转换为单位向量
|
|
* L2 归一化:将向量转换为单位向量
|
|
|
*/
|
|
*/
|
|
@@ -282,6 +533,7 @@ public class RedisVectorStoreServiceImpl implements VectorStoreService {
|
|
|
|
|
|
|
|
/**
|
|
/**
|
|
|
* 两个已归一化向量的点积 = 余弦相似度
|
|
* 两个已归一化向量的点积 = 余弦相似度
|
|
|
|
|
+ * 使用 double 累加提高精度
|
|
|
*/
|
|
*/
|
|
|
private double dotProduct(float[] a, float[] b) {
|
|
private double dotProduct(float[] a, float[] b) {
|
|
|
int len = Math.min(a.length, b.length);
|
|
int len = Math.min(a.length, b.length);
|