Explorar o código

优化dnn模型调用,使用分组方式调用

xueyiming hai 3 meses
pai
achega
c18d7ace00

+ 102 - 34
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/scorer/PAIScorer.java

@@ -3,26 +3,31 @@ package com.tzld.piaoquan.ad.engine.service.score.scorer;
 
 import com.google.common.collect.Lists;
 import com.tzld.piaoquan.ad.engine.commons.score.AbstractScorer;
-import com.tzld.piaoquan.ad.engine.commons.score.BaseXGBoostModelScorer;
 import com.tzld.piaoquan.ad.engine.commons.score.ScoreParam;
 import com.tzld.piaoquan.ad.engine.commons.score.ScorerConfigInfo;
 import com.tzld.piaoquan.ad.engine.commons.score.model.PAIModelV1;
-import com.tzld.piaoquan.ad.engine.commons.score.model.XGBoostModel683;
 import com.tzld.piaoquan.recommend.feature.domain.ad.base.AdRankItem;
 import com.tzld.piaoquan.recommend.feature.domain.ad.base.UserAdFeature;
 import org.apache.commons.collections4.CollectionUtils;
-import org.apache.commons.collections4.MapUtils;
-import org.apache.commons.lang.exception.ExceptionUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
-import org.springframework.stereotype.Component;
 
-import java.util.*;
-import java.util.concurrent.*;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
 
 public class PAIScorer extends AbstractScorer {
 
     private final static Logger LOGGER = LoggerFactory.getLogger(PAIScorer.class);
+    private static final ExecutorService executorService = Executors.newFixedThreadPool(256);
+    private static final int DEFAULT_BATCH_SIZE = 200;
+    public static final int SCORE_TIME_OUT = 350;
+
 
 
     public PAIScorer(ScorerConfigInfo configInfo) {
@@ -43,52 +48,115 @@ public class PAIScorer extends AbstractScorer {
             return rankItems;
         }
 
-        long startTime = System.currentTimeMillis();
-
-        List<AdRankItem> result = rankByJava(sceneFeatureMap, userFeatureMap, rankItems);
+        final long startTime = System.currentTimeMillis();
+        final int batchSize = DEFAULT_BATCH_SIZE;
 
-        LOGGER.debug("ctr ranker time java items size={}, time={} ", result != null ? result.size() : 0,
-                System.currentTimeMillis() - startTime);
+        // 小数据量直接同步处理
+        if (rankItems.size() <= batchSize) {
+            return processBatchSynchronously(sceneFeatureMap, userFeatureMap, rankItems, startTime);
+        }
 
-        return result;
+        try {
+            // 1. 分批处理
+            List<List<AdRankItem>> batches = Lists.partition(rankItems, batchSize);
+
+            // 2. 创建异步任务
+            List<CompletableFuture<List<AdRankItem>>> futures = batches.stream()
+                    .map(batch -> CompletableFuture.supplyAsync(
+                            () -> processBatch(sceneFeatureMap, userFeatureMap, batch),
+                            executorService
+                    ))
+                    .collect(Collectors.toList());
+
+            // 3. 合并结果
+            CompletableFuture<Void> allFutures = CompletableFuture.allOf(
+                    futures.toArray(new CompletableFuture[0])
+            );
+
+            List<AdRankItem> result = allFutures.thenApply(v ->
+                    futures.stream()
+                            .flatMap(future -> future.join().stream())
+                            .collect(Collectors.toList())
+            ).get(SCORE_TIME_OUT, TimeUnit.MILLISECONDS); // 设置超时时间
+
+            // 4. 全局排序
+            Collections.sort(result);
+
+            // 5. 记录日志
+            LOGGER.debug("Async scoring completed. Total items={}, batches={}, time={}ms",
+                    result.size(),
+                    batches.size(),
+                    System.currentTimeMillis() - startTime);
+
+            return result;
+        } catch (Exception e) {
+            LOGGER.error("Async scoring failed, falling back to sync. Error: {}", e.getMessage(), e);
+            // 降级:同步处理
+            return processBatchSynchronously(sceneFeatureMap, userFeatureMap, rankItems, startTime);
+        }
     }
 
-    private List<AdRankItem> rankByJava(final Map<String, String> sceneFeatureMap,
-                                        final Map<String, String> userFeatureMap,
-                                        final List<AdRankItem> items) {
-        long startTime = System.currentTimeMillis();
+    // 处理单个批次(不排序)
+    private List<AdRankItem> processBatch(final Map<String, String> sceneFeatureMap,
+                                          final Map<String, String> userFeatureMap,
+                                          final List<AdRankItem> batch) {
+        long batchStart = System.currentTimeMillis();
         PAIModelV1 model = PAIModelV1.getModel();
-        // 所有都参与打分,按照ctr排序
-        multipleCtrScore(items, userFeatureMap, sceneFeatureMap, model);
+        multipleCtrScore(batch, userFeatureMap, sceneFeatureMap, model);
 
-        // debug log
         if (LOGGER.isDebugEnabled()) {
-            for (int i = 0; i < items.size(); i++) {
-                LOGGER.debug("before enter feeds model predict ctr score [{}] [{}]", items.get(i), items.get(i));
+            for (AdRankItem item : batch) {
+                LOGGER.debug("Batch item scored: {}", item);
             }
+            LOGGER.debug("Batch processed: size={}, cost={}ms",
+                    batch.size(), System.currentTimeMillis() - batchStart);
         }
 
-        Collections.sort(items);
+        return batch;
+    }
+
+    // 同步处理整个批次(包含排序)
+    private List<AdRankItem> processBatchSynchronously(
+            final Map<String, String> sceneFeatureMap,
+            final Map<String, String> userFeatureMap,
+            final List<AdRankItem> batch,
+            final long startTime) {
+
+        PAIModelV1 model = PAIModelV1.getModel();
+        multipleCtrScore(batch, userFeatureMap, sceneFeatureMap, model);
+        Collections.sort(batch);
 
-        LOGGER.debug("ctr ranker java execute time: [{}]", System.currentTimeMillis() - startTime);
-        LOGGER.debug("[ctr ranker time java] items size={}, cost={} ", items != null ? items.size() : 0,
-                System.currentTimeMillis() - startTime);
-        return items;
+        LOGGER.debug("Sync scoring completed. Items={}, time={}ms",
+                batch.size(), System.currentTimeMillis() - startTime);
+
+        return batch;
     }
 
     private void multipleCtrScore(final List<AdRankItem> items,
                                   final Map<String, String> userFeatureMap,
                                   final Map<String, String> sceneFeatureMap,
                                   final PAIModelV1 model) {
+        // 添加空检查确保安全
+        if (CollectionUtils.isEmpty(items)) return;
+
+        List<Float> scores = model.score(items, userFeatureMap, sceneFeatureMap);
+
+        if (scores == null || scores.size() != items.size()) {
+            LOGGER.error("Score size mismatch! Items: {}, Scores: {}",
+                    items.size(),
+                    scores != null ? scores.size() : "null");
+            return;
+        }
 
-        List<Float> score = model.score(items, userFeatureMap, sceneFeatureMap);
-        LOGGER.debug("PAIScorer score={}", score);
         for (int i = 0; i < items.size(); i++) {
-            Double pro = Double.valueOf(score.get(i));
-            items.get(i).setLrScore(pro);
-            items.get(i).getScoreMap().put("ctcvrScore", pro);
+            try {
+                Double pro = Double.valueOf(scores.get(i));
+                AdRankItem item = items.get(i);
+                item.setLrScore(pro);
+                item.getScoreMap().put("ctcvrScore", pro);
+            } catch (Exception e) {
+                LOGGER.error("Error setting score for item: {}", items.get(i), e);
+            }
         }
     }
-
-
 }