丁云鹏 vor 8 Monaten
Ursprung
Commit
c2a6a72257

+ 4 - 0
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/service/CMDService.java

@@ -1,5 +1,6 @@
 package com.tzld.piaoquan.recommend.model.produce.service;
 
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
 
@@ -11,6 +12,9 @@ import java.util.Map;
 public class CMDService {
 
     public Map<String, String> parse(String[] args) {
+        if (args == null) {
+            return Collections.emptyMap();
+        }
         Map<String, String> map = new HashMap<>();
         for (int i = 0; i < args.length - 1; i++) {
             map.put(args[i].substring(1), args[i + 1]);

+ 19 - 14
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/service/XGBoostService.java

@@ -33,9 +33,11 @@ public class XGBoostService {
 
     public void train(String[] args) {
         try {
-
+            CMDService cmd = new CMDService();
+            Map<String, String> argMap = cmd.parse(args);
+            String path = argMap.get("path");
             // 训练
-            Dataset<Row> trainData = dataset("/dw/recommend/model/33_ad_train_data_v4/20240726/part-00099.gz");
+            Dataset<Row> trainData = dataset(path);
             trainData.show();
             XGBoostClassifier xgbClassifier = new XGBoostClassifier()
                     .setEta(0.01f)
@@ -54,14 +56,14 @@ public class XGBoostService {
             XGBoostClassificationModel model = xgbClassifier.fit(trainData);
 
             // 保存模型
-            String path = "/root/recommend-model/modeltrain";
-            model.write().overwrite().save("file://" + path);
-            String outputPath = "/root/recommend-model/model.tar.gz";
-            CompressUtil.compressDirectoryToGzip(path, outputPath);
+            String modelPath = "/root/recommend-model/modeltrain";
+            model.write().overwrite().save("file://" + modelPath);
+            String gzPath = "/root/recommend-model/model.tar.gz";
+            CompressUtil.compressDirectoryToGzip(modelPath, gzPath);
             String bucketName = "art-test-video";
             String ossPath = "test/model.tar.gz";
             OSSService ossService = new OSSService();
-            ossService.upload(bucketName, outputPath, ossPath);
+            ossService.upload(bucketName, gzPath, ossPath);
 
         } catch (Throwable e) {
             log.error("", e);
@@ -71,25 +73,28 @@ public class XGBoostService {
     public void predict(String[] args) {
         try {
 
+            CMDService cmd = new CMDService();
+            Map<String, String> argMap = cmd.parse(args);
+            String path = argMap.get("path");
+
 
             // 加载模型
             String bucketName = "art-test-video";
             String objectName = "test/model.tar.gz";
             OSSService ossService = new OSSService();
 
-            String destPath = "/root/recommend-model/model2.tar.gz";
-            ossService.download(bucketName, destPath, objectName);
-            String destDir = "/root/recommend-model/modelpredict";
-            CompressUtil.decompressGzFile(destPath, destDir);
+            String gzPath = "/root/recommend-model/model2.tar.gz";
+            ossService.download(bucketName, gzPath, objectName);
+            String modelDir = "/root/recommend-model/modelpredict";
+            CompressUtil.decompressGzFile(gzPath, modelDir);
 
-            XGBoostClassificationModel model = XGBoostClassificationModel.load("file://" + destDir);
+            XGBoostClassificationModel model = XGBoostClassificationModel.load("file://" + modelDir);
             model.setMissing(0.0f)
                     .setFeaturesCol("features");
 
 
             // 预测
-            Dataset<Row> predictData =
-                    dataset("/dw/recommend/model/33_ad_train_data_v4/20240726/part-00098.gz");
+            Dataset<Row> predictData = dataset(path);
             predictData.show();
             Dataset<Row> predictions = model.transform(predictData);
             predictions.show();