Procházet zdrojové kódy

增加批量删除向量和审核状态过滤功能

wangyunpeng před 1 měsícem
rodič
revize
57040cc4aa

+ 169 - 0
core/src/main/java/com/tzld/videoVector/api/VideoApiService.java

@@ -0,0 +1,169 @@
+package com.tzld.videoVector.api;
+
+import com.alibaba.fastjson.JSONArray;
+import com.alibaba.fastjson.JSONObject;
+import com.google.common.collect.Lists;
+import com.tzld.videoVector.model.entity.VideoDetail;
+import lombok.extern.slf4j.Slf4j;
+import okhttp3.*;
+import org.springframework.beans.factory.annotation.Value;
+import org.springframework.stereotype.Service;
+import org.springframework.util.CollectionUtils;
+
+import javax.annotation.PostConstruct;
+import java.io.IOException;
+import java.util.*;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * 视频 API 服务
+ * 调用长视频 API 获取视频详情信息
+ */
+@Slf4j
+@Service
+public class VideoApiService {
+
+    private OkHttpClient client;
+
+    private static final String POST_VIDEO_DETAIL_URL = "https://longvideoapi.piaoquantv.com/longvideoapi/openapi/video/batchSelectVideoInfo";
+
+    private static final MediaType JSON_MEDIA_TYPE = MediaType.get("application/json; charset=utf-8");
+
+    /**
+     * 每批次最大查询数量
+     */
+    private static final int BATCH_SIZE = 20;
+
+    @Value("${video.api.timeout:30}")
+    private int timeout;
+
+    @PostConstruct
+    public void init() {
+        client = new OkHttpClient.Builder()
+                .connectTimeout(timeout, TimeUnit.SECONDS)
+                .readTimeout(timeout, TimeUnit.SECONDS)
+                .writeTimeout(timeout, TimeUnit.SECONDS)
+                .build();
+        log.info("视频 API 服务初始化完成");
+    }
+
+    /**
+     * 批量获取视频详情信息
+     * 自动分批处理,每批次最多20个videoId
+     *
+     * @param videoIdList 视频ID列表
+     * @return videoId -> VideoDetail 映射
+     */
+    public Map<Long, VideoDetail> getVideoDetail(Set<Long> videoIdList) {
+        if (CollectionUtils.isEmpty(videoIdList)) {
+            return Collections.emptyMap();
+        }
+
+        Map<Long, VideoDetail> result = new HashMap<>();
+        List<Long> videoIds = new ArrayList<>(videoIdList);
+
+        // 分批处理,每批次最多20个
+        List<List<Long>> partitions = Lists.partition(videoIds, BATCH_SIZE);
+        for (List<Long> batch : partitions) {
+            try {
+                Map<Long, VideoDetail> batchResult = getVideoDetailBatch(new HashSet<>(batch));
+                result.putAll(batchResult);
+            } catch (Exception e) {
+                log.error("批量获取视频详情失败, batch={}, error={}", batch, e.getMessage(), e);
+            }
+        }
+
+        return result;
+    }
+
+    /**
+     * 单批次获取视频详情信息
+     *
+     * @param videoIdList 视频ID列表(最多20个)
+     * @return videoId -> VideoDetail 映射
+     */
+    private Map<Long, VideoDetail> getVideoDetailBatch(Set<Long> videoIdList) {
+        try {
+            Map<Long, VideoDetail> map = new HashMap<>();
+            JSONObject params = new JSONObject();
+            params.put("videoIdList", videoIdList);
+
+            RequestBody body = RequestBody.create(JSON_MEDIA_TYPE, params.toJSONString());
+            Request request = new Request.Builder()
+                    .url(POST_VIDEO_DETAIL_URL)
+                    .post(body)
+                    .addHeader("Content-Type", "application/json")
+                    .build();
+
+            try (Response response = client.newCall(request).execute()) {
+                if (!response.isSuccessful()) {
+                    String errorBody = response.body() != null ? response.body().string() : "无";
+                    log.error("获取视频详情 API 请求失败,HTTP状态码: {}, 错误信息: {}", response.code(), errorBody);
+                    return map;
+                }
+
+                String post = response.body().string();
+                JSONObject res = JSONObject.parseObject(post);
+                JSONArray data = res.getJSONArray("data");
+
+                if (data == null || data.isEmpty()) {
+                    return map;
+                }
+
+                for (int i = 0; i < data.size(); i++) {
+                    JSONObject jsonObject = data.getJSONObject(i);
+                    if (jsonObject == null) {
+                        continue;
+                    }
+
+                    Long videoId = jsonObject.getLong("id");
+                    if (videoId == null) {
+                        continue;
+                    }
+
+                    VideoDetail videoDetail = new VideoDetail();
+                    videoDetail.setCover(jsonObject.getString("shareImgPath"));
+                    videoDetail.setTitle(jsonObject.getString("title"));
+                    videoDetail.setVideoPath(jsonObject.getString("videoPath"));
+                    videoDetail.setVideoCoverSnapshotPath(jsonObject.getString("videoCoverSnapshotPath"));
+                    videoDetail.setAuditStatus(jsonObject.getInteger("auditStatus"));
+                    videoDetail.setRecommendStatus(jsonObject.getInteger("recommendStatus"));
+
+                    map.put(videoId, videoDetail);
+                }
+            }
+
+            return map;
+        } catch (IOException e) {
+            log.error("VideoApiService getVideoDetail error", e);
+        }
+
+        return new HashMap<>();
+    }
+
+    /**
+     * 获取审核未通过的视频ID列表
+     *
+     * @param videoIdList 视频ID列表
+     * @return 审核未通过的视频ID集合
+     */
+    public Set<Long> getNotAuditPassedVideoIds(Set<Long> videoIdList) {
+        Map<Long, VideoDetail> videoDetails = getVideoDetail(videoIdList);
+        Set<Long> notPassedIds = new HashSet<>();
+
+        for (Map.Entry<Long, VideoDetail> entry : videoDetails.entrySet()) {
+            if (!entry.getValue().isAuditPassed()) {
+                notPassedIds.add(entry.getKey());
+            }
+        }
+
+        // 如果查询结果中没有某些视频ID,说明这些视频可能已被删除,也视为审核不通过
+        for (Long videoId : videoIdList) {
+            if (!videoDetails.containsKey(videoId)) {
+                notPassedIds.add(videoId);
+            }
+        }
+
+        return notPassedIds;
+    }
+}

+ 56 - 0
core/src/main/java/com/tzld/videoVector/job/VideoVectorJob.java

@@ -12,6 +12,7 @@ import com.tzld.videoVector.model.po.videoVector.deconstruct.DeconstructContent;
 import com.tzld.videoVector.model.po.videoVector.deconstruct.DeconstructContentExample;
 import com.tzld.videoVector.model.po.videoVector.deconstruct.DeconstructVectorConfig;
 import com.tzld.videoVector.model.po.videoVector.deconstruct.DeconstructVectorConfigExample;
+import com.tzld.videoVector.api.VideoApiService;
 import com.tzld.videoVector.service.DeconstructService;
 import com.tzld.videoVector.service.EmbeddingService;
 import com.tzld.videoVector.service.VectorStoreService;
@@ -47,11 +48,19 @@ public class VideoVectorJob {
     @Resource
     private EmbeddingService embeddingService;
 
+    @Resource
+    private VideoApiService videoApiService;
+
     /**
      * 每页查询数量
      */
     private static final int PAGE_SIZE = 1000;
 
+    /**
+     * 审核状态检查批次大小
+     */
+    private static final int AUDIT_CHECK_BATCH_SIZE = 20;
+
     /**
      * 超时时间:1小时(毫秒)
      */
@@ -98,6 +107,10 @@ public class VideoVectorJob {
                 // 3. 对每个配置进行处理
                 for (DeconstructVectorConfig config : configs) {
                     String configCode = config.getConfigCode();
+
+                    // 3.0 检查已有向量的审核状态,移除审核不通过的视频
+                    checkAndRemoveNotAuditPassedVideos(configCode);
+
                     // 3.1 查询哪些 videoId 在该配置下已有向量
                     Set<Long> existingIds = vectorStoreService.existsByIds(configCode, videoIds);
                     // 3.2 过滤出需要处理的 videoId(排除已有向量的)
@@ -496,6 +509,49 @@ public class VideoVectorJob {
         }
     }
 
+    /**
+     * 检查并移除审核状态不通过的视频向量
+     * 每批次最多检查20个videoId
+     *
+     * @param configCode 配置编码
+     */
+    private void checkAndRemoveNotAuditPassedVideos(String configCode) {
+        try {
+            // 获取该配置下所有已有的视频ID
+            Set<Long> allVideoIds = vectorStoreService.getAllVideoIds(configCode);
+            if (allVideoIds == null || allVideoIds.isEmpty()) {
+                log.debug("配置 {} 下没有已存储的向量,跳过审核检查", configCode);
+                return;
+            }
+
+            log.info("配置 {} 开始检查审核状态,共 {} 个视频", configCode, allVideoIds.size());
+
+            // 分批检查审核状态
+            List<Long> videoIdList = new ArrayList<>(allVideoIds);
+            int totalRemoved = 0;
+
+            for (int i = 0; i < videoIdList.size(); i += AUDIT_CHECK_BATCH_SIZE) {
+                int end = Math.min(i + AUDIT_CHECK_BATCH_SIZE, videoIdList.size());
+                Set<Long> batchIds = new HashSet<>(videoIdList.subList(i, end));
+
+                // 获取审核未通过的视频ID
+                Set<Long> notPassedIds = videoApiService.getNotAuditPassedVideoIds(batchIds);
+
+                if (!notPassedIds.isEmpty()) {
+                    // 批量删除审核不通过的视频向量
+                    vectorStoreService.deleteBatch(configCode, notPassedIds);
+                    totalRemoved += notPassedIds.size();
+                    log.info("配置 {} 移除审核不通过的视频 {} 个: {}", configCode, notPassedIds.size(), notPassedIds);
+                }
+            }
+
+            log.info("配置 {} 审核检查完成,共移除 {} 个视频向量", configCode, totalRemoved);
+
+        } catch (Exception e) {
+            log.error("配置 {} 检查审核状态失败: {}", configCode, e.getMessage(), e);
+        }
+    }
+
     /**
      * 更新内容状态为失败
      */

+ 46 - 0
core/src/main/java/com/tzld/videoVector/model/entity/VideoDetail.java

@@ -0,0 +1,46 @@
+package com.tzld.videoVector.model.entity;
+
+import lombok.Data;
+
+/**
+ * 视频详情信息
+ */
+@Data
+public class VideoDetail {
+
+    private String cover;
+
+    private String title;
+
+    private String videoPath;
+
+    private String videoCoverSnapshotPath;
+
+    /**
+     * 审核状态
+     * 1: 审核中
+     * 2: 不通过
+     * 3: 待修改
+     * 4: 自己可见
+     * 5: 通过
+     */
+    private Integer auditStatus;
+
+    /**
+     * 推荐状态
+     * 0: 不可搜
+     * -6: 待推荐
+     * 1: 普通推荐
+     * 10: 编辑推荐
+     * -7: 可搜索
+     */
+    private Integer recommendStatus;
+
+    /**
+     * 判断审核状态是否通过
+     * @return true 表示审核通过
+     */
+    public boolean isAuditPassed() {
+        return auditStatus != null && auditStatus == 5;
+    }
+}

+ 13 - 0
core/src/main/java/com/tzld/videoVector/service/VectorStoreService.java

@@ -117,6 +117,19 @@ public interface VectorStoreService {
      */
     void delete(String configCode, Long videoId);
 
+    /**
+     * 批量删除指定视频向量(默认配置)
+     * @param videoIds 视频ID集合
+     */
+    void deleteBatch(Collection<Long> videoIds);
+
+    /**
+     * 批量删除指定视频向量(指定配置)
+     * @param configCode 配置编码
+     * @param videoIds   视频ID集合
+     */
+    void deleteBatch(String configCode, Collection<Long> videoIds);
+
     /**
      * 在所有向量中搜索 Top-N 最相似的视频(默认配置)
      * @param queryVector 查询向量

+ 42 - 0
core/src/main/java/com/tzld/videoVector/service/impl/RedisVectorStoreServiceImpl.java

@@ -274,6 +274,48 @@ public class RedisVectorStoreServiceImpl implements VectorStoreService {
         log.debug("删除向量成功,configCode={}, videoId={}", configCode, videoId);
     }
 
+    @Override
+    public void deleteBatch(Collection<Long> videoIds) {
+        deleteBatch(DEFAULT_CONFIG_CODE, videoIds);
+    }
+
+    @Override
+    public void deleteBatch(String configCode, Collection<Long> videoIds) {
+        if (videoIds == null || videoIds.isEmpty()) {
+            return;
+        }
+        if (configCode == null || configCode.isEmpty()) {
+            configCode = DEFAULT_CONFIG_CODE;
+        }
+
+        final String finalConfigCode = configCode;
+        List<String> keys = videoIds.stream()
+                .map(id -> buildKey(finalConfigCode, id))
+                .collect(Collectors.toList());
+
+        // 批量删除向量数据
+        redisTemplate.delete(keys);
+
+        // 批量从 ID 集合中移除
+        redisTemplate.opsForSet().remove(buildIdsKey(configCode),
+                videoIds.stream().map(String::valueOf).toArray());
+
+        // 从本地缓存批量移除
+        cacheLock.writeLock().lock();
+        try {
+            Map<Long, float[]> cached = vectorCache.getIfPresent(configCode);
+            if (cached != null) {
+                for (Long videoId : videoIds) {
+                    cached.remove(videoId);
+                }
+            }
+        } finally {
+            cacheLock.writeLock().unlock();
+        }
+
+        log.info("批量删除向量成功,configCode={}, 数量={}", configCode, videoIds.size());
+    }
+
     // ---------------------------------------------------------------- 搜索(优化版)
 
     @Override

+ 51 - 8
core/src/main/java/com/tzld/videoVector/service/impl/VideoSearchServiceImpl.java

@@ -1,8 +1,11 @@
 package com.tzld.videoVector.service.impl;
 
 import com.alibaba.fastjson.JSONObject;
+import com.google.common.collect.Lists;
+import com.tzld.videoVector.api.VideoApiService;
 import com.tzld.videoVector.dao.mapper.videoVector.deconstruct.DeconstructContentMapper;
 import com.tzld.videoVector.model.entity.DeconstructResult;
+import com.tzld.videoVector.model.entity.VideoDetail;
 import com.tzld.videoVector.model.entity.VideoMatch;
 import com.tzld.videoVector.model.param.DeconstructParam;
 import com.tzld.videoVector.model.param.GetDeconstructParam;
@@ -20,10 +23,8 @@ import org.springframework.stereotype.Service;
 import org.springframework.util.StringUtils;
 
 import javax.annotation.Resource;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.Date;
-import java.util.List;
+import java.util.*;
+import java.util.stream.Collectors;
 
 import static com.tzld.videoVector.service.VectorStoreService.DEFAULT_CONFIG_CODE;
 
@@ -46,6 +47,9 @@ public class VideoSearchServiceImpl implements VideoSearchService {
     @Resource
     private VectorizeService vectorizeService;
 
+    @Resource
+    private VideoApiService videoApiService;
+
     @Override
     public String deconstruct(DeconstructParam param) {
         if (param == null) {
@@ -395,16 +399,21 @@ public class VideoSearchServiceImpl implements VideoSearchService {
         log.info("开始匹配 Top-{} 视频,configCode: {},向量维度: {}", topN, configCode, queryVector.size());
 
         // 在 Redis 中搜索(支持按 configCode 搜索)
+        // 为了确保返回足够的结果,搜索更多的候选
+        int candidateSize = topN * 3;
         List<VideoMatch> matches;
         if (configCode != null && !configCode.isEmpty()) {
-            matches = vectorStoreService.searchTopN(configCode, queryVector, topN);
+            matches = vectorStoreService.searchTopN(configCode, queryVector, candidateSize);
         } else {
-            matches = vectorStoreService.searchTopN(queryVector, topN);
+            matches = vectorStoreService.searchTopN(queryVector, candidateSize);
         }
 
+        // 过滤审核状态不通过的视频
+        List<VideoMatch> filteredMatches = filterByAuditStatus(matches, topN);
+
         // 转化为返回格式
-        List<Object> result = new ArrayList<>(matches.size());
-        for (VideoMatch match : matches) {
+        List<Object> result = new ArrayList<>(filteredMatches.size());
+        for (VideoMatch match : filteredMatches) {
             JSONObject item = new JSONObject();
             item.put("videoId", match.getVideoId());
             item.put("score", match.getScore());
@@ -414,4 +423,38 @@ public class VideoSearchServiceImpl implements VideoSearchService {
         log.info("匹配完成,configCode: {},返回 {} 条结果", configCode, result.size());
         return result;
     }
+
+    /**
+     * 根据审核状态过滤视频
+     * 每批次最多查询20个videoId
+     *
+     * @param matches 匹配结果列表
+     * @param topN    需要返回的数量
+     * @return 过滤后的结果列表
+     */
+    private List<VideoMatch> filterByAuditStatus(List<VideoMatch> matches, int topN) {
+        if (matches == null || matches.isEmpty()) {
+            return Collections.emptyList();
+        }
+
+        // 提取所有 videoId
+        Set<Long> videoIds = matches.stream()
+                .map(VideoMatch::getVideoId)
+                .collect(Collectors.toSet());
+
+        // 批量获取视频详情
+        Map<Long, VideoDetail> videoDetails = videoApiService.getVideoDetail(videoIds);
+
+        // 过滤审核通过的视频
+        List<VideoMatch> filteredMatches = matches.stream()
+                .filter(match -> {
+                    VideoDetail detail = videoDetails.get(match.getVideoId());
+                    // 如果查询不到视频详情,视为审核不通过
+                    return detail != null && detail.isAuditPassed();
+                })
+                .limit(topN)
+                .collect(Collectors.toList());
+
+        return filteredMatches;
+    }
 }