Selaa lähdekoodia

feat:完善打分策略

zhaohaipeng 1 vuosi sitten
vanhempi
commit
7ad54e0864

+ 21 - 0
ad-engine-commons/src/main/java/com/tzld/piaoquan/ad/engine/commons/score/BaseLRV2ModelScorer.java

@@ -0,0 +1,21 @@
+package com.tzld.piaoquan.ad.engine.commons.score;
+
+import com.tzld.piaoquan.ad.engine.commons.score.model.LRV2Model;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+public abstract class BaseLRV2ModelScorer extends AbstractScorer {
+
+    private static Logger LOGGER = LoggerFactory.getLogger(BaseLRV2ModelScorer.class);
+
+    public BaseLRV2ModelScorer(ScorerConfigInfo scorerConfigInfo) {
+        super(scorerConfigInfo);
+    }
+
+    @Override
+    public void loadModel() {
+        doLoadModel(LRV2Model.class);
+    }
+
+}

+ 111 - 0
ad-engine-commons/src/main/java/com/tzld/piaoquan/ad/engine/commons/score/model/LRV2Model.java

@@ -0,0 +1,111 @@
+package com.tzld.piaoquan.ad.engine.commons.score.model;
+
+
+import org.apache.commons.collections4.MapUtils;
+import org.apache.commons.lang.math.NumberUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.util.HashMap;
+import java.util.Map;
+
+
+public class LRV2Model extends Model {
+    protected static final int MODEL_FIRST_LOAD_COUNT = 1 << 25; // 32M
+    private static final Logger LOGGER = LoggerFactory.getLogger(LRV2Model.class);
+    private Map<String, Float> lrModel;
+
+    public void putFeature(Map<String, Float> model, String featureKey, float weight) {
+        model.put(featureKey, weight);
+    }
+
+    public float getWeight(Map<String, Float> model, String featureKey) {
+        return model.getOrDefault(featureKey, 0.0f);
+    }
+
+    @Override
+    public int getModelSize() {
+        if (this.lrModel == null)
+            return 0;
+        return lrModel.size();
+    }
+
+    public void cleanModel() {
+        this.lrModel = null;
+    }
+
+    public Float score(Map<String, String> featureMap) {
+        float sum = 0.0f;
+
+        if (MapUtils.isNotEmpty(featureMap)) {
+            for (Map.Entry<String, String> e : featureMap.entrySet()) {
+                float w = getWeight(this.lrModel, e.getKey());
+                sum += w * NumberUtils.toFloat(e.getValue(), 0.0f);
+            }
+
+            float biasW = lrModel.get("bias");
+            sum += biasW;
+        }
+
+
+        return (float) (1.0f / (1 + Math.exp(-sum)));
+    }
+
+    /**
+     * 目前模型比较大,分两个阶段load模型
+     * (1). load 8M 模型, 并更新;
+     * (2). load 剩余的模型
+     * 中间提供一段时间有损的打分服务
+     *
+     * @param in
+     * @return
+     * @throws IOException
+     */
+    @Override
+    public boolean loadFromStream(InputStreamReader in) throws IOException {
+
+        Map<String, Float> model = new HashMap<>();
+        BufferedReader input = new BufferedReader(in);
+        String line = null;
+        int cnt = 0;
+
+        Integer curTime = new Long(System.currentTimeMillis() / 1000).intValue();
+        //first stage
+        while ((line = input.readLine()) != null) {
+            String[] items = line.split("\t");
+            if (items.length < 2) {
+                continue;
+            }
+
+            putFeature(model, items[0], Float.valueOf(items[1].trim()).floatValue());
+            if (cnt++ < 10) {
+                LOGGER.debug("fea: " + items[0] + ", weight: " + items[1]);
+            }
+            if (cnt > MODEL_FIRST_LOAD_COUNT) {
+                break;
+            }
+        }
+        //model update
+        this.lrModel = model;
+
+        LOGGER.info("[MODELLOAD] after first stage model load, key size: {}, current time: {}", lrModel.size(), curTime);
+        //final stage
+        while ((line = input.readLine()) != null) {
+            String[] items = line.split("\t");
+            if (items.length < 2) {
+                continue;
+            }
+            putFeature(model, items[0], Float.valueOf(items[1]).floatValue());
+        }
+        LOGGER.info("[MODELLOAD] after model load, key size: {}, current time: {}", lrModel.size(), curTime);
+
+        LOGGER.info("[MODELLOAD] model load over and size " + cnt);
+        input.close();
+        in.close();
+        return true;
+    }
+
+}

+ 153 - 0
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/VlogRovLRScorer.java

@@ -0,0 +1,153 @@
+package com.tzld.piaoquan.ad.engine.service.score;
+
+
+import com.tzld.piaoquan.ad.engine.commons.score.BaseLRV2ModelScorer;
+import com.tzld.piaoquan.ad.engine.commons.score.ScoreParam;
+import com.tzld.piaoquan.ad.engine.commons.score.ScorerConfigInfo;
+import com.tzld.piaoquan.ad.engine.commons.score.model.LRV2Model;
+import com.tzld.piaoquan.recommend.feature.domain.ad.base.AdRankItem;
+import com.tzld.piaoquan.recommend.feature.domain.ad.base.UserAdFeature;
+import org.apache.commons.collections4.CollectionUtils;
+import org.apache.commons.collections4.MapUtils;
+import org.apache.commons.lang.exception.ExceptionUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.*;
+import java.util.concurrent.*;
+
+
+public class VlogRovLRScorer extends BaseLRV2ModelScorer {
+
+    private static final int LOCAL_TIME_OUT = 150;
+    private final static Logger LOGGER = LoggerFactory.getLogger(VlogRovLRScorer.class);
+    private static final ExecutorService executorService = Executors.newFixedThreadPool(128);
+
+
+    public VlogRovLRScorer(ScorerConfigInfo configInfo) {
+        super(configInfo);
+    }
+
+    @Override
+    public List<AdRankItem> scoring(final ScoreParam param,
+                                    final UserAdFeature userFeature,
+                                    final List<AdRankItem> rankItems) {
+        throw new NoSuchMethodError();
+    }
+
+    public List<AdRankItem> scoring(final Map<String, String> sceneFeatureMap,
+                                    final Map<String, String> userFeatureMap,
+                                    final List<AdRankItem> rankItems) {
+        if (CollectionUtils.isEmpty(rankItems)) {
+            return rankItems;
+        }
+
+        long startTime = System.currentTimeMillis();
+        LRV2Model model = (LRV2Model) this.getModel();
+        LOGGER.debug("model size: [{}]", model.getModelSize());
+
+        List<AdRankItem> result = rankByJava(sceneFeatureMap, userFeatureMap, rankItems);
+
+        LOGGER.debug("ctr ranker time java items size={}, time={} ", result.size(),
+                System.currentTimeMillis() - startTime);
+
+        return result;
+    }
+
+    private List<AdRankItem> rankByJava(final Map<String, String> sceneFeatureMap,
+                                        final Map<String, String> userFeatureMap,
+                                        final List<AdRankItem> items) {
+        long startTime = System.currentTimeMillis();
+        LRV2Model model = (LRV2Model) this.getModel();
+        // 所有都参与打分,按照ctr排序
+        multipleCtrScore(items, userFeatureMap, sceneFeatureMap, model);
+
+        // debug log
+        if (LOGGER.isDebugEnabled()) {
+            for (AdRankItem item : items) {
+                LOGGER.debug("before enter feeds model predict ctr score [{}] [{}]", item, item);
+            }
+        }
+
+        Collections.sort(items);
+
+        LOGGER.debug("[ctr ranker time java] items size={}, cost={} ", items.size(),
+                System.currentTimeMillis() - startTime);
+        return items;
+    }
+
+    private void multipleCtrScore(final List<AdRankItem> items,
+                                  final Map<String, String> userFeatureMap,
+                                  final Map<String, String> sceneFeatureMap,
+                                  final LRV2Model model) {
+
+        List<Callable<Object>> calls = new ArrayList<Callable<Object>>();
+        for (int index = 0; index < items.size(); index++) {
+            final int fIndex = index;
+            calls.add(new Callable<Object>() {
+                @Override
+                public Object call() throws Exception {
+                    try {
+                        calcScore(model, items.get(fIndex), userFeatureMap, sceneFeatureMap);
+                    } catch (Exception e) {
+                        LOGGER.error("ctr exception: [{}] [{}]", items.get(fIndex), ExceptionUtils.getFullStackTrace(e));
+                    }
+                    return new Object();
+                }
+            });
+        }
+
+        List<Future<Object>> futures = null;
+        try {
+            futures = executorService.invokeAll(calls, LOCAL_TIME_OUT, TimeUnit.MILLISECONDS);
+        } catch (InterruptedException e) {
+            LOGGER.error("execute invoke fail: {}", ExceptionUtils.getFullStackTrace(e));
+        }
+
+        // 等待所有请求的结果返回, 超时也返回
+        int cancel = 0;
+        if (futures != null) {
+            for (Future<Object> future : futures) {
+                try {
+                    if (!future.isDone() || future.isCancelled() || future.get() == null) {
+                        cancel++;
+                    }
+                } catch (InterruptedException e) {
+                    LOGGER.error("InterruptedException: ", e);
+                } catch (ExecutionException e) {
+                    LOGGER.error("ExecutionException {},{}", sceneFeatureMap.size(),
+                            ExceptionUtils.getFullStackTrace(e));
+                }
+            }
+        }
+    }
+
+    public double calcScore(final LRV2Model lrModel,
+                            final AdRankItem item,
+                            final Map<String, String> userFeatureMap,
+                            final Map<String, String> sceneFeatureMap) {
+
+
+        Map<String, String> featureMap = new HashMap<>();
+        if (MapUtils.isNotEmpty(item.getFeatureMap())) {
+            featureMap.putAll(item.getFeatureMap());
+        }
+        if (MapUtils.isNotEmpty(userFeatureMap)) {
+            featureMap.putAll(userFeatureMap);
+        }
+        if (MapUtils.isNotEmpty(sceneFeatureMap)) {
+            featureMap.putAll(sceneFeatureMap);
+        }
+
+        double pro = 0.0;
+        if (MapUtils.isNotEmpty(featureMap)) {
+            try {
+                pro = lrModel.score(featureMap);
+                // LOGGER.info("fea : {}, score:{}", JSONUtils.toJson(featureMap), pro);
+            } catch (Exception e) {
+                LOGGER.error("score error for doc={} exception={}", item.getVideoId(), ExceptionUtils.getFullStackTrace(e));
+            }
+        }
+        return pro;
+    }
+}

+ 17 - 0
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/dto/AdDirectionScore.java

@@ -0,0 +1,17 @@
+package com.tzld.piaoquan.ad.engine.service.score.dto;
+
+import lombok.Builder;
+import lombok.Data;
+
+import java.util.concurrent.ConcurrentHashMap;
+
+@Data
+@Builder
+public class AdDirectionScore {
+
+    private Double exponent;
+
+
+    private ConcurrentHashMap<String, Double> scoreDetail;
+
+}

+ 1 - 1
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/dto/AdPlatformCreativeDTO.java

@@ -37,7 +37,7 @@ public class AdPlatformCreativeDTO {
     /**
      * 定向打分参数
      */
-    private HumanDto humanScore;
+    private AdDirectionScore adDirectionScore;
 
     public static void main(String[] args) {
         System.out.println(JSON.toJSONString(AdPlatformCreativeDTO.builder()

+ 0 - 18
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/dto/HumanDto.java

@@ -1,18 +0,0 @@
-package com.tzld.piaoquan.ad.engine.service.score.dto;
-
-import lombok.AllArgsConstructor;
-import lombok.Data;
-import lombok.NoArgsConstructor;
-
-import java.util.Map;
-
-@Data
-@NoArgsConstructor
-@AllArgsConstructor
-public class HumanDto {
-
-    private Double exponent;
-
-    private Map<String, Double> detail;
-
-}