Ver Fonte

修改模型调用为PAI

xueyiming há 3 meses atrás
pai
commit
b58408c341

+ 2 - 0
ad-engine-commons/src/main/java/com/tzld/piaoquan/ad/engine/commons/score/ScorerUtils.java

@@ -29,6 +29,7 @@ public final class ScorerUtils {
     public static String XGBOOST_SCORE_CONF_683 = "ad_score_config_xgboost_683.conf";
     public static String XGBOOST_SCORE_CONF_20240909 = "ad_score_config_xgboost_20240909.conf";
     public static String XGBOOST_SCORE_CONF_20241105 = "ad_score_config_xgboost_20241105.conf";
+    public static String PAI_SCORE_CONF_20250214 = "ad_score_config_pai_20250214.conf";
 
     public static void warmUp() {
         log.info("scorer warm up ");
@@ -39,6 +40,7 @@ public final class ScorerUtils {
         ScorerUtils.init(XGBOOST_SCORE_CONF);
         ScorerUtils.init(XGBOOST_SCORE_CONF_20240909);
         ScorerUtils.init(XGBOOST_SCORE_CONF_20241105);
+        ScorerUtils.init(PAI_SCORE_CONF_20250214);
     }
 
     private ScorerUtils() {

Diff do ficheiro suprimidas por serem muito extensas
+ 32 - 0
ad-engine-commons/src/main/java/com/tzld/piaoquan/ad/engine/commons/score/model/PAIModel.java


+ 6 - 0
ad-engine-server/src/main/resources/ad_score_config_pai_20250214.conf

@@ -0,0 +1,6 @@
+scorer-config = {
+  pai-score-config = {
+    scorer-name = "com.tzld.piaoquan.ad.engine.service.score.scorer.PAIScorer"
+    scorer-priority = 99
+  }
+}

+ 99 - 0
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/scorer/PAIScorer.java

@@ -0,0 +1,99 @@
+package com.tzld.piaoquan.ad.engine.service.score.scorer;
+
+
+import com.google.common.collect.Lists;
+import com.tzld.piaoquan.ad.engine.commons.score.BaseXGBoostModelScorer;
+import com.tzld.piaoquan.ad.engine.commons.score.ScoreParam;
+import com.tzld.piaoquan.ad.engine.commons.score.ScorerConfigInfo;
+import com.tzld.piaoquan.ad.engine.commons.score.model.PAIModel;
+import com.tzld.piaoquan.ad.engine.commons.score.model.XGBoostModel683;
+import com.tzld.piaoquan.recommend.feature.domain.ad.base.AdRankItem;
+import com.tzld.piaoquan.recommend.feature.domain.ad.base.UserAdFeature;
+import org.apache.commons.collections4.CollectionUtils;
+import org.apache.commons.collections4.MapUtils;
+import org.apache.commons.lang.exception.ExceptionUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.stereotype.Component;
+
+import java.util.*;
+import java.util.concurrent.*;
+
+public class PAIScorer extends BaseXGBoostModelScorer {
+
+    private final static Logger LOGGER = LoggerFactory.getLogger(PAIScorer.class);
+
+
+    public PAIScorer(ScorerConfigInfo configInfo) {
+        super(configInfo);
+    }
+
+    @Override
+    public void loadModel() {
+        doLoadModel(XGBoostModel683.class);
+    }
+
+    @Override
+    public List<AdRankItem> scoring(final ScoreParam param,
+                                    final UserAdFeature userAdFeature,
+                                    final List<AdRankItem> rankItems) {
+        throw new NoSuchMethodError();
+    }
+
+    public List<AdRankItem> scoring(final Map<String, String> sceneFeatureMap,
+                                    final Map<String, String> userFeatureMap,
+                                    final List<AdRankItem> rankItems) {
+        if (CollectionUtils.isEmpty(rankItems)) {
+            return rankItems;
+        }
+
+        long startTime = System.currentTimeMillis();
+
+        List<AdRankItem> result = rankByJava(sceneFeatureMap, userFeatureMap, rankItems);
+
+        LOGGER.debug("ctr ranker time java items size={}, time={} ", result != null ? result.size() : 0,
+                System.currentTimeMillis() - startTime);
+
+        return result;
+    }
+
+    private List<AdRankItem> rankByJava(final Map<String, String> sceneFeatureMap,
+                                        final Map<String, String> userFeatureMap,
+                                        final List<AdRankItem> items) {
+        long startTime = System.currentTimeMillis();
+        PAIModel model = (PAIModel) this.getModel();
+        LOGGER.debug("model size: [{}]", model.getModelSize());
+        // 所有都参与打分,按照ctr排序
+        multipleCtrScore(items, userFeatureMap, sceneFeatureMap, model);
+
+        // debug log
+        if (LOGGER.isDebugEnabled()) {
+            for (int i = 0; i < items.size(); i++) {
+                LOGGER.debug("before enter feeds model predict ctr score [{}] [{}]", items.get(i), items.get(i));
+            }
+        }
+
+        Collections.sort(items);
+
+        LOGGER.debug("ctr ranker java execute time: [{}]", System.currentTimeMillis() - startTime);
+        LOGGER.debug("[ctr ranker time java] items size={}, cost={} ", items != null ? items.size() : 0,
+                System.currentTimeMillis() - startTime);
+        return items;
+    }
+
+    private void multipleCtrScore(final List<AdRankItem> items,
+                                  final Map<String, String> userFeatureMap,
+                                  final Map<String, String> sceneFeatureMap,
+                                  final PAIModel model) {
+
+        List<Float> score = model.score(items, userFeatureMap, sceneFeatureMap);
+        LOGGER.debug("PAIScorer score={}", score);
+        for (int i = 0; i < items.size(); i++) {
+            Double pro = Double.valueOf(score.get(i));
+            items.get(i).setLrScore(pro);
+            items.get(i).getScoreMap().put("ctcvrScore", pro);
+        }
+    }
+
+
+}

+ 1 - 1
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/strategy/RankStrategyBy683.java

@@ -211,7 +211,7 @@ public class RankStrategyBy683 extends RankStrategyBasic {
         long time4 = System.currentTimeMillis();
         // 打分排序
         // getScorerPipeline
-        List<AdRankItem> result = ScorerUtils.getScorerPipeline(ScorerUtils.XGBOOST_SCORE_CONF_683).scoring(sceneFeatureMap, userFeatureMap, adRankItems);
+        List<AdRankItem> result = ScorerUtils.getScorerPipeline(ScorerUtils.PAI_SCORE_CONF_20250214).scoring(sceneFeatureMap, userFeatureMap, adRankItems);
         long time5 = System.currentTimeMillis();
 
         // calibrate score for negative sampling

+ 10 - 0
pom.xml

@@ -322,6 +322,16 @@
             <groupId>org.springframework.cloud</groupId>
             <artifactId>spring-cloud-starter-openfeign</artifactId>
         </dependency>
+
+        <!-- https://mvnrepository.com/artifact/com.aliyun.openservices.eas/eas-sdk -->
+        <dependency>
+            <groupId>com.aliyun.openservices.eas</groupId>
+            <artifactId>eas-sdk</artifactId>
+            <version>2.0.23</version>
+        </dependency>
+
+
+
         <!--easyexcel-->
     </dependencies>
 

Alguns ficheiros não foram mostrados porque muitos ficheiros mudaram neste diff