jiandong.liu 2 napja
szülő
commit
e8c4d8055d

+ 91 - 46
ad-engine-commons/src/main/java/com/tzld/piaoquan/ad/engine/commons/score/model/PAIModelV1.java

@@ -6,7 +6,6 @@ import com.aliyun.openservices.eas.predict.request.TFDataType;
 import com.aliyun.openservices.eas.predict.request.TFRequest;
 import com.aliyun.openservices.eas.predict.response.TFResponse;
 import com.tzld.piaoquan.recommend.feature.domain.ad.base.AdRankItem;
-import org.apache.commons.collections4.MapUtils;
 import org.apache.commons.lang.math.NumberUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -30,14 +29,6 @@ public class PAIModelV1 {
 
     private static final PredictClient client;
 
-    static {
-        model = new PAIModelV1();
-        client = new PredictClient(new HttpConfig());
-        client.setEndpoint("1894469520484605.vpc.cn-hangzhou.pai-eas.aliyuncs.com");
-        client.setToken("ODI1MmUxODgzZDc3ODM0ZmQwZWU0YTVjZjdlOWVlMGFlZGJjNTlkYQ==");
-        client.setModelName("ad_rank_dnn_v11_easyrec");
-    }
-
     private static final String[] sparseUserStrFeatures = {
             "brand", "region", "city", "cate1", "cate2", "user_cid_click_list", "user_cid_conver_list",
             "user_vid_return_tags_2h", "user_vid_return_tags_1d", "user_vid_return_tags_3d", "user_vid_return_tags_7d",
@@ -68,11 +59,16 @@ public class PAIModelV1 {
             "profession", "category_name"
     };
 
-    private final String[] userFeatures = {
+    private static final String[] userFeatures = {
             "viewAll", "clickAll", "converAll", "incomeAll", "ctr_all", "ctcvr_all", "cvr_all", "ecpm_all"
     };
 
-    private final String[] itemFeatures = {
+    /**
+     * 预计算的 userFeatures 小写映射,避免运行时重复调用 toLowerCase()
+     */
+    private static final String[] userFeaturesLowerCase;
+
+    private static final String[] itemFeatures = {
             "actionstatic_click", "actionstatic_ctcvr", "actionstatic_ctr", "actionstatic_view", "b2_12h_click", "b2_12h_conver",
             "b2_12h_conver_x_ctcvr", "b2_12h_conver_x_log_view", "b2_12h_ctcvr", "b2_12h_ctr", "b2_12h_cvr", "b2_12h_ecpm",
             "b2_1d_click", "b2_1d_conver", "b2_1d_conver_x_ctcvr", "b2_1d_conver_x_log_view", "b2_1d_ctcvr", "b2_1d_ctr",
@@ -122,6 +118,32 @@ public class PAIModelV1 {
             "vid_rank_ecpm_3d", "vid_rank_ecpm_7d"
     };
 
+    /**
+     * 预计算的 itemFeatures key 映射,避免运行时重复执行 replace 操作
+     * key: feature name, value: 转换后的 key(用于从 featureMap 中获取值)
+     */
+    private static final String[] itemFeatureKeys;
+
+    static {
+        model = new PAIModelV1();
+        client = new PredictClient(new HttpConfig());
+        client.setEndpoint("1894469520484605.vpc.cn-hangzhou.pai-eas.aliyuncs.com");
+        client.setToken("ODI1MmUxODgzZDc3ODM0ZmQwZWU0YTVjZjdlOWVlMGFlZGJjNTlkYQ==");
+        client.setModelName("ad_rank_dnn_v11_easyrec");
+
+        // 预计算 itemFeatures 的 key 映射
+        itemFeatureKeys = new String[itemFeatures.length];
+        for (int i = 0; i < itemFeatures.length; i++) {
+            itemFeatureKeys[i] = itemFeatures[i].replace("_x_", "*").replace("_view", "(view)");
+        }
+
+        // 预计算 userFeatures 的小写映射
+        userFeaturesLowerCase = new String[userFeatures.length];
+        for (int i = 0; i < userFeatures.length; i++) {
+            userFeaturesLowerCase[i] = userFeatures[i].toLowerCase();
+        }
+    }
+
 
     public List<Float> score(final List<AdRankItem> items,
                              final Map<String, String> userFeatureMap,
@@ -133,6 +155,8 @@ public class PAIModelV1 {
         long predictTime = 0;
         long parseResponseTime = 0;
         
+        final int size = items.size();
+        
         try {
             TFRequest request = new TFRequest();
 
@@ -153,63 +177,84 @@ public class PAIModelV1 {
                 request.addFeed(feature, TFDataType.DT_INT64, new long[]{1}, new long[]{v});
             }
 
-            for (String feature : userFeatures) {
-                double v = NumberUtils.toDouble(userFeatureMap.getOrDefault(feature, "0.0"), 0.0);
-                request.addFeed(feature.toLowerCase(), TFDataType.DT_DOUBLE, new long[]{1}, new double[]{v});
+            for (int i = 0; i < userFeatures.length; i++) {
+                double v = NumberUtils.toDouble(userFeatureMap.getOrDefault(userFeatures[i], "0.0"), 0.0);
+                request.addFeed(userFeaturesLowerCase[i], TFDataType.DT_DOUBLE, new long[]{1}, new double[]{v});
             }
             buildUserFeatureTime = System.currentTimeMillis() - stageStart;
 
-            // 阶段2: 构建广告Item特征
+            // 阶段2: 构建广告Item特征(优化版本)
             stageStart = System.currentTimeMillis();
-            Map<String, double[]> doubleFeed = new HashMap<>();
-            Map<String, long[]> longFeed = new HashMap<>();
-            Map<String, String[]> strFeed = new HashMap<>();
-            for (int i = 0; i < items.size(); i++) {
-                for (String feature : itemFeatures) {
-                    String key = feature.replace("_x_", "*").replace("_view", "(view)");
-                    double[] doubles = doubleFeed.computeIfAbsent(feature, k -> new double[items.size()]);
-                    if (MapUtils.isEmpty(items.get(i).getFeatureMap())) {
+            
+            // 预提取所有 item 的 featureMap 引用,避免重复调用 get 方法
+            Map[] featureMaps = new Map[size];
+            for (int i = 0; i < size; i++) {
+                featureMaps[i] = items.get(i).getFeatureMap();
+            }
+            
+            // 预分配所有数组
+            double[][] doubleArrays = new double[itemFeatures.length][size];
+            long[][] longArrays = new long[sparseAdLongFeatures.length][size];
+            String[][] strArrays = new String[sparseAdStrFeatures.length][size];
+            
+            // 按 feature 遍历(外层),提高缓存局部性
+            // 处理 double 类型特征
+            for (int f = 0; f < itemFeatures.length; f++) {
+                String key = itemFeatureKeys[f];
+                double[] doubles = doubleArrays[f];
+                for (int i = 0; i < size; i++) {
+                    Map<String, String> featureMap = featureMaps[i];
+                    if (featureMap == null || featureMap.isEmpty()) {
                         doubles[i] = 0.0;
-                        continue;
+                    } else {
+                        doubles[i] = NumberUtils.toDouble(featureMap.getOrDefault(key, "0.0"), 0.0);
                     }
-                    double v = NumberUtils.toDouble(items.get(i).getFeatureMap().getOrDefault(key, "0.0"), 0.0);
-                    doubles[i] = v;
                 }
-
-                for (String feature : sparseAdLongFeatures) {
-                    long[] longs = longFeed.computeIfAbsent(feature, k -> new long[items.size()]);
-                    if (MapUtils.isEmpty(items.get(i).getFeatureMap())) {
+            }
+            
+            // 处理 long 类型特征
+            for (int f = 0; f < sparseAdLongFeatures.length; f++) {
+                String feature = sparseAdLongFeatures[f];
+                long[] longs = longArrays[f];
+                for (int i = 0; i < size; i++) {
+                    Map<String, String> featureMap = featureMaps[i];
+                    if (featureMap == null || featureMap.isEmpty()) {
                         longs[i] = 0L;
-                        continue;
+                    } else {
+                        longs[i] = NumberUtils.toLong(featureMap.getOrDefault(feature, "0"), 0L);
                     }
-                    long v = NumberUtils.toLong(items.get(i).getFeatureMap().getOrDefault(feature, "0"), 0L);
-                    longs[i] = v;
                 }
-
-                for (String feature : sparseAdStrFeatures) {
-                    String[] strs = strFeed.computeIfAbsent(feature, k -> new String[items.size()]);
-                    if (MapUtils.isEmpty(items.get(i).getFeatureMap())) {
+            }
+            
+            // 处理 String 类型特征
+            for (int f = 0; f < sparseAdStrFeatures.length; f++) {
+                String feature = sparseAdStrFeatures[f];
+                String[] strs = strArrays[f];
+                for (int i = 0; i < size; i++) {
+                    Map<String, String> featureMap = featureMaps[i];
+                    if (featureMap == null || featureMap.isEmpty()) {
                         strs[i] = "";
-                        continue;
+                    } else {
+                        strs[i] = featureMap.getOrDefault(feature, "");
                     }
-                    String v = items.get(i).getFeatureMap().getOrDefault(feature, "");
-                    strs[i] = v;
                 }
             }
             buildItemFeatureTime = System.currentTimeMillis() - stageStart;
 
             // 阶段3: 构建请求体
             stageStart = System.currentTimeMillis();
-            for (Map.Entry<String, double[]> entry : doubleFeed.entrySet()) {
-                request.addFeed(entry.getKey(), TFDataType.DT_DOUBLE, new long[]{items.size()}, entry.getValue());
+            long[] shape = new long[]{size};
+            
+            for (int f = 0; f < itemFeatures.length; f++) {
+                request.addFeed(itemFeatures[f], TFDataType.DT_DOUBLE, shape, doubleArrays[f]);
             }
 
-            for (Map.Entry<String, long[]> entry : longFeed.entrySet()) {
-                request.addFeed(entry.getKey(), TFDataType.DT_INT64, new long[]{items.size()}, entry.getValue());
+            for (int f = 0; f < sparseAdLongFeatures.length; f++) {
+                request.addFeed(sparseAdLongFeatures[f], TFDataType.DT_INT64, shape, longArrays[f]);
             }
 
-            for (Map.Entry<String, String[]> entry : strFeed.entrySet()) {
-                request.addFeed(entry.getKey(), TFDataType.DT_STRING, new long[]{items.size()}, entry.getValue());
+            for (int f = 0; f < sparseAdStrFeatures.length; f++) {
+                request.addFeed(sparseAdStrFeatures[f], TFDataType.DT_STRING, shape, strArrays[f]);
             }
             request.addFetch("probs");
             buildRequestTime = System.currentTimeMillis() - stageStart;