Browse Source

老内容召回-增加召回源内的权重随机

zhangbo 11 months ago
parent
commit
5859debde0

+ 1 - 1
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/recall/strategy/RegionRealtimeRecallStrategyV7VovLongTermV1.java

@@ -29,7 +29,7 @@ public class RegionRealtimeRecallStrategyV7VovLongTermV1 implements RecallStrate
     @Override
     public List<Video> recall(RecallParam param) {
         Map<String, String> param4Model = new HashMap<>(1);
-        param4Model.put("t_2_8", "100");
+        param4Model.put("t_2_8", "50");
         ScorerPipeline4Recall pipeline = ScorerUtils.getScorerPipeline4Recall("feeds_recall_config_region_v7_vov_longterm.conf");
         List<List<Pair<Long, Double>>> results = pipeline.recall(param4Model);
         List<Pair<Long, Double>> result = results.get(0);

+ 1 - 1
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/recall/strategy/RegionRealtimeRecallStrategyV7VovLongTermV2.java

@@ -29,7 +29,7 @@ public class RegionRealtimeRecallStrategyV7VovLongTermV2 implements RecallStrate
     @Override
     public List<Video> recall(RecallParam param) {
         Map<String, String> param4Model = new HashMap<>(1);
-        param4Model.put("t_9_36", "100");
+        param4Model.put("t_9_36", "50");
         ScorerPipeline4Recall pipeline = ScorerUtils.getScorerPipeline4Recall("feeds_recall_config_region_v7_vov_longterm.conf");
         List<List<Pair<Long, Double>>> results = pipeline.recall(param4Model);
         List<Pair<Long, Double>> result = results.get(0);

+ 1 - 1
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/recall/strategy/RegionRealtimeRecallStrategyV7VovLongTermV3.java

@@ -29,7 +29,7 @@ public class RegionRealtimeRecallStrategyV7VovLongTermV3 implements RecallStrate
     @Override
     public List<Video> recall(RecallParam param) {
         Map<String, String> param4Model = new HashMap<>(1);
-        param4Model.put("t_37_90", "100");
+        param4Model.put("t_37_90", "50");
         ScorerPipeline4Recall pipeline = ScorerUtils.getScorerPipeline4Recall("feeds_recall_config_region_v7_vov_longterm.conf");
         List<List<Pair<Long, Double>>> results = pipeline.recall(param4Model);
         List<Pair<Long, Double>> result = results.get(0);

+ 1 - 1
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/recall/strategy/RegionRealtimeRecallStrategyV7VovLongTermV4.java

@@ -29,7 +29,7 @@ public class RegionRealtimeRecallStrategyV7VovLongTermV4 implements RecallStrate
     @Override
     public List<Video> recall(RecallParam param) {
         Map<String, String> param4Model = new HashMap<>(1);
-        param4Model.put("t_91_365", "100");
+        param4Model.put("t_91_365", "50");
         ScorerPipeline4Recall pipeline = ScorerUtils.getScorerPipeline4Recall("feeds_recall_config_region_v7_vov_longterm.conf");
         List<List<Pair<Long, Double>>> results = pipeline.recall(param4Model);
         List<Pair<Long, Double>> result = results.get(0);

+ 36 - 2
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/score4recall/strategy/RegionRecallScorerV7VovLongTerm.java

@@ -27,10 +27,44 @@ public class RegionRecallScorerV7VovLongTerm extends AbstractScorer4Recall {
             if (kv.containsKey(key)) {
                 List<Pair<Long, Double>> copy = new ArrayList<>(kv.get(key));
                 // 先随机,再截断。
-                Collections.shuffle(copy);
-                result.addAll(copy.subList(0, Math.min(count, copy.size())));
+                List<Pair<Long, Double>> selected = getWeightedRandomElements(copy, count);
+                result.addAll(selected);
             }
         }
         return result;
     }
+
+    public static List<Pair<Long, Double>> getWeightedRandomElements(List<Pair<Long, Double>> list, int count) {
+        if (list == null || list.isEmpty()) {
+            return new ArrayList<>(0);
+        }
+        if (count >= list.size()) {
+            return new ArrayList<>(list);
+        }
+
+        // 计算权重总和
+        double totalWeight = list.stream().mapToDouble(Pair::getRight).sum();
+
+        // 使用加权随机选择
+        List<Pair<Long, Double>> result = new ArrayList<>();
+        Random random = new Random();
+
+        for (int i = 0; i < count; i++) {
+            double rand = random.nextDouble() * totalWeight; // 生成一个0到totalWeight之间的随机数
+            double cumulativeWeight = 0.0;
+
+            for (Pair<Long, Double> pair : list) {
+                cumulativeWeight += pair.getRight(); // 累加权重
+                if (rand <= cumulativeWeight) {
+                    result.add(pair); // 选择当前元素
+                    totalWeight -= pair.getRight(); // 更新总权重
+                    list.remove(pair); // 防止重复选择
+                    break;
+                }
+            }
+        }
+
+        return result;
+    }
+
 }