瀏覽代碼

dssm train

丁云鹏 4 月之前
父節點
當前提交
146e66a459

+ 16 - 0
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/i2i/I2IDSSMPredict.java

@@ -0,0 +1,16 @@
+package com.tzld.piaoquan.recommend.model.produce.i2i;
+
+import com.tzld.piaoquan.recommend.model.produce.service.XGBoostService;
+import lombok.extern.slf4j.Slf4j;
+
+/**
+ * @author dyp
+ */
+@Slf4j
+public class I2IDSSMPredict {
+
+    public static void main(String[] args) {
+        I2IDSSMService dssm = new I2IDSSMService();
+        dssm.predict(args);
+    }
+}

+ 134 - 0
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/i2i/I2IDSSMService.java

@@ -0,0 +1,134 @@
+package com.tzld.piaoquan.recommend.model.produce.i2i;
+
+import com.tzld.piaoquan.recommend.model.produce.service.CMDService;
+import com.tzld.piaoquan.recommend.model.produce.service.OSSService;
+import com.tzld.piaoquan.recommend.model.produce.util.CompressUtil;
+import lombok.extern.slf4j.Slf4j;
+import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
+import org.apache.commons.lang.math.NumberUtils;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
+import org.apache.spark.ml.feature.VectorAssembler;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * @author dyp
+ */
+@Slf4j
+public class I2IDSSMService {
+
+    public void predict(String[] args) {
+
+        try {
+
+            CMDService cmd = new CMDService();
+            Map<String, String> argMap = cmd.parse(args);
+            String file = argMap.get("path");
+
+            if(StringUtils.isBlank(file)){
+                String dir = argMap.get("dir");
+            }
+            
+            // 加载模型
+            String bucketName = "art-test-video";
+            String objectName = "test/model.tar.gz";
+            OSSService ossService = new OSSService();
+
+            String gzPath = "/root/recommend-model/model2.tar.gz";
+            ossService.download(bucketName, gzPath, objectName);
+            String modelDir = "/root/recommend-model/modelpredict";
+            CompressUtil.decompressGzFile(gzPath, modelDir);
+
+
+        } catch (Throwable e) {
+            log.error("", e);
+        }
+    }
+
+    private static Dataset<Row> dataset(String path) {
+        String[] features = {
+                "cpa",
+                "b2_1h_ctr",
+                "b2_1h_ctcvr",
+                "b2_1h_cvr",
+                "b2_1h_conver",
+                "b2_1h_click",
+                "b2_1h_conver*log(view)",
+                "b2_1h_conver*ctcvr",
+                "b2_2h_ctr",
+                "b2_2h_ctcvr",
+                "b2_2h_cvr",
+                "b2_2h_conver",
+                "b2_2h_click",
+                "b2_2h_conver*log(view)",
+                "b2_2h_conver*ctcvr",
+                "b2_3h_ctr",
+                "b2_3h_ctcvr",
+                "b2_3h_cvr",
+                "b2_3h_conver",
+                "b2_3h_click",
+                "b2_3h_conver*log(view)",
+                "b2_3h_conver*ctcvr",
+                "b2_6h_ctr",
+                "b2_6h_ctcvr"
+        };
+
+
+        SparkSession spark = SparkSession.builder()
+                .appName("XGBoostTrain")
+                .getOrCreate();
+
+        JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
+        String file = path;
+        JavaRDD<String> rdd = jsc.textFile(file);
+
+        JavaRDD<Row> rowRDD = rdd.map(s -> {
+            String[] line = StringUtils.split(s, '\t');
+            double label = NumberUtils.toDouble(line[0]);
+            // 选特征
+            Map<String, Double> map = new HashMap<>();
+            for (int i = 1; i < line.length; i++) {
+                String[] fv = StringUtils.split(line[i], ':');
+                map.put(fv[0], NumberUtils.toDouble(fv[1], 0.0));
+            }
+
+            Object[] v = new Object[features.length + 1];
+            v[0] = label;
+            for (int i = 0; i < features.length; i++) {
+                v[i + 1] = map.getOrDefault(features[i], 0.0d);
+            }
+
+            return RowFactory.create(v);
+        });
+
+        log.info("rowRDD count {}", rowRDD.count());
+        // 将 JavaRDD<Row> 转换为 Dataset<Row>
+        List<StructField> fields = new ArrayList<>();
+        fields.add(DataTypes.createStructField("label", DataTypes.DoubleType, true));
+        for (String f : features) {
+            fields.add(DataTypes.createStructField(f, DataTypes.DoubleType, true));
+        }
+        StructType schema = DataTypes.createStructType(fields);
+        Dataset<Row> dataset = spark.createDataFrame(rowRDD, schema);
+
+        VectorAssembler assembler = new VectorAssembler()
+                .setInputCols(features)
+                .setOutputCol("features");
+
+        Dataset<Row> assembledData = assembler.transform(dataset);
+        return assembledData;
+    }
+}

+ 127 - 0
recommend-model-produce/src/main/python/models/dssm/infer.py

@@ -0,0 +1,127 @@
+import os
+import sys
+__dir__ = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+#sys.path.append(__dir__)
+sys.path.append(os.path.join(__dir__,"tools"))
+
+import numpy as np
+import json
+from concurrent.futures import ThreadPoolExecutor
+from utils.oss_client import HangZhouOSSClient
+import utils.compress as compress
+from utils.my_hdfs_client import MyHDFSClient
+import paddle.inference as paddle_infer
+
+# Hadoop 安装目录和配置信息
+hadoop_home = "/app/env/hadoop-3.2.4"
+configs = {
+    "fs.defaultFS": "hdfs://192.168.141.208:9000",
+    "hadoop.job.ugi": ""
+}
+hdfs_client = MyHDFSClient(hadoop_home, configs)
+
+def download_and_extract_model(init_model_path, oss_client, oss_object_name):
+    """下载并解压模型"""
+    model_tar_path = "model.tar.gz"
+    oss_client.get_object_to_file(oss_object_name, model_tar_path)
+    compress.uncompress_tar(model_tar_path, init_model_path)
+    assert os.path.exists(init_model_path)
+
+def create_paddle_predictor(model_file, params_file):
+    """创建PaddlePaddle的predictor"""
+    config = paddle_infer.Config(model_file, params_file)
+    predictor = paddle_infer.create_predictor(config)
+    return predictor
+
+def process_file(file_path, model_file, params_file):
+    """处理单个文件"""
+    predictor = create_paddle_predictor(model_file, params_file)
+    ret, out = hdfs_client._run_cmd(f"text {file_path}")
+    input_data = {}
+    for line in out:
+        sample_values = line.rstrip('\n').split('\t')
+        vid, left_features_str = sample_values
+        left_features = [float(x) for x in left_features_str.split(',')]
+        input_data[vid] = left_features
+
+    result = []
+    for k, v in input_data.items():
+        v2 = np.array([v], dtype=np.float32)
+        input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
+        input_handle.copy_from_cpu(v2)
+        predictor.run()
+        output_handle = predictor.get_output_handle(predictor.get_output_names()[0])
+        output_data = output_handle.copy_to_cpu()
+        result.append(k + "\t" + str(output_data.tolist()[0]))
+    return result
+
+def write_results(results, output_file):
+    """将结果写入文件"""
+    with open(output_file, 'w') as json_file:
+        for s in results:
+            json_file.write(s + "\n")
+
+def thread_task(name, file_list, model_file, params_file):
+    """线程任务"""
+    print(f"Thread {name}: starting file_list:{file_list}")
+    results = []
+    i=0
+    for file_path in file_list:
+        i=i+1
+        count=len(file_list)
+        print(f"Thread {name}: starting file:{file_path} {i}/{count}")
+        results.extend(process_file(file_path, model_file, params_file))
+        file_name, file_suffix = os.path.splitext(os.path.basename(file_path))
+        output_file = f"/app/vec-{file_name}.json"
+        write_results(results, output_file)
+        compress.compress_file_tar(output_file, f"{output_file}.tar.gz")
+        hdfs_client.delete(f"/dyp/vec/{file_name}.gz")
+        hdfs_client.upload(f"{output_file}.tar.gz", f"/dyp/vec/{file_name}.gz", multi_processes=1, overwrite=False)
+        results=[]
+        print(f"Thread {name}: ending file:{file_path} {i}/{count}")
+    
+    print(f"Thread {name}: finishing")
+
+def main():
+    init_model_path = "/app/output_model_dssm"
+    client = HangZhouOSSClient("art-recommend")
+    oss_object_name = "dyp/dssm.tar.gz"
+    download_and_extract_model(init_model_path, client, oss_object_name)
+
+    model_file = os.path.join(init_model_path, "dssm.pdmodel")
+    params_file = os.path.join(init_model_path, "dssm.pdiparams")
+
+
+    sub_dirs,file_list = hdfs_client.ls_dir('/dw/recommend/model/56_dssm_i2i_itempredData/20241212')   
+    all_file=[]
+    for file in files:
+        # 扩展名过滤
+        if file_extensions and not any(file.endswith(ext) for ext in ".gz"):
+            continue    
+        all_files.append(file)
+    print(f"File list : {all_files}")
+    max_workers = 16
+    chunk_size = len(file_list) // max_workers
+    remaining = len(file_list) % max_workers
+
+    # 分割列表
+    split_file_list = []
+    for i in range(8):
+        # 计算每份的起始和结束索引
+        start = i * chunk_size + min(i, remaining)
+        end = start + chunk_size + (1 if i < remaining else 0)
+        # 添加分割后的子列表
+        split_file_list.append(file_list[start:end])
+
+    future_list = []
+    with ThreadPoolExecutor(max_workers=max_workers) as executor:
+        for i, file_list in enumerate(split_file_list):
+            future_list.append(executor.submit(thread_task, f"thread{i}", file_list, model_file, params_file))
+
+    for future in future_list:
+        future.result()
+
+    print("Main program ending")
+
+if __name__ == "__main__":
+    main()