ソースを参照

新增广告阈值pid控制

gufengshou1 1 年間 前
コミット
bfd8f062d8

+ 155 - 0
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/predict/container/PredictPidContainer.java

@@ -0,0 +1,155 @@
+package com.tzld.piaoquan.ad.engine.service.predict.container;
+
+import com.alibaba.fastjson.JSONObject;
+import com.alibaba.fastjson.TypeReference;
+import com.aliyun.oss.OSS;
+import com.aliyun.oss.OSSClientBuilder;
+import com.aliyun.oss.common.auth.CredentialsProvider;
+import com.aliyun.oss.common.auth.DefaultCredentialProvider;
+import com.aliyun.oss.model.CopyObjectResult;
+import com.aliyun.oss.model.OSSObject;
+import com.aliyun.oss.model.PutObjectResult;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.beans.factory.annotation.Value;
+import org.springframework.stereotype.Component;
+
+import javax.annotation.PostConstruct;
+import java.io.BufferedReader;
+import java.io.ByteArrayInputStream;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.util.Arrays;
+import java.util.Date;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+
+@Component
+public class PredictPidContainer {
+    private final static Logger log = LoggerFactory.getLogger(PredictPidContainer.class);
+
+    private static final int SCHEDULE_PERIOD = 10;
+    private static final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1);
+    @Value("${model.oss.internal.endpoint:oss-cn-hangzhou.aliyuncs.com}")
+    String endpoint = "";
+    @Value("${model.oss.accessKeyId:LTAIP6x1l3DXfSxm}")
+    String accessKeyId = "";
+    @Value("${model.oss.accessKetSecret:KbTaM9ars4OX3PMS6Xm7rtxGr1FLon}")
+    String accessKetSecret = "";
+    @Value("${model.oss.bucketName:art-recommend}")
+    String bucketName = "";
+
+    @Value("${model.oss.pid.predict.filename.lambda:pid/predict_lambda}")
+    String lambdaFileName = "";
+
+    @Value("${model.oss.pid.predict.filename.dThreshold:pid/predict_dThreshold}")
+    String dThresholdFileName = "";
+
+    @Value("${ad.model.pid.predict_threshold.kp:0.5}")
+    Double kp = 0d;
+
+    @Value("${ad.model.pid.predict_threshold.ki:0.05}")
+    Double ki = 0d;
+
+    @Value("${ad.model.pid.predict_threshold.kd:0.005}")
+    Double kd = 0d;
+
+    OSS client;
+
+    private static ConcurrentHashMap<String,Double>  lambdaCache=new ConcurrentHashMap<>();
+    private Date cacheDate;
+
+    @PostConstruct
+    private void init(){
+        instanceClient();
+        final Runnable task = new Runnable() {
+            public void run() {
+                try {
+                    loadAndCalIfNeed();
+                }catch (Exception e){
+                    e.printStackTrace();
+                }
+            }
+        };
+        scheduler.scheduleAtFixedRate(task, 0, SCHEDULE_PERIOD, TimeUnit.MINUTES); // 10分钟
+    }
+
+    private void instanceClient(){
+        CredentialsProvider credentialsProvider = new DefaultCredentialProvider(accessKeyId, accessKetSecret);
+        this.client = new OSSClientBuilder().build(endpoint, credentialsProvider);
+    }
+
+    private void loadAndCalIfNeed(){
+        loadLambdaFile();
+        OSSObject dCpaFileOjb=client.getObject(bucketName,dThresholdFileName);
+        if(cacheDate==null||dCpaFileOjb.getObjectMetadata().getLastModified().after(cacheDate)){
+            calNewLambda(dCpaFileOjb);
+            writeLambdaFileToOss();
+        }
+    }
+
+    private void calNewLambda(OSSObject object) {
+        try {
+            InputStream is=object.getObjectContent();
+            InputStreamReader isr=new InputStreamReader(is);
+            BufferedReader bufferedReader = new BufferedReader(isr);
+            String line = null;
+            while ((line = bufferedReader.readLine()) != null){
+                try {
+                    String[] cols=line.split(",");
+                    String group=cols[0];
+                    Double lambdaNew=lambdaCache.getOrDefault(group,0d)+
+                            kp*Double.parseDouble(cols[1])+ki*Double.parseDouble(cols[2])+kd*Double.parseDouble(cols[3]);
+                    lambdaCache.put(group,lambdaNew);
+                }catch (Exception e){
+                    e.printStackTrace();
+                }
+            }
+        }catch (Exception e){
+            log.error("svc=calNewPredictLambda status=failed error={}", Arrays.toString(e.getStackTrace()));
+        }
+    }
+
+    private void writeLambdaFileToOss(){
+        //先不考虑各种更新失败及重复更新问题。
+        try {
+            String tempFile=lambdaFileName+"_temp";
+            String content= JSONObject.toJSONString(lambdaCache);
+            PutObjectResult putObjectResult=client.putObject(bucketName,tempFile,new ByteArrayInputStream(content.getBytes()));
+            CopyObjectResult copyObjectResult=client.copyObject(bucketName, tempFile, bucketName, lambdaFileName);
+            this.cacheDate= copyObjectResult.getLastModified();
+            client.deleteObject(bucketName, tempFile);
+        }catch (Exception e){
+            log.error("svc=writePredictLambdaFileToOss status=failed error={}", Arrays.toString(e.getStackTrace()));
+            e.printStackTrace();
+        }
+    }
+
+    private void loadLambdaFile(){
+        try {
+            OSSObject object=client.getObject(bucketName,lambdaFileName);
+            if(object==null) return;
+            if(cacheDate!=null&& !cacheDate.before(object.getObjectMetadata().getLastModified())) return;
+//            if(cacheDate!=null&& cacheDate.after(object.getObjectMetadata().getLastModified())) return;
+            StringBuilder builder=new StringBuilder();
+            InputStream is=object.getObjectContent();
+            InputStreamReader isr=new InputStreamReader(is);
+            BufferedReader bufferedReader = new BufferedReader(isr);
+            String line = null;
+            while ((line=bufferedReader.readLine())!=null){
+                builder.append(line);
+            }
+            lambdaCache=JSONObject.parseObject(builder.toString(),new TypeReference<ConcurrentHashMap<String,Double>>(){});
+            this.cacheDate=object.getObjectMetadata().getLastModified();
+        }catch (Exception e){
+            log.error("svc=loadPredictLambdaFile status=failed error={}", Arrays.toString(e.getStackTrace()));
+            e.printStackTrace();
+        }
+    }
+
+    public static Double getPidLambda(String group){
+        return lambdaCache.getOrDefault(group,0d);
+    }
+}

+ 5 - 1
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/predict/model/threshold/ScoreV2ThresholdPredictModel.java

@@ -6,6 +6,7 @@ import com.tzld.piaoquan.ad.engine.commons.score.AdConfig;
 import com.tzld.piaoquan.ad.engine.commons.score.ScoreParam;
 import com.tzld.piaoquan.ad.engine.commons.score.ScorerUtils;
 import com.tzld.piaoquan.ad.engine.commons.util.CommonCollectionUtils;
+import com.tzld.piaoquan.ad.engine.service.predict.container.PredictPidContainer;
 import com.tzld.piaoquan.ad.engine.service.predict.container.ThresholdModelContainer;
 import com.tzld.piaoquan.ad.engine.service.predict.impl.PredictModelServiceImpl;
 import com.tzld.piaoquan.ad.engine.service.predict.param.ThresholdPredictModelParam;
@@ -92,6 +93,9 @@ public class ScoreV2ThresholdPredictModel extends ThresholdPredictModel {
                 scoreParam.getExtraParam().getOrDefault("ScoreV2ThresholdPredict_"+modelParam.getAppType(),threshold).toString()
         );
         int adPredict;
+        //加入pid逻辑
+        realThreshold=realThreshold+ PredictPidContainer.getPidLambda(
+                scoreParam.getExtraParam().getOrDefault("predict_test_id","default")+"_"+modelParam.getAppType());
         if (maxItem != null && maxItem.getScore() < realThreshold) {
             // If final score is below threshold, do not show the ad
             adPredict = 1;
@@ -110,7 +114,7 @@ public class ScoreV2ThresholdPredictModel extends ThresholdPredictModel {
         }
 
         Map<String, Object> result = new HashMap<>();
-        result.put("threshold", threshold);
+        result.put("threshold", realThreshold);
         result.put("score", maxItem == null ? -1 : maxItem.getScore());
         result.put("ad_predict", adPredict);