Browse Source

Merge branch 'feature/flowpool_thompson_mz' of algorithm/recommend-server into master

smz&zhangbo
qingqu-git 1 year ago
parent
commit
f952db20a3

+ 7 - 0
recommend-server-service/pom.xml

@@ -28,6 +28,11 @@
             <groupId>org.apache.commons</groupId>
             <artifactId>commons-pool2</artifactId>
         </dependency>
+        <dependency>
+            <groupId>org.apache.commons</groupId>
+            <artifactId>commons-math3</artifactId>
+            <version>3.6.1</version>
+        </dependency>
         <dependency>
             <groupId>org.apache.commons</groupId>
             <artifactId>commons-lang3</artifactId>
@@ -70,6 +75,8 @@
         </dependency>
 
 
+
+
         <dependency>
             <groupId>org.springframework.boot</groupId>
             <artifactId>spring-boot-starter-test</artifactId>

+ 17 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/common/base/VideoActionFeature.java

@@ -0,0 +1,17 @@
+package com.tzld.piaoquan.recommend.server.common.base;
+
+import lombok.Getter;
+import lombok.Data;
+import lombok.Getter;
+
+@Data
+public class VideoActionFeature {
+
+
+    private double view = 0d;
+    private double play = 0d;
+    private double realPlay = 0d;
+    private double share = 0d;
+    private double returns = 0d;
+
+}

+ 7 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/rank/RankRouter.java

@@ -16,6 +16,11 @@ public class RankRouter {
     private RankStrategy4RankModel rankStrategy4RankModel;
     @Autowired
     private RankStrategy4Density rankStrategy4Density;
+
+    @Autowired
+    private RankStrategyFlowThompsonModel rankStrategyFlowThompsonModel;
+
+
     public RankResult rank(RankParam param) {
         String abCode = param.getAbCode();
         if (abCode == null) {
@@ -28,6 +33,8 @@ public class RankRouter {
                 return rankStrategy4RankModel.rank(param);
             case "60098":
                 return rankStrategy4Density.rank(param);
+            case "60107":
+                return rankStrategyFlowThompsonModel.rank(param);
             default:
                 break;
         }

+ 3 - 2
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/rank/RankService.java

@@ -128,7 +128,8 @@ public class RankService {
                 || param.getAbCode().equals("60098")
                 || param.getAbCode().equals("60103")
                 || param.getAbCode().equals("60104")
-                || param.getAbCode().equals("60105")) {
+                || param.getAbCode().equals("60105")
+                || param.getAbCode().equals("60107")) {
             // 地域召回要做截取,再做融合排序
             removeDuplicate(rovRecallRank);
             rovRecallRank = rovRecallRank.size() <= sizeReturn
@@ -210,7 +211,7 @@ public class RankService {
         });
     }
 
-    private ScoreParam convert(RankParam param) {
+    protected ScoreParam convert(RankParam param) {
         ScoreParam scoreParam = new ScoreParam();
 
         scoreParam.setMid(param.getMid());

+ 90 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/rank/strategy/RankStrategyFlowThompsonModel.java

@@ -0,0 +1,90 @@
+package com.tzld.piaoquan.recommend.server.service.rank.strategy;
+
+
+import com.tzld.piaoquan.recommend.feature.domain.video.base.UserFeature;
+import com.tzld.piaoquan.recommend.server.common.base.RankItem;
+import com.tzld.piaoquan.recommend.server.common.enums.AppTypeEnum;
+import com.tzld.piaoquan.recommend.server.model.Video;
+import com.tzld.piaoquan.recommend.server.service.flowpool.FlowPoolConstants;
+import com.tzld.piaoquan.recommend.server.service.rank.RankParam;
+import com.tzld.piaoquan.recommend.server.service.rank.RankService;
+
+import com.tzld.piaoquan.recommend.server.service.recall.RecallResult;
+import com.tzld.piaoquan.recommend.server.service.score.ScoreParam;
+import com.tzld.piaoquan.recommend.server.service.score.ScorerUtils;
+import com.tzld.piaoquan.recommend.server.util.CommonCollectionUtils;
+import lombok.extern.slf4j.Slf4j;
+import org.apache.commons.collections4.CollectionUtils;
+import org.springframework.stereotype.Service;
+
+import java.util.*;
+
+/**
+ * @author 孙铭泽
+ * @desc 对流量池增加排序Thompson sampling 策略
+ * @date 2024-01-17
+ */
+@Service
+@Slf4j
+public class RankStrategyFlowThompsonModel extends RankService {
+
+    public List<Video> mergeAndRankFlowPoolRecall(RankParam param) {
+        if (param.getAppType() == AppTypeEnum.LAO_HAO_KAN_VIDEO.getCode()
+                || param.getAppType() == AppTypeEnum.ZUI_JING_QI.getCode()) {
+            if (param.getAbCode().equals("60054")
+                    || param.getAbCode().equals("60068")
+                    || param.getAbCode().equals("60081")
+                    || param.getAbCode().equals("60084")) {
+                return extractAndSort(param, FlowPoolConstants.QUICK_PUSH_FORM);
+            } else {
+                return Collections.emptyList();
+            }
+        } else {
+            List<Video> quickFlowPoolVideos = sortFlowPoolByThompson(param, FlowPoolConstants.QUICK_PUSH_FORM);
+            if (CollectionUtils.isNotEmpty(quickFlowPoolVideos)) {
+                return quickFlowPoolVideos;
+            } else {
+                return sortFlowPoolByThompson(param, FlowPoolConstants.PUSH_FORM);
+            }
+        }
+    }
+
+    public List<Video> sortFlowPoolByThompson(RankParam param, String pushFrom) {
+
+        //初始化 userid
+        UserFeature userFeature = new UserFeature();
+        userFeature.setMid(param.getMid());
+
+        // 初始化RankItem
+        Optional<RecallResult.RecallData> data = param.getRecallResult().getData().stream()
+                .filter(d -> d.getPushFrom().equals(pushFrom))
+                .findFirst();
+        List<Video> videoList = data.get().getVideos();
+
+        if (videoList == null) {
+            return Collections.emptyList();
+        }
+        List<RankItem> rankItems = new ArrayList<>();
+        for (int i = 0; i < videoList.size(); i++) {
+            RankItem rankItem = new RankItem(videoList.get(i));
+            rankItems.add(rankItem);
+        }
+
+        // 初始化上下文参数
+        ScoreParam scoreParam = convert(param);
+        List<RankItem> rovRecallScore = ScorerUtils.getScorerPipeline(ScorerUtils.FLOWPOOL_CONF)
+                .scoring(scoreParam, userFeature, rankItems);
+
+        if (rovRecallScore == null) {
+            return Collections.emptyList();
+        }
+
+        return CommonCollectionUtils.toList(rovRecallScore, i -> {
+            // hard code 将排序分数 赋值给video的sortScore
+            Video v = i.getVideo();
+            v.setSortScore(i.getScore());
+            return v;
+        });
+    }
+
+}

+ 1 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/recall/RecallService.java

@@ -115,6 +115,7 @@ public class RecallService implements ApplicationContextAware {
                     strategies.add(strategyMap.get(SimHotVideoRecallStrategy.class.getSimpleName()));
                     strategies.add(strategyMap.get(ReturnVideoRecallStrategy.class.getSimpleName()));
                     break;
+                case "60107":
                 case "60106":
                 case "60068":
                 case "60092":

+ 32 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/score/BaseThompsonSamplingScorer.java

@@ -0,0 +1,32 @@
+package com.tzld.piaoquan.recommend.server.service.score;
+
+import com.tzld.piaoquan.recommend.server.common.base.RankItem;
+import com.tzld.piaoquan.recommend.server.service.score.model.ThompsonSamplingModel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.List;
+import java.util.Map;
+
+
+public abstract class BaseThompsonSamplingScorer extends AbstractScorer {
+
+    private static Logger LOGGER = LoggerFactory.getLogger(BaseThompsonSamplingScorer.class);
+
+    public BaseThompsonSamplingScorer(ScorerConfigInfo scorerConfigInfo) {
+        super(scorerConfigInfo);
+    }
+
+    @Override
+    public void loadModel() {
+        doLoadModel(ThompsonSamplingModel.class);
+    }
+    @Override
+    public List<RankItem> scoring(final Map<String, String> sceneFeatureMap,
+                                           final Map<String, String> userFeatureMap,
+                                           final List<RankItem> rankItems){
+
+        return rankItems;
+    }
+
+}

+ 6 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/score/ScorerUtils.java

@@ -21,9 +21,14 @@ public final class ScorerUtils {
 
     public static String BASE_CONF = "feeds_score_config_baseline.conf";
 
+    public static String FLOWPOOL_CONF = "feeds_score_config_thompson.conf";
+
+
     public static void warmUp() {
         log.info("scorer warm up ");
         ScorerUtils.init(BASE_CONF);
+        ScorerUtils.init(FLOWPOOL_CONF);
+
     }
 
     private ScorerUtils() {
@@ -59,6 +64,7 @@ public final class ScorerUtils {
         initLoadModel(scorers);
     }
 
+
     public static void initLoadModel(Config config) {
         ScorerConfig scorerConfig = new ScorerConfig();
         scorerConfig.load(config);

+ 141 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/score/VlogThompsonScorer.java

@@ -0,0 +1,141 @@
+package com.tzld.piaoquan.recommend.server.service.score;
+
+import com.tzld.piaoquan.recommend.feature.domain.video.base.*;
+import com.tzld.piaoquan.recommend.feature.domain.video.base.UserFeature;
+import com.tzld.piaoquan.recommend.server.common.base.RankItem;
+import com.tzld.piaoquan.recommend.server.service.score.model.ThompsonSamplingModel;
+import org.apache.commons.collections4.CollectionUtils;
+import org.apache.commons.lang.exception.ExceptionUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.*;
+
+
+//@Service
+public class VlogThompsonScorer extends BaseThompsonSamplingScorer {
+
+    private static final int LOCAL_TIME_OUT = 150;
+    private final static Logger LOGGER = LoggerFactory.getLogger(VlogThompsonScorer.class);
+    private static final ExecutorService executorService = Executors.newFixedThreadPool(128);
+
+    public VlogThompsonScorer(ScorerConfigInfo configInfo) {
+        super(configInfo);
+    }
+
+    @Override
+    public List<RankItem> scoring(final ScoreParam param,
+                                  final UserFeature userFeature,
+                                  final List<RankItem> rankItems) {
+
+        if (userFeature == null || CollectionUtils.isEmpty(rankItems)) {
+            return rankItems;
+        }
+
+        long startTime = System.currentTimeMillis();
+        ThompsonSamplingModel model = (ThompsonSamplingModel) this.getModel();
+        LOGGER.debug("model size: [{}]", model.getModelSize());
+
+        List<RankItem> result = rankItems;
+        result = rankByJava(rankItems, param.getRequestContext(), userFeature);
+
+        LOGGER.debug("thompson sampling ctr ranker time java items size={}, time={} ", result != null ? result.size() : 0,
+                System.currentTimeMillis() - startTime);
+
+        return result;
+    }
+
+    private List<RankItem> rankByJava(final List<RankItem> items,
+                                      final RequestContext requestContext,
+                                      final UserFeature user) {
+        long startTime = System.currentTimeMillis();
+        ThompsonSamplingModel model = (ThompsonSamplingModel) this.getModel();
+        LOGGER.debug("model size: [{}]", model.getModelSize());
+
+        // 所有都参与打分,按照ROV Thompson排序
+        multipleCtrScore(items, model);
+
+        // debug log
+        if (LOGGER.isDebugEnabled()) {
+            for (int i = 0; i < items.size(); i++) {
+                LOGGER.debug("after enter feeds model predict ctr score [{}] [{}]", items.get(i), items.get(i).getScore());
+            }
+        }
+
+        LOGGER.debug("thompson ranker java execute time: [{}]", System.currentTimeMillis() - startTime);
+        LOGGER.debug("[thompson ranker time java] items size={}, cost={} ", items != null ? items.size() : 0,
+                System.currentTimeMillis() - startTime);
+        return items;
+    }
+
+
+    /**
+     * 计算 predict ROV
+     */
+    public double calcScore(final ThompsonSamplingModel model,
+                            final RankItem item) {
+        double score = 0d;
+        try {
+            score = model.score(item);
+        } catch (Exception e) {
+            LOGGER.error("score error for doc={} exception={}", new Object[]{
+                    item.getVideo(), ExceptionUtils.getFullStackTrace(e)});
+        }
+        item.setScore(score);
+        return score;
+    }
+
+
+    /**
+     * 并行打分 Thompson ROV
+     *
+     * @param items
+     * @param model
+     */
+    private void multipleCtrScore(final List<RankItem> items,
+                                  final ThompsonSamplingModel model) {
+
+        List<Callable<Object>> calls = new ArrayList<Callable<Object>>();
+        for (int index = 0; index < items.size(); index++) {
+            final int fIndex = index;
+            items.get(fIndex).setScore(0.0);   //设置为原始值为0
+            calls.add(new Callable<Object>() {
+                @Override
+                public Object call() throws Exception {
+                    try {
+                        calcScore(model, items.get(fIndex));
+                    } catch (Exception e) {
+                        LOGGER.error("thompson exception: [{}] [{}]", items.get(fIndex).videoId, ExceptionUtils.getFullStackTrace(e));
+                    }
+                    return new Object();
+                }
+            });
+        }
+
+        List<Future<Object>> futures = null;
+        try {
+            futures = executorService.invokeAll(calls, LOCAL_TIME_OUT, TimeUnit.MILLISECONDS);
+        } catch (InterruptedException e) {
+            LOGGER.error("execute invoke fail: {}", ExceptionUtils.getFullStackTrace(e));
+        }
+
+        //等待所有请求的结果返回, 超时也返回
+        int cancel = 0;
+        if (futures != null) {
+            for (Future<Object> future : futures) {
+                try {
+                    if (!future.isDone() || future.isCancelled() || future.get() == null) {
+                        cancel++;
+                    }
+                } catch (InterruptedException e) {
+                    LOGGER.error("InterruptedException {},{}", ExceptionUtils.getFullStackTrace(e));
+                } catch (ExecutionException e) {
+                    LOGGER.error("ExecutionException {},{}", ExceptionUtils.getFullStackTrace(e));
+                }
+            }
+        }
+        LOGGER.debug("ROV-Thompson Score {}, Total: {}, Cancel: {}", new Object[]{items.size(), cancel});
+    }
+}

+ 91 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/score/model/ThompsonSamplingModel.java

@@ -0,0 +1,91 @@
+package com.tzld.piaoquan.recommend.server.service.score.model;
+
+import com.tzld.piaoquan.recommend.server.common.base.RankItem;
+import com.tzld.piaoquan.recommend.server.common.base.VideoActionFeature;
+import com.tzld.piaoquan.recommend.server.model.Video;
+import org.apache.commons.math3.distribution.BetaDistribution;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.math.BigInteger;
+import java.util.HashMap;
+import java.util.Map;
+
+
+public class ThompsonSamplingModel extends Model {
+    protected static final int MODEL_FIRST_LOAD_COUNT = 1 << 25;  // 32M
+    private static final Logger LOGGER = LoggerFactory.getLogger(ThompsonSamplingModel.class);
+
+    // key = videoid, value = < push, exp, play, realplay, share, retures >
+    private Map<Long, VideoActionFeature> thompsonSamplingModel;
+
+    private static final int alpha = 20;
+    private static final int beta_returns = 100;
+
+    public ThompsonSamplingModel() {
+        //配置不同环境的hdfs conf
+        this.thompsonSamplingModel = new HashMap<>();
+    }
+
+    public Map<Long, VideoActionFeature> getThompsonSamplingModel() {
+        return this.thompsonSamplingModel;
+    }
+
+
+
+
+    @Override
+    public boolean loadFromStream(InputStreamReader in) throws IOException {
+        Map<Long, VideoActionFeature> initModel = new HashMap<>();
+        BufferedReader input = new BufferedReader(in);
+        String line = null;
+        int cnt = 0;
+        while ((line = input.readLine()) != null) {
+            String[] items = line.split("\t");
+            if (items.length < 3) {
+                continue;
+            }
+            Long videoId = new BigInteger(items[0].trim()).longValue();
+            VideoActionFeature videoFeature = new VideoActionFeature();
+            videoFeature.setView(Double.valueOf(items[1].trim()));
+            videoFeature.setPlay(Double.valueOf(items[2].trim()));
+            videoFeature.setShare(Double.valueOf(items[3].trim()));
+            videoFeature.setReturns(Double.valueOf(items[5].trim()));
+            initModel.put(videoId, videoFeature);
+        }
+
+        this.thompsonSamplingModel = initModel;
+        LOGGER.info("[MODELLOAD] model load over and size " + cnt);
+        input.close();
+        in.close();
+        return true;
+    }
+
+    @Override
+    public int getModelSize() {
+        if (this.thompsonSamplingModel == null)
+            return 0;
+        int sum = this.thompsonSamplingModel.size();
+        return sum;
+    }
+
+    public double score(RankItem rankItem) {
+        double score = 0.0f;
+        VideoActionFeature videoActionFeature = this.thompsonSamplingModel.getOrDefault(rankItem.getVideoId(), new VideoActionFeature());
+
+        int alpha = (int) videoActionFeature.getReturns() + this.alpha;
+        int beta = this.beta_returns + (int) videoActionFeature.getView();
+        score = this.betaSampler(alpha, beta);
+        return score;
+    }
+
+    public double betaSampler(double alpha, double beta) {
+        BetaDistribution betaSample = new BetaDistribution(alpha, beta);
+        return betaSample.sample();
+    }
+
+
+}

+ 7 - 0
recommend-server-service/src/main/resources/feeds_score_config_thompson.conf

@@ -0,0 +1,7 @@
+scorer-config = {
+  flowpool-score-config = {
+    scorer-name = "com.tzld.piaoquan.recommend.server.service.score.VlogThompsonScorer"
+    scorer-priority = 99
+    model-path = "video_thompson_model/model_video_thompson.txt"
+  }
+}