Przeglądaj źródła

calibration scores v2

sunmingze 1 rok temu
rodzic
commit
a3fe1797db

+ 1 - 1
ad-engine-commons/pom.xml

@@ -25,7 +25,7 @@
         <dependency>
             <groupId>com.tzld.piaoquan</groupId>
             <artifactId>recommend-feature-client</artifactId>
-            <version>1.0.2</version>
+            <version>1.0.3</version>
         </dependency>
         <dependency>
             <groupId>com.tzld.piaoquan</groupId>

+ 20 - 0
ad-engine-commons/src/main/java/com/tzld/piaoquan/ad/engine/commons/score/BaseCalibrationScorer.java

@@ -0,0 +1,20 @@
+package com.tzld.piaoquan.ad.engine.commons.score;
+
+import com.tzld.piaoquan.ad.engine.commons.score.model.CalibrationModel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+public abstract class BaseCalibrationScorer extends AbstractScorer {
+
+    private static Logger LOGGER = LoggerFactory.getLogger(BaseCalibrationScorer.class);
+
+    public BaseCalibrationScorer(ScorerConfigInfo scorerConfigInfo) {
+        super(scorerConfigInfo);
+    }
+
+    @Override
+    public void loadModel() {
+        doLoadModel(CalibrationModel.class);
+    }
+}

+ 76 - 0
ad-engine-commons/src/main/java/com/tzld/piaoquan/ad/engine/commons/score/model/CalibrationModel.java

@@ -0,0 +1,76 @@
+package com.tzld.piaoquan.ad.engine.commons.score.model;
+
+import com.tzld.piaoquan.recommend.feature.domain.ad.base.AdActionFeature;
+import com.tzld.piaoquan.recommend.feature.domain.ad.base.AdRankItem;
+import org.apache.commons.math3.distribution.BetaDistribution;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.math.BigInteger;
+import java.util.HashMap;
+import java.util.TreeMap;
+
+
+public class CalibrationModel extends Model {
+    protected static final int MODEL_FIRST_LOAD_COUNT = 1 << 25;  // 32M
+    private static final Logger LOGGER = LoggerFactory.getLogger(CalibrationModel.class);
+
+    private TreeMap<Double, Double> calibrationModel;
+
+    public CalibrationModel() {
+        //配置不同环境的hdfs conf
+        this.calibrationModel = new TreeMap<>();
+    }
+
+    public TreeMap<Double, Double> getCalibrationModel() {
+        return this.calibrationModel;
+    }
+
+
+    @Override
+    public boolean loadFromStream(InputStreamReader in) throws IOException {
+        TreeMap<Double, Double> bins = new TreeMap<>();
+        BufferedReader input = new BufferedReader(in);
+        String line = null;
+        int cnt = 0;
+        while ((line = input.readLine()) != null) {
+            String[] items = line.split("\t");
+            if (items.length < 2) {
+                continue;
+            }
+            bins.put(Double.parseDouble(items[0]), Double.parseDouble(items[1]));
+        }
+
+        this.calibrationModel = bins;
+        LOGGER.info("[MODELLOAD] calibration model load over and size " + cnt);
+        input.close();
+        in.close();
+        return true;
+    }
+
+
+    @Override
+    public int getModelSize() {
+        if (this.calibrationModel == null)
+            return 0;
+        int sum = this.calibrationModel.size();
+        return sum;
+    }
+
+    public double score(AdRankItem adRankItem, String ctrOrCVR) {
+        double score = 0.0f;
+        if (ctrOrCVR.equals("ctr")) {
+            double upperBound = this.calibrationModel.floorKey(adRankItem.getCtr());
+            score = this.calibrationModel.get(upperBound);
+        }
+        if (ctrOrCVR.equals("cvr")) {
+            double upperBound = this.calibrationModel.floorKey(adRankItem.getCvr());
+            score = this.calibrationModel.get(upperBound);
+        }
+        return score;
+    }
+
+}

+ 2 - 2
ad-engine-server/src/main/resources/feeds_score_config_baseline.conf

@@ -10,13 +10,13 @@ scorer-config = {
       model-path = "ad_cvr_model/model_ad_cvr.txt"
   }
   lr-ctr-calibretion-config = {
-      scorer-name = "com.tzld.piaoquan.ad.engine.service.score.VlogCtrCalibretionScorer"
+      scorer-name = "com.tzld.piaoquan.ad.engine.service.score.VlogAdCtrCalibrationScorer"
       scorer-priority = 99
       model-path = "ad_ctr_calibretion/model_ctr_calibretion.txt"
   }
 
   lr-cvr-calibretion-config = {
-      scorer-name = "com.tzld.piaoquan.ad.engine.service.score.VlogCvrCalibretionScorer"
+      scorer-name = "com.tzld.piaoquan.ad.engine.service.score.VlogAdCvrCalibrationScorer"
       scorer-priority = 99
       model-path = "ad_cvr_calibretion/model_cvr_calibretion.txt"
   }

+ 143 - 0
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/VlogAdCtrCalibrationScorer.java

@@ -0,0 +1,143 @@
+package com.tzld.piaoquan.ad.engine.service.score;
+
+
+import com.tzld.piaoquan.ad.engine.commons.score.BaseCalibrationScorer;
+import com.tzld.piaoquan.ad.engine.commons.score.ScoreParam;
+import com.tzld.piaoquan.ad.engine.commons.score.ScorerConfigInfo;
+import com.tzld.piaoquan.ad.engine.commons.score.model.CalibrationModel;
+import com.tzld.piaoquan.recommend.feature.domain.ad.base.AdRankItem;
+import com.tzld.piaoquan.recommend.feature.domain.ad.base.AdRequestContext;
+import com.tzld.piaoquan.recommend.feature.domain.ad.base.UserAdFeature;
+import org.apache.commons.collections4.CollectionUtils;
+import org.apache.commons.lang.exception.ExceptionUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.*;
+
+
+//@Service
+public class VlogAdCtrCalibrationScorer extends BaseCalibrationScorer {
+
+    private static final int LOCAL_TIME_OUT = 150;
+    private final static Logger LOGGER = LoggerFactory.getLogger(VlogAdCtrCalibrationScorer.class);
+    private static final ExecutorService executorService = Executors.newFixedThreadPool(128);
+
+    public VlogAdCtrCalibrationScorer(ScorerConfigInfo configInfo) {
+        super(configInfo);
+    }
+
+    @Override
+    public List<AdRankItem> scoring(final ScoreParam param,
+                                    final UserAdFeature userFeature,
+                                    final List<AdRankItem> rankItems) {
+
+        if (userFeature == null || CollectionUtils.isEmpty(rankItems)) {
+            return rankItems;
+        }
+
+        long startTime = System.currentTimeMillis();
+        CalibrationModel model = (CalibrationModel) this.getModel();
+        LOGGER.debug("model size: [{}]", model.getModelSize());
+
+        List<AdRankItem> result = rankItems;
+        result = rankByJava(rankItems, param.getRequestContext(), userFeature);
+
+        LOGGER.debug("calibration ctr ranker time java items size={}, time={} ", result != null ? result.size() : 0,
+                System.currentTimeMillis() - startTime);
+
+        return result;
+    }
+
+    private List<AdRankItem> rankByJava(final List<AdRankItem> items,
+                                        final AdRequestContext requestContext,
+                                        final UserAdFeature user) {
+        long startTime = System.currentTimeMillis();
+        CalibrationModel model = (CalibrationModel) this.getModel();
+        LOGGER.debug("model size: [{}]", model.getModelSize());
+
+        // 所有都参与打分,按照ctr排序
+        multipleCtrScore(items, model);
+
+        // debug log
+        if (LOGGER.isDebugEnabled()) {
+            for (int i = 0; i < items.size(); i++) {
+                LOGGER.debug("after enter feeds model predict ctr score [{}] [{}]", items.get(i), items.get(i).getScore());
+            }
+        }
+
+        LOGGER.debug("calibration ranker java execute time: [{}]", System.currentTimeMillis() - startTime);
+        return items;
+    }
+
+
+    /**
+     * 计算 predict ecpm
+     */
+    public double calcScore(final CalibrationModel model,
+                            final AdRankItem item) {
+        double pctr = 0.0;
+        try {
+            pctr = model.score(item, "ctr");
+        } catch (Exception e) {
+            LOGGER.error("score error for doc={} exception={}", new Object[]{
+                    item.getAdId(), ExceptionUtils.getFullStackTrace(e)});
+        }
+        item.setCtr(pctr);
+        return pctr;
+    }
+
+
+    /**
+     * 并行打分 ecpm
+     *
+     * @param items
+     * @param model
+     */
+    private void multipleCtrScore(final List<AdRankItem> items,
+                                  final CalibrationModel model) {
+
+        List<Callable<Object>> calls = new ArrayList<Callable<Object>>();
+        for (int index = 0; index < items.size(); index++) {
+            final int fIndex = index;
+            items.get(fIndex).setScore(0.0);   //设置为原始值为0
+            calls.add(new Callable<Object>() {
+                @Override
+                public Object call() throws Exception {
+                    try {
+                        calcScore(model, items.get(fIndex));
+                    } catch (Exception e) {
+                        LOGGER.error("calibration exception: [{}] [{}]", items.get(fIndex).adId, ExceptionUtils.getFullStackTrace(e));
+                    }
+                    return new Object();
+                }
+            });
+        }
+
+        List<Future<Object>> futures = null;
+        try {
+            futures = executorService.invokeAll(calls, LOCAL_TIME_OUT, TimeUnit.MILLISECONDS);
+        } catch (InterruptedException e) {
+            LOGGER.error("execute invoke fail: {}", ExceptionUtils.getFullStackTrace(e));
+        }
+
+        //等待所有请求的结果返回, 超时也返回
+        int cancel = 0;
+        if (futures != null) {
+            for (Future<Object> future : futures) {
+                try {
+                    if (!future.isDone() || future.isCancelled() || future.get() == null) {
+                        cancel++;
+                    }
+                } catch (InterruptedException e) {
+                    LOGGER.error("InterruptedException {},{}", ExceptionUtils.getFullStackTrace(e));
+                } catch (ExecutionException e) {
+                    LOGGER.error("ExecutionException {},{}", ExceptionUtils.getFullStackTrace(e));
+                }
+            }
+        }
+        LOGGER.debug("ecpm Score {}, Total: {}, Cancel: {}", new Object[]{items.size(), cancel});
+    }
+}

+ 143 - 0
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/VlogAdCvrCalibrationScorer.java

@@ -0,0 +1,143 @@
+package com.tzld.piaoquan.ad.engine.service.score;
+
+
+import com.tzld.piaoquan.ad.engine.commons.score.BaseCalibrationScorer;
+import com.tzld.piaoquan.ad.engine.commons.score.ScoreParam;
+import com.tzld.piaoquan.ad.engine.commons.score.ScorerConfigInfo;
+import com.tzld.piaoquan.ad.engine.commons.score.model.CalibrationModel;
+import com.tzld.piaoquan.recommend.feature.domain.ad.base.AdRankItem;
+import com.tzld.piaoquan.recommend.feature.domain.ad.base.AdRequestContext;
+import com.tzld.piaoquan.recommend.feature.domain.ad.base.UserAdFeature;
+import org.apache.commons.collections4.CollectionUtils;
+import org.apache.commons.lang.exception.ExceptionUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.*;
+
+
+//@Service
+public class VlogAdCvrCalibrationScorer extends BaseCalibrationScorer {
+
+    private static final int LOCAL_TIME_OUT = 150;
+    private final static Logger LOGGER = LoggerFactory.getLogger(VlogAdCvrCalibrationScorer.class);
+    private static final ExecutorService executorService = Executors.newFixedThreadPool(128);
+
+    public VlogAdCvrCalibrationScorer(ScorerConfigInfo configInfo) {
+        super(configInfo);
+    }
+
+    @Override
+    public List<AdRankItem> scoring(final ScoreParam param,
+                                    final UserAdFeature userFeature,
+                                    final List<AdRankItem> rankItems) {
+
+        if (userFeature == null || CollectionUtils.isEmpty(rankItems)) {
+            return rankItems;
+        }
+
+        long startTime = System.currentTimeMillis();
+        CalibrationModel model = (CalibrationModel) this.getModel();
+        LOGGER.debug("model size: [{}]", model.getModelSize());
+
+        List<AdRankItem> result = rankItems;
+        result = rankByJava(rankItems, param.getRequestContext(), userFeature);
+
+        LOGGER.debug("calibration ctr ranker time java items size={}, time={} ", result != null ? result.size() : 0,
+                System.currentTimeMillis() - startTime);
+
+        return result;
+    }
+
+    private List<AdRankItem> rankByJava(final List<AdRankItem> items,
+                                        final AdRequestContext requestContext,
+                                        final UserAdFeature user) {
+        long startTime = System.currentTimeMillis();
+        CalibrationModel model = (CalibrationModel) this.getModel();
+        LOGGER.debug("model size: [{}]", model.getModelSize());
+
+        // 所有都参与打分,按照ctr排序
+        multipleCtrScore(items, model);
+
+        // debug log
+        if (LOGGER.isDebugEnabled()) {
+            for (int i = 0; i < items.size(); i++) {
+                LOGGER.debug("after enter feeds model predict ctr score [{}] [{}]", items.get(i), items.get(i).getScore());
+            }
+        }
+
+        LOGGER.debug("calibration ranker java execute time: [{}]", System.currentTimeMillis() - startTime);
+        return items;
+    }
+
+
+    /**
+     * 计算 predict ecpm
+     */
+    public double calcScore(final CalibrationModel model,
+                            final AdRankItem item) {
+        double pcvr = 0.0;
+        try {
+            pcvr = model.score(item, "cvr");
+        } catch (Exception e) {
+            LOGGER.error("score error for doc={} exception={}", new Object[]{
+                    item.getAdId(), ExceptionUtils.getFullStackTrace(e)});
+        }
+        item.setCvr(pcvr);
+        return pcvr;
+    }
+
+
+    /**
+     * 并行打分 ecpm
+     *
+     * @param items
+     * @param model
+     */
+    private void multipleCtrScore(final List<AdRankItem> items,
+                                  final CalibrationModel model) {
+
+        List<Callable<Object>> calls = new ArrayList<Callable<Object>>();
+        for (int index = 0; index < items.size(); index++) {
+            final int fIndex = index;
+            items.get(fIndex).setScore(0.0);   //设置为原始值为0
+            calls.add(new Callable<Object>() {
+                @Override
+                public Object call() throws Exception {
+                    try {
+                        calcScore(model, items.get(fIndex));
+                    } catch (Exception e) {
+                        LOGGER.error("calibration exception: [{}] [{}]", items.get(fIndex).adId, ExceptionUtils.getFullStackTrace(e));
+                    }
+                    return new Object();
+                }
+            });
+        }
+
+        List<Future<Object>> futures = null;
+        try {
+            futures = executorService.invokeAll(calls, LOCAL_TIME_OUT, TimeUnit.MILLISECONDS);
+        } catch (InterruptedException e) {
+            LOGGER.error("execute invoke fail: {}", ExceptionUtils.getFullStackTrace(e));
+        }
+
+        //等待所有请求的结果返回, 超时也返回
+        int cancel = 0;
+        if (futures != null) {
+            for (Future<Object> future : futures) {
+                try {
+                    if (!future.isDone() || future.isCancelled() || future.get() == null) {
+                        cancel++;
+                    }
+                } catch (InterruptedException e) {
+                    LOGGER.error("InterruptedException {},{}", ExceptionUtils.getFullStackTrace(e));
+                } catch (ExecutionException e) {
+                    LOGGER.error("ExecutionException {},{}", ExceptionUtils.getFullStackTrace(e));
+                }
+            }
+        }
+        LOGGER.debug("ecpm Score {}, Total: {}, Cancel: {}", new Object[]{items.size(), cancel});
+    }
+}

+ 0 - 80
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/VlogCtrCalibretionScorer.java

@@ -1,80 +0,0 @@
-package com.tzld.piaoquan.ad.engine.service.score;
-
-
-import com.tzld.piaoquan.ad.engine.commons.score.BaseLRModelScorer;
-import com.tzld.piaoquan.ad.engine.commons.score.ScoreParam;
-import com.tzld.piaoquan.ad.engine.commons.score.ScorerConfigInfo;
-import com.tzld.piaoquan.ad.engine.commons.score.model.LRModel;
-import com.tzld.piaoquan.recommend.feature.domain.ad.base.*;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.util.Collections;
-import java.util.List;
-
-
-//@Service
-public class VlogCtrCalibretionScorer extends BaseLRModelScorer {
-
-    private final static Logger LOGGER = LoggerFactory.getLogger(VlogCtrCalibretionScorer.class);
-
-
-    public VlogCtrCalibretionScorer(ScorerConfigInfo configInfo) {
-        super(configInfo);
-    }
-
-
-    @Override
-    public List<AdRankItem> scoring(final ScoreParam param,
-                                    final UserAdFeature userFeature,
-                                    final List<AdRankItem> rankItems) {
-
-
-        long startTime = System.currentTimeMillis();
-        LRModel model = (LRModel) this.getModel();
-        LOGGER.debug("model size: [{}]", model.getModelSize());
-        List<AdRankItem> result = ctrCalibretion(rankItems);
-
-        LOGGER.debug("ctr ranker time java items size={}, time={} ", result != null ? result.size() : 0,
-                System.currentTimeMillis() - startTime);
-
-        return result;
-    }
-
-
-    public List<AdRankItem> ctrCalibretion(List<AdRankItem> items) {
-        long startTime = System.currentTimeMillis();
-        LRModel model = (LRModel) this.getModel();
-        if (model == null) {
-            LOGGER.error("not found model for ctr calibration [{}]");
-            return items;
-        }
-        int model_size = model.getModelSize();
-        for (AdRankItem item : items) {
-            double oldScore = item.getCtr();
-            double newScore = 0.0;
-            if (oldScore > 1.0 || oldScore < 0) {
-                item.setCtr(0.0);
-                continue;
-            }
-
-            try {
-                long key = (long) Math.floor(oldScore * model_size);
-                newScore = model.getWeight(model.getLrModel(), key);
-                if (newScore == 0) {
-                    LOGGER.error("ctr ctrCalibretion ctr Score: {} error", oldScore);
-                } else {
-                    item.setCtr(newScore);
-                }
-            } catch (Exception e) {
-                LOGGER.error("ctr ctrCalibretion ctr Score: {} couldn`t get key", oldScore);
-                item.setCtr(0.0);
-            }
-            LOGGER.debug("ctr ctrCalibretion ranker , score: {}->{}", new Object[]{oldScore, newScore});
-        }
-        Collections.sort(items);
-        return items;
-    }
-
-
-}

+ 0 - 81
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/VlogCvrCalibretionScorer.java

@@ -1,81 +0,0 @@
-package com.tzld.piaoquan.ad.engine.service.score;
-
-
-import com.tzld.piaoquan.ad.engine.commons.score.BaseLRModelScorer;
-import com.tzld.piaoquan.ad.engine.commons.score.ScoreParam;
-import com.tzld.piaoquan.ad.engine.commons.score.ScorerConfigInfo;
-import com.tzld.piaoquan.ad.engine.commons.score.model.LRModel;
-import com.tzld.piaoquan.recommend.feature.domain.ad.base.AdRankItem;
-import com.tzld.piaoquan.recommend.feature.domain.ad.base.UserAdFeature;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.util.Collections;
-import java.util.List;
-
-
-//@Service
-public class VlogCvrCalibretionScorer extends BaseLRModelScorer {
-
-    private final static Logger LOGGER = LoggerFactory.getLogger(VlogCvrCalibretionScorer.class);
-
-
-    public VlogCvrCalibretionScorer(ScorerConfigInfo configInfo) {
-        super(configInfo);
-    }
-
-
-    @Override
-    public List<AdRankItem> scoring(final ScoreParam param,
-                                    final UserAdFeature userFeature,
-                                    final List<AdRankItem> rankItems) {
-
-
-        long startTime = System.currentTimeMillis();
-        LRModel model = (LRModel) this.getModel();
-        LOGGER.debug("model size: [{}]", model.getModelSize());
-        List<AdRankItem> result = cvrCalibretion(rankItems);
-
-        LOGGER.debug("cvr ranker time java items size={}, time={} ", result != null ? result.size() : 0,
-                System.currentTimeMillis() - startTime);
-
-        return result;
-    }
-
-
-    public List<AdRankItem> cvrCalibretion(List<AdRankItem> items) {
-        long startTime = System.currentTimeMillis();
-        LRModel model = (LRModel) this.getModel();
-        if (model == null) {
-            LOGGER.error("not found model for cvr calibration [{}]");
-            return items;
-        }
-        int model_size = model.getModelSize();
-        for (AdRankItem item : items) {
-            double oldScore = item.getCvr();
-            double newScore = 0.0;
-            if (oldScore > 1.0 || oldScore < 0) {
-                item.setCvr(0.0);
-                continue;
-            }
-
-            try {
-                long key = (long) Math.floor(oldScore * model_size);
-                newScore = model.getWeight(model.getLrModel(), key);
-                if (newScore == 0) {
-                    LOGGER.error("cvr cvrCalibretion cvr Score: {} error", oldScore);
-                } else {
-                    item.setCvr(newScore);
-                }
-            } catch (Exception e) {
-                LOGGER.error("cvr cvrCalibretion ctr Score: {} couldn`t get key", oldScore);
-                item.setCvr(0.0);
-            }
-            LOGGER.debug("cvr cvrCalibretion ranker , score: {}->{}", new Object[]{oldScore, newScore});
-        }
-        Collections.sort(items);
-        return items;
-    }
-
-
-}