|
@@ -33,9 +33,11 @@ public class XGBoostService {
|
|
|
|
|
|
public void train(String[] args) {
|
|
|
try {
|
|
|
-
|
|
|
+ CMDService cmd = new CMDService();
|
|
|
+ Map<String, String> argMap = cmd.parse(args);
|
|
|
+ String path = argMap.get("path");
|
|
|
// 训练
|
|
|
- Dataset<Row> trainData = dataset("/dw/recommend/model/33_ad_train_data_v4/20240726/part-00099.gz");
|
|
|
+ Dataset<Row> trainData = dataset(path);
|
|
|
trainData.show();
|
|
|
XGBoostClassifier xgbClassifier = new XGBoostClassifier()
|
|
|
.setEta(0.01f)
|
|
@@ -54,14 +56,14 @@ public class XGBoostService {
|
|
|
XGBoostClassificationModel model = xgbClassifier.fit(trainData);
|
|
|
|
|
|
// 保存模型
|
|
|
- String path = "/root/recommend-model/modeltrain";
|
|
|
- model.write().overwrite().save("file://" + path);
|
|
|
- String outputPath = "/root/recommend-model/model.tar.gz";
|
|
|
- CompressUtil.compressDirectoryToGzip(path, outputPath);
|
|
|
+ String modelPath = "/root/recommend-model/modeltrain";
|
|
|
+ model.write().overwrite().save("file://" + modelPath);
|
|
|
+ String gzPath = "/root/recommend-model/model.tar.gz";
|
|
|
+ CompressUtil.compressDirectoryToGzip(modelPath, gzPath);
|
|
|
String bucketName = "art-test-video";
|
|
|
String ossPath = "test/model.tar.gz";
|
|
|
OSSService ossService = new OSSService();
|
|
|
- ossService.upload(bucketName, outputPath, ossPath);
|
|
|
+ ossService.upload(bucketName, gzPath, ossPath);
|
|
|
|
|
|
} catch (Throwable e) {
|
|
|
log.error("", e);
|
|
@@ -71,25 +73,28 @@ public class XGBoostService {
|
|
|
public void predict(String[] args) {
|
|
|
try {
|
|
|
|
|
|
+ CMDService cmd = new CMDService();
|
|
|
+ Map<String, String> argMap = cmd.parse(args);
|
|
|
+ String path = argMap.get("path");
|
|
|
+
|
|
|
|
|
|
// 加载模型
|
|
|
String bucketName = "art-test-video";
|
|
|
String objectName = "test/model.tar.gz";
|
|
|
OSSService ossService = new OSSService();
|
|
|
|
|
|
- String destPath = "/root/recommend-model/model2.tar.gz";
|
|
|
- ossService.download(bucketName, destPath, objectName);
|
|
|
- String destDir = "/root/recommend-model/modelpredict";
|
|
|
- CompressUtil.decompressGzFile(destPath, destDir);
|
|
|
+ String gzPath = "/root/recommend-model/model2.tar.gz";
|
|
|
+ ossService.download(bucketName, gzPath, objectName);
|
|
|
+ String modelDir = "/root/recommend-model/modelpredict";
|
|
|
+ CompressUtil.decompressGzFile(gzPath, modelDir);
|
|
|
|
|
|
- XGBoostClassificationModel model = XGBoostClassificationModel.load("file://" + destDir);
|
|
|
+ XGBoostClassificationModel model = XGBoostClassificationModel.load("file://" + modelDir);
|
|
|
model.setMissing(0.0f)
|
|
|
.setFeaturesCol("features");
|
|
|
|
|
|
|
|
|
// 预测
|
|
|
- Dataset<Row> predictData =
|
|
|
- dataset("/dw/recommend/model/33_ad_train_data_v4/20240726/part-00098.gz");
|
|
|
+ Dataset<Row> predictData = dataset(path);
|
|
|
predictData.show();
|
|
|
Dataset<Row> predictions = model.transform(predictData);
|
|
|
predictions.show();
|