|
@@ -3,26 +3,31 @@ package com.tzld.piaoquan.ad.engine.service.score.scorer;
|
|
|
|
|
|
import com.google.common.collect.Lists;
|
|
|
import com.tzld.piaoquan.ad.engine.commons.score.AbstractScorer;
|
|
|
-import com.tzld.piaoquan.ad.engine.commons.score.BaseXGBoostModelScorer;
|
|
|
import com.tzld.piaoquan.ad.engine.commons.score.ScoreParam;
|
|
|
import com.tzld.piaoquan.ad.engine.commons.score.ScorerConfigInfo;
|
|
|
import com.tzld.piaoquan.ad.engine.commons.score.model.PAIModelV1;
|
|
|
-import com.tzld.piaoquan.ad.engine.commons.score.model.XGBoostModel683;
|
|
|
import com.tzld.piaoquan.recommend.feature.domain.ad.base.AdRankItem;
|
|
|
import com.tzld.piaoquan.recommend.feature.domain.ad.base.UserAdFeature;
|
|
|
import org.apache.commons.collections4.CollectionUtils;
|
|
|
-import org.apache.commons.collections4.MapUtils;
|
|
|
-import org.apache.commons.lang.exception.ExceptionUtils;
|
|
|
import org.slf4j.Logger;
|
|
|
import org.slf4j.LoggerFactory;
|
|
|
-import org.springframework.stereotype.Component;
|
|
|
|
|
|
-import java.util.*;
|
|
|
-import java.util.concurrent.*;
|
|
|
+import java.util.Collections;
|
|
|
+import java.util.List;
|
|
|
+import java.util.Map;
|
|
|
+import java.util.concurrent.CompletableFuture;
|
|
|
+import java.util.concurrent.ExecutorService;
|
|
|
+import java.util.concurrent.Executors;
|
|
|
+import java.util.concurrent.TimeUnit;
|
|
|
+import java.util.stream.Collectors;
|
|
|
|
|
|
public class PAIScorer extends AbstractScorer {
|
|
|
|
|
|
private final static Logger LOGGER = LoggerFactory.getLogger(PAIScorer.class);
|
|
|
+ private static final ExecutorService executorService = Executors.newFixedThreadPool(256);
|
|
|
+ private static final int DEFAULT_BATCH_SIZE = 200;
|
|
|
+ public static final int SCORE_TIME_OUT = 350;
|
|
|
+
|
|
|
|
|
|
|
|
|
public PAIScorer(ScorerConfigInfo configInfo) {
|
|
@@ -43,52 +48,115 @@ public class PAIScorer extends AbstractScorer {
|
|
|
return rankItems;
|
|
|
}
|
|
|
|
|
|
- long startTime = System.currentTimeMillis();
|
|
|
-
|
|
|
- List<AdRankItem> result = rankByJava(sceneFeatureMap, userFeatureMap, rankItems);
|
|
|
+ final long startTime = System.currentTimeMillis();
|
|
|
+ final int batchSize = DEFAULT_BATCH_SIZE;
|
|
|
|
|
|
- LOGGER.debug("ctr ranker time java items size={}, time={} ", result != null ? result.size() : 0,
|
|
|
- System.currentTimeMillis() - startTime);
|
|
|
+ // 小数据量直接同步处理
|
|
|
+ if (rankItems.size() <= batchSize) {
|
|
|
+ return processBatchSynchronously(sceneFeatureMap, userFeatureMap, rankItems, startTime);
|
|
|
+ }
|
|
|
|
|
|
- return result;
|
|
|
+ try {
|
|
|
+ // 1. 分批处理
|
|
|
+ List<List<AdRankItem>> batches = Lists.partition(rankItems, batchSize);
|
|
|
+
|
|
|
+ // 2. 创建异步任务
|
|
|
+ List<CompletableFuture<List<AdRankItem>>> futures = batches.stream()
|
|
|
+ .map(batch -> CompletableFuture.supplyAsync(
|
|
|
+ () -> processBatch(sceneFeatureMap, userFeatureMap, batch),
|
|
|
+ executorService
|
|
|
+ ))
|
|
|
+ .collect(Collectors.toList());
|
|
|
+
|
|
|
+ // 3. 合并结果
|
|
|
+ CompletableFuture<Void> allFutures = CompletableFuture.allOf(
|
|
|
+ futures.toArray(new CompletableFuture[0])
|
|
|
+ );
|
|
|
+
|
|
|
+ List<AdRankItem> result = allFutures.thenApply(v ->
|
|
|
+ futures.stream()
|
|
|
+ .flatMap(future -> future.join().stream())
|
|
|
+ .collect(Collectors.toList())
|
|
|
+ ).get(SCORE_TIME_OUT, TimeUnit.MILLISECONDS); // 设置超时时间
|
|
|
+
|
|
|
+ // 4. 全局排序
|
|
|
+ Collections.sort(result);
|
|
|
+
|
|
|
+ // 5. 记录日志
|
|
|
+ LOGGER.debug("Async scoring completed. Total items={}, batches={}, time={}ms",
|
|
|
+ result.size(),
|
|
|
+ batches.size(),
|
|
|
+ System.currentTimeMillis() - startTime);
|
|
|
+
|
|
|
+ return result;
|
|
|
+ } catch (Exception e) {
|
|
|
+ LOGGER.error("Async scoring failed, falling back to sync. Error: {}", e.getMessage(), e);
|
|
|
+ // 降级:同步处理
|
|
|
+ return processBatchSynchronously(sceneFeatureMap, userFeatureMap, rankItems, startTime);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
- private List<AdRankItem> rankByJava(final Map<String, String> sceneFeatureMap,
|
|
|
- final Map<String, String> userFeatureMap,
|
|
|
- final List<AdRankItem> items) {
|
|
|
- long startTime = System.currentTimeMillis();
|
|
|
+ // 处理单个批次(不排序)
|
|
|
+ private List<AdRankItem> processBatch(final Map<String, String> sceneFeatureMap,
|
|
|
+ final Map<String, String> userFeatureMap,
|
|
|
+ final List<AdRankItem> batch) {
|
|
|
+ long batchStart = System.currentTimeMillis();
|
|
|
PAIModelV1 model = PAIModelV1.getModel();
|
|
|
- // 所有都参与打分,按照ctr排序
|
|
|
- multipleCtrScore(items, userFeatureMap, sceneFeatureMap, model);
|
|
|
+ multipleCtrScore(batch, userFeatureMap, sceneFeatureMap, model);
|
|
|
|
|
|
- // debug log
|
|
|
if (LOGGER.isDebugEnabled()) {
|
|
|
- for (int i = 0; i < items.size(); i++) {
|
|
|
- LOGGER.debug("before enter feeds model predict ctr score [{}] [{}]", items.get(i), items.get(i));
|
|
|
+ for (AdRankItem item : batch) {
|
|
|
+ LOGGER.debug("Batch item scored: {}", item);
|
|
|
}
|
|
|
+ LOGGER.debug("Batch processed: size={}, cost={}ms",
|
|
|
+ batch.size(), System.currentTimeMillis() - batchStart);
|
|
|
}
|
|
|
|
|
|
- Collections.sort(items);
|
|
|
+ return batch;
|
|
|
+ }
|
|
|
+
|
|
|
+ // 同步处理整个批次(包含排序)
|
|
|
+ private List<AdRankItem> processBatchSynchronously(
|
|
|
+ final Map<String, String> sceneFeatureMap,
|
|
|
+ final Map<String, String> userFeatureMap,
|
|
|
+ final List<AdRankItem> batch,
|
|
|
+ final long startTime) {
|
|
|
+
|
|
|
+ PAIModelV1 model = PAIModelV1.getModel();
|
|
|
+ multipleCtrScore(batch, userFeatureMap, sceneFeatureMap, model);
|
|
|
+ Collections.sort(batch);
|
|
|
|
|
|
- LOGGER.debug("ctr ranker java execute time: [{}]", System.currentTimeMillis() - startTime);
|
|
|
- LOGGER.debug("[ctr ranker time java] items size={}, cost={} ", items != null ? items.size() : 0,
|
|
|
- System.currentTimeMillis() - startTime);
|
|
|
- return items;
|
|
|
+ LOGGER.debug("Sync scoring completed. Items={}, time={}ms",
|
|
|
+ batch.size(), System.currentTimeMillis() - startTime);
|
|
|
+
|
|
|
+ return batch;
|
|
|
}
|
|
|
|
|
|
private void multipleCtrScore(final List<AdRankItem> items,
|
|
|
final Map<String, String> userFeatureMap,
|
|
|
final Map<String, String> sceneFeatureMap,
|
|
|
final PAIModelV1 model) {
|
|
|
+ // 添加空检查确保安全
|
|
|
+ if (CollectionUtils.isEmpty(items)) return;
|
|
|
+
|
|
|
+ List<Float> scores = model.score(items, userFeatureMap, sceneFeatureMap);
|
|
|
+
|
|
|
+ if (scores == null || scores.size() != items.size()) {
|
|
|
+ LOGGER.error("Score size mismatch! Items: {}, Scores: {}",
|
|
|
+ items.size(),
|
|
|
+ scores != null ? scores.size() : "null");
|
|
|
+ return;
|
|
|
+ }
|
|
|
|
|
|
- List<Float> score = model.score(items, userFeatureMap, sceneFeatureMap);
|
|
|
- LOGGER.debug("PAIScorer score={}", score);
|
|
|
for (int i = 0; i < items.size(); i++) {
|
|
|
- Double pro = Double.valueOf(score.get(i));
|
|
|
- items.get(i).setLrScore(pro);
|
|
|
- items.get(i).getScoreMap().put("ctcvrScore", pro);
|
|
|
+ try {
|
|
|
+ Double pro = Double.valueOf(scores.get(i));
|
|
|
+ AdRankItem item = items.get(i);
|
|
|
+ item.setLrScore(pro);
|
|
|
+ item.getScoreMap().put("ctcvrScore", pro);
|
|
|
+ } catch (Exception e) {
|
|
|
+ LOGGER.error("Error setting score for item: {}", items.get(i), e);
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
-
|
|
|
-
|
|
|
}
|