丁云鹏 9 months ago
parent
commit
7bbc8a8dcb

+ 14 - 0
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/Demo.java

@@ -0,0 +1,14 @@
+package com.tzld.piaoquan.recommend.model.produce;
+
+import java.io.File;
+
+/**
+ * @author dyp
+ */
+public class Demo {
+    public static void main(String[] args) {
+        String rpath = "xgboost";
+        String apath = new File(rpath).getAbsolutePath();
+        System.out.println(apath);
+    }
+}

+ 9 - 4
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/xgboost/XGBoostTrainLocalTest.java

@@ -1,6 +1,7 @@
 package com.tzld.piaoquan.recommend.model.produce.xgboost;
 
 import lombok.extern.slf4j.Slf4j;
+import ml.dmlc.xgboost4j.scala.DMatrix;
 import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
 import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier;
 import org.apache.commons.lang.math.NumberUtils;
@@ -47,6 +48,7 @@ public class XGBoostTrainLocalTest {
                     .setObjective("binary:logistic")
                     .setNthread(1)
                     .setNumRound(100)
+                    //.setNumClass(2)
                     .setNumWorkers(1);
             Dataset<Row> trainData = dataset("/Users/dingyunpeng/Desktop/part-00099.gz");
             trainData.show();
@@ -124,13 +126,16 @@ public class XGBoostTrainLocalTest {
             featureMap.put("b2_6h_ctr", "0.88");
             featureMap.put("b2_6h_ctcvr", "0");
 
-            double[] values = new double[features.length];
+            float[] values = new float[features.length];
             for (int i = 0; i < features.length; i++) {
-                double v = NumberUtils.toDouble(featureMap.getOrDefault(features[i], "0.0"), 0.0);
+                float v = NumberUtils.toFloat(featureMap.getOrDefault(features[i], "0.0"), 0.0f);
                 values[i] = v;
             }
-            Vector v = Vectors.dense(values);
-            double result = model.predict(v);
+
+
+            DMatrix dm = new DMatrix(values, 1, features.length, 0.0f);
+            float[][] result = model._booster().predict(dm,false,100);
+
             log.info("model.predict {}", result);
 
         } catch (Throwable e) {