丁云鹏 8 ay önce
ebeveyn
işleme
52aa2d56c4

+ 13 - 20
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/service/OSSService.java

@@ -2,13 +2,12 @@ package com.tzld.piaoquan.recommend.model.produce.service;
 
 import com.aliyun.oss.OSS;
 import com.aliyun.oss.OSSClientBuilder;
-import com.aliyun.oss.model.CopyObjectRequest;
-import com.aliyun.oss.model.CopyObjectResult;
-import com.aliyun.oss.model.ObjectMetadata;
+import com.aliyun.oss.model.PutObjectRequest;
+import com.aliyun.oss.model.PutObjectResult;
 import lombok.extern.slf4j.Slf4j;
 
+import java.io.File;
 import java.io.Serializable;
-import java.util.List;
 
 /**
  * @author dyp
@@ -20,21 +19,15 @@ public class OSSService implements Serializable {
     private String endpoint = "https://oss-cn-hangzhou-internal.aliyuncs.com";
 
     public void upload(String bucketName, String srcPath, String orcPath) {
-//        OSS ossClient = new OSSClientBuilder().build(endpoint, accessId, accessKey);
-//        try {
-//            if (objectName.startsWith("http")) {
-//                continue;
-//            }
-//            CopyObjectRequest request = new CopyObjectRequest(bucketName, objectName, bucketName, objectName);
-//            ObjectMetadata objectMetadata = new ObjectMetadata();
-//            objectMetadata.setHeader("x-oss-storage-class", "DeepColdArchive");
-//            request.setNewObjectMetadata(objectMetadata);
-//            CopyObjectResult result = ossClient.copyObject(request);
-//        } catch (Exception e) {
-//            log.error("transToDeepColdArchive error {} {}", objectName, e.getMessage(), e);
-//        }
-//        if (ossClient != null) {
-//            ossClient.shutdown();
-//        }
+        OSS ossClient = new OSSClientBuilder().build(endpoint, accessId, accessKey);
+        try {
+            PutObjectRequest request = new PutObjectRequest(bucketName, orcPath, new File(srcPath));
+            PutObjectResult result = ossClient.putObject(request);
+        } catch (Exception e) {
+            log.error("upload error bucketName {}, srcPath {}, orcPath {}", bucketName, srcPath, orcPath, e);
+        }
+        if (ossClient != null) {
+            ossClient.shutdown();
+        }
     }
 }

+ 13 - 1
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/xgboost/XGBoostTrain.java

@@ -1,5 +1,6 @@
 package com.tzld.piaoquan.recommend.model.produce.xgboost;
 
+import com.tzld.piaoquan.recommend.model.produce.service.OSSService;
 import lombok.extern.slf4j.Slf4j;
 import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
 import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier;
@@ -121,10 +122,21 @@ public class XGBoostTrain {
             // 训练模型
             XGBoostClassificationModel model = xgbClassifier.fit(assembledData);
 
+            // 保存模型
+            String path = "/root/recommend-model/model";
+            model.save(path);
+
+//            OSSService ossService = new OSSService();
+//            String bucketName = "";
+//            String ossPath = "";
+//            ossService.upload(bucketName, path, ossPath);
+
             // 显示预测结果
-            Dataset<Row> predictions = model.transform(assembledData);
+            XGBoostClassificationModel model2 = XGBoostClassificationModel.load(path);
+            Dataset<Row> predictions = model2.transform(assembledData);
             predictions.select("label", "prediction", "features", "rawPrediction", "probability").show(500);
 
+            spark.close();
 
         } catch (Throwable e) {
             log.error("", e);

+ 13 - 22
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/xgboost/XGBoostTrainLocalTest.java

@@ -1,5 +1,6 @@
 package com.tzld.piaoquan.recommend.model.produce.xgboost;
 
+import com.tzld.piaoquan.recommend.model.produce.service.OSSService;
 import lombok.extern.slf4j.Slf4j;
 import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
 import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier;
@@ -59,27 +60,6 @@ public class XGBoostTrainLocalTest {
             file = "/Users/dingyunpeng/Desktop/part-00099.gz";
             JavaRDD<String> rdd = jsc.textFile(file);
 
-            // 将 RDD[LabeledPoint] 转换为 JavaRDD<Row>
-//            JavaRDD<Row> rowRDD = rdd.map(s -> {
-//                String[] line = StringUtils.split(s, '\t');
-//                int label = NumberUtils.toInt(line[0]);
-//                // 选特征
-//                Map<String, Double> map = new HashMap<>();
-//                for (int i = 1; i < line.length; i++) {
-//                    String[] fv = StringUtils.split(line[i], ':');
-//                    map.put(fv[0], NumberUtils.toDouble(fv[1], 0.0));
-//                }
-//
-//                int[] indices = new int[features.length];
-//                double[] values = new double[features.length];
-//                for (int i = 0; i < features.length; i++) {
-//                    indices[i] = i;
-//                    values[i] = map.getOrDefault(features[i], 0.0);
-//                }
-//                SparseVector vector = new SparseVector(indices.length, indices, values);
-//                return RowFactory.create(label, vector);
-//            });
-
             JavaRDD<Row> rowRDD = rdd.map(s -> {
                 String[] line = StringUtils.split(s, '\t');
                 int label = NumberUtils.toInt(line[0]);
@@ -142,10 +122,21 @@ public class XGBoostTrainLocalTest {
             // 训练模型
             XGBoostClassificationModel model = xgbClassifier.fit(assembledData);
 
+            // 保存模型
+            String path = "/Users/dingyunpeng/Desktop/model";
+            model.save(path);
+
+            OSSService ossService = new OSSService();
+            String bucketName = "";
+            String ossPath = "";
+            ossService.upload(bucketName, path, ossPath);
+
             // 显示预测结果
-            Dataset<Row> predictions = model.transform(assembledData);
+            XGBoostClassificationModel model2 = XGBoostClassificationModel.load(path);
+            Dataset<Row> predictions = model2.transform(assembledData);
             predictions.select("label", "prediction", "features", "rawPrediction", "probability").show(500);
 
+            spark.close();
 
         } catch (Throwable e) {
             log.error("", e);