소스 검색

特征点泛化 使用特征树召回

wangyunpeng 18 시간 전
부모
커밋
c433703238

+ 205 - 0
core/src/main/java/com/tzld/videoVector/api/LibraryApiService.java

@@ -0,0 +1,205 @@
+package com.tzld.videoVector.api;
+
+import com.alibaba.fastjson.JSON;
+import com.alibaba.fastjson.annotation.JSONField;
+import lombok.Data;
+import lombok.extern.slf4j.Slf4j;
+import okhttp3.*;
+import org.springframework.beans.factory.annotation.Value;
+import org.springframework.stereotype.Service;
+
+import javax.annotation.PostConstruct;
+import java.io.IOException;
+import java.util.List;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * Library API 服务
+ * 封装 library.aiddit.com 的话题元素搜索和视频帖子列表两个接口
+ */
+@Slf4j
+@Service
+public class LibraryApiService {
+
+    private OkHttpClient client;
+
+    private static final String BASE_URL = "https://library.aiddit.com";
+
+    private static final String TOPIC_ELEMENTS_SEARCH_URL = BASE_URL + "/api/pattern/executions/%d/topic-elements/search";
+
+    private static final String POSTS_URL = BASE_URL + "/api/posts";
+
+    @Value("${library.api.timeout:30}")
+    private int timeout;
+
+    @PostConstruct
+    public void init() {
+        client = new OkHttpClient.Builder()
+                .connectTimeout(timeout, TimeUnit.SECONDS)
+                .readTimeout(timeout, TimeUnit.SECONDS)
+                .writeTimeout(timeout, TimeUnit.SECONDS)
+                .build();
+        log.info("Library API 服务初始化完成");
+    }
+
+    /**
+     * 搜索话题元素
+     *
+     * @param executionId 执行 ID
+     * @param query       搜索关键词
+     * @param limit       返回数量上限
+     * @param layer       搜索层级(如 "all")
+     * @return 话题元素搜索结果,失败时返回 null
+     */
+    public TopicElementSearchResponse searchTopicElements(int executionId, String query, int limit, String layer) {
+        try {
+            HttpUrl url = HttpUrl.parse(String.format(TOPIC_ELEMENTS_SEARCH_URL, executionId));
+            if (url == null) {
+                log.error("searchTopicElements URL 解析失败, executionId: {}", executionId);
+                return null;
+            }
+            HttpUrl.Builder urlBuilder = url.newBuilder()
+                    .addQueryParameter("q", query)
+                    .addQueryParameter("limit", String.valueOf(limit));
+            if (layer != null && !layer.isEmpty()) {
+                urlBuilder.addQueryParameter("layer", layer);
+            }
+
+            Request request = new Request.Builder()
+                    .url(urlBuilder.build())
+                    .get()
+                    .build();
+
+            try (Response response = client.newCall(request).execute()) {
+                if (!response.isSuccessful()) {
+                    log.error("搜索话题元素请求失败,HTTP状态码: {}, executionId: {}, query: {}",
+                            response.code(), executionId, query);
+                    return null;
+                }
+                String respStr = response.body().string();
+                log.info("搜索话题元素响应, executionId: {}, query: {}, resp: {}", executionId, query, respStr);
+                return JSON.parseObject(respStr, TopicElementSearchResponse.class);
+            }
+        } catch (IOException e) {
+            log.error("搜索话题元素异常, executionId: {}, query: {}, error: {}",
+                    executionId, query, e.getMessage(), e);
+            return null;
+        }
+    }
+
+    /**
+     * 获取视频帖子列表
+     *
+     * @param page        页码(从 1 开始)
+     * @param pageSize    每页数量
+     * @param platform    平台(如 "piaoquan")
+     * @param executionId 执行 ID
+     * @param elementId   元素 ID
+     * @return 帖子列表响应,失败时返回 null
+     */
+    public PostListResponse getPosts(int page, int pageSize, String platform, int executionId, long elementId) {
+        try {
+            HttpUrl url = HttpUrl.parse(POSTS_URL);
+            if (url == null) {
+                log.error("getPosts URL 解析失败");
+                return null;
+            }
+            HttpUrl.Builder urlBuilder = url.newBuilder()
+                    .addQueryParameter("page", String.valueOf(page))
+                    .addQueryParameter("page_size", String.valueOf(pageSize))
+                    .addQueryParameter("platform", platform)
+                    .addQueryParameter("execution_id", String.valueOf(executionId))
+                    .addQueryParameter("element_id", String.valueOf(elementId));
+
+            Request request = new Request.Builder()
+                    .url(urlBuilder.build())
+                    .get()
+                    .build();
+
+            try (Response response = client.newCall(request).execute()) {
+                if (!response.isSuccessful()) {
+                    log.error("获取帖子列表请求失败,HTTP状态码: {}, executionId: {}, elementId: {}",
+                            response.code(), executionId, elementId);
+                    return null;
+                }
+                String respStr = response.body().string();
+                log.info("获取帖子列表响应, executionId: {}, elementId: {}, page: {}, resp: {}",
+                        executionId, elementId, page, respStr);
+                return JSON.parseObject(respStr, PostListResponse.class);
+            }
+        } catch (IOException e) {
+            log.error("获取帖子列表异常, executionId: {}, elementId: {}, error: {}",
+                    executionId, elementId, e.getMessage(), e);
+            return null;
+        }
+    }
+
+    // ======================== 响应模型 ========================
+
+    @Data
+    public static class TopicElementSearchResponse {
+        private String q;
+        private String layer;
+        private List<TopicElementItem> items;
+    }
+
+    @Data
+    public static class TopicElementItem {
+        private Long id;
+        private String name;
+        private String dimension;
+        @JSONField(name = "point_types")
+        private List<String> pointTypes;
+        @JSONField(name = "occurrence_count")
+        private Integer occurrenceCount;
+    }
+
+    @Data
+    public static class PostListResponse {
+        private Boolean success;
+        private List<PostItem> posts;
+        private Integer total;
+        private Integer page;
+        @JSONField(name = "page_size")
+        private Integer pageSize;
+        @JSONField(name = "total_pages")
+        private Integer totalPages;
+        private List<String> platforms;
+        private ExecutionInfo execution;
+    }
+
+    @Data
+    public static class PostItem {
+        private Long id;
+        @JSONField(name = "post_id")
+        private String postId;
+        private String title;
+        private String platform;
+        @JSONField(name = "platform_account_name")
+        private String platformAccountName;
+        private String mergeLeve1;
+        private String mergeLeve2;
+        private List<String> images;
+        @JSONField(name = "like_count")
+        private Integer likeCount;
+        @JSONField(name = "comment_count")
+        private Integer commentCount;
+        @JSONField(name = "collect_count")
+        private Integer collectCount;
+        @JSONField(name = "import_date")
+        private String importDate;
+        @JSONField(name = "has_topic_decode")
+        private Boolean hasTopicDecode;
+        @JSONField(name = "has_script_decode")
+        private Boolean hasScriptDecode;
+    }
+
+    @Data
+    public static class ExecutionInfo {
+        private Integer id;
+        @JSONField(name = "snapshot_date")
+        private String snapshotDate;
+        @JSONField(name = "post_count")
+        private Integer postCount;
+    }
+}

+ 187 - 7
core/src/main/java/com/tzld/videoVector/job/ChannelDemandMatchJob.java

@@ -1,9 +1,12 @@
 package com.tzld.videoVector.job;
 
 import com.alibaba.fastjson.JSON;
+import com.alibaba.fastjson.JSONObject;
 import com.aliyun.odps.data.Record;
 import com.ctrip.framework.apollo.spring.annotation.ApolloJsonValue;
 import com.google.common.collect.Lists;
+import com.tzld.videoVector.api.LibraryApiService;
+import com.tzld.videoVector.common.constant.VectorConstants;
 import com.tzld.videoVector.dao.mapper.pgVector.ChannelDemandMatchConfigMapper;
 import com.tzld.videoVector.dao.mapper.pgVector.ChannelDemandMatchResultMapper;
 import com.tzld.videoVector.dao.mapper.pgVector.ext.ChannelDemandMatchResultMapperExt;
@@ -61,6 +64,9 @@ public class ChannelDemandMatchJob {
     @Resource
     private RedisUtils redisUtils;
 
+    @Resource
+    private LibraryApiService libraryApiService;
+
     /**
      * 召回结果Redis缓存前缀
      */
@@ -161,6 +167,36 @@ public class ChannelDemandMatchJob {
     @Value("${channel.demand.result.retention-days:14}")
     private int resultRetentionDays;
 
+    /**
+     * Library API 执行 ID
+     */
+    @Value("${library.api.execution-id:581}")
+    private int libraryExecutionId;
+
+    /**
+     * Library API 平台
+     */
+    @Value("${library.api.platform:piaoquan}")
+    private String libraryPlatform;
+
+    /**
+     * Library API 话题元素搜索返回上限
+     */
+    @Value("${library.api.element-search-limit:20}")
+    private int elementSearchLimit;
+
+    /**
+     * Library API 帖子列表每页数量
+     */
+    @Value("${library.api.post-page-size:50}")
+    private int libraryPostPageSize;
+
+    /**
+     * 视频详情指标天数维度
+     */
+    @Value("${video.detail.metrics.days:7}")
+    private int metricsDays;
+
     /**
      * 点类型 → 向量配置编码映射
      */
@@ -521,15 +557,10 @@ public class ChannelDemandMatchJob {
             allBatchRows.addAll(rows);
         }
 
-        // 策略三:需求特征点类型+需求特征点 均有值 → 用需求特征点召回
+        // 策略三-泛化:特征点泛化 → 使用 Library API 召回
         if ("特征点泛化".equals(demand.getDemandType()) && hasValidValue(demand.getMatchGeneralizedPointType())
                 && hasValidValue(demand.getMatchGeneralizedElement())) {
-            List<String> configCodes = POINT_TYPE_CONFIG_CODE_MAP.getOrDefault(demand.getMatchGeneralizedPointType(), Arrays.asList("VIDEO_TOPIC"));
-            List<ChannelDemandMatchResult> rows = new ArrayList<>();
-            for (String configCode : configCodes) {
-                rows.addAll(doRecall(demand, demand.getMatchGeneralizedElement(), configCode, topN / configCodes.size()));
-
-            }
+            List<ChannelDemandMatchResult> rows = doLibraryRecall(demand, topN);
             allBatchRows.addAll(rows);
         }
 
@@ -584,6 +615,155 @@ public class ChannelDemandMatchJob {
         return batchRows;
     }
 
+    /**
+     * Library API 召回:通过话题元素搜索 + 帖子列表获取召回视频
+     * <p>
+     * 流程:
+     * 1. 用泛化元素名称搜索话题元素,获取元素 ID
+     * 2. 仅保留名称全等匹配的元素
+     * 3. 遍历元素 ID,调用帖子列表接口获取视频帖子
+     * 4. 按 post_id 去重
+     * 5. 从 Redis 批量获取视频详情指标(按 video.detail.metrics.days 天数维度)
+     * 6. 按 rov 降序排列,取 topN
+     */
+    private List<ChannelDemandMatchResult> doLibraryRecall(ChannelDemandMatchResult demand, int topN) {
+        List<ChannelDemandMatchResult> batchRows = new ArrayList<>();
+        String elementName = demand.getMatchGeneralizedElement();
+
+        // 1. 搜索话题元素
+        LibraryApiService.TopicElementSearchResponse elementResp = libraryApiService.searchTopicElements(
+                libraryExecutionId, elementName, elementSearchLimit, "all");
+        if (elementResp == null || CollectionUtils.isEmpty(elementResp.getItems())) {
+            log.info("Library API 话题元素搜索无结果, executionId={}, elementName={}", libraryExecutionId, elementName);
+            return batchRows;
+        }
+        log.info("Library API 话题元素搜索到 {} 个元素, elementName={}", elementResp.getItems().size(), elementName);
+
+        // 2. 仅保留名称全等匹配的元素
+        List<LibraryApiService.TopicElementItem> matchedItems = elementResp.getItems().stream()
+                .filter(e -> elementName.equals(e.getName()))
+                .collect(Collectors.toList());
+        if (matchedItems.isEmpty()) {
+            log.info("Library API 话题元素无全等匹配, elementName={}", elementName);
+            return batchRows;
+        }
+        log.info("Library API 话题元素全等匹配 {} 个, elementName={}", matchedItems.size(), elementName);
+
+        // 3. 遍历匹配的元素获取帖子,按 post_id 去重
+        Map<Long, LibraryApiService.PostItem> postMap = new LinkedHashMap<>();
+        for (LibraryApiService.TopicElementItem element : matchedItems) {
+            if (element.getId() == null) {
+                continue;
+            }
+            LibraryApiService.PostListResponse postResp = libraryApiService.getPosts(
+                    1, libraryPostPageSize, libraryPlatform, libraryExecutionId, element.getId());
+            if (postResp == null || CollectionUtils.isEmpty(postResp.getPosts())) {
+                continue;
+            }
+            for (LibraryApiService.PostItem post : postResp.getPosts()) {
+                if (post.getPostId() == null) {
+                    continue;
+                }
+                Long postIdLong;
+                try {
+                    postIdLong = Long.parseLong(post.getPostId());
+                } catch (NumberFormatException e) {
+                    log.warn("post_id 解析失败, postId={}, elementId={}", post.getPostId(), element.getId());
+                    continue;
+                }
+                postMap.putIfAbsent(postIdLong, post);
+            }
+        }
+        if (postMap.isEmpty()) {
+            log.info("Library API 帖子列表无结果, elementName={}", elementName);
+            return batchRows;
+        }
+        log.info("Library API 去重后获取到 {} 个帖子, elementName={}", postMap.size(), elementName);
+
+        // 4. 从 Redis 批量获取视频详情指标
+        List<Long> postIdList = new ArrayList<>(postMap.keySet());
+        List<String> redisKeys = postIdList.stream()
+                .map(id -> VectorConstants.VIDEO_DETAIL_DAYS_KEY_PREFIX + metricsDays + "d:" + id)
+                .collect(Collectors.toList());
+        List<String> redisValues = redisUtils.mGet(redisKeys);
+
+        // 5. 解析指标并构建结果行
+        List<PostWithMetrics> postWithMetricsList = new ArrayList<>();
+        for (int i = 0; i < postIdList.size(); i++) {
+            Long postId = postIdList.get(i);
+            LibraryApiService.PostItem post = postMap.get(postId);
+            Map<String, Object> videoDetail = null;
+            Double rov = null;
+
+            if (redisValues != null && i < redisValues.size() && redisValues.get(i) != null) {
+                try {
+                    videoDetail = JSONObject.parseObject(redisValues.get(i), Map.class);
+                    if (videoDetail != null) {
+                        Object rovObj = videoDetail.get("rov");
+                        if (rovObj != null) {
+                            rov = rovObj instanceof Number ? ((Number) rovObj).doubleValue()
+                                    : Double.parseDouble(rovObj.toString());
+                        }
+                    }
+                } catch (Exception e) {
+                    log.warn("解析视频详情失败, postId={}: {}", postId, e.getMessage());
+                }
+            }
+            postWithMetricsList.add(new PostWithMetrics(post, videoDetail, rov));
+        }
+
+        // 6. 按 rov 降序排列(无 rov 数据的排在最后),取 topN
+        postWithMetricsList.sort((a, b) -> {
+            Double aRov = a.rov != null ? a.rov : -1.0;
+            Double bRov = b.rov != null ? b.rov : -1.0;
+            return bRov.compareTo(aRov);
+        });
+
+        int count = 0;
+        for (PostWithMetrics pm : postWithMetricsList) {
+            if (count >= topN) {
+                break;
+            }
+            LibraryApiService.PostItem post = pm.post;
+            Map<String, Object> detail = pm.videoDetail;
+            Long postId = Long.parseLong(post.getPostId());
+
+            ChannelDemandMatchResult row = copyDemandFields(demand);
+            row.setMatchVideoId(postId);
+            row.setMatchConfigCode("LIBRARY_TOPIC_ELEMENT");
+            row.setMatchRov(pm.rov);
+            row.setMatchScore(pm.rov);
+            row.setMatchSim(null);
+            row.setMatchExposurePv(extractNumber(detail, "分发曝光pv", Long.class));
+            row.setMatchHeadSingleReturnRate(extractNumber(detail, "头部单层回流率", Double.class));
+            row.setMatchHeadDistributionSingleReturnRate(extractNumber(detail, "头部进分发单层回流率", Double.class));
+            row.setMatchText(post.getTitle());
+            row.setMatchStatus((short) 1);
+            row.setExperimentId(generateExperimentId(demand, postId, "LIBRARY_TOPIC_ELEMENT"));
+            batchRows.add(row);
+            count++;
+        }
+
+        log.info("Library API 召回完成, elementName={}, 候选{}条, 返回{}条",
+                elementName, postWithMetricsList.size(), batchRows.size());
+        return batchRows;
+    }
+
+    /**
+     * 帖子与指标数据组装
+     */
+    private static class PostWithMetrics {
+        final LibraryApiService.PostItem post;
+        final Map<String, Object> videoDetail;
+        final Double rov;
+
+        PostWithMetrics(LibraryApiService.PostItem post, Map<String, Object> videoDetail, Double rov) {
+            this.post = post;
+            this.videoDetail = videoDetail;
+            this.rov = rov;
+        }
+    }
+
     /**
      * 带Redis缓存的召回:相同queryText+configCode+topN直接复用缓存结果
      */