|
@@ -27,10 +27,44 @@ public class RegionRecallScorerV7VovLongTerm extends AbstractScorer4Recall {
|
|
|
if (kv.containsKey(key)) {
|
|
|
List<Pair<Long, Double>> copy = new ArrayList<>(kv.get(key));
|
|
|
// 先随机,再截断。
|
|
|
- Collections.shuffle(copy);
|
|
|
- result.addAll(copy.subList(0, Math.min(count, copy.size())));
|
|
|
+ List<Pair<Long, Double>> selected = getWeightedRandomElements(copy, count);
|
|
|
+ result.addAll(selected);
|
|
|
}
|
|
|
}
|
|
|
return result;
|
|
|
}
|
|
|
+
|
|
|
+ public static List<Pair<Long, Double>> getWeightedRandomElements(List<Pair<Long, Double>> list, int count) {
|
|
|
+ if (list == null || list.isEmpty()) {
|
|
|
+ return new ArrayList<>(0);
|
|
|
+ }
|
|
|
+ if (count >= list.size()) {
|
|
|
+ return new ArrayList<>(list);
|
|
|
+ }
|
|
|
+
|
|
|
+ // 计算权重总和
|
|
|
+ double totalWeight = list.stream().mapToDouble(Pair::getRight).sum();
|
|
|
+
|
|
|
+ // 使用加权随机选择
|
|
|
+ List<Pair<Long, Double>> result = new ArrayList<>();
|
|
|
+ Random random = new Random();
|
|
|
+
|
|
|
+ for (int i = 0; i < count; i++) {
|
|
|
+ double rand = random.nextDouble() * totalWeight; // 生成一个0到totalWeight之间的随机数
|
|
|
+ double cumulativeWeight = 0.0;
|
|
|
+
|
|
|
+ for (Pair<Long, Double> pair : list) {
|
|
|
+ cumulativeWeight += pair.getRight(); // 累加权重
|
|
|
+ if (rand <= cumulativeWeight) {
|
|
|
+ result.add(pair); // 选择当前元素
|
|
|
+ totalWeight -= pair.getRight(); // 更新总权重
|
|
|
+ list.remove(pair); // 防止重复选择
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return result;
|
|
|
+ }
|
|
|
+
|
|
|
}
|