Browse Source

Merge branch 'dev-xym-update-branch' of algorithm/ad-engine into master

xueyiming 1 week ago
parent
commit
c92b727446

+ 1 - 1
ad-engine-commons/src/main/java/com/tzld/piaoquan/ad/engine/commons/score/ScorerPipeline.java

@@ -18,7 +18,7 @@ import java.util.concurrent.*;
 @Slf4j
 public class ScorerPipeline {
     public static final int corePoolSize = 128;
-    public static final int SCORE_TIME_OUT = 400;
+    public static final int SCORE_TIME_OUT = 500;
     public static final Logger LOGGER = LoggerFactory.getLogger(ScorerPipeline.class);
     public static final ExecutorService executorService = Executors.newFixedThreadPool(corePoolSize);
 

+ 8 - 0
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/feature/Feature.java

@@ -23,4 +23,12 @@ public class Feature {
     // k1:skuid、k2:表、v:特征值
     private Map<String, Map<String, Map<String, String>>> skuFeature = new HashMap<>();
 
+    // 合并方法,使用 putAll() 合并,且如果目标为空则取源
+    public void merge(Feature other) {
+        this.cidFeature.putAll(other.cidFeature);
+        this.videoFeature.putAll(other.videoFeature);
+        this.adVerFeature.putAll(other.adVerFeature);
+        this.userFeature.putAll(other.userFeature);
+        this.skuFeature.putAll(other.skuFeature);
+    }
 }

+ 63 - 11
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/scorer/PAIScorer.java

@@ -7,7 +7,9 @@ import com.tzld.piaoquan.ad.engine.commons.score.BaseXGBoostModelScorer;
 import com.tzld.piaoquan.ad.engine.commons.score.ScoreParam;
 import com.tzld.piaoquan.ad.engine.commons.score.ScorerConfigInfo;
 import com.tzld.piaoquan.ad.engine.commons.score.model.PAIModelV1;
+import com.tzld.piaoquan.ad.engine.commons.score.model.PAIModelV2;
 import com.tzld.piaoquan.ad.engine.commons.score.model.XGBoostModel683;
+import com.tzld.piaoquan.ad.engine.commons.thread.ThreadPoolFactory;
 import com.tzld.piaoquan.recommend.feature.domain.ad.base.AdRankItem;
 import com.tzld.piaoquan.recommend.feature.domain.ad.base.UserAdFeature;
 import org.apache.commons.collections4.CollectionUtils;
@@ -53,29 +55,79 @@ public class PAIScorer extends AbstractScorer {
         return result;
     }
 
+//    private List<AdRankItem> rankByJava(final Map<String, String> sceneFeatureMap,
+//                                        final Map<String, String> userFeatureMap,
+//                                        final List<AdRankItem> items) {
+//        long startTime = System.currentTimeMillis();
+//        PAIModelV1 model = PAIModelV1.getModel();
+//        // 所有都参与打分,按照ctr排序
+//        multipleCtrScore(items, userFeatureMap, sceneFeatureMap, model);
+//
+//        // debug log
+//        if (LOGGER.isDebugEnabled()) {
+//            for (int i = 0; i < items.size(); i++) {
+//                LOGGER.debug("before enter feeds model predict ctr score [{}] [{}]", items.get(i), items.get(i));
+//            }
+//        }
+//
+//        Collections.sort(items);
+//
+//        LOGGER.debug("ctr ranker java execute time: [{}]", System.currentTimeMillis() - startTime);
+//        LOGGER.debug("[ctr ranker time java] items size={}, cost={} ", items != null ? items.size() : 0,
+//                System.currentTimeMillis() - startTime);
+//        return items;
+//    }
+
     private List<AdRankItem> rankByJava(final Map<String, String> sceneFeatureMap,
                                         final Map<String, String> userFeatureMap,
                                         final List<AdRankItem> items) {
+        if (items == null || items.isEmpty()) {
+            return Collections.emptyList();
+        }
+
         long startTime = System.currentTimeMillis();
         PAIModelV1 model = PAIModelV1.getModel();
-        // 所有都参与打分,按照ctr排序
-        multipleCtrScore(items, userFeatureMap, sceneFeatureMap, model);
 
-        // debug log
-        if (LOGGER.isDebugEnabled()) {
-            for (int i = 0; i < items.size(); i++) {
-                LOGGER.debug("before enter feeds model predict ctr score [{}] [{}]", items.get(i), items.get(i));
+        final int batchSize = 300;
+        List<List<AdRankItem>> batches = new ArrayList<>();
+        for (int i = 0; i < items.size(); i += batchSize) {
+            batches.add(new ArrayList<>(items.subList(i, Math.min(i + batchSize, items.size()))));
+        }
+
+        ExecutorService executor = ThreadPoolFactory.defaultPool();
+        List<Future<List<AdRankItem>>> futures = new ArrayList<>();
+
+        for (List<AdRankItem> batch : batches) {
+            futures.add(executor.submit(() -> {
+                try {
+                    multipleCtrScore(batch, userFeatureMap, sceneFeatureMap, model);
+                } catch (Exception e) {
+                    LOGGER.error("Error during multipleCtrScore batch execution", e);
+                }
+                return batch;
+            }));
+        }
+
+        // 合并结果
+        List<AdRankItem> merged = new ArrayList<>();
+        for (Future<List<AdRankItem>> future : futures) {
+            try {
+                merged.addAll(future.get(400, TimeUnit.MILLISECONDS));
+            } catch (Exception e) {
+                LOGGER.error("Execution error in batch", e);
             }
         }
 
-        Collections.sort(items);
+        Collections.sort(merged);
 
-        LOGGER.debug("ctr ranker java execute time: [{}]", System.currentTimeMillis() - startTime);
-        LOGGER.debug("[ctr ranker time java] items size={}, cost={} ", items != null ? items.size() : 0,
-                System.currentTimeMillis() - startTime);
-        return items;
+        LOGGER.debug("ctr ranker java execute time: [{}ms]", System.currentTimeMillis() - startTime);
+        return merged;
     }
 
+
+
+
+
     private void multipleCtrScore(final List<AdRankItem> items,
                                   final Map<String, String> userFeatureMap,
                                   final Map<String, String> sceneFeatureMap,

+ 64 - 11
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/scorer/PAIScorerV2.java

@@ -4,16 +4,23 @@ package com.tzld.piaoquan.ad.engine.service.score.scorer;
 import com.tzld.piaoquan.ad.engine.commons.score.AbstractScorer;
 import com.tzld.piaoquan.ad.engine.commons.score.ScoreParam;
 import com.tzld.piaoquan.ad.engine.commons.score.ScorerConfigInfo;
+import com.tzld.piaoquan.ad.engine.commons.score.model.PAIModelV1;
 import com.tzld.piaoquan.ad.engine.commons.score.model.PAIModelV2;
+import com.tzld.piaoquan.ad.engine.commons.thread.ThreadPoolFactory;
 import com.tzld.piaoquan.recommend.feature.domain.ad.base.AdRankItem;
 import com.tzld.piaoquan.recommend.feature.domain.ad.base.UserAdFeature;
 import org.apache.commons.collections4.CollectionUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
 
 public class PAIScorerV2 extends AbstractScorer {
 
@@ -48,27 +55,73 @@ public class PAIScorerV2 extends AbstractScorer {
         return result;
     }
 
+//    private List<AdRankItem> rankByJava(final Map<String, String> sceneFeatureMap,
+//                                        final Map<String, String> userFeatureMap,
+//                                        final List<AdRankItem> items) {
+//        long startTime = System.currentTimeMillis();
+//        PAIModelV2 model = PAIModelV2.getModel();
+//        // 所有都参与打分,按照ctr排序
+//        multipleCtrScore(items, userFeatureMap, sceneFeatureMap, model);
+//
+//        // debug log
+//        if (LOGGER.isDebugEnabled()) {
+//            for (int i = 0; i < items.size(); i++) {
+//                LOGGER.debug("before enter feeds model predict ctr score [{}] [{}]", items.get(i), items.get(i));
+//            }
+//        }
+//
+//        Collections.sort(items);
+//
+//        LOGGER.debug("ctr ranker java execute time: [{}]", System.currentTimeMillis() - startTime);
+//        LOGGER.debug("[ctr ranker time java] items size={}, cost={} ", items != null ? items.size() : 0,
+//                System.currentTimeMillis() - startTime);
+//        return items;
+//    }
+
     private List<AdRankItem> rankByJava(final Map<String, String> sceneFeatureMap,
                                         final Map<String, String> userFeatureMap,
                                         final List<AdRankItem> items) {
+        if (items == null || items.isEmpty()) {
+            return Collections.emptyList();
+        }
+
         long startTime = System.currentTimeMillis();
         PAIModelV2 model = PAIModelV2.getModel();
-        // 所有都参与打分,按照ctr排序
-        multipleCtrScore(items, userFeatureMap, sceneFeatureMap, model);
 
-        // debug log
-        if (LOGGER.isDebugEnabled()) {
-            for (int i = 0; i < items.size(); i++) {
-                LOGGER.debug("before enter feeds model predict ctr score [{}] [{}]", items.get(i), items.get(i));
+        final int batchSize = 300;
+        List<List<AdRankItem>> batches = new ArrayList<>();
+        for (int i = 0; i < items.size(); i += batchSize) {
+            batches.add(new ArrayList<>(items.subList(i, Math.min(i + batchSize, items.size()))));
+        }
+
+        ExecutorService executor = ThreadPoolFactory.defaultPool();
+        List<Future<List<AdRankItem>>> futures = new ArrayList<>();
+
+        for (List<AdRankItem> batch : batches) {
+            futures.add(executor.submit(() -> {
+                try {
+                    multipleCtrScore(batch, userFeatureMap, sceneFeatureMap, model);
+                } catch (Exception e) {
+                    LOGGER.error("Error during multipleCtrScore batch execution", e);
+                }
+                return batch;
+            }));
+        }
+
+        // 合并结果
+        List<AdRankItem> merged = new ArrayList<>();
+        for (Future<List<AdRankItem>> future : futures) {
+            try {
+                merged.addAll(future.get(400, TimeUnit.MILLISECONDS));
+            } catch (Exception e) {
+                LOGGER.error("Execution error in batch", e);
             }
         }
 
-        Collections.sort(items);
+        Collections.sort(merged);
 
-        LOGGER.debug("ctr ranker java execute time: [{}]", System.currentTimeMillis() - startTime);
-        LOGGER.debug("[ctr ranker time java] items size={}, cost={} ", items != null ? items.size() : 0,
-                System.currentTimeMillis() - startTime);
-        return items;
+        LOGGER.debug("ctr ranker java execute time: [{}ms]", System.currentTimeMillis() - startTime);
+        return merged;
     }
 
     private void multipleCtrScore(final List<AdRankItem> items,

+ 86 - 15
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/strategy/RankStrategyBasic.java

@@ -13,6 +13,7 @@ import com.tzld.piaoquan.ad.engine.commons.param.RankRecommendRequestParam;
 import com.tzld.piaoquan.ad.engine.commons.redis.AdRedisHelper;
 import com.tzld.piaoquan.ad.engine.commons.redis.AlgorithmRedisHelper;
 import com.tzld.piaoquan.ad.engine.commons.score.ScoreParam;
+import com.tzld.piaoquan.ad.engine.commons.thread.ThreadPoolFactory;
 import com.tzld.piaoquan.ad.engine.commons.util.DateUtils;
 import com.tzld.piaoquan.ad.engine.commons.util.ObjUtil;
 import com.tzld.piaoquan.ad.engine.service.entity.*;
@@ -27,6 +28,10 @@ import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.beans.factory.annotation.Value;
 
 import java.util.*;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
 import java.util.function.Function;
 import java.util.stream.Collectors;
 
@@ -113,6 +118,9 @@ public abstract class RankStrategyBasic implements RankStrategy {
     @Value("${filter.config.value:[]}")
     protected String filterConfigValue;
 
+    @Value("${feature.branch.config.value:500}")
+    protected Integer featureBranchConfigValue;
+
     @Autowired
     private FeatureService featureService;
     @Autowired
@@ -153,25 +161,88 @@ public abstract class RankStrategyBasic implements RankStrategy {
     }};
 
 
+//    protected Feature getFeature(ScoreParam param, RankRecommendRequestParam request) {
+//        List<AdPlatformCreativeDTO> adIdList = request.getAdIdList();
+//        List<String> cidList = adIdList.stream()
+//                .map(AdPlatformCreativeDTO::getCreativeId)
+//                .map(Object::toString)
+//                .collect(Collectors.toList());
+//
+//        List<String> adVerIdList = adIdList.stream()
+//                .map(AdPlatformCreativeDTO::getAdVerId)
+//                .filter(StringUtils::isNotBlank)
+//                .distinct()
+//                .collect(Collectors.toList());
+//
+//        List<Long> skuIdList = adIdList.stream()
+//                .map(AdPlatformCreativeDTO::getSkuId)
+//                .filter(Objects::nonNull)
+//                .distinct()
+//                .collect(Collectors.toList());
+//        return featureService.getFeature(cidList, adVerIdList, skuIdList, param);
+//    }
+
     protected Feature getFeature(ScoreParam param, RankRecommendRequestParam request) {
         List<AdPlatformCreativeDTO> adIdList = request.getAdIdList();
-        List<String> cidList = adIdList.stream()
-                .map(AdPlatformCreativeDTO::getCreativeId)
-                .map(Object::toString)
-                .collect(Collectors.toList());
+        Feature finalFeature = null;
 
-        List<String> adVerIdList = adIdList.stream()
-                .map(AdPlatformCreativeDTO::getAdVerId)
-                .filter(StringUtils::isNotBlank)
-                .distinct()
-                .collect(Collectors.toList());
+        // 分批处理 AdPlatformCreativeDTO 列表
+        List<List<AdPlatformCreativeDTO>> adIdBatches = partitionList(adIdList, featureBranchConfigValue);
 
-        List<Long> skuIdList = adIdList.stream()
-                .map(AdPlatformCreativeDTO::getSkuId)
-                .filter(Objects::nonNull)
-                .distinct()
-                .collect(Collectors.toList());
-        return featureService.getFeature(cidList, adVerIdList, skuIdList, param);
+        // 使用 ThreadPoolFactory 获取 DEFAULT 线程池
+        ExecutorService executorService = ThreadPoolFactory.defaultPool();  // 使用 DEFAULT 线程池
+        List<Callable<Feature>> tasks = new ArrayList<>();
+
+        for (List<AdPlatformCreativeDTO> batch : adIdBatches) {
+            // 提取批次中的 cidList、adVerIdList 和 skuIdList
+            List<String> cidList = batch.stream()
+                    .map(AdPlatformCreativeDTO::getCreativeId)
+                    .map(Object::toString)
+                    .collect(Collectors.toList());
+
+            List<String> adVerIdList = batch.stream()
+                    .map(AdPlatformCreativeDTO::getAdVerId)
+                    .filter(StringUtils::isNotBlank)
+                    .distinct()
+                    .collect(Collectors.toList());
+
+            List<Long> skuIdList = batch.stream()
+                    .map(AdPlatformCreativeDTO::getSkuId)
+                    .filter(Objects::nonNull)
+                    .distinct()
+                    .collect(Collectors.toList());
+
+            // 将每个批次的请求任务封装为 Callable
+            tasks.add(() -> featureService.getFeature(cidList, adVerIdList, skuIdList, param));
+        }
+
+        try {
+            // 执行所有的任务并等待所有任务完成
+            List<Future<Feature>> futures = executorService.invokeAll(tasks);
+
+            // 等待所有任务完成并合并结果
+            for (Future<Feature> future : futures) {
+                Feature batchFeature = future.get();  // 获取每个任务的结果
+                if (finalFeature == null) {
+                    finalFeature = batchFeature;
+                } else {
+                    // 合并特征
+                    finalFeature.merge(batchFeature);
+                }
+            }
+        } catch (InterruptedException | ExecutionException e) {
+            log.error("getFeature error", e);
+        }
+        return finalFeature;
+    }
+
+    // 辅助方法:将列表分成指定大小的批次
+    private <T> List<List<T>> partitionList(List<T> originalList, Integer batchSize) {
+        List<List<T>> batches = new ArrayList<>();
+        for (int i = 0; i < originalList.size(); i += batchSize) {
+            batches.add(originalList.subList(i, Math.min(i + batchSize, originalList.size())));
+        }
+        return batches;
     }
 
     protected Set<String> getNoApiAdVerIds() {