Selaa lähdekoodia

Merge branch 'fix/funnel-coarse-rank-step3-subset' of algorithm/recommend-server into master

yangxiaohui 3 päivää sitten
vanhempi
commit
1106d07c31

+ 1 - 1
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/funnel/FunnelAggregator.java

@@ -116,7 +116,7 @@ public class FunnelAggregator {
         ctx.getStages123RecallByStrategy().forEach((pf, entries) -> {
             JSONArray arr = new JSONArray();
             for (RecallVideoEntry e : entries) {
-                if (e.getSelect() == null) continue;
+                if (!e.isFilteredIn() || e.getSelect() == null) continue;
                 JSONObject o = new JSONObject();
                 o.put("vid", e.getVideoId());
                 o.put("index", displayIndex(e.getIndex()));

+ 6 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/rank/strategy/RankStrategy4RegionMergeModelV564.java

@@ -114,11 +114,17 @@ public class RankStrategy4RegionMergeModelV564 extends RankStrategy4RegionMergeM
         int personalTopN = (int) Math.round(totalTopN * personalRatio);
         Map<Long, Double> coarseRankMap = fetchCoarseRankScores(param);
 
+        int personalCandidates = RecallUtils.countDistinctCandidates(param, setVideo, PERSONAL_RECALL_PUSH_FROMS);
         int sizeBeforePersonal = rovRecallRank.size();
         RecallUtils.extractAllAndTruncateByCoarseRank(personalTopN, param, setVideo, rovRecallRank, coarseRankMap, PERSONAL_RECALL_PUSH_FROMS);
         int personalActual = rovRecallRank.size() - sizeBeforePersonal;
         int nonPersonalBudget = totalTopN - personalActual;  // 个性化不足时, 名额转给非个性化
+        int nonPersonalCandidates = RecallUtils.countDistinctCandidates(param, setVideo, NON_PERSONAL_RECALL_PUSH_FROMS);
+        int sizeBeforeNonPersonal = rovRecallRank.size();
         RecallUtils.extractAllAndTruncateByCoarseRank(nonPersonalBudget, param, setVideo, rovRecallRank, coarseRankMap, NON_PERSONAL_RECALL_PUSH_FROMS);
+        int nonPersonalActual = rovRecallRank.size() - sizeBeforeNonPersonal;
+        log.info("coarse_rank_summary exp=564 quota={} pc={} ps={} nc={} ns={}",
+                totalTopN, personalCandidates, personalActual, nonPersonalCandidates, nonPersonalActual);
 
         // 记录召回源中的视频
         this.rankBeforePostProcessor(rovRecallRank);

+ 6 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/rank/strategy/RankStrategy4RegionMergeModelV566.java

@@ -114,11 +114,17 @@ public class RankStrategy4RegionMergeModelV566 extends RankStrategy4RegionMergeM
         int personalTopN = (int) Math.round(totalTopN * personalRatio);
         Map<Long, Double> coarseRankMap = fetchCoarseRankScores(param);
 
+        int personalCandidates = RecallUtils.countDistinctCandidates(param, setVideo, PERSONAL_RECALL_PUSH_FROMS);
         int sizeBeforePersonal = rovRecallRank.size();
         RecallUtils.extractAllAndTruncateByCoarseRank(personalTopN, param, setVideo, rovRecallRank, coarseRankMap, PERSONAL_RECALL_PUSH_FROMS);
         int personalActual = rovRecallRank.size() - sizeBeforePersonal;
         int nonPersonalBudget = totalTopN - personalActual;  // 个性化不足时, 名额转给非个性化
+        int nonPersonalCandidates = RecallUtils.countDistinctCandidates(param, setVideo, NON_PERSONAL_RECALL_PUSH_FROMS);
+        int sizeBeforeNonPersonal = rovRecallRank.size();
         RecallUtils.extractAllAndTruncateByCoarseRank(nonPersonalBudget, param, setVideo, rovRecallRank, coarseRankMap, NON_PERSONAL_RECALL_PUSH_FROMS);
+        int nonPersonalActual = rovRecallRank.size() - sizeBeforeNonPersonal;
+        log.info("coarse_rank_summary exp=566 quota={} pc={} ps={} nc={} ns={}",
+                totalTopN, personalCandidates, personalActual, nonPersonalCandidates, nonPersonalActual);
 
         // 记录召回源中的视频
         this.rankBeforePostProcessor(rovRecallRank);

+ 6 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/rank/strategy/RankStrategy4RegionMergeModelV569.java

@@ -114,11 +114,17 @@ public class RankStrategy4RegionMergeModelV569 extends RankStrategy4RegionMergeM
         int personalTopN = (int) Math.round(totalTopN * personalRatio);
         Map<Long, Double> coarseRankMap = fetchCoarseRankScores(param);
 
+        int personalCandidates = RecallUtils.countDistinctCandidates(param, setVideo, PERSONAL_RECALL_PUSH_FROMS);
         int sizeBeforePersonal = rovRecallRank.size();
         RecallUtils.extractAllAndTruncateByCoarseRank(personalTopN, param, setVideo, rovRecallRank, coarseRankMap, PERSONAL_RECALL_PUSH_FROMS);
         int personalActual = rovRecallRank.size() - sizeBeforePersonal;
         int nonPersonalBudget = totalTopN - personalActual;  // 个性化不足时, 名额转给非个性化
+        int nonPersonalCandidates = RecallUtils.countDistinctCandidates(param, setVideo, NON_PERSONAL_RECALL_PUSH_FROMS);
+        int sizeBeforeNonPersonal = rovRecallRank.size();
         RecallUtils.extractAllAndTruncateByCoarseRank(nonPersonalBudget, param, setVideo, rovRecallRank, coarseRankMap, NON_PERSONAL_RECALL_PUSH_FROMS);
+        int nonPersonalActual = rovRecallRank.size() - sizeBeforeNonPersonal;
+        log.info("coarse_rank_summary exp=569 quota={} pc={} ps={} nc={} ns={}",
+                totalTopN, personalCandidates, personalActual, nonPersonalCandidates, nonPersonalActual);
 
         // 记录召回源中的视频
         this.rankBeforePostProcessor(rovRecallRank);

+ 52 - 10
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/util/RecallUtils.java

@@ -48,6 +48,29 @@ public class RecallUtils {
         return Collections.emptyList();
     }
 
+    public static int countDistinctCandidates(RankParam param, Set<Long> excludedVideoIds,
+                                              Set<String> includedPushFroms) {
+        if (param == null || param.getRecallResult() == null
+                || CollectionUtils.isEmpty(param.getRecallResult().getData())
+                || CollectionUtils.isEmpty(includedPushFroms)) {
+            return 0;
+        }
+        Set<Long> candidates = new HashSet<>();
+        for (RecallResult.RecallData data : param.getRecallResult().getData()) {
+            if (data == null || CollectionUtils.isEmpty(data.getVideos())
+                    || !includedPushFroms.contains(data.getPushFrom())) {
+                continue;
+            }
+            for (Video video : data.getVideos()) {
+                if (video == null) continue;
+                long videoId = video.getVideoId();
+                if (excludedVideoIds != null && excludedVideoIds.contains(videoId)) continue;
+                candidates.add(videoId);
+            }
+        }
+        return candidates.size();
+    }
+
     public static void extractOldSpecialRecall(int sizeReturn, RankParam param, Set<Long> setVideo, List<Video> rovRecallRank) {
         List<Video> oldRovs = new ArrayList<>();
         oldRovs.addAll(extractAndSort(param, RegionHRecallStrategy.PUSH_FORM));
@@ -93,8 +116,12 @@ public class RecallUtils {
             return;
         }
         Map<Long, Double> coarseMap = coarseRankMap != null ? coarseRankMap : Collections.emptyMap();
+        Set<Long> selectedBefore = rovRecallRank.stream()
+                .map(Video::getVideoId)
+                .collect(Collectors.toSet());
         // 1. 全并 + dedupe(首次命中保留,只挑 includedPushFroms 命中的路)
         Map<Long, Video> mergedById = new LinkedHashMap<>();
+        Map<Long, String> attributionById = new HashMap<>();
         for (RecallResult.RecallData data : param.getRecallResult().getData()) {
             if (data == null || CollectionUtils.isEmpty(data.getVideos())) continue;
             if (!includedPushFroms.contains(data.getPushFrom())) continue;
@@ -103,11 +130,13 @@ public class RecallUtils {
                 long vid = v.getVideoId();
                 if (setVideo.contains(vid)) continue;
                 mergedById.putIfAbsent(vid, v);
+                attributionById.putIfAbsent(vid, data.getPushFrom());
             }
         }
         if (mergedById.isEmpty()) {
-            // 仍然给 includedPushFroms 范围内 entry 归因:本次没新增 vid,但要标 OTHER(被前面池抢光 + 没命中本类)
-            markCoarseRankAttribution(param, includedPushFroms, Collections.emptySet(), topN, coarseMap);
+            // 本池没有新增 vid 时,只标记确实已被前序池选走的重复候选。
+            markCoarseRankAttribution(param, includedPushFroms, Collections.emptyMap(),
+                    selectedBefore, topN, coarseMap);
             return;
         }
 
@@ -122,19 +151,25 @@ public class RecallUtils {
 
         // 3. 漏斗归因(仅 includedPushFroms 范围内的 entry):
         //    - entry.score 覆盖为粗排分(命运分;同 vid 在所有命中路 entry 都覆盖同一个值)
-        //    - entry.select = SELF/OTHER 复用 V568 二元语义:vid 在本次截断胜出 → SELF,否则 OTHER
-        //      (跨类抢占场景: vid 同时在个性化+非个性化命中, 个性化先抢走时, 非个性化路的 vid entry
-        //       在第二次调用时 pickedIds 不含 it → 自动标 OTHER, 含义="vid 被前面池抢走", 与 V568 一致)
+        //    - SELF = 本次截断胜出且由该 pushFrom 首次贡献
+        //    - OTHER = 同一 vid 被本次其他 pushFrom 或前序池选走
+        //    - 普通粗排落选保持 null,不进入阶段 3
         //    - entry.truncate = 本次调用的 topN
-        markCoarseRankAttribution(param, includedPushFroms, pickedIds, topN, coarseMap);
+        Map<Long, String> pickedAttribution = new HashMap<>();
+        for (Long pickedId : pickedIds) {
+            pickedAttribution.put(pickedId, attributionById.get(pickedId));
+        }
+        markCoarseRankAttribution(param, includedPushFroms, pickedAttribution,
+                selectedBefore, topN, coarseMap);
     }
 
     /**
-     * V564:覆盖 entry.score + 写 SELF/OTHER + truncate(仅 includedPushFroms 范围)。
+     * 统一粗排漏斗归因:仅阶段 2 过滤后的视频可进入阶段 3
      * 流量池 3 路等 includedPushFroms 之外的 entry 保持原状(select=null, 不参与 V564 截断)。
      */
     private static void markCoarseRankAttribution(RankParam param, Set<String> includedPushFroms,
-                                                   Set<Long> pickedIds, int truncate,
+                                                   Map<Long, String> pickedAttribution,
+                                                   Set<Long> selectedBefore, int truncate,
                                                    Map<Long, Double> coarseRankMap) {
         if (param == null || param.getFunnelContext() == null
                 || CollectionUtils.isEmpty(includedPushFroms)) return;
@@ -143,10 +178,17 @@ public class RecallUtils {
             List<RecallVideoEntry> entries = ctx.getStages123RecallByStrategy().get(pf);
             if (CollectionUtils.isEmpty(entries)) continue;
             for (RecallVideoEntry e : entries) {
+                if (!e.isFilteredIn()) continue;
                 Double coarse = coarseRankMap.get(e.getVideoId());
                 if (coarse != null) e.setScore(coarse);
-                e.setSelect(pickedIds.contains(e.getVideoId()) ? SelectKind.SELF : SelectKind.OTHER);
-                e.setTruncate(truncate);
+                String attributedPushFrom = pickedAttribution.get(e.getVideoId());
+                if (attributedPushFrom != null) {
+                    e.setSelect(attributedPushFrom.equals(pf) ? SelectKind.SELF : SelectKind.OTHER);
+                    e.setTruncate(truncate);
+                } else if (selectedBefore.contains(e.getVideoId())) {
+                    e.setSelect(SelectKind.OTHER);
+                    e.setTruncate(truncate);
+                }
             }
         }
     }

+ 38 - 0
recommend-server-service/src/test/java/com/tzld/piaoquan/recommend/server/service/funnel/FunnelAggregatorTest.java

@@ -0,0 +1,38 @@
+package com.tzld.piaoquan.recommend.server.service.funnel;
+
+import com.alibaba.fastjson.JSON;
+import com.alibaba.fastjson.JSONArray;
+import com.alibaba.fastjson.JSONObject;
+import org.junit.jupiter.api.Test;
+
+import java.util.Arrays;
+import java.util.Map;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+class FunnelAggregatorTest {
+
+    @Test
+    void step3NeverContainsVideoFilteredOutAtStep2() {
+        FunnelContext context = new FunnelContext();
+        RecallVideoEntry kept = entry(1, true, SelectKind.SELF);
+        RecallVideoEntry filtered = entry(2, false, SelectKind.SELF);
+        context.getStages123RecallByStrategy().put("p1", Arrays.asList(kept, filtered));
+
+        Map<String, String> logItem = FunnelAggregator.toLogItem(context);
+        JSONObject step3 = JSON.parseObject(logItem.get("step_3_truncated"));
+        JSONArray videos = step3.getJSONArray("p1");
+
+        assertEquals(1, videos.size());
+        assertEquals(1L, videos.getJSONObject(0).getLongValue("vid"));
+    }
+
+    private static RecallVideoEntry entry(long videoId, boolean filteredIn, SelectKind select) {
+        RecallVideoEntry entry = new RecallVideoEntry(videoId, 0, 0);
+        entry.setFilteredIn(filteredIn);
+        entry.setIndexNewAfterFilter(filteredIn ? 0 : -1);
+        entry.setSelect(select);
+        entry.setTruncate(1);
+        return entry;
+    }
+}

+ 161 - 0
recommend-server-service/src/test/java/com/tzld/piaoquan/recommend/server/util/RecallUtilsTest.java

@@ -0,0 +1,161 @@
+package com.tzld.piaoquan.recommend.server.util;
+
+import com.tzld.piaoquan.recommend.server.model.Video;
+import com.tzld.piaoquan.recommend.server.service.funnel.FunnelContext;
+import com.tzld.piaoquan.recommend.server.service.funnel.RecallVideoEntry;
+import com.tzld.piaoquan.recommend.server.service.funnel.SelectKind;
+import com.tzld.piaoquan.recommend.server.service.rank.RankParam;
+import com.tzld.piaoquan.recommend.server.service.recall.RecallResult;
+import org.junit.jupiter.api.Test;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNull;
+
+class RecallUtilsTest {
+
+    @Test
+    void countDistinctCandidatesDeduplicatesAndExcludesPreviouslySelectedVideos() {
+        RankParam param = rankParam(
+                Arrays.asList(
+                        recallData("p1", video(1, 10), video(2, 9)),
+                        recallData("p2", video(2, 9), video(3, 8)),
+                        recallData("ignored", video(4, 7))
+                ),
+                Collections.emptyList()
+        );
+
+        int candidates = RecallUtils.countDistinctCandidates(
+                param, Collections.singleton(1L), new HashSet<>(Arrays.asList("p1", "p2")));
+
+        assertEquals(2, candidates);
+    }
+
+    @Test
+    void coarseRankOnlyAttributesFilteredCandidatesThatWereSelected() {
+        RankParam param = rankParam(
+                recallData("p1", video(1, 10), video(3, 1)),
+                entries("p1",
+                        entry(1, true),
+                        entry(2, false),
+                        entry(3, true))
+        );
+
+        RecallUtils.extractAllAndTruncateByCoarseRank(
+                1, param, new HashSet<>(), new ArrayList<>(),
+                scores(1, 10, 3, 1), Collections.singleton("p1"));
+
+        List<RecallVideoEntry> entries = param.getFunnelContext()
+                .getStages123RecallByStrategy().get("p1");
+        assertEquals(SelectKind.SELF, entries.get(0).getSelect());
+        assertNull(entries.get(1).getSelect(), "filtered video must not enter step 3");
+        assertNull(entries.get(2).getSelect(), "ordinary coarse-rank loser must not enter step 3");
+    }
+
+    @Test
+    void coarseRankMarksDuplicatePushFromAsOther() {
+        RankParam param = rankParam(
+                Arrays.asList(
+                        recallData("p1", video(1, 10)),
+                        recallData("p2", video(1, 10))
+                ),
+                Arrays.asList(
+                        entries("p1", entry(1, true)),
+                        entries("p2", entry(1, true))
+                )
+        );
+
+        RecallUtils.extractAllAndTruncateByCoarseRank(
+                1, param, new HashSet<>(), new ArrayList<>(),
+                scores(1, 10), new HashSet<>(Arrays.asList("p1", "p2")));
+
+        assertEquals(SelectKind.SELF, entry(param, "p1", 0).getSelect());
+        assertEquals(SelectKind.OTHER, entry(param, "p2", 0).getSelect());
+    }
+
+    @Test
+    void coarseRankMarksCandidateSelectedByPreviousPoolAsOther() {
+        RankParam param = rankParam(
+                Arrays.asList(
+                        recallData("personal", video(1, 10)),
+                        recallData("nonPersonal", video(1, 10), video(2, 9))
+                ),
+                Arrays.asList(
+                        entries("personal", entry(1, true)),
+                        entries("nonPersonal", entry(1, true), entry(2, true))
+                )
+        );
+        Set<Long> selectedIds = new HashSet<>();
+        List<Video> result = new ArrayList<>();
+
+        RecallUtils.extractAllAndTruncateByCoarseRank(
+                1, param, selectedIds, result, scores(1, 10, 2, 9),
+                Collections.singleton("personal"));
+        RecallUtils.extractAllAndTruncateByCoarseRank(
+                1, param, selectedIds, result, scores(1, 10, 2, 9),
+                Collections.singleton("nonPersonal"));
+
+        assertEquals(SelectKind.OTHER, entry(param, "nonPersonal", 0).getSelect());
+        assertEquals(SelectKind.SELF, entry(param, "nonPersonal", 1).getSelect());
+    }
+
+    private static RankParam rankParam(RecallResult.RecallData data,
+                                       Map.Entry<String, List<RecallVideoEntry>> entries) {
+        return rankParam(Collections.singletonList(data), Collections.singletonList(entries));
+    }
+
+    private static RankParam rankParam(List<RecallResult.RecallData> data,
+                                       List<Map.Entry<String, List<RecallVideoEntry>>> entries) {
+        FunnelContext context = new FunnelContext();
+        for (Map.Entry<String, List<RecallVideoEntry>> item : entries) {
+            context.getStages123RecallByStrategy().put(item.getKey(), item.getValue());
+        }
+        RankParam param = new RankParam();
+        param.setRecallResult(new RecallResult(data));
+        param.setFunnelContext(context);
+        return param;
+    }
+
+    private static RecallResult.RecallData recallData(String pushFrom, Video... videos) {
+        return new RecallResult.RecallData(pushFrom, Arrays.asList(videos));
+    }
+
+    private static Map.Entry<String, List<RecallVideoEntry>> entries(
+            String pushFrom, RecallVideoEntry... entries) {
+        return new java.util.AbstractMap.SimpleEntry<>(pushFrom, Arrays.asList(entries));
+    }
+
+    private static RecallVideoEntry entry(long videoId, boolean filteredIn) {
+        RecallVideoEntry entry = new RecallVideoEntry(videoId, 0, 0);
+        entry.setFilteredIn(filteredIn);
+        entry.setIndexNewAfterFilter(filteredIn ? 0 : -1);
+        return entry;
+    }
+
+    private static RecallVideoEntry entry(RankParam param, String pushFrom, int index) {
+        return param.getFunnelContext().getStages123RecallByStrategy().get(pushFrom).get(index);
+    }
+
+    private static Video video(long videoId, double score) {
+        Video video = new Video();
+        video.setVideoId(videoId);
+        video.setRovScore(score);
+        return video;
+    }
+
+    private static Map<Long, Double> scores(double... values) {
+        Map<Long, Double> scores = new HashMap<>();
+        for (int i = 0; i < values.length; i += 2) {
+            scores.put((long) values[i], values[i + 1]);
+        }
+        return scores;
+    }
+}