瀏覽代碼

feat:rerank优化

zhaohaipeng 1 月之前
父節點
當前提交
7c291dc68a

+ 4 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/rank/RankService.java

@@ -7,6 +7,7 @@ import com.tzld.piaoquan.recommend.server.model.Video;
 import com.tzld.piaoquan.recommend.server.service.flowpool.FlowPoolConstants;
 import com.tzld.piaoquan.recommend.server.service.recall.RecallResult;
 import com.tzld.piaoquan.recommend.server.service.recall.strategy.*;
+import com.tzld.piaoquan.recommend.server.service.rerank.RerankService;
 import lombok.extern.slf4j.Slf4j;
 import org.apache.commons.collections4.CollectionUtils;
 import org.apache.commons.lang3.math.NumberUtils;
@@ -30,6 +31,9 @@ public abstract class RankService {
     @ApolloJsonValue("${alg.recall.special.app&vid:{}}")
     protected Map<String, List<Long>> specialAppVid;
 
+    @Autowired
+    protected RerankService rerankService;
+
     public RankResult rank(RankParam param) {
         if (param == null
                 || param.getRecallResult() == null

+ 28 - 1
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/rank/strategy/RankStrategy4RegionMergeModelBasic.java

@@ -16,6 +16,8 @@ import com.tzld.piaoquan.recommend.server.service.rank.processor.RankProcessorBo
 import com.tzld.piaoquan.recommend.server.service.rank.processor.RankProcessorDensity;
 import com.tzld.piaoquan.recommend.server.service.rank.processor.RankProcessorInsert;
 import com.tzld.piaoquan.recommend.server.service.rank.processor.RankProcessorTagFilter;
+import com.tzld.piaoquan.recommend.server.service.rerank.RerankParam;
+import com.tzld.piaoquan.recommend.server.service.rerank.RerankResult;
 import com.tzld.piaoquan.recommend.server.util.JSONUtils;
 import lombok.extern.slf4j.Slf4j;
 import org.apache.commons.collections4.CollectionUtils;
@@ -82,8 +84,33 @@ public abstract class RankStrategy4RegionMergeModelBasic extends RankService {
             }
         }
 
-        // 2 根据实验号解析阿波罗参数。
         Set<String> abExpCodes = param.getAbExpCodes();
+        if (CollectionUtils.isNotEmpty(abExpCodes) && abExpCodes.contains("810")) {
+            RerankParam rerankParam = new RerankParam();
+            rerankParam.setRovVideos(rovVideos);
+            rerankParam.setFlowPoolVideos(flowVideos);
+            rerankParam.setDouHotFlowPoolVideos(douHotFlowPoolVideos);
+            rerankParam.setSize(param.getSize());
+            rerankParam.setTopK(param.getTopK());
+            rerankParam.setFlowPoolP(this.getFlowPoolP(param));
+            rerankParam.setSpecialRecommend(param.isSpecialRecommend());
+            rerankParam.setAppType(param.getAppType());
+            rerankParam.setMid(param.getMid());
+            rerankParam.setUid(param.getUid());
+            rerankParam.setProvince(param.getProvince());
+            rerankParam.setCity(param.getCity());
+            rerankParam.setMachineInfo(param.getMachineInfo());
+            rerankParam.setAbExpCodes(abExpCodes);
+            rerankParam.setHeadVid(param.getHeadVid());
+            rerankParam.setHotSceneType(param.getHotSceneType());
+
+            RerankResult rerank = rerankService.rerank(rerankParam);
+
+            return new RankResult(rerank.getVideos());
+        }
+
+
+        // 2 根据实验号解析阿波罗参数。
         Map<String, Map<String, String>> rulesMap = Collections.emptyMap();
 
         Map<String, List<Map<String, String>>> rankReduceRulesMap = Collections.emptyMap();

+ 40 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/rerank/RerankParam.java

@@ -0,0 +1,40 @@
+package com.tzld.piaoquan.recommend.server.service.rerank;
+
+import com.tzld.piaoquan.recommend.server.model.MachineInfo;
+import com.tzld.piaoquan.recommend.server.model.Video;
+import lombok.AllArgsConstructor;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Set;
+
+@Data
+@NoArgsConstructor
+@AllArgsConstructor
+public class RerankParam {
+    // 算法排序完的视频集合,或者上一个节点重排完的算法视频集合
+    private List<Video> rovVideos = new ArrayList<>();
+    // 流量池视频集合【包含flow_pool和quick_flow_pool】
+    private List<Video> flowPoolVideos = new ArrayList<>();
+    // 热点宝视频集合
+    private List<Video> douHotFlowPoolVideos = new ArrayList<>();
+
+    private int size;
+    private int topK;
+    private double flowPoolP;
+    private boolean specialRecommend;
+
+    private int appType;
+    private String mid;
+    private String uid;
+    private String province;
+    private String city;
+    private MachineInfo machineInfo;
+    private Set<String> abExpCodes;
+
+    private Long headVid = 0L;
+    private Long hotSceneType = 0L;
+
+}

+ 17 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/rerank/RerankResult.java

@@ -0,0 +1,17 @@
+package com.tzld.piaoquan.recommend.server.service.rerank;
+
+import com.tzld.piaoquan.recommend.server.model.Video;
+import lombok.AllArgsConstructor;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+
+import java.util.List;
+
+@Data
+@NoArgsConstructor
+@AllArgsConstructor
+public class RerankResult {
+
+    private List<Video> videos;
+
+}

+ 62 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/rerank/RerankService.java

@@ -0,0 +1,62 @@
+package com.tzld.piaoquan.recommend.server.service.rerank;
+
+import com.tzld.piaoquan.recommend.server.model.Video;
+import com.tzld.piaoquan.recommend.server.service.rerank.strategy.FlowPoolVideoInsertRerankStrategy;
+import com.tzld.piaoquan.recommend.server.service.rerank.strategy.VideoAttrWeightRerankStrategy;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.beans.BeansException;
+import org.springframework.context.ApplicationContext;
+import org.springframework.context.ApplicationContextAware;
+import org.springframework.stereotype.Service;
+
+import javax.annotation.PostConstruct;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+@Slf4j
+@Service
+public class RerankService implements ApplicationContextAware {
+
+    private final Map<String, RerankStrategy> strategyMap = new HashMap<>();
+    private ApplicationContext applicationContext;
+
+    @PostConstruct
+    public void init() {
+        Map<String, RerankStrategy> type = applicationContext.getBeansOfType(RerankStrategy.class);
+        for (Map.Entry<String, RerankStrategy> entry : type.entrySet()) {
+            RerankStrategy value = entry.getValue();
+            strategyMap.put(value.getClass().getSimpleName(), value);
+        }
+    }
+
+    /**
+     * 方法使用类似责任链模式的逻辑,多个重排策略串行调用。上一个策略的结果放到param中,传递到下一个策略。最后一个策略输出最终的结果
+     */
+    public RerankResult rerank(RerankParam param) {
+        List<RerankStrategy> rerankStrategies = this.getRerankStrategies(param);
+
+        // 遍历每个重排策略,以最后一个重排策略的结果为最终结果
+        List<Video> videos = new ArrayList<>();
+        for (RerankStrategy rerankStrategy : rerankStrategies) {
+            videos = rerankStrategy.rerank(param);
+        }
+
+        RerankResult rerankResult = new RerankResult();
+        rerankResult.setVideos(videos);
+        return rerankResult;
+    }
+
+    private List<RerankStrategy> getRerankStrategies(RerankParam param) {
+        List<RerankStrategy> rerankStrategies = new ArrayList<>();
+        rerankStrategies.add(strategyMap.get(VideoAttrWeightRerankStrategy.class.getSimpleName()));
+        rerankStrategies.add(strategyMap.get(FlowPoolVideoInsertRerankStrategy.class.getSimpleName()));
+        return rerankStrategies;
+    }
+
+    @Override
+    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
+        this.applicationContext = applicationContext;
+    }
+}

+ 11 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/rerank/RerankStrategy.java

@@ -0,0 +1,11 @@
+package com.tzld.piaoquan.recommend.server.service.rerank;
+
+import com.tzld.piaoquan.recommend.server.model.Video;
+
+import java.util.List;
+
+public interface RerankStrategy {
+
+    List<Video> rerank(RerankParam param);
+
+}

+ 93 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/rerank/strategy/BasicRerankStrategy.java

@@ -0,0 +1,93 @@
+package com.tzld.piaoquan.recommend.server.service.rerank.strategy;
+
+import com.tzld.piaoquan.recommend.server.service.rerank.RerankStrategy;
+import lombok.extern.slf4j.Slf4j;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.commons.lang3.tuple.Pair;
+import org.springframework.beans.factory.annotation.Qualifier;
+import org.springframework.data.redis.core.RedisTemplate;
+
+import javax.annotation.Resource;
+import java.time.LocalDateTime;
+import java.time.format.DateTimeFormatter;
+import java.util.Objects;
+
+@Slf4j
+public abstract class BasicRerankStrategy implements RerankStrategy {
+
+    @Resource
+    @Qualifier("redisTemplate")
+    public RedisTemplate<String, String> redisTemplate;
+
+    private static final DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyyMMddHH");
+
+    /**
+     * 是否在时间范围条件内
+     */
+    protected boolean isInTimeRangeCondition(String timeRange) {
+        if (StringUtils.isBlank(timeRange)) {
+            return true;
+        }
+        String[] split = timeRange.split("-");
+        if (split.length != 2) {
+            return false;
+        }
+
+        String start = split[0].trim();
+        String end = split[1].trim();
+
+        LocalDateTime now = LocalDateTime.now();
+
+        // 具体的时间判断
+        Pair<LocalDateTime, LocalDateTime> localDateTimePair = this.parseDateTimeRange(start, end);
+        if (Objects.nonNull(localDateTimePair)) {
+            LocalDateTime startDt = localDateTimePair.getLeft();
+            LocalDateTime endDt = localDateTimePair.getRight();
+
+            // 等价于 startDt <= now <= endDt
+            return !now.isBefore(startDt) && !now.isAfter(endDt);
+        }
+
+        // 小时判断
+        Pair<Integer, Integer> hourPair = this.parseHourRange(start, end);
+        if (Objects.nonNull(hourPair)) {
+            int nowHour = now.getHour();
+            return hourPair.getLeft() <= nowHour && nowHour <= hourPair.getRight();
+        }
+
+        return false;
+    }
+
+    // 判断是否符合 yyyyMMddHH 格式
+    protected Pair<LocalDateTime, LocalDateTime> parseDateTimeRange(String start, String end) {
+        try {
+            if (StringUtils.isBlank(start) || StringUtils.isBlank(end)) {
+                return null;
+            }
+
+            return Pair.of(LocalDateTime.parse(start, formatter), LocalDateTime.parse(end, formatter));
+        } catch (Exception e) {
+            log.error("datetime parse error. {}-{} \n", start, end, e);
+        }
+        return null;
+    }
+
+    protected Pair<Integer, Integer> parseHourRange(String start, String end) {
+        try {
+            if (StringUtils.isBlank(start) || StringUtils.isBlank(end)) {
+                return null;
+            }
+
+            int startHour = Integer.parseInt(start);
+            int endHour = Integer.parseInt(end);
+            if (startHour < 0 || startHour > 23 || endHour < 0 || endHour > 23) {
+                return null;
+            }
+
+            return Pair.of(startHour, endHour);
+        } catch (Exception e) {
+            log.error("parse hour range error. {}-{} \n", start, end, e);
+            return null;
+        }
+    }
+}

+ 83 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/rerank/strategy/FlowPoolVideoInsertRerankStrategy.java

@@ -0,0 +1,83 @@
+package com.tzld.piaoquan.recommend.server.service.rerank.strategy;
+
+import com.tzld.piaoquan.recommend.server.model.Video;
+import com.tzld.piaoquan.recommend.server.service.rerank.RerankParam;
+import com.tzld.piaoquan.recommend.server.service.rerank.RerankStrategy;
+import lombok.extern.slf4j.Slf4j;
+import org.apache.commons.lang3.RandomUtils;
+import org.springframework.beans.factory.annotation.Value;
+import org.springframework.stereotype.Component;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * 流量池视频插入
+ */
+@Slf4j
+@Component
+public class FlowPoolVideoInsertRerankStrategy implements RerankStrategy {
+
+    @Value("${new.flow.pool.select.rate:1}")
+    private double newFlowPoolSelectRate;
+
+    @Override
+    public List<Video> rerank(RerankParam param) {
+        List<Video> result = new ArrayList<>();
+
+        double flowPoolP = param.getFlowPoolP();
+        int topK = param.getTopK();
+        int flowPoolIndex = param.getTopK();
+        int rovPoolIndex = param.getTopK();
+        int size = param.getSize();
+
+        List<Video> rovVideos = param.getRovVideos();
+        List<Video> flowVideos = param.getFlowPoolVideos();
+        List<Video> douHotFlowPoolVideos = param.getDouHotFlowPoolVideos();
+
+        // 前N个写入算法的视频
+        for (int i = 0; i < topK && i < rovVideos.size(); i++) {
+            result.add(rovVideos.get(i));
+        }
+
+        for (int i = 0; i < size - topK; i++) {
+            double rand = RandomUtils.nextDouble(0, 1);
+            if (rand < flowPoolP) {
+                if (flowPoolIndex < flowVideos.size()) {
+                    result.add(flowVideos.get(flowPoolIndex++));
+                } else {
+                    break;
+                }
+            } else if (this.isInsertDouHotFlowPoolVideo()) {
+                if (flowPoolIndex < douHotFlowPoolVideos.size()) {
+                    result.add(douHotFlowPoolVideos.get(flowPoolIndex++));
+                } else {
+                    break;
+                }
+            } else {
+                if (rovPoolIndex < rovVideos.size()) {
+                    result.add(rovVideos.get(rovPoolIndex++));
+                } else {
+                    break;
+                }
+            }
+        }
+        if (rovPoolIndex >= rovVideos.size()) {
+            for (int i = flowPoolIndex; i < flowVideos.size() && result.size() < param.getSize(); i++) {
+                result.add(flowVideos.get(i));
+            }
+        }
+        if (flowPoolIndex >= flowVideos.size()) {
+            for (int i = rovPoolIndex; i < rovVideos.size() && result.size() < param.getSize(); i++) {
+                result.add(rovVideos.get(i));
+            }
+        }
+
+        return result;
+    }
+
+    private boolean isInsertDouHotFlowPoolVideo() {
+        double rand = RandomUtils.nextDouble(0, 1);
+        return rand <= newFlowPoolSelectRate;
+    }
+}

+ 97 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/rerank/strategy/VideoAttrWeightRerankStrategy.java

@@ -0,0 +1,97 @@
+package com.tzld.piaoquan.recommend.server.service.rerank.strategy;
+
+import com.ctrip.framework.apollo.spring.annotation.ApolloJsonValue;
+import com.tzld.piaoquan.recommend.server.model.Video;
+import com.tzld.piaoquan.recommend.server.service.rerank.RerankParam;
+import lombok.AllArgsConstructor;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+import lombok.extern.slf4j.Slf4j;
+import org.apache.commons.collections4.CollectionUtils;
+import org.apache.commons.lang3.StringUtils;
+import org.springframework.stereotype.Component;
+
+import java.util.*;
+import java.util.stream.Collectors;
+
+/**
+ * 根据视频的属性,对score加权或降权
+ */
+@Slf4j
+@Component
+public class VideoAttrWeightRerankStrategy extends BasicRerankStrategy {
+
+    @ApolloJsonValue("${video.attr.rerank.config:[]}")
+    private List<VideoAttributeConfigItem> configItems;
+
+    @Override
+    public List<Video> rerank(RerankParam param) {
+
+        List<Video> rovVideos = param.getRovVideos();
+        if (CollectionUtils.isEmpty(rovVideos) || CollectionUtils.isEmpty(configItems)) {
+            return rovVideos;
+        }
+
+        // 重新计算分数,并重新排序
+        for (Video video : rovVideos) {
+            this.reCalcVideoScore(video);
+        }
+
+        rovVideos.sort(Comparator.comparingDouble(o -> -o.getSortScore()));
+        param.setRovVideos(rovVideos);
+        return rovVideos;
+    }
+
+    private void reCalcVideoScore(Video video) {
+        double score = video.getScore();
+        video.getScoresMap().put("rerankBeforeScore", score);
+
+        double newScore = score;
+        Map<String, String> basicInfo = video.getMetaFeatureMap().getOrDefault("alg_vid_feature_basic_info", new HashMap<>());
+
+        for (VideoAttributeConfigItem configItem : configItems) {
+
+            String time = configItem.getTime();
+            String key = configItem.getKey();
+            String value = configItem.getValue();
+
+            // 当前时间不需要加权,直接跳过
+            if (!this.isInTimeRangeCondition(time)) {
+                continue;
+            }
+
+            // 当前视频不包含此属性,直接跳过
+            if (!basicInfo.containsKey(key)) {
+                continue;
+            }
+
+            // 当前视频属性值与配置项不符,直接跳过
+            Set<String> configValueSet = Arrays.stream(value.split(","))
+                    .filter(StringUtils::isNotBlank)
+                    .collect(Collectors.toSet());
+            if (!configValueSet.contains(basicInfo.get(key))) {
+                continue;
+            }
+
+            double weight = configItem.getWeight();
+
+            newScore = newScore * weight;
+            video.getScoresMap().put(String.format("%s_%s_%s", key, value, time), weight);
+
+        }
+
+        video.getScoresMap().put("rerankAfterScore", score);
+        video.setScore(newScore);
+        video.setSortScore(newScore);
+    }
+
+    @Data
+    @NoArgsConstructor
+    @AllArgsConstructor
+    private static class VideoAttributeConfigItem {
+        private String key;
+        private String value;
+        private String time;
+        private Double weight;
+    }
+}

+ 1 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/util/DateUtils.java

@@ -4,6 +4,7 @@ import java.text.SimpleDateFormat;
 import java.time.LocalDate;
 import java.time.LocalDateTime;
 import java.time.format.DateTimeFormatter;
+import java.time.format.DateTimeParseException;
 import java.util.Calendar;
 import java.util.Date;