|
@@ -12,6 +12,8 @@ import org.apache.spark.api.java.JavaRDD;
|
|
|
import org.apache.spark.api.java.JavaSparkContext;
|
|
|
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
|
|
|
import org.apache.spark.ml.feature.VectorAssembler;
|
|
|
+import org.apache.spark.ml.linalg.Vector;
|
|
|
+import org.apache.spark.ml.linalg.Vectors;
|
|
|
import org.apache.spark.sql.Dataset;
|
|
|
import org.apache.spark.sql.Row;
|
|
|
import org.apache.spark.sql.RowFactory;
|
|
@@ -32,7 +34,98 @@ import java.util.Map;
|
|
|
public class XGBoostPredictLocalTest {
|
|
|
|
|
|
public static void main(String[] args) {
|
|
|
+ //batchTest();
|
|
|
+ singleTest();
|
|
|
+ }
|
|
|
+
|
|
|
+ private static void singleTest() {
|
|
|
+ String[] features = {
|
|
|
+ "cpa",
|
|
|
+ "b2_1h_ctr",
|
|
|
+ "b2_1h_ctcvr",
|
|
|
+ "b2_1h_cvr",
|
|
|
+ "b2_1h_conver",
|
|
|
+ "b2_1h_click",
|
|
|
+ "b2_1h_conver*log(view)",
|
|
|
+ "b2_1h_conver*ctcvr",
|
|
|
+ "b2_2h_ctr",
|
|
|
+ "b2_2h_ctcvr",
|
|
|
+ "b2_2h_cvr",
|
|
|
+ "b2_2h_conver",
|
|
|
+ "b2_2h_click",
|
|
|
+ "b2_2h_conver*log(view)",
|
|
|
+ "b2_2h_conver*ctcvr",
|
|
|
+ "b2_3h_ctr",
|
|
|
+ "b2_3h_ctcvr",
|
|
|
+ "b2_3h_cvr",
|
|
|
+ "b2_3h_conver",
|
|
|
+ "b2_3h_click",
|
|
|
+ "b2_3h_conver*log(view)",
|
|
|
+ "b2_3h_conver*ctcvr",
|
|
|
+ "b2_6h_ctr",
|
|
|
+ "b2_6h_ctcvr"
|
|
|
+ };
|
|
|
+
|
|
|
+ Map<String, String> featureMap = new HashMap<>();
|
|
|
+ featureMap.put("cpa", "0.1");
|
|
|
+ featureMap.put("b2_1h_ctr", "0");
|
|
|
+ featureMap.put("b2_1h_ctcvr", "0");
|
|
|
+ featureMap.put("b2_1h_cvr", "0");
|
|
|
+ featureMap.put("b2_1h_conver", "0");
|
|
|
+ featureMap.put("b2_1h_click", "0");
|
|
|
+ featureMap.put("b2_1h_conver*log(view)", "0");
|
|
|
+ featureMap.put("b2_1h_conver*ctcvr", "0");
|
|
|
+ featureMap.put("b2_2h_ctr", "0");
|
|
|
+ featureMap.put("b2_2h_ctcvr", "0");
|
|
|
+ featureMap.put("b2_2h_cvr", "0");
|
|
|
+ featureMap.put("b2_2h_conver", "0");
|
|
|
+ featureMap.put("b2_2h_click", "0");
|
|
|
+ featureMap.put("b2_2h_conver*log(view)", "0");
|
|
|
+ featureMap.put("b2_2h_conver*ctcvr", "0");
|
|
|
+ featureMap.put("b2_3h_ctr", "0.89");
|
|
|
+ featureMap.put("b2_3h_ctcvr", "0");
|
|
|
+ featureMap.put("b2_3h_cvr", "0");
|
|
|
+ featureMap.put("b2_3h_conver", "0");
|
|
|
+ featureMap.put("b2_3h_click", "0.01");
|
|
|
+ featureMap.put("b2_3h_conver*log(view)", "0");
|
|
|
+ featureMap.put("b2_3h_conver*ctcvr", "0");
|
|
|
+ featureMap.put("b2_6h_ctr", "0.88");
|
|
|
+ featureMap.put("b2_6h_ctcvr", "0");
|
|
|
+
|
|
|
+ double[] values = new double[features.length];
|
|
|
+ for (int i = 0; i < features.length; i++) {
|
|
|
+ double v = NumberUtils.toDouble(featureMap.getOrDefault(features[i], "0.0"), 0.0);
|
|
|
+ values[i] = v;
|
|
|
+ }
|
|
|
+ Vector v = Vectors.dense(values);
|
|
|
+
|
|
|
+
|
|
|
+ SparkConf sparkConf = new SparkConf()
|
|
|
+ .setMaster("local")
|
|
|
+ .setAppName("XGBoostPredict");
|
|
|
+ JavaSparkContext jsc = new JavaSparkContext(sparkConf);
|
|
|
+
|
|
|
+
|
|
|
+ String bucketName = "art-test-video";
|
|
|
+ String objectName = "test/model.tar.gz";
|
|
|
+ OSSService ossService = new OSSService();
|
|
|
+
|
|
|
+ String gzPath = "/Users/dingyunpeng/Desktop/model2.tar.gz";
|
|
|
+ ossService.download(bucketName, gzPath, objectName);
|
|
|
+ String modelDir = "/Users/dingyunpeng/Desktop/modelpredict";
|
|
|
+ CompressUtil.decompressGzFile(gzPath, modelDir);
|
|
|
+
|
|
|
+ XGBoostClassificationModel model = XGBoostClassificationModel.load("file://" + modelDir);
|
|
|
+ model.setMissing(0.0f)
|
|
|
+ .setFeaturesCol("features");
|
|
|
+ double score = model.predict(v);
|
|
|
+
|
|
|
+ log.info("model.predict {}", score);
|
|
|
+ }
|
|
|
+
|
|
|
+ private static void batchTest() {
|
|
|
try {
|
|
|
+
|
|
|
SparkConf sparkConf = new SparkConf()
|
|
|
.setMaster("local")
|
|
|
.setAppName("XGBoostPredict");
|