sunxy 11 місяців тому
батько
коміт
2e5033a001

+ 11 - 8
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/recall/strategy/ContentBaseRecallStrategy.java

@@ -67,14 +67,17 @@ public class ContentBaseRecallStrategy implements RecallStrategy {
         FilterResult filterResult = filterService.filter(filterParam);
         List<Video> videosResult = new ArrayList<>();
         if (filterResult != null && CollectionUtils.isNotEmpty(filterResult.getVideoIds())) {
-            filterResult.getVideoIds().stream().limit(limit).forEach(vid -> {
-                Video video = new Video();
-                video.setVideoId(vid);
-                video.setAbCode(param.getAbCode());
-                video.setRovScore(videoMap.get(vid));
-                video.setPushFrom(pushFrom());
-                videosResult.add(video);
-            });
+            filterResult.getVideoIds().stream()
+                    // 按照 rovScore 倒序排序
+                    .sorted(Comparator.comparing(vid -> videoMap.getOrDefault(vid, 0.0)).reversed())
+                    .limit(limit).forEach(vid -> {
+                        Video video = new Video();
+                        video.setVideoId(vid);
+                        video.setAbCode(param.getAbCode());
+                        video.setRovScore(videoMap.get(vid));
+                        video.setPushFrom(pushFrom());
+                        videosResult.add(video);
+                    });
         }
         return videosResult;
     }

+ 21 - 13
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/score4recall/strategy/ContentBaseRecallScore.java

@@ -2,9 +2,9 @@ package com.tzld.piaoquan.recommend.server.service.score4recall.strategy;
 
 import com.tzld.piaoquan.recommend.server.service.score.ScorerConfigInfo;
 import com.tzld.piaoquan.recommend.server.service.score4recall.AbstractScorer4Recall;
-import com.tzld.piaoquan.recommend.server.service.score4recall.model4recall.VideoTagModel4RecallMap;
-import com.tzld.piaoquan.recommend.server.util.ListMerger;
+import com.tzld.piaoquan.recommend.server.service.score4recall.model4recall.Model4RecallKeyValue;
 import org.apache.commons.collections4.CollectionUtils;
+import org.apache.commons.collections4.MapUtils;
 import org.apache.commons.lang3.StringUtils;
 import org.apache.commons.lang3.tuple.Pair;
 
@@ -20,13 +20,13 @@ public class ContentBaseRecallScore extends AbstractScorer4Recall {
 
     @Override
     public void loadModel() {
-        doLoadModel(VideoTagModel4RecallMap.class);
+        doLoadModel(Model4RecallKeyValue.class);
     }
 
     @Override
     public List<Pair<Long, Double>> recall(Map<String, String> params) {
-        VideoTagModel4RecallMap model = (VideoTagModel4RecallMap) this.getModel();
-        if (model == null || model.map == null) {
+        Model4RecallKeyValue model = (Model4RecallKeyValue) this.getModel();
+        if (model == null || MapUtils.isEmpty(model.kv)) {
             return Collections.emptyList();
         }
         String tags = params.get("tags");
@@ -34,17 +34,25 @@ public class ContentBaseRecallScore extends AbstractScorer4Recall {
             return Collections.emptyList();
         }
         List<String> tagList = Arrays.stream(tags.split(",")).collect(Collectors.toList());
-        Map<String, List<Long>> tagVideoIdMap = model.map;
-        List<List<Pair<Long, Double>>> multiTagVideos = new ArrayList<>();
+        List<Pair<Long, Double>> result = new ArrayList<>();
         for (String tag : tagList) {
-            List<Long> videoIds = tagVideoIdMap.get(tag);
-            if (CollectionUtils.isNotEmpty(videoIds)) {
-                multiTagVideos.add(videoIds.stream().map(videoId -> Pair.of(videoId, 1.0))
-                        .collect(Collectors.toList()));
+            List<Pair<Long, Double>> videoAndScores = model.kv.get(tag);
+            if (CollectionUtils.isNotEmpty(videoAndScores)) {
+                result.addAll(videoAndScores);
             }
         }
-
-        return ListMerger.mergeLists(multiTagVideos);
+        // 结果去重
+        Set<Long> videoIdSet = new HashSet<>();
+        List<Pair<Long, Double>> distinctResult = new ArrayList<>();
+        for (Pair<Long, Double> pair : result) {
+            if (pair.getLeft() == null || pair.getRight() == null) {
+                continue;
+            }
+            if (videoIdSet.add(pair.getLeft())) {
+                distinctResult.add(pair);
+            }
+        }
+        return distinctResult;
     }