丁云鹏 8 months ago
parent
commit
7047482d2e

+ 42 - 2
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/xgboost/XGBoostTrain.java

@@ -1,6 +1,7 @@
 package com.tzld.piaoquan.recommend.model.produce.xgboost;
 
 import com.aliyun.odps.utils.StringUtils;
+import com.google.common.collect.Lists;
 import lombok.extern.slf4j.Slf4j;
 import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
 import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier;
@@ -18,15 +19,45 @@ import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
 
 import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 
 /**
  * @author dyp
  */
 @Slf4j
 public class XGBoostTrain {
+
     public static void main(String[] args) {
         try {
+
+            List<String> features = Lists.newArrayList("cpa",
+                    "b2_1h_ctr",
+                    "b2_1h_ctcvr",
+                    "b2_1h_cvr",
+                    "b2_1h_conver",
+                    "b2_1h_click",
+                    "b2_1h_conver*log(view)",
+                    "b2_1h_conver*ctcvr",
+                    "b2_2h_ctr",
+                    "b2_2h_ctcvr",
+                    "b2_2h_cvr",
+                    "b2_2h_conver",
+                    "b2_2h_click",
+                    "b2_2h_conver*log(view)",
+                    "b2_2h_conver*ctcvr",
+                    "b2_3h_ctr",
+                    "b2_3h_ctcvr",
+                    "b2_3h_cvr",
+                    "b2_3h_conver",
+                    "b2_3h_click",
+                    "b2_3h_conver*log(view)",
+                    "b2_3h_conver*ctcvr",
+                    "b2_6h_ctr",
+                    "b2_6h_ctcvr");
+
+
             SparkSession spark = SparkSession.builder()
                     .appName("XGBoostTrain")
                     //.master("local")
@@ -42,11 +73,18 @@ public class XGBoostTrain {
                 int label = NumberUtils.toInt(line[0]);
                 int[] indices = new int[line.length - 1];
                 double[] values = new double[line.length - 1];
+
+                // 选特征
+                Map<String, Double> map = new HashMap<>();
                 for (int i = 1; i < line.length; i++) {
                     String[] fv = StringUtils.split(":");
-                    indices[i - 1] = i - 1;
-                    values[i - 1] = NumberUtils.toDouble(fv[1], 0.0);
+                    map.put(fv[0], NumberUtils.toDouble(fv[1], 0.0));
+                }
+                for (int i = 0; i < features.size(); i++) {
+                    indices[i] = i;
+                    values[i] = map.getOrDefault(features.get(i), 0.0);
                 }
+
                 SparseVector vector = new SparseVector(indices.length, indices, values);
                 return RowFactory.create(label, vector);
             });
@@ -71,6 +109,7 @@ public class XGBoostTrain {
             // 创建 XGBoostClassifier 对象
             XGBoostClassifier xgbClassifier = new XGBoostClassifier()
                     .setEta(0.1f)
+                    .setMissing(0.0f)
                     .setFeaturesCol("features")
                     .setLabelCol("label")
                     .setMaxDepth(5)
@@ -87,6 +126,7 @@ public class XGBoostTrain {
             Dataset<Row> predictions = model.transform(testData);
             predictions.select("label", "prediction").show();
 
+
         } catch (Exception e) {
             log.error("", e);
         }