|
@@ -0,0 +1,140 @@
|
|
|
+package com.tzld.piaoquan.recommend.model.produce.xgboost;
|
|
|
+
|
|
|
+import com.tzld.piaoquan.recommend.model.produce.service.OSSService;
|
|
|
+import com.tzld.piaoquan.recommend.model.produce.util.CompressUtil;
|
|
|
+import lombok.extern.slf4j.Slf4j;
|
|
|
+import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
|
|
|
+import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier;
|
|
|
+import org.apache.commons.lang.math.NumberUtils;
|
|
|
+import org.apache.commons.lang3.StringUtils;
|
|
|
+import org.apache.spark.SparkConf;
|
|
|
+import org.apache.spark.api.java.JavaRDD;
|
|
|
+import org.apache.spark.api.java.JavaSparkContext;
|
|
|
+import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
|
|
|
+import org.apache.spark.ml.feature.VectorAssembler;
|
|
|
+import org.apache.spark.sql.Dataset;
|
|
|
+import org.apache.spark.sql.Row;
|
|
|
+import org.apache.spark.sql.RowFactory;
|
|
|
+import org.apache.spark.sql.SparkSession;
|
|
|
+import org.apache.spark.sql.types.DataTypes;
|
|
|
+import org.apache.spark.sql.types.StructField;
|
|
|
+import org.apache.spark.sql.types.StructType;
|
|
|
+
|
|
|
+import java.util.ArrayList;
|
|
|
+import java.util.HashMap;
|
|
|
+import java.util.List;
|
|
|
+import java.util.Map;
|
|
|
+
|
|
|
+/**
|
|
|
+ * @author dyp
|
|
|
+ */
|
|
|
+@Slf4j
|
|
|
+public class XGBoostPredictLocalTest {
|
|
|
+
|
|
|
+ public static void main(String[] args) {
|
|
|
+ try {
|
|
|
+ SparkConf sparkConf = new SparkConf()
|
|
|
+ .setMaster("local")
|
|
|
+ .setAppName("XGBoostPredict");
|
|
|
+ JavaSparkContext jsc = new JavaSparkContext(sparkConf);
|
|
|
+
|
|
|
+
|
|
|
+ String bucketName = "art-test-video";
|
|
|
+ String objectName = "test/model.tar.gz";
|
|
|
+ OSSService ossService = new OSSService();
|
|
|
+
|
|
|
+ String gzPath = "/Users/dingyunpeng/Desktop/model2.tar.gz";
|
|
|
+ ossService.download(bucketName, gzPath, objectName);
|
|
|
+ String modelDir = "/Users/dingyunpeng/Desktop/modelpredict";
|
|
|
+ CompressUtil.decompressGzFile(gzPath, modelDir);
|
|
|
+
|
|
|
+ XGBoostClassificationModel model = XGBoostClassificationModel.load("file://" + modelDir);
|
|
|
+ model.setMissing(0.0f)
|
|
|
+ .setFeaturesCol("features");
|
|
|
+
|
|
|
+
|
|
|
+ // 预测
|
|
|
+ Dataset<Row> predictData = dataset("/Users/dingyunpeng/Desktop/part-00099.gz");
|
|
|
+ predictData.show();
|
|
|
+ Dataset<Row> predictions = model.transform(predictData);
|
|
|
+ predictions.show();
|
|
|
+
|
|
|
+ } catch (Throwable e) {
|
|
|
+ log.error("", e);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private static Dataset<Row> dataset(String path) {
|
|
|
+ String[] features = {
|
|
|
+ "cpa",
|
|
|
+ "b2_1h_ctr",
|
|
|
+ "b2_1h_ctcvr",
|
|
|
+ "b2_1h_cvr",
|
|
|
+ "b2_1h_conver",
|
|
|
+ "b2_1h_click",
|
|
|
+ "b2_1h_conver*log(view)",
|
|
|
+ "b2_1h_conver*ctcvr",
|
|
|
+ "b2_2h_ctr",
|
|
|
+ "b2_2h_ctcvr",
|
|
|
+ "b2_2h_cvr",
|
|
|
+ "b2_2h_conver",
|
|
|
+ "b2_2h_click",
|
|
|
+ "b2_2h_conver*log(view)",
|
|
|
+ "b2_2h_conver*ctcvr",
|
|
|
+ "b2_3h_ctr",
|
|
|
+ "b2_3h_ctcvr",
|
|
|
+ "b2_3h_cvr",
|
|
|
+ "b2_3h_conver",
|
|
|
+ "b2_3h_click",
|
|
|
+ "b2_3h_conver*log(view)",
|
|
|
+ "b2_3h_conver*ctcvr",
|
|
|
+ "b2_6h_ctr",
|
|
|
+ "b2_6h_ctcvr"
|
|
|
+ };
|
|
|
+
|
|
|
+
|
|
|
+ SparkSession spark = SparkSession.builder()
|
|
|
+ .appName("XGBoostTrain")
|
|
|
+ .master("local")
|
|
|
+ .getOrCreate();
|
|
|
+
|
|
|
+ JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
|
|
|
+ String file = path;
|
|
|
+ JavaRDD<String> rdd = jsc.textFile(file);
|
|
|
+
|
|
|
+ JavaRDD<Row> rowRDD = rdd.map(s -> {
|
|
|
+ String[] line = StringUtils.split(s, '\t');
|
|
|
+ double label = NumberUtils.toDouble(line[0]);
|
|
|
+ // 选特征
|
|
|
+ Map<String, Double> map = new HashMap<>();
|
|
|
+ for (int i = 1; i < line.length; i++) {
|
|
|
+ String[] fv = StringUtils.split(line[i], ':');
|
|
|
+ map.put(fv[0], NumberUtils.toDouble(fv[1], 0.0));
|
|
|
+ }
|
|
|
+
|
|
|
+ Object[] v = new Object[features.length + 1];
|
|
|
+ v[0] = label;
|
|
|
+ for (int i = 0; i < features.length; i++) {
|
|
|
+ v[i + 1] = map.getOrDefault(features[i], 0.0d);
|
|
|
+ }
|
|
|
+ return RowFactory.create(v);
|
|
|
+ });
|
|
|
+
|
|
|
+ log.info("rowRDD count {}", rowRDD.count());
|
|
|
+ // 将 JavaRDD<Row> 转换为 Dataset<Row>
|
|
|
+ List<StructField> fields = new ArrayList<>();
|
|
|
+ fields.add(DataTypes.createStructField("label", DataTypes.DoubleType, true));
|
|
|
+ for (String f : features) {
|
|
|
+ fields.add(DataTypes.createStructField(f, DataTypes.DoubleType, true));
|
|
|
+ }
|
|
|
+ StructType schema = DataTypes.createStructType(fields);
|
|
|
+ Dataset<Row> dataset = spark.createDataFrame(rowRDD, schema);
|
|
|
+
|
|
|
+ VectorAssembler assembler = new VectorAssembler()
|
|
|
+ .setInputCols(features)
|
|
|
+ .setOutputCol("features");
|
|
|
+
|
|
|
+ Dataset<Row> assembledData = assembler.transform(dataset);
|
|
|
+ return assembledData;
|
|
|
+ }
|
|
|
+}
|