丁云鹏 8 months ago
parent
commit
c213cbf361

+ 17 - 53
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/service/OSSService.java

@@ -9,10 +9,6 @@ import lombok.extern.slf4j.Slf4j;
 
 import java.io.Serializable;
 import java.util.List;
-import java.util.concurrent.CountDownLatch;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Executors;
-import java.util.concurrent.TimeUnit;
 
 /**
  * @author dyp
@@ -23,54 +19,22 @@ public class OSSService implements Serializable {
     private String accessKey = "XLi5YUJusVwbbQOaGeGsaRJ1Qyzbui";
     private String endpoint = "https://oss-cn-hangzhou-internal.aliyuncs.com";
 
-    public void transToDeepColdArchive(String bucketName, List<String> objectNames) {
-        OSS ossClient = new OSSClientBuilder().build(endpoint, accessId, accessKey);
-        for (String objectName : objectNames) {
-            try {
-                if (objectName.startsWith("http")) {
-                    continue;
-                }
-                CopyObjectRequest request = new CopyObjectRequest(bucketName, objectName, bucketName, objectName);
-                ObjectMetadata objectMetadata = new ObjectMetadata();
-                objectMetadata.setHeader("x-oss-storage-class", "DeepColdArchive");
-                request.setNewObjectMetadata(objectMetadata);
-                CopyObjectResult result = ossClient.copyObject(request);
-            } catch (Exception e) {
-                log.error("transToDeepColdArchive error {} {}", objectName, e.getMessage(), e);
-            }
-        }
-        if (ossClient != null) {
-            ossClient.shutdown();
-        }
-    }
-
-    public void transToDeepColdArchive2(String bucketName, List<String> objectNames) {
-        OSS ossClient = new OSSClientBuilder().build(endpoint, accessId, accessKey);
-        CountDownLatch cdl = new CountDownLatch(objectNames.size());
-        ExecutorService es = Executors.newFixedThreadPool(3);
-        for (String objectName : objectNames) {
-            es.submit(() -> {
-                try {
-                    if (!objectName.startsWith("http")) {
-                        CopyObjectRequest request = new CopyObjectRequest(bucketName, objectName, bucketName, objectName);
-                        ObjectMetadata objectMetadata = new ObjectMetadata();
-                        objectMetadata.setHeader("x-oss-storage-class", "DeepColdArchive");
-                        request.setNewObjectMetadata(objectMetadata);
-                        ossClient.copyObject(request);
-                    }
-                } catch (Exception e) {
-                    log.error("transToDeepColdArchive error {} {}", objectName, e.getMessage(), e);
-                }
-                cdl.countDown();
-            });
-        }
-        try {
-            cdl.await(1, TimeUnit.HOURS);
-        } catch (InterruptedException e) {
-            log.error("transToDeepColdArchive error", e);
-        }
-        if (ossClient != null) {
-            ossClient.shutdown();
-        }
+    public void upload(String bucketName, String srcPath, String orcPath) {
+//        OSS ossClient = new OSSClientBuilder().build(endpoint, accessId, accessKey);
+//        try {
+//            if (objectName.startsWith("http")) {
+//                continue;
+//            }
+//            CopyObjectRequest request = new CopyObjectRequest(bucketName, objectName, bucketName, objectName);
+//            ObjectMetadata objectMetadata = new ObjectMetadata();
+//            objectMetadata.setHeader("x-oss-storage-class", "DeepColdArchive");
+//            request.setNewObjectMetadata(objectMetadata);
+//            CopyObjectResult result = ossClient.copyObject(request);
+//        } catch (Exception e) {
+//            log.error("transToDeepColdArchive error {} {}", objectName, e.getMessage(), e);
+//        }
+//        if (ossClient != null) {
+//            ossClient.shutdown();
+//        }
     }
 }

+ 61 - 38
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/xgboost/XGBoostTrain.java

@@ -5,9 +5,11 @@ import lombok.extern.slf4j.Slf4j;
 import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
 import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier;
 import org.apache.commons.lang.math.NumberUtils;
+import org.apache.commons.lang3.RandomUtils;
 import org.apache.commons.lang3.StringUtils;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.feature.VectorAssembler;
 import org.apache.spark.ml.linalg.SparseVector;
 import org.apache.spark.ml.linalg.VectorUDT;
 import org.apache.spark.sql.Dataset;
@@ -18,10 +20,8 @@ import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
 
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
+import java.util.*;
+import java.util.stream.Collectors;
 
 /**
  * @author dyp
@@ -32,30 +32,22 @@ public class XGBoostTrain {
     public static void main(String[] args) {
         try {
 
-            List<String> features = Lists.newArrayList("cpa",
-                    "b2_1h_ctr",
-                    "b2_1h_ctcvr",
-                    "b2_1h_cvr",
-                    "b2_1h_conver",
-                    "b2_1h_click",
-                    "b2_1h_conver*log(view)",
-                    "b2_1h_conver*ctcvr",
-                    "b2_2h_ctr",
-                    "b2_2h_ctcvr",
-                    "b2_2h_cvr",
-                    "b2_2h_conver",
-                    "b2_2h_click",
-                    "b2_2h_conver*log(view)",
-                    "b2_2h_conver*ctcvr",
-                    "b2_3h_ctr",
-                    "b2_3h_ctcvr",
-                    "b2_3h_cvr",
-                    "b2_3h_conver",
-                    "b2_3h_click",
-                    "b2_3h_conver*log(view)",
-                    "b2_3h_conver*ctcvr",
-                    "b2_6h_ctr",
-                    "b2_6h_ctcvr");
+            String[] features = {"cpa",
+                    "b2_12h_ctr",
+                    "b2_12h_ctcvr",
+                    "b2_12h_cvr",
+                    "b2_12h_conver",
+                    "b2_12h_click",
+                    "b2_12h_conver*log(view)",
+                    "b2_12h_conver*ctcvr",
+                    "b2_7d_ctr",
+                    "b2_7d_ctcvr",
+                    "b2_7d_cvr",
+                    "b2_7d_conver",
+                    "b2_7d_click",
+                    "b2_7d_conver*log(view)",
+                    "b2_7d_conver*ctcvr"
+            };
 
 
             SparkSession spark = SparkSession.builder()
@@ -69,6 +61,26 @@ public class XGBoostTrain {
             JavaRDD<String> rdd = jsc.textFile(file);
 
             // 将 RDD[LabeledPoint] 转换为 JavaRDD<Row>
+//            JavaRDD<Row> rowRDD = rdd.map(s -> {
+//                String[] line = StringUtils.split(s, '\t');
+//                int label = NumberUtils.toInt(line[0]);
+//                // 选特征
+//                Map<String, Double> map = new HashMap<>();
+//                for (int i = 1; i < line.length; i++) {
+//                    String[] fv = StringUtils.split(line[i], ':');
+//                    map.put(fv[0], NumberUtils.toDouble(fv[1], 0.0));
+//                }
+//
+//                int[] indices = new int[features.length];
+//                double[] values = new double[features.length];
+//                for (int i = 0; i < features.length; i++) {
+//                    indices[i] = i;
+//                    values[i] = map.getOrDefault(features[i], 0.0);
+//                }
+//                SparseVector vector = new SparseVector(indices.length, indices, values);
+//                return RowFactory.create(label, vector);
+//            });
+
             JavaRDD<Row> rowRDD = rdd.map(s -> {
                 String[] line = StringUtils.split(s, '\t');
                 int label = NumberUtils.toInt(line[0]);
@@ -79,29 +91,40 @@ public class XGBoostTrain {
                     map.put(fv[0], NumberUtils.toDouble(fv[1], 0.0));
                 }
 
-                int[] indices = new int[features.size()];
-                double[] values = new double[features.size()];
-                for (int i = 0; i < features.size(); i++) {
-                    indices[i] = i;
-                    values[i] = map.getOrDefault(features.get(i), 0.0);
+                Object[] v = new Object[features.length + 1];
+                v[0] = label;
+                v[0] = RandomUtils.nextInt(0, 2);
+                double[] values = new double[features.length];
+                for (int i = 0; i < features.length; i++) {
+                    values[i] = map.getOrDefault(features[i], 0.0d);
+                    v[i + 1] = map.getOrDefault(features[i], 0.0d);
                 }
-                SparseVector vector = new SparseVector(indices.length, indices, values);
-                return RowFactory.create(label, vector);
+
+                return RowFactory.create(v);
             });
 
             log.info("rowRDD count {}", rowRDD.count());
             // 将 JavaRDD<Row> 转换为 Dataset<Row>
             List<StructField> fields = new ArrayList<>();
             fields.add(DataTypes.createStructField("label", DataTypes.IntegerType, true));
-            fields.add(DataTypes.createStructField("features", new VectorUDT(), true));
+            for (String f : features) {
+                fields.add(DataTypes.createStructField(f, DataTypes.DoubleType, true));
+            }
             StructType schema = DataTypes.createStructType(fields);
             Dataset<Row> dataset = spark.createDataFrame(rowRDD, schema);
 
+            VectorAssembler assembler = new VectorAssembler()
+                    .setInputCols(features)
+                    .setOutputCol("features");
 
+            Dataset<Row> assembledData = assembler.transform(dataset);
+            assembledData.show();
             // 划分训练集和测试集
-            Dataset<Row>[] splits = dataset.randomSplit(new double[]{0.7, 0.3});
+            Dataset<Row>[] splits = assembledData.randomSplit(new double[]{0.7, 0.3});
             Dataset<Row> trainData = splits[0];
+            trainData.show();
             Dataset<Row> testData = splits[1];
+            testData.show();
 
             // 参数
 
@@ -124,7 +147,7 @@ public class XGBoostTrain {
 
             // 显示预测结果
             Dataset<Row> predictions = model.transform(testData);
-            predictions.select("label", "prediction").show(30000);
+            predictions.show(100);
 
 
         } catch (Throwable e) {