Browse Source

Merge branch 'feature_gufengshou_20240207_predict_threshold_pid'

gufengshou1 1 year ago
parent
commit
e06c19dd86

+ 155 - 0
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/predict/container/PredictPidContainer.java

@@ -0,0 +1,155 @@
+package com.tzld.piaoquan.ad.engine.service.predict.container;
+
+import com.alibaba.fastjson.JSONObject;
+import com.alibaba.fastjson.TypeReference;
+import com.aliyun.oss.OSS;
+import com.aliyun.oss.OSSClientBuilder;
+import com.aliyun.oss.common.auth.CredentialsProvider;
+import com.aliyun.oss.common.auth.DefaultCredentialProvider;
+import com.aliyun.oss.model.CopyObjectResult;
+import com.aliyun.oss.model.OSSObject;
+import com.aliyun.oss.model.PutObjectResult;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.beans.factory.annotation.Value;
+import org.springframework.stereotype.Component;
+
+import javax.annotation.PostConstruct;
+import java.io.BufferedReader;
+import java.io.ByteArrayInputStream;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.util.Arrays;
+import java.util.Date;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+
+@Component
+public class PredictPidContainer {
+    private final static Logger log = LoggerFactory.getLogger(PredictPidContainer.class);
+
+    private static final int SCHEDULE_PERIOD = 10;
+    private static final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1);
+    @Value("${model.oss.internal.endpoint:oss-cn-hangzhou.aliyuncs.com}")
+    String endpoint = "";
+    @Value("${model.oss.accessKeyId:LTAIP6x1l3DXfSxm}")
+    String accessKeyId = "";
+    @Value("${model.oss.accessKetSecret:KbTaM9ars4OX3PMS6Xm7rtxGr1FLon}")
+    String accessKetSecret = "";
+    @Value("${model.oss.bucketName:art-recommend}")
+    String bucketName = "";
+
+    @Value("${model.oss.pid.predict.filename.lambda:pid/predict_lambda}")
+    String lambdaFileName = "";
+
+    @Value("${model.oss.pid.predict.filename.dThreshold:pid/predict_dThreshold.txt}")
+    String dThresholdFileName = "";
+
+    @Value("${ad.model.pid.predict_threshold.kp:0.8}")
+    Double kp = 0d;
+
+    @Value("${ad.model.pid.predict_threshold.ki:0.01}")
+    Double ki = 0d;
+
+    @Value("${ad.model.pid.predict_threshold.kd:0.002}")
+    Double kd = 0d;
+
+    OSS client;
+
+    private static ConcurrentHashMap<String,Double>  lambdaCache=new ConcurrentHashMap<>();
+    private Date cacheDate;
+
+    @PostConstruct
+    private void init(){
+        instanceClient();
+        final Runnable task = new Runnable() {
+            public void run() {
+                try {
+                    loadAndCalIfNeed();
+                }catch (Exception e){
+                    e.printStackTrace();
+                }
+            }
+        };
+        scheduler.scheduleAtFixedRate(task, 0, SCHEDULE_PERIOD, TimeUnit.MINUTES); // 10分钟
+    }
+
+    private void instanceClient(){
+        CredentialsProvider credentialsProvider = new DefaultCredentialProvider(accessKeyId, accessKetSecret);
+        this.client = new OSSClientBuilder().build(endpoint, credentialsProvider);
+    }
+
+    private void loadAndCalIfNeed(){
+        loadLambdaFile();
+        OSSObject dCpaFileOjb=client.getObject(bucketName,dThresholdFileName);
+        if(cacheDate==null||dCpaFileOjb.getObjectMetadata().getLastModified().after(cacheDate)){
+            calNewLambda(dCpaFileOjb);
+            writeLambdaFileToOss();
+        }
+    }
+
+    private void calNewLambda(OSSObject object) {
+        try {
+            InputStream is=object.getObjectContent();
+            InputStreamReader isr=new InputStreamReader(is);
+            BufferedReader bufferedReader = new BufferedReader(isr);
+            String line = null;
+            while ((line = bufferedReader.readLine()) != null){
+                try {
+                    String[] cols=line.split(",");
+                    String group=cols[0].trim();
+                    Double lambdaNew=lambdaCache.getOrDefault(group,0d)+
+                            kp*Double.parseDouble(cols[1])+ki*Double.parseDouble(cols[2])+kd*Double.parseDouble(cols[3]);
+                    lambdaCache.put(group,lambdaNew);
+                }catch (Exception e){
+                    e.printStackTrace();
+                }
+            }
+        }catch (Exception e){
+            log.error("svc=calNewPredictLambda status=failed error={}", Arrays.toString(e.getStackTrace()));
+        }
+    }
+
+    private void writeLambdaFileToOss(){
+        //先不考虑各种更新失败及重复更新问题。
+        try {
+            String tempFile=lambdaFileName+"_temp";
+            String content= JSONObject.toJSONString(lambdaCache);
+            PutObjectResult putObjectResult=client.putObject(bucketName,tempFile,new ByteArrayInputStream(content.getBytes()));
+            CopyObjectResult copyObjectResult=client.copyObject(bucketName, tempFile, bucketName, lambdaFileName);
+            this.cacheDate= copyObjectResult.getLastModified();
+            client.deleteObject(bucketName, tempFile);
+        }catch (Exception e){
+            log.error("svc=writePredictLambdaFileToOss status=failed error={}", Arrays.toString(e.getStackTrace()));
+            e.printStackTrace();
+        }
+    }
+
+    private void loadLambdaFile(){
+        try {
+            OSSObject object=client.getObject(bucketName,lambdaFileName);
+            if(object==null) return;
+            if(cacheDate!=null&& !cacheDate.before(object.getObjectMetadata().getLastModified())) return;
+//            if(cacheDate!=null&& cacheDate.after(object.getObjectMetadata().getLastModified())) return;
+            StringBuilder builder=new StringBuilder();
+            InputStream is=object.getObjectContent();
+            InputStreamReader isr=new InputStreamReader(is);
+            BufferedReader bufferedReader = new BufferedReader(isr);
+            String line = null;
+            while ((line=bufferedReader.readLine())!=null){
+                builder.append(line);
+            }
+            lambdaCache=JSONObject.parseObject(builder.toString(),new TypeReference<ConcurrentHashMap<String,Double>>(){});
+            this.cacheDate=object.getObjectMetadata().getLastModified();
+        }catch (Exception e){
+            log.error("svc=loadPredictLambdaFile status=failed error={}", Arrays.toString(e.getStackTrace()));
+            e.printStackTrace();
+        }
+    }
+
+    public static Double getPidLambda(String group){
+        return lambdaCache.getOrDefault(group,0d);
+    }
+}

+ 15 - 7
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/predict/container/ThresholdModelContainer.java

@@ -30,7 +30,7 @@ public class ThresholdModelContainer {
     private double position;
     private double position;
 
 
     public static Map<String,ThresholdPredictModel> modelMap=new HashMap<>();
     public static Map<String,ThresholdPredictModel> modelMap=new HashMap<>();
-    public static MergingDigest mergingDigest = new MergingDigest(1000);
+    public static Map<Integer,MergingDigest> mergingDigestMap=new HashMap<>();
 
 
     private static final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1);
     private static final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1);
     @PostConstruct
     @PostConstruct
@@ -39,7 +39,11 @@ public class ThresholdModelContainer {
         beanMap.forEach((s,model)->{
         beanMap.forEach((s,model)->{
             modelMap.put(model.getName(), model);
             modelMap.put(model.getName(), model);
         });
         });
-        mergingDigest.add(0);
+        //只关注重点app
+        mergingDigestMap.put(0, new MergingDigest(10000));
+        mergingDigestMap.put(4, new MergingDigest(10000));
+        mergingDigestMap.put(5, new MergingDigest(10000));
+        mergingDigestMap.put(21, new MergingDigest(10000));
         final Runnable task = new Runnable() {
         final Runnable task = new Runnable() {
             public void run() {
             public void run() {
                 try {
                 try {
@@ -60,17 +64,21 @@ public class ThresholdModelContainer {
         return modelMap.get("basic");
         return modelMap.get("basic");
     }
     }
 
 
-    public static void mergingDigestAddScore(Double score){
-        mergingDigest.add(score);
+    public static void mergingDigestAddScore(Integer appType,Double score){
+        mergingDigestMap.getOrDefault(appType,new MergingDigest(1)).add(score);
     }
     }
 
 
-    public static double getThresholdByTDigest(Double sortPosition){
-        return mergingDigest.quantile(sortPosition);
+    public static double getThresholdByTDigest(Integer appType,Double sortPosition){
+        return  mergingDigestMap.getOrDefault(appType,new MergingDigest(1)).quantile(sortPosition);
     }
     }
 
 
     public void printDigestThreshold(){
     public void printDigestThreshold(){
         try {
         try {
-            log.info("svc=printDigestThreshold modelName=modelV2 mergingDigestThreshold={}", ThresholdModelContainer.getThresholdByTDigest(position));
+            for(Map.Entry<Integer,MergingDigest> entry:mergingDigestMap.entrySet()){
+                log.info("svc=printDigestThreshold modelName=modelV2 appType={} mergingDigestThreshold={}"
+                        , entry.getKey(),entry.getValue().quantile(position));
+            }
+
         }catch (Exception e){
         }catch (Exception e){
             e.printStackTrace();
             e.printStackTrace();
         }
         }

+ 6 - 2
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/predict/model/threshold/ScoreV2ThresholdPredictModel.java

@@ -6,6 +6,7 @@ import com.tzld.piaoquan.ad.engine.commons.score.AdConfig;
 import com.tzld.piaoquan.ad.engine.commons.score.ScoreParam;
 import com.tzld.piaoquan.ad.engine.commons.score.ScoreParam;
 import com.tzld.piaoquan.ad.engine.commons.score.ScorerUtils;
 import com.tzld.piaoquan.ad.engine.commons.score.ScorerUtils;
 import com.tzld.piaoquan.ad.engine.commons.util.CommonCollectionUtils;
 import com.tzld.piaoquan.ad.engine.commons.util.CommonCollectionUtils;
+import com.tzld.piaoquan.ad.engine.service.predict.container.PredictPidContainer;
 import com.tzld.piaoquan.ad.engine.service.predict.container.ThresholdModelContainer;
 import com.tzld.piaoquan.ad.engine.service.predict.container.ThresholdModelContainer;
 import com.tzld.piaoquan.ad.engine.service.predict.impl.PredictModelServiceImpl;
 import com.tzld.piaoquan.ad.engine.service.predict.impl.PredictModelServiceImpl;
 import com.tzld.piaoquan.ad.engine.service.predict.param.ThresholdPredictModelParam;
 import com.tzld.piaoquan.ad.engine.service.predict.param.ThresholdPredictModelParam;
@@ -92,6 +93,9 @@ public class ScoreV2ThresholdPredictModel extends ThresholdPredictModel {
                 scoreParam.getExtraParam().getOrDefault("ScoreV2ThresholdPredict_"+modelParam.getAppType(),threshold).toString()
                 scoreParam.getExtraParam().getOrDefault("ScoreV2ThresholdPredict_"+modelParam.getAppType(),threshold).toString()
         );
         );
         int adPredict;
         int adPredict;
+        //加入pid逻辑
+        realThreshold=realThreshold+ PredictPidContainer.getPidLambda(
+                scoreParam.getExtraParam().getOrDefault("predict_test_id","default")+"_"+modelParam.getAppType());
         if (maxItem != null && maxItem.getScore() < realThreshold) {
         if (maxItem != null && maxItem.getScore() < realThreshold) {
             // If final score is below threshold, do not show the ad
             // If final score is below threshold, do not show the ad
             adPredict = 1;
             adPredict = 1;
@@ -100,7 +104,7 @@ public class ScoreV2ThresholdPredictModel extends ThresholdPredictModel {
             adPredict = 2;
             adPredict = 2;
         }
         }
         if(maxItem != null){
         if(maxItem != null){
-            ThresholdModelContainer.mergingDigestAddScore(maxItem.getScore());
+            ThresholdModelContainer.mergingDigestAddScore(modelParam.getAppType(),maxItem.getScore());
             //删除多余打印
             //删除多余打印
             maxItem.setItemFeature(null);
             maxItem.setItemFeature(null);
             maxItem.setLrSampleString(null);
             maxItem.setLrSampleString(null);
@@ -110,7 +114,7 @@ public class ScoreV2ThresholdPredictModel extends ThresholdPredictModel {
         }
         }
 
 
         Map<String, Object> result = new HashMap<>();
         Map<String, Object> result = new HashMap<>();
-        result.put("threshold", threshold);
+        result.put("threshold", realThreshold);
         result.put("score", maxItem == null ? -1 : maxItem.getScore());
         result.put("score", maxItem == null ? -1 : maxItem.getScore());
         result.put("ad_predict", adPredict);
         result.put("ad_predict", adPredict);