浏览代码

MOD: merge sort

sunxy 1 年之前
父节点
当前提交
1e54ddf89d

+ 1 - 6
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/framework/merger/MergeUtils.java

@@ -119,12 +119,7 @@ public class MergeUtils {
                                           int freeRecNum,
                                           int expId) {
         // 先定义按score排序优选的独立额
-        PriorityQueue<Pair<String, Integer>> mergePriorityQueue = new PriorityQueue<Pair<String, Integer>>(freeRecNum, new Comparator<Pair<String, Integer>>() {
-            @Override
-            public int compare(Pair<String, Integer> o1, Pair<String, Integer> o2) {
-                return rankerItemsListMap.get(o1.getLeft()).getRight().get(o1.getRight()).compareTo(rankerItemsListMap.get(o2.getLeft()).getRight().get(o2.getRight()));
-            }
-        });
+        PriorityQueue<Pair<String, Integer>> mergePriorityQueue = new PriorityQueue<Pair<String, Integer>>(freeRecNum, Comparator.comparing(o -> rankerItemsListMap.get(o.getLeft()).getRight().get(o.getRight())));
 
         // 大于最小mergenum 同时小于maxMergenum,在队列中add
         for (Pair<MergeRule, List<RankItem>> entry : rankerItemsListMap.values()) {

+ 37 - 37
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/implement/TopRecommendPipeline.java

@@ -91,13 +91,34 @@ public class TopRecommendPipeline {
         return videos;
     }
 
+    private List<Video> rankItem2Video(List<RankItem> rankItems) {
+        List<Video> videos = new ArrayList<>();
+        for (RankItem item : rankItems) {
+            Video video = new Video();
+            video.setVideoId(Long.parseLong(item.getId()));
+            video.setPushFrom(item.getQueue());
+            video.setScore(item.getScore());
+            video.setSortScore(item.getScore());
+            video.setScoreStr(item.getScoreStr());
+            video.setScoresMap(item.getScoresMap());
+
+            Map<String, List<String>> pushFromIndex = new HashMap<>();
+            pushFromIndex.put(item.getQueue(), item.getCandidateInfoList().stream()
+                    .map(CandidateInfo::getCandidateQueueName).collect(Collectors.toList()));
+            video.setPushFromIndex(pushFromIndex);
+            videos.add(video);
+        }
+        videos.sort(Comparator.comparing(Video::getScore).reversed());
+        return videos;
+    }
+
     public List<Double> getStaticData(Map<String, Map<String, Double>> itemRealMap,
-                                      List<String> datehours, String key){
+                                      List<String> datehours, String key) {
         List<Double> views = new LinkedList<>();
         Map<String, Double> tmp = itemRealMap.getOrDefault(key, new HashMap<>());
-        for (String dh : datehours){
+        for (String dh : datehours) {
             views.add(tmp.getOrDefault(dh, 0.0D) +
-                    (views.isEmpty() ? 0.0: views.get(views.size()-1))
+                    (views.isEmpty() ? 0.0 : views.get(views.size() - 1))
             );
         }
         return views;
@@ -158,14 +179,18 @@ public class TopRecommendPipeline {
         timeLogMap.put("recalling-cost", stopwatch.elapsed().toMillis() + "");
         timeLogMap.put("recalling-size", items == null ? "0" : items.size() + "");
 
-        // Step 4: Advance Scoring
-//        timestamp = System.currentTimeMillis();
-//        ScorerPipeline scorerPipeline = getScorerPipeline(requestData);
-//        items = scorerPipeline.scoring(requestData, userInfo, requestIndex, items);
         if (CollectionUtils.isEmpty(items)) {
             return new ArrayList<>();
         }
 
+        // Step 4: Advance Scoring
+        stopwatch.reset().start();
+        videoScoredByFeature(items);
+        if (logPrint) {
+            log.info("traceId = {}, cost = {}, items = {}", requestData.getRequestId(),
+                    stopwatch.elapsed().toMillis(), JSONUtils.toJson(items));
+        }
+
         stopwatch.reset().start();
         // Step 5: Merger
         MergeUtils.distributeItemsToMultiQueues(topQueue, items);
@@ -188,18 +213,7 @@ public class TopRecommendPipeline {
 //        MergeUtils.diversityRerank(mergeItems, SimilarityUtils.getIsSameUserTagOrCategoryFunc(), recallNum, 6, 2);
 
         // Step 6: Global Rank & subList
-        // TODO 前置和后置处理逻辑 hardcode,后续优化
-        stopwatch.reset().start();
-        List<RankItem> rovRecallRankNewScore = rankByScore(mergeItems, requestData);
-
-        timeLogMap.put("rankByScore-cost", stopwatch.elapsed().toMillis() + "");
-        timeLogMap.put("rankByScore-size", rovRecallRankNewScore.size() + "");
-
-        if (logPrint) {
-            log.info("traceId = {}, cost = {}, rovRecallRankNewScore = {}", requestData.getRequestId(),
-                    stopwatch.elapsed().toMillis(), JSONUtils.toJson(rovRecallRankNewScore));
-        }
-        return rovRecallRankNewScore;
+        return mergeItems;
     }
 
     public Double calScoreWeightNoTimeDecay(List<Double> data) {
@@ -212,7 +226,7 @@ public class TopRecommendPipeline {
         return down > 1E-8 ? up / down : 0.0;
     }
 
-    private List<Video> rankItem2Video(List<RankItem> items) {
+    private void videoScoredByFeature(List<RankItem> items) {
         // 1 模型分
         List<String> rtFeaPartKey = new ArrayList<>(Arrays.asList("item_rt_fea_1day_partition", "item_rt_fea_1h_partition"));
         List<String> rtFeaPartKeyResult = this.redisTemplate.opsForValue().multiGet(rtFeaPartKey);
@@ -280,7 +294,6 @@ public class TopRecommendPipeline {
 
         }
         // 3 融合公式
-        List<Video> result = new ArrayList<>();
         double a = mergeWeight.getOrDefault("a", 0.1);
         double b = mergeWeight.getOrDefault("b", 0.0);
         double c = mergeWeight.getOrDefault("c", 0.000001);
@@ -300,7 +313,7 @@ public class TopRecommendPipeline {
             double share2allreturnScore = item.scoresMap.getOrDefault("share2allreturnScore", 0.0);
             double view2allreturnScore = item.scoresMap.getOrDefault("view2allreturnScore", 0.0);
             double preturnsScore = Math.log(1 + item.scoresMap.getOrDefault("preturnsScore", 0.0));
-            double score = 0.0;
+            double score;
             if (ifAdd < 0.5) {
                 score = Math.pow(strScore, a) * Math.pow(rosScore, b) + c * preturnsScore +
                         (newVideoScore > 1E-8 ? d * trendScore * (e + newVideoScore) : 0.0);
@@ -313,22 +326,9 @@ public class TopRecommendPipeline {
             if (allreturnsScore > h) {
                 score += (f * share2allreturnScore + g * view2allreturnScore);
             }
-            Video video = new Video();
-            video.setVideoId(Long.parseLong(item.getId()));
-            video.setPushFrom(item.getQueue());
-            video.setScore(score);
-            video.setSortScore(score);
-            video.setScoreStr(item.getScoreStr());
-            video.setScoresMap(item.getScoresMap());
-
-            Map<String, List<String>> pushFromIndex = new HashMap<>();
-            pushFromIndex.put(item.getQueue(), item.getCandidateInfoList().stream()
-                    .map(CandidateInfo::getCandidateQueueName).collect(Collectors.toList()));
-            video.setPushFromIndex(pushFromIndex);
-            result.add(video);
+            // 设置计算好的分数
+            item.setScore(score);
         }
-        Collections.sort(result, Comparator.comparingDouble(o -> -o.getSortScore()));
-        return result;
     }
 
     public double calNewVideoScore(Map<String, String> itemBasicMap) {