丁云鹏 9 ay önce
ebeveyn
işleme
17c7d104f3

+ 4 - 0
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/service/HDFSService.java

@@ -31,4 +31,8 @@ public class HDFSService implements Serializable {
         return fSystem.delete(new Path(path));
     }
 
+    public void get(){
+
+    }
+
 }

+ 1 - 2
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/service/XGBoostService.java

@@ -97,7 +97,7 @@ public class XGBoostService {
             Dataset<Row> predictData = dataset(path);
             predictData.show();
             Dataset<Row> predictions = model.transform(predictData);
-            predictions.show(50000);
+            predictions.show();
 
             // 计算AUC
             BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator()
@@ -142,7 +142,6 @@ public class XGBoostService {
 
         SparkSession spark = SparkSession.builder()
                 .appName("XGBoostTrain")
-                .master("local")
                 .getOrCreate();
 
         JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());

+ 67 - 1
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/xgboost/XGBoostTrainLocalTest.java

@@ -10,6 +10,8 @@ import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
 import org.apache.spark.ml.feature.VectorAssembler;
+import org.apache.spark.ml.linalg.Vector;
+import org.apache.spark.ml.linalg.Vectors;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
@@ -67,6 +69,70 @@ public class XGBoostTrainLocalTest {
 
             log.info("AUC: {}", auc);
 
+
+
+            String[] features = {
+                    "cpa",
+                    "b2_1h_ctr",
+                    "b2_1h_ctcvr",
+                    "b2_1h_cvr",
+                    "b2_1h_conver",
+                    "b2_1h_click",
+                    "b2_1h_conver*log(view)",
+                    "b2_1h_conver*ctcvr",
+                    "b2_2h_ctr",
+                    "b2_2h_ctcvr",
+                    "b2_2h_cvr",
+                    "b2_2h_conver",
+                    "b2_2h_click",
+                    "b2_2h_conver*log(view)",
+                    "b2_2h_conver*ctcvr",
+                    "b2_3h_ctr",
+                    "b2_3h_ctcvr",
+                    "b2_3h_cvr",
+                    "b2_3h_conver",
+                    "b2_3h_click",
+                    "b2_3h_conver*log(view)",
+                    "b2_3h_conver*ctcvr",
+                    "b2_6h_ctr",
+                    "b2_6h_ctcvr"
+            };
+
+            Map<String, String> featureMap = new HashMap<>();
+            featureMap.put("cpa", "0.1");
+            featureMap.put("b2_1h_ctr", "0");
+            featureMap.put("b2_1h_ctcvr", "0");
+            featureMap.put("b2_1h_cvr", "0");
+            featureMap.put("b2_1h_conver", "0");
+            featureMap.put("b2_1h_click", "0");
+            featureMap.put("b2_1h_conver*log(view)", "0");
+            featureMap.put("b2_1h_conver*ctcvr", "0");
+            featureMap.put("b2_2h_ctr", "0");
+            featureMap.put("b2_2h_ctcvr", "0");
+            featureMap.put("b2_2h_cvr", "0");
+            featureMap.put("b2_2h_conver", "0");
+            featureMap.put("b2_2h_click", "0");
+            featureMap.put("b2_2h_conver*log(view)", "0");
+            featureMap.put("b2_2h_conver*ctcvr", "0");
+            featureMap.put("b2_3h_ctr", "0.89");
+            featureMap.put("b2_3h_ctcvr", "0");
+            featureMap.put("b2_3h_cvr", "0");
+            featureMap.put("b2_3h_conver", "0");
+            featureMap.put("b2_3h_click", "0.01");
+            featureMap.put("b2_3h_conver*log(view)", "0");
+            featureMap.put("b2_3h_conver*ctcvr", "0");
+            featureMap.put("b2_6h_ctr", "0.88");
+            featureMap.put("b2_6h_ctcvr", "0");
+
+            double[] values = new double[features.length];
+            for (int i = 0; i < features.length; i++) {
+                double v = NumberUtils.toDouble(featureMap.getOrDefault(features[i], "0.0"), 0.0);
+                values[i] = v;
+            }
+            Vector v = Vectors.dense(values);
+            double result = model.predict(v);
+            log.info("model.predict {}", result);
+
         } catch (Throwable e) {
             log.error("", e);
         }
@@ -125,7 +191,7 @@ public class XGBoostTrainLocalTest {
             for (int i = 0; i < features.length; i++) {
                 v[i + 1] = map.getOrDefault(features[i], 0.0d);
             }
-            //v[0] = (double) v[1] > 0.05 ? 1.0 : 0.0;
+            v[0] = (double) v[1] > 0.05 ? 1.0 : 0.0;
             return RowFactory.create(v);
         });