Explorar o código

ScoreParam增加extraParam字段以便获取实验参数动态调参

gufengshou1 hai 1 ano
pai
achega
795a56b1ce

+ 2 - 0
ad-engine-commons/src/main/java/com/tzld/piaoquan/ad/engine/commons/score/ScoreParam.java

@@ -2,6 +2,7 @@ package com.tzld.piaoquan.ad.engine.commons.score;
 
 import com.tzld.piaoquan.recommend.feature.domain.ad.base.AdRequestContext;
 import lombok.Data;
+import java.util.*;
 
 /**
  * @author dyp
@@ -15,6 +16,7 @@ public class ScoreParam {
     private String uid;
     private String city;
     private String province;
+    private Map<String,Object> extraParam=new HashMap<>();
 
 
 }

+ 24 - 0
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/predict/impl/PredictModelServiceImpl.java

@@ -2,6 +2,7 @@ package com.tzld.piaoquan.ad.engine.service.predict.impl;
 
 import com.alibaba.fastjson.JSONArray;
 import com.alibaba.fastjson.JSONObject;
+import com.alibaba.fastjson.TypeReference;
 import com.google.common.reflect.TypeToken;
 import com.tzld.piaoquan.ad.engine.commons.enums.AppTypeEnum;
 import com.tzld.piaoquan.ad.engine.commons.redis.AlgorithmRedisHelper;
@@ -28,6 +29,7 @@ import org.springframework.beans.factory.annotation.Value;
 import org.springframework.stereotype.Service;
 
 import java.util.*;
+import java.util.concurrent.ConcurrentHashMap;
 
 @Service
 public class PredictModelServiceImpl implements PredictModelService {
@@ -55,6 +57,9 @@ public class PredictModelServiceImpl implements PredictModelService {
     @Value("${ad.predict.break.exp.code:0}")
     private String adPredictBreakExpCode;
 
+    @Value("${ad.predict.param.testIds:0}")
+    private String testIds;
+
     public Map<String, Object> adPredict(ThresholdPredictModelRequestParam requestParam) {
 
         boolean isHit = false;
@@ -196,6 +201,8 @@ public class PredictModelServiceImpl implements PredictModelService {
         modelParam.setAbTestConfigTag(abTestConfigTag);
         modelParam.setAbtestParam(abtestParam);
         modelParam.setMidGroup(midGroup);
+        modelParam.setExtraParam(new HashMap<>());
+        setExtraParam(modelParam);
         result = ThresholdModelContainer.
                 getThresholdPredictModel("modelV2")
                 .predict(modelParam);
@@ -314,6 +321,8 @@ public class PredictModelServiceImpl implements PredictModelService {
         modelParam.setAbTestConfigTag(abTestConfigTag);
         modelParam.setAbtestParam(abtestParam);
         modelParam.setMidGroup(midGroup);
+        modelParam.setExtraParam(new HashMap<>());
+        setExtraParam(modelParam);
         Object thresholdMixFunc = abtestParam.getOrDefault("threshold_mix_func", "basic");
         result = ThresholdModelContainer.
                 getThresholdPredictModel(thresholdMixFunc.toString())
@@ -350,4 +359,19 @@ public class PredictModelServiceImpl implements PredictModelService {
                 .predict(modelParam);
     }
 
+
+    public void setExtraParam(ThresholdPredictModelParam modelParam){
+        String[] ids=testIds.split(",");
+        List<String> idList=Arrays.asList(ids);
+        List<Map<String,Object>> mapList=(List)modelParam.getAbExpInfo().get("ab_test002");
+        Map<String,Object> configMap;
+        for(Map<String,Object> map:mapList){
+            if(idList.contains(map.getOrDefault("abExpCode",""))){
+                configMap=JSONObject.parseObject(map.get("configValue").toString(),new TypeReference<Map<String,Object>>(){});
+                for(Map.Entry<String,Object> entry:configMap.entrySet()){
+                    modelParam.getExtraParam().put(entry.getKey(),entry.getValue());
+                }
+            }
+        }
+    }
 }

+ 1 - 1
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/predict/model/threshold/ScoreV2ThresholdPredictModel.java

@@ -70,7 +70,7 @@ public class ScoreV2ThresholdPredictModel extends ThresholdPredictModel {
         scoreParam.setUid("");
         scoreParam.setProvince(modelParam.getRegion());
         scoreParam.setCity(modelParam.getCity());
-
+        scoreParam.setExtraParam(modelParam.getExtraParam());
 
         List<AdRankItem> scoreResult = ScorerUtils
                 .getScorerPipeline(BREAK_CONFIG)

+ 2 - 0
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/predict/param/ThresholdPredictModelParam.java

@@ -8,6 +8,7 @@ import lombok.Data;
 import lombok.NoArgsConstructor;
 
 import java.util.Date;
+import java.util.HashMap;
 import java.util.Map;
 
 /**
@@ -40,4 +41,5 @@ public class ThresholdPredictModelParam {
     String city = "-1";
     MachineInfoParam machineInfo = new MachineInfoParam();
 
+    Map<String,Object> extraParam=new HashMap<>();
 }

+ 7 - 5
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/VlogMergeBreakScorer.java

@@ -30,6 +30,10 @@ public class VlogMergeBreakScorer extends BaseLRModelScorer {
     public List<AdRankItem> scoring(final ScoreParam param,
                                     final UserAdFeature userFeature,
                                     final List<AdRankItem> rankItems) {
+        double a,b,c;
+        a=Double.parseDouble(param.getExtraParam().getOrDefault("mergeBreakCtrCvrW",0.2).toString());
+        b=Double.parseDouble(param.getExtraParam().getOrDefault("mergeBreakStrRosW",1d).toString());
+        c=Double.parseDouble(param.getExtraParam().getOrDefault("mergeBreakW",1d).toString());
 
         for (AdRankItem item : rankItems) {
             double ctr = item.getCtr();
@@ -37,15 +41,13 @@ public class VlogMergeBreakScorer extends BaseLRModelScorer {
             double str = item.getStr();
             double ros = item.getRos();
 
-            double a = 0.2;
-            double b = 1.0;
-            double c = 1.0;
-
             BigDecimal ctrCvr = new BigDecimal(Math.pow(70 * ctr * cvr, a));
             BigDecimal strRos = new BigDecimal(Math.pow(str * ros, b));
             BigDecimal breakRate = new BigDecimal(Math.pow(item.getBreakRate(), c));
             try {
-                log.info("svc=scoring modelName=modelV2 ctr={} cvr={} str={} ros={}", item.getCtr(),item.getCvr(),item.getStr(),item.getRos());
+                log.info("svc=scoring modelName=modelV2 a={} b={} c={} ctr={} cvr={} str={} ros={}",
+                        a,b,c,
+                        item.getCtr(),item.getCvr(),item.getStr(),item.getRos());
             }catch (Exception e){
 
             }