Browse Source

请求模型改为批量

xueyiming 1 week ago
parent
commit
4ddaa82396

+ 13 - 19
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/scorer/PAIScorer.java

@@ -7,6 +7,7 @@ import com.tzld.piaoquan.ad.engine.commons.score.BaseXGBoostModelScorer;
 import com.tzld.piaoquan.ad.engine.commons.score.ScoreParam;
 import com.tzld.piaoquan.ad.engine.commons.score.ScoreParam;
 import com.tzld.piaoquan.ad.engine.commons.score.ScorerConfigInfo;
 import com.tzld.piaoquan.ad.engine.commons.score.ScorerConfigInfo;
 import com.tzld.piaoquan.ad.engine.commons.score.model.PAIModelV1;
 import com.tzld.piaoquan.ad.engine.commons.score.model.PAIModelV1;
+import com.tzld.piaoquan.ad.engine.commons.score.model.PAIModelV2;
 import com.tzld.piaoquan.ad.engine.commons.score.model.XGBoostModel683;
 import com.tzld.piaoquan.ad.engine.commons.score.model.XGBoostModel683;
 import com.tzld.piaoquan.ad.engine.commons.thread.ThreadPoolFactory;
 import com.tzld.piaoquan.ad.engine.commons.thread.ThreadPoolFactory;
 import com.tzld.piaoquan.recommend.feature.domain.ad.base.AdRankItem;
 import com.tzld.piaoquan.recommend.feature.domain.ad.base.AdRankItem;
@@ -87,16 +88,14 @@ public class PAIScorer extends AbstractScorer {
         long startTime = System.currentTimeMillis();
         long startTime = System.currentTimeMillis();
         PAIModelV1 model = PAIModelV1.getModel();
         PAIModelV1 model = PAIModelV1.getModel();
 
 
-        // 切分批次
-        final int batchSize = 500; // 每批 500 个,可按你的业务场景调整
+        final int batchSize = 500;
         List<List<AdRankItem>> batches = new ArrayList<>();
         List<List<AdRankItem>> batches = new ArrayList<>();
         for (int i = 0; i < items.size(); i += batchSize) {
         for (int i = 0; i < items.size(); i += batchSize) {
-            batches.add(items.subList(i, Math.min(i + batchSize, items.size())));
+            batches.add(new ArrayList<>(items.subList(i, Math.min(i + batchSize, items.size()))));
         }
         }
 
 
-        // 并发执行
         ExecutorService executor = ThreadPoolFactory.defaultPool();
         ExecutorService executor = ThreadPoolFactory.defaultPool();
-        List<Future<?>> futures = new ArrayList<>();
+        List<Future<List<AdRankItem>>> futures = new ArrayList<>();
 
 
         for (List<AdRankItem> batch : batches) {
         for (List<AdRankItem> batch : batches) {
             futures.add(executor.submit(() -> {
             futures.add(executor.submit(() -> {
@@ -105,29 +104,24 @@ public class PAIScorer extends AbstractScorer {
                 } catch (Exception e) {
                 } catch (Exception e) {
                     LOGGER.error("Error during multipleCtrScore batch execution", e);
                     LOGGER.error("Error during multipleCtrScore batch execution", e);
                 }
                 }
+                return batch;
             }));
             }));
         }
         }
 
 
-        // 等待所有任务完成
-        for (Future<?> future : futures) {
+        // 合并结果
+        List<AdRankItem> merged = new ArrayList<>();
+        for (Future<List<AdRankItem>> future : futures) {
             try {
             try {
-                future.get();
-            } catch (InterruptedException e) {
-                Thread.currentThread().interrupt();
-                LOGGER.warn("Thread interrupted while waiting for batch tasks", e);
-            } catch (ExecutionException e) {
-                LOGGER.error("Execution error in batch tasks", e);
+                merged.addAll(future.get());
+            } catch (Exception e) {
+                LOGGER.error("Execution error in batch", e);
             }
             }
         }
         }
 
 
-        // 打分完成后排序
-        Collections.sort(items);
+        Collections.sort(merged);
 
 
         LOGGER.debug("ctr ranker java execute time: [{}ms]", System.currentTimeMillis() - startTime);
         LOGGER.debug("ctr ranker java execute time: [{}ms]", System.currentTimeMillis() - startTime);
-        LOGGER.debug("[ctr ranker time java] items size={}, cost={}ms",
-                items.size(), System.currentTimeMillis() - startTime);
-
-        return items;
+        return merged;
     }
     }
 
 
 
 

+ 13 - 19
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/scorer/PAIScorerV2.java

@@ -86,16 +86,15 @@ public class PAIScorerV2 extends AbstractScorer {
 
 
         long startTime = System.currentTimeMillis();
         long startTime = System.currentTimeMillis();
         PAIModelV2 model = PAIModelV2.getModel();
         PAIModelV2 model = PAIModelV2.getModel();
-        // 切分批次
-        final int batchSize = 500; // 每批 500 个,可按你的业务场景调整
+
+        final int batchSize = 500;
         List<List<AdRankItem>> batches = new ArrayList<>();
         List<List<AdRankItem>> batches = new ArrayList<>();
         for (int i = 0; i < items.size(); i += batchSize) {
         for (int i = 0; i < items.size(); i += batchSize) {
-            batches.add(items.subList(i, Math.min(i + batchSize, items.size())));
+            batches.add(new ArrayList<>(items.subList(i, Math.min(i + batchSize, items.size()))));
         }
         }
 
 
-        // 并发执行
         ExecutorService executor = ThreadPoolFactory.defaultPool();
         ExecutorService executor = ThreadPoolFactory.defaultPool();
-        List<Future<?>> futures = new ArrayList<>();
+        List<Future<List<AdRankItem>>> futures = new ArrayList<>();
 
 
         for (List<AdRankItem> batch : batches) {
         for (List<AdRankItem> batch : batches) {
             futures.add(executor.submit(() -> {
             futures.add(executor.submit(() -> {
@@ -104,29 +103,24 @@ public class PAIScorerV2 extends AbstractScorer {
                 } catch (Exception e) {
                 } catch (Exception e) {
                     LOGGER.error("Error during multipleCtrScore batch execution", e);
                     LOGGER.error("Error during multipleCtrScore batch execution", e);
                 }
                 }
+                return batch;
             }));
             }));
         }
         }
 
 
-        // 等待所有任务完成
-        for (Future<?> future : futures) {
+        // 合并结果
+        List<AdRankItem> merged = new ArrayList<>();
+        for (Future<List<AdRankItem>> future : futures) {
             try {
             try {
-                future.get();
-            } catch (InterruptedException e) {
-                Thread.currentThread().interrupt();
-                LOGGER.warn("Thread interrupted while waiting for batch tasks", e);
-            } catch (ExecutionException e) {
-                LOGGER.error("Execution error in batch tasks", e);
+                merged.addAll(future.get());
+            } catch (Exception e) {
+                LOGGER.error("Execution error in batch", e);
             }
             }
         }
         }
 
 
-        // 打分完成后排序
-        Collections.sort(items);
+        Collections.sort(merged);
 
 
         LOGGER.debug("ctr ranker java execute time: [{}ms]", System.currentTimeMillis() - startTime);
         LOGGER.debug("ctr ranker java execute time: [{}ms]", System.currentTimeMillis() - startTime);
-        LOGGER.debug("[ctr ranker time java] items size={}, cost={}ms",
-                items.size(), System.currentTimeMillis() - startTime);
-
-        return items;
+        return merged;
     }
     }
 
 
     private void multipleCtrScore(final List<AdRankItem> items,
     private void multipleCtrScore(final List<AdRankItem> items,