丁云鹏 8 달 전
부모
커밋
bc14630ef9

+ 5 - 0
ad-engine-commons/src/main/java/com/tzld/piaoquan/ad/engine/commons/score/model/Model.java

@@ -1,11 +1,16 @@
 package com.tzld.piaoquan.ad.engine.commons.score.model;
 
 
+import java.io.InputStream;
 import java.io.InputStreamReader;
 
 public abstract class Model {
     public abstract int getModelSize();
 
     public abstract boolean loadFromStream(InputStreamReader in) throws Exception;
+
+    public boolean loadFromStream(InputStream is) throws Exception {
+        throw new NoSuchMethodException();
+    }
 }
 

+ 51 - 7
ad-engine-commons/src/main/java/com/tzld/piaoquan/ad/engine/commons/score/model/XGBoostModel.java

@@ -1,11 +1,16 @@
 package com.tzld.piaoquan.ad.engine.commons.score.model;
 
 
+import com.tzld.piaoquan.ad.engine.commons.util.CompressUtil;
+import ml.dmlc.xgboost4j.java.DMatrix;
 import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
+import org.apache.commons.lang.math.NumberUtils;
+import org.apache.spark.ml.linalg.Vector;
+import org.apache.spark.ml.linalg.Vectors;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.io.IOException;
+import java.io.InputStream;
 import java.io.InputStreamReader;
 import java.util.Map;
 
@@ -14,6 +19,33 @@ public class XGBoostModel extends Model {
     private static final Logger LOGGER = LoggerFactory.getLogger(XGBoostModel.class);
     private XGBoostClassificationModel model;
 
+    private String[] features = {
+            "cpa",
+            "b2_1h_ctr",
+            "b2_1h_ctcvr",
+            "b2_1h_cvr",
+            "b2_1h_conver",
+            "b2_1h_click",
+            "b2_1h_conver*log(view)",
+            "b2_1h_conver*ctcvr",
+            "b2_2h_ctr",
+            "b2_2h_ctcvr",
+            "b2_2h_cvr",
+            "b2_2h_conver",
+            "b2_2h_click",
+            "b2_2h_conver*log(view)",
+            "b2_2h_conver*ctcvr",
+            "b2_3h_ctr",
+            "b2_3h_ctcvr",
+            "b2_3h_cvr",
+            "b2_3h_conver",
+            "b2_3h_click",
+            "b2_3h_conver*log(view)",
+            "b2_3h_conver*ctcvr",
+            "b2_6h_ctr",
+            "b2_6h_ctcvr"
+    };
+
     @Override
     public int getModelSize() {
         if (this.model == null)
@@ -21,24 +53,36 @@ public class XGBoostModel extends Model {
         return 1;
     }
 
+    @Override
+    public boolean loadFromStream(InputStreamReader in) throws Exception {
+        return false;
+    }
+
     public void cleanModel() {
         this.model = null;
     }
 
     public Float score(Map<String, String> featureMap) {
-        return 0f;
-    }
-
-    @Override
-    public boolean loadFromStream(InputStreamReader in) throws IOException {
 
+        double[] values = new double[features.length];
+        for (int i = 0; i < features.length; i++) {
+            double v = NumberUtils.toDouble(featureMap.getOrDefault(features[i], "0.0"), 0.0);
+            values[i] = v;
+        }
 
+        Vector v = Vectors.dense(values);
+        double score = model.predict(v);
+        return (float) score;
+    }
 
+    @Override
+    public boolean loadFromStream(InputStream in) throws Exception {
         String modelDir = "";
+        CompressUtil.decompressGzFile(in, modelDir);
         XGBoostClassificationModel model2 = XGBoostClassificationModel.load("file://" + modelDir);
         model2.setMissing(0.0f)
                 .setFeaturesCol("features");
-        model = model2;
+        this.model = model2;
         return true;
     }
 

+ 123 - 0
ad-engine-commons/src/main/java/com/tzld/piaoquan/ad/engine/commons/util/CompressUtil.java

@@ -0,0 +1,123 @@
+package com.tzld.piaoquan.ad.engine.commons.util;
+
+import lombok.extern.slf4j.Slf4j;
+import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
+import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
+import org.apache.commons.compress.archivers.tar.TarArchiveOutputStream;
+import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream;
+import org.apache.commons.compress.compressors.gzip.GzipCompressorOutputStream;
+
+import java.io.*;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+
+/**
+ * @author dyp
+ */
+@Slf4j
+public class CompressUtil {
+    public static void compressDirectoryToGzip(String sourceDirPath, String outputFilePath) {
+        // 创建.gz文件的输出流
+        try (OutputStream out = new FileOutputStream(outputFilePath);
+             GzipCompressorOutputStream gzipOut = new GzipCompressorOutputStream(out);
+             TarArchiveOutputStream taos = new TarArchiveOutputStream(gzipOut)) {
+
+            taos.setLongFileMode(TarArchiveOutputStream.LONGFILE_GNU);
+
+            // 遍历目录
+            Files.walk(Paths.get(sourceDirPath))
+                    .filter(Files::isRegularFile)
+                    .forEach(filePath -> {
+                        try {
+                            // 为每个文件创建TarEntry
+                            TarArchiveEntry entry = new TarArchiveEntry(filePath.toFile(), filePath.toString().substring(sourceDirPath.length() + 1));
+                            taos.putArchiveEntry(entry);
+
+                            // 读取文件内容并写入TarArchiveOutputStream
+                            try (InputStream is = Files.newInputStream(filePath)) {
+                                byte[] buffer = new byte[1024];
+                                int len;
+                                while ((len = is.read(buffer)) > 0) {
+                                    taos.write(buffer, 0, len);
+                                }
+                            }
+                            // 关闭entry
+                            taos.closeArchiveEntry();
+                        } catch (IOException e) {
+                            log.error("", e);
+                        }
+                    });
+        } catch (Exception e) {
+            log.error("", e);
+        }
+    }
+
+    public static void decompressGzFile(String gzipFilePath, String destDirPath) {
+        try (InputStream gzipIn = new FileInputStream(gzipFilePath);
+             GzipCompressorInputStream gzIn = new GzipCompressorInputStream(gzipIn);
+             TarArchiveInputStream tais = new TarArchiveInputStream(gzIn)) {
+
+            TarArchiveEntry entry;
+            Files.createDirectories(Paths.get(destDirPath));
+            while ((entry = tais.getNextTarEntry()) != null) {
+                if (entry.isDirectory()) {
+                    // 如果是目录,创建目录
+                    Files.createDirectories(Paths.get(destDirPath, entry.getName()));
+                } else {
+                    // 如果是文件,创建文件并写入内容
+                    File outputFile = new File(destDirPath, entry.getName());
+                    if (!outputFile.exists()) {
+                        File parent = outputFile.getParentFile();
+                        if (!parent.exists()) {
+                            parent.mkdirs();
+                        }
+                        outputFile.createNewFile();
+                    }
+                    try (OutputStream out = new FileOutputStream(outputFile)) {
+                        byte[] buffer = new byte[1024];
+                        int len;
+                        while ((len = tais.read(buffer)) > 0) {
+                            out.write(buffer, 0, len);
+                        }
+                    }
+                }
+            }
+        } catch (Exception e) {
+            log.error("", e);
+        }
+    }
+
+    public static void decompressGzFile(InputStream gzipIn, String destDirPath) {
+        try (GzipCompressorInputStream gzIn = new GzipCompressorInputStream(gzipIn);
+             TarArchiveInputStream tais = new TarArchiveInputStream(gzIn)) {
+
+            TarArchiveEntry entry;
+            Files.createDirectories(Paths.get(destDirPath));
+            while ((entry = tais.getNextTarEntry()) != null) {
+                if (entry.isDirectory()) {
+                    // 如果是目录,创建目录
+                    Files.createDirectories(Paths.get(destDirPath, entry.getName()));
+                } else {
+                    // 如果是文件,创建文件并写入内容
+                    File outputFile = new File(destDirPath, entry.getName());
+                    if (!outputFile.exists()) {
+                        File parent = outputFile.getParentFile();
+                        if (!parent.exists()) {
+                            parent.mkdirs();
+                        }
+                        outputFile.createNewFile();
+                    }
+                    try (OutputStream out = new FileOutputStream(outputFile)) {
+                        byte[] buffer = new byte[1024];
+                        int len;
+                        while ((len = tais.read(buffer)) > 0) {
+                            out.write(buffer, 0, len);
+                        }
+                    }
+                }
+            }
+        } catch (Exception e) {
+            log.error("", e);
+        }
+    }
+}