Kaynağa Gözat

feat:汤姆森实验

zhaohaipeng 1 yıl önce
ebeveyn
işleme
c9c8f159c0

+ 142 - 4
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/VideoAdThompsonScorerV2.java

@@ -5,6 +5,7 @@ import com.tzld.piaoquan.ad.engine.commons.redis.AlgorithmRedisHelper;
 import com.tzld.piaoquan.ad.engine.commons.score.ScoreParam;
 import com.tzld.piaoquan.ad.engine.service.score.dto.AdPlatformCreativeDTO;
 import com.tzld.piaoquan.recommend.feature.domain.ad.base.AdRankItem;
+import org.apache.commons.lang3.StringUtils;
 import org.apache.commons.math3.distribution.BetaDistribution;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -13,6 +14,7 @@ import org.springframework.beans.factory.annotation.Value;
 import org.springframework.stereotype.Component;
 
 import java.util.*;
+import java.util.stream.Collectors;
 
 @Component
 public class VideoAdThompsonScorerV2 {
@@ -29,6 +31,9 @@ public class VideoAdThompsonScorerV2 {
     private Map<String,Double> exp664Param=new HashMap<>();
     private Map<String,Double> exp665Param=new HashMap<>();
     private Map<String,Double> exp666Param=new HashMap<>();
+    private Map<String,Double> exp669Param=new HashMap<>();
+    private Map<String,Double> exp670Param=new HashMap<>();
+
     Random random=new Random();
     Gson gson=new Gson();
     public List<AdRankItem> thompsonScorerByExp663(ScoreParam param, List<AdPlatformCreativeDTO> adIdList){
@@ -240,19 +245,115 @@ public class VideoAdThompsonScorerV2 {
         return map;
     }
 
-    class CreativeStatistic{
-        public CreativeStatistic() {
+
+    public List<AdRankItem> thompsonScorerByExp669(ScoreParam param, List<AdPlatformCreativeDTO> adIdList) {
+        Map<Long, CreativeStatistic> creativeStatisticsMap = this.batchFindCreativeRedisCache(redisCreativeStatisticsPrefix, adIdList);
+        Map<Long, CreativeStatistic> videoCreativeStatisticsMap = this.batchFindCreativeRedisCache(redisVideoCreativeStatisticsPrefix + param.getVideoId() + "_", adIdList);
+        Double creativeExpSum = this.sumCreativeStatisticExp(creativeStatisticsMap.values());
+        Double videoCreativeExpSum = this.sumCreativeStatisticExp(videoCreativeStatisticsMap.values());
+
+        List<AdRankItem> result = new ArrayList<>(adIdList.size());
+        this.calcScore(result, adIdList,1d, creativeExpSum, videoCreativeExpSum, creativeStatisticsMap, videoCreativeStatisticsMap, exp669Param);
+        result.sort(equalsRandomComparator());
+
+        return result;
+    }
+
+    public List<AdRankItem> thompsonScorerByExp670(ScoreParam param, List<AdPlatformCreativeDTO> adIdList) {
+        Map<Long, CreativeStatistic> creativeStatisticsMap = this.batchFindCreativeRedisCache(redisCreativeStatisticsPrefix, adIdList);
+        Map<Long, CreativeStatistic> videoCreativeStatisticsMap = this.batchFindCreativeRedisCache(redisVideoCreativeStatisticsPrefix + param.getVideoId() + "_", adIdList);
+
+        List<AdRankItem> result = new ArrayList<>(adIdList.size());
+        this.calcScore(result, adIdList, 0.0, 0.0, 0.0, creativeStatisticsMap, videoCreativeStatisticsMap, exp670Param);
+        result.sort(equalsRandomComparator());
+
+        return result;
+    }
+
+    private void calcScore(List<AdRankItem> result, List<AdPlatformCreativeDTO> adIdList,Double alpha, Double cidBeta, Double vidCidBeta, Map<Long, CreativeStatistic> cidMap, Map<Long, CreativeStatistic> vidCidMap, Map<String, Double> expParam) {
+        for (AdPlatformCreativeDTO dto : adIdList) {
+
+            CreativeStatistic cidStatistic = cidMap.getOrDefault(dto.getCreativeId(), new CreativeStatistic());
+            CreativeStatistic vidCidStatistic = vidCidMap.getOrDefault(dto.getCreativeId(), new CreativeStatistic());
+
+            double cidScore = this.calcThompsonScore(expParam, cidStatistic, alpha, cidBeta);
+            double vidCidScore = this.calcThompsonScore(expParam, vidCidStatistic, alpha, vidCidBeta);
+            double w1 = expParam.getOrDefault("w1", 1d);
+            double w2 = expParam.getOrDefault("w2", 2d);
+            double score = w1 * vidCidScore + w2 * cidScore;
+
+            AdRankItem item = new AdRankItem();
+            item.setCpa(dto.getCpa());
+            item.setAdId(dto.getCreativeId());
+            item.setScore(score);
+
+            result.add(item);
         }
+    }
 
-        public void setCpa(String cpa) {
-            this.cpa = cpa;
+
+    private Map<Long, CreativeStatistic> batchFindCreativeRedisCache(String keyPrefix, List<AdPlatformCreativeDTO> adIdList) {
+        Map<Long, CreativeStatistic> resultMap = new HashMap<>();
+        for (AdPlatformCreativeDTO dto : adIdList) {
+            String redisKey = keyPrefix + dto.getCreativeId();
+            String value = redisHelper.get(redisKey);
+            if (StringUtils.isNotBlank(value)) {
+                resultMap.put(dto.getCreativeId(), gson.fromJson(value, CreativeStatistic.class));
+            }
         }
+        return resultMap;
+    }
+
+    private double calcThompsonScore(Map<String, Double> expParam, CreativeStatistic creativeStatistic, Double defaultAlpha, Double defaultBeta) {
+        Double alpha = expParam.getOrDefault("alpha", defaultAlpha);
+        Double beta = expParam.getOrDefault("beta", defaultBeta);
+
+        double order = creativeStatistic.getDoubleOrder() + alpha;
+        double exp = creativeStatistic.getDoubleExp() + beta;
+
+        if (order == 0 || exp == 0) {
+            return 0.0;
+        }
+
+        return this.betaSampler(order, exp);
+    }
+
+    private Double sumCreativeStatisticExp(Collection<CreativeStatistic> creativeStatistics) {
+        return creativeStatistics.stream()
+                .map(CreativeStatistic::getExp)
+                .filter(StringUtils::isNotBlank)
+                .collect(Collectors.summarizingDouble(Double::parseDouble))
+                .getSum();
+    }
+
+    private Comparator<AdRankItem> equalsRandomComparator(){
+        return new Comparator<AdRankItem>() {
+            @Override
+            public int compare(AdRankItem o1, AdRankItem o2) {
+                if (o1.getScore() < o2.getScore()) {
+                    return 1;
+                } else if (o1.getScore() > o2.getScore()) {
+                    return -1;
+                }
+                return random.nextInt(20) - 10;
+            }
+        };
+    }
+
+    class CreativeStatistic{
 
         private String exp;
         private String click;
         private String order;
         private String cpa;
 
+        public CreativeStatistic() {
+        }
+
+        public void setCpa(String cpa) {
+            this.cpa = cpa;
+        }
+
         public String getExp() {
             if (exp == null || "".equals(exp)) {
                 return "0";
@@ -260,6 +361,13 @@ public class VideoAdThompsonScorerV2 {
             return exp;
         }
 
+        public Double getDoubleExp() {
+            if (StringUtils.isBlank(exp)) {
+                return 0.0;
+            }
+            return Double.parseDouble(exp);
+        }
+
         public void setExp(String exp) {
             this.exp = exp;
         }
@@ -271,6 +379,13 @@ public class VideoAdThompsonScorerV2 {
             return click;
         }
 
+        public Double getDoubleClick(){
+            if (StringUtils.isBlank(click)){
+                return 0.0;
+            }
+            return Double.parseDouble(click);
+        }
+
         public void setClick(String click) {
             this.click = click;
         }
@@ -282,6 +397,13 @@ public class VideoAdThompsonScorerV2 {
             return order;
         }
 
+        public Double getDoubleOrder(){
+            if (StringUtils.isBlank(order)){
+                return 0.0;
+            }
+            return Double.parseDouble(order);
+        }
+
         public void setOrder(String order) {
             this.order = order;
         }
@@ -292,6 +414,13 @@ public class VideoAdThompsonScorerV2 {
             }
             return cpa;
         }
+
+        public Double getDoubleCpa() {
+            if (StringUtils.isBlank(cpa)) {
+                return 0.0;
+            }
+            return Double.parseDouble(cpa);
+        }
     }
     double betaSampler(double alpha, double beta) {
         BetaDistribution betaSample = new BetaDistribution(alpha, beta);
@@ -313,4 +442,13 @@ public class VideoAdThompsonScorerV2 {
     public void setExp666Param(String str){
         this.exp663Param=gson.fromJson(str,Map.class);
     }
+
+    @Value("${ad.engine.new.thompson.exp.V2.666:{}}")
+    public void setExp669Param(String str){
+        this.exp663Param=gson.fromJson(str,Map.class);
+    }
+    @Value("${ad.engine.new.thompson.exp.V2.666:{}}")
+    public void setExp670Param(String str){
+        this.exp663Param=gson.fromJson(str,Map.class);
+    }
 }

+ 8 - 0
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/impl/RankServiceImpl.java

@@ -263,6 +263,14 @@ public class RankServiceImpl implements RankService {
                 (expCodes.contains(NewExpInfoHelper.flagId)&&NewExpInfoHelper.checkInNewExpGroupAndSetParamIfIn(
                         request.getAppType().toString(),request.getNewExpGroup(),"666",modelParam))){
             rankResult = videoAdThompsonScorerV2.thompsonScorerByExp666(param, request.getAdIdList());
+        } else if (expCodes.contains("669") ||
+                (expCodes.contains(NewExpInfoHelper.flagId) && NewExpInfoHelper.checkInNewExpGroupAndSetParamIfIn(
+                        request.getAppType().toString(), request.getNewExpGroup(), "669", modelParam))) {
+            rankResult = videoAdThompsonScorerV2.thompsonScorerByExp669(param, request.getAdIdList());
+        } else if (expCodes.contains("670") ||
+                (expCodes.contains(NewExpInfoHelper.flagId) && NewExpInfoHelper.checkInNewExpGroupAndSetParamIfIn(
+                        request.getAppType().toString(), request.getNewExpGroup(), "670", modelParam))) {
+            rankResult = videoAdThompsonScorerV2.thompsonScorerByExp670(param, request.getAdIdList());
         }
         log.info("RankServiceImpl.adItemRankWithVideoAdThompson.adIdList: {}, result: {}", JSON.toJSONString(request.getAdIdList()), JSON.toJSONString(rankResult));