Ver Fonte

Merge branch 'refs/heads/dev-xym-update-branch' into pre-master

xueyiming há 1 semana atrás
pai
commit
3acfe95b20

+ 67 - 9
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/scorer/PAIScorer.java

@@ -8,6 +8,7 @@ import com.tzld.piaoquan.ad.engine.commons.score.ScoreParam;
 import com.tzld.piaoquan.ad.engine.commons.score.ScorerConfigInfo;
 import com.tzld.piaoquan.ad.engine.commons.score.model.PAIModelV1;
 import com.tzld.piaoquan.ad.engine.commons.score.model.XGBoostModel683;
+import com.tzld.piaoquan.ad.engine.commons.thread.ThreadPoolFactory;
 import com.tzld.piaoquan.recommend.feature.domain.ad.base.AdRankItem;
 import com.tzld.piaoquan.recommend.feature.domain.ad.base.UserAdFeature;
 import org.apache.commons.collections4.CollectionUtils;
@@ -53,29 +54,86 @@ public class PAIScorer extends AbstractScorer {
         return result;
     }
 
+//    private List<AdRankItem> rankByJava(final Map<String, String> sceneFeatureMap,
+//                                        final Map<String, String> userFeatureMap,
+//                                        final List<AdRankItem> items) {
+//        long startTime = System.currentTimeMillis();
+//        PAIModelV1 model = PAIModelV1.getModel();
+//        // 所有都参与打分,按照ctr排序
+//        multipleCtrScore(items, userFeatureMap, sceneFeatureMap, model);
+//
+//        // debug log
+//        if (LOGGER.isDebugEnabled()) {
+//            for (int i = 0; i < items.size(); i++) {
+//                LOGGER.debug("before enter feeds model predict ctr score [{}] [{}]", items.get(i), items.get(i));
+//            }
+//        }
+//
+//        Collections.sort(items);
+//
+//        LOGGER.debug("ctr ranker java execute time: [{}]", System.currentTimeMillis() - startTime);
+//        LOGGER.debug("[ctr ranker time java] items size={}, cost={} ", items != null ? items.size() : 0,
+//                System.currentTimeMillis() - startTime);
+//        return items;
+//    }
+
     private List<AdRankItem> rankByJava(final Map<String, String> sceneFeatureMap,
                                         final Map<String, String> userFeatureMap,
                                         final List<AdRankItem> items) {
+        if (items == null || items.isEmpty()) {
+            return Collections.emptyList();
+        }
+
         long startTime = System.currentTimeMillis();
         PAIModelV1 model = PAIModelV1.getModel();
-        // 所有都参与打分,按照ctr排序
-        multipleCtrScore(items, userFeatureMap, sceneFeatureMap, model);
 
-        // debug log
-        if (LOGGER.isDebugEnabled()) {
-            for (int i = 0; i < items.size(); i++) {
-                LOGGER.debug("before enter feeds model predict ctr score [{}] [{}]", items.get(i), items.get(i));
+        // 切分批次
+        final int batchSize = 500; // 每批 500 个,可按你的业务场景调整
+        List<List<AdRankItem>> batches = new ArrayList<>();
+        for (int i = 0; i < items.size(); i += batchSize) {
+            batches.add(items.subList(i, Math.min(i + batchSize, items.size())));
+        }
+
+        // 并发执行
+        ExecutorService executor = ThreadPoolFactory.defaultPool();
+        List<Future<?>> futures = new ArrayList<>();
+
+        for (List<AdRankItem> batch : batches) {
+            futures.add(executor.submit(() -> {
+                try {
+                    multipleCtrScore(batch, userFeatureMap, sceneFeatureMap, model);
+                } catch (Exception e) {
+                    LOGGER.error("Error during multipleCtrScore batch execution", e);
+                }
+            }));
+        }
+
+        // 等待所有任务完成
+        for (Future<?> future : futures) {
+            try {
+                future.get();
+            } catch (InterruptedException e) {
+                Thread.currentThread().interrupt();
+                LOGGER.warn("Thread interrupted while waiting for batch tasks", e);
+            } catch (ExecutionException e) {
+                LOGGER.error("Execution error in batch tasks", e);
             }
         }
 
+        // 打分完成后排序
         Collections.sort(items);
 
-        LOGGER.debug("ctr ranker java execute time: [{}]", System.currentTimeMillis() - startTime);
-        LOGGER.debug("[ctr ranker time java] items size={}, cost={} ", items != null ? items.size() : 0,
-                System.currentTimeMillis() - startTime);
+        LOGGER.debug("ctr ranker java execute time: [{}ms]", System.currentTimeMillis() - startTime);
+        LOGGER.debug("[ctr ranker time java] items size={}, cost={}ms",
+                items.size(), System.currentTimeMillis() - startTime);
+
         return items;
     }
 
+
+
+
+
     private void multipleCtrScore(final List<AdRankItem> items,
                                   final Map<String, String> userFeatureMap,
                                   final Map<String, String> sceneFeatureMap,

+ 67 - 9
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/scorer/PAIScorerV2.java

@@ -4,16 +4,22 @@ package com.tzld.piaoquan.ad.engine.service.score.scorer;
 import com.tzld.piaoquan.ad.engine.commons.score.AbstractScorer;
 import com.tzld.piaoquan.ad.engine.commons.score.ScoreParam;
 import com.tzld.piaoquan.ad.engine.commons.score.ScorerConfigInfo;
+import com.tzld.piaoquan.ad.engine.commons.score.model.PAIModelV1;
 import com.tzld.piaoquan.ad.engine.commons.score.model.PAIModelV2;
+import com.tzld.piaoquan.ad.engine.commons.thread.ThreadPoolFactory;
 import com.tzld.piaoquan.recommend.feature.domain.ad.base.AdRankItem;
 import com.tzld.piaoquan.recommend.feature.domain.ad.base.UserAdFeature;
 import org.apache.commons.collections4.CollectionUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
 
 public class PAIScorerV2 extends AbstractScorer {
 
@@ -48,26 +54,78 @@ public class PAIScorerV2 extends AbstractScorer {
         return result;
     }
 
+//    private List<AdRankItem> rankByJava(final Map<String, String> sceneFeatureMap,
+//                                        final Map<String, String> userFeatureMap,
+//                                        final List<AdRankItem> items) {
+//        long startTime = System.currentTimeMillis();
+//        PAIModelV2 model = PAIModelV2.getModel();
+//        // 所有都参与打分,按照ctr排序
+//        multipleCtrScore(items, userFeatureMap, sceneFeatureMap, model);
+//
+//        // debug log
+//        if (LOGGER.isDebugEnabled()) {
+//            for (int i = 0; i < items.size(); i++) {
+//                LOGGER.debug("before enter feeds model predict ctr score [{}] [{}]", items.get(i), items.get(i));
+//            }
+//        }
+//
+//        Collections.sort(items);
+//
+//        LOGGER.debug("ctr ranker java execute time: [{}]", System.currentTimeMillis() - startTime);
+//        LOGGER.debug("[ctr ranker time java] items size={}, cost={} ", items != null ? items.size() : 0,
+//                System.currentTimeMillis() - startTime);
+//        return items;
+//    }
+
     private List<AdRankItem> rankByJava(final Map<String, String> sceneFeatureMap,
                                         final Map<String, String> userFeatureMap,
                                         final List<AdRankItem> items) {
+        if (items == null || items.isEmpty()) {
+            return Collections.emptyList();
+        }
+
         long startTime = System.currentTimeMillis();
         PAIModelV2 model = PAIModelV2.getModel();
-        // 所有都参与打分,按照ctr排序
-        multipleCtrScore(items, userFeatureMap, sceneFeatureMap, model);
+        // 切分批次
+        final int batchSize = 500; // 每批 500 个,可按你的业务场景调整
+        List<List<AdRankItem>> batches = new ArrayList<>();
+        for (int i = 0; i < items.size(); i += batchSize) {
+            batches.add(items.subList(i, Math.min(i + batchSize, items.size())));
+        }
 
-        // debug log
-        if (LOGGER.isDebugEnabled()) {
-            for (int i = 0; i < items.size(); i++) {
-                LOGGER.debug("before enter feeds model predict ctr score [{}] [{}]", items.get(i), items.get(i));
+        // 并发执行
+        ExecutorService executor = ThreadPoolFactory.defaultPool();
+        List<Future<?>> futures = new ArrayList<>();
+
+        for (List<AdRankItem> batch : batches) {
+            futures.add(executor.submit(() -> {
+                try {
+                    multipleCtrScore(batch, userFeatureMap, sceneFeatureMap, model);
+                } catch (Exception e) {
+                    LOGGER.error("Error during multipleCtrScore batch execution", e);
+                }
+            }));
+        }
+
+        // 等待所有任务完成
+        for (Future<?> future : futures) {
+            try {
+                future.get();
+            } catch (InterruptedException e) {
+                Thread.currentThread().interrupt();
+                LOGGER.warn("Thread interrupted while waiting for batch tasks", e);
+            } catch (ExecutionException e) {
+                LOGGER.error("Execution error in batch tasks", e);
             }
         }
 
+        // 打分完成后排序
         Collections.sort(items);
 
-        LOGGER.debug("ctr ranker java execute time: [{}]", System.currentTimeMillis() - startTime);
-        LOGGER.debug("[ctr ranker time java] items size={}, cost={} ", items != null ? items.size() : 0,
-                System.currentTimeMillis() - startTime);
+        LOGGER.debug("ctr ranker java execute time: [{}ms]", System.currentTimeMillis() - startTime);
+        LOGGER.debug("[ctr ranker time java] items size={}, cost={}ms",
+                items.size(), System.currentTimeMillis() - startTime);
+
         return items;
     }