|
@@ -0,0 +1,149 @@
|
|
|
+package com.tzld.piaoquan.ad.engine.commons.score;
|
|
|
+
|
|
|
+import com.tzld.piaoquan.ad.engine.commons.score.model.UnionThompsonSamplingModel;
|
|
|
+import com.tzld.piaoquan.recommend.feature.domain.ad.base.AdRankItem;
|
|
|
+import com.tzld.piaoquan.recommend.feature.domain.ad.base.AdRequestContext;
|
|
|
+import com.tzld.piaoquan.recommend.feature.domain.ad.base.UserAdFeature;
|
|
|
+import org.apache.commons.collections4.CollectionUtils;
|
|
|
+import org.apache.commons.lang.exception.ExceptionUtils;
|
|
|
+import org.slf4j.Logger;
|
|
|
+import org.slf4j.LoggerFactory;
|
|
|
+
|
|
|
+import java.util.ArrayList;
|
|
|
+import java.util.List;
|
|
|
+import java.util.concurrent.*;
|
|
|
+
|
|
|
+public class UnionThompsonSamplingScore extends BaseThompsonSamplingScorer {
|
|
|
+
|
|
|
+ private static final Logger LOGGER = LoggerFactory.getLogger(UnionThompsonSamplingModel.class);
|
|
|
+ private static final int LOCAL_TIME_OUT = 150;
|
|
|
+ private static final ExecutorService executorService = Executors.newFixedThreadPool(8);
|
|
|
+
|
|
|
+ public UnionThompsonSamplingScore(ScorerConfigInfo scorerConfigInfo) {
|
|
|
+ super(scorerConfigInfo);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void loadModel() {
|
|
|
+ doLoadModel(UnionThompsonSamplingModel.class);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public List<AdRankItem> scoring(ScoreParam param, UserAdFeature userFeature, List<AdRankItem> rankItems) {
|
|
|
+ if (CollectionUtils.isEmpty(rankItems)) {
|
|
|
+ return rankItems;
|
|
|
+ }
|
|
|
+ long startTime = System.currentTimeMillis();
|
|
|
+
|
|
|
+ List<AdRankItem> result = rankByJava(rankItems, param.getRequestContext(), userFeature);
|
|
|
+
|
|
|
+ LOGGER.debug("union thompson sampling ctr ranker time java items size={}, time={} ",
|
|
|
+ result.size(), System.currentTimeMillis() - startTime);
|
|
|
+
|
|
|
+ return result;
|
|
|
+ }
|
|
|
+
|
|
|
+ private List<AdRankItem> rankByJava(final List<AdRankItem> items,
|
|
|
+ final AdRequestContext requestContext,
|
|
|
+ final UserAdFeature user) {
|
|
|
+ long startTime = System.currentTimeMillis();
|
|
|
+ UnionThompsonSamplingModel model = (UnionThompsonSamplingModel) this.getModel();
|
|
|
+ LOGGER.debug("model size: [{}]", model.getModelSize());
|
|
|
+
|
|
|
+ // 所有都参与打分,按照ctr排序
|
|
|
+ multipleCtrScore(items, model);
|
|
|
+
|
|
|
+ // debug log
|
|
|
+ if (LOGGER.isDebugEnabled()) {
|
|
|
+ for (AdRankItem item : items) {
|
|
|
+ LOGGER.debug("after enter feeds model predict ctr score [{}] [{}]", item, item.getScore());
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ LOGGER.debug("union thompson ranker java execute time: [{}]", System.currentTimeMillis() - startTime);
|
|
|
+ LOGGER.debug("[union thompson ranker time java] items size={}, cost={} ",
|
|
|
+ items.size(), System.currentTimeMillis() - startTime);
|
|
|
+ return items;
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 计算 predict ecpm
|
|
|
+ */
|
|
|
+ public void calcScore(final UnionThompsonSamplingModel model, final AdRankItem item) {
|
|
|
+
|
|
|
+ double pctr = 0.0;
|
|
|
+ double pcvr = 0.0;
|
|
|
+ double ecpm = 0.0;
|
|
|
+
|
|
|
+ try {
|
|
|
+ pctr = model.ctrScore(item);
|
|
|
+ pcvr = model.cvrScore(item);
|
|
|
+ ecpm = item.getCpa() * pctr * pcvr;
|
|
|
+ } catch (
|
|
|
+ Exception e) {
|
|
|
+ LOGGER.error("score error for doc={} exception={}",
|
|
|
+ item.getAdId(), ExceptionUtils.getFullStackTrace(e));
|
|
|
+ }
|
|
|
+ item.setTf_ctr(pctr);
|
|
|
+ item.setTf_cvr(pcvr);
|
|
|
+ item.setEcpm1(ecpm);
|
|
|
+ item.setScore(ecpm);
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 并行打分 ecpm
|
|
|
+ *
|
|
|
+ * @param items
|
|
|
+ * @param model
|
|
|
+ */
|
|
|
+ private void multipleCtrScore(final List<AdRankItem> items,
|
|
|
+ final UnionThompsonSamplingModel model) {
|
|
|
+
|
|
|
+ List<Callable<Object>> calls = new ArrayList<Callable<Object>>();
|
|
|
+ for (int index = 0; index < items.size(); index++) {
|
|
|
+ final int fIndex = index;
|
|
|
+ items.get(fIndex).setScore(0.0); // 设置为原始值为0
|
|
|
+ calls.add(new Callable<Object>() {
|
|
|
+ @Override
|
|
|
+ public Object call() {
|
|
|
+ try {
|
|
|
+ calcScore(model, items.get(fIndex));
|
|
|
+ } catch (
|
|
|
+ Exception e) {
|
|
|
+ LOGGER.error("thompson exception: [{}] [{}]", items.get(fIndex).adId, ExceptionUtils.getFullStackTrace(e));
|
|
|
+ }
|
|
|
+ return new Object();
|
|
|
+ }
|
|
|
+ });
|
|
|
+ }
|
|
|
+
|
|
|
+ List<Future<Object>> futures = null;
|
|
|
+ try {
|
|
|
+ futures = executorService.invokeAll(calls, LOCAL_TIME_OUT, TimeUnit.MILLISECONDS);
|
|
|
+ } catch (
|
|
|
+ InterruptedException e) {
|
|
|
+ LOGGER.error("execute invoke fail: {}", ExceptionUtils.getFullStackTrace(e));
|
|
|
+ }
|
|
|
+
|
|
|
+ // 等待所有请求的结果返回, 超时也返回
|
|
|
+ int cancel = 0;
|
|
|
+ if (futures != null) {
|
|
|
+ for (Future<Object> future : futures) {
|
|
|
+ try {
|
|
|
+ if (!future.isDone() || future.isCancelled() || future.get() == null) {
|
|
|
+ cancel++;
|
|
|
+ }
|
|
|
+ } catch (
|
|
|
+ InterruptedException e) {
|
|
|
+ LOGGER.error("InterruptedException {}", ExceptionUtils.getFullStackTrace(e));
|
|
|
+ } catch (
|
|
|
+ ExecutionException e) {
|
|
|
+ LOGGER.error("ExecutionException {}", ExceptionUtils.getFullStackTrace(e));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ LOGGER.debug("ecpm Score {}, Total: {}, Cancel: {}", new Object[]{items.size(), cancel});
|
|
|
+ }
|
|
|
+}
|