Browse Source

feat:添加CVR模型校准逻辑

zhaohaipeng 11 months ago
parent
commit
0620233e4d

+ 92 - 0
ad-engine-commons/src/main/java/com/tzld/piaoquan/ad/engine/commons/score/model/CvrAdjustingModel.java

@@ -0,0 +1,92 @@
+package com.tzld.piaoquan.ad.engine.commons.score.model;
+
+import com.google.common.collect.HashBasedTable;
+import com.google.common.collect.Table;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.BufferedReader;
+import java.io.InputStreamReader;
+import java.math.BigDecimal;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.TreeMap;
+import java.util.stream.Collectors;
+
+public class CvrAdjustingModel extends Model {
+
+    private static final Logger LOGGER = LoggerFactory.getLogger(CvrAdjustingModel.class);
+
+    private Table<Double, Double, Double> table = HashBasedTable.create();
+
+    @Override
+    public int getModelSize() {
+        return table.size();
+    }
+
+    @Override
+    public boolean loadFromStream(InputStreamReader in) throws Exception {
+        Table<Double, Double, Double> initTable = HashBasedTable.create();
+        try (BufferedReader input = new BufferedReader(in)) {
+            String line;
+            int cnt = 0;
+            Map<Double, Double> initModel = new TreeMap<>();
+            while ((line = input.readLine()) != null) {
+                String[] items = line.split("\t");
+                if (items.length < 4) {
+                    continue;
+                }
+
+                double key = new BigDecimal(items[2]).doubleValue();
+                double value = new BigDecimal(items[3]).doubleValue();
+                initModel.put(key, value);
+            }
+
+            // 最终生成的格式为  区间最小值,区间最大值,系数
+            List<Double> keySet = initModel.keySet().stream().sorted().collect(Collectors.toList());
+            double preKey = 0.0;
+            for (Double key : keySet) {
+                initTable.put(preKey, key, initModel.get(key));
+                preKey = key;
+            }
+            initTable.put(preKey, Double.MAX_VALUE, initModel.get(preKey));
+
+            this.table = initTable;
+
+            for (Table.Cell<Double, Double, Double> cell : this.table.cellSet()) {
+                LOGGER.info("cell.row: {}, cell.column: {}, cell.value: {}", cell.getRowKey(), cell.getColumnKey(), cell.getValue());
+            }
+            
+            LOGGER.info("[CvrAdjustingModel] model load over and size {}", cnt);
+        } catch (
+                Exception e) {
+            LOGGER.info("[CvrAdjustingModel] model load error ", e);
+        } finally {
+            in.close();
+
+        }
+        return true;
+    }
+
+    public Double getAdjustingCoefficien(double score) {
+        if (Objects.isNull(table)) {
+            return 1.0;
+        }
+
+        for (Table.Cell<Double, Double, Double> cell : table.cellSet()) {
+            double rowKey = cell.getRowKey();
+            double columnKey = cell.getColumnKey();
+            if (rowKey <= score & score < columnKey) {
+                return cell.getValue();
+            }
+        }
+
+        return 1.0;
+    }
+
+    public static void main(String[] args) {
+        System.out.println(0.5 / 0.0);
+    }
+
+}

+ 35 - 4
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/VlogAdCvrLRScorer.java

@@ -4,6 +4,7 @@ package com.tzld.piaoquan.ad.engine.service.score;
 import com.tzld.piaoquan.ad.engine.commons.score.BaseLRModelScorer;
 import com.tzld.piaoquan.ad.engine.commons.score.ScoreParam;
 import com.tzld.piaoquan.ad.engine.commons.score.ScorerConfigInfo;
+import com.tzld.piaoquan.ad.engine.commons.score.model.CvrAdjustingModel;
 import com.tzld.piaoquan.ad.engine.commons.score.model.LRModel;
 import com.tzld.piaoquan.recommend.feature.domain.ad.base.*;
 import com.tzld.piaoquan.recommend.feature.domain.ad.feature.VlogAdCtrLRFeatureExtractor;
@@ -16,6 +17,7 @@ import org.slf4j.LoggerFactory;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
+import java.util.Objects;
 import java.util.concurrent.*;
 
 
@@ -31,11 +33,29 @@ public class VlogAdCvrLRScorer extends BaseLRModelScorer {
     private static final int enterFeedsScoreRatio = 10;
     private static final int enterFeedsScoreNum = 20;
 
+    // 校准策略数据的OSS文件路径
+    public static final String OSS_FILE_PATH = "ad_cvr_model/cvr_adjusting_strategy_coefficient.txt";
+
 
     public VlogAdCvrLRScorer(ScorerConfigInfo configInfo) {
         super(configInfo);
     }
 
+    /**
+     * 因CVR有校验策略,除了加载CVR自己的Model外还需要加载校准策略的数据Model
+     * <br />
+     * 故在此重写loadModel方法,加载校验策略的Model
+     */
+    @Override
+    public void loadModel() {
+        super.loadModel();
+        try {
+            modelManager.registerModel(OSS_FILE_PATH, OSS_FILE_PATH, CvrAdjustingModel.class);
+        } catch (
+                Exception e) {
+            LOGGER.error("加载校准策略数据异常: ", e);
+        }
+    }
 
     @Override
     public List<AdRankItem> scoring(final ScoreParam param,
@@ -70,8 +90,10 @@ public class VlogAdCvrLRScorer extends BaseLRModelScorer {
         UserAdBytesFeature userInfoBytes = null;
         userInfoBytes = new UserAdBytesFeature(user);
 
+        CvrAdjustingModel adjustingModel = (CvrAdjustingModel) modelManager.getModel(OSS_FILE_PATH);
+
         // 所有都参与打分,按照cvr排序
-        multipleScore(items, userInfoBytes, requestContext, model);
+        multipleScore(items, userInfoBytes, requestContext, model, adjustingModel);
 
         // debug log
         if (LOGGER.isDebugEnabled()) {
@@ -93,7 +115,8 @@ public class VlogAdCvrLRScorer extends BaseLRModelScorer {
     public double calcScore(final LRModel lrModel,
                             final AdRankItem item,
                             final UserAdBytesFeature userInfoBytes,
-                            final AdRequestContext requestContext) {
+                            final AdRequestContext requestContext,
+                            final CvrAdjustingModel adjustingModel) {
 
         LRSamples lrSamples = null;
         VlogAdCtrLRFeatureExtractor bytesFeatureExtractor;
@@ -113,6 +136,13 @@ public class VlogAdCvrLRScorer extends BaseLRModelScorer {
         if (lrSamples != null && lrSamples.getFeaturesList() != null) {
             try {
                 pro = lrModel.score(lrSamples);
+
+                // CVR校准
+                Double coef = adjustingModel.getAdjustingCoefficien(pro);
+                if (Objects.nonNull(coef)) {
+                    pro = pro / coef;
+                }
+
             } catch (Exception e) {
                 LOGGER.error("score error for doc={} exception={}", new Object[]{
                         item.getAdId(), ExceptionUtils.getFullStackTrace(e)});
@@ -135,7 +165,8 @@ public class VlogAdCvrLRScorer extends BaseLRModelScorer {
     private void multipleScore(final List<AdRankItem> items,
                                   final UserAdBytesFeature userInfoBytes,
                                   final AdRequestContext requestContext,
-                                  final LRModel model) {
+                                  final LRModel model,
+                               final CvrAdjustingModel adjustingModel) {
 
         List<Callable<Object>> calls = new ArrayList<Callable<Object>>();
         for (int index = 0; index < items.size(); index++) {
@@ -145,7 +176,7 @@ public class VlogAdCvrLRScorer extends BaseLRModelScorer {
                 @Override
                 public Object call() throws Exception {
                     try {
-                        calcScore(model, items.get(fIndex), userInfoBytes, requestContext);
+                        calcScore(model, items.get(fIndex), userInfoBytes, requestContext, adjustingModel);
                     } catch (Exception e) {
                         LOGGER.error("ctr exception: [{}] [{}]", items.get(fIndex).adId, ExceptionUtils.getFullStackTrace(e));
                     }