Forráskód Böngészése

修复xgb更新模型,导致特征丢失

jch 3 hónapja
szülő
commit
f9f503ff5d

+ 2 - 2
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/framework/model/ModelManager.java

@@ -185,8 +185,8 @@ public class ModelManager {
             ossObj = client.getObject(bucketName, loadTask.path);
             long timeStamp = ossObj.getObjectMetadata().getLastModified().getTime();
             if (loadTask.lastModifyTime < timeStamp || isForceLoads) {
-                log.info("model file changed, ready to update, last modify: [{}], current model time: [{}]",
-                        loadTask.lastModifyTime, timeStamp);
+                log.info("model file changed: [{}], ready to update, last modify: [{}], current model time: [{}]",
+                        loadTask.path, loadTask.lastModifyTime, timeStamp);
 
                 Model model = loadTask.modelClass.newInstance();
                 if (model.loadFromStream(new InputStreamReader(ossObj.getObjectContent()))) {

+ 2 - 3
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/score/model/ModelManager.java

@@ -11,7 +11,6 @@ import com.ctrip.framework.apollo.ConfigService;
 import lombok.extern.slf4j.Slf4j;
 
 import java.io.IOException;
-import java.io.InputStreamReader;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.concurrent.Executors;
@@ -202,8 +201,8 @@ public class ModelManager {
             ossObj = client.getObject(bucketName, loadTask.path);
             long timeStamp = ossObj.getObjectMetadata().getLastModified().getTime();
             if (loadTask.lastModifyTime < timeStamp || isForceLoads) {
-                log.info("model file changed, ready to update, last modify: [{}], current model time: [{}]",
-                        loadTask.lastModifyTime, timeStamp);
+                log.info("model file changed: [{}], ready to update, last modify: [{}], current model time: [{}]",
+                        loadTask.path, loadTask.lastModifyTime, timeStamp);
 
                 Model model = loadTask.modelClass.newInstance();
                 model.setParams(loadTask.params);

+ 6 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/score/model/XGBRegressionModel.java

@@ -11,6 +11,7 @@ import org.slf4j.LoggerFactory;
 import java.io.File;
 import java.io.InputStream;
 import java.io.InputStreamReader;
+import java.util.List;
 import java.util.Map;
 import java.util.UUID;
 
@@ -65,6 +66,11 @@ public class XGBRegressionModel extends Model {
         String absolutePath = new File(modelDir).getAbsolutePath();
         XGBoostRegressionModel model2 = XGBoostRegressionModel.load("file://" + absolutePath);
         model2.setMissing(0.0f);
+        if (params != null) {
+            Object value = params.get("features");
+            List<String> features = (List<String>) value;
+            setFeatures(features.toArray(new String[features.size()]));
+        }
         this.model = model2;
         return true;
     }

+ 6 - 1
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/score/model/XGBoostModel.java

@@ -6,13 +6,13 @@ import com.tzld.piaoquan.recommend.server.util.PropertiesUtil;
 import ml.dmlc.xgboost4j.scala.DMatrix;
 import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
 import org.apache.commons.lang.math.NumberUtils;
-import org.apache.commons.lang3.StringUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.io.File;
 import java.io.InputStream;
 import java.io.InputStreamReader;
+import java.util.List;
 import java.util.Map;
 import java.util.UUID;
 
@@ -68,6 +68,11 @@ public class XGBoostModel extends Model {
         String absolutePath = new File(modelDir).getAbsolutePath();
         XGBoostClassificationModel model2 = XGBoostClassificationModel.load("file://" + absolutePath);
         model2.setMissing(0.0f);
+        if (params != null) {
+            Object value = params.get("features");
+            List<String> features = (List<String>) value;
+            setFeatures(features.toArray(new String[features.size()]));
+        }
         this.model = model2;
         return true;
     }