ソースを参照

Merge branch 'feature_20240507_supeng_supply_ab' of algorithm/recommend-server into master

qingqu-git 11 ヶ月 前
コミット
55137a67e1

+ 4 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/common/base/Constant.java

@@ -15,4 +15,8 @@ public class Constant {
      * 流量池头部视频redis key
      */
     public static final String VIDEO_PERFORMANCE_DATA_REDIS_KEY = "video_performance_data_redis_key:";
+    /**
+     * 供给流量池实验 648 random
+     */
+    public static final String SUPPLY_AB_CODE = "60600";
 }

+ 18 - 5
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/RecommendService.java

@@ -4,6 +4,7 @@ import com.ctrip.framework.apollo.spring.annotation.ApolloJsonValue;
 import com.google.common.base.Stopwatch;
 import com.google.common.base.Strings;
 import com.google.common.reflect.TypeToken;
+import com.tzld.piaoquan.recommend.server.common.base.Constant;
 import com.tzld.piaoquan.recommend.server.common.enums.AppTypeEnum;
 import com.tzld.piaoquan.recommend.server.gen.common.Result;
 import com.tzld.piaoquan.recommend.server.gen.recommend.MachineInfoProto;
@@ -335,12 +336,21 @@ public class RecommendService {
 
         // 流量池分发实验组划分
         int flowPoolIdChoice = flowPoolIds.get(RandomUtils.nextInt(0, flowPoolIds.size()));
-        param.setFlowPoolId(flowPoolIdChoice);
-        param.setFlowPoolAbtestGroup("control_group");
         Map<String, List<Integer>> flowPoolConfig = flowPoolConfigService.getFlowPoolConfig();
-        for (Map.Entry<String, List<Integer>> entry : flowPoolConfig.entrySet()) {
-            if (entry.getValue().contains(flowPoolIdChoice)) {
-                param.setFlowPoolAbtestGroup(entry.getKey());
+        if (Objects.equals(Constant.SUPPLY_AB_CODE, param.getAbCode())) {
+            List<Integer> supplyFlowPoolIdList = flowPoolConfig.get(FlowPoolConstants.SUPPLY_FLOW_SET_LEVEL);
+            if (Objects.nonNull(supplyFlowPoolIdList) && !supplyFlowPoolIdList.isEmpty()) {
+                flowPoolIdChoice = supplyFlowPoolIdList.get(0);
+                param.setFlowPoolId(flowPoolIdChoice);
+                param.setFlowPoolAbtestGroup(FlowPoolConstants.SUPPLY_FLOW_SET_LEVEL);
+            }
+        } else {
+            param.setFlowPoolId(flowPoolIdChoice);
+            param.setFlowPoolAbtestGroup("control_group");
+            for (Map.Entry<String, List<Integer>> entry : flowPoolConfig.entrySet()) {
+                if (entry.getValue().contains(flowPoolIdChoice)) {
+                    param.setFlowPoolAbtestGroup(entry.getKey());
+                }
             }
         }
 
@@ -642,6 +652,9 @@ public class RecommendService {
         }
 
         switch (param.getFlowPoolAbtestGroup()) {
+            case FlowPoolConstants.SUPPLY_FLOW_SET_LEVEL:
+                flowPoolService.updateSupplyDistributeCountWithLevel(flowPoolVideos);
+                break;
             case FlowPoolConstants.EXPERIMENTAL_FLOW_SET_LEVEL:
                 flowPoolService.updateDistributeCountWithLevel(flowPoolVideos);
                 break;

+ 11 - 4
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/filter/FlowPoolWithLevelFilterService.java

@@ -1,5 +1,6 @@
 package com.tzld.piaoquan.recommend.server.service.filter;
 
+import com.tzld.piaoquan.recommend.server.common.base.Constant;
 import com.tzld.piaoquan.recommend.server.service.flowpool.FlowPoolService;
 import com.tzld.piaoquan.recommend.server.util.JSONUtils;
 import lombok.extern.slf4j.Slf4j;
@@ -9,9 +10,7 @@ import org.apache.commons.lang3.StringUtils;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.stereotype.Service;
 
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Map;
+import java.util.*;
 import java.util.stream.Collectors;
 
 /**
@@ -43,7 +42,15 @@ public class FlowPoolWithLevelFilterService extends AbstractFilterService {
                         v -> v,
                         v -> param.getFlowPoolMap().get(v)));
 
-        Map<Long, Integer> distributeCountMap = flowPoolService.getDistributeCountWithLevel(flowPoolMap);
+        Map<Long, Integer> distributeCountMap;
+        //供给流量池实验
+        if (Objects.equals(Constant.SUPPLY_AB_CODE, param.getAbCode())) {
+            distributeCountMap = flowPoolService.getSupplyDistributeCountWithLevel(flowPoolMap);
+        } else {
+            distributeCountMap = flowPoolService.getDistributeCountWithLevel(flowPoolMap);
+        }
+
+//        Map<Long, Integer> distributeCountMap = flowPoolService.getDistributeCountWithLevel(flowPoolMap);
 
         List<Long> remainVideoIds = new ArrayList<>();
         for (Long videoId : videoIds) {

+ 3 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/flowpool/FlowPoolConstants.java

@@ -6,11 +6,14 @@ package com.tzld.piaoquan.recommend.server.service.flowpool;
 public class FlowPoolConstants {
     public static final String EXPERIMENTAL_FLOW_SET_LEVEL = "experimental_flow_set_level";
     public static final String EXPERIMENTAL_FLOW_SET_LEVEL_SCORE = "experimental_flow_set_level_score";
+    public static final String SUPPLY_FLOW_SET_LEVEL = "supply_flow_set_level";
 
     public static final String PUSH_FORM = "flow_pool";
     public static final String QUICK_PUSH_FORM = "quick_flow_pool";
+    public static final String SUPPLY_PUSH_FORM = "supply_flow_pool";
 
     public static final String KEY_WITH_LEVEL_FORMAT = "flow:pool:level:item:%s:%s";
+    public static final String KEY_WITH_LEVEL_SUPPLY_FORMAT = "flow:pool:level:item:supply:%s:%s";
     public static final String KEY_QUICK_WITH_LEVEL_FORMAT = "flow:pool:quick:item:%s:3";
     public static final String KEY_WITH_LEVEL_SCORE_FORMAT = "flow:pool:level:item:score:%s:%s";
     public static final String KEY_QUICK_WITH_LEVEL_SCORE_FORMAT = "flow:pool:quick:item:score:%s:3";

+ 72 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/flowpool/FlowPoolService.java

@@ -33,6 +33,11 @@ public class FlowPoolService {
     private FlowPoolConfigService flowPoolConfigService;
 
     private final String localDistributeCountFormat = "flow:pool:local:distribute:count:%s:%s";
+    /**
+     * 供给池 本地缓存
+     * flow:pool:supply:local:distribute:count:{videoId}:{flowPool标记}
+     */
+    private final String supplyLocalDistributeCountFormat = "flow:pool:supply:local:distribute:count:%s:%s";
 
     public final String valueFormat = "%s-%s";
 
@@ -205,4 +210,71 @@ public class FlowPoolService {
             }
         });
     }
+
+    public Map<Long, Integer> getSupplyDistributeCountWithLevel(Map<Long, String> videoFlowPoolMap) {
+        if (MapUtils.isEmpty(videoFlowPoolMap)) {
+            return Collections.emptyMap();
+        }
+
+        Map<Long, Integer> result = getSupplyDistributeCount(videoFlowPoolMap);
+
+
+        // 处理脏数据:分发数<0
+        Map<Long, String> dirties = videoFlowPoolMap.entrySet().stream()
+                .filter(e -> result.get(e.getKey()) <= 0)
+                .collect(Collectors.toMap(
+                        e -> e.getKey(),
+                        e -> e.getValue()
+                ));
+        asyncDelSupplyDistributeCountWithLevel(dirties);
+
+        return result;
+    }
+
+    private Map<Long, Integer> getSupplyDistributeCount(Map<Long, String> videoFlowPoolMap) {
+        // 为了保证有序
+        List<Map.Entry<Long, String>> entries = videoFlowPoolMap.entrySet().stream()
+                .sorted(Comparator.comparingLong(e -> e.getKey()))
+                .collect(Collectors.toList());
+
+        List<String> keys = entries.stream()
+                .map(v -> String.format(supplyLocalDistributeCountFormat, v.getKey(), v.getValue()))
+                .collect(Collectors.toList());
+        List<String> counts = redisTemplate.opsForValue().multiGet(keys);
+        Map<Long, Integer> result = new HashMap<>();
+        for (int i = 0; i < entries.size(); i++) {
+            result.put(entries.get(i).getKey(), NumberUtils.toInt(counts.get(i), 0));
+        }
+        return result;
+    }
+
+    private Map<Long, String> updateSupplyDistributeCount(List<Video> videos) {
+        // TODO 异步更新
+        Map<Long, String> removeMap = new HashMap<>();
+        videos.stream().forEach(v -> {
+            String key = String.format(supplyLocalDistributeCountFormat, v.getVideoId(), v.getFlowPool());
+            Long count = redisTemplate.opsForValue().decrement(key);
+            if (count <= 0) {
+                removeMap.put(v.getVideoId(), v.getFlowPool());
+            }
+        });
+        return removeMap;
+    }
+
+    public void updateSupplyDistributeCountWithLevel(List<Video> videos) {
+        if (CollectionUtils.isEmpty(videos)) {
+            return;
+        }
+        Map<Long, String> removeMap = updateSupplyDistributeCount(videos);
+
+        asyncDelSupplyDistributeCountWithLevel(removeMap);
+
+    }
+
+    private void asyncDelSupplyDistributeCountWithLevel(Map<Long, String> videoFlowPoolMap) {
+        asyncDelDistributeCount(videoFlowPoolMap, (appType, level, values) -> {
+            String key = String.format(KEY_WITH_LEVEL_SUPPLY_FORMAT, appType, level);
+            Long count = redisTemplate.opsForSet().remove(key, values);
+        });
+    }
 }

+ 6 - 2
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/recall/RecallService.java

@@ -1,6 +1,7 @@
 package com.tzld.piaoquan.recommend.server.service.recall;
 import com.ctrip.framework.apollo.spring.annotation.ApolloJsonValue;
 import com.tzld.piaoquan.recommend.server.common.ThreadPoolFactory;
+import com.tzld.piaoquan.recommend.server.common.base.Constant;
 import com.tzld.piaoquan.recommend.server.common.enums.AppTypeEnum;
 import com.tzld.piaoquan.recommend.server.model.Video;
 import com.tzld.piaoquan.recommend.server.service.filter.strategy.BlacklistContainer;
@@ -184,8 +185,11 @@ public class RecallService implements ApplicationContextAware {
         }
         // 命中用户黑名单不走流量池
         if (!hitUserBlacklist || !isInBlacklist) {
-            //2:通过“流量池标记”控制“流量池召回子策略” 其中有9组会走EXPERIMENTAL_FLOW_SET_LEVEL 有1组会走EXPERIMENTAL_FLOW_SET_LEVEL_SCORE
-            if ("60116".equals(abCode)) {
+            if (Objects.equals(Constant.SUPPLY_AB_CODE, abCode)) {
+                // 供给流量池策略 648 实验 random
+                strategies.add(strategyMap.get(FlowPoolWithLevelSupplyRecallStrategy.class.getSimpleName()));
+                //2:通过“流量池标记”控制“流量池召回子策略” 其中有9组会走EXPERIMENTAL_FLOW_SET_LEVEL 有1组会走EXPERIMENTAL_FLOW_SET_LEVEL_SCORE
+            } else if ("60116".equals(abCode)) {
                 int lastDigit = param.getLastDigit();
                 String lastDigitAB = lastDigitAbcode != null ? lastDigitAbcode.getOrDefault(lastDigit, "default") : "default";
                 switch (lastDigitAB) {

+ 2 - 2
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/recall/strategy/AbstractFlowPoolWithLevelRecallStrategy.java

@@ -6,6 +6,7 @@ import com.tzld.piaoquan.recommend.server.service.filter.FlowPoolWithLevelFilter
 import com.tzld.piaoquan.recommend.server.service.recall.FilterParamFactory;
 import com.tzld.piaoquan.recommend.server.service.recall.RecallParam;
 import com.tzld.piaoquan.recommend.server.service.recall.RecallStrategy;
+import lombok.extern.slf4j.Slf4j;
 import org.apache.commons.collections4.CollectionUtils;
 import org.apache.commons.lang3.RandomUtils;
 import org.apache.commons.lang3.math.NumberUtils;
@@ -22,6 +23,7 @@ import java.util.Map;
 /**
  * @author dyp
  */
+@Slf4j
 public abstract class AbstractFlowPoolWithLevelRecallStrategy implements RecallStrategy {
     @Autowired
     @Qualifier("redisTemplate")
@@ -46,9 +48,7 @@ public abstract class AbstractFlowPoolWithLevelRecallStrategy implements RecallS
             String[] values = value.split("-");
             videoFlowPoolMap.put(NumberUtils.toLong(values[0], 0), values[1]);
         }
-
         FilterResult filterResult = filterService.filter(FilterParamFactory.create(param, videoFlowPoolMap));
-
         if (filterResult != null && CollectionUtils.isNotEmpty(filterResult.getVideoIds())) {
             filterResult.getVideoIds().stream().forEach(vid -> {
                 Video recallData = new Video();

+ 98 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/recall/strategy/FlowPoolWithLevelSupplyRecallStrategy.java

@@ -0,0 +1,98 @@
+package com.tzld.piaoquan.recommend.server.service.recall.strategy;
+
+import com.tzld.piaoquan.recommend.server.service.flowpool.FlowPoolConfigService;
+import com.tzld.piaoquan.recommend.server.service.flowpool.FlowPoolConstants;
+import com.tzld.piaoquan.recommend.server.service.recall.RecallParam;
+import lombok.Data;
+import lombok.extern.slf4j.Slf4j;
+import org.apache.commons.collections4.CollectionUtils;
+import org.apache.commons.lang3.RandomUtils;
+import org.apache.commons.lang3.tuple.Pair;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.stereotype.Service;
+
+import java.math.BigDecimal;
+import java.math.RoundingMode;
+import java.util.*;
+
+import static com.tzld.piaoquan.recommend.server.service.flowpool.FlowPoolConstants.KEY_WITH_LEVEL_SUPPLY_FORMAT;
+
+@Service
+@Slf4j
+public class FlowPoolWithLevelSupplyRecallStrategy extends AbstractFlowPoolWithLevelRecallStrategy {
+
+    @Autowired
+    private FlowPoolConfigService flowPoolConfigService;
+
+    @Override
+    Pair<String, String> flowPoolKeyAndLevel(RecallParam param) {
+        //# 1. 获取流量池各层级分发概率权重
+        Map<String, Double> levelWeightMap = flowPoolConfigService.getLevelWeight();
+        // 2. 判断各层级是否有视频需分发
+        List<LevelWeight> availableLevels = new ArrayList<>();
+        for (Map.Entry<String, Double> entry : levelWeightMap.entrySet()) {
+            String levelKey = String.format(KEY_WITH_LEVEL_SUPPLY_FORMAT, param.getAppType(), entry.getKey());
+            if (redisTemplate.hasKey(levelKey)) {
+                LevelWeight lw = new LevelWeight();
+                lw.setLevel(entry.getKey());
+                lw.setLevelKey(levelKey);
+                lw.setWeight(entry.getValue());
+                availableLevels.add(lw);
+            }
+        }
+        if (CollectionUtils.isEmpty(availableLevels)) {
+            return Pair.of("", "");
+        }
+
+        // 3. 根据可分发层级权重设置分发概率
+        Collections.sort(availableLevels, Comparator.comparingDouble(LevelWeight::getWeight));
+
+        double weightSum = availableLevels.stream().mapToDouble(o -> o.getWeight()).sum();
+        BigDecimal weightSumBD = new BigDecimal(weightSum);
+        double level_p_low = 0;
+        double weight_temp = 0;
+        double level_p_up = 0;
+        Map<String, LevelP> level_p_mapping = new HashMap<>();
+        for (LevelWeight lw : availableLevels) {
+            BigDecimal bd = new BigDecimal(weight_temp + lw.getWeight());
+            level_p_up = bd.divide(weightSumBD, 2, RoundingMode.HALF_UP).doubleValue();
+            LevelP levelP = new LevelP();
+            levelP.setMin(level_p_low);
+            levelP.setMax(level_p_up);
+            levelP.setLevelKey(lw.getLevelKey());
+            level_p_mapping.put(lw.level, levelP);
+            level_p_low = level_p_up;
+
+            weight_temp += lw.getWeight();
+        }
+
+        // 4. 随机生成[0,1)之间数,返回相应概率区间的key
+        double random_p = RandomUtils.nextDouble(0, 1);
+        for (Map.Entry<String, LevelP> entry : level_p_mapping.entrySet()) {
+            if (random_p >= entry.getValue().getMin()
+                    && random_p <= entry.getValue().getMax()) {
+                return Pair.of(entry.getValue().getLevelKey(), entry.getKey());
+            }
+        }
+        return Pair.of("", "");
+    }
+
+    @Data
+    static class LevelWeight {
+        private String level;
+        private String levelKey;
+        private Double weight;
+    }
+
+    @Data
+    static class LevelP {
+        private String levelKey;
+        private double min;
+        private double max;
+    }
+
+    @Override
+    public String pushFrom() {
+        return FlowPoolConstants.SUPPLY_PUSH_FORM;
+    }
+}