jch hace 3 semanas
padre
commit
cbed3cb5c6

+ 76 - 53
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/rank/strategy/RankStrategy4RegionMergeModelV564.java

@@ -17,6 +17,7 @@ import com.tzld.piaoquan.recommend.server.service.score.ScorerUtils;
 import com.tzld.piaoquan.recommend.server.util.CommonCollectionUtils;
 import com.tzld.piaoquan.recommend.server.util.FeatureBucketUtils;
 import com.tzld.piaoquan.recommend.server.util.JSONUtils;
+import com.tzld.piaoquan.recommend.server.util.RecallUtils;
 import lombok.extern.slf4j.Slf4j;
 import org.apache.commons.collections4.MapUtils;
 import org.apache.commons.lang3.StringUtils;
@@ -26,7 +27,6 @@ import org.springframework.stereotype.Service;
 import java.util.*;
 import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
-import java.util.stream.Collectors;
 
 @Service
 @Slf4j
@@ -47,57 +47,28 @@ public class RankStrategy4RegionMergeModelV564 extends RankStrategy4RegionMergeM
         //-------------------辑-------------------
 
         long currentMs = System.currentTimeMillis();
-        List<Video> oldRovs = new ArrayList<>();
-        oldRovs.addAll(extractAndSort(param, RegionHRecallStrategy.PUSH_FORM));
-        oldRovs.addAll(extractAndSort(param, RegionHDupRecallStrategy.PUSH_FORM));
-        oldRovs.addAll(extractAndSort(param, Region24HRecallStrategy.PUSH_FORM));
-        oldRovs.addAll(extractAndSort(param, RegionRelative24HRecallStrategy.PUSH_FORM));
-        oldRovs.addAll(extractAndSort(param, RegionRelative24HDupRecallStrategy.PUSH_FORM));
-        removeDuplicate(oldRovs);
-        int sizeReturn = param.getSize();
-        List<Video> v0 = oldRovs.size() <= sizeReturn
-                ? oldRovs
-                : oldRovs.subList(0, sizeReturn);
         Set<Long> setVideo = new HashSet<>();
-        this.duplicate(setVideo, v0);
-        setVideo.addAll(v0.stream().map(Video::getVideoId).collect(Collectors.toSet()));
-        List<Video> rovRecallRank = new ArrayList<>(v0);
+        List<Video> rovRecallRank = new ArrayList<>();
+        // -------------------5路特殊旧召回------------------
+        RecallUtils.extractOldSpecialRecall(param, setVideo, rovRecallRank);
         //-------------------return相似召回------------------
-        List<Video> v6 = extractAndSort(param, ReturnVideoRecallStrategy.PUSH_FORM);
-        v6 = v6.stream().filter(r -> !setVideo.contains(r.getVideoId())).collect(Collectors.toList());
-        v6 = v6.subList(0, Math.min(mergeWeight.getOrDefault("v6", 5.0).intValue(), v6.size()));
-        rovRecallRank.addAll(v6);
-        setVideo.addAll(v6.stream().map(Video::getVideoId).collect(Collectors.toSet()));
+        RecallUtils.extractRecall(mergeWeight.getOrDefault("v6", 5.0).intValue(), param, ReturnVideoRecallStrategy.PUSH_FORM, setVideo, rovRecallRank);
         //-------------------新地域召回------------------
-        List<Video> v1 = extractAndSort(param, RegionRealtimeRecallStrategyV1.PUSH_FORM);
-        v1 = v1.stream().filter(r -> !setVideo.contains(r.getVideoId())).collect(Collectors.toList());
-        v1 = v1.subList(0, Math.min(mergeWeight.getOrDefault("v1", 5.0).intValue(), v1.size()));
-        rovRecallRank.addAll(v1);
-        setVideo.addAll(v1.stream().map(Video::getVideoId).collect(Collectors.toSet()));
+        RecallUtils.extractRecall(mergeWeight.getOrDefault("v1", 5.0).intValue(), param, RegionRealtimeRecallStrategyV1.PUSH_FORM, setVideo, rovRecallRank);
         //-------------------scene cf rovn------------------
-        List<Video> sceneCFRovn = extractAndSort(param, SceneCFRovnRecallStrategy.PUSH_FORM);
-        sceneCFRovn = sceneCFRovn.stream().filter(r -> !setVideo.contains(r.getVideoId())).collect(Collectors.toList());
-        sceneCFRovn = sceneCFRovn.subList(0, Math.min(mergeWeight.getOrDefault("sceneCFRovn", 5.0).intValue(), sceneCFRovn.size()));
-        rovRecallRank.addAll(sceneCFRovn);
-        setVideo.addAll(sceneCFRovn.stream().map(Video::getVideoId).collect(Collectors.toSet()));
+        RecallUtils.extractRecall(mergeWeight.getOrDefault("sceneCFRovn", 5.0).intValue(), param, SceneCFRovnRecallStrategy.PUSH_FORM, setVideo, rovRecallRank);
         //-------------------scene cf rosn------------------
-        List<Video> sceneCFRosn = extractAndSort(param, SceneCFRosnRecallStrategy.PUSH_FORM);
-        sceneCFRosn = sceneCFRosn.stream().filter(r -> !setVideo.contains(r.getVideoId())).collect(Collectors.toList());
-        sceneCFRosn = sceneCFRosn.subList(0, Math.min(mergeWeight.getOrDefault("sceneCFRosn", 5.0).intValue(), sceneCFRosn.size()));
-        rovRecallRank.addAll(sceneCFRosn);
-        setVideo.addAll(sceneCFRosn.stream().map(Video::getVideoId).collect(Collectors.toSet()));
-        // -------------------cate1------------------
-        int cate1RecallN = mergeWeight.getOrDefault("cate1RecallN", 5.0).intValue();
-        addRecall(param, cate1RecallN, UserCate1RecallStrategy.PUSH_FORM, setVideo, rovRecallRank);
-        // -------------------cate2------------------
-        int cate2RecallN = mergeWeight.getOrDefault("cate2RecallN", 5.0).intValue();
-        addRecall(param, cate2RecallN, UserCate2RecallStrategy.PUSH_FORM, setVideo, rovRecallRank);
+        RecallUtils.extractRecall(mergeWeight.getOrDefault("sceneCFRosn", 5.0).intValue(), param, SceneCFRosnRecallStrategy.PUSH_FORM, setVideo, rovRecallRank);
+        // -------------------user cate1------------------
+        RecallUtils.extractRecall(mergeWeight.getOrDefault("cate1RecallN", 5.0).intValue(), param, UserCate1RecallStrategy.PUSH_FORM, setVideo, rovRecallRank);
+        // -------------------user cate2------------------
+        RecallUtils.extractRecall(mergeWeight.getOrDefault("cate2RecallN", 5.0).intValue(), param, UserCate2RecallStrategy.PUSH_FORM, setVideo, rovRecallRank);
         // -------------------head province cate1------------------
-        int headCate1RecallN = mergeWeight.getOrDefault("headCate1RecallN", 3.0).intValue();
-        addRecall(param, headCate1RecallN, HeadProvinceCate1RecallStrategy.PUSH_FORM, setVideo, rovRecallRank);
+        RecallUtils.extractRecall(mergeWeight.getOrDefault("headCate1RecallN", 3.0).intValue(), param, HeadProvinceCate1RecallStrategy.PUSH_FORM, setVideo, rovRecallRank);
         // -------------------head province cate2------------------
-        int headCate2RecallN = mergeWeight.getOrDefault("headCate2RecallN", 3.0).intValue();
-        addRecall(param, headCate2RecallN, HeadProvinceCate2RecallStrategy.PUSH_FORM, setVideo, rovRecallRank);
+        RecallUtils.extractRecall(mergeWeight.getOrDefault("headCate2RecallN", 3.0).intValue(), param, HeadProvinceCate2RecallStrategy.PUSH_FORM, setVideo, rovRecallRank);
+        //-------------------head cate2 of rovn------------------
+        RecallUtils.extractRecall(mergeWeight.getOrDefault("headCate2Rov", 5.0).intValue(), param, HeadCate2RovRecallStrategy.PUSH_FROM, setVideo, rovRecallRank);
 
         //-------------------排-------------------
         //-------------------序-------------------
@@ -135,6 +106,19 @@ public class RankStrategy4RegionMergeModelV564 extends RankStrategy4RegionMergeM
         double xgbNorPowerWeight = mergeWeight.getOrDefault("xgbNorPowerWeight", 1.22);
         double xgbNorPowerExp = mergeWeight.getOrDefault("xgbNorPowerExp", 1.15);
         Map<String, Map<String, String>> vid2MapFeature = this.getVideoRedisFeature(vids, "redis:vid_hasreturn_vor:");
+
+        // 获取权重
+        Map<String, Double> cate2Coefficient = new HashMap<>();
+        double cate2CoefficientFunc = mergeWeight.getOrDefault("cate2CoefficientFunc", 0d);
+        if (cate2CoefficientFunc == 1d) {
+            String headVidStr = String.valueOf(param.getHeadVid());
+            String mergeCate2 = this.findVideoMergeCate2(videoBaseInfoMap, headVidStr);
+            Double length = mergeWeight.getOrDefault("cate2CoefficientLength", 10000d);
+            Map<String, Double> simCateScore = this.findSimCateScore(mergeCate2, length.intValue());
+            cate2Coefficient.putAll(simCateScore);
+        }
+        Double cate2CoefficientDenominator = mergeWeight.getOrDefault("cate2CoefficientDenominator", 1d);
+
         List<Video> result = new ArrayList<>();
         for (RankItem item : items) {
             double score;
@@ -149,7 +133,12 @@ public class RankStrategy4RegionMergeModelV564 extends RankStrategy4RegionMergeM
             double vor = Double.parseDouble(vid2MapFeature.getOrDefault(item.getVideoId() + "", new HashMap<>()).getOrDefault("vor", "0"));
             item.getScoresMap().put("vor", vor);
 
-            score = fmRov * (0.1 + newNorXGBScore) * (0.1 + vor);
+            String vidMergeCate2 = this.findVideoMergeCate2(videoBaseInfoMap, String.valueOf(item.getVideoId()));
+            Double scoreCoefficient = cate2Coefficient.getOrDefault(vidMergeCate2, 0d);
+            item.getScoresMap().put("scoreCoefficient", scoreCoefficient);
+            item.getScoresMap().put("cate2CoefficientDenominator", cate2CoefficientDenominator);
+
+            score = fmRov * (0.1 + newNorXGBScore) * (0.1 + vor) * (1 + scoreCoefficient / cate2CoefficientDenominator);
 
             Video video = item.getVideo();
             video.setScore(score);
@@ -356,13 +345,47 @@ public class RankStrategy4RegionMergeModelV564 extends RankStrategy4RegionMergeM
         return newScore;
     }
 
-    private void addRecall(RankParam param, int recallNum, String recallName, Set<Long> setVideo, List<Video> rovRecallRank) {
-        if (recallNum > 0) {
-            List<Video> list = extractAndSort(param, recallName);
-            list = list.stream().filter(r -> !setVideo.contains(r.getVideoId())).collect(Collectors.toList());
-            list = list.subList(0, Math.min(recallNum, list.size()));
-            rovRecallRank.addAll(list);
-            setVideo.addAll(list.stream().map(Video::getVideoId).collect(Collectors.toSet()));
+    private Map<String, Double> findSimCateScore(String headCate2, int length) {
+        if (StringUtils.isBlank(headCate2)) {
+            return new HashMap<>();
+        }
+
+        String redisKey = String.format("alg_recsys_good_cate_pair_list:%s", headCate2);
+        String cate2Value = redisTemplate.opsForValue().get(redisKey);
+        if (StringUtils.isEmpty(cate2Value)) {
+            return new HashMap<>();
+        }
+
+        return this.parsePair(cate2Value, length);
+    }
+
+    private Map<String, Double> parsePair(String value, int length) {
+        if (StringUtils.isBlank(value)) {
+            return new HashMap<>();
         }
+
+        String[] split = value.split("\t");
+        if (split.length != 2) {
+            return new HashMap<>();
+        }
+
+        String[] valueList = split[0].trim().split(",");
+        String[] scoreList = split[1].trim().split(",");
+        if (valueList.length != scoreList.length) {
+            return new HashMap<>();
+        }
+
+        int minLength = Math.min(length, valueList.length);
+        Map<String, Double> resultMap = new HashMap<>();
+        for (int i = 0; i < minLength; i++) {
+            resultMap.put(valueList[i].trim(), Double.parseDouble(scoreList[i].trim()));
+        }
+
+        return resultMap;
+    }
+
+    private String findVideoMergeCate2(Map<String, Map<String, Map<String, String>>> featureOriginVideo, String vid) {
+        Map<String, String> videoInfo = featureOriginVideo.getOrDefault(vid, new HashMap<>()).getOrDefault("alg_vid_feature_basic_info", new HashMap<>());
+        return videoInfo.get("merge_second_level_cate");
     }
 }

+ 6 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/recall/strategy/HeadCate2AndChannelRovRecallStrategy.java

@@ -87,6 +87,7 @@ public class HeadCate2AndChannelRovRecallStrategy implements RecallStrategy {
         Set<Long> filterVids = new HashSet<>(filterResult.getVideoIds());
         filterVids.remove(param.getVideoId());
 
+        Set<Long> hit = new HashSet<>();
         for (Map.Entry<String, Double> entry : mergeCate2Pair.entrySet()) {
             String cate = entry.getKey();
             Double cateScore = entry.getValue();
@@ -99,6 +100,11 @@ public class HeadCate2AndChannelRovRecallStrategy implements RecallStrategy {
                 if (!filterVids.contains(vid)) {
                     continue;
                 }
+                // 去重
+                if (hit.contains(vid)) {
+                    continue;
+                }
+                hit.add(vid);
 
                 Double videoScore = videoEntry.getValue();
 

+ 6 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/recall/strategy/HeadCate2RovRecallStrategy.java

@@ -85,6 +85,7 @@ public class HeadCate2RovRecallStrategy implements RecallStrategy {
         Set<Long> filterVids = new HashSet<>(filterResult.getVideoIds());
         filterVids.remove(param.getVideoId());
 
+        Set<Long> hit = new HashSet<>();
         for (Map.Entry<String, Double> entry : mergeCate2Pair.entrySet()) {
             String cate = entry.getKey();
             Double cateScore = entry.getValue();
@@ -97,6 +98,11 @@ public class HeadCate2RovRecallStrategy implements RecallStrategy {
                 if (!filterVids.contains(vid)) {
                     continue;
                 }
+                // 去重
+                if (hit.contains(vid)) {
+                    continue;
+                }
+                hit.add(vid);
 
                 Double videoScore = videoEntry.getValue();
 

+ 73 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/util/RecallUtils.java

@@ -0,0 +1,73 @@
+package com.tzld.piaoquan.recommend.server.util;
+
+import com.tzld.piaoquan.recommend.server.model.Video;
+import com.tzld.piaoquan.recommend.server.service.rank.RankParam;
+import com.tzld.piaoquan.recommend.server.service.recall.RecallResult;
+import com.tzld.piaoquan.recommend.server.service.recall.strategy.*;
+import org.apache.commons.collections4.CollectionUtils;
+
+import java.util.*;
+import java.util.stream.Collectors;
+
+public class RecallUtils {
+    private static void removeDuplicate(List<Video> rovRecallRank) {
+        if (CollectionUtils.isNotEmpty(rovRecallRank)) {
+            Set<Long> videoIds = new HashSet<>();
+            Iterator<Video> ite = rovRecallRank.iterator();
+            while (ite.hasNext()) {
+                Video v = ite.next();
+                if (videoIds.contains(v.getVideoId())) {
+                    ite.remove();
+                    continue;
+                }
+                videoIds.add(v.getVideoId());
+            }
+        }
+    }
+
+    private static List<Video> extractAndSort(RankParam param, String pushFrom) {
+        if (param == null
+                || param.getRecallResult() == null
+                || CollectionUtils.isEmpty(param.getRecallResult().getData())) {
+            return Collections.emptyList();
+        }
+        Optional<RecallResult.RecallData> data = param.getRecallResult().getData().stream()
+                .filter(d -> d.getPushFrom().equals(pushFrom))
+                .findFirst();
+        if (data.isPresent()
+                && data.get() != null
+                && CollectionUtils.isNotEmpty(data.get().getVideos())) {
+            List<Video> result = data.get().getVideos();
+            Collections.sort(result, Comparator.comparingDouble(o -> -o.getRovScore()));
+            return result;
+        }
+        return Collections.emptyList();
+    }
+
+    public static void extractOldSpecialRecall(RankParam param, Set<Long> setVideo, List<Video> rovRecallRank) {
+        List<Video> oldRovs = new ArrayList<>();
+        oldRovs.addAll(extractAndSort(param, RegionHRecallStrategy.PUSH_FORM));
+        oldRovs.addAll(extractAndSort(param, RegionHDupRecallStrategy.PUSH_FORM));
+        oldRovs.addAll(extractAndSort(param, Region24HRecallStrategy.PUSH_FORM));
+        oldRovs.addAll(extractAndSort(param, RegionRelative24HRecallStrategy.PUSH_FORM));
+        oldRovs.addAll(extractAndSort(param, RegionRelative24HDupRecallStrategy.PUSH_FORM));
+        removeDuplicate(oldRovs);
+        int sizeReturn = param.getSize();
+        List<Video> v0 = oldRovs.size() <= sizeReturn
+                ? oldRovs
+                : oldRovs.subList(0, sizeReturn);
+
+        rovRecallRank.addAll(v0);
+        setVideo.addAll(v0.stream().map(Video::getVideoId).collect(Collectors.toSet()));
+    }
+
+    public static void extractRecall(int recallNum, RankParam param, String pushFrom, Set<Long> setVideo, List<Video> rovRecallRank) {
+        if (recallNum > 0) {
+            List<Video> list = extractAndSort(param, pushFrom);
+            list = list.stream().filter(r -> !setVideo.contains(r.getVideoId())).collect(Collectors.toList());
+            list = list.subList(0, Math.min(recallNum, list.size()));
+            rovRecallRank.addAll(list);
+            setVideo.addAll(list.stream().map(Video::getVideoId).collect(Collectors.toSet()));
+        }
+    }
+}