Pārlūkot izejas kodu

Merge branch 'master' into feature/rov_nor_rank

jch 4 mēneši atpakaļ
vecāks
revīzija
603d1c4bb8

+ 2 - 1
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/framework/score/AbstractScorer.java

@@ -10,6 +10,7 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 
@@ -37,7 +38,7 @@ public abstract class AbstractScorer {
         if (StringUtils.isNotBlank(modelPath)) {
             try {
                 // 使用 modelPath 作为 modelName 注册
-                modelManager.registerModel(modelPath, modelPath, modelClass);
+                modelManager.registerModel(modelPath, modelPath, modelClass, Collections.emptyMap());
                 LOGGER.info("register model success, model path [{}], model class [{}]", modelPath, modelClass);
             } catch (ModelManager.ModelRegisterException e) {
                 LOGGER.error("register model fail [{}]:[{}]", modelPath, e);

+ 5 - 5
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/filter/strategy/VideoSourceTypeStrategy.java

@@ -154,11 +154,11 @@ public class VideoSourceTypeStrategy implements FilterStrategy {
                         || (vid2UidMap.containsKey(l) && notUserUploadUserIds.contains(vid2UidMap.get(l))))
                 .collect(Collectors.toList());
 
-        log.info("VideoSourceTypeStrategy \t param={} \t before={} \t " +
-                        "after={}",
-                JSONUtils.toJson(param),
-                JSONUtils.toJson(param.getVideoIds()),
-                JSONUtils.toJson(videoIds));
+//        log.info("VideoSourceTypeStrategy \t param={} \t before={} \t " +
+//                        "after={}",
+//                JSONUtils.toJson(param),
+//                JSONUtils.toJson(param.getVideoIds()),
+//                JSONUtils.toJson(videoIds));
 
         return videoIds;
     }

+ 1 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/recall/RecallService.java

@@ -224,6 +224,7 @@ public class RecallService implements ApplicationContextAware {
                     }
                 } else if (param.getFlowPoolAbtestGroup().equals(FlowPoolConstants.EXPERIMENTAL_FLOW_SET_LEVEL_SCORE)) {
                     strategies.add(strategyMap.get(QuickFlowPoolWithLevelScoreRecallStrategy.class.getSimpleName()));
+                    // 在执行中
                     strategies.add(strategyMap.get(FlowPoolWithLevelScoreRecallStrategy.class.getSimpleName()));
                 } else {
                     strategies.add(strategyMap.get(QuickFlowPoolWithScoreRecallStrategy.class.getSimpleName()));

+ 84 - 4
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/recall/strategy/FlowPoolWithLevelScoreRecallStrategy.java

@@ -1,18 +1,31 @@
 package com.tzld.piaoquan.recommend.server.service.recall.strategy;
 
+import com.ctrip.framework.apollo.spring.annotation.ApolloJsonValue;
+import com.google.common.collect.Lists;
+import com.tzld.piaoquan.recommend.server.model.Video;
+import com.tzld.piaoquan.recommend.server.service.filter.FilterParam;
+import com.tzld.piaoquan.recommend.server.service.filter.FilterResult;
 import com.tzld.piaoquan.recommend.server.service.flowpool.FlowPoolConfigService;
 import com.tzld.piaoquan.recommend.server.service.flowpool.FlowPoolConstants;
+import com.tzld.piaoquan.recommend.server.service.recall.FilterParamFactory;
 import com.tzld.piaoquan.recommend.server.service.recall.RecallParam;
+import com.tzld.piaoquan.recommend.server.service.score.ScorerUtils;
+import com.tzld.piaoquan.recommend.server.service.score4recall.ScorerPipeline4Recall;
+import com.tzld.piaoquan.recommend.server.util.JSONUtils;
 import lombok.Data;
+import lombok.extern.slf4j.Slf4j;
 import org.apache.commons.collections4.CollectionUtils;
 import org.apache.commons.lang3.RandomUtils;
+import org.apache.commons.lang3.math.NumberUtils;
 import org.apache.commons.lang3.tuple.Pair;
 import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.data.redis.core.ZSetOperations;
 import org.springframework.stereotype.Service;
 
 import java.math.BigDecimal;
 import java.math.RoundingMode;
 import java.util.*;
+import java.util.stream.Collectors;
 
 import static com.tzld.piaoquan.recommend.server.service.flowpool.FlowPoolConstants.KEY_WITH_LEVEL_SCORE_FORMAT;
 
@@ -21,7 +34,8 @@ import static com.tzld.piaoquan.recommend.server.service.flowpool.FlowPoolConsta
  */
 @Service
 public class FlowPoolWithLevelScoreRecallStrategy extends AbstractFlowPoolWithLevelScoreRecallStrategy {
-
+    @ApolloJsonValue("${ifOneLevelRandom:true}")
+    private boolean ifOneLevelRandom;
     @Autowired
     private FlowPoolConfigService flowPoolConfigService;
 
@@ -29,11 +43,12 @@ public class FlowPoolWithLevelScoreRecallStrategy extends AbstractFlowPoolWithLe
     Pair<String, String> flowPoolKeyAndLevel(RecallParam param) {
         //# 1. 获取流量池各层级分发概率权重
         Map<String, Double> levelWeightMap = flowPoolConfigService.getLevelWeight();
+
         // 2. 判断各层级是否有视频需分发
         List<LevelWeight> availableLevels = new ArrayList<>();
         for (Map.Entry<String, Double> entry : levelWeightMap.entrySet()) {
             String levelKey = String.format(KEY_WITH_LEVEL_SCORE_FORMAT, param.getAppType(), entry.getKey());
-            if (redisTemplate.hasKey(levelKey)) {
+            if (Boolean.TRUE.equals(redisTemplate.hasKey(levelKey))) {
                 LevelWeight lw = new LevelWeight();
                 lw.setLevel(entry.getKey());
                 lw.setLevelKey(levelKey);
@@ -41,6 +56,8 @@ public class FlowPoolWithLevelScoreRecallStrategy extends AbstractFlowPoolWithLe
                 availableLevels.add(lw);
             }
         }
+
+        //log.info("availableLevels {}", JSONUtils.toJson(availableLevels));
         if (CollectionUtils.isEmpty(availableLevels)) {
             return Pair.of("", "");
         }
@@ -52,7 +69,7 @@ public class FlowPoolWithLevelScoreRecallStrategy extends AbstractFlowPoolWithLe
         BigDecimal weightSumBD = new BigDecimal(weightSum);
         double level_p_low = 0;
         double weight_temp = 0;
-        double level_p_up;
+        double level_p_up = 0;
         Map<String, LevelP> level_p_mapping = new HashMap<>();
         for (LevelWeight lw : availableLevels) {
             BigDecimal bd = new BigDecimal(weight_temp + lw.getWeight());
@@ -68,7 +85,6 @@ public class FlowPoolWithLevelScoreRecallStrategy extends AbstractFlowPoolWithLe
         }
 
         // 4. 随机生成[0,1)之间数,返回相应概率区间的key
-
         double random_p = RandomUtils.nextDouble(0, 1);
         for (Map.Entry<String, LevelP> entry : level_p_mapping.entrySet()) {
             if (random_p >= entry.getValue().getMin()
@@ -97,4 +113,68 @@ public class FlowPoolWithLevelScoreRecallStrategy extends AbstractFlowPoolWithLe
     public String pushFrom() {
         return FlowPoolConstants.PUSH_FORM;
     }
+
+    @Override
+    public List<Video> recall(RecallParam param) {
+        Pair<String, String> flowPoolKeyAndLevel = flowPoolKeyAndLevel(param);
+        String flowPoolKey = flowPoolKeyAndLevel.getLeft();
+        String level = flowPoolKeyAndLevel.getRight();
+        Set<ZSetOperations.TypedTuple<String>> data = redisTemplate.opsForZSet().reverseRangeWithScores(flowPoolKey, 0, 1000);
+        if (CollectionUtils.isEmpty(data)) {
+            return null;
+        }
+        Map<String, String> videoFlowPoolMap = new LinkedHashMap<>();
+        Map<Long, String> videoFlowPoolMap_ = new LinkedHashMap<>();
+        for (ZSetOperations.TypedTuple<String> value : data) {
+            String[] values = Objects.requireNonNull(value.getValue()).split("-");
+            videoFlowPoolMap.put(values[0], values[1]);
+            videoFlowPoolMap_.put(NumberUtils.toLong(values[0], 0),  values[1]);
+        }
+        Map<Long, Double> resultmap = null;
+        if ("1".equals(level) && ifOneLevelRandom) {
+            // 流量池一层改为全随机
+            int limitSize = 60;
+            List<Long> keyList = new ArrayList<>(videoFlowPoolMap_.keySet());
+            Collections.shuffle(keyList);
+            resultmap = keyList.stream().limit(limitSize).collect(Collectors.toMap(
+                    key -> key,
+                    key -> Math.random()
+            ));
+        } else {
+            ScorerPipeline4Recall pipeline = ScorerUtils.getScorerPipeline4Recall("feeds_recall_config_tomson.conf");
+            List<List<Pair<Long, Double>>> results = pipeline.recall(videoFlowPoolMap);
+            List<Pair<Long, Double>> result = results.get(0);
+            resultmap = result.stream()
+                    .collect(Collectors.toMap(
+                            Pair::getLeft, // 键是Pair的left值
+                            Pair::getRight, // 值是Pair的right值
+                            (existingValue, newValue) -> existingValue, // 如果键冲突,选择保留现有的值(或者你可以根据需要定义其他合并策略)
+                            LinkedHashMap::new // 使用LinkedHashMap来保持插入顺序(如果需要的话)
+                    ));
+        }
+
+        // 3 召回内部过滤
+        FilterParam filterParam = FilterParamFactory.create(param, new ArrayList<>(resultmap.keySet()));
+        filterParam.setForceTruncation(10000);
+        filterParam.setConcurrent(true);
+        filterParam.setNotUsePreView(false);
+        FilterResult filterResult = filterService.filter(filterParam);
+        List<Video> videosResult = new ArrayList<>();
+        if (filterResult != null && CollectionUtils.isNotEmpty(filterResult.getVideoIds())) {
+            Map<Long, Double> finalResultmap = resultmap;
+            filterResult.getVideoIds().forEach(vid -> {
+                Video recallData = new Video();
+                recallData.setVideoId(vid);
+                recallData.setAbCode(param.getAbCode());
+                recallData.setRovScore(finalResultmap.getOrDefault(vid, 0.0));
+                recallData.setPushFrom(pushFrom());
+                recallData.setFlowPool(videoFlowPoolMap_.get(vid));
+                recallData.setFlowPoolAbtestGroup(param.getFlowPoolAbtestGroup());
+                recallData.setLevel(level);
+                videosResult.add(recallData);
+            });
+        }
+        videosResult.sort(Comparator.comparingDouble(o -> -o.getRovScore()));
+        return videosResult;
+    }
 }

+ 5 - 1
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/score/AbstractScorer.java

@@ -1,6 +1,7 @@
 package com.tzld.piaoquan.recommend.server.service.score;
 
 
+import com.typesafe.config.ConfigObject;
 import com.tzld.piaoquan.recommend.feature.domain.video.base.UserFeature;
 import com.tzld.piaoquan.recommend.server.common.base.RankItem;
 import com.tzld.piaoquan.recommend.server.service.score.model.Model;
@@ -10,6 +11,7 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 
@@ -37,7 +39,9 @@ public abstract class AbstractScorer {
         if (StringUtils.isNotBlank(modelPath)) {
             try {
                 // 使用 modelPath 作为 modelName 注册
-                modelManager.registerModel(modelPath, modelPath, modelClass);
+                ConfigObject paramMap = scorerConfigInfo.getParamMap();
+                modelManager.registerModel(modelPath, modelPath, modelClass, paramMap == null ?
+                        Collections.emptyMap() : paramMap.unwrapped());
                 LOGGER.info("register model success, model path [{}], model class [{}]", modelPath, modelClass);
             } catch (ModelManager.ModelRegisterException e) {
                 LOGGER.error("register model fail [{}]:[{}]", modelPath, e);

+ 10 - 1
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/score/model/Model.java

@@ -3,13 +3,22 @@ package com.tzld.piaoquan.recommend.server.service.score.model;
 
 import java.io.InputStream;
 import java.io.InputStreamReader;
+import java.util.Map;
+
+public abstract class Model {
+    protected Map<String, Object> params;
 
-abstract public class Model {
     public abstract int getModelSize();
 
     public abstract boolean loadFromStream(InputStreamReader in) throws Exception;
+
     public boolean loadFromStream(InputStream is) throws Exception {
         return loadFromStream(new InputStreamReader(is));
     }
+
+    public void setParams(Map<String, Object> params) {
+        this.params = params;
+
+    }
 }
 

+ 7 - 3
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/score/model/ModelManager.java

@@ -83,7 +83,8 @@ public class ModelManager {
      * @param path       Model在OSS上的全路径
      * @param modelClass Model的子类型
      */
-    public void registerModel(String modelName, String path, Class<? extends Model> modelClass) throws ModelRegisterException, IOException {
+    public void registerModel(String modelName, String path, Class<? extends Model> modelClass,
+                              Map<String, Object> params) throws ModelRegisterException, IOException {
         if (modelPathMap.containsKey(modelName)) {
             // fail fast
             // throw new RuntimeException(modelName + " already exists");
@@ -96,7 +97,7 @@ public class ModelManager {
             ModelLoadTask loadTask = loadTasks.get(path);
             loadTask.refCount++;
         } else {
-            ModelLoadTask task = new ModelLoadTask(path, modelClass);
+            ModelLoadTask task = new ModelLoadTask(path, modelClass, params);
             task.refCount++;
             loadTasks.put(path, task);
             loadModelWithRetry(task, false, true);
@@ -205,6 +206,7 @@ public class ModelManager {
                         loadTask.lastModifyTime, timeStamp);
 
                 Model model = loadTask.modelClass.newInstance();
+                model.setParams(loadTask.params);
                 if (model.loadFromStream(ossObj.getObjectContent())) {
                     loadTask.model = model;
                     loadTask.lastModifyTime = timeStamp;
@@ -245,12 +247,14 @@ public class ModelManager {
         private boolean isLoading;
         private final Class<? extends Model> modelClass;
         private Model model;
+        private Map<String, Object> params;
 
-        ModelLoadTask(String path, Class<? extends Model> modelClass) {
+        ModelLoadTask(String path, Class<? extends Model> modelClass, Map<String, Object> params) {
             this.refCount = 0;
             this.path = path;
             this.lastModifyTime = 0;
             this.modelClass = modelClass;
+            this.params = params;
         }
     }
 }

+ 7 - 3
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/score/model/XGBoostModel.java

@@ -6,6 +6,7 @@ import com.tzld.piaoquan.recommend.server.util.PropertiesUtil;
 import ml.dmlc.xgboost4j.scala.DMatrix;
 import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
 import org.apache.commons.lang.math.NumberUtils;
+import org.apache.commons.lang3.StringUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -13,6 +14,7 @@ import java.io.File;
 import java.io.InputStream;
 import java.io.InputStreamReader;
 import java.util.Map;
+import java.util.UUID;
 
 
 public class XGBoostModel extends Model {
@@ -21,7 +23,7 @@ public class XGBoostModel extends Model {
 
     private String[] features;
 
-    public void setFeatures(String[] features){
+    public void setFeatures(String[] features) {
         this.features = features;
     }
 
@@ -59,9 +61,11 @@ public class XGBoostModel extends Model {
 
     @Override
     public boolean loadFromStream(InputStream in) throws Exception {
-        String modelDir = PropertiesUtil.getString("model.xgboost.path");
+        Object localDir = params.getOrDefault("localDir",
+                PropertiesUtil.getString("model.xgboost.path") + "/" + UUID.randomUUID());
+        String modelDir = String.valueOf(localDir);
         CompressUtil.decompressGzFile(in, modelDir);
-        String absolutePath =new File(modelDir).getAbsolutePath();
+        String absolutePath = new File(modelDir).getAbsolutePath();
         XGBoostClassificationModel model2 = XGBoostClassificationModel.load("file://" + absolutePath);
         model2.setMissing(0.0f);
         this.model = model2;

+ 6 - 1
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/score4recall/AbstractScorer4Recall.java

@@ -1,6 +1,7 @@
 package com.tzld.piaoquan.recommend.server.service.score4recall;
 
 
+import com.typesafe.config.ConfigObject;
 import com.tzld.piaoquan.recommend.server.service.score.ScorerConfigInfo;
 import com.tzld.piaoquan.recommend.server.service.score.model.Model;
 import com.tzld.piaoquan.recommend.server.service.score.model.ModelManager;
@@ -10,6 +11,7 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 
@@ -36,7 +38,9 @@ public abstract class AbstractScorer4Recall {
         if (StringUtils.isNotBlank(modelPath)) {
             try {
                 // 使用 modelPath 作为 modelName 注册
-                modelManager.registerModel(modelPath, modelPath, modelClass);
+                ConfigObject paramMap = scorerConfigInfo.getParamMap();
+                modelManager.registerModel(modelPath, modelPath, modelClass,
+                        paramMap == null ? Collections.emptyMap() : paramMap.unwrapped());
                 LOGGER.info("register model success, model path [{}], model class [{}]", modelPath, modelClass);
             } catch (ModelManager.ModelRegisterException e) {
                 LOGGER.error("register model fail [{}]:[{}]", modelPath, e);
@@ -47,6 +51,7 @@ public abstract class AbstractScorer4Recall {
             LOGGER.error("modelpath is null, for model class [{}]", modelClass);
         }
     }
+
     public Model getModel() {
         if (StringUtils.isBlank(scorerConfigInfo.getModelPath())) {
             return null;

+ 1 - 0
recommend-server-service/src/main/resources/feeds_score_config_xgb_20240828.conf

@@ -4,6 +4,7 @@ scorer-config = {
     scorer-priority = 99
     model-path = "zhangbo/model_xgb_for_recsys.tar.gz"
     param = {
+      localDir = "xgboost/recsys"
       features = [
         "b123_1h_STR",
         "b123_1h_log(share)",