Forráskód Böngészése

Merge branch 'master' into feature_gufengshou_20240513_pid_v7

zhaohaipeng 11 hónapja
szülő
commit
f533915c9a

+ 2 - 1
ad-engine-commons/src/main/java/com/tzld/piaoquan/ad/engine/commons/score/ScorerUtils.java

@@ -26,6 +26,7 @@ public final class ScorerUtils {
     public static String BREAK_CONFIG = "feeds_score_config_break.conf";
     public static String SHARE0_CONFIG = "feeds_score_config_share0.conf";
 
+    public static String CVR_ADJUSTING = "feeds_score_config_cvr_adjusting.conf";
 
     public static void warmUp() {
         log.info("scorer warm up ");
@@ -33,7 +34,7 @@ public final class ScorerUtils {
         ScorerUtils.init(THOMPSON_CONF);
         ScorerUtils.init(BREAK_CONFIG);
         ScorerUtils.init(SHARE0_CONFIG);
-
+        ScorerUtils.init(CVR_ADJUSTING);
     }
 
     private ScorerUtils() {

+ 89 - 0
ad-engine-commons/src/main/java/com/tzld/piaoquan/ad/engine/commons/score/model/CvrAdjustingModel.java

@@ -0,0 +1,89 @@
+package com.tzld.piaoquan.ad.engine.commons.score.model;
+
+import com.google.common.collect.HashBasedTable;
+import com.google.common.collect.Table;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.BufferedReader;
+import java.io.InputStreamReader;
+import java.math.BigDecimal;
+import java.math.RoundingMode;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.TreeMap;
+import java.util.stream.Collectors;
+
+public class CvrAdjustingModel extends Model {
+
+    private static final Logger LOGGER = LoggerFactory.getLogger(CvrAdjustingModel.class);
+
+    private Table<Double, Double, Double> table = HashBasedTable.create();
+
+    @Override
+    public int getModelSize() {
+        return table.size();
+    }
+
+    @Override
+    public boolean loadFromStream(InputStreamReader in) throws Exception {
+        Table<Double, Double, Double> initTable = HashBasedTable.create();
+        try (BufferedReader input = new BufferedReader(in)) {
+            String line;
+            int cnt = 0;
+            Map<Double, Double> initModel = new TreeMap<>();
+            while ((line = input.readLine()) != null) {
+                String[] items = line.split("\t");
+                if (items.length < 4) {
+                    continue;
+                }
+
+                double key = new BigDecimal(items[2]).doubleValue();
+                double value = new BigDecimal(items[3]).doubleValue();
+                initModel.put(key, value);
+            }
+
+            // 最终生成的格式为  区间最小值,区间最大值,系数
+            List<Double> keySet = initModel.keySet().stream().sorted().collect(Collectors.toList());
+            double preKey = 0.0;
+            for (Double key : keySet) {
+                initTable.put(preKey, key, initModel.get(key));
+                preKey = key;
+            }
+            initTable.put(preKey, Double.MAX_VALUE, initModel.get(preKey));
+
+            this.table = initTable;
+
+            for (Table.Cell<Double, Double, Double> cell : this.table.cellSet()) {
+                LOGGER.info("cell.row: {}, cell.column: {}, cell.value: {}", cell.getRowKey(), cell.getColumnKey(), cell.getValue());
+            }
+            
+            LOGGER.info("[CvrAdjustingModel] model load over and size {}", cnt);
+        } catch (
+                Exception e) {
+            LOGGER.info("[CvrAdjustingModel] model load error ", e);
+        } finally {
+            in.close();
+
+        }
+        return true;
+    }
+
+    public Double getAdjustingCoefficien(double score) {
+        if (Objects.isNull(table)) {
+            return 1.0;
+        }
+
+        for (Table.Cell<Double, Double, Double> cell : table.cellSet()) {
+            double rowKey = cell.getRowKey();
+            double columnKey = cell.getColumnKey();
+            if (rowKey <= score & score < columnKey) {
+                LOGGER.info("score {} in {} - {} , value is {}", score, rowKey, columnKey, cell.getValue());
+                return cell.getValue();
+            }
+        }
+
+        return 1.0;
+    }
+}

+ 27 - 0
ad-engine-server/src/main/resources/feeds_score_config_cvr_adjusting.conf

@@ -0,0 +1,27 @@
+scorer-config = {
+  lr-ctr-score-config = {
+    scorer-name = "com.tzld.piaoquan.ad.engine.service.score.VlogAdCtrLRScorer"
+    scorer-priority = 99
+    model-path = "ad_ctr_model/model_ad_ctr.txt"
+  }
+  lr-cvr-score-config = {
+      scorer-name = "com.tzld.piaoquan.ad.engine.service.score.VlogAdCvrLRScorer"
+      scorer-priority = 98
+      model-path = "ad_cvr_model/model_ad_cvr.txt"
+  }
+    tf-ctr-score-config = {
+      scorer-name = "com.tzld.piaoquan.ad.engine.service.score.VlogAdThompsonScorer"
+      scorer-priority = 97
+      model-path = "ad_thompson_model/model_ad_thompson.txt"
+    }
+    lr-cvr-adjusting-score-config = {
+          scorer-name = "com.tzld.piaoquan.ad.engine.service.score.VlogAdCvrLRAdjustingScorer"
+          scorer-priority = 96
+          model-path = "ad_cvr_model/cvr_adjusting_strategy_coefficient.txt"
+    }
+  lr-ecpm-merge-config = {
+      scorer-name = "com.tzld.piaoquan.ad.engine.service.score.VlogMergeEcpmScorer"
+      scorer-priority = 1
+  }
+
+}

+ 162 - 0
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/VlogAdCvrLRAdjustingScorer.java

@@ -0,0 +1,162 @@
+package com.tzld.piaoquan.ad.engine.service.score;
+
+
+import com.tzld.piaoquan.ad.engine.commons.score.AbstractScorer;
+import com.tzld.piaoquan.ad.engine.commons.score.ScoreParam;
+import com.tzld.piaoquan.ad.engine.commons.score.ScorerConfigInfo;
+import com.tzld.piaoquan.ad.engine.commons.score.model.CvrAdjustingModel;
+import com.tzld.piaoquan.recommend.feature.domain.ad.base.AdRankItem;
+import com.tzld.piaoquan.recommend.feature.domain.ad.base.AdRequestContext;
+import com.tzld.piaoquan.recommend.feature.domain.ad.base.UserAdFeature;
+import org.apache.commons.collections4.CollectionUtils;
+import org.apache.commons.lang.exception.ExceptionUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Objects;
+import java.util.concurrent.*;
+
+
+//@Service
+public class VlogAdCvrLRAdjustingScorer extends AbstractScorer {
+
+    private static final int LOCAL_TIME_OUT = 150;
+    private final static Logger LOGGER = LoggerFactory.getLogger(VlogAdCvrLRAdjustingScorer.class);
+    private static final ExecutorService executorService = Executors.newFixedThreadPool(8);
+
+    public VlogAdCvrLRAdjustingScorer(ScorerConfigInfo configInfo) {
+        super(configInfo);
+    }
+
+    @Override
+    public void loadModel() {
+        doLoadModel(CvrAdjustingModel.class);
+    }
+
+    @Override
+    public List<AdRankItem> scoring(final ScoreParam param,
+                                    final UserAdFeature userFeature,
+                                    final List<AdRankItem> rankItems) {
+
+        if (CollectionUtils.isEmpty(rankItems)) {
+            return rankItems;
+        }
+
+        long startTime = System.currentTimeMillis();
+        List<AdRankItem> result = rankByJava(rankItems, param.getRequestContext(), userFeature);
+        LOGGER.debug("ctr ranker time java items size={}, time={} ",
+                result.size(), System.currentTimeMillis() - startTime);
+
+        return result;
+    }
+
+    private List<AdRankItem> rankByJava(final List<AdRankItem> items,
+                                        final AdRequestContext requestContext,
+                                        final UserAdFeature user) {
+        long startTime = System.currentTimeMillis();
+        CvrAdjustingModel model = (CvrAdjustingModel) this.getModel();
+        LOGGER.debug("model size: [{}]", model.getModelSize());
+
+        // 所有都参与打分,按照cvr排序
+        multipleScore(items, requestContext, model);
+
+        // debug log
+        if (LOGGER.isDebugEnabled()) {
+            for (AdRankItem item : items) {
+                LOGGER.debug("after enter feeds model predict cvr adjusting score [{}] [{}]", item, item.getScore());
+            }
+        }
+
+        LOGGER.debug("[ctr ranker time java] items size={}, cost={} ",
+                items.size(), System.currentTimeMillis() - startTime);
+        return items;
+    }
+
+
+    /**
+     * 校准cvr
+     */
+    public void calcScore(final CvrAdjustingModel model,
+                            final AdRankItem item,
+                            final AdRequestContext requestContext) {
+
+        double pro = item.getCvr();
+        try {
+            Double coef = model.getAdjustingCoefficien(pro);
+            if (Objects.nonNull(coef)) {
+                LOGGER.info("[VlogAdCvrLRAdjustingScorer.cvr adjusting] before: {}", pro);
+                pro = pro / coef;
+                LOGGER.info("[VlogAdCvrLRAdjustingScorer.cvr adjusting] after: {}, coef: {}", pro, coef);
+
+            }
+
+        } catch (
+                Exception e) {
+            LOGGER.error("score error for doc={} exception={}",
+                    item.getAdId(), ExceptionUtils.getFullStackTrace(e));
+        }
+        item.setCvr(pro);
+    }
+
+
+    /**
+     * 并行打分
+     *
+     * @param items
+     * @param userInfoBytes
+     * @param requestContext
+     * @param model
+     */
+    private void multipleScore(final List<AdRankItem> items,
+                               final AdRequestContext requestContext,
+                               final CvrAdjustingModel model) {
+
+        List<Callable<Object>> calls = new ArrayList<Callable<Object>>();
+        for (int index = 0; index < items.size(); index++) {
+            final int fIndex = index;
+            calls.add(new Callable<Object>() {
+                @Override
+                public Object call() throws Exception {
+                    try {
+                        calcScore(model, items.get(fIndex), requestContext);
+                    } catch (
+                            Exception e) {
+                        LOGGER.error("ctr exception: [{}] [{}]", items.get(fIndex).adId, ExceptionUtils.getFullStackTrace(e));
+                    }
+                    return new Object();
+                }
+            });
+        }
+
+        List<Future<Object>> futures = null;
+        try {
+            futures = executorService.invokeAll(calls, LOCAL_TIME_OUT, TimeUnit.MILLISECONDS);
+        } catch (
+                InterruptedException e) {
+            LOGGER.error("execute invoke fail: {}", ExceptionUtils.getFullStackTrace(e));
+        }
+
+        // 等待所有请求的结果返回, 超时也返回
+        int cancel = 0;
+        if (futures != null) {
+            for (Future<Object> future : futures) {
+                try {
+                    if (!future.isDone() || future.isCancelled() || future.get() == null) {
+                        cancel++;
+                    }
+                } catch (
+                        InterruptedException e) {
+                    LOGGER.error("InterruptedException {},{}",
+                            requestContext.getApptype(), ExceptionUtils.getFullStackTrace(e));
+                } catch (
+                        ExecutionException e) {
+                    LOGGER.error("ExecutionException {},{}",
+                            requestContext.getApptype(), ExceptionUtils.getFullStackTrace(e));
+                }
+            }
+        }
+        LOGGER.debug("Ctr Score {}, Total: {}, Cancel: {}", requestContext.getApptype(), items.size(), cancel);
+    }
+}

+ 2 - 3
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/VlogAdCvrLRScorer.java

@@ -14,8 +14,8 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.util.ArrayList;
-import java.util.Collections;
 import java.util.List;
+import java.util.Objects;
 import java.util.concurrent.*;
 
 
@@ -31,12 +31,10 @@ public class VlogAdCvrLRScorer extends BaseLRModelScorer {
     private static final int enterFeedsScoreRatio = 10;
     private static final int enterFeedsScoreNum = 20;
 
-
     public VlogAdCvrLRScorer(ScorerConfigInfo configInfo) {
         super(configInfo);
     }
 
-
     @Override
     public List<AdRankItem> scoring(final ScoreParam param,
                                     final UserAdFeature userFeature,
@@ -113,6 +111,7 @@ public class VlogAdCvrLRScorer extends BaseLRModelScorer {
         if (lrSamples != null && lrSamples.getFeaturesList() != null) {
             try {
                 pro = lrModel.score(lrSamples);
+
             } catch (Exception e) {
                 LOGGER.error("score error for doc={} exception={}", new Object[]{
                         item.getAdId(), ExceptionUtils.getFullStackTrace(e)});

+ 1 - 0
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/VlogMergeEcpmScorer.java

@@ -109,6 +109,7 @@ public class VlogMergeEcpmScorer extends BaseLRModelScorer {
             double bid2 = item.getBid2();
             double pctr = item.getCtr();
             double pcvr = item.getCvr();
+            LOGGER.info("VlogMergeEcmpScore.pcvr: {}", pcvr);
 //            item.setScore_type( isTfType?1:0);
             item.setScore_type( 0);
             //todo

+ 3 - 0
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/container/PidLambdaForCpcContainer.java

@@ -233,6 +233,9 @@ public class PidLambdaForCpcContainer {
             }
             double derivative = (error - lastError) / 1; // 假设采样间隔为1
             lastError = error;
+            if(lambda<0){
+                lambda=setPoint;
+            }
             return lambda+kp * error + ki * integral + kd * derivative;
         }
 

+ 9 - 3
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/impl/RankServiceImpl.java

@@ -129,8 +129,13 @@ public class RankServiceImpl implements RankService {
             }
         }
 
-        //兜底方案
-        List<AdRankItem> rankResult=rank(param, userAdFeature, rankItems,ScorerUtils.BASE_CONF);
+        // 兜底方案
+        List<AdRankItem> rankResult;
+        if (inCpcPidExp) {
+            rankResult = rank(param, userAdFeature, rankItems, ScorerUtils.CVR_ADJUSTING);
+        } else {
+            rankResult = rank(param, userAdFeature, rankItems, ScorerUtils.BASE_CONF);
+        }
 
         if (!CollectionUtils.isEmpty(rankResult)) {
             JSONObject object=new JSONObject();
@@ -143,7 +148,7 @@ public class RankServiceImpl implements RankService {
             object.put("pidLambda",rankResult.get(0).getPidLambda());
             object.put("lrsamples",rankResult.get(0).getLrSampleString());
             object.put("dataTime",currentTime.format(timeFormatter));
-
+            object.put("creativeId",rankResult.get(0).getAdId());
             log.info("svc=adItemRank {}", JSONObject.toJSONString(object));
             object.remove("lrsamples");
             if(inCpcPidExp){
@@ -311,6 +316,7 @@ public class RankServiceImpl implements RankService {
         object.put("pcvr",topItem.getCvr());
         object.put("lrsamples",topItem.getLrSampleString());
         object.put("pidLambda",topItem.getPidLambda());
+
         //临时加入供pid v2使用
         object.put("realECpm",realECpm);
         object.put("creativeId",result.getCreativeId());