Bladeren bron

feat:添加FM模型相关代码

zhaohaipeng 10 maanden geleden
bovenliggende
commit
2d4452b7c0

+ 20 - 0
ad-engine-commons/src/main/java/com/tzld/piaoquan/ad/engine/commons/score/BaseFMModelScorer.java

@@ -0,0 +1,20 @@
+package com.tzld.piaoquan.ad.engine.commons.score;
+
+import com.tzld.piaoquan.ad.engine.commons.score.model.FMModel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+public abstract class BaseFMModelScorer extends AbstractScorer {
+
+    private static Logger LOGGER = LoggerFactory.getLogger(BaseFMModelScorer.class);
+
+    public BaseFMModelScorer(ScorerConfigInfo scorerConfigInfo) {
+        super(scorerConfigInfo);
+    }
+
+    @Override
+    public void loadModel() {
+        doLoadModel(FMModel.class);
+    }
+}

+ 144 - 0
ad-engine-commons/src/main/java/com/tzld/piaoquan/ad/engine/commons/score/model/FMModel.java

@@ -0,0 +1,144 @@
+package com.tzld.piaoquan.ad.engine.commons.score.model;
+
+
+import org.apache.commons.collections4.MapUtils;
+import org.apache.commons.lang.math.NumberUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+
+public class FMModel extends Model {
+    protected static final int MODEL_FIRST_LOAD_COUNT = 1 << 25; // 32M
+    private static final Logger LOGGER = LoggerFactory.getLogger(FMModel.class);
+    private Map<String, List<Float>> model;
+
+    public void putFeature(Map<String, List<Float>> model, String[] items) {
+
+        String featureKey = items[0];
+        List<Float> weights = new ArrayList<>();
+        for (int i = 1; i < items.length; i++) {
+            weights.add(Float.valueOf(items[i]));
+        }
+
+        model.put(featureKey, weights);
+    }
+
+    public float getWeight(Map<String, List<Float>> model, String featureKey, int index) {
+        if (!model.containsKey(featureKey)) {
+            return 0.0f;
+        }
+
+        return model.get(featureKey).get(index);
+
+    }
+
+    @Override
+    public int getModelSize() {
+        if (this.model == null)
+            return 0;
+        return model.size();
+    }
+
+    public void cleanModel() {
+        this.model = null;
+    }
+
+    public Float score(Map<String, String> featureMap) {
+        float sum = 0.0f;
+
+        if (MapUtils.isNotEmpty(featureMap)) {
+            // 计算 sum w*x
+            float sum0 = 0.0f;
+            for (Map.Entry<String, String> e : featureMap.entrySet()) {
+                float x = NumberUtils.toFloat(e.getValue(), 0.0f);
+                float w = getWeight(this.model, e.getKey(), 0);
+                sum0 += w * x;
+            }
+            sum += sum0;
+
+            // 计算 sum v*v*x*X
+            float sum1 = 0.0f;
+            for (int i = 1; i < 9; i++) {
+                float sum10 = 0.0f;
+                float sum11 = 0.0f;
+                for (Map.Entry<String, String> e : featureMap.entrySet()) {
+                    float x = NumberUtils.toFloat(e.getValue(), 0.0f);
+                    float v = getWeight(this.model, e.getKey(), i);
+                    float d = v * x;
+                    sum10 += d;
+                    sum11 += d * d;
+                }
+                sum1 += sum10 * sum10 - sum11;
+            }
+            sum1 = 0.5f * sum1;
+            float biasW = model.get("bias").get(0);
+            sum = biasW + sum0 + sum1;
+        }
+
+        return (float) (1.0f / (1 + Math.exp(-sum)));
+    }
+
+    /**
+     * 目前模型比较大,分两个阶段load模型
+     * (1). load 8M 模型, 并更新;
+     * (2). load 剩余的模型
+     * 中间提供一段时间有损的打分服务
+     *
+     * @param in
+     * @return
+     * @throws IOException
+     */
+    @Override
+    public boolean loadFromStream(InputStreamReader in) throws IOException {
+
+        Map<String, List<Float>> model = new HashMap<>();
+        BufferedReader input = new BufferedReader(in);
+        String line = null;
+        int cnt = 0;
+
+        Integer curTime = new Long(System.currentTimeMillis() / 1000).intValue();
+        LOGGER.info("[MODELLOAD] before model load, key size: {}, current time: {}", model.size(), curTime);
+        //first stage
+        while ((line = input.readLine()) != null) {
+            String[] items = line.split("\t");
+            if (items.length < 9) {
+                if (items[0].equals("bias")) {
+                    putFeature(model, items);
+                }
+                continue;
+            }
+
+            putFeature(model, items);
+            if (cnt > MODEL_FIRST_LOAD_COUNT) {
+                break;
+            }
+        }
+        //model update
+        this.model = model;
+
+        LOGGER.info("[MODELLOAD] after first stage model load, key size: {}, current time: {}", model.size(), curTime);
+        //final stage
+        while ((line = input.readLine()) != null) {
+            String[] items = line.split("\t");
+            if (items.length < 9) {
+                continue;
+            }
+            putFeature(model, items);
+        }
+        LOGGER.info("[MODELLOAD] after model load, key size: {}, current time: {}", model.size(), curTime);
+
+        LOGGER.info("[MODELLOAD] model load over and size " + cnt);
+        input.close();
+        in.close();
+        return true;
+    }
+
+}

+ 1 - 1
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/VideoAdThompsonScorerV2.java

@@ -135,7 +135,7 @@ public class VideoAdThompsonScorerV2 {
             item.setCpa(cpa);
             item.setAdId(dto.getCreativeId());
             item.setScore(score);
-            item.setExt(ext);
+            item.setFeature(ext);
             item.setVideoId(param.getVideoId());
             item.setScore_type(663);
             item.setWeight(dto.getWeight());

+ 160 - 0
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/VlogRovFMScorer.java

@@ -0,0 +1,160 @@
+package com.tzld.piaoquan.ad.engine.service.score;
+
+
+import com.tzld.piaoquan.ad.engine.commons.score.BaseFMModelScorer;
+import com.tzld.piaoquan.ad.engine.commons.score.ScoreParam;
+import com.tzld.piaoquan.ad.engine.commons.score.ScorerConfigInfo;
+import com.tzld.piaoquan.ad.engine.commons.score.model.FMModel;
+import com.tzld.piaoquan.recommend.feature.domain.ad.base.AdRankItem;
+import com.tzld.piaoquan.recommend.feature.domain.ad.base.UserAdFeature;
+import org.apache.commons.collections4.CollectionUtils;
+import org.apache.commons.lang.exception.ExceptionUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.*;
+
+
+public class VlogRovFMScorer extends BaseFMModelScorer {
+
+    private static final int LOCAL_TIME_OUT = 150;
+    private final static Logger LOGGER = LoggerFactory.getLogger(VlogRovFMScorer.class);
+    private static final ExecutorService executorService = Executors.newFixedThreadPool(128);
+
+
+    public VlogRovFMScorer(ScorerConfigInfo configInfo) {
+        super(configInfo);
+    }
+
+    @Override
+    public List<AdRankItem> scoring(final ScoreParam param,
+                                    final UserAdFeature userAdFeature,
+                                    final List<AdRankItem> rankItems) {
+        throw new NoSuchMethodError();
+    }
+
+    public List<AdRankItem> scoring(final Map<String, String> sceneFeatureMap,
+                                    final Map<String, String> userFeatureMap,
+                                    final List<AdRankItem> rankItems) {
+        if (CollectionUtils.isEmpty(rankItems)) {
+            return rankItems;
+        }
+
+        long startTime = System.currentTimeMillis();
+        FMModel model = (FMModel) this.getModel();
+
+        List<AdRankItem> result = rankByJava(sceneFeatureMap, userFeatureMap, rankItems);
+
+        LOGGER.debug("ctr ranker time java items size={}, time={} ", result != null ? result.size() : 0,
+                System.currentTimeMillis() - startTime);
+
+        return result;
+    }
+
+    private List<AdRankItem> rankByJava(final Map<String, String> sceneFeatureMap,
+                                      final Map<String, String> userFeatureMap,
+                                      final List<AdRankItem> items) {
+        long startTime = System.currentTimeMillis();
+        FMModel model = (FMModel) this.getModel();
+        LOGGER.debug("model size: [{}]", model.getModelSize());
+
+        // 所有都参与打分,按照ctr排序
+        multipleCtrScore(items, userFeatureMap, sceneFeatureMap, model);
+
+        // debug log
+        if (LOGGER.isDebugEnabled()) {
+            for (int i = 0; i < items.size(); i++) {
+                LOGGER.debug("before enter feeds model predict ctr score [{}] [{}]", items.get(i), items.get(i));
+            }
+        }
+
+        Collections.sort(items);
+
+        LOGGER.debug("ctr ranker java execute time: [{}]", System.currentTimeMillis() - startTime);
+        LOGGER.debug("[ctr ranker time java] items size={}, cost={} ", items != null ? items.size() : 0,
+                System.currentTimeMillis() - startTime);
+        return items;
+    }
+
+    private void multipleCtrScore(final List<AdRankItem> items,
+                                  final Map<String, String> userFeatureMap,
+                                  final Map<String, String> sceneFeatureMap,
+                                  final FMModel model) {
+
+        List<Callable<Object>> calls = new ArrayList<Callable<Object>>();
+        for (int index = 0; index < items.size(); index++) {
+            final int fIndex = index;
+            calls.add(new Callable<Object>() {
+                @Override
+                public Object call() throws Exception {
+                    try {
+                        calcScore(model, items.get(fIndex), userFeatureMap, sceneFeatureMap);
+                    } catch (Exception e) {
+                        LOGGER.error("ctr exception: [{}] [{}]", items.get(fIndex).videoId, ExceptionUtils.getFullStackTrace(e));
+                    }
+                    return new Object();
+                }
+            });
+        }
+
+        List<Future<Object>> futures = null;
+        try {
+            futures = executorService.invokeAll(calls, LOCAL_TIME_OUT, TimeUnit.MILLISECONDS);
+        } catch (InterruptedException e) {
+            LOGGER.error("execute invoke fail: {}", ExceptionUtils.getFullStackTrace(e));
+        }
+
+        // 等待所有请求的结果返回, 超时也返回
+        int cancel = 0;
+        if (futures != null) {
+            for (Future<Object> future : futures) {
+                try {
+                    if (!future.isDone() || future.isCancelled() || future.get() == null) {
+                        cancel++;
+                    }
+                } catch (InterruptedException e) {
+                    LOGGER.error("InterruptedException {},{}", ExceptionUtils.getFullStackTrace(e));
+                } catch (ExecutionException e) {
+                    LOGGER.error("ExecutionException {},{}", sceneFeatureMap.size(),
+                            ExceptionUtils.getFullStackTrace(e));
+                }
+            }
+        }
+    }
+
+    public double calcScore(final FMModel model,
+                            final AdRankItem item,
+                            final Map<String, String> userFeatureMap,
+                            final Map<String, String> sceneFeatureMap) {
+
+
+        // Map<String, String> featureMap = new HashMap<>();
+        // if (MapUtils.isNotEmpty(item.getFeatureMap())) {
+        //     featureMap.putAll(item.getFeatureMap());
+        // }
+        // if (MapUtils.isNotEmpty(userFeatureMap)) {
+        //     featureMap.putAll(userFeatureMap);
+        // }
+        // if (MapUtils.isNotEmpty(sceneFeatureMap)) {
+        //     featureMap.putAll(sceneFeatureMap);
+        // }
+
+        double pro = 0.0;
+        // if (MapUtils.isNotEmpty(featureMap)) {
+        //     try {
+        //         pro = model.score(featureMap);
+        //         // LOGGER.info("fea : {}, score:{}", JSONUtils.toJson(featureMap), pro);
+        //     } catch (Exception e) {
+        //         LOGGER.error("score error for doc={} exception={}", item.getVideoId(), ExceptionUtils.getFullStackTrace(e));
+        //     }
+        // }
+        // item.setScoreRov(pro);
+        // item.getScoresMap().put("RovFMScore", pro);
+        // item.setAllFeatureMap(featureMap);
+        return pro;
+    }
+}