Quellcode durchsuchen

Merge branch 'feature/zhangbo_rank' of algorithm/recommend-server into master

zhangbo vor 1 Jahr
Ursprung
Commit
912f1d7695

+ 4 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/common/base/RankItem.java

@@ -16,6 +16,8 @@ public class RankItem implements Comparable<RankItem> {
     public long videoId;
     private double score; // 记录最终的score
     private Video video;
+    private double scoreRos; // 记录ros的score
+    private double scoreStr; // 记录str的score
 
     // 记录Item侧用到的特征
     private ItemFeature itemFeature;
@@ -23,6 +25,8 @@ public class RankItem implements Comparable<RankItem> {
     public RankItem(Video video) {
         this.videoId = video.getVideoId();
         this.score = 0.0;
+        this.scoreRos = 0.0;
+        this.scoreStr = 0.0;
         this.video = video;
     }
 

+ 4 - 1
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/model/Video.java

@@ -30,6 +30,9 @@ public class Video {
     private List<String> tags = new ArrayList<>();
 
     // video的模型打分
-    private double modelScore = 0.0D;
+    private double scoreRos = 0.0D;
+    private double scoreStr = 0.0D;
+    public double score = 0.0D;
+    public double scoreRegion = 0.0D;
 
 }

+ 5 - 2
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/rank/RankRouter.java

@@ -1,7 +1,6 @@
 package com.tzld.piaoquan.recommend.server.service.rank;
 
-import com.tzld.piaoquan.recommend.server.service.rank.strategy.RankStrategy4Density;
-import com.tzld.piaoquan.recommend.server.service.rank.strategy.RankStrategy4RankModel;
+import com.tzld.piaoquan.recommend.server.service.rank.strategy.*;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.stereotype.Service;
@@ -12,6 +11,8 @@ public class RankRouter {
     @Autowired
     private RankService rankService;
     @Autowired
+    private RankStrategy4Rankv2Model rankStrategy4Rankv2Model;
+    @Autowired
     private RankStrategy4RankModel rankStrategy4RankModel;
     @Autowired
     private RankStrategy4Density rankStrategy4Density;
@@ -21,6 +22,8 @@ public class RankRouter {
             return rankService.rank(param);
         }
         switch (abCode){
+            case "60106":
+                return rankStrategy4Rankv2Model.rank(param);
             case "60101":
                 return rankStrategy4RankModel.rank(param);
             case "60098":

+ 29 - 6
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/rank/strategy/RankStrategy4RankModel.java

@@ -2,6 +2,7 @@ package com.tzld.piaoquan.recommend.server.service.rank.strategy;
 
 
 import com.alibaba.fastjson.JSONObject;
+import com.ctrip.framework.apollo.spring.annotation.ApolloJsonValue;
 import com.google.common.reflect.TypeToken;
 import com.tzld.piaoquan.recommend.server.common.base.RankItem;
 import com.tzld.piaoquan.recommend.server.model.Video;
@@ -36,12 +37,24 @@ public class RankStrategy4RankModel extends RankService {
 
     @Value("${video.model.weight:}")
     private Double mergeWeight;
+    @ApolloJsonValue("${video.model.weightv2:}")
+    private Map<String, Double> mergeWeightNew;
     final private String CLASS_NAME = this.getClass().getSimpleName();
+
+//    public Video getTestVideo(Long id, String s){
+//        Video a1 = new Video();
+//        a1.setVideoId(id);
+//        a1.setFlowPool(s);
+//        a1.setPushFrom("recall_pool_region_h");
+//        return a1;
+//    }
     @Override
     public List<Video> mergeAndRankRovRecall(RankParam param) {
 
         //-------------------地域内部融合-------------------
         List<Video> rovRecallRank = new ArrayList<>();
+//        rovRecallRank.add(0, getTestVideo(1070462L, ""));
+//        rovRecallRank.add(0, getTestVideo(1085062L, ""));
         rovRecallRank.addAll(extractAndSort(param, RegionHRecallStrategy.PUSH_FORM));
         rovRecallRank.addAll(extractAndSort(param, RegionHDupRecallStrategy.PUSH_FORM));
         rovRecallRank.addAll(extractAndSort(param, Region24HRecallStrategy.PUSH_FORM));
@@ -76,15 +89,24 @@ public class RankStrategy4RankModel extends RankService {
             Collections.sort(rovRecallRank, Comparator.comparingDouble(o -> -o.getSortScore()));
         }
 
-        //------------------- todo zhangbo 增加排序str模型逻辑 合并二者得分-------------------
+        //------------------- todo zhangbo 增加排序str ros模型逻辑 合并二者得分-------------------
         List<Video> videosWithModel = model(rovRecallRank, param);
+        Map<String, Double> mergeWeight = this.mergeWeightNew == null? new HashMap<>(): this.mergeWeightNew;
+        double alpha = mergeWeight.getOrDefault("alpha", 1.0D);
+        double beta = mergeWeight.getOrDefault("beta", 0.0D);
+        double gamma = mergeWeight.getOrDefault("gamma", 0.0D);
         for (Video v : videosWithModel){
-            double mergeWeightIn = this.mergeWeight == null? 0.0D: this.mergeWeight;
-            v.setSortScore(v.getSortScore() + mergeWeightIn * v.getModelScore());
+            double score = alpha * v.getSortScore() + beta * v.getScoreStr() + gamma * v.getScoreRos();
+            if (mergeWeight.containsKey("mul") && mergeWeight.getOrDefault("mul", 0.0D) > 0.5){
+                score = alpha * v.getSortScore() + (beta + v.getScoreStr()) * (gamma + v.getScoreRos());
+            }
+            v.setScoreRegion(v.getSortScore());
+            v.score = score;
+            v.setSortScore(score);
         }
-        Collections.sort(videosWithModel, Comparator.comparingDouble(o -> -o.getSortScore()));
+        videosWithModel.sort(Comparator.comparingDouble(o -> -o.score));
 
-        //------------------- 增加日志-------------------
+        //------------------- 增加日志 -------------------
         int size = 4;
         List<Long> oldRes = rovRecallRank.subList(0, Math.min(rovRecallRank.size(), size)).stream().map(r-> r.getVideoId()).collect(Collectors.toList());
         List<Long> newRes = videosWithModel.subList(0, Math.min(videosWithModel.size(), size)).stream().map(r-> r.getVideoId()).collect(Collectors.toList());
@@ -217,7 +239,8 @@ public class RankStrategy4RankModel extends RankService {
         return CommonCollectionUtils.toList(rovRecallScore, i -> {
             // hard code 将排序分数 赋值给video的sortScore
             Video v = i.getVideo();
-            v.setModelScore(i.getScore());
+            v.setScoreStr(i.getScoreStr());
+            v.setScoreRos(i.getScoreRos());
             return v;
         });
     }

+ 290 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/rank/strategy/RankStrategy4Rankv2Model.java

@@ -0,0 +1,290 @@
+package com.tzld.piaoquan.recommend.server.service.rank.strategy;
+
+
+import com.alibaba.fastjson.JSONObject;
+import com.ctrip.framework.apollo.spring.annotation.ApolloJsonValue;
+import com.google.common.reflect.TypeToken;
+import com.tzld.piaoquan.recommend.server.common.base.RankItem;
+import com.tzld.piaoquan.recommend.server.model.Video;
+import com.tzld.piaoquan.recommend.server.service.rank.RankParam;
+import com.tzld.piaoquan.recommend.server.service.rank.RankService;
+import com.tzld.piaoquan.recommend.server.service.recall.strategy.*;
+import com.tzld.piaoquan.recommend.server.service.score.ScorerUtils;
+import com.tzld.piaoquan.recommend.server.util.CommonCollectionUtils;
+import com.tzld.piaoquan.recommend.server.util.JSONUtils;
+import lombok.extern.slf4j.Slf4j;
+import org.apache.commons.collections4.CollectionUtils;
+import org.apache.commons.lang3.math.NumberUtils;
+import org.springframework.data.redis.connection.RedisConnectionFactory;
+import org.springframework.data.redis.connection.RedisStandaloneConfiguration;
+import org.springframework.data.redis.connection.jedis.JedisConnectionFactory;
+import org.springframework.data.redis.core.RedisTemplate;
+import org.springframework.data.redis.serializer.StringRedisSerializer;
+import org.springframework.stereotype.Service;
+
+import java.text.SimpleDateFormat;
+import java.util.*;
+import java.util.stream.Collectors;
+
+/**
+ * @author zhangbo
+ * @desc 模型的排序实验
+ */
+@Service
+@Slf4j
+public class RankStrategy4Rankv2Model extends RankService {
+
+    @ApolloJsonValue("${video.model.weightv2:}")
+    private Map<String, Double> mergeWeight;
+    final private String CLASS_NAME = this.getClass().getSimpleName();
+//    public Video getTestVideo(Long id, String s){
+//        Video a1 = new Video();
+//        a1.setVideoId(id);
+//        a1.setFlowPool(s);
+//        a1.setPushFrom("recall_pool_region_h");
+//        return a1;
+//    }
+
+    @Override
+    public List<Video> mergeAndRankRovRecall(RankParam param) {
+
+        //-------------------地域内部融合-------------------
+        List<Video> rovRecallRank = new ArrayList<>();
+//        rovRecallRank.add(0, getTestVideo(1070462L, ""));
+//        rovRecallRank.add(0, getTestVideo(1085062L, ""));
+        rovRecallRank.addAll(extractAndSort(param, RegionHRecallStrategy.PUSH_FORM));
+        rovRecallRank.addAll(extractAndSort(param, RegionHDupRecallStrategy.PUSH_FORM));
+        rovRecallRank.addAll(extractAndSort(param, Region24HRecallStrategy.PUSH_FORM));
+        rovRecallRank.addAll(extractAndSort(param, RegionRelative24HRecallStrategy.PUSH_FORM));
+        rovRecallRank.addAll(extractAndSort(param, RegionRelative24HDupRecallStrategy.PUSH_FORM));
+
+        removeDuplicate(rovRecallRank);
+        rovRecallRank = rovRecallRank.size() <= param.getSize()
+                ? rovRecallRank
+                : rovRecallRank.subList(0, param.getSize());
+
+        //-------------------地域 sim returnv2 融合-------------------
+        rovRecallRank.addAll(extractAndSort(param, SimHotVideoRecallStrategy.PUSH_FORM));
+        rovRecallRank.addAll(extractAndSort(param, ReturnVideoRecallStrategy.PUSH_FORM));
+        removeDuplicate(rovRecallRank);
+
+        //-------------------排-------------------
+        //-------------------序-------------------
+        //-------------------逻-------------------
+        //-------------------辑-------------------
+        List<String> videoIdKeys = rovRecallRank.stream()
+                .map(t -> param.getRankKeyPrefix() + t.getVideoId())
+                .collect(Collectors.toList());
+        List<String> videoScores = this.redisTemplate.opsForValue().multiGet(videoIdKeys);
+        log.info("rank mergeAndRankRovRecall videoIdKeys={}, videoScores={}", JSONUtils.toJson(videoIdKeys),
+                JSONUtils.toJson(videoScores));
+        if (CollectionUtils.isNotEmpty(videoScores)
+                && videoScores.size() == rovRecallRank.size()) {
+            for (int i = 0; i < videoScores.size(); i++) {
+                rovRecallRank.get(i).setSortScore(NumberUtils.toDouble(videoScores.get(i), 0.0));
+            }
+            Collections.sort(rovRecallRank, Comparator.comparingDouble(o -> -o.getSortScore()));
+        }
+
+        //------------------- todo zhangbo 增加排序str ros模型逻辑 合并二者得分-------------------
+        List<Video> videosWithModel = model(rovRecallRank, param);
+        Map<String, Double> mergeWeight = this.mergeWeight == null? new HashMap<>(): this.mergeWeight;
+        double alpha = mergeWeight.getOrDefault("alpha", 1.0D);
+        double beta = mergeWeight.getOrDefault("beta", 0.0D);
+        double gamma = mergeWeight.getOrDefault("gamma", 0.0D);
+        for (Video v : videosWithModel){
+            double score = alpha * v.getSortScore() + beta * v.getScoreStr() + gamma * v.getScoreRos();
+            if (mergeWeight.containsKey("mul") && mergeWeight.getOrDefault("mul", 0.0D) > 0.5){
+                score = alpha * v.getSortScore() + (beta + v.getScoreStr()) * (gamma + v.getScoreRos());
+            }
+            v.setScoreRegion(v.getSortScore());
+            v.score = score;
+            v.setSortScore(score);
+        }
+        videosWithModel.sort(Comparator.comparingDouble(o -> -o.score));
+
+        //------------------- 增加日志 -------------------
+        int size = 4;
+        List<Long> oldRes = rovRecallRank.subList(0, Math.min(rovRecallRank.size(), size)).stream().map(r-> r.getVideoId()).collect(Collectors.toList());
+        List<Long> newRes = videosWithModel.subList(0, Math.min(videosWithModel.size(), size)).stream().map(r-> r.getVideoId()).collect(Collectors.toList());
+        int diffpos = 0;
+        int difftop = 0;
+        for (int i=0; i<newRes.size(); ++i){
+            if (!oldRes.get(i).equals(newRes.get(i))){
+                ++diffpos;
+            }
+            if (!oldRes.contains(newRes.get(i))){
+                ++difftop;
+            }
+        }
+        JSONObject obj = new JSONObject();
+        obj.put("name", "RankStrategy4Rankv2Model");
+        obj.put("diffpos", diffpos);
+        obj.put("difftop", difftop);
+        obj.put("videosWithModel_size", videosWithModel.size());
+        obj.put("oldRes", oldRes.stream()
+                .map(String::valueOf)
+                .collect(Collectors.joining(",")));
+        obj.put("newRes", newRes.stream()
+                .map(String::valueOf)
+                .collect(Collectors.joining(",")));
+        log.info(obj.toString());
+
+        return videosWithModel;
+    }
+
+    public List<Video> model(List<Video> videos, RankParam param){
+        if (videos.isEmpty()){
+            return videos;
+        }
+
+        RedisStandaloneConfiguration redisSC = new RedisStandaloneConfiguration();
+        redisSC.setPort(6379);
+        redisSC.setPassword("Wqsd@2019");
+        redisSC.setHostName("r-bp1pi8wyv6lzvgjy5z.redis.rds.aliyuncs.com");
+        RedisConnectionFactory connectionFactory = new JedisConnectionFactory(redisSC);
+        RedisTemplate<String, String> redisTemplate = new RedisTemplate<>();
+        redisTemplate.setConnectionFactory(connectionFactory);
+        redisTemplate.setDefaultSerializer(new StringRedisSerializer());
+        redisTemplate.afterPropertiesSet();
+
+        Map<String, String> userFeatureMap = new HashMap<>();
+        if (param.getMid() != null && !param.getMid().isEmpty()){
+            String midKey = "user_info_4video_" + param.getMid();
+            String userFeatureStr = redisTemplate.opsForValue().get(midKey);
+            if (userFeatureStr != null){
+                try{
+                    userFeatureMap = JSONUtils.fromJson(userFeatureStr,
+                            new TypeToken<Map<String, String>>() {},
+                            userFeatureMap);
+                }catch (Exception e){
+                    log.error(String.format("parse user json is wrong in {} with {}",
+                            this.CLASS_NAME, e));
+                }
+            }else{
+                JSONObject obj = new JSONObject();
+                obj.put("name", "user_key_in_model_is_null");
+                log.info(obj.toString());
+                return videos;
+            }
+        }
+        final Set<String> userFeatureSet = new HashSet<>(Arrays.asList(
+                "machineinfo_brand", "machineinfo_model", "machineinfo_platform", "machineinfo_system",
+                "u_1day_exp_cnt", "u_1day_click_cnt", "u_1day_share_cnt", "u_1day_return_cnt",
+                "u_ctr_1day","u_str_1day","u_rov_1day","u_ros_1day",
+                "u_3day_exp_cnt","u_3day_click_cnt","u_3day_share_cnt","u_3day_return_cnt",
+                "u_ctr_3day","u_str_3day","u_rov_3day","u_ros_3day"
+        ));
+        Iterator<Map.Entry<String, String>> iterator = userFeatureMap.entrySet().iterator();
+        while (iterator.hasNext()) {
+            Map.Entry<String, String> entry = iterator.next();
+            if (!userFeatureSet.contains(entry.getKey())) {
+                // 删除键值对
+                iterator.remove();
+            }
+        }
+
+        log.info("userFeature in model = {}", JSONUtils.toJson(userFeatureMap));
+
+        final Set<String> itemFeatureSet = new HashSet<>(Arrays.asList(
+                "total_time", "play_count_total",
+                "i_1day_exp_cnt", "i_1day_click_cnt", "i_1day_share_cnt", "i_1day_return_cnt",
+                "i_ctr_1day", "i_str_1day", "i_rov_1day", "i_ros_1day",
+                "i_3day_exp_cnt", "i_3day_click_cnt", "i_3day_share_cnt", "i_3day_return_cnt",
+                "i_ctr_3day", "i_str_3day", "i_rov_3day", "i_ros_3day"
+        ));
+
+        List<RankItem> rankItems = CommonCollectionUtils.toList(videos, RankItem::new);
+        List<Long> videoIds = CommonCollectionUtils.toListDistinct(videos, Video::getVideoId);
+        List<String> videoFeatureKeys = videoIds.stream().map(r-> "video_info_" + r)
+                .collect(Collectors.toList());
+        List<String> videoFeatures = redisTemplate.opsForValue().multiGet(videoFeatureKeys);
+        if (videoFeatures != null){
+            for (int i=0; i<videoFeatures.size(); ++i){
+                String vF = videoFeatures.get(i);
+                Map<String, String> vfMap = new HashMap<>();
+                if (vF == null){
+                    continue;
+                }
+                try{
+                    vfMap = JSONUtils.fromJson(vF, new TypeToken<Map<String, String>>() {}, vfMap);
+                    Iterator<Map.Entry<String, String>> iteratorIn = vfMap.entrySet().iterator();
+                    while (iteratorIn.hasNext()) {
+                        Map.Entry<String, String> entry = iteratorIn.next();
+                        if (!itemFeatureSet.contains(entry.getKey())) {
+                            // 删除键值对
+                            iteratorIn.remove();
+                        }
+                    }
+                    rankItems.get(i).setFeatureMap(vfMap);
+                }catch (Exception e){
+                    log.error(String.format("parse video json is wrong in {} with {}",
+                            this.CLASS_NAME, e));
+                }
+            }
+        }
+        log.info("ItemFeature = {}", JSONUtils.toJson(videoFeatures));
+
+        Map<String, String> sceneFeatureMap =  this.getSceneFeature(param);
+
+        List<RankItem> rovRecallScore = ScorerUtils.getScorerPipeline(ScorerUtils.BASE_CONF)
+                .scoring(sceneFeatureMap, userFeatureMap, rankItems);
+        log.info("mergeAndRankRovRecallNew rovRecallScore={}", JSONUtils.toJson(rovRecallScore));
+        JSONObject obj = new JSONObject();
+        obj.put("name", "user_key_in_model_is_not_null");
+        log.info(obj.toString());
+        return CommonCollectionUtils.toList(rovRecallScore, i -> {
+            // hard code 将排序分数 赋值给video的sortScore
+            Video v = i.getVideo();
+            v.setScoreStr(i.getScoreStr());
+            v.setScoreRos(i.getScoreRos());
+            return v;
+        });
+    }
+
+    private Map<String, String> getSceneFeature(RankParam param) {
+        Map<String, String> sceneFeatureMap = new HashMap<>();
+        String provinceCn = param.getProvince();
+        provinceCn = provinceCn.replaceAll("省$", "");
+        sceneFeatureMap.put("ctx_region", provinceCn);
+        String city = param.getCity();
+        if ("台北市".equals(city) |
+            "高雄市".equals(city) |
+            "台中市".equals(city) |
+            "桃园市".equals(city) |
+            "新北市".equals(city) |
+            "台南市".equals(city) |
+            "基隆市".equals(city) |
+            "吉林市".equals(city) |
+            "新竹市".equals(city) |
+            "嘉义市".equals(city)
+        ){
+            ;
+        }else{
+            city = city.replaceAll("市$", "");
+        }
+        sceneFeatureMap.put("ctx_city", city);
+
+        Calendar calendar = Calendar.getInstance();
+        sceneFeatureMap.put("ctx_week", (calendar.get(Calendar.DAY_OF_WEEK) + 6) % 7 + "");
+        sceneFeatureMap.put("ctx_hour", new SimpleDateFormat("HH").format(calendar.getTime()));
+
+        return sceneFeatureMap;
+    }
+
+    public static void main(String[] args) {
+        Calendar calendar = Calendar.getInstance();
+        calendar.set(Calendar.YEAR, 2022);
+        calendar.set(Calendar.MONTH, 0); // January is 0
+        calendar.set(Calendar.DAY_OF_MONTH, 1);
+        calendar.set(Calendar.HOUR_OF_DAY, 0);
+        calendar.set(Calendar.MINUTE, 12);
+        calendar.set(Calendar.SECOND, 30);
+        System.out.println(new SimpleDateFormat("HH").format(calendar.getTime()));
+
+        String provinceCn = "吉林省2";
+        provinceCn = provinceCn.replaceAll("省$", "");
+        System.out.println(provinceCn);
+    }
+
+}

+ 1 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/recall/RecallService.java

@@ -112,6 +112,7 @@ public class RecallService implements ApplicationContextAware {
             ;
         }else{
             switch (abCode){
+                case "60106":
                 case "60068":
                 case "60092":
                 case "60094":

+ 2 - 2
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/score/VlogShareLRScorer.java

@@ -252,7 +252,7 @@ public class VlogShareLRScorer extends BaseLRModelScorer {
         List<Callable<Object>> calls = new ArrayList<Callable<Object>>();
         for (int index = 0; index < items.size(); index++) {
             final int fIndex = index;
-            items.get(fIndex).setScore(0.0);   //原始分为 cube中的粗打分,如果超时,为原始值存在问题, 需要置0
+            // items.get(fIndex).setScore(0.0);   //原始分为 cube中的粗打分,如果超时,为原始值存在问题, 需要置0
             calls.add(new Callable<Object>() {
                 @Override
                 public Object call() throws Exception {
@@ -323,7 +323,7 @@ public class VlogShareLRScorer extends BaseLRModelScorer {
                         item.getVideoId(), ExceptionUtils.getFullStackTrace(e)});
             }
         }
-        item.setScore(pro);
+        item.setScoreStr(pro);
         return pro;
     }
 }

+ 329 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/score/VlogShareLRScorer4Ros.java

@@ -0,0 +1,329 @@
+package com.tzld.piaoquan.recommend.server.service.score;
+
+
+import com.tzld.piaoquan.recommend.feature.domain.video.base.*;
+import com.tzld.piaoquan.recommend.server.common.base.*;
+import com.tzld.piaoquan.recommend.feature.model.sample.*;
+import com.tzld.piaoquan.recommend.feature.domain.video.feature.VlogShareLRFeatureExtractor;
+import com.tzld.piaoquan.recommend.server.service.rank.strategy.OfflineVlogShareLRFeatureExtractor;
+import com.tzld.piaoquan.recommend.server.service.score.model.LRModel;
+
+import org.apache.commons.collections4.CollectionUtils;
+import org.apache.commons.lang.exception.ExceptionUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.*;
+import java.util.concurrent.*;
+
+
+public class VlogShareLRScorer4Ros extends BaseLRModelScorer {
+
+    private final static int CORE_POOL_SIZE = 64;
+
+    private static final int LOCAL_TIME_OUT = 150;
+    private final static Logger LOGGER = LoggerFactory.getLogger(VlogShareLRScorer4Ros.class);
+    private static final ExecutorService executorService = Executors.newFixedThreadPool(128);
+    private static final double defaultUserCtrGroupNumber = 10.0;
+    private static final int enterFeedsScoreRatio = 10;
+    private static final int enterFeedsScoreNum = 20;
+
+
+    public VlogShareLRScorer4Ros(ScorerConfigInfo configInfo) {
+        super(configInfo);
+    }
+
+
+    @Override
+    public List<RankItem> scoring(final ScoreParam param,
+                                  final UserFeature userFeature,
+                                  final List<RankItem> rankItems) {
+
+        if (CollectionUtils.isEmpty(rankItems)) {
+            return rankItems;
+        }
+
+        long startTime = System.currentTimeMillis();
+        LRModel model = (LRModel) this.getModel();
+        LOGGER.debug("model size: [{}]", model.getModelSize());
+
+        List<RankItem> result = rankItems;
+        result = rankByJava(rankItems, param.getRequestContext(),
+                userFeature == null ? UserFeature.defaultInstance(param.getMid()) : userFeature);
+
+        LOGGER.debug("ctr ranker time java items size={}, time={} ", result != null ? result.size() : 0,
+                System.currentTimeMillis() - startTime);
+
+        return result;
+    }
+
+    private List<RankItem> rankByJava(final List<RankItem> items,
+                                      final RequestContext requestContext,
+                                      final UserFeature user) {
+        long startTime = System.currentTimeMillis();
+        LRModel model = (LRModel) this.getModel();
+        LOGGER.debug("model size: [{}]", model.getModelSize());
+
+        // userBytes
+        UserBytesFeature userInfoBytes = null;
+        userInfoBytes = new UserBytesFeature(user);
+
+        // 所有都参与打分,按照ctr排序
+        multipleCtrScore(items, userInfoBytes, requestContext, model);
+
+        // debug log
+        if (LOGGER.isDebugEnabled()) {
+            for (int i = 0; i < items.size(); i++) {
+                LOGGER.debug("before enter feeds model predict ctr score [{}] [{}]", items.get(i), items.get(i));
+            }
+        }
+
+        Collections.sort(items);
+
+        LOGGER.debug("ctr ranker java execute time: [{}]", System.currentTimeMillis() - startTime);
+        LOGGER.debug("[ctr ranker time java] items size={}, cost={} ", items != null ? items.size() : 0,
+                System.currentTimeMillis() - startTime);
+        return items;
+    }
+
+
+    /**
+     * 计算 predict ctr
+     */
+    public double calcScore(final LRModel lrModel,
+                            final RankItem item,
+                            final UserBytesFeature userInfoBytes,
+                            final RequestContext requestContext) {
+
+        LRSamples lrSamples = null;
+        VlogShareLRFeatureExtractor bytesFeatureExtractor;
+        bytesFeatureExtractor = new VlogShareLRFeatureExtractor();
+
+        try {
+            VideoBytesFeature newsInfoBytes = new VideoBytesFeature(item.getItemFeature() == null
+                    ? ItemFeature.defaultInstance(item.getVideoId() + "")
+                    : item.getItemFeature());
+            lrSamples = bytesFeatureExtractor.single(userInfoBytes, newsInfoBytes,
+                    new RequestContextBytesFeature(requestContext));
+        } catch (Exception e) {
+            LOGGER.error("extract feature error for imei={}, doc={}, [{}]", new Object[]{new String(userInfoBytes.getUid()), item.getVideoId(),
+                    ExceptionUtils.getFullStackTrace(e)});
+        }
+
+
+        double pro = 0.0;
+        if (lrSamples != null && lrSamples.getFeaturesList() != null) {
+            try {
+                pro = lrModel.score(lrSamples);
+            } catch (Exception e) {
+                LOGGER.error("score error for doc={} exception={}", new Object[]{
+                        item.getVideoId(), ExceptionUtils.getFullStackTrace(e)});
+            }
+            // 增加实时特征后打开在线存储日志逻辑
+            //
+            // CtrSamples.Builder samples =  com.tzld.piaoquan.recommend.server.gen.recommend.CtrSamples.newBuilder();
+            // samples.setLr_samples(lrSamples);
+            // item.setSamples(samples);
+            //
+        }
+        item.setScore(pro);
+        return pro;
+    }
+
+
+    /**
+     * 并行打分
+     *
+     * @param items
+     * @param userInfoBytes
+     * @param requestContext
+     * @param model
+     */
+    private void multipleCtrScore(final List<RankItem> items,
+                                  final UserBytesFeature userInfoBytes,
+                                  final RequestContext requestContext,
+                                  final LRModel model) {
+
+        List<Callable<Object>> calls = new ArrayList<Callable<Object>>();
+        for (int index = 0; index < items.size(); index++) {
+            final int fIndex = index;
+            items.get(fIndex).setScore(0.0);   //原始分为 cube中的粗打分,如果超时,为原始值存在问题, 需要置0
+            calls.add(new Callable<Object>() {
+                @Override
+                public Object call() throws Exception {
+                    try {
+                        calcScore(model, items.get(fIndex), userInfoBytes, requestContext);
+                    } catch (Exception e) {
+                        LOGGER.error("ctr exception: [{}] [{}]", items.get(fIndex).videoId, ExceptionUtils.getFullStackTrace(e));
+                    }
+                    return new Object();
+                }
+            });
+        }
+
+        List<Future<Object>> futures = null;
+        try {
+            futures = executorService.invokeAll(calls, LOCAL_TIME_OUT, TimeUnit.MILLISECONDS);
+        } catch (InterruptedException e) {
+            LOGGER.error("execute invoke fail: {}", ExceptionUtils.getFullStackTrace(e));
+        }
+
+        //等待所有请求的结果返回, 超时也返回
+        int cancel = 0;
+        if (futures != null) {
+            for (Future<Object> future : futures) {
+                try {
+                    if (!future.isDone() || future.isCancelled() || future.get() == null) {
+                        cancel++;
+                    }
+                } catch (InterruptedException e) {
+                    LOGGER.error("InterruptedException {},{}", ExceptionUtils.getFullStackTrace(e));
+                } catch (ExecutionException e) {
+                    LOGGER.error("ExecutionException {},{}", requestContext.getRequest_id(),
+                            ExceptionUtils.getFullStackTrace(e));
+                }
+            }
+        }
+        LOGGER.debug("Ctr Score {}, Total: {}, Cancel: {}", new Object[]{requestContext.getRequest_id(), items.size(), cancel});
+    }
+    @Override
+    public List<RankItem> scoring(final Map<String, String> sceneFeatureMap,
+                                  final Map<String, String> userFeatureMap,
+                                  final List<RankItem> rankItems){
+        if (CollectionUtils.isEmpty(rankItems)) {
+            return rankItems;
+        }
+
+        long startTime = System.currentTimeMillis();
+        LRModel model = (LRModel) this.getModel();
+        LOGGER.debug("model size: [{}]", model.getModelSize());
+
+        List<RankItem> result = rankItems;
+        result = rankByJava(
+                sceneFeatureMap, userFeatureMap, rankItems
+        );
+
+        LOGGER.debug("ctr ranker time java items size={}, time={} ", result != null ? result.size() : 0,
+                System.currentTimeMillis() - startTime);
+
+        return result;
+    }
+
+    private List<RankItem> rankByJava(final Map<String, String> sceneFeatureMap,
+                                      final Map<String, String> userFeatureMap,
+                                      final List<RankItem> items) {
+        long startTime = System.currentTimeMillis();
+        LRModel model = (LRModel) this.getModel();
+        LOGGER.debug("model size: [{}]", model.getModelSize());
+        // userBytes
+        Map<String, byte[]> userFeatureMapByte = new HashMap<>();
+        for(Map.Entry<String, String> entry: userFeatureMap.entrySet()){
+            userFeatureMapByte.put(entry.getKey(), entry.getValue().getBytes());
+        }
+        //sceneBytes
+        Map<String, byte[]> sceneFeatureMapByte = new HashMap<>();
+        for(Map.Entry<String, String> entry: sceneFeatureMap.entrySet()){
+            sceneFeatureMapByte.put(entry.getKey(), entry.getValue().getBytes());
+        }
+
+        // 所有都参与打分,按照ctr排序
+        multipleCtrScore(items, userFeatureMapByte, sceneFeatureMapByte, model);
+
+        // debug log
+        if (LOGGER.isDebugEnabled()) {
+            for (int i = 0; i < items.size(); i++) {
+                LOGGER.debug("before enter feeds model predict ctr score [{}] [{}]", items.get(i), items.get(i));
+            }
+        }
+
+        Collections.sort(items);
+
+        LOGGER.debug("ctr ranker java execute time: [{}]", System.currentTimeMillis() - startTime);
+        LOGGER.debug("[ctr ranker time java] items size={}, cost={} ", items != null ? items.size() : 0,
+                System.currentTimeMillis() - startTime);
+        return items;
+    }
+
+    private void multipleCtrScore(final List<RankItem> items,
+                                  final Map<String, byte[]> userFeatureMapByte,
+                                  final Map<String, byte[]> sceneFeatureMapByte,
+                                  final LRModel model) {
+
+        List<Callable<Object>> calls = new ArrayList<Callable<Object>>();
+        for (int index = 0; index < items.size(); index++) {
+            final int fIndex = index;
+//            items.get(fIndex).setScore(0.0);   //原始分为 cube中的粗打分,如果超时,为原始值存在问题, 需要置0
+            calls.add(new Callable<Object>() {
+                @Override
+                public Object call() throws Exception {
+                    try {
+                        calcScore(model, items.get(fIndex), userFeatureMapByte, sceneFeatureMapByte);
+                    } catch (Exception e) {
+                        LOGGER.error("ctr exception: [{}] [{}]", items.get(fIndex).videoId, ExceptionUtils.getFullStackTrace(e));
+                    }
+                    return new Object();
+                }
+            });
+        }
+
+        List<Future<Object>> futures = null;
+        try {
+            futures = executorService.invokeAll(calls, LOCAL_TIME_OUT, TimeUnit.MILLISECONDS);
+        } catch (InterruptedException e) {
+            LOGGER.error("execute invoke fail: {}", ExceptionUtils.getFullStackTrace(e));
+        }
+
+        //等待所有请求的结果返回, 超时也返回
+        int cancel = 0;
+        if (futures != null) {
+            for (Future<Object> future : futures) {
+                try {
+                    if (!future.isDone() || future.isCancelled() || future.get() == null) {
+                        cancel++;
+                    }
+                } catch (InterruptedException e) {
+                    LOGGER.error("InterruptedException {},{}", ExceptionUtils.getFullStackTrace(e));
+                } catch (ExecutionException e) {
+                    LOGGER.error("ExecutionException {},{}", sceneFeatureMapByte.size(),
+                            ExceptionUtils.getFullStackTrace(e));
+                }
+            }
+        }
+        LOGGER.debug("Ctr Score {}, Total: {}, Cancel: {}", new Object[]{sceneFeatureMapByte.size(), items.size(), cancel});
+    }
+
+    public double calcScore(final LRModel lrModel,
+                            final RankItem item,
+                            final Map<String, byte[]> userFeatureMapByte,
+                            final Map<String, byte[]> sceneFeatureMapByte) {
+
+        LRSamples lrSamples = null;
+        OfflineVlogShareLRFeatureExtractor bytesFeatureExtractor;
+        bytesFeatureExtractor = new OfflineVlogShareLRFeatureExtractor();
+
+        try {
+
+            Map<String, byte[]> itemFeatureByte = new HashMap<>();
+            for (Map.Entry<String, String> entry: item.getFeatureMap().entrySet()){
+                itemFeatureByte.put(entry.getKey(), entry.getValue().getBytes());
+            }
+            lrSamples = bytesFeatureExtractor.single(userFeatureMapByte, itemFeatureByte, sceneFeatureMapByte);
+        } catch (Exception e) {
+            LOGGER.error("extract feature error for imei={}, doc={}, [{}]", new Object[]{"", item.getVideoId(),
+                    ExceptionUtils.getFullStackTrace(e)});
+        }
+
+
+        double pro = 0.0;
+        if (lrSamples != null && lrSamples.getFeaturesList() != null) {
+            try {
+                pro = lrModel.score(lrSamples);
+            } catch (Exception e) {
+                LOGGER.error("score error for doc={} exception={}", new Object[]{
+                        item.getVideoId(), ExceptionUtils.getFullStackTrace(e)});
+            }
+        }
+        item.setScoreRos(pro);
+        return pro;
+    }
+}

+ 6 - 1
recommend-server-service/src/main/resources/feeds_score_config_baseline.conf

@@ -1,7 +1,12 @@
 scorer-config = {
-  related-score-config = {
+  str-score-config = {
     scorer-name = "com.tzld.piaoquan.recommend.server.service.score.VlogShareLRScorer"
     scorer-priority = 99
     model-path = "video_str_model/model_sharev2_20231220_change.txt"
   }
+  ros-score-config = {
+    scorer-name = "com.tzld.piaoquan.recommend.server.service.score.VlogShareLRScorer4Ros"
+    scorer-priority = 99
+    model-path = "video_str_model/model_ros_v2_20231220_change.txt"
+  }
 }