ソースを参照

LR、FM模型

丁云鹏 10 ヶ月 前
コミット
f6038276a6

+ 3 - 3
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/filter/strategy/BlacklistContainer.java

@@ -155,7 +155,7 @@ public class BlacklistContainer {
     }
 
     public void refreshVideoTagCache() {
-        LOG.info("同步本地标签ID与视频列表的缓存任务开始");
+        // LOG.info("同步本地标签ID与视频列表的缓存任务开始");
         Map<Long, Set<Long>> tmpMap = new ConcurrentHashMap<>();
 
         if (MapUtils.isNotEmpty(tagFilterConfigMap)) {
@@ -179,13 +179,13 @@ public class BlacklistContainer {
             for (Long tagId : tagIdSet) {
                 List<WxVideoTagRel> wxVideoTagRels = wxVideoTagRelRepository.findAllByTagId(tagId);
                 Set<Long> videoIdSet = wxVideoTagRels.stream().map(WxVideoTagRel::getVideoId).collect(Collectors.toSet());
-                LOG.info("同步本地标签ID与视频列表缓存任务 -- tagId: {}, videoIdSize: {}", tagId, videoIdSet.size());
+                // LOG.info("同步本地标签ID与视频列表缓存任务 -- tagId: {}, videoIdSize: {}", tagId, videoIdSet.size());
                 tmpMap.put(tagId, videoIdSet);
             }
         }
         videoTagCache = tmpMap;
 
-        LOG.info("同步本地标签ID与视频列表的缓存任务结束");
+        // LOG.info("同步本地标签ID与视频列表的缓存任务结束");
     }
 
     public List<Long> filterUnsafeVideoByUser(List<Long> videoIds, String uid, Long hotSceneType, String cityCode, String clientIP, String mid, String usedScene, Integer appType) {

+ 20 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/score/BaseFMModelScorer.java

@@ -0,0 +1,20 @@
+package com.tzld.piaoquan.recommend.server.service.score;
+
+import com.tzld.piaoquan.recommend.server.service.score.model.FMModel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+public abstract class BaseFMModelScorer extends AbstractScorer {
+
+    private static Logger LOGGER = LoggerFactory.getLogger(BaseFMModelScorer.class);
+
+    public BaseFMModelScorer(ScorerConfigInfo scorerConfigInfo) {
+        super(scorerConfigInfo);
+    }
+
+    @Override
+    public void loadModel() {
+        doLoadModel(FMModel.class);
+    }
+}

+ 159 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/score/VlogRovFMScorer.java

@@ -0,0 +1,159 @@
+package com.tzld.piaoquan.recommend.server.service.score;
+
+
+import com.tzld.piaoquan.recommend.feature.domain.video.base.UserFeature;
+import com.tzld.piaoquan.recommend.server.common.base.RankItem;
+import com.tzld.piaoquan.recommend.server.service.score.model.FMModel;
+import com.tzld.piaoquan.recommend.server.util.JSONUtils;
+import org.apache.commons.collections4.CollectionUtils;
+import org.apache.commons.collections4.MapUtils;
+import org.apache.commons.lang.exception.ExceptionUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.*;
+import java.util.concurrent.*;
+
+
+public class VlogRovFMScorer extends BaseLRV2ModelScorer {
+
+    private static final int LOCAL_TIME_OUT = 150;
+    private final static Logger LOGGER = LoggerFactory.getLogger(VlogRovFMScorer.class);
+    private static final ExecutorService executorService = Executors.newFixedThreadPool(128);
+
+
+    public VlogRovFMScorer(ScorerConfigInfo configInfo) {
+        super(configInfo);
+    }
+
+    @Override
+    public List<RankItem> scoring(final ScoreParam param,
+                                  final UserFeature userFeature,
+                                  final List<RankItem> rankItems) {
+        throw new NoSuchMethodError();
+    }
+
+    @Override
+    public List<RankItem> scoring(final Map<String, String> sceneFeatureMap,
+                                  final Map<String, String> userFeatureMap,
+                                  final List<RankItem> rankItems) {
+        if (CollectionUtils.isEmpty(rankItems)) {
+            return rankItems;
+        }
+
+        long startTime = System.currentTimeMillis();
+        FMModel model = (FMModel) this.getModel();
+        LOGGER.debug("model size: [{}]", model.getModelSize());
+
+        List<RankItem> result = rankItems;
+        result = rankByJava(
+                sceneFeatureMap, userFeatureMap, rankItems
+        );
+
+        LOGGER.debug("ctr ranker time java items size={}, time={} ", result != null ? result.size() : 0,
+                System.currentTimeMillis() - startTime);
+
+        return result;
+    }
+
+    private List<RankItem> rankByJava(final Map<String, String> sceneFeatureMap,
+                                      final Map<String, String> userFeatureMap,
+                                      final List<RankItem> items) {
+        long startTime = System.currentTimeMillis();
+        FMModel model = (FMModel) this.getModel();
+        LOGGER.debug("model size: [{}]", model.getModelSize());
+
+        // 所有都参与打分,按照ctr排序
+        multipleCtrScore(items, userFeatureMap, sceneFeatureMap, model);
+
+        // debug log
+        if (LOGGER.isDebugEnabled()) {
+            for (int i = 0; i < items.size(); i++) {
+                LOGGER.debug("before enter feeds model predict ctr score [{}] [{}]", items.get(i), items.get(i));
+            }
+        }
+
+        Collections.sort(items);
+
+        LOGGER.debug("ctr ranker java execute time: [{}]", System.currentTimeMillis() - startTime);
+        LOGGER.debug("[ctr ranker time java] items size={}, cost={} ", items != null ? items.size() : 0,
+                System.currentTimeMillis() - startTime);
+        return items;
+    }
+
+    private void multipleCtrScore(final List<RankItem> items,
+                                  final Map<String, String> userFeatureMap,
+                                  final Map<String, String> sceneFeatureMap,
+                                  final FMModel model) {
+
+        List<Callable<Object>> calls = new ArrayList<Callable<Object>>();
+        for (int index = 0; index < items.size(); index++) {
+            final int fIndex = index;
+            calls.add(new Callable<Object>() {
+                @Override
+                public Object call() throws Exception {
+                    try {
+                        calcScore(model, items.get(fIndex), userFeatureMap, sceneFeatureMap);
+                    } catch (Exception e) {
+                        LOGGER.error("ctr exception: [{}] [{}]", items.get(fIndex).videoId, ExceptionUtils.getFullStackTrace(e));
+                    }
+                    return new Object();
+                }
+            });
+        }
+
+        List<Future<Object>> futures = null;
+        try {
+            futures = executorService.invokeAll(calls, LOCAL_TIME_OUT, TimeUnit.MILLISECONDS);
+        } catch (InterruptedException e) {
+            LOGGER.error("execute invoke fail: {}", ExceptionUtils.getFullStackTrace(e));
+        }
+
+        //等待所有请求的结果返回, 超时也返回
+        int cancel = 0;
+        if (futures != null) {
+            for (Future<Object> future : futures) {
+                try {
+                    if (!future.isDone() || future.isCancelled() || future.get() == null) {
+                        cancel++;
+                    }
+                } catch (InterruptedException e) {
+                    LOGGER.error("InterruptedException {},{}", ExceptionUtils.getFullStackTrace(e));
+                } catch (ExecutionException e) {
+                    LOGGER.error("ExecutionException {},{}", sceneFeatureMap.size(),
+                            ExceptionUtils.getFullStackTrace(e));
+                }
+            }
+        }
+    }
+
+    public double calcScore(final FMModel model,
+                            final RankItem item,
+                            final Map<String, String> userFeatureMap,
+                            final Map<String, String> sceneFeatureMap) {
+
+
+        Map<String, String> featureMap = new HashMap<>();
+        if (MapUtils.isNotEmpty(item.getFeatureMap())) {
+            featureMap.putAll(item.getFeatureMap());
+        }
+        if (MapUtils.isNotEmpty(userFeatureMap)) {
+            featureMap.putAll(userFeatureMap);
+        }
+        if (MapUtils.isNotEmpty(sceneFeatureMap)) {
+            featureMap.putAll(sceneFeatureMap);
+        }
+
+        double pro = 0.0;
+        if (MapUtils.isNotEmpty(featureMap)) {
+            try {
+                pro = model.score(featureMap);
+                LOGGER.info("fea : {}, score:{}", JSONUtils.toJson(featureMap), pro);
+            } catch (Exception e) {
+                LOGGER.error("score error for doc={} exception={}", item.getVideoId(), ExceptionUtils.getFullStackTrace(e));
+            }
+        }
+        item.setScoreRov(pro);
+        return pro;
+    }
+}

+ 3 - 2
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/score/VlogRovLRScorer.java

@@ -3,8 +3,8 @@ package com.tzld.piaoquan.recommend.server.service.score;
 
 import com.tzld.piaoquan.recommend.feature.domain.video.base.UserFeature;
 import com.tzld.piaoquan.recommend.server.common.base.RankItem;
-import com.tzld.piaoquan.recommend.server.service.score.model.LRModel;
 import com.tzld.piaoquan.recommend.server.service.score.model.LRV2Model;
+import com.tzld.piaoquan.recommend.server.util.JSONUtils;
 import org.apache.commons.collections4.CollectionUtils;
 import org.apache.commons.collections4.MapUtils;
 import org.apache.commons.lang.exception.ExceptionUtils;
@@ -42,7 +42,7 @@ public class VlogRovLRScorer extends BaseLRV2ModelScorer {
         }
 
         long startTime = System.currentTimeMillis();
-        LRModel model = (LRModel) this.getModel();
+        LRV2Model model = (LRV2Model) this.getModel();
         LOGGER.debug("model size: [{}]", model.getModelSize());
 
         List<RankItem> result = rankItems;
@@ -148,6 +148,7 @@ public class VlogRovLRScorer extends BaseLRV2ModelScorer {
         if (MapUtils.isNotEmpty(featureMap)) {
             try {
                 pro = lrModel.score(featureMap);
+                LOGGER.info("fea : {}, score:{}", JSONUtils.toJson(featureMap), pro);
             } catch (Exception e) {
                 LOGGER.error("score error for doc={} exception={}", item.getVideoId(), ExceptionUtils.getFullStackTrace(e));
             }

+ 2 - 2
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/score/model/FMModel.java

@@ -108,7 +108,7 @@ public class FMModel extends Model {
         LOGGER.info("[MODELLOAD] before model load, key size: {}, current time: {}", model.size(), curTime);
         //first stage
         while ((line = input.readLine()) != null) {
-            String[] items = line.split(" ");
+            String[] items = line.split("\t");
             if (items.length < 9) {
                 continue;
             }
@@ -124,7 +124,7 @@ public class FMModel extends Model {
         LOGGER.info("[MODELLOAD] after first stage model load, key size: {}, current time: {}", model.size(), curTime);
         //final stage
         while ((line = input.readLine()) != null) {
-            String[] items = line.split(" ");
+            String[] items = line.split("\t");
             if (items.length < 9) {
                 continue;
             }

+ 2 - 2
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/score/model/LRV2Model.java

@@ -75,7 +75,7 @@ public class LRV2Model extends Model {
         Integer curTime = new Long(System.currentTimeMillis() / 1000).intValue();
         //first stage
         while ((line = input.readLine()) != null) {
-            String[] items = line.split(" ");
+            String[] items = line.split("\t");
             if (items.length < 2) {
                 continue;
             }
@@ -94,7 +94,7 @@ public class LRV2Model extends Model {
         LOGGER.info("[MODELLOAD] after first stage model load, key size: {}, current time: {}", lrModel.size(), curTime);
         //final stage
         while ((line = input.readLine()) != null) {
-            String[] items = line.split(" ");
+            String[] items = line.split("\t");
             if (items.length < 2) {
                 continue;
             }

+ 2 - 2
recommend-server-service/src/main/resources/feeds_score_config_20240609.conf

@@ -1,7 +1,7 @@
 scorer-config = {
   rov-score-config = {
-    scorer-name = "com.tzld.piaoquan.recommend.server.service.score.VlogRovLRScorer"
+    scorer-name = "com.tzld.piaoquan.recommend.server.service.score.VlogRovFMScorer"
     scorer-priority = 96
-    model-path = "zhangbo/model_aka0.txt"
+    model-path = "zhangbo/model_aka8.txt"
   }
 }