|
@@ -81,9 +81,15 @@ public class YearShareDkElementsRecallStrategy implements RecallStrategy {
|
|
|
|
|
|
|
|
List<String> keys = this.getRedisKey(allElements);
|
|
List<String> keys = this.getRedisKey(allElements);
|
|
|
List<String> values = redisTemplate.opsForValue().multiGet(keys);
|
|
List<String> values = redisTemplate.opsForValue().multiGet(keys);
|
|
|
- List<Long> ids = recall(param.getVideoId(), values);
|
|
|
|
|
|
|
|
|
|
- Map<Long, Double> scoresMap = FilterParamFactory.positionScores(ids);
|
|
|
|
|
|
|
+ // 保留 Redis 倒排的真实 rovn 分 (而非位置分): scoresMap 的 score 会写到 Video.rovScore,
|
|
|
|
|
+ // 粗排截断 coarseMap miss 的 vid 会 fallback 用 Video.rovScore 排序, 真实分更有信号.
|
|
|
|
|
+ Map<Long, Double> scoresMap = recall(param.getVideoId(), values);
|
|
|
|
|
+ List<Long> ids = scoresMap.entrySet().stream()
|
|
|
|
|
+ .sorted(Comparator.comparingDouble((Map.Entry<Long, Double> e) -> e.getValue()).reversed())
|
|
|
|
|
+ .map(Map.Entry::getKey)
|
|
|
|
|
+ .collect(Collectors.toList());
|
|
|
|
|
+
|
|
|
FilterParam filterParam = FilterParamFactory.create(param, ids, pushFrom(), scoresMap);
|
|
FilterParam filterParam = FilterParamFactory.create(param, ids, pushFrom(), scoresMap);
|
|
|
FilterResult filterResult = filterService.filter(filterParam);
|
|
FilterResult filterResult = filterService.filter(filterParam);
|
|
|
if (filterResult != null && CollectionUtils.isNotEmpty(filterResult.getVideoIds())) {
|
|
if (filterResult != null && CollectionUtils.isNotEmpty(filterResult.getVideoIds())) {
|
|
@@ -144,40 +150,44 @@ public class YearShareDkElementsRecallStrategy implements RecallStrategy {
|
|
|
return keys;
|
|
return keys;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- private List<Long> recall(Long headVid, List<String> values) {
|
|
|
|
|
- List<Long> vidList = new ArrayList<>();
|
|
|
|
|
- if (null != values && !values.isEmpty()) {
|
|
|
|
|
- Set<Long> hits = new HashSet<>();
|
|
|
|
|
- hits.add(headVid);
|
|
|
|
|
- List<org.apache.commons.math3.util.Pair<Long, Double>> list = new ArrayList<>();
|
|
|
|
|
- for (String value : values) {
|
|
|
|
|
- if (null != value && !value.isEmpty()) {
|
|
|
|
|
- String[] cells = value.split("\t");
|
|
|
|
|
- if (2 == cells.length) {
|
|
|
|
|
- List<Long> ids = Arrays.stream(cells[0].split(",")).map(Long::valueOf).collect(Collectors.toList());
|
|
|
|
|
- List<Double> scores = Arrays.stream(cells[1].split(",")).map(Double::valueOf).collect(Collectors.toList());
|
|
|
|
|
- if (!ids.isEmpty() && ids.size() == scores.size()) {
|
|
|
|
|
- for (int i = 0; i < ids.size(); ++i) {
|
|
|
|
|
- long id = ids.get(i);
|
|
|
|
|
- double score = scores.get(i);
|
|
|
|
|
- if (hits.contains(id)) {
|
|
|
|
|
- continue;
|
|
|
|
|
- }
|
|
|
|
|
- hits.add(id);
|
|
|
|
|
- list.add(org.apache.commons.math3.util.Pair.create(id, score));
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ /**
|
|
|
|
|
+ * 解析 multiGet 拿到的 N 个 Redis value, 拼成 vid -> 真实 score map.
|
|
|
|
|
+ * value 格式: vid1,vid2,...\tscore1,score2,... (rovn 真实分)
|
|
|
|
|
+ * 同 vid 在多个 element 倒排里出现时, 取 max score (跟 AbstractRedisRecallStrategy 一致).
|
|
|
|
|
+ */
|
|
|
|
|
+ private Map<Long, Double> recall(Long headVid, List<String> values) {
|
|
|
|
|
+ Map<Long, Double> scoresMap = new HashMap<>();
|
|
|
|
|
+ if (CollectionUtils.isEmpty(values)) {
|
|
|
|
|
+ return scoresMap;
|
|
|
|
|
+ }
|
|
|
|
|
+ for (String value : values) {
|
|
|
|
|
+ if (StringUtils.isBlank(value)) {
|
|
|
|
|
+ continue;
|
|
|
|
|
+ }
|
|
|
|
|
+ String[] cells = value.split("\t");
|
|
|
|
|
+ if (cells.length != 2) {
|
|
|
|
|
+ continue;
|
|
|
|
|
+ }
|
|
|
|
|
+ List<Long> ids;
|
|
|
|
|
+ List<Double> scores;
|
|
|
|
|
+ try {
|
|
|
|
|
+ ids = Arrays.stream(cells[0].split(",")).map(Long::valueOf).collect(Collectors.toList());
|
|
|
|
|
+ scores = Arrays.stream(cells[1].split(",")).map(Double::valueOf).collect(Collectors.toList());
|
|
|
|
|
+ } catch (NumberFormatException nfe) {
|
|
|
|
|
+ continue;
|
|
|
|
|
+ }
|
|
|
|
|
+ if (ids.isEmpty() || ids.size() != scores.size()) {
|
|
|
|
|
+ continue;
|
|
|
}
|
|
}
|
|
|
- if (!list.isEmpty()) {
|
|
|
|
|
- list.sort(Comparator.comparingDouble(o -> -o.getSecond()));
|
|
|
|
|
- for (org.apache.commons.math3.util.Pair<Long, Double> pair : list) {
|
|
|
|
|
- vidList.add(pair.getFirst());
|
|
|
|
|
|
|
+ for (int i = 0; i < ids.size(); i++) {
|
|
|
|
|
+ long id = ids.get(i);
|
|
|
|
|
+ if (headVid != null && headVid == id) {
|
|
|
|
|
+ continue;
|
|
|
}
|
|
}
|
|
|
|
|
+ scoresMap.merge(id, scores.get(i), Math::max);
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
- return vidList;
|
|
|
|
|
|
|
+ return scoresMap;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
@Override
|
|
@Override
|