|
@@ -4,10 +4,10 @@ import lombok.extern.slf4j.Slf4j;
|
|
|
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
|
|
|
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier;
|
|
|
import org.apache.commons.lang.math.NumberUtils;
|
|
|
-import org.apache.commons.lang3.RandomUtils;
|
|
|
import org.apache.commons.lang3.StringUtils;
|
|
|
import org.apache.spark.api.java.JavaRDD;
|
|
|
import org.apache.spark.api.java.JavaSparkContext;
|
|
|
+import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
|
|
|
import org.apache.spark.ml.feature.VectorAssembler;
|
|
|
import org.apache.spark.sql.Dataset;
|
|
|
import org.apache.spark.sql.Row;
|
|
@@ -31,111 +31,22 @@ public class XGBoostTrainLocalTest {
|
|
|
public static void main(String[] args) {
|
|
|
try {
|
|
|
|
|
|
- String[] features = {"cpa",
|
|
|
- "b2_12h_ctr",
|
|
|
- "b2_12h_ctcvr",
|
|
|
- "b2_12h_cvr",
|
|
|
- "b2_12h_conver",
|
|
|
- "b2_12h_click",
|
|
|
- "b2_12h_conver*log(view)",
|
|
|
- "b2_12h_conver*ctcvr",
|
|
|
- "b2_7d_ctr",
|
|
|
- "b2_7d_ctcvr",
|
|
|
- "b2_7d_cvr",
|
|
|
- "b2_7d_conver",
|
|
|
- "b2_7d_click",
|
|
|
- "b2_7d_conver*log(view)",
|
|
|
- "b2_7d_conver*ctcvr"
|
|
|
- };
|
|
|
-
|
|
|
-
|
|
|
- SparkSession spark = SparkSession.builder()
|
|
|
- .appName("XGBoostTrain")
|
|
|
- .master("local")
|
|
|
- .getOrCreate();
|
|
|
-
|
|
|
- JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
|
|
|
- String file = "/dw/recommend/model/33_ad_train_data_v4/20240726/part-00099.gz";
|
|
|
- file = "/Users/dingyunpeng/Desktop/part-00099.gz";
|
|
|
- JavaRDD<String> rdd = jsc.textFile(file);
|
|
|
-
|
|
|
- // 将 RDD[LabeledPoint] 转换为 JavaRDD<Row>
|
|
|
-// JavaRDD<Row> rowRDD = rdd.map(s -> {
|
|
|
-// String[] line = StringUtils.split(s, '\t');
|
|
|
-// int label = NumberUtils.toInt(line[0]);
|
|
|
-// // 选特征
|
|
|
-// Map<String, Double> map = new HashMap<>();
|
|
|
-// for (int i = 1; i < line.length; i++) {
|
|
|
-// String[] fv = StringUtils.split(line[i], ':');
|
|
|
-// map.put(fv[0], NumberUtils.toDouble(fv[1], 0.0));
|
|
|
-// }
|
|
|
-//
|
|
|
-// int[] indices = new int[features.length];
|
|
|
-// double[] values = new double[features.length];
|
|
|
-// for (int i = 0; i < features.length; i++) {
|
|
|
-// indices[i] = i;
|
|
|
-// values[i] = map.getOrDefault(features[i], 0.0);
|
|
|
-// }
|
|
|
-// SparseVector vector = new SparseVector(indices.length, indices, values);
|
|
|
-// return RowFactory.create(label, vector);
|
|
|
-// });
|
|
|
-
|
|
|
- JavaRDD<Row> rowRDD = rdd.map(s -> {
|
|
|
- String[] line = StringUtils.split(s, '\t');
|
|
|
- int label = NumberUtils.toInt(line[0]);
|
|
|
- // 选特征
|
|
|
- Map<String, Double> map = new HashMap<>();
|
|
|
- for (int i = 1; i < line.length; i++) {
|
|
|
- String[] fv = StringUtils.split(line[i], ':');
|
|
|
- map.put(fv[0], NumberUtils.toDouble(fv[1], 0.0));
|
|
|
- }
|
|
|
-
|
|
|
- Object[] v = new Object[features.length + 1];
|
|
|
- v[0] = label;
|
|
|
- v[0] = RandomUtils.nextInt(0, 2);
|
|
|
- for (int i = 0; i < features.length; i++) {
|
|
|
- v[i + 1] = map.getOrDefault(features[i], 0.0d);
|
|
|
- }
|
|
|
-
|
|
|
- return RowFactory.create(v);
|
|
|
- });
|
|
|
-
|
|
|
- log.info("rowRDD count {}", rowRDD.count());
|
|
|
- // 将 JavaRDD<Row> 转换为 Dataset<Row>
|
|
|
- List<StructField> fields = new ArrayList<>();
|
|
|
- fields.add(DataTypes.createStructField("label", DataTypes.IntegerType, true));
|
|
|
- for (String f : features) {
|
|
|
- fields.add(DataTypes.createStructField(f, DataTypes.DoubleType, true));
|
|
|
- }
|
|
|
- StructType schema = DataTypes.createStructType(fields);
|
|
|
- Dataset<Row> dataset = spark.createDataFrame(rowRDD, schema);
|
|
|
-
|
|
|
- VectorAssembler assembler = new VectorAssembler()
|
|
|
- .setInputCols(features)
|
|
|
- .setOutputCol("features");
|
|
|
-
|
|
|
- Dataset<Row> assembledData = assembler.transform(dataset);
|
|
|
- assembledData.show();
|
|
|
- // 划分训练集和测试集
|
|
|
- Dataset<Row>[] splits = assembledData.randomSplit(new double[]{0.7, 0.3});
|
|
|
- Dataset<Row> trainData = splits[0];
|
|
|
- trainData.show(500);
|
|
|
- Dataset<Row> testData = splits[1];
|
|
|
- testData.show(500);
|
|
|
-
|
|
|
- // 参数
|
|
|
-
|
|
|
+ Dataset<Row> assembledData = dataset("/Users/dingyunpeng/Desktop/part-00099.gz");
|
|
|
|
|
|
// 创建 XGBoostClassifier 对象
|
|
|
XGBoostClassifier xgbClassifier = new XGBoostClassifier()
|
|
|
- .setEta(0.1f)
|
|
|
+ .setEta(0.01f)
|
|
|
+ .setSubsample(0.8)
|
|
|
+ .setColsampleBytree(0.8)
|
|
|
+ .setScalePosWeight(1)
|
|
|
+ .setSeed(2024)
|
|
|
.setMissing(0.0f)
|
|
|
.setFeaturesCol("features")
|
|
|
.setLabelCol("label")
|
|
|
.setMaxDepth(5)
|
|
|
.setObjective("binary:logistic")
|
|
|
.setNthread(1)
|
|
|
- .setNumRound(5)
|
|
|
+ .setNumRound(100)
|
|
|
.setNumWorkers(1);
|
|
|
|
|
|
|
|
@@ -143,12 +54,88 @@ public class XGBoostTrainLocalTest {
|
|
|
XGBoostClassificationModel model = xgbClassifier.fit(assembledData);
|
|
|
|
|
|
// 显示预测结果
|
|
|
- Dataset<Row> predictions = model.transform(assembledData);
|
|
|
- predictions.select("label", "prediction", "features", "rawPrediction", "probability").show(500);
|
|
|
+ Dataset<Row> predictData = dataset("/Users/dingyunpeng/Desktop/part-00099.gz");
|
|
|
+ Dataset<Row> predictions = model.transform(predictData);
|
|
|
+ predictions.select("label", "prediction", "rawPrediction", "probability", "features").show();
|
|
|
|
|
|
+ // 计算AUC
|
|
|
+ Dataset<Row> selected = predictions.select("label", "rawPrediction");
|
|
|
+ BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator()
|
|
|
+ .setLabelCol("label")
|
|
|
+ .setRawPredictionCol("rawPrediction")
|
|
|
+ .setMetricName("areaUnderROC");
|
|
|
+ double auc = evaluator.evaluate(selected);
|
|
|
+
|
|
|
+ log.info("AUC: {}", auc);
|
|
|
|
|
|
} catch (Throwable e) {
|
|
|
log.error("", e);
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ private static Dataset<Row> dataset(String path) {
|
|
|
+ String[] features = {"cpa",
|
|
|
+ "b2_12h_ctr",
|
|
|
+ "b2_12h_ctcvr",
|
|
|
+ "b2_12h_cvr",
|
|
|
+ "b2_12h_conver",
|
|
|
+ "b2_12h_click",
|
|
|
+ "b2_12h_conver*log(view)",
|
|
|
+ "b2_12h_conver*ctcvr",
|
|
|
+ "b2_7d_ctr",
|
|
|
+ "b2_7d_ctcvr",
|
|
|
+ "b2_7d_cvr",
|
|
|
+ "b2_7d_conver",
|
|
|
+ "b2_7d_click",
|
|
|
+ "b2_7d_conver*log(view)",
|
|
|
+ "b2_7d_conver*ctcvr"
|
|
|
+ };
|
|
|
+
|
|
|
+
|
|
|
+ SparkSession spark = SparkSession.builder()
|
|
|
+ .appName("XGBoostTrain")
|
|
|
+ .master("local")
|
|
|
+ .getOrCreate();
|
|
|
+
|
|
|
+ JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
|
|
|
+ String file = path;
|
|
|
+ JavaRDD<String> rdd = jsc.textFile(file);
|
|
|
+
|
|
|
+ JavaRDD<Row> rowRDD = rdd.map(s -> {
|
|
|
+ String[] line = StringUtils.split(s, '\t');
|
|
|
+ int label = NumberUtils.toInt(line[0]);
|
|
|
+ // 选特征
|
|
|
+ Map<String, Double> map = new HashMap<>();
|
|
|
+ for (int i = 1; i < line.length; i++) {
|
|
|
+ String[] fv = StringUtils.split(line[i], ':');
|
|
|
+ map.put(fv[0], NumberUtils.toDouble(fv[1], 0.0));
|
|
|
+ }
|
|
|
+
|
|
|
+ Object[] v = new Object[features.length + 1];
|
|
|
+ v[0] = label;
|
|
|
+ for (int i = 0; i < features.length; i++) {
|
|
|
+ v[i + 1] = map.getOrDefault(features[i], 0.0d);
|
|
|
+ }
|
|
|
+
|
|
|
+ return RowFactory.create(v);
|
|
|
+ });
|
|
|
+
|
|
|
+ log.info("rowRDD count {}", rowRDD.count());
|
|
|
+ // 将 JavaRDD<Row> 转换为 Dataset<Row>
|
|
|
+ List<StructField> fields = new ArrayList<>();
|
|
|
+ fields.add(DataTypes.createStructField("label", DataTypes.IntegerType, true));
|
|
|
+ for (String f : features) {
|
|
|
+ fields.add(DataTypes.createStructField(f, DataTypes.DoubleType, true));
|
|
|
+ }
|
|
|
+ StructType schema = DataTypes.createStructType(fields);
|
|
|
+ Dataset<Row> dataset = spark.createDataFrame(rowRDD, schema);
|
|
|
+
|
|
|
+ VectorAssembler assembler = new VectorAssembler()
|
|
|
+ .setInputCols(features)
|
|
|
+ .setOutputCol("features");
|
|
|
+
|
|
|
+ Dataset<Row> assembledData = assembler.transform(dataset);
|
|
|
+ assembledData.show();
|
|
|
+ return assembledData;
|
|
|
+ }
|
|
|
}
|