Przeglądaj źródła

多样性实验

zhangbo 1 rok temu
rodzic
commit
c3232b5e3a

+ 6 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/model/Video.java

@@ -2,6 +2,9 @@ package com.tzld.piaoquan.recommend.server.model;
 
 import lombok.Data;
 
+import java.util.ArrayList;
+import java.util.List;
+
 /**
  * @author dyp
  */
@@ -23,4 +26,7 @@ public class Video {
     private double rand;
     private String lastVideoKey;
 
+    // video的特征 tag
+    private List<String> tags = new ArrayList<>();
+
 }

+ 129 - 1
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/rank/RankService.java

@@ -1,6 +1,7 @@
 package com.tzld.piaoquan.recommend.server.service.rank;
 
 
+import com.google.common.reflect.TypeToken;
 import com.tzld.piaoquan.recommend.feature.domain.video.base.ItemFeature;
 import com.tzld.piaoquan.recommend.feature.domain.video.base.RequestContext;
 import com.tzld.piaoquan.recommend.feature.domain.video.base.UserFeature;
@@ -10,6 +11,8 @@ import com.tzld.piaoquan.recommend.server.model.MachineInfo;
 import com.tzld.piaoquan.recommend.server.model.Video;
 import com.tzld.piaoquan.recommend.server.remote.FeatureRemoteService;
 import com.tzld.piaoquan.recommend.server.service.flowpool.FlowPoolConstants;
+import com.tzld.piaoquan.recommend.server.service.rank.extractor.RankExtractorFeature;
+import com.tzld.piaoquan.recommend.server.service.rank.processor.RankProcessorDensity;
 import com.tzld.piaoquan.recommend.server.service.recall.RecallResult;
 import com.tzld.piaoquan.recommend.server.service.recall.strategy.*;
 import com.tzld.piaoquan.recommend.server.service.score.ScoreParam;
@@ -75,7 +78,13 @@ public class RankService {
                 JSONUtils.toJson(flowPoolRank));
 
         // 融合排序
-        return mergeAndSort(param, rovRecallRank, flowPoolRank);
+        String abCode = param.getAbCode();
+        switch (abCode){
+            case "60098":
+                return this.mergeAndSort4Density(param, rovRecallRank, flowPoolRank);
+            default:
+                return mergeAndSort(param, rovRecallRank, flowPoolRank);
+        }
     }
 
     private List<Video> mergeAndRankRovRecall(RankParam param) {
@@ -125,6 +134,31 @@ public class RankService {
             rovRecallRank.addAll(extractAndSort(param, ReturnVideoRecallStrategy.PUSH_FORM));
             removeDuplicate(rovRecallRank);
 
+            // 融合排序
+            List<String> videoIdKeys = rovRecallRank.stream()
+                    .map(t -> param.getRankKeyPrefix() + t.getVideoId())
+                    .collect(Collectors.toList());
+            List<String> videoScores = redisTemplate.opsForValue().multiGet(videoIdKeys);
+            log.info("rank mergeAndRankRovRecall videoIdKeys={}, videoScores={}", JSONUtils.toJson(videoIdKeys),
+                    JSONUtils.toJson(videoScores));
+            if (CollectionUtils.isNotEmpty(videoScores)
+                    && videoScores.size() == rovRecallRank.size()) {
+                for (int i = 0; i < videoScores.size(); i++) {
+                    rovRecallRank.get(i).setSortScore(NumberUtils.toDouble(videoScores.get(i), 0.0));
+                }
+                Collections.sort(rovRecallRank, Comparator.comparingDouble(o -> -o.getSortScore()));
+            }
+        }else if(param.getAbCode().equals("60098")) {
+
+            int sizeNew = param.getSize();
+            removeDuplicate(rovRecallRank);
+            rovRecallRank = rovRecallRank.size() <= sizeNew ? rovRecallRank: rovRecallRank.subList(0, sizeNew);
+
+            // merge sim recall 和 return recall
+            rovRecallRank.addAll(extractAndSort(param, SimHotVideoRecallStrategy.PUSH_FORM));
+            rovRecallRank.addAll(extractAndSort(param, ReturnVideoRecallStrategy.PUSH_FORM));
+            removeDuplicate(rovRecallRank);
+
             // 融合排序
             List<String> videoIdKeys = rovRecallRank.stream()
                     .map(t -> param.getRankKeyPrefix() + t.getVideoId())
@@ -359,4 +393,98 @@ public class RankService {
         return new RankResult(result);
     }
 
+    private RankResult mergeAndSort4Density(RankParam param, List<Video> rovRecallRank, List<Video> flowPoolRank) {
+        if (CollectionUtils.isEmpty(rovRecallRank)) {
+            if (param.getSize() < flowPoolRank.size()) {
+                return new RankResult(flowPoolRank.subList(0, param.getSize()));
+            } else {
+                return new RankResult(flowPoolRank);
+            }
+        }
+        // 1 读取多样性密度控制规则
+        String appType = String.valueOf(param.getAppType());
+        String ruleStr = this.redisTemplate.opsForValue().get("TAGS_FILTER_RULE_V1_JSON");
+        Map<String, Integer> densityRules = new HashMap<>();
+        if (ruleStr != null){
+            Map<String, Map<String, Object>> ruleOrigin = JSONUtils.fromJson(ruleStr,
+                    new TypeToken<Map<String, Map<String, Object>>>() {},
+                    Collections.emptyMap());
+            for (Map.Entry<String, Map<String, Object>> entry : ruleOrigin.entrySet()){
+                String k = entry.getKey();
+                if (!entry.getValue().containsKey(appType)){
+                    continue;
+                }
+                String str = (String) entry.getValue().get(appType);
+                Map<String, Object> tmp = JSONUtils.fromJson(str,
+                        new TypeToken<Map<String, Object>>() {},
+                        Collections.emptyMap());
+                if (tmp.containsKey("density") && tmp.get("density") instanceof Integer){
+                    densityRules.put(k, (Integer)tmp.get("density"));
+                }
+            }
+        }
+        // 2 读取video的tags
+        List<Long> videoIds = new ArrayList<>();
+        for (Video v : rovRecallRank) {
+            videoIds.add(v.getVideoId());
+        }
+        for (Video v : flowPoolRank) {
+            videoIds.add(v.getVideoId());
+        }
+        Map<Long, List<String>> videoTagDict = RankExtractorFeature.getVideoTags(this.redisTemplate, videoIds);
+        for (Video v : rovRecallRank) {
+            v.setTags(videoTagDict.getOrDefault(v.getVideoId(), new ArrayList<>()));
+        }
+        for (Video v : flowPoolRank) {
+            v.setTags(videoTagDict.getOrDefault(v.getVideoId(), new ArrayList<>()));
+        }
+        // ------------------读取video的tags完成---------------------
+
+
+        List<Video> result = new ArrayList<>();
+        for (int i = 0; i < param.getTopK() && i < rovRecallRank.size(); i++) {
+            result.add(rovRecallRank.get(i));
+        }
+
+        double flowPoolP = getFlowPoolP(param);
+        int flowPoolIndex = 0;
+        int rovPoolIndex = param.getTopK();
+
+        for (int i = 0; i < param.getSize() - param.getTopK(); i++) {
+            double rand = RandomUtils.nextDouble(0, 1);
+            log.info("rand={}, flowPoolP={}", rand, flowPoolP);
+            if (rand < flowPoolP) {
+                if (flowPoolIndex < flowPoolRank.size()) {
+                    result.add(flowPoolRank.get(flowPoolIndex++));
+                } else {
+                    break;
+                }
+            } else {
+                if (rovPoolIndex < rovRecallRank.size()) {
+                    result.add(rovRecallRank.get(rovPoolIndex++));
+                } else {
+                    break;
+                }
+            }
+        }
+        if (rovPoolIndex >= rovRecallRank.size()) {
+            for (int i = flowPoolIndex; i < flowPoolRank.size() && result.size() < param.getSize(); i++) {
+                result.add(flowPoolRank.get(i));
+            }
+        }
+        if (flowPoolIndex >= flowPoolRank.size()) {
+            for (int i = rovPoolIndex; i < rovRecallRank.size() && result.size() < param.getSize(); i++) {
+                result.add(rovRecallRank.get(i));
+            }
+        }
+
+        // 3 进行密度控制
+        Set<Long> videosSet = result.stream().map(r-> r.getVideoId()).collect(Collectors.toSet());
+        rovRecallRank = rovRecallRank.stream().filter(r -> !videosSet.contains(r.getVideoId())).collect(Collectors.toList());
+        flowPoolRank = flowPoolRank.stream().filter(r -> !videosSet.contains(r.getVideoId())).collect(Collectors.toList());
+        List<Video> resultWithDnsity = RankProcessorDensity.mergeDensityControl(result,
+                rovRecallRank, flowPoolRank, densityRules);
+        return new RankResult(resultWithDnsity);
+    }
+
 }

+ 29 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/rank/extractor/RankExtractorFeature.java

@@ -0,0 +1,29 @@
+package com.tzld.piaoquan.recommend.server.service.rank.extractor;
+import org.springframework.data.redis.core.RedisTemplate;
+
+import java.util.*;
+
+public class RankExtractorFeature {
+
+    public static Map<Long, List<String>> getVideoTags(RedisTemplate<String, String> redisHelper, List<Long> videoIds) {
+        String REDIS_PREFIX = "alg_recsys_video_tags_";
+        List<String> redisKeys = new ArrayList<>();
+        for (Long videoId : videoIds) {
+            redisKeys.add(REDIS_PREFIX + String.valueOf(videoId));
+        }
+        List<String> videoTags = redisHelper.opsForValue().multiGet(redisKeys);
+        Map<Long, List<String>> videoTagDict = new HashMap<>();
+        if (videoTags != null) {
+            for (int i = 0; i < videoTags.size(); i++) {
+                String tagsStr = videoTags.get(i);
+                List<String> tags = new ArrayList<>();
+                if (tagsStr != null && !tagsStr.isEmpty()) {
+                    String[] tagsArray = tagsStr.split(",");
+                    tags = new ArrayList<>(Arrays.asList(tagsArray));
+                }
+                videoTagDict.put(videoIds.get(i), tags);
+            }
+        }
+        return videoTagDict;
+    }
+}

+ 119 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/rank/processor/RankProcessorDensity.java

@@ -0,0 +1,119 @@
+package com.tzld.piaoquan.recommend.server.service.rank.processor;
+import com.tzld.piaoquan.recommend.server.model.Video;
+
+import java.util.*;
+
+public class RankProcessorDensity {
+
+    public static List<Video> mergeDensityControl(
+            List<Video> data, List<Video> rov, List<Video> flow, Map<String, Integer> rule
+    ) {
+        // 1 判断是否满足规则
+        Map<String, Integer> statusCur = new HashMap<>();
+        for (Video v : data) {
+            List<String> tags = v.getTags();
+            for (String t : tags) {
+                if (rule.containsKey(t)) {
+                    statusCur.put(t, statusCur.getOrDefault(t, 0) + 1);
+                }
+            }
+
+        }
+        Map<String, Integer> statusCurIllegal = new HashMap<>();
+        for (Map.Entry<String, Integer> entry : statusCur.entrySet()) {
+            String k = entry.getKey();
+            Integer v = entry.getValue();
+            if (rule.containsKey(k) && rule.get(k) < v) {
+                statusCurIllegal.put(k, v - rule.get(k));
+            }
+        }
+
+        if (statusCurIllegal.isEmpty()) {
+            return data;
+        }
+
+        // 2 反向遍历,直到statusCurIllegal满足,记录要替换的index和召回池标记。
+        List<Integer> indexes = new ArrayList<>();
+        List<String> pushes = new ArrayList<>();
+        for (int i = data.size() - 1; i >= 0; i--) {
+            Video video = data.get(i);
+            List<String> tags = video.getTags();
+            Set<String> inters = new HashSet<>(tags);
+            // intersection of tags and statusCurIllegal keys
+            inters.retainAll(statusCurIllegal.keySet());
+            if (inters.isEmpty()) {
+                continue;
+            }
+            indexes.add(i);
+            if (video.getFlowPool() != null){
+                pushes.add(video.getFlowPool());
+            }else{
+                pushes.add("");
+            }
+
+            for (String inter : inters) {
+                statusCurIllegal.put(inter, statusCurIllegal.get(inter) - 1);
+                if (statusCurIllegal.get(inter) == 0) {
+                    statusCurIllegal.remove(inter);
+                }
+                statusCur.put(inter, statusCur.get(inter) - 1);
+                if (statusCur.get(inter) == 0) {
+                    statusCur.remove(inter);
+                }
+            }
+        }
+
+        // 3 反向遍历index,再正向遍历增补列表,取可替换的video
+        Collections.reverse(indexes);
+        Collections.reverse(pushes);
+        for (int j = 0; j < indexes.size(); j++) {
+            int index = indexes.get(j);
+            String push = pushes.get(j);
+            List<Video> candidate;
+            if (!push.isEmpty()) {
+                // 5 如果是flow的video 取不到 不做替换
+                candidate = flow;
+            } else {
+                // 5 如果是rov的video  取不到 不做替换
+                candidate = rov;
+            }
+            for (int i = 0; i < candidate.size(); i++) {
+                Video videoNew = candidate.get(i);
+                Set<String> judgeRuleSet = judgeRule(rule, statusCur);
+                Set<String> tags1 = new HashSet<>(videoNew.getTags());
+                tags1.retainAll(judgeRuleSet);
+                if (!tags1.isEmpty()) {
+                    continue;
+                }
+                // 开始插入
+                Video tmp = data.get(index);
+                data.set(index, videoNew);
+                candidate.set(i, tmp);
+                // 更新状态
+                Set<String> tags2 = new HashSet<>(videoNew.getTags());
+                for (String tag : tags2) {
+                    statusCur.put(tag, statusCur.getOrDefault(tag, 0) + 1);
+                }
+                break;
+            }
+            if (!push.isEmpty()) {
+                flow = candidate;
+            } else {
+                rov = candidate;
+            }
+        }
+        return data;
+    }
+
+    public static Set<String> judgeRule(Map<String, Integer> rule, Map<String, Integer> status) {
+        Set<String> result = new HashSet<>();
+        for (Map.Entry<String, Integer> entry : status.entrySet()) {
+            String k = entry.getKey();
+            Integer v = entry.getValue();
+            if (rule.containsKey(k) && rule.get(k) <= v) {
+                result.add(k);
+            }
+        }
+        return result;
+    }
+}

+ 3 - 1
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/recall/RecallService.java

@@ -116,7 +116,9 @@ public class RecallService implements ApplicationContextAware {
                 || param.getAbCode().equals("60094")
                 || param.getAbCode().equals("60096")
                 || param.getAbCode().equals("60095")
-                || param.getAbCode().equals("60097")) {
+                || param.getAbCode().equals("60097")
+                || param.getAbCode().equals("60098")
+        ) {
             strategies.add(strategyMap.get(SimHotVideoRecallStrategy.class.getSimpleName()));
             strategies.add(strategyMap.get(ReturnVideoRecallStrategy.class.getSimpleName()));
         }