丁云鹏 8 months ago
parent
commit
1856ed40b8

+ 20 - 0
ad-engine-commons/src/main/java/com/tzld/piaoquan/ad/engine/commons/score/BaseXGBoostModelScorer.java

@@ -0,0 +1,20 @@
+package com.tzld.piaoquan.ad.engine.commons.score;
+
+import com.tzld.piaoquan.ad.engine.commons.score.model.XGBoostModel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+public abstract class BaseXGBoostModelScorer extends AbstractScorer {
+
+    private static Logger LOGGER = LoggerFactory.getLogger(BaseXGBoostModelScorer.class);
+
+    public BaseXGBoostModelScorer(ScorerConfigInfo scorerConfigInfo) {
+        super(scorerConfigInfo);
+    }
+
+    @Override
+    public void loadModel() {
+        doLoadModel(XGBoostModel.class);
+    }
+}

+ 7 - 8
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/XGBoostScorer.java

@@ -1,10 +1,10 @@
 package com.tzld.piaoquan.ad.engine.service.score;
 
 
-import com.tzld.piaoquan.ad.engine.commons.score.BaseFMModelScorer;
+import com.tzld.piaoquan.ad.engine.commons.score.BaseXGBoostModelScorer;
 import com.tzld.piaoquan.ad.engine.commons.score.ScoreParam;
 import com.tzld.piaoquan.ad.engine.commons.score.ScorerConfigInfo;
-import com.tzld.piaoquan.ad.engine.commons.score.model.FMModel;
+import com.tzld.piaoquan.ad.engine.commons.score.model.XGBoostModel;
 import com.tzld.piaoquan.recommend.feature.domain.ad.base.AdRankItem;
 import com.tzld.piaoquan.recommend.feature.domain.ad.base.UserAdFeature;
 import org.apache.commons.collections4.CollectionUtils;
@@ -17,10 +17,10 @@ import java.util.*;
 import java.util.concurrent.*;
 
 
-public class XGBoostScorer extends BaseFMModelScorer {
+public class XGBoostScorer extends BaseXGBoostModelScorer {
 
     private static final int LOCAL_TIME_OUT = 150;
-    private final static Logger LOGGER = LoggerFactory.getLogger(VlogRovFMScorer.class);
+    private final static Logger LOGGER = LoggerFactory.getLogger(XGBoostScorer.class);
     private static final ExecutorService executorService = Executors.newFixedThreadPool(128);
 
 
@@ -43,7 +43,6 @@ public class XGBoostScorer extends BaseFMModelScorer {
         }
 
         long startTime = System.currentTimeMillis();
-        FMModel model = (FMModel) this.getModel();
 
         List<AdRankItem> result = rankByJava(sceneFeatureMap, userFeatureMap, rankItems);
 
@@ -57,7 +56,7 @@ public class XGBoostScorer extends BaseFMModelScorer {
                                       final Map<String, String> userFeatureMap,
                                       final List<AdRankItem> items) {
         long startTime = System.currentTimeMillis();
-        FMModel model = (FMModel) this.getModel();
+        XGBoostModel model = (XGBoostModel) this.getModel();
         LOGGER.debug("model size: [{}]", model.getModelSize());
 
         // 所有都参与打分,按照ctr排序
@@ -81,7 +80,7 @@ public class XGBoostScorer extends BaseFMModelScorer {
     private void multipleCtrScore(final List<AdRankItem> items,
                                   final Map<String, String> userFeatureMap,
                                   final Map<String, String> sceneFeatureMap,
-                                  final FMModel model) {
+                                  final XGBoostModel model) {
 
         List<Callable<Object>> calls = new ArrayList<Callable<Object>>();
         for (int index = 0; index < items.size(); index++) {
@@ -124,7 +123,7 @@ public class XGBoostScorer extends BaseFMModelScorer {
         }
     }
 
-    public double calcScore(final FMModel model,
+    public double calcScore(final XGBoostModel model,
                             final AdRankItem item,
                             final Map<String, String> userFeatureMap,
                             final Map<String, String> sceneFeatureMap) {