#197 优化

باز‌کردن
fanjinyang قصد ادغام 7 تغییر را از algorithm/20260304_feature_fjy_add_log به algorithm/master دارد

+ 99 - 45
ad-engine-commons/src/main/java/com/tzld/piaoquan/ad/engine/commons/score/model/PAIModelV2.java

@@ -30,14 +30,6 @@ public class PAIModelV2 {
 
     private static final PredictClient client;
 
-    static {
-        model = new PAIModelV2();
-        client = new PredictClient(new HttpConfig());
-        client.setEndpoint("1894469520484605.vpc.cn-hangzhou.pai-eas.aliyuncs.com");
-        client.setToken("M2Y0ZGJlOGRiNDJiYmJlM2E2MGRiNzk4NTdhN2MxOWUzYWMxNzdjYg==");
-        client.setModelName("ad_rank_dnn_v11_easyrec_v2");
-    }
-
     private static final String[] sparseUserStrFeatures = {
             "brand", "region", "city", "cate1", "cate2", "user_cid_click_list", "user_cid_conver_list",
             "user_vid_return_tags_2h", "user_vid_return_tags_1d", "user_vid_return_tags_3d", "user_vid_return_tags_7d",
@@ -68,11 +60,13 @@ public class PAIModelV2 {
             "profession", "category_name"
     };
 
-    private final String[] userFeatures = {
+    private static final String[] userFeatures = {
             "viewAll", "clickAll", "converAll", "incomeAll", "ctr_all", "ctcvr_all", "cvr_all", "ecpm_all"
     };
 
-    private final String[] itemFeatures = {
+    private static final String[] userFeaturesLowerCase;
+
+    private static final String[] itemFeatures = {
             "actionstatic_click", "actionstatic_ctcvr", "actionstatic_ctr", "actionstatic_view", "b2_12h_click", "b2_12h_conver",
             "b2_12h_conver_x_ctcvr", "b2_12h_conver_x_log_view", "b2_12h_ctcvr", "b2_12h_ctr", "b2_12h_cvr", "b2_12h_ecpm",
             "b2_1d_click", "b2_1d_conver", "b2_1d_conver_x_ctcvr", "b2_1d_conver_x_log_view", "b2_1d_ctcvr", "b2_1d_ctr",
@@ -122,13 +116,42 @@ public class PAIModelV2 {
             "vid_rank_ecpm_3d", "vid_rank_ecpm_7d"
     };
 
+    // 预计算 itemFeature → featureMap key 的映射
+    private static final String[] itemFeatureKeys;
+
+    static {
+        model = new PAIModelV2();
+        client = new PredictClient(new HttpConfig());
+        client.setEndpoint("1894469520484605.vpc.cn-hangzhou.pai-eas.aliyuncs.com");
+        client.setToken("M2Y0ZGJlOGRiNDJiYmJlM2E2MGRiNzk4NTdhN2MxOWUzYWMxNzdjYg==");
+        client.setModelName("ad_rank_dnn_v11_easyrec_v2");
+
+        // 预计算 itemFeatures 的 key 映射
+        itemFeatureKeys = new String[itemFeatures.length];
+        for (int i = 0; i < itemFeatures.length; i++) {
+            itemFeatureKeys[i] = itemFeatures[i].replace("_x_", "*").replace("_view", "(view)");
+        }
+
+        // 预计算 userFeatures 的小写映射
+        userFeaturesLowerCase = new String[userFeatures.length];
+        for (int i = 0; i < userFeatures.length; i++) {
+            userFeaturesLowerCase[i] = userFeatures[i].toLowerCase();
+        }
+    }
+
 
     public List<Float> score(final List<AdRankItem> items,
                              final Map<String, String> userFeatureMap,
                              final Map<String, String> sceneFeatureMap) {
+        long totalStart = System.currentTimeMillis();
+        long buildFeatureTime = 0;
+        long predictTime = 0;
+
         try {
             TFRequest request = new TFRequest();
 
+            // 阶段1:构建用户/场景特征
+            long stageStart = System.currentTimeMillis();
             for (String feature : sparseUserStrFeatures) {
                 String v = userFeatureMap.getOrDefault(feature, "");
                 request.addFeed(feature, TFDataType.DT_STRING, new long[]{1}, new String[]{v});
@@ -144,65 +167,96 @@ public class PAIModelV2 {
                 request.addFeed(feature, TFDataType.DT_INT64, new long[]{1}, new long[]{v});
             }
 
+            for (int i = 0; i < userFeatures.length; i++) {
+                double v = NumberUtils.toDouble(userFeatureMap.getOrDefault(userFeatures[i], "0.0"), 0.0);
+                request.addFeed(userFeaturesLowerCase[i], TFDataType.DT_DOUBLE, new long[]{1}, new double[]{v});
+            }
+
+            // 阶段2:构建 item 特征(优化版本)
+            final int size = items.size();
 
-            for (String feature : userFeatures) {
-                double v = NumberUtils.toDouble(userFeatureMap.getOrDefault(feature, "0.0"), 0.0);
-                request.addFeed(feature.toLowerCase(), TFDataType.DT_DOUBLE, new long[]{1}, new double[]{v});
+            // 预提取所有 item 的 featureMap 引用,避免重复调用 get 方法
+            Map[] featureMaps = new Map[size];
+            for (int i = 0; i < size; i++) {
+                featureMaps[i] = items.get(i).getFeatureMap();
             }
-            Map<String, double[]> doubleFeed = new HashMap<>();
-            Map<String, long[]> longFeed = new HashMap<>();
-            Map<String, String[]> strFeed = new HashMap<>();
-            for (int i = 0; i < items.size(); i++) {
-                for (String feature : itemFeatures) {
-                    String key = feature.replace("_x_", "*").replace("_view", "(view)");
-                    double[] doubles = doubleFeed.computeIfAbsent(feature, k -> new double[items.size()]);
-                    if (MapUtils.isEmpty(items.get(i).getFeatureMap())) {
+
+            // 预分配所有数组
+            double[][] doubleArrays = new double[itemFeatures.length][size];
+            long[][] longArrays = new long[sparseAdLongFeatures.length][size];
+            String[][] strArrays = new String[sparseAdStrFeatures.length][size];
+
+            // 按 feature 遍历(外层),提高缓存局部性
+            // 处理 double 类型特征
+            for (int f = 0; f < itemFeatures.length; f++) {
+                String key = itemFeatureKeys[f];
+                double[] doubles = doubleArrays[f];
+                for (int i = 0; i < size; i++) {
+                    Map<String, String> featureMap = featureMaps[i];
+                    if (featureMap == null || featureMap.isEmpty()) {
                         doubles[i] = 0.0;
-                        continue;
+                    } else {
+                        doubles[i] = NumberUtils.toDouble(featureMap.getOrDefault(key, "0.0"), 0.0);
                     }
-                    double v = NumberUtils.toDouble(items.get(i).getFeatureMap().getOrDefault(key, "0.0"), 0.0);
-                    doubles[i] = v;
                 }
+            }
 
-                for (String feature : sparseAdLongFeatures) {
-                    long[] longs = longFeed.computeIfAbsent(feature, k -> new long[items.size()]);
-                    if (MapUtils.isEmpty(items.get(i).getFeatureMap())) {
+            // 处理 long 类型特征
+            for (int f = 0; f < sparseAdLongFeatures.length; f++) {
+                String feature = sparseAdLongFeatures[f];
+                long[] longs = longArrays[f];
+                for (int i = 0; i < size; i++) {
+                    Map<String, String> featureMap = featureMaps[i];
+                    if (featureMap == null || featureMap.isEmpty()) {
                         longs[i] = 0L;
-                        continue;
+                    } else {
+                        longs[i] = NumberUtils.toLong(featureMap.getOrDefault(feature, "0"), 0L);
                     }
-                    long v = NumberUtils.toLong(items.get(i).getFeatureMap().getOrDefault(feature, "0"), 0L);
-                    longs[i] = v;
                 }
+            }
 
-                for (String feature : sparseAdStrFeatures) {
-                    String[] strs = strFeed.computeIfAbsent(feature, k -> new String[items.size()]);
-                    if (MapUtils.isEmpty(items.get(i).getFeatureMap())) {
+            // 处理 String 类型特征
+            for (int f = 0; f < sparseAdStrFeatures.length; f++) {
+                String feature = sparseAdStrFeatures[f];
+                String[] strs = strArrays[f];
+                for (int i = 0; i < size; i++) {
+                    Map<String, String> featureMap = featureMaps[i];
+                    if (featureMap == null || featureMap.isEmpty()) {
                         strs[i] = "";
-                        continue;
+                    } else {
+                        strs[i] = featureMap.getOrDefault(feature, "");
                     }
-                    String v = items.get(i).getFeatureMap().getOrDefault(feature, "");
-                    strs[i] = v;
                 }
             }
-            for (Map.Entry<String, double[]> entry : doubleFeed.entrySet()) {
-                request.addFeed(entry.getKey(), TFDataType.DT_DOUBLE, new long[]{items.size()}, entry.getValue());
-            }
 
-            for (Map.Entry<String, long[]> entry : longFeed.entrySet()) {
-                request.addFeed(entry.getKey(), TFDataType.DT_INT64, new long[]{items.size()}, entry.getValue());
+            // 构建请求体
+            long[] shape = new long[]{size};
+            for (int f = 0; f < itemFeatures.length; f++) {
+                request.addFeed(itemFeatures[f], TFDataType.DT_DOUBLE, shape, doubleArrays[f]);
             }
-
-            for (Map.Entry<String, String[]> entry : strFeed.entrySet()) {
-                request.addFeed(entry.getKey(), TFDataType.DT_STRING, new long[]{items.size()}, entry.getValue());
+            for (int f = 0; f < sparseAdLongFeatures.length; f++) {
+                request.addFeed(sparseAdLongFeatures[f], TFDataType.DT_INT64, shape, longArrays[f]);
+            }
+            for (int f = 0; f < sparseAdStrFeatures.length; f++) {
+                request.addFeed(sparseAdStrFeatures[f], TFDataType.DT_STRING, shape, strArrays[f]);
             }
             request.addFetch("probs");
+            buildFeatureTime = System.currentTimeMillis() - stageStart;
+
+            // 阶段3:PAI-EAS 远程调用
+            stageStart = System.currentTimeMillis();
             TFResponse response = client.predict(request);
+            predictTime = System.currentTimeMillis() - stageStart;
+
             List<Float> result = response.getFloatVals("probs");
+
             if (!CollectionUtils.isEmpty(result)) {
                 return result;
             }
         } catch (Exception e) {
-            LOGGER.error("PAIModel score error", e);
+            long totalTime = System.currentTimeMillis() - totalStart;
+            LOGGER.error("PAIModelV2 score error itemSize={} totalTime={}ms buildFeature={}ms predict={}ms",
+                    items.size(), totalTime, buildFeatureTime, predictTime, e);
         }
         return new ArrayList<>(Collections.nCopies(items.size(), 0.0f));
     }

+ 14 - 3
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/scorer/PAIScorerV2.java

@@ -21,6 +21,7 @@ import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
 
 public class PAIScorerV2 extends AbstractScorer {
 
@@ -88,7 +89,7 @@ public class PAIScorerV2 extends AbstractScorer {
         long startTime = System.currentTimeMillis();
         PAIModelV2 model = PAIModelV2.getModel();
 
-        final int batchSize = 300;
+        final int batchSize = 500;
         List<List<AdRankItem>> batches = new ArrayList<>();
         for (int i = 0; i < items.size(); i += batchSize) {
             batches.add(new ArrayList<>(items.subList(i, Math.min(i + batchSize, items.size()))));
@@ -97,30 +98,40 @@ public class PAIScorerV2 extends AbstractScorer {
         ExecutorService executor = ThreadPoolFactory.score();
         List<Future<List<AdRankItem>>> futures = new ArrayList<>();
 
+        // 记录提交时间,用于计算 SCORE 线程池排队等待耗时
+        long submitTime = System.currentTimeMillis();
         for (List<AdRankItem> batch : batches) {
             futures.add(executor.submit(() -> {
+
                 try {
                     multipleCtrScore(batch, userFeatureMap, sceneFeatureMap, model);
                 } catch (Exception e) {
                     LOGGER.error("Error during multipleCtrScore batch execution", e);
                 }
+
                 return batch;
             }));
         }
 
         // 合并结果
         List<AdRankItem> merged = new ArrayList<>();
+        int batchIndex = 0;
         for (Future<List<AdRankItem>> future : futures) {
+            long getStartTime = System.currentTimeMillis();
             try {
                 merged.addAll(future.get(400, TimeUnit.MILLISECONDS));
+            } catch (TimeoutException e) {
+                long waitTime = System.currentTimeMillis() - getStartTime;
+                LOGGER.error("PAIScorerV2 batch timeout batchIndex={} waitTime={}ms totalElapsed={}ms",
+                        batchIndex, waitTime, System.currentTimeMillis() - startTime);
             } catch (Exception e) {
-                LOGGER.error("Execution error in batch", e);
+                LOGGER.error("Execution error in batch batchIndex={}", batchIndex, e);
             }
+            batchIndex++;
         }
 
         Collections.sort(merged);
 
-        LOGGER.debug("ctr ranker java execute time: [{}ms]", System.currentTimeMillis() - startTime);
         return merged;
     }