Przeglądaj źródła

feat:添加ros特征文件

zhaohaipeng 1 miesiąc temu
rodzic
commit
ebf6927602

+ 38 - 0
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/util/RosUtil.java

@@ -0,0 +1,38 @@
+package com.tzld.piaoquan.recommend.model.produce.util;
+
+import org.apache.commons.lang3.StringUtils;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.Collectors;
+
+public class RosUtil {
+
+    public static double multiClassModelScore(String modelOutput, List<Integer> labelList) {
+        double score = 0d;
+        if (StringUtils.isBlank(modelOutput)) {
+            return score;
+        }
+        modelOutput = modelOutput.replace("[", "").replace("]", "").trim();
+        if (StringUtils.isBlank(modelOutput)) {
+            return score;
+        }
+
+        List<Double> scores = Arrays.stream(modelOutput.split(","))
+                .map(Double::parseDouble)
+                .collect(Collectors.toList());
+
+        if (labelList.size() != scores.size()) {
+            return score;
+        }
+
+        for (int i = 0; i < labelList.size(); i++) {
+            Integer label = labelList.get(i);
+            score += scores.get(i) * label;
+        }
+
+        return score;
+
+    }
+
+}

+ 16 - 6
recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/recsys_01_ros_multi_class_xgb_train.scala

@@ -1,6 +1,7 @@
 package com.tzld.piaoquan.recommend.model
 
 import com.alibaba.fastjson.JSON
+import com.tzld.piaoquan.recommend.model.produce.util.RosUtil
 import com.tzld.piaoquan.recommend.utils.{MyHdfsUtils, ParamUtils}
 import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier
 import org.apache.commons.lang.math.NumberUtils
@@ -46,6 +47,14 @@ object recsys_01_ros_multi_class_xgb_train {
     val subsample = param.getOrElse("subsample", "0.95").toDouble
     val modelPath = param.getOrElse("modelPath", "/dw/recommend/model/45_recommend_model/")
     val modelFile = param.getOrElse("modelFile", "model.tar.gz")
+    val predictLabelList = param.getOrElse("predictLabelList", "").split(",").filter(_.nonEmpty).toList
+
+    val pll = new util.ArrayList[Integer]()
+    for (elem <- predictLabelList) {
+      pll.add(Integer.parseInt(elem))
+    }
+
+    val predictLabelList_br = sc.broadcast(pll)
 
     val loader = getClass.getClassLoader
     val resourceUrl = loader.getResource(featureFile)
@@ -148,15 +157,16 @@ object recsys_01_ros_multi_class_xgb_train {
     sc.textFile(hdfsPath).map(r => {
       val rList = r.split("\t")
       val vid = JSON.parseObject(rList(3)).getString("vid")
-      val score = rList(2).replace("[", "").replace("]", "")
-        .split(",")(1).toDouble
       val label = rList(0).toDouble
-      (vid, (1, label, score))
+      val score = RosUtil.multiClassModelScore(rList(2), predictLabelList_br.value)
+
+      ((vid, label), (1, score))
     }).reduceByKey {
-      case (a, b) => (a._1 + b._1, a._2 + b._2, a._3 + b._3)
+      case ((c1, s1), (c2, s2)) =>
+        (c1 + c2, (s1 + s2))
     }.map {
-      case (vid, (all, zheng, scores)) =>
-        (vid, all, zheng, scores, zheng / all, scores / all)
+      case ((vid, label), (count, sumScore)) =>
+        (vid, label, count, sumScore, sumScore / count)
     }.collect().sortBy(_._1).map(_.productIterator.mkString("\t")).foreach(println)
 
   }