|
|
@@ -0,0 +1,199 @@
|
|
|
+package com.tzld.piaoquan.recommend.server.service.recall.strategy;
|
|
|
+
|
|
|
+import com.tzld.piaoquan.recommend.server.model.Video;
|
|
|
+import com.tzld.piaoquan.recommend.server.service.filter.FilterParam;
|
|
|
+import com.tzld.piaoquan.recommend.server.service.filter.FilterResult;
|
|
|
+import com.tzld.piaoquan.recommend.server.service.filter.FilterService;
|
|
|
+import com.tzld.piaoquan.recommend.server.service.recall.FilterParamFactory;
|
|
|
+import com.tzld.piaoquan.recommend.server.service.recall.RecallParam;
|
|
|
+import com.tzld.piaoquan.recommend.server.service.recall.RecallStrategy;
|
|
|
+import com.tzld.piaoquan.recommend.server.util.DkElementsUtils;
|
|
|
+import com.tzld.piaoquan.recommend.server.util.FeatureUtils;
|
|
|
+import lombok.extern.slf4j.Slf4j;
|
|
|
+import org.apache.commons.collections4.CollectionUtils;
|
|
|
+import org.apache.commons.collections4.MapUtils;
|
|
|
+import org.apache.commons.lang3.StringUtils;
|
|
|
+import org.apache.commons.lang3.math.NumberUtils;
|
|
|
+import org.apache.commons.lang3.tuple.Pair;
|
|
|
+import org.springframework.beans.factory.annotation.Autowired;
|
|
|
+import org.springframework.beans.factory.annotation.Qualifier;
|
|
|
+import org.springframework.data.redis.core.RedisTemplate;
|
|
|
+import org.springframework.stereotype.Component;
|
|
|
+
|
|
|
+import java.util.*;
|
|
|
+import java.util.stream.Collectors;
|
|
|
+
|
|
|
+/**
|
|
|
+ * 视频解构 实质元素 rovn 召回 (用户近期 click 回流行为 -> dk_elements)
|
|
|
+ * 触发源相对 YearShareDkElementsRecallStrategy 改为 click 行为, 其余逻辑、Redis 倒排、参数完全一致。
|
|
|
+ *
|
|
|
+ * 每个 click vid 一般有多个 element, parseUserActionVideoAndElements 返回扁平的 (vid, element) pair 列表;
|
|
|
+ * parse 时已按 c >= MIN_CONTRIB_SCORE (0.8) 过滤掉低贡献分元素, 噪声不进召回.
|
|
|
+ *
|
|
|
+ * 挑 kw 逻辑: 按近期 click 行为时间序遍历摊平 element, distinct 取前 topN (30) 个 -> 一次 multiGet
|
|
|
+ * elements_rovn_recall:{kw} 倒排. 不再做"最近+最频"并集 (元素粒度比 cate2 细, 取近期 30 更直接).
|
|
|
+ *
|
|
|
+ * 上游 ODPS: alg_recsys_recall_elements_rovn (原始元素 -> top-50 vid + rovn 得分)
|
|
|
+ * Redis key: elements_rovn_recall:{原始元素}
|
|
|
+ * value: vid1,vid2,...\tscore1,score2,...
|
|
|
+ */
|
|
|
+@Slf4j
|
|
|
+@Component
|
|
|
+public class YearReturnDkElementsRecallStrategy implements RecallStrategy {
|
|
|
+
|
|
|
+ @Autowired
|
|
|
+ @Qualifier("redisTemplate")
|
|
|
+ private RedisTemplate<String, String> redisTemplate;
|
|
|
+
|
|
|
+ @Autowired
|
|
|
+ private FilterService filterService;
|
|
|
+
|
|
|
+ private final String CLASS_NAME = this.getClass().getSimpleName();
|
|
|
+
|
|
|
+ public static final String PUSH_FROM = "recall_user_year_return_dk_elements";
|
|
|
+ public static final String redisKeyPrefix = "elements_rovn_recall";
|
|
|
+
|
|
|
+ /** 元素贡献分过滤阈值 (parse 时丢弃 c < 0.8 的 element, 噪声元素不进召回) */
|
|
|
+ public static final double MIN_CONTRIB_SCORE = 0.8;
|
|
|
+ /** 摊平后按时间序 distinct 取前 N 个 element 进 Redis 倒排查询 */
|
|
|
+ public static final int topN = 30;
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public List<Video> recall(RecallParam param) {
|
|
|
+
|
|
|
+ List<Video> videosResult = new ArrayList<>();
|
|
|
+ try {
|
|
|
+
|
|
|
+ if (MapUtils.isEmpty(param.getUserNetworkSeqVideoInfoMap())) {
|
|
|
+ return videosResult;
|
|
|
+ }
|
|
|
+
|
|
|
+ List<Pair<Long, String>> userNetworkVideoElement = this.parseUserActionVideoAndElements(param.getUserNetworkSeqFeature(), param.getUserNetworkSeqVideoInfoMap());
|
|
|
+ if (CollectionUtils.isEmpty(userNetworkVideoElement)) {
|
|
|
+ return videosResult;
|
|
|
+ }
|
|
|
+ // 按用户近期 click 行为时间序遍历, distinct 取前 topN 个高贡献分 element
|
|
|
+ // (贡献分过滤已在 parseUserActionVideoAndElements 内 c >= MIN_CONTRIB_SCORE 完成)
|
|
|
+ List<String> allElements = userNetworkVideoElement.stream()
|
|
|
+ .map(Pair::getValue)
|
|
|
+ .filter(StringUtils::isNotBlank)
|
|
|
+ .distinct()
|
|
|
+ .limit(topN)
|
|
|
+ .collect(Collectors.toList());
|
|
|
+
|
|
|
+ List<String> keys = this.getRedisKey(allElements);
|
|
|
+ List<String> values = redisTemplate.opsForValue().multiGet(keys);
|
|
|
+
|
|
|
+ // 保留 Redis 倒排的真实 rovn 分 (而非位置分): scoresMap 的 score 会写到 Video.rovScore,
|
|
|
+ // 粗排截断 coarseMap miss 的 vid 会 fallback 用 Video.rovScore 排序, 真实分更有信号.
|
|
|
+ Map<Long, Double> scoresMap = recall(param.getVideoId(), values);
|
|
|
+ List<Long> ids = scoresMap.entrySet().stream()
|
|
|
+ .sorted(Comparator.comparingDouble((Map.Entry<Long, Double> e) -> e.getValue()).reversed())
|
|
|
+ .map(Map.Entry::getKey)
|
|
|
+ .collect(Collectors.toList());
|
|
|
+
|
|
|
+ FilterParam filterParam = FilterParamFactory.create(param, ids, pushFrom(), scoresMap);
|
|
|
+ FilterResult filterResult = filterService.filter(filterParam);
|
|
|
+ if (filterResult != null && CollectionUtils.isNotEmpty(filterResult.getVideoIds())) {
|
|
|
+ for (Long vid : filterResult.getVideoIds()) {
|
|
|
+ Video video = new Video();
|
|
|
+ video.setVideoId(vid);
|
|
|
+ video.setRovScore(scoresMap.getOrDefault(vid, 0.0));
|
|
|
+ video.setPushFrom(pushFrom());
|
|
|
+ videosResult.add(video);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } catch (Exception e) {
|
|
|
+ log.error("recall is wrong in {}, error={}", CLASS_NAME, e);
|
|
|
+ }
|
|
|
+
|
|
|
+ return videosResult;
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 摊平: 每个 click vid 一般有多个 dk_element, 输出 (vid, element) pair 序列, 按 vid 时间序保留
|
|
|
+ */
|
|
|
+ private List<Pair<Long, String>> parseUserActionVideoAndElements(Map<String, String> userNetworkSeqFeature, Map<Long, Map<String, String>> userNetworkSeqVideoInfoMap) {
|
|
|
+ List<Pair<Long, String>> result = new ArrayList<>();
|
|
|
+ List<String> actVidSeq = FeatureUtils.extractVidsFromUserNetworkSeqFeature(userNetworkSeqFeature, "a_v_s");
|
|
|
+ List<String> actTypeSeq = FeatureUtils.extractVidsFromUserNetworkSeqFeature(userNetworkSeqFeature, "a_t_s");
|
|
|
+ if (actVidSeq.size() != actTypeSeq.size()) {
|
|
|
+ return new ArrayList<>();
|
|
|
+ }
|
|
|
+
|
|
|
+ for (int i = 0; i < actVidSeq.size(); i++) {
|
|
|
+ long videoIdL = NumberUtils.toLong(actVidSeq.get(i), -1);
|
|
|
+ if (videoIdL <= 0) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ String type = actTypeSeq.get(i);
|
|
|
+ if (!"click".equals(type)) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+
|
|
|
+ Map<String, String> videoBaseInfo = userNetworkSeqVideoInfoMap.getOrDefault(videoIdL, new HashMap<>());
|
|
|
+ String dkElementsStr = videoBaseInfo.get("dk_elements");
|
|
|
+ if (StringUtils.isBlank(dkElementsStr)) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ List<String> kws = DkElementsUtils.parseElementKws(dkElementsStr, MIN_CONTRIB_SCORE);
|
|
|
+ for (String kw : kws) {
|
|
|
+ result.add(Pair.of(videoIdL, kw));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return result;
|
|
|
+ }
|
|
|
+
|
|
|
+ private List<String> getRedisKey(List<String> elementList) {
|
|
|
+ List<String> keys = new ArrayList<>();
|
|
|
+ for (String element : elementList) {
|
|
|
+ keys.add(String.format("%s:%s", redisKeyPrefix, element));
|
|
|
+ }
|
|
|
+ return keys;
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 解析 multiGet 拿到的 N 个 Redis value, 拼成 vid -> 真实 score map.
|
|
|
+ * value 格式: vid1,vid2,...\tscore1,score2,... (rovn 真实分)
|
|
|
+ * 同 vid 在多个 element 倒排里出现时, 取 max score (跟 AbstractRedisRecallStrategy 一致).
|
|
|
+ */
|
|
|
+ private Map<Long, Double> recall(Long headVid, List<String> values) {
|
|
|
+ Map<Long, Double> scoresMap = new HashMap<>();
|
|
|
+ if (CollectionUtils.isEmpty(values)) {
|
|
|
+ return scoresMap;
|
|
|
+ }
|
|
|
+ for (String value : values) {
|
|
|
+ if (StringUtils.isBlank(value)) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ String[] cells = value.split("\t");
|
|
|
+ if (cells.length != 2) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ List<Long> ids;
|
|
|
+ List<Double> scores;
|
|
|
+ try {
|
|
|
+ ids = Arrays.stream(cells[0].split(",")).map(Long::valueOf).collect(Collectors.toList());
|
|
|
+ scores = Arrays.stream(cells[1].split(",")).map(Double::valueOf).collect(Collectors.toList());
|
|
|
+ } catch (NumberFormatException nfe) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ if (ids.isEmpty() || ids.size() != scores.size()) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ for (int i = 0; i < ids.size(); i++) {
|
|
|
+ long id = ids.get(i);
|
|
|
+ if (headVid != null && headVid == id) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ scoresMap.merge(id, scores.get(i), Math::max);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return scoresMap;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public String pushFrom() {
|
|
|
+ return PUSH_FROM;
|
|
|
+ }
|
|
|
+}
|