Jelajahi Sumber

Merge branch 'master' into pre-master

# Conflicts:
#	ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/predict/model/threshold/ScoreV2ThresholdPredictModel.java
gufengshou1 1 tahun lalu
induk
melakukan
2de12a7db0

+ 155 - 0
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/predict/container/PredictPidContainer.java

@@ -0,0 +1,155 @@
+package com.tzld.piaoquan.ad.engine.service.predict.container;
+
+import com.alibaba.fastjson.JSONObject;
+import com.alibaba.fastjson.TypeReference;
+import com.aliyun.oss.OSS;
+import com.aliyun.oss.OSSClientBuilder;
+import com.aliyun.oss.common.auth.CredentialsProvider;
+import com.aliyun.oss.common.auth.DefaultCredentialProvider;
+import com.aliyun.oss.model.CopyObjectResult;
+import com.aliyun.oss.model.OSSObject;
+import com.aliyun.oss.model.PutObjectResult;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.beans.factory.annotation.Value;
+import org.springframework.stereotype.Component;
+
+import javax.annotation.PostConstruct;
+import java.io.BufferedReader;
+import java.io.ByteArrayInputStream;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.util.Arrays;
+import java.util.Date;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+
+@Component
+public class PredictPidContainer {
+    private final static Logger log = LoggerFactory.getLogger(PredictPidContainer.class);
+
+    private static final int SCHEDULE_PERIOD = 10;
+    private static final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1);
+    @Value("${model.oss.internal.endpoint:oss-cn-hangzhou.aliyuncs.com}")
+    String endpoint = "";
+    @Value("${model.oss.accessKeyId:LTAIP6x1l3DXfSxm}")
+    String accessKeyId = "";
+    @Value("${model.oss.accessKetSecret:KbTaM9ars4OX3PMS6Xm7rtxGr1FLon}")
+    String accessKetSecret = "";
+    @Value("${model.oss.bucketName:art-recommend}")
+    String bucketName = "";
+
+    @Value("${model.oss.pid.predict.filename.lambda:pid/predict_lambda}")
+    String lambdaFileName = "";
+
+    @Value("${model.oss.pid.predict.filename.dThreshold:pid/predict_dThreshold.txt}")
+    String dThresholdFileName = "";
+
+    @Value("${ad.model.pid.predict_threshold.kp:0.8}")
+    Double kp = 0d;
+
+    @Value("${ad.model.pid.predict_threshold.ki:0.01}")
+    Double ki = 0d;
+
+    @Value("${ad.model.pid.predict_threshold.kd:0.002}")
+    Double kd = 0d;
+
+    OSS client;
+
+    private static ConcurrentHashMap<String,Double>  lambdaCache=new ConcurrentHashMap<>();
+    private Date cacheDate;
+
+    @PostConstruct
+    private void init(){
+        instanceClient();
+        final Runnable task = new Runnable() {
+            public void run() {
+                try {
+                    loadAndCalIfNeed();
+                }catch (Exception e){
+                    e.printStackTrace();
+                }
+            }
+        };
+        scheduler.scheduleAtFixedRate(task, 0, SCHEDULE_PERIOD, TimeUnit.MINUTES); // 10分钟
+    }
+
+    private void instanceClient(){
+        CredentialsProvider credentialsProvider = new DefaultCredentialProvider(accessKeyId, accessKetSecret);
+        this.client = new OSSClientBuilder().build(endpoint, credentialsProvider);
+    }
+
+    private void loadAndCalIfNeed(){
+        loadLambdaFile();
+        OSSObject dCpaFileOjb=client.getObject(bucketName,dThresholdFileName);
+        if(cacheDate==null||dCpaFileOjb.getObjectMetadata().getLastModified().after(cacheDate)){
+            calNewLambda(dCpaFileOjb);
+            writeLambdaFileToOss();
+        }
+    }
+
+    private void calNewLambda(OSSObject object) {
+        try {
+            InputStream is=object.getObjectContent();
+            InputStreamReader isr=new InputStreamReader(is);
+            BufferedReader bufferedReader = new BufferedReader(isr);
+            String line = null;
+            while ((line = bufferedReader.readLine()) != null){
+                try {
+                    String[] cols=line.split(",");
+                    String group=cols[0].trim();
+                    Double lambdaNew=lambdaCache.getOrDefault(group,0d)+
+                            kp*Double.parseDouble(cols[1])+ki*Double.parseDouble(cols[2])+kd*Double.parseDouble(cols[3]);
+                    lambdaCache.put(group,lambdaNew);
+                }catch (Exception e){
+                    e.printStackTrace();
+                }
+            }
+        }catch (Exception e){
+            log.error("svc=calNewPredictLambda status=failed error={}", Arrays.toString(e.getStackTrace()));
+        }
+    }
+
+    private void writeLambdaFileToOss(){
+        //先不考虑各种更新失败及重复更新问题。
+        try {
+            String tempFile=lambdaFileName+"_temp";
+            String content= JSONObject.toJSONString(lambdaCache);
+            PutObjectResult putObjectResult=client.putObject(bucketName,tempFile,new ByteArrayInputStream(content.getBytes()));
+            CopyObjectResult copyObjectResult=client.copyObject(bucketName, tempFile, bucketName, lambdaFileName);
+            this.cacheDate= copyObjectResult.getLastModified();
+            client.deleteObject(bucketName, tempFile);
+        }catch (Exception e){
+            log.error("svc=writePredictLambdaFileToOss status=failed error={}", Arrays.toString(e.getStackTrace()));
+            e.printStackTrace();
+        }
+    }
+
+    private void loadLambdaFile(){
+        try {
+            OSSObject object=client.getObject(bucketName,lambdaFileName);
+            if(object==null) return;
+            if(cacheDate!=null&& !cacheDate.before(object.getObjectMetadata().getLastModified())) return;
+//            if(cacheDate!=null&& cacheDate.after(object.getObjectMetadata().getLastModified())) return;
+            StringBuilder builder=new StringBuilder();
+            InputStream is=object.getObjectContent();
+            InputStreamReader isr=new InputStreamReader(is);
+            BufferedReader bufferedReader = new BufferedReader(isr);
+            String line = null;
+            while ((line=bufferedReader.readLine())!=null){
+                builder.append(line);
+            }
+            lambdaCache=JSONObject.parseObject(builder.toString(),new TypeReference<ConcurrentHashMap<String,Double>>(){});
+            this.cacheDate=object.getObjectMetadata().getLastModified();
+        }catch (Exception e){
+            log.error("svc=loadPredictLambdaFile status=failed error={}", Arrays.toString(e.getStackTrace()));
+            e.printStackTrace();
+        }
+    }
+
+    public static Double getPidLambda(String group){
+        return lambdaCache.getOrDefault(group,0d);
+    }
+}

+ 16 - 7
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/predict/container/ThresholdModelContainer.java

@@ -30,7 +30,7 @@ public class ThresholdModelContainer {
     private double position;
 
     public static Map<String,ThresholdPredictModel> modelMap=new HashMap<>();
-    public static MergingDigest mergingDigest = new MergingDigest(1000);
+    public static Map<Integer,MergingDigest> mergingDigestMap=new HashMap<>();
 
     private static final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1);
     @PostConstruct
@@ -39,7 +39,12 @@ public class ThresholdModelContainer {
         beanMap.forEach((s,model)->{
             modelMap.put(model.getName(), model);
         });
-
+        //只关注重点app
+        mergingDigestMap.put(0, new MergingDigest(10000));
+        mergingDigestMap.put(3, new MergingDigest(10000));
+        mergingDigestMap.put(4, new MergingDigest(10000));
+        mergingDigestMap.put(5, new MergingDigest(10000));
+        mergingDigestMap.put(21, new MergingDigest(10000));
         final Runnable task = new Runnable() {
             public void run() {
                 try {
@@ -60,17 +65,21 @@ public class ThresholdModelContainer {
         return modelMap.get("basic");
     }
 
-    public static void mergingDigestAddScore(Double score){
-        mergingDigest.add(score);
+    public static void mergingDigestAddScore(Integer appType,Double score){
+        mergingDigestMap.getOrDefault(appType,new MergingDigest(1)).add(score);
     }
 
-    public static double getThresholdByTDigest(Double sortPosition){
-        return mergingDigest.quantile(sortPosition);
+    public static double getThresholdByTDigest(Integer appType,Double sortPosition){
+        return  mergingDigestMap.getOrDefault(appType,new MergingDigest(1)).quantile(sortPosition);
     }
 
     public void printDigestThreshold(){
         try {
-            log.info("svc=predict modelName=modelV2 mergingDigestThreshold={}", ThresholdModelContainer.getThresholdByTDigest(position));
+            for(Map.Entry<Integer,MergingDigest> entry:mergingDigestMap.entrySet()){
+                log.info("svc=printDigestThreshold modelName=modelV2 appType={} mergingDigestThreshold={}"
+                        , entry.getKey(),entry.getValue().quantile(position));
+            }
+
         }catch (Exception e){
             e.printStackTrace();
         }

+ 25 - 1
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/predict/impl/PredictModelServiceImpl.java

@@ -59,9 +59,25 @@ public class PredictModelServiceImpl implements PredictModelService {
 
     @Value("${ad.predict.param.testIds:0}")
     private String testIds;
+    @Value("${ad.predict.without.ad.video_ids:0}")
+    private String withoutAdVideoIds;
 
     public Map<String, Object> adPredict(ThresholdPredictModelRequestParam requestParam) {
-
+        String[] withoutAdVideoIdsArr=withoutAdVideoIds.split(",");
+        for(String videoId:withoutAdVideoIdsArr){
+            if(videoId.equals(requestParam.getVideoId()+"")){
+                if(requestParam.getAppType().equals(0)
+                        ||requestParam.getAppType().equals(4)
+                        ||requestParam.getAppType().equals(5)
+                        ||requestParam.getAppType().equals(21)
+                ){
+                    Map<String,Object> result=new HashMap<>();
+                    result.put("ad_predict", 1);
+                    result.put("no_ad_strategy", "no_ad_with_video_in_white_list");
+                    return result;
+                }
+            }
+        }
         boolean isHit = false;
 
         try {
@@ -364,6 +380,14 @@ public class PredictModelServiceImpl implements PredictModelService {
         String[] ids=testIds.split(",");
         List<String> idList=Arrays.asList(ids);
         List<Map<String,Object>> mapList=(List)modelParam.getAbExpInfo().get("ab_test002");
+        Collections.sort(mapList,new Comparator<Map<String, Object>>() {
+            @Override
+            public int compare(Map<String, Object> map1, Map<String, Object> map2) {
+                int abExpCode1 =Integer.parseInt(map1.get("abExpCode").toString()) ;
+                int abExpCode2 =Integer.parseInt(map2.get("abExpCode").toString());
+                return Integer.compare(abExpCode1, abExpCode2);
+            }
+        });
         Map<String,Object> configMap;
         for(Map<String,Object> map:mapList){
             if(idList.contains(map.getOrDefault("abExpCode",""))){

+ 13 - 2
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/predict/model/threshold/ScoreV2ThresholdPredictModel.java

@@ -6,6 +6,7 @@ import com.tzld.piaoquan.ad.engine.commons.score.AdConfig;
 import com.tzld.piaoquan.ad.engine.commons.score.ScoreParam;
 import com.tzld.piaoquan.ad.engine.commons.score.ScorerUtils;
 import com.tzld.piaoquan.ad.engine.commons.util.CommonCollectionUtils;
+import com.tzld.piaoquan.ad.engine.service.predict.container.PredictPidContainer;
 import com.tzld.piaoquan.ad.engine.service.predict.container.ThresholdModelContainer;
 import com.tzld.piaoquan.ad.engine.service.predict.impl.PredictModelServiceImpl;
 import com.tzld.piaoquan.ad.engine.service.predict.param.ThresholdPredictModelParam;
@@ -92,6 +93,9 @@ public class ScoreV2ThresholdPredictModel extends ThresholdPredictModel {
                 scoreParam.getExtraParam().getOrDefault("ScoreV2ThresholdPredict_"+modelParam.getAppType(),threshold).toString()
         );
         int adPredict;
+        //加入pid逻辑
+        realThreshold=realThreshold+ PredictPidContainer.getPidLambda(
+                scoreParam.getExtraParam().getOrDefault("predict_test_id","default")+"_"+modelParam.getAppType());
         if (maxItem != null && maxItem.getScore() < realThreshold) {
             // If final score is below threshold, do not show the ad
             adPredict = 1;
@@ -100,14 +104,21 @@ public class ScoreV2ThresholdPredictModel extends ThresholdPredictModel {
             adPredict = 2;
         }
         if(maxItem != null){
-            ThresholdModelContainer.mergingDigestAddScore(maxItem.getScore());
+            ThresholdModelContainer.mergingDigestAddScore(modelParam.getAppType(),maxItem.getScore());
+            //删除多余打印
+            maxItem.setItemFeature(null);
+            maxItem.setLrSampleString(null);
+            maxItem.setLrSampleStringOrgin(null);
+            log.info("svc=ScoreV2ThresholdPredictModel_predict modelName=ScoreV2ThresholdPredictModel maxItem={} extraParam={} app_type={} realThreshold={}",
+                    JSONObject.toJSONString(maxItem), JSONObject.toJSONString(scoreParam.getExtraParam()),modelParam.getAppType(),realThreshold);
         }
 
         Map<String, Object> result = new HashMap<>();
-        result.put("threshold", threshold);
+        result.put("threshold", realThreshold);
         result.put("score", maxItem == null ? -1 : maxItem.getScore());
         result.put("ad_predict", adPredict);
 
+
         return result;
     }
 }

+ 6 - 10
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/VlogMergeBreakScorer.java

@@ -30,9 +30,11 @@ public class VlogMergeBreakScorer extends BaseLRModelScorer {
     public List<AdRankItem> scoring(final ScoreParam param,
                                     final UserAdFeature userFeature,
                                     final List<AdRankItem> rankItems) {
-        double a,b,c;
+        double a,b,c,strW,rosW;
         a=Double.parseDouble(param.getExtraParam().getOrDefault("mergeBreakCtrCvrW",0.2).toString());
         b=Double.parseDouble(param.getExtraParam().getOrDefault("mergeBreakStrRosW",1d).toString());
+        strW=Double.parseDouble(param.getExtraParam().getOrDefault("mergeBreakStrW",b).toString());
+        rosW=Double.parseDouble(param.getExtraParam().getOrDefault("mergeBreakRosW",b).toString());
         c=Double.parseDouble(param.getExtraParam().getOrDefault("mergeBreakW",1d).toString());
 
         for (AdRankItem item : rankItems) {
@@ -42,16 +44,10 @@ public class VlogMergeBreakScorer extends BaseLRModelScorer {
             double ros = item.getRos();
 
             BigDecimal ctrCvr = new BigDecimal(Math.pow(70 * ctr * cvr, a));
-            BigDecimal strRos = new BigDecimal(Math.pow(str * ros, b));
+            BigDecimal strBG = new BigDecimal(Math.pow(str, strW));
+            BigDecimal rosBG = new BigDecimal(Math.pow(ros, rosW));
             BigDecimal breakRate = new BigDecimal(Math.pow(item.getBreakRate(), c));
-            try {
-                log.info("svc=scoring modelName=modelV2 a={} b={} c={} ctr={} cvr={} str={} ros={}",
-                        a,b,c,
-                        item.getCtr(),item.getCvr(),item.getStr(),item.getRos());
-            }catch (Exception e){
-
-            }
-            BigDecimal score = ctrCvr.divide(strRos.multiply(breakRate), 5, BigDecimal.ROUND_HALF_UP);
+            BigDecimal score = ctrCvr.divide(strBG.multiply(rosBG).multiply(breakRate), 5, BigDecimal.ROUND_HALF_UP);
 
             item.setScore(score.doubleValue());
         }