Parcourir la source

Merge branch 'feature_20241213_creative_calibration_v2' of algorithm/ad-engine into master

zhaohaipeng il y a 4 mois
Parent
commit
f4219e89b6

+ 53 - 0
ad-engine-commons/src/main/java/com/tzld/piaoquan/ad/engine/commons/score/model/CreativeCalibrationModel.java

@@ -0,0 +1,53 @@
+package com.tzld.piaoquan.ad.engine.commons.score.model;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.BufferedReader;
+import java.io.InputStreamReader;
+import java.util.HashMap;
+import java.util.Map;
+
+public class CreativeCalibrationModel extends Model {
+
+    private static final Logger LOGGER = LoggerFactory.getLogger(CreativeCalibrationModel.class);
+
+    private Map<Long, Double> creativeDiffRateMap = new HashMap<>();
+
+    @Override
+    public int getModelSize() {
+        return this.creativeDiffRateMap.size();
+    }
+
+    @Override
+    public boolean loadFromStream(InputStreamReader in) throws Exception {
+        Map<Long, Double> initMap = new HashMap<>();
+        try (BufferedReader input = new BufferedReader(in)) {
+            String line;
+            while ((line = input.readLine()) != null) {
+                String[] items = line.split("\t");
+                if (items.length < 2) {
+                    continue;
+                }
+
+                long cid = Long.parseLong(items[0]);
+                double diffRate = Double.parseDouble(items[1]);
+
+                initMap.put(cid, diffRate);
+
+            }
+            this.creativeDiffRateMap = initMap;
+            LOGGER.info("[CreativeCalibrationModel] model load over and size {}", this.creativeDiffRateMap.size());
+        } catch (Exception e) {
+            LOGGER.info("[CreativeCalibrationModel] model load error ", e);
+        } finally {
+            in.close();
+
+        }
+        return true;
+    }
+
+    public double getDiffRate(long cid) {
+        return this.creativeDiffRateMap.getOrDefault(cid, 0d);
+    }
+}

+ 5 - 5
ad-engine-commons/src/main/java/com/tzld/piaoquan/ad/engine/commons/score/model/XGBCalibrationModel.java → ad-engine-commons/src/main/java/com/tzld/piaoquan/ad/engine/commons/score/model/ValueRangeCalibrationModel.java

@@ -9,9 +9,9 @@ import java.io.BufferedReader;
 import java.io.InputStreamReader;
 import java.util.Objects;
 
-public class XGBCalibrationModel extends Model {
+public class ValueRangeCalibrationModel extends Model {
 
-    private static final Logger LOGGER = LoggerFactory.getLogger(XGBCalibrationModel.class);
+    private static final Logger LOGGER = LoggerFactory.getLogger(ValueRangeCalibrationModel.class);
 
     private Table<Double, Double, Double> table = HashBasedTable.create();
 
@@ -40,12 +40,12 @@ public class XGBCalibrationModel extends Model {
             this.table = initTable;
 
             for (Table.Cell<Double, Double, Double> cell : this.table.cellSet()) {
-                LOGGER.info("cell.row: {}, cell.column: {}, cell.value: {}", cell.getRowKey(), cell.getColumnKey(), cell.getValue());
+                LOGGER.info("[ValueRangeCalibrationModel] cell.row: {}, cell.column: {}, cell.value: {}", cell.getRowKey(), cell.getColumnKey(), cell.getValue());
             }
 
-            LOGGER.info("[XGBCalibrationModel] model load over and size {}", this.table.size());
+            LOGGER.info("[ValueRangeCalibrationModel] model load over and size {}", this.table.size());
         } catch (Exception e) {
-            LOGGER.info("[XGBCalibrationModel] model load error ", e);
+            LOGGER.info("[ValueRangeCalibrationModel] model load error ", e);
         } finally {
             in.close();
 

+ 1 - 1
ad-engine-server/src/main/resources/ad_score_config_xgboost_20241105.conf

@@ -5,7 +5,7 @@ scorer-config = {
     model-path = "zhangbo/model_xgb_351_1000_v2.tar.gz"
   }
   calibration-score-config = {
-    scorer-name = "com.tzld.piaoquan.ad.engine.service.score.scorer.ScoreCalibrationScorer"
+    scorer-name = "com.tzld.piaoquan.ad.engine.service.score.scorer.CreativeCalibrationScorer"
     scorer-priority = 98
     model-path = "zhangbo/model_xgb_351_1000_v2_calibration.txt"
   }

+ 148 - 0
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/scorer/CreativeCalibrationScorer.java

@@ -0,0 +1,148 @@
+package com.tzld.piaoquan.ad.engine.service.score.scorer;
+
+import com.tzld.piaoquan.ad.engine.commons.score.AbstractScorer;
+import com.tzld.piaoquan.ad.engine.commons.score.ScoreParam;
+import com.tzld.piaoquan.ad.engine.commons.score.ScorerConfigInfo;
+import com.tzld.piaoquan.ad.engine.commons.score.model.CreativeCalibrationModel;
+import com.tzld.piaoquan.recommend.feature.domain.ad.base.AdRankItem;
+import com.tzld.piaoquan.recommend.feature.domain.ad.base.UserAdFeature;
+import org.apache.commons.collections4.CollectionUtils;
+import org.apache.commons.lang.exception.ExceptionUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.*;
+
+public class CreativeCalibrationScorer extends AbstractScorer {
+
+
+    private static final int LOCAL_TIME_OUT = 150;
+    private static final Logger LOGGER = LoggerFactory.getLogger(CreativeCalibrationScorer.class);
+    private static final ExecutorService executorService = Executors.newFixedThreadPool(128);
+
+    public CreativeCalibrationScorer(ScorerConfigInfo scorerConfigInfo) {
+        super(scorerConfigInfo);
+    }
+
+    @Override
+    public void loadModel() {
+        super.doLoadModel(CreativeCalibrationModel.class);
+    }
+
+    @Override
+    public List<AdRankItem> scoring(ScoreParam param, UserAdFeature userAdFeature, List<AdRankItem> rankItems) {
+        throw new NoSuchMethodError();
+    }
+
+    public List<AdRankItem> scoring(final Map<String, String> sceneFeatureMap,
+                                    final Map<String, String> userFeatureMap,
+                                    final List<AdRankItem> rankItems) {
+        if (CollectionUtils.isEmpty(rankItems)) {
+            return rankItems;
+        }
+
+        long startTime = System.currentTimeMillis();
+
+        List<AdRankItem> result = rankByJava(sceneFeatureMap, userFeatureMap, rankItems);
+
+        LOGGER.debug("[CreativeCalibrationScorer] scoring items size={}, time={} ",
+                result.size(), System.currentTimeMillis() - startTime);
+
+        return result;
+    }
+
+    private List<AdRankItem> rankByJava(final Map<String, String> sceneFeatureMap,
+                                        final Map<String, String> userFeatureMap,
+                                        final List<AdRankItem> items) {
+        long startTime = System.currentTimeMillis();
+        CreativeCalibrationModel model = (CreativeCalibrationModel) this.getModel();
+        LOGGER.debug("model size: [{}]", model.getModelSize());
+
+        // 所有都参与打分,按照ctr排序
+        multipleCtrScore(items, userFeatureMap, sceneFeatureMap, model);
+
+        // debug log
+        if (LOGGER.isDebugEnabled()) {
+            for (AdRankItem item : items) {
+                LOGGER.debug("before enter feeds model predict ctr score [{}] [{}]", item, item);
+            }
+        }
+
+        Collections.sort(items);
+
+        LOGGER.debug("[CreativeCalibrationScorer] items size={}, cost={} ",
+                items.size(), System.currentTimeMillis() - startTime);
+        return items;
+    }
+
+    private void multipleCtrScore(final List<AdRankItem> items,
+                                  final Map<String, String> userFeatureMap,
+                                  final Map<String, String> sceneFeatureMap,
+                                  final CreativeCalibrationModel model) {
+
+        List<Callable<Object>> calls = new ArrayList<Callable<Object>>();
+        for (int index = 0; index < items.size(); index++) {
+            final int fIndex = index;
+            calls.add(new Callable<Object>() {
+                @Override
+                public Object call() throws Exception {
+                    try {
+                        calcScore(model, items.get(fIndex), userFeatureMap, sceneFeatureMap);
+                    } catch (Exception e) {
+                        LOGGER.error("ctr exception: [{}] [{}]", items.get(fIndex), ExceptionUtils.getFullStackTrace(e));
+                    }
+                    return new Object();
+                }
+            });
+        }
+
+        List<Future<Object>> futures = null;
+        try {
+            futures = executorService.invokeAll(calls, LOCAL_TIME_OUT, TimeUnit.MILLISECONDS);
+        } catch (InterruptedException e) {
+            LOGGER.error("execute invoke fail: {}", ExceptionUtils.getFullStackTrace(e));
+        }
+
+        // 等待所有请求的结果返回, 超时也返回
+        int cancel = 0;
+        if (futures != null) {
+            for (Future<Object> future : futures) {
+                try {
+                    if (!future.isDone() || future.isCancelled() || future.get() == null) {
+                        cancel++;
+                    }
+                } catch (InterruptedException e) {
+                    LOGGER.error("InterruptedException: ", e);
+                } catch (ExecutionException e) {
+                    LOGGER.error("ExecutionException {},", sceneFeatureMap.size(), e);
+                }
+            }
+        }
+    }
+
+    public double calcScore(final CreativeCalibrationModel model,
+                            final AdRankItem item,
+                            final Map<String, String> userFeatureMap,
+                            final Map<String, String> sceneFeatureMap) {
+        double ctcvrScore = item.getLrScore();
+        double newCtcvrScore = ctcvrScore;
+        try {
+
+            double diffRate = model.getDiffRate(item.getAdId());
+            newCtcvrScore = ctcvrScore / (1 + diffRate);
+            item.setLrScore(newCtcvrScore);
+            item.getScoreMap().put("diff_rate", diffRate);
+            item.getScoreMap().put("originCtcvrScore", ctcvrScore);
+            item.getScoreMap().put("ctcvrScore", newCtcvrScore);
+        } catch (Exception e) {
+            LOGGER.error("[score calibration] calcScore error: ", e);
+        }
+
+        return newCtcvrScore;
+    }
+
+}

+ 8 - 8
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/scorer/ScoreCalibrationScorer.java → ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/scorer/ValueRangeCalibrationScorer.java

@@ -3,7 +3,7 @@ package com.tzld.piaoquan.ad.engine.service.score.scorer;
 import com.tzld.piaoquan.ad.engine.commons.score.AbstractScorer;
 import com.tzld.piaoquan.ad.engine.commons.score.ScoreParam;
 import com.tzld.piaoquan.ad.engine.commons.score.ScorerConfigInfo;
-import com.tzld.piaoquan.ad.engine.commons.score.model.XGBCalibrationModel;
+import com.tzld.piaoquan.ad.engine.commons.score.model.ValueRangeCalibrationModel;
 import com.tzld.piaoquan.recommend.feature.domain.ad.base.AdRankItem;
 import com.tzld.piaoquan.recommend.feature.domain.ad.base.UserAdFeature;
 import org.apache.commons.collections4.CollectionUtils;
@@ -17,20 +17,20 @@ import java.util.List;
 import java.util.Map;
 import java.util.concurrent.*;
 
-public class ScoreCalibrationScorer extends AbstractScorer {
+public class ValueRangeCalibrationScorer extends AbstractScorer {
 
     private static final int LOCAL_TIME_OUT = 150;
-    private static Logger LOGGER = LoggerFactory.getLogger(ScoreCalibrationScorer.class);
+    private static final Logger LOGGER = LoggerFactory.getLogger(ValueRangeCalibrationScorer.class);
     private static final ExecutorService executorService = Executors.newFixedThreadPool(128);
 
 
-    public ScoreCalibrationScorer(ScorerConfigInfo scorerConfigInfo) {
+    public ValueRangeCalibrationScorer(ScorerConfigInfo scorerConfigInfo) {
         super(scorerConfigInfo);
     }
 
     @Override
     public void loadModel() {
-        super.doLoadModel(XGBCalibrationModel.class);
+        super.doLoadModel(ValueRangeCalibrationModel.class);
     }
 
     @Override
@@ -59,7 +59,7 @@ public class ScoreCalibrationScorer extends AbstractScorer {
                                         final Map<String, String> userFeatureMap,
                                         final List<AdRankItem> items) {
         long startTime = System.currentTimeMillis();
-        XGBCalibrationModel model = (XGBCalibrationModel) this.getModel();
+        ValueRangeCalibrationModel model = (ValueRangeCalibrationModel) this.getModel();
         LOGGER.debug("model size: [{}]", model.getModelSize());
 
         // 所有都参与打分,按照ctr排序
@@ -82,7 +82,7 @@ public class ScoreCalibrationScorer extends AbstractScorer {
     private void multipleCtrScore(final List<AdRankItem> items,
                                   final Map<String, String> userFeatureMap,
                                   final Map<String, String> sceneFeatureMap,
-                                  final XGBCalibrationModel model) {
+                                  final ValueRangeCalibrationModel model) {
 
         List<Callable<Object>> calls = new ArrayList<Callable<Object>>();
         for (int index = 0; index < items.size(); index++) {
@@ -124,7 +124,7 @@ public class ScoreCalibrationScorer extends AbstractScorer {
         }
     }
 
-    public double calcScore(final XGBCalibrationModel model,
+    public double calcScore(final ValueRangeCalibrationModel model,
                             final AdRankItem item,
                             final Map<String, String> userFeatureMap,
                             final Map<String, String> sceneFeatureMap) {