Parcourir la source

feat:修改CVR校准逻辑

zhaohaipeng il y a 11 mois
Parent
commit
78ab47e9b5

+ 2 - 1
ad-engine-commons/src/main/java/com/tzld/piaoquan/ad/engine/commons/score/ScorerUtils.java

@@ -26,6 +26,7 @@ public final class ScorerUtils {
     public static String BREAK_CONFIG = "feeds_score_config_break.conf";
     public static String SHARE0_CONFIG = "feeds_score_config_share0.conf";
 
+    public static String CVR_ADJUSTING = "feeds_score_config_cvr_adjusting.conf";
 
     public static void warmUp() {
         log.info("scorer warm up ");
@@ -33,7 +34,7 @@ public final class ScorerUtils {
         ScorerUtils.init(THOMPSON_CONF);
         ScorerUtils.init(BREAK_CONFIG);
         ScorerUtils.init(SHARE0_CONFIG);
-
+        ScorerUtils.init(CVR_ADJUSTING);
     }
 
     private ScorerUtils() {

+ 27 - 0
ad-engine-server/src/main/resources/feeds_score_config_cvr_adjusting.conf

@@ -0,0 +1,27 @@
+scorer-config = {
+  lr-ctr-score-config = {
+    scorer-name = "com.tzld.piaoquan.ad.engine.service.score.VlogAdCtrLRScorer"
+    scorer-priority = 99
+    model-path = "ad_ctr_model/model_ad_ctr.txt"
+  }
+  lr-cvr-score-config = {
+      scorer-name = "com.tzld.piaoquan.ad.engine.service.score.VlogAdCvrLRScorer"
+      scorer-priority = 98
+      model-path = "ad_cvr_model/model_ad_cvr.txt"
+  }
+    tf-ctr-score-config = {
+      scorer-name = "com.tzld.piaoquan.ad.engine.service.score.VlogAdThompsonScorer"
+      scorer-priority = 97
+      model-path = "ad_thompson_model/model_ad_thompson.txt"
+    }
+    lr-cvr-adjusting-score-config = {
+          scorer-name = "com.tzld.piaoquan.ad.engine.service.score.VlogAdCvrLRAdjustingScorer"
+          scorer-priority = 96
+          model-path = "ad_cvr_model/cvr_adjusting_strategy_coefficient.txt"
+    }
+  lr-ecpm-merge-config = {
+      scorer-name = "com.tzld.piaoquan.ad.engine.service.score.VlogMergeEcpmScorer"
+      scorer-priority = 1
+  }
+
+}

+ 163 - 0
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/VlogAdCvrLRAdjustingScorer.java

@@ -0,0 +1,163 @@
+package com.tzld.piaoquan.ad.engine.service.score;
+
+
+import com.tzld.piaoquan.ad.engine.commons.score.AbstractScorer;
+import com.tzld.piaoquan.ad.engine.commons.score.ScoreParam;
+import com.tzld.piaoquan.ad.engine.commons.score.ScorerConfigInfo;
+import com.tzld.piaoquan.ad.engine.commons.score.model.CvrAdjustingModel;
+import com.tzld.piaoquan.recommend.feature.domain.ad.base.AdRankItem;
+import com.tzld.piaoquan.recommend.feature.domain.ad.base.AdRequestContext;
+import com.tzld.piaoquan.recommend.feature.domain.ad.base.UserAdFeature;
+import org.apache.commons.collections4.CollectionUtils;
+import org.apache.commons.lang.exception.ExceptionUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Objects;
+import java.util.concurrent.*;
+
+
+//@Service
+public class VlogAdCvrLRAdjustingScorer extends AbstractScorer {
+
+    private static final int LOCAL_TIME_OUT = 150;
+    private final static Logger LOGGER = LoggerFactory.getLogger(VlogAdCvrLRAdjustingScorer.class);
+    private static final ExecutorService executorService = Executors.newFixedThreadPool(8);
+
+    public VlogAdCvrLRAdjustingScorer(ScorerConfigInfo configInfo) {
+        super(configInfo);
+    }
+
+    @Override
+    public void loadModel() {
+        doLoadModel(CvrAdjustingModel.class);
+    }
+
+    @Override
+    public List<AdRankItem> scoring(final ScoreParam param,
+                                    final UserAdFeature userFeature,
+                                    final List<AdRankItem> rankItems) {
+
+        if (CollectionUtils.isEmpty(rankItems)) {
+            return rankItems;
+        }
+
+        long startTime = System.currentTimeMillis();
+        List<AdRankItem> result = rankByJava(rankItems, param.getRequestContext(), userFeature);
+        LOGGER.debug("ctr ranker time java items size={}, time={} ",
+                result.size(), System.currentTimeMillis() - startTime);
+
+        return result;
+    }
+
+    private List<AdRankItem> rankByJava(final List<AdRankItem> items,
+                                        final AdRequestContext requestContext,
+                                        final UserAdFeature user) {
+        long startTime = System.currentTimeMillis();
+        CvrAdjustingModel model = (CvrAdjustingModel) this.getModel();
+        LOGGER.debug("model size: [{}]", model.getModelSize());
+
+        // 所有都参与打分,按照cvr排序
+        multipleScore(items, requestContext, model);
+
+        // debug log
+        if (LOGGER.isDebugEnabled()) {
+            for (AdRankItem item : items) {
+                LOGGER.debug("after enter feeds model predict cvr adjusting score [{}] [{}]", item, item.getScore());
+            }
+        }
+
+        LOGGER.debug("ctr ranker java execute time: [{}]", System.currentTimeMillis() - startTime);
+        LOGGER.debug("[ctr ranker time java] items size={}, cost={} ",
+                items.size(), System.currentTimeMillis() - startTime);
+        return items;
+    }
+
+
+    /**
+     * 校准cvr
+     */
+    public void calcScore(final CvrAdjustingModel lrModel,
+                            final AdRankItem item,
+                            final AdRequestContext requestContext) {
+
+        double pro = item.getCvr();
+        try {
+            Double coef = lrModel.getAdjustingCoefficien(pro);
+            if (Objects.nonNull(coef)) {
+                LOGGER.info("[VlogAdCvrLRAdjustingScorer.cvr adjusting] before: {}", pro);
+                pro = pro / coef;
+                LOGGER.info("[VlogAdCvrLRAdjustingScorer.cvr adjusting] after: {}, coef: {}", pro, coef);
+
+            }
+
+        } catch (
+                Exception e) {
+            LOGGER.error("score error for doc={} exception={}",
+                    item.getAdId(), ExceptionUtils.getFullStackTrace(e));
+        }
+        item.setCvr(pro);
+    }
+
+
+    /**
+     * 并行打分
+     *
+     * @param items
+     * @param userInfoBytes
+     * @param requestContext
+     * @param model
+     */
+    private void multipleScore(final List<AdRankItem> items,
+                               final AdRequestContext requestContext,
+                               final CvrAdjustingModel model) {
+
+        List<Callable<Object>> calls = new ArrayList<Callable<Object>>();
+        for (int index = 0; index < items.size(); index++) {
+            final int fIndex = index;
+            calls.add(new Callable<Object>() {
+                @Override
+                public Object call() throws Exception {
+                    try {
+                        calcScore(model, items.get(fIndex), requestContext);
+                    } catch (
+                            Exception e) {
+                        LOGGER.error("ctr exception: [{}] [{}]", items.get(fIndex).adId, ExceptionUtils.getFullStackTrace(e));
+                    }
+                    return new Object();
+                }
+            });
+        }
+
+        List<Future<Object>> futures = null;
+        try {
+            futures = executorService.invokeAll(calls, LOCAL_TIME_OUT, TimeUnit.MILLISECONDS);
+        } catch (
+                InterruptedException e) {
+            LOGGER.error("execute invoke fail: {}", ExceptionUtils.getFullStackTrace(e));
+        }
+
+        // 等待所有请求的结果返回, 超时也返回
+        int cancel = 0;
+        if (futures != null) {
+            for (Future<Object> future : futures) {
+                try {
+                    if (!future.isDone() || future.isCancelled() || future.get() == null) {
+                        cancel++;
+                    }
+                } catch (
+                        InterruptedException e) {
+                    LOGGER.error("InterruptedException {},{}",
+                            requestContext.getApptype(), ExceptionUtils.getFullStackTrace(e));
+                } catch (
+                        ExecutionException e) {
+                    LOGGER.error("ExecutionException {},{}",
+                            requestContext.getApptype(), ExceptionUtils.getFullStackTrace(e));
+                }
+            }
+        }
+        LOGGER.debug("Ctr Score {}, Total: {}, Cancel: {}", requestContext.getApptype(), items.size(), cancel);
+    }
+}

+ 4 - 38
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/VlogAdCvrLRScorer.java

@@ -4,7 +4,6 @@ package com.tzld.piaoquan.ad.engine.service.score;
 import com.tzld.piaoquan.ad.engine.commons.score.BaseLRModelScorer;
 import com.tzld.piaoquan.ad.engine.commons.score.ScoreParam;
 import com.tzld.piaoquan.ad.engine.commons.score.ScorerConfigInfo;
-import com.tzld.piaoquan.ad.engine.commons.score.model.CvrAdjustingModel;
 import com.tzld.piaoquan.ad.engine.commons.score.model.LRModel;
 import com.tzld.piaoquan.recommend.feature.domain.ad.base.*;
 import com.tzld.piaoquan.recommend.feature.domain.ad.feature.VlogAdCtrLRFeatureExtractor;
@@ -15,7 +14,6 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.util.ArrayList;
-import java.util.Collections;
 import java.util.List;
 import java.util.Objects;
 import java.util.concurrent.*;
@@ -33,30 +31,10 @@ public class VlogAdCvrLRScorer extends BaseLRModelScorer {
     private static final int enterFeedsScoreRatio = 10;
     private static final int enterFeedsScoreNum = 20;
 
-    // 校准策略数据的OSS文件路径
-    public static final String OSS_FILE_PATH = "ad_cvr_model/cvr_adjusting_strategy_coefficient.txt";
-
-
     public VlogAdCvrLRScorer(ScorerConfigInfo configInfo) {
         super(configInfo);
     }
 
-    /**
-     * 因CVR有校验策略,除了加载CVR自己的Model外还需要加载校准策略的数据Model
-     * <br />
-     * 故在此重写loadModel方法,加载校验策略的Model
-     */
-    @Override
-    public void loadModel() {
-        super.loadModel();
-        try {
-            modelManager.registerModel(OSS_FILE_PATH, OSS_FILE_PATH, CvrAdjustingModel.class);
-        } catch (
-                Exception e) {
-            LOGGER.error("加载校准策略数据异常: ", e);
-        }
-    }
-
     @Override
     public List<AdRankItem> scoring(final ScoreParam param,
                                     final UserAdFeature userFeature,
@@ -90,10 +68,8 @@ public class VlogAdCvrLRScorer extends BaseLRModelScorer {
         UserAdBytesFeature userInfoBytes = null;
         userInfoBytes = new UserAdBytesFeature(user);
 
-        CvrAdjustingModel adjustingModel = (CvrAdjustingModel) modelManager.getModel(OSS_FILE_PATH);
-
         // 所有都参与打分,按照cvr排序
-        multipleScore(items, userInfoBytes, requestContext, model, adjustingModel);
+        multipleScore(items, userInfoBytes, requestContext, model);
 
         // debug log
         if (LOGGER.isDebugEnabled()) {
@@ -115,8 +91,7 @@ public class VlogAdCvrLRScorer extends BaseLRModelScorer {
     public double calcScore(final LRModel lrModel,
                             final AdRankItem item,
                             final UserAdBytesFeature userInfoBytes,
-                            final AdRequestContext requestContext,
-                            final CvrAdjustingModel adjustingModel) {
+                            final AdRequestContext requestContext) {
 
         LRSamples lrSamples = null;
         VlogAdCtrLRFeatureExtractor bytesFeatureExtractor;
@@ -137,14 +112,6 @@ public class VlogAdCvrLRScorer extends BaseLRModelScorer {
             try {
                 pro = lrModel.score(lrSamples);
 
-                // CVR校准
-                Double coef = adjustingModel.getAdjustingCoefficien(pro);
-                if (Objects.nonNull(coef)) {
-                    LOGGER.info("[VlogAdCvrLRScorer.cvr adjusting] before: {}", pro);
-                    pro = pro / coef;
-                    LOGGER.info("[VlogAdCvrLRScorer.cvr adjusting] after: {}, coef: {}", pro, coef);
-                }
-
             } catch (Exception e) {
                 LOGGER.error("score error for doc={} exception={}", new Object[]{
                         item.getAdId(), ExceptionUtils.getFullStackTrace(e)});
@@ -167,8 +134,7 @@ public class VlogAdCvrLRScorer extends BaseLRModelScorer {
     private void multipleScore(final List<AdRankItem> items,
                                   final UserAdBytesFeature userInfoBytes,
                                   final AdRequestContext requestContext,
-                                  final LRModel model,
-                               final CvrAdjustingModel adjustingModel) {
+                                  final LRModel model) {
 
         List<Callable<Object>> calls = new ArrayList<Callable<Object>>();
         for (int index = 0; index < items.size(); index++) {
@@ -178,7 +144,7 @@ public class VlogAdCvrLRScorer extends BaseLRModelScorer {
                 @Override
                 public Object call() throws Exception {
                     try {
-                        calcScore(model, items.get(fIndex), userInfoBytes, requestContext, adjustingModel);
+                        calcScore(model, items.get(fIndex), userInfoBytes, requestContext);
                     } catch (Exception e) {
                         LOGGER.error("ctr exception: [{}] [{}]", items.get(fIndex).adId, ExceptionUtils.getFullStackTrace(e));
                     }

+ 1 - 0
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/VlogMergeEcpmScorer.java

@@ -109,6 +109,7 @@ public class VlogMergeEcpmScorer extends BaseLRModelScorer {
             double bid2 = item.getBid2();
             double pctr = item.getCtr();
             double pcvr = item.getCvr();
+            LOGGER.info("VlogMergeEcmpScore.pcvr: {}", pcvr);
 //            item.setScore_type( isTfType?1:0);
             item.setScore_type( 0);
             //todo

+ 7 - 2
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/impl/RankServiceImpl.java

@@ -131,8 +131,13 @@ public class RankServiceImpl implements RankService {
             }
         }
 
-        //兜底方案
-        List<AdRankItem> rankResult=rank(param, userAdFeature, rankItems,ScorerUtils.BASE_CONF);
+        // 兜底方案
+        List<AdRankItem> rankResult;
+        if (inCpcPidExp) {
+            rankResult = rank(param, userAdFeature, rankItems, ScorerUtils.CVR_ADJUSTING);
+        } else {
+            rankResult = rank(param, userAdFeature, rankItems, ScorerUtils.BASE_CONF);
+        }
 
         if (!CollectionUtils.isEmpty(rankResult)) {
             JSONObject object=new JSONObject();