|
@@ -1,6 +1,6 @@
|
|
|
package com.tzld.piaoquan.recommend.server.service.rank;
|
|
|
|
|
|
-import com.tzld.piaoquan.recommend.server.service.recall.RecallResult;
|
|
|
+import com.tzld.piaoquan.recommend.server.model.Video;
|
|
|
import org.apache.commons.collections4.CollectionUtils;
|
|
|
import org.apache.commons.lang3.RandomUtils;
|
|
|
import org.apache.commons.lang3.math.NumberUtils;
|
|
@@ -29,7 +29,7 @@ public class RankService {
|
|
|
List<String> videoIdKeys = param.getRecallResult().getRovPoolRecall().stream()
|
|
|
.map(t -> param.getRankKeyPrefix() + t.getVideoId())
|
|
|
.collect(Collectors.toList());
|
|
|
- List<RecallResult.RecallData> rov_recall_rank = param.getRecallResult().getRovPoolRecall().stream()
|
|
|
+ List<Video> rov_recall_rank = param.getRecallResult().getRovPoolRecall().stream()
|
|
|
.collect(Collectors.toList());
|
|
|
List<String> video_scores = redisTemplate.opsForValue().multiGet(videoIdKeys);
|
|
|
if (CollectionUtils.isNotEmpty(video_scores)
|
|
@@ -41,7 +41,7 @@ public class RankService {
|
|
|
}
|
|
|
|
|
|
// rank flow pool recall
|
|
|
- List<RecallResult.RecallData> flow_recall_rank = param.getRecallResult().getFlowPoolRecall().stream()
|
|
|
+ List<Video> flow_recall_rank = param.getRecallResult().getFlowPoolRecall().stream()
|
|
|
.collect(Collectors.toList());
|
|
|
Collections.sort(flow_recall_rank, Comparator.comparingDouble(o -> -o.getRovScore()));
|
|
|
|
|
@@ -55,9 +55,9 @@ public class RankService {
|
|
|
}
|
|
|
|
|
|
Set<Long> flowPoolVideoIds = new HashSet<>();
|
|
|
- Iterator<RecallResult.RecallData> flowRecallRankIte = flow_recall_rank.iterator();
|
|
|
+ Iterator<Video> flowRecallRankIte = flow_recall_rank.iterator();
|
|
|
while (flowRecallRankIte.hasNext()) {
|
|
|
- RecallResult.RecallData data = flowRecallRankIte.next();
|
|
|
+ Video data = flowRecallRankIte.next();
|
|
|
if (rovTopKVideoIds.contains(data.getVideoId())) {
|
|
|
flowRecallRankIte.remove();
|
|
|
} else {
|
|
@@ -65,15 +65,15 @@ public class RankService {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- Iterator<RecallResult.RecallData> rovRecallRankIte = rov_recall_rank.iterator();
|
|
|
+ Iterator<Video> rovRecallRankIte = rov_recall_rank.iterator();
|
|
|
while (rovRecallRankIte.hasNext()) {
|
|
|
- RecallResult.RecallData data = rovRecallRankIte.next();
|
|
|
+ Video data = rovRecallRankIte.next();
|
|
|
if (flowPoolVideoIds.contains(data.getVideoId())) {
|
|
|
rovRecallRankIte.remove();
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- // 1 取topK
|
|
|
+ // 融合排序
|
|
|
if (CollectionUtils.isEmpty(rov_recall_rank)) {
|
|
|
if (param.getSize() < flow_recall_rank.size()) {
|
|
|
return new RankResult(flow_recall_rank.subList(0, param.getSize()));
|
|
@@ -82,8 +82,7 @@ public class RankService {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- // rov 取topK
|
|
|
- List<RecallResult.RecallData> datas = new ArrayList<>();
|
|
|
+ List<Video> datas = new ArrayList<>();
|
|
|
for (int i = 0; i < param.getTopK() && i < rov_recall_rank.size(); i++) {
|
|
|
datas.add(rov_recall_rank.get(i));
|
|
|
}
|
|
@@ -113,25 +112,16 @@ public class RankService {
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
- // while i < size - top_K:
|
|
|
- // # 随机生成[0, 1)浮点数
|
|
|
- // rand = random.random()
|
|
|
- // # log_.info('rand: {}'.format(rand))
|
|
|
- // if rand < flow_pool_P:
|
|
|
- // if flow_recall_rank:
|
|
|
- // rank_result.append(flow_recall_rank[0])
|
|
|
- // flow_recall_rank.remove(flow_recall_rank[0])
|
|
|
- // else:
|
|
|
- // rank_result.extend(rov_recall_rank[:size - top_K - i])
|
|
|
- // return rank_result[:size], flow_num
|
|
|
- // else:
|
|
|
- // if rov_recall_rank:
|
|
|
- // rank_result.append(rov_recall_rank[0])
|
|
|
- // rov_recall_rank.remove(rov_recall_rank[0])
|
|
|
- // else:
|
|
|
- // rank_result.extend(flow_recall_rank[:size - top_K - i])
|
|
|
- // return rank_result[:size], flow_num
|
|
|
- // i += 1
|
|
|
+ if (rovPoolIndex >= rov_recall_rank.size()) {
|
|
|
+ for (int i = flowPoolIndex; i < flow_recall_rank.size(); i++) {
|
|
|
+ datas.add(flow_recall_rank.get(i));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (flowPoolIndex >= flow_recall_rank.size()) {
|
|
|
+ for (int i = rovPoolIndex; i < rov_recall_rank.size(); i++) {
|
|
|
+ datas.add(rov_recall_rank.get(i));
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
if (param.getSize() < datas.size()) {
|
|
|
return new RankResult(datas.subList(0, param.getSize()));
|