浏览代码

修改model创建方式

xueyiming 1 月之前
父节点
当前提交
94eac76d24

+ 19 - 19
ad-engine-commons/src/main/java/com/tzld/piaoquan/ad/engine/commons/score/model/PAIModel.java

@@ -17,13 +17,28 @@ import java.io.InputStreamReader;
 import java.util.*;
 
 
-public class PAIModel extends Model {
+public class PAIModel {
 
     private static final Logger LOGGER = LoggerFactory.getLogger(PAIModel.class);
 
+    private PAIModel() {
+    }
+
+    private static final PAIModel model;
 
-    // 在类加载时就创建单例实例
-    private PredictClient model;
+    public static PAIModel getModel() {
+        return model;
+    }
+
+    private static final PredictClient client;
+
+    static {
+        model = new PAIModel();
+        client = new PredictClient(new HttpConfig());
+        client.setEndpoint("1894469520484605.cn-hangzhou.pai-eas.aliyuncs.com");
+        client.setToken("NmFhZGRlMjBmOGVhZTM1ZjU3YTgyZTYxMWRjNzgxZWJlOTFkZmI1NA==");
+        client.setModelName("ad_rank_widedeep_v8_tf115");
+    }
 
     private final String[] userFeatures = {
             "viewall", "clickall", "converall", "incomeall", "ctr_all", "ctcvr_all", "cvr_all"
@@ -61,7 +76,7 @@ public class PAIModel extends Model {
                 request.addFeed(entry.getKey(), TFDataType.DT_DOUBLE, new long[]{items.size()}, entry.getValue());
             }
             request.addFetch("probs");
-            TFResponse response = model.predict(request);
+            TFResponse response = client.predict(request);
             List<Float> result = response.getFloatVals("probs");
             if (!CollectionUtils.isEmpty(result)) {
                 return result;
@@ -72,19 +87,4 @@ public class PAIModel extends Model {
         return new ArrayList<>(Collections.nCopies(items.size(), 0.0f));
     }
 
-    @Override
-    public int getModelSize() {
-        if (this.model == null)
-            return 0;
-        return 0;
-    }
-
-    @Override
-    public boolean loadFromStream(InputStreamReader in) throws Exception {
-        model = new PredictClient(new HttpConfig());
-        model.setEndpoint("1894469520484605.cn-hangzhou.pai-eas.aliyuncs.com");
-        model.setToken("NmFhZGRlMjBmOGVhZTM1ZjU3YTgyZTYxMWRjNzgxZWJlOTFkZmI1NA==");
-        model.setModelName("ad_rank_widedeep_v8_tf115");
-        return true;
-    }
 }

+ 0 - 1
ad-engine-server/src/main/resources/ad_score_config_pai_20250214.conf

@@ -2,6 +2,5 @@ scorer-config = {
   pai-score-config = {
     scorer-name = "com.tzld.piaoquan.ad.engine.service.score.scorer.PAIScorer"
     scorer-priority = 99
-    model-path = "pai"
   }
 }

+ 3 - 8
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/scorer/PAIScorer.java

@@ -2,6 +2,7 @@ package com.tzld.piaoquan.ad.engine.service.score.scorer;
 
 
 import com.google.common.collect.Lists;
+import com.tzld.piaoquan.ad.engine.commons.score.AbstractScorer;
 import com.tzld.piaoquan.ad.engine.commons.score.BaseXGBoostModelScorer;
 import com.tzld.piaoquan.ad.engine.commons.score.ScoreParam;
 import com.tzld.piaoquan.ad.engine.commons.score.ScorerConfigInfo;
@@ -19,7 +20,7 @@ import org.springframework.stereotype.Component;
 import java.util.*;
 import java.util.concurrent.*;
 
-public class PAIScorer extends BaseXGBoostModelScorer {
+public class PAIScorer extends AbstractScorer {
 
     private final static Logger LOGGER = LoggerFactory.getLogger(PAIScorer.class);
 
@@ -28,11 +29,6 @@ public class PAIScorer extends BaseXGBoostModelScorer {
         super(configInfo);
     }
 
-    @Override
-    public void loadModel() {
-        doLoadModel(PAIModel.class);
-    }
-
     @Override
     public List<AdRankItem> scoring(final ScoreParam param,
                                     final UserAdFeature userAdFeature,
@@ -61,8 +57,7 @@ public class PAIScorer extends BaseXGBoostModelScorer {
                                         final Map<String, String> userFeatureMap,
                                         final List<AdRankItem> items) {
         long startTime = System.currentTimeMillis();
-        PAIModel model = (PAIModel) this.getModel();
-        LOGGER.debug("model size: [{}]", model.getModelSize());
+        PAIModel model = PAIModel.getModel();
         // 所有都参与打分,按照ctr排序
         multipleCtrScore(items, userFeatureMap, sceneFeatureMap, model);