丁云鹏 8 ماه پیش
والد
کامیت
d665aa0920

+ 14 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/score/BaseXGBoostModelScorer.java

@@ -1,9 +1,15 @@
 package com.tzld.piaoquan.recommend.server.service.score;
 
+import com.google.common.reflect.TypeToken;
 import com.tzld.piaoquan.recommend.server.service.score.model.XGBoostModel;
+import com.tzld.piaoquan.recommend.server.util.JSONUtils;
+import org.apache.commons.collections4.CollectionUtils;
+import org.apache.commons.collections4.MapUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.util.Map;
+
 
 public abstract class BaseXGBoostModelScorer extends AbstractScorer {
 
@@ -16,5 +22,13 @@ public abstract class BaseXGBoostModelScorer extends AbstractScorer {
     @Override
     public void loadModel() {
         doLoadModel(XGBoostModel.class);
+        XGBoostModel model = (XGBoostModel) this.getModel();
+        Map<String, String> paramMap = scorerConfigInfo.getParamMap();
+        if (MapUtils.isNotEmpty(paramMap)) {
+            model.setFeatures(JSONUtils.fromJson(paramMap.get("features"), new TypeToken<String[]>() {
+            }, null));
+        }
+
+
     }
 }

+ 13 - 5
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/score/ScorerConfig.java

@@ -1,18 +1,17 @@
 package com.tzld.piaoquan.recommend.server.service.score;
 
 
+import com.google.common.reflect.TypeToken;
 import com.typesafe.config.Config;
 import com.typesafe.config.ConfigFactory;
 import com.typesafe.config.ConfigObject;
 import com.typesafe.config.ConfigValue;
+import com.tzld.piaoquan.recommend.server.util.JSONUtils;
 import org.apache.commons.lang.exception.ExceptionUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.util.ArrayList;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Set;
+import java.util.*;
 
 
 public class ScorerConfig {
@@ -100,6 +99,14 @@ public class ScorerConfig {
             if (conf.hasPath("disable-switch")) {
                 disableSwitch = conf.getBoolean("disable-switch");
             }
+            Map<String, String> paramMap = null;
+            if (conf.hasPath("param")) {
+                String param = conf.getString("param");
+                paramMap = JSONUtils.fromJson(param, new TypeToken<Map<String, String>>() {
+                }, Collections.emptyMap());
+            }
+
+
             Config paramConfig = loadOptionConfig(conf, "param-config");
             // model path
             String modelPath = loadOptionStringConfig(conf, "model-path");
@@ -118,7 +125,8 @@ public class ScorerConfig {
                     disableSwitch,
                     enableQueues,
                     modelPath,
-                    paramConfig
+                    paramConfig,
+                    paramMap
             );
             LOGGER.debug("parse scorer config info [{}]", configInfo);
             // add to ConfigInfoList

+ 10 - 1
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/score/ScorerConfigInfo.java

@@ -2,6 +2,8 @@ package com.tzld.piaoquan.recommend.server.service.score;
 
 import com.google.gson.Gson;
 import com.typesafe.config.Config;
+
+import java.util.Map;
 import java.util.Set;
 
 
@@ -14,6 +16,7 @@ public class ScorerConfigInfo {
     private Set<String> enableQueues;
     private String modelPath;
     private Config paramConfig; // param config
+    private Map<String,String> paramMap;
 
     public ScorerConfigInfo(String configName,
                             String scorerName,
@@ -21,7 +24,8 @@ public class ScorerConfigInfo {
                             Boolean disableSwitch,
                             Set<String> enableQueues,
                             String modelPath,
-                            Config paramConfig) {
+                            Config paramConfig,
+                            Map<String, String> paramMap) {
 
         this.configName = configName;
         this.scorerName = scorerName;
@@ -30,6 +34,7 @@ public class ScorerConfigInfo {
         this.enableQueues = enableQueues;
         this.modelPath = modelPath;
         this.paramConfig = paramConfig;
+        this.paramMap = paramMap;
     }
 
     public Config getParamConfig() {
@@ -70,6 +75,10 @@ public class ScorerConfigInfo {
         return disableSwitch;
     }
 
+    public Map<String, String> getParamMap(){
+        return paramMap;
+    }
+
     @Override
     public String toString() {
         return new Gson().toJson(this);

+ 5 - 26
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/score/model/XGBoostModel.java

@@ -19,32 +19,11 @@ public class XGBoostModel extends Model {
     private static final Logger LOGGER = LoggerFactory.getLogger(XGBoostModel.class);
     private XGBoostClassificationModel model;
 
-    private String[] features = {
-            "cpa",
-            "b2_1h_ctr",
-            "b2_1h_ctcvr",
-            "b2_1h_cvr",
-            "b2_1h_conver",
-            "b2_1h_click",
-            "b2_1h_conver*log(view)",
-            "b2_1h_conver*ctcvr",
-            "b2_2h_ctr",
-            "b2_2h_ctcvr",
-            "b2_2h_cvr",
-            "b2_2h_conver",
-            "b2_2h_click",
-            "b2_2h_conver*log(view)",
-            "b2_2h_conver*ctcvr",
-            "b2_3h_ctr",
-            "b2_3h_ctcvr",
-            "b2_3h_cvr",
-            "b2_3h_conver",
-            "b2_3h_click",
-            "b2_3h_conver*log(view)",
-            "b2_3h_conver*ctcvr",
-            "b2_6h_ctr",
-            "b2_6h_ctcvr"
-    };
+    private String[] features;
+
+    public void setFeatures(String[] features){
+        this.features = features;
+    }
 
     @Override
     public int getModelSize() {

+ 3 - 0
recommend-server-service/src/main/resources/feeds_score_config_20240826.conf

@@ -3,6 +3,9 @@ scorer-config = {
     scorer-name = "com.tzld.piaoquan.recommend.server.service.score.XGBoostScorer"
     scorer-priority = 99
     model-path = "zhangbo/model_xgb_1000.tar.gz"
+    param = {
+      "features":["cpa","b2_1h_ctr","b2_1h_ctcvr"]
+    }
   }
 
 }