Kaynağa Gözat

Merge branch 'feature_20250415_zhaohaipeng_cate_recall' of algorithm/recommend-server into master

zhaohaipeng 1 hafta önce
ebeveyn
işleme
5f49f7ed7f

+ 2 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/FeatureService.java

@@ -79,6 +79,8 @@ public class FeatureService {
             //protos.add(genWithVidAndHeadVid("alg_recsys_feature_cf_i2i_new_v2", vid, headVid));
         }
 
+        // 头部视频的基础信息
+        protos.add(genWithVid("alg_vid_feature_basic_info", headVid));
 
         // user
         protos.add(genWithMid("alg_mid_feature_play", mid));

+ 73 - 56
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/rank/strategy/RankStrategy4RegionMergeModelV567.java

@@ -12,6 +12,7 @@ import com.tzld.piaoquan.recommend.server.service.score.ScorerUtils;
 import com.tzld.piaoquan.recommend.server.util.CommonCollectionUtils;
 import lombok.extern.slf4j.Slf4j;
 import org.apache.commons.collections4.MapUtils;
+import org.apache.commons.lang3.StringUtils;
 import org.apache.commons.math3.util.Pair;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.stereotype.Service;
@@ -78,6 +79,15 @@ public class RankStrategy4RegionMergeModelV567 extends RankStrategy4RegionMergeM
         rovRecallRank.addAll(sceneCFRosn);
         setVideo.addAll(sceneCFRosn.stream().map(Video::getVideoId).collect(Collectors.toSet()));
 
+        //-------------------head cate2 of rovn------------------
+        List<Video> headCate2Rov = extractAndSort(param, HeadCate2RovRecallStrategy.PUSH_FROM);
+        // 视频去重
+        removeDuplicate(headCate2Rov);
+        headCate2Rov = headCate2Rov.stream().filter(o -> !setVideo.contains(o.getVideoId())).collect(Collectors.toList());
+        headCate2Rov = headCate2Rov.subList(0, Math.min(mergeWeight.getOrDefault("headCate2Rov", 5.0).intValue(), headCate2Rov.size()));
+        rovRecallRank.addAll(headCate2Rov);
+        setVideo.addAll(headCate2Rov.stream().map(Video::getVideoId).collect(Collectors.toSet()));
+
         //-------------------排-------------------
         //-------------------序-------------------
         //-------------------逻-------------------
@@ -318,62 +328,39 @@ public class RankStrategy4RegionMergeModelV567 extends RankStrategy4RegionMergeM
         // 5 排序公式特征
         Map<String, Map<String, String>> vid2MapFeature = this.getVideoRedisFeature(vids, "redis:vid_hasreturn_vor:");
 
-        // Ros增强传播因子
-        Map<String, Map<String, String>> rosSpreadDivMap = this.getVideoRedisFeature(vids, "vid_for_spread:");
-
-        List<Video> result = new ArrayList<>();
-
-        double calcVorMode = mergeWeight.getOrDefault("calcVorMode", 3d);
-        double calcRosMode = mergeWeight.getOrDefault("calcRosMode", 0d);
-        double calcStrMode = mergeWeight.getOrDefault("calcStrMode", 3d);
-
-        double rosAdd = mergeWeight.getOrDefault("ros_add", 0.1d);
-        double ros2Multi = mergeWeight.getOrDefault("ros2_multi", 1d);
-        double vorAdd = mergeWeight.getOrDefault("vor_add", 0d);
+        // 获取权重
+        Map<String, Double> cate2Coefficient = new HashMap<>();
+        double cate2CoefficientFunc = mergeWeight.getOrDefault("cate2CoefficientFunc", 0d);
+        if (cate2CoefficientFunc == 1d) {
+            String headVidStr = String.valueOf(param.getHeadVid());
+            String mergeCate2 = this.findVideoMergeCate2(featureOriginVideo, headVidStr);
+            Double length = mergeWeight.getOrDefault("cate2CoefficientLength", 10000d);
+            Map<String, Double> simCateScore = this.findSimCateScore(mergeCate2, length.intValue());
+            cate2Coefficient.putAll(simCateScore);
+        }
 
-        double rosSpreadDivisorIndex = mergeWeight.getOrDefault("rosSpreadDivisorIndex", 2d);
-        String spreadDivisorKey = this.indexCoverKey(rosSpreadDivisorIndex);
-        log.info("567 spreadDivisorKey is: {}", spreadDivisorKey);
+        Double cate2CoefficientDenominator = mergeWeight.getOrDefault("cate2CoefficientDenominator", 1d);
 
+        List<Video> result = new ArrayList<>();
         for (RankItem item : items) {
             double score;
             double fmRovOrigin = item.getScoreRov();
             item.getScoresMap().put("fmRovOrigin", fmRovOrigin);
-            double str = restoreScore(fmRovOrigin);
-            item.getScoresMap().put("originStr", str);
-            str = this.handleStr(str, calcStrMode, item, mergeWeight);
-            item.getScoresMap().put("xgbRovNegRate", 0.9d);
-            item.getScoresMap().put("fmRov", str);
-            item.getScoresMap().put("str", str);
-            item.getScoresMap().put("calcStrMode", calcStrMode);
-
-            double originRos = Double.parseDouble(vid2MapFeature.getOrDefault(item.getVideoId() + "", new HashMap<>()).getOrDefault("rov", "0"));
-            double ros = this.handleRos(originRos, calcRosMode, item, mergeWeight);
-            item.getScoresMap().put("hasReturnRovScore", ros);
-            item.getScoresMap().put("ros", ros);
-            item.getScoresMap().put("originRos", originRos);
-            item.getScoresMap().put("calcRosMode", calcRosMode);
-
-            String spreadDivStr = rosSpreadDivMap.getOrDefault(String.valueOf(item.getVideoId()), new HashMap<>()).getOrDefault(spreadDivisorKey, "0");
-            double rosSpreadDiv = Double.parseDouble(spreadDivStr);
-            item.getScoresMap().put("rosSpreadDiv", rosSpreadDiv);
-
-            double originVor = Double.parseDouble(vid2MapFeature.getOrDefault(item.getVideoId() + "", new HashMap<>()).getOrDefault("vor", "0"));
-            double vor = this.handleVor(originVor, calcVorMode, item, mergeWeight);
-            item.getScoresMap().put("originVor", originVor);
+            double fmRov = restoreScore(fmRovOrigin);
+            item.getScoresMap().put("fmRov", fmRov);
+            double hasReturnRovScore = Double.parseDouble(vid2MapFeature.getOrDefault(item.getVideoId() + "", new HashMap<>()).getOrDefault("rov", "0"));
+            item.getScoresMap().put("hasReturnRovScore", hasReturnRovScore);
+            double vor = Double.parseDouble(vid2MapFeature.getOrDefault(item.getVideoId() + "", new HashMap<>()).getOrDefault("vor", "0"));
             item.getScoresMap().put("vor", vor);
-            item.getScoresMap().put("calcVorMode", calcVorMode);
 
+            String vidMergeCate2 = this.findVideoMergeCate2(featureOriginVideo, String.valueOf(item.getVideoId()));
+            Double scoreCoefficient = cate2Coefficient.getOrDefault(vidMergeCate2, 0d);
+            item.getScoresMap().put("scoreCoefficient", scoreCoefficient);
+            item.getScoresMap().put("cate2CoefficientDenominator", cate2CoefficientDenominator);
 
-            item.getScoresMap().put("rosAdd", rosAdd);
-            item.getScoresMap().put("vorAdd", vorAdd);
-            item.getScoresMap().put("ros2Multi", ros2Multi);
-            item.getScoresMap().put("rosSpreadDivisorIndex", rosSpreadDivisorIndex);
-            score = str * (rosAdd + ros + ros2Multi * rosSpreadDiv) * (vorAdd + vor);
+            score = fmRov * (0.1 + hasReturnRovScore) * (0.1 + vor) * (1 + scoreCoefficient / cate2CoefficientDenominator);
 
             Video video = item.getVideo();
-            video.setScoreStr(str);
-            video.setScoreRos(rosAdd + ros + ros2Multi * rosSpreadDiv);
             video.setScore(score);
             video.setSortScore(score);
             video.setScoresMap(item.getScoresMap());
@@ -390,17 +377,47 @@ public class RankStrategy4RegionMergeModelV567 extends RankStrategy4RegionMergeM
         return result;
     }
 
-    private String indexCoverKey(double index) {
-        switch (String.valueOf(index)) {
-            case "1":
-                return "head_video_rov1";
-            case "3":
-                return "head_video_recommend_rovn";
-            case "4":
-                return "head_video_recommend_fission_rate";
-            default:
-                return "recommend_123_depth_fission_rate";
+    private Map<String, Double> findSimCateScore(String headCate2, int length) {
+        if (StringUtils.isBlank(headCate2)) {
+            return new HashMap<>();
+        }
+
+        String redisKey = String.format("alg_recsys_good_cate_pair_list:%s", headCate2);
+        String cate2Value = redisTemplate.opsForValue().get(redisKey);
+        if (StringUtils.isEmpty(cate2Value)) {
+            return new HashMap<>();
         }
+
+        return this.parsePair(cate2Value, length);
     }
 
-}
+    private Map<String, Double> parsePair(String value, int length) {
+        if (StringUtils.isBlank(value)) {
+            return new HashMap<>();
+        }
+
+        String[] split = value.split("\t");
+        if (split.length != 2) {
+            return new HashMap<>();
+        }
+
+        String[] valueList = split[0].trim().split(",");
+        String[] scoreList = split[1].trim().split(",");
+        if (valueList.length != scoreList.length) {
+            return new HashMap<>();
+        }
+
+        int minLength = Math.min(length, valueList.length);
+        Map<String, Double> resultMap = new HashMap<>();
+        for (int i = 0; i < minLength; i++) {
+            resultMap.put(valueList[i].trim(), Double.parseDouble(scoreList[i].trim()));
+        }
+
+        return resultMap;
+    }
+
+    private String findVideoMergeCate2(Map<String, Map<String, Map<String, String>>> featureOriginVideo, String vid) {
+        Map<String, String> videoInfo = featureOriginVideo.getOrDefault(vid, new HashMap<>()).getOrDefault("alg_vid_feature_basic_info", new HashMap<>());
+        return videoInfo.get("merge_second_level_cate");
+    }
+}

+ 4 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/recall/RecallService.java

@@ -122,6 +122,10 @@ public class RecallService implements ApplicationContextAware {
             strategies.add(strategyMap.get(HeadProvinceCate2RecallStrategy.class.getSimpleName()));
         }
 
+        if (CollectionUtils.isNotEmpty(abExpCodes) && abExpCodes.contains("567")) {
+            strategies.add(strategyMap.get(HeadCate2RovRecallStrategy.class.getSimpleName()));
+        }
+
         // 命中用户黑名单不走流量池
         if (!param.isRiskUser()) {
             strategies.add(strategyMap.get(QuickFlowPoolWithLevelRecallStrategy.class.getSimpleName()));

+ 182 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/recall/strategy/HeadCate2RovRecallStrategy.java

@@ -0,0 +1,182 @@
+package com.tzld.piaoquan.recommend.server.service.recall.strategy;
+
+import com.ctrip.framework.apollo.spring.annotation.ApolloJsonValue;
+import com.tzld.piaoquan.recommend.server.model.Video;
+import com.tzld.piaoquan.recommend.server.service.FeatureService;
+import com.tzld.piaoquan.recommend.server.service.filter.FilterParam;
+import com.tzld.piaoquan.recommend.server.service.filter.FilterResult;
+import com.tzld.piaoquan.recommend.server.service.filter.FilterService;
+import com.tzld.piaoquan.recommend.server.service.recall.FilterParamFactory;
+import com.tzld.piaoquan.recommend.server.service.recall.RecallParam;
+import com.tzld.piaoquan.recommend.server.service.recall.RecallStrategy;
+import lombok.extern.slf4j.Slf4j;
+import org.apache.commons.collections4.CollectionUtils;
+import org.apache.commons.collections4.MapUtils;
+import org.apache.commons.lang3.StringUtils;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.beans.factory.annotation.Qualifier;
+import org.springframework.data.redis.core.RedisTemplate;
+import org.springframework.stereotype.Component;
+
+import java.util.*;
+import java.util.stream.Collectors;
+
+/**
+ * 头部视频,类似品类召回
+ */
+@Slf4j
+@Component
+public class HeadCate2RovRecallStrategy implements RecallStrategy {
+
+    @Autowired
+    private FilterService filterService;
+    @Autowired
+    @Qualifier("redisTemplate")
+    public RedisTemplate<String, String> redisTemplate;
+
+    @Autowired
+    private FeatureService featureService;
+
+    @ApolloJsonValue("${head.cate2.rov.recall.param:{}}")
+    private Map<String, String> paramJson;
+
+    public static final String PUSH_FROM = "recall_strategy_head_cate2_rov";
+
+    private static final String SIM_MERGE_CATE2_KEY_FORMAT = "alg_recsys_good_cate_pair_list:%s";
+    private static final String MERGE_CATE2_VIDEO_LIST_KEY_FORMAT = "alg_recsys_recall_good_cate_pair_rovn:%s";
+
+    @Override
+    public List<Video> recall(RecallParam param) {
+        if (Objects.isNull(param.getVideoId())) {
+            return Collections.emptyList();
+        }
+
+        // 获取头部视频基础信息
+        String vidStr = String.valueOf(param.getVideoId());
+        Map<String, String> headVideoInfo = featureService.getHeadVideoInfo(vidStr);
+        if (MapUtils.isEmpty(headVideoInfo)) {
+            return Collections.emptyList();
+        }
+
+        // 不存在品类或无效品类
+        String mergeCate2 = headVideoInfo.get("merge_second_level_cate");
+        if (StringUtils.isBlank(mergeCate2) || "unknown".equals(mergeCate2)) {
+            return Collections.emptyList();
+        }
+
+        // 获取相似品类
+        String simCate2Key = String.format(SIM_MERGE_CATE2_KEY_FORMAT, mergeCate2);
+        String simCate2List = redisTemplate.opsForValue().get(simCate2Key);
+        if (Objects.isNull(simCate2List)) {
+            return Collections.emptyList();
+        }
+
+        List<Video> videoResult = new ArrayList<>();
+
+        int simCateLength = Integer.parseInt(paramJson.getOrDefault("sim_cate_length", "10000"));
+        Map<String, Double> mergeCate2Pair = this.parsePair(simCate2List, simCateLength);
+
+        Map<String, Map<String, Double>> recallVideoMap = this.cate2Recall(new ArrayList<>(mergeCate2Pair.keySet()));
+
+        // 过滤
+        List<Long> allVid = recallVideoMap.values().stream()
+                .map(Map::keySet)
+                .flatMap(Collection::stream)
+                .map(Long::parseLong)
+                .collect(Collectors.toList());
+
+        FilterParam filterParam = FilterParamFactory.create(param, allVid);
+        FilterResult filterResult = filterService.filter(filterParam);
+        Set<Long> filterVids = new HashSet<>(filterResult.getVideoIds());
+        filterVids.remove(param.getVideoId());
+
+        for (Map.Entry<String, Double> entry : mergeCate2Pair.entrySet()) {
+            String cate = entry.getKey();
+            Double cateScore = entry.getValue();
+
+            Map<String, Double> videoMap = recallVideoMap.getOrDefault(cate, new HashMap<>());
+            for (Map.Entry<String, Double> videoEntry : videoMap.entrySet()) {
+                long vid = Long.parseLong(videoEntry.getKey());
+
+                // 过滤之后不存在的视频,过滤掉
+                if (!filterVids.contains(vid)) {
+                    continue;
+                }
+
+                Double videoScore = videoEntry.getValue();
+
+                Video video = new Video();
+                video.setVideoId(vid);
+                video.setRovScore(cateScore * videoScore);
+                video.setPushFrom(PUSH_FROM);
+                videoResult.add(video);
+            }
+
+        }
+
+        videoResult.sort(Comparator.comparingDouble(o -> -o.getRovScore()));
+
+        return videoResult;
+    }
+
+    private Map<String, Double> parsePair(String value, int length) {
+        if (StringUtils.isBlank(value)) {
+            return new HashMap<>();
+        }
+
+        String[] split = value.split("\t");
+        if (split.length != 2) {
+            return new HashMap<>();
+        }
+
+        String[] valueList = split[0].trim().split(",");
+        String[] scoreList = split[1].trim().split(",");
+        if (valueList.length != scoreList.length) {
+            return new HashMap<>();
+        }
+
+        int minLength = Math.min(length, valueList.length);
+        Map<String, Double> resultMap = new HashMap<>();
+        for (int i = 0; i < minLength; i++) {
+            resultMap.put(valueList[i].trim(), Double.parseDouble(scoreList[i].trim()));
+        }
+
+        return resultMap;
+    }
+
+    private Map<String, Map<String, Double>> cate2Recall(List<String> mergeCate2List) {
+
+
+        List<String> redisKeys = mergeCate2List.stream().map(i -> String.format(MERGE_CATE2_VIDEO_LIST_KEY_FORMAT, i)).collect(Collectors.toList());
+        List<String> values = redisTemplate.opsForValue().multiGet(redisKeys);
+        if (CollectionUtils.isEmpty(values)) {
+            return new HashMap<>();
+        }
+
+        Map<String, Map<String, Double>> resultMap = new HashMap<>();
+
+        int recallVidLength = Integer.parseInt(paramJson.getOrDefault("recall_vid_length", "10000"));
+        for (int i = 0; i < mergeCate2List.size(); i++) {
+            String mergeCate2 = mergeCate2List.get(i);
+
+            String value = values.get(i);
+            if (StringUtils.isBlank(value)) {
+                continue;
+            }
+
+            Map<String, Double> recallVideoMap = this.parsePair(value, recallVidLength);
+            if (MapUtils.isEmpty(recallVideoMap)) {
+                continue;
+            }
+
+            resultMap.put(mergeCate2, recallVideoMap);
+        }
+
+        return resultMap;
+    }
+
+    @Override
+    public String pushFrom() {
+        return PUSH_FROM;
+    }
+}