Ver código fonte

Merge branch 'dev-xym-add-PAI' into pre-master

xueyiming 3 meses atrás
pai
commit
5bf9ae60bb

+ 8 - 0
ad-engine-commons/src/main/java/com/tzld/piaoquan/ad/engine/commons/math/Const.java

@@ -0,0 +1,8 @@
+package com.tzld.piaoquan.ad.engine.commons.math;
+
+public class Const {
+    public static final double WILSON_ZSCORE = 1.96;
+    public static final double CTR_SMOOTH_BETA_FACTOR = 25;
+    public static final double CVR_SMOOTH_BETA_FACTOR = 10;
+    public static final double CTCVR_SMOOTH_BETA_FACTOR = 100;
+}

+ 2 - 0
ad-engine-commons/src/main/java/com/tzld/piaoquan/ad/engine/commons/score/ScorerUtils.java

@@ -29,6 +29,7 @@ public final class ScorerUtils {
     public static String XGBOOST_SCORE_CONF_683 = "ad_score_config_xgboost_683.conf";
     public static String XGBOOST_SCORE_CONF_20240909 = "ad_score_config_xgboost_20240909.conf";
     public static String XGBOOST_SCORE_CONF_20241105 = "ad_score_config_xgboost_20241105.conf";
+    public static String PAI_SCORE_CONF_20250214 = "ad_score_config_pai_20250214.conf";
 
     public static void warmUp() {
         log.info("scorer warm up ");
@@ -39,6 +40,7 @@ public final class ScorerUtils {
         ScorerUtils.init(XGBOOST_SCORE_CONF);
         ScorerUtils.init(XGBOOST_SCORE_CONF_20240909);
         ScorerUtils.init(XGBOOST_SCORE_CONF_20241105);
+        ScorerUtils.init(PAI_SCORE_CONF_20250214);
     }
 
     private ScorerUtils() {

Diferenças do arquivo suprimidas por serem muito extensas
+ 32 - 0
ad-engine-commons/src/main/java/com/tzld/piaoquan/ad/engine/commons/score/model/PAIModel.java


+ 18 - 0
ad-engine-commons/src/main/java/com/tzld/piaoquan/ad/engine/commons/util/NumUtil.java

@@ -13,6 +13,24 @@ public class NumUtil {
         return d1 / d2;
     }
 
+    public static double divSmoothV1(double a, double b, double zscore) {
+        // Wilson Smoothing
+        if (a == 0 || b == 0) {
+            return 0d;
+        }
+        double zscore2 = zscore * zscore;
+        double p = a / b;
+        double numerator = p + zscore2 / (2 * b) - zscore * Math.sqrt((p * (1 - p) + zscore2 / (4 * b)) / b);
+        double denominator = 1 + zscore2 / b;
+        return numerator / denominator;
+    }
+
+    public static double divSmoothV2(double a, double b, double beta) {
+        if (a == 0 || b == 0) {
+            return 0d;
+        }
+        return a / (b + beta);
+    }
 
     public static double log(double a) {
         if (a <= 0) {

Diferenças do arquivo suprimidas por serem muito extensas
+ 11 - 0
ad-engine-server/src/main/resources/20250113_ad_bucket_688.txt


+ 6 - 0
ad-engine-server/src/main/resources/ad_score_config_pai_20250214.conf

@@ -0,0 +1,6 @@
+scorer-config = {
+  pai-score-config = {
+    scorer-name = "com.tzld.piaoquan.ad.engine.service.score.scorer.PAIScorer"
+    scorer-priority = 99
+  }
+}

+ 1 - 1
ad-engine-server/src/main/resources/ad_score_config_xgboost_683.conf

@@ -2,6 +2,6 @@ scorer-config = {
   xgb-score-config = {
     scorer-name = "com.tzld.piaoquan.ad.engine.service.score.scorer.XGBoostScorer683"
     scorer-priority = 99
-    model-path = "fengzhoutian/model_xgb_351_1000_30d_v1.tar.gz"
+    model-path = "zhangbo/model_xgb_351_1000_v2.tar.gz"
   }
 }

+ 99 - 0
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/scorer/PAIScorer.java

@@ -0,0 +1,99 @@
+package com.tzld.piaoquan.ad.engine.service.score.scorer;
+
+
+import com.google.common.collect.Lists;
+import com.tzld.piaoquan.ad.engine.commons.score.BaseXGBoostModelScorer;
+import com.tzld.piaoquan.ad.engine.commons.score.ScoreParam;
+import com.tzld.piaoquan.ad.engine.commons.score.ScorerConfigInfo;
+import com.tzld.piaoquan.ad.engine.commons.score.model.PAIModel;
+import com.tzld.piaoquan.ad.engine.commons.score.model.XGBoostModel683;
+import com.tzld.piaoquan.recommend.feature.domain.ad.base.AdRankItem;
+import com.tzld.piaoquan.recommend.feature.domain.ad.base.UserAdFeature;
+import org.apache.commons.collections4.CollectionUtils;
+import org.apache.commons.collections4.MapUtils;
+import org.apache.commons.lang.exception.ExceptionUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.stereotype.Component;
+
+import java.util.*;
+import java.util.concurrent.*;
+
+public class PAIScorer extends BaseXGBoostModelScorer {
+
+    private final static Logger LOGGER = LoggerFactory.getLogger(PAIScorer.class);
+
+
+    public PAIScorer(ScorerConfigInfo configInfo) {
+        super(configInfo);
+    }
+
+    @Override
+    public void loadModel() {
+        doLoadModel(XGBoostModel683.class);
+    }
+
+    @Override
+    public List<AdRankItem> scoring(final ScoreParam param,
+                                    final UserAdFeature userAdFeature,
+                                    final List<AdRankItem> rankItems) {
+        throw new NoSuchMethodError();
+    }
+
+    public List<AdRankItem> scoring(final Map<String, String> sceneFeatureMap,
+                                    final Map<String, String> userFeatureMap,
+                                    final List<AdRankItem> rankItems) {
+        if (CollectionUtils.isEmpty(rankItems)) {
+            return rankItems;
+        }
+
+        long startTime = System.currentTimeMillis();
+
+        List<AdRankItem> result = rankByJava(sceneFeatureMap, userFeatureMap, rankItems);
+
+        LOGGER.debug("ctr ranker time java items size={}, time={} ", result != null ? result.size() : 0,
+                System.currentTimeMillis() - startTime);
+
+        return result;
+    }
+
+    private List<AdRankItem> rankByJava(final Map<String, String> sceneFeatureMap,
+                                        final Map<String, String> userFeatureMap,
+                                        final List<AdRankItem> items) {
+        long startTime = System.currentTimeMillis();
+        PAIModel model = (PAIModel) this.getModel();
+        LOGGER.debug("model size: [{}]", model.getModelSize());
+        // 所有都参与打分,按照ctr排序
+        multipleCtrScore(items, userFeatureMap, sceneFeatureMap, model);
+
+        // debug log
+        if (LOGGER.isDebugEnabled()) {
+            for (int i = 0; i < items.size(); i++) {
+                LOGGER.debug("before enter feeds model predict ctr score [{}] [{}]", items.get(i), items.get(i));
+            }
+        }
+
+        Collections.sort(items);
+
+        LOGGER.debug("ctr ranker java execute time: [{}]", System.currentTimeMillis() - startTime);
+        LOGGER.debug("[ctr ranker time java] items size={}, cost={} ", items != null ? items.size() : 0,
+                System.currentTimeMillis() - startTime);
+        return items;
+    }
+
+    private void multipleCtrScore(final List<AdRankItem> items,
+                                  final Map<String, String> userFeatureMap,
+                                  final Map<String, String> sceneFeatureMap,
+                                  final PAIModel model) {
+
+        List<Float> score = model.score(items, userFeatureMap, sceneFeatureMap);
+        LOGGER.debug("PAIScorer score={}", score);
+        for (int i = 0; i < items.size(); i++) {
+            Double pro = Double.valueOf(score.get(i));
+            items.get(i).setLrScore(pro);
+            items.get(i).getScoreMap().put("ctcvrScore", pro);
+        }
+    }
+
+
+}

+ 37 - 20
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/strategy/RankStrategyBy683.java

@@ -17,6 +17,7 @@ import org.springframework.beans.factory.annotation.Value;
 import org.springframework.stereotype.Component;
 import org.xm.Similarity;
 
+import javax.annotation.PostConstruct;
 import java.io.BufferedReader;
 import java.io.IOException;
 import java.io.InputStream;
@@ -28,6 +29,8 @@ import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
 import java.util.stream.Collectors;
 
+import static com.tzld.piaoquan.ad.engine.commons.math.Const.*;
+
 @Slf4j
 @Component
 public class RankStrategyBy683 extends RankStrategyBasic {
@@ -46,6 +49,11 @@ public class RankStrategyBy683 extends RankStrategyBasic {
     @ApolloJsonValue("${rank.score.neg_sample_rate:0.01}")
     Double negSampleRate;
 
+    @PostConstruct
+    public void afterInit() {
+        this.readBucketFile();
+    }
+
     @Override
     public List<AdRankItem> adItemRank(RankRecommendRequestParam request, ScoreParam scoreParam) {
 
@@ -183,7 +191,6 @@ public class RankStrategyBy683 extends RankStrategyBasic {
 
         long time3 = System.currentTimeMillis();
         // 分桶
-        this.readBucketFile();
         userFeatureMap = this.featureBucket(userFeatureMap);
         CountDownLatch cdl4 = new CountDownLatch(adRankItems.size());
         for (AdRankItem adRankItem : adRankItems) {
@@ -204,10 +211,11 @@ public class RankStrategyBy683 extends RankStrategyBasic {
         long time4 = System.currentTimeMillis();
         // 打分排序
         // getScorerPipeline
-        List<AdRankItem> result = ScorerUtils.getScorerPipeline(ScorerUtils.XGBOOST_SCORE_CONF_683).scoring(sceneFeatureMap, userFeatureMap, adRankItems);
+        List<AdRankItem> result = ScorerUtils.getScorerPipeline(ScorerUtils.PAI_SCORE_CONF_20250214).scoring(sceneFeatureMap, userFeatureMap, adRankItems);
         long time5 = System.currentTimeMillis();
 
         // calibrate score for negative sampling
+        /* 02-11 update: 因模型换回基线无采样模型,取消校准
         for (AdRankItem item : result) {
             double originalScore = item.getLrScore();
             double calibratedScore = originalScore / (originalScore + (1 - originalScore) / negSampleRate);
@@ -215,6 +223,7 @@ public class RankStrategyBy683 extends RankStrategyBasic {
             item.getScoreMap().put("originCtcvrScore", originalScore);
             item.getScoreMap().put("ctcvrScore", calibratedScore);
         }
+        */
 
         // loop
         double cpmCoefficient = weightParam.getOrDefault("cpmCoefficient", 0.9);
@@ -322,17 +331,19 @@ public class RankStrategyBy683 extends RankStrategyBasic {
                 double click = Double.parseDouble(feature.getOrDefault("ad_click_" + time, "0"));
                 double conver = Double.parseDouble(feature.getOrDefault("ad_conversion_" + time, "0"));
                 double income = Double.parseDouble(feature.getOrDefault("ad_income_" + time, "0"));
-                double f2 = NumUtil.div(conver, view);
-                double ecpm = NumUtil.div(income * 1000, view);
-                cidFeatureMap.put(prefix + "_" + time + "_ctr", String.valueOf(NumUtil.div(click, view)));
-                cidFeatureMap.put(prefix + "_" + time + "_ctcvr", String.valueOf(f2));
-                cidFeatureMap.put(prefix + "_" + time + "_cvr", String.valueOf(NumUtil.div(conver, click)));
+                double cpc = NumUtil.div(income, click);
+                double ctr = NumUtil.divSmoothV2(click, view, CTR_SMOOTH_BETA_FACTOR);
+                double ctcvr = NumUtil.divSmoothV2(conver, view, CTCVR_SMOOTH_BETA_FACTOR);
+                double ecpm = ctr * cpc * 1000;
+                cidFeatureMap.put(prefix + "_" + time + "_ctr", String.valueOf(ctr));
+                cidFeatureMap.put(prefix + "_" + time + "_ctcvr", String.valueOf(ctcvr));
+                cidFeatureMap.put(prefix + "_" + time + "_cvr", String.valueOf(NumUtil.divSmoothV2(conver, click, CVR_SMOOTH_BETA_FACTOR)));
                 cidFeatureMap.put(prefix + "_" + time + "_conver", String.valueOf(conver));
                 cidFeatureMap.put(prefix + "_" + time + "_ecpm", String.valueOf(ecpm));
 
                 cidFeatureMap.put(prefix + "_" + time + "_click", String.valueOf(click));
                 cidFeatureMap.put(prefix + "_" + time + "_conver*log(view)", String.valueOf(conver * NumUtil.log(view)));
-                cidFeatureMap.put(prefix + "_" + time + "_conver*ctcvr", String.valueOf(conver * f2));
+                cidFeatureMap.put(prefix + "_" + time + "_conver*ctcvr", String.valueOf(conver * ctcvr));
             }
         }
 
@@ -355,17 +366,19 @@ public class RankStrategyBy683 extends RankStrategyBasic {
                 double click = Double.parseDouble(feature.getOrDefault("ad_click_" + time, "0"));
                 double conver = Double.parseDouble(feature.getOrDefault("ad_conversion_" + time, "0"));
                 double income = Double.parseDouble(feature.getOrDefault("ad_income_" + time, "0"));
-                double f2 = NumUtil.div(conver, view);
-                double ecpm = NumUtil.div(income * 1000, view);
-                cidFeatureMap.put(prefix + "_" + time + "_ctr", String.valueOf(NumUtil.div(click, view)));
-                cidFeatureMap.put(prefix + "_" + time + "_ctcvr", String.valueOf(f2));
-                cidFeatureMap.put(prefix + "_" + time + "_cvr", String.valueOf(NumUtil.div(conver, click)));
+                double cpc = NumUtil.div(income, click);
+                double ctr = NumUtil.divSmoothV2(click, view, CTR_SMOOTH_BETA_FACTOR);
+                double ctcvr = NumUtil.divSmoothV2(conver, view, CTCVR_SMOOTH_BETA_FACTOR);
+                double ecpm = ctr * cpc * 1000;
+                cidFeatureMap.put(prefix + "_" + time + "_ctr", String.valueOf(ctr));
+                cidFeatureMap.put(prefix + "_" + time + "_ctcvr", String.valueOf(ctcvr));
+                cidFeatureMap.put(prefix + "_" + time + "_cvr", String.valueOf(NumUtil.divSmoothV2(conver, click, CVR_SMOOTH_BETA_FACTOR)));
                 cidFeatureMap.put(prefix + "_" + time + "_conver", String.valueOf(conver));
                 cidFeatureMap.put(prefix + "_" + time + "_ecpm", String.valueOf(ecpm));
 
                 cidFeatureMap.put(prefix + "_" + time + "_click", String.valueOf(click));
                 cidFeatureMap.put(prefix + "_" + time + "_conver*log(view)", String.valueOf(conver * NumUtil.log(view)));
-                cidFeatureMap.put(prefix + "_" + time + "_conver*ctcvr", String.valueOf(conver * f2));
+                cidFeatureMap.put(prefix + "_" + time + "_conver*ctcvr", String.valueOf(conver * ctcvr));
             }
         }
 
@@ -449,11 +462,13 @@ public class RankStrategyBy683 extends RankStrategyBasic {
             double click = Double.parseDouble(d1Feature.getOrDefault("ad_click_" + prefix, "0"));
             double conver = Double.parseDouble(d1Feature.getOrDefault("ad_conversion_" + prefix, "0"));
             double income = Double.parseDouble(d1Feature.getOrDefault("ad_income_" + prefix, "0"));
-            featureMap.put("d1_feature_" + prefix + "_ctr", String.valueOf(NumUtil.div(click, view)));
-            featureMap.put("d1_feature_" + prefix + "_ctcvr", String.valueOf(NumUtil.div(conver, view)));
-            featureMap.put("d1_feature_" + prefix + "_cvr", String.valueOf(NumUtil.div(conver, click)));
+            double cpc = NumUtil.div(income, click);
+            double ctr = NumUtil.divSmoothV2(click, view, CTR_SMOOTH_BETA_FACTOR);
+            featureMap.put("d1_feature_" + prefix + "_ctr", String.valueOf(ctr));
+            featureMap.put("d1_feature_" + prefix + "_ctcvr", String.valueOf(NumUtil.divSmoothV2(conver, view, CTCVR_SMOOTH_BETA_FACTOR)));
+            featureMap.put("d1_feature_" + prefix + "_cvr", String.valueOf(NumUtil.divSmoothV2(conver, click, CVR_SMOOTH_BETA_FACTOR)));
             featureMap.put("d1_feature_" + prefix + "_conver", String.valueOf(conver));
-            featureMap.put("d1_feature_" + prefix + "_ecpm", String.valueOf(NumUtil.div(income * 1000, view)));
+            featureMap.put("d1_feature_" + prefix + "_ecpm", String.valueOf(ctr * cpc * 1000));
         }
     }
 
@@ -596,7 +611,8 @@ public class RankStrategyBy683 extends RankStrategyBasic {
             return;
         }
         synchronized (this) {
-            InputStream resourceStream = RankStrategyBy683.class.getClassLoader().getResourceAsStream("20240718_ad_bucket_688.txt");
+            String bucketFile = "20240718_ad_bucket_688.txt";
+            InputStream resourceStream = RankStrategyBy683.class.getClassLoader().getResourceAsStream(bucketFile);
             if (resourceStream != null) {
                 try (BufferedReader reader = new BufferedReader(new InputStreamReader(resourceStream))) {
                     Map<String, double[]> bucketsMap = new HashMap<>();
@@ -619,8 +635,9 @@ public class RankStrategyBy683 extends RankStrategyBasic {
                     this.bucketsMap = bucketsMap;
                     this.bucketsLen = bucketsLen;
                 } catch (IOException e) {
-                    log.error("something is wrong in parse bucket file:", e);
+                    log.error("something is wrong in parse bucket file: ", e);
                 }
+                log.info("load bucket file success: {}", bucketFile);
             } else {
                 log.error("no bucket file");
             }

+ 10 - 0
pom.xml

@@ -322,6 +322,16 @@
             <groupId>org.springframework.cloud</groupId>
             <artifactId>spring-cloud-starter-openfeign</artifactId>
         </dependency>
+
+        <!-- https://mvnrepository.com/artifact/com.aliyun.openservices.eas/eas-sdk -->
+        <dependency>
+            <groupId>com.aliyun.openservices.eas</groupId>
+            <artifactId>eas-sdk</artifactId>
+            <version>2.0.23</version>
+        </dependency>
+
+
+
         <!--easyexcel-->
     </dependencies>
 

Alguns arquivos não foram mostrados porque muitos arquivos mudaram nesse diff