Ver Fonte

Merge branch 'feature_gufengshou_20240401_pid_v6' into pre-master

gufengshou1 há 1 ano atrás
pai
commit
550337ef0b

+ 3 - 0
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/predict/helper/NewExpInfoHelper.java

@@ -39,6 +39,9 @@ public class NewExpInfoHelper {
             Map<String, String> expMap = JSONUtils.fromJson(httpServletRequest.getHeader("newGroupInfo"), new TypeToken<Map<String, String>>() {
             }, Collections.emptyMap());
             String groupStr=expMap.get("group");
+            if(groupStr.contains("ab100")){
+                return -1;
+            }
             Map<String, String> groupMap = JSONUtils.fromJson(groupStr, new TypeToken<Map<String, String>>() {
             }, Collections.emptyMap());
             return Integer.parseInt(groupMap.getOrDefault("ad-server","-1"));

+ 3 - 2
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/VlogMergeEcpmScorer.java

@@ -37,7 +37,7 @@ public class VlogMergeEcpmScorer extends BaseLRModelScorer {
 
 
         long startTime = System.currentTimeMillis();
-        List<AdRankItem> result = mergetEcpm(rankItems);
+        List<AdRankItem> result = mergeEcpm(rankItems);
         LOGGER.debug("ctr ranker time java items size={}, time={} ", result != null ? result.size() : 0,
                 System.currentTimeMillis() - startTime);
 
@@ -46,7 +46,8 @@ public class VlogMergeEcpmScorer extends BaseLRModelScorer {
 
 
 
-    public List<AdRankItem> mergetEcpm(List<AdRankItem> items) {
+
+    public List<AdRankItem> mergeEcpm(List<AdRankItem> items) {
         CountDownLatch countDownLatch = new CountDownLatch(items.size());
         for (AdRankItem item : items) {
             executorService.execute(() -> {

+ 4 - 0
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/container/PidLambdaContainer.java

@@ -222,6 +222,8 @@ public class PidLambdaContainer {
         public double integral=0d; // 积分项
         public double lastError=0d; // 上一个误差
 
+        public double lastPidValue=0d;
+
         public double calculate(double kp, double ki, double kd, double setPoint,double currentValue) {
             processVariable = currentValue;
             double error = setPoint - processVariable;
@@ -236,5 +238,7 @@ public class PidLambdaContainer {
             integral = 0;
             lastError = 0;
         }
+
+        //优化 添加步长及补偿机制  并判定异常值 若发现异常值则重置
     }
 }

+ 59 - 49
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/container/PidLambdaV2Container.java

@@ -9,6 +9,7 @@ import com.aliyun.oss.common.auth.DefaultCredentialProvider;
 import com.aliyun.oss.model.CopyObjectResult;
 import com.aliyun.oss.model.OSSObject;
 import com.aliyun.oss.model.PutObjectResult;
+import com.tzld.piaoquan.ad.engine.commons.util.DateUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.springframework.beans.factory.annotation.Value;
@@ -52,6 +53,15 @@ public class PidLambdaV2Container {
 
     @Value("${ad.model.pid.v2.lambda.min:0.8}")
     Double minLambda = 0d;
+
+    @Value("${ad.model.pid.v2.kp:1.2}")
+    Double kp = 0d;
+
+    @Value("${ad.model.pid.v2.ki:0.4}")
+    Double ki = 0d;
+
+    @Value("${ad.model.pid.v2.kd:0.1}")
+    Double kd = 0d;
     OSS client;
 
     private static ConcurrentHashMap<Long,CacheItem>  lambdaCache=new ConcurrentHashMap<>();
@@ -92,49 +102,42 @@ public class PidLambdaV2Container {
             InputStreamReader isr=new InputStreamReader(is);
             BufferedReader bufferedReader = new BufferedReader(isr);
             String line = null;
-            ConcurrentHashMap<Long,CacheItem>  temp=new ConcurrentHashMap<>();
+            ConcurrentHashMap<Long, CacheItem>  temp=new ConcurrentHashMap<>();
             Double conversion=0d;
             Double cpa=0d;
             Double realCost=0d;
-            Double yesterdayConv=0d;
-            Double yesterdayCpa=0d;
-            Double yesterdayRealCost=0d;
-
+            Double latestRealCPA=0d;
+            double sumE=0d;
             while ((line = bufferedReader.readLine()) != null){
                 try {
                     String[] cols=line.split(",");
                     Long creativeId=Long.parseLong(cols[0]);
                     CacheItem cacheItem=lambdaCache.getOrDefault(creativeId,new CacheItem(creativeId));
+                    if(DateUtils.getCurrentHour()<=8){
+                        temp.put(creativeId,cacheItem);
+                        continue;
+                    }
                     conversion=Double.parseDouble(cols[1]);
                     cpa=Double.parseDouble(cols[2]);
                     realCost=Double.parseDouble(cols[3]);
-                    yesterdayConv=Double.parseDouble(cols[4]);
-                    yesterdayCpa=Double.parseDouble(cols[5]);
-                    yesterdayRealCost=Double.parseDouble(cols[6]);
-                    Double lambdaNew=1d;
-                    if((conversion+yesterdayConv)!=0){
-                        if((realCost*yesterdayRealCost)!=0){
-                            double yesterdayW=yesterdayConv/(yesterdayConv+2*conversion);
-                            lambdaNew=(yesterdayW*yesterdayConv*yesterdayCpa)/yesterdayRealCost
-                                    +(1-yesterdayW)*(cpa*conversion)/realCost;
-                        }else if(realCost!=0){
-                            lambdaNew=(cpa*conversion)/realCost;
-                        }else if(yesterdayRealCost!=0){
-                            lambdaNew=(yesterdayConv*yesterdayCpa)/yesterdayRealCost;
-                        }
+                    if(conversion<1d){
+                        temp.put(creativeId,cacheItem);
+                        continue;
                     }
-
-
-                    lambdaNew=cacheItem.calculate( conversion,  realCost,  cpa,lambdaNew) ;
-
-
-                    if(lambdaNew>maxLambda){
-                        lambdaNew=maxLambda;
-                    }else if(lambdaNew<minLambda){
+                    latestRealCPA=realCost/conversion;
+                    if(Math.abs(latestRealCPA-cacheItem.latestRealCpa)<0.01){
+                        temp.put(creativeId,cacheItem);
+                        continue;
+                    }
+                    Double lambdaNew =cacheItem.calculate(kp,ki,kd,cpa,latestRealCPA);
+                    if(lambdaNew<minLambda){
                         lambdaNew=minLambda;
                     }
                     cacheItem.lambda=lambdaNew;
+                    cacheItem.latestRealCpa=latestRealCPA;
+                    cacheItem.sumError=sumE;
                     cacheItem.latestConv=conversion;
+
                     temp.put(creativeId,cacheItem);
 
                     log.info("svc=calNewLambdaV2 creativeId={} lambdaNew={}", creativeId,lambdaNew);
@@ -177,7 +180,7 @@ public class PidLambdaV2Container {
             while ((line=bufferedReader.readLine())!=null){
                 builder.append(line);
             }
-            lambdaCache=JSONObject.parseObject(builder.toString(),new TypeReference<ConcurrentHashMap<Long,CacheItem>>(){});
+            lambdaCache=JSONObject.parseObject(builder.toString(),new TypeReference<ConcurrentHashMap<Long, CacheItem>>(){});
             this.cacheDate=object.getObjectMetadata().getLastModified();
         }catch (Exception e){
             log.error("svc=loadLambdaV2File status=failed error={}", Arrays.toString(e.getStackTrace()));
@@ -204,33 +207,40 @@ public class PidLambdaV2Container {
             this.creativeId=creativeId;
         }
 
-        public double calculate(double conversion, double realCost, double cpa,double lambdaNew ) {
-            if(conversion!=0){
-                double latestRealCPA=realCost/conversion;
-
-                if(this.latestRealCpa==0d){
-                    this.latestRealCpa=cpa;
-                }
-                if(Math.abs(latestRealCPA-cpa)-Math.abs(this.latestRealCpa-cpa)>0){
-                    this.pow=2d;
-                }else {
-                    this.pow=1.2;
-                }
-                this.latestRealCpa=latestRealCPA;
-                return Math.pow(lambdaNew,this.pow);
-            }
-            this.pow=1.5d;
-            return Math.pow(lambdaNew,this.pow);
-        }
-
         public Long creativeId;
 
-        public double lambda=1d;
+        public double lambda=-1d;
 
         public double latestConv=0d;
 
-        public double pow=1d;
+        public double sumError=0d;
 
         public double latestRealCpa=0d;
+
+        public double processVariable=0d; // 处理变量
+        public double integral=0d; // 积分项
+        public double lastError=0d; // 上一个误差
+        public double lastPidValue=0d;
+        public double calculate(double kp, double ki, double kd, double setPoint,double currentValue) {
+            processVariable = currentValue;
+            double error = setPoint - processVariable;
+
+            integral += error;
+            double derivative = (error - lastError) / 1; // 假设采样间隔为1
+            lastError = error;
+            lastPidValue=kp * error + ki * integral + kd * derivative;
+            double result=currentValue+lastPidValue;
+            double min=setPoint>currentValue?currentValue:setPoint;
+            if(result<=min/8d){
+                result=min/8d;
+            }
+            return result;
+        }
+
+        public void reset() {
+            integral = 0;
+            lastError = 0;
+        }
+
     }
 }

+ 32 - 6
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/impl/RankServiceImpl.java

@@ -88,7 +88,7 @@ public class RankServiceImpl implements RankService {
             item.setBid1(1d);
             item.setBid2(1d);
             item.setCpa(75d);
-            item.setPidLambda(PidLambdaContainer.getPidLambda(item.getAdId()));
+            item.setPidLambda(1d);
         }
 
         //兜底方案
@@ -270,14 +270,27 @@ public class RankServiceImpl implements RankService {
         if(!cache.isEmpty()){
             rankItems=new LinkedList<>(cache.values());
         }
+        double lambda=-1d;
         for(AdRankItem item:rankItems){
             try {
+//                AdPlatformBidCreativeDTO dto=groupMap.get(item.getAdId()+"").get(0);
                 AdPlatformBidCreativeDTO dto=groupMap.get(item.getAdId()).get(0);
                 item.setBid1(dto.getBid1());
                 item.setBid2(dto.getBid2());
-                item.setCpa(dto.getCpa());
-                item.setPidLambda(PidLambdaV2Container.getPidLambda(item.getAdId()));
+                lambda=PidLambdaV2Container.getPidLambda(item.getAdId());
+                if(lambda<0){
+                    item.setCpa(dto.getCpa());
+                    item.setPidLambda(0.6);
+                }else {
+                    if(dto.getCpa()>1&&lambda<=1){
+                        lambda=2d;
+                    }
+                    item.setCpa(lambda);
+                    item.setPidLambda(1d);
+                }
+
             }catch (Exception e){
+                log.error("rankItems info error itemId={}",item.getAdId());
                 e.printStackTrace();
             }
         }
@@ -288,10 +301,21 @@ public class RankServiceImpl implements RankService {
                 AdRankItem item=new AdRankItem();
                 item.setBid1(dto.getBid1());
                 item.setBid2(dto.getBid2());
-                item.setCpa(dto.getCpa());
                 item.setAdId(dto.getCreativeId());
                 item.setItemFeature(new AdItemFeature());
-                item.setPidLambda(PidLambdaV2Container.getPidLambda(item.getAdId()));
+                lambda=PidLambdaV2Container.getPidLambda(item.getAdId());
+                if(lambda<0){
+                    item.setCpa(dto.getCpa());
+                    item.setPidLambda(0.6);
+                }else {
+                    if(dto.getCpa()>1&&lambda<=1){
+                        lambda=2d;
+                    }
+                    item.setCpa(lambda);
+                    item.setPidLambda(1d);
+                }
+//                item.setCpa(dto.getCpa());
+//                item.setPidLambda(PidLambdaContainer.getPidLambda(item.getAdId()));
                 rankItems.add(item);
             }
             rankResult=rankServiceThompson.rank(param, userAdFeature, rankItems,null);
@@ -320,6 +344,7 @@ public class RankServiceImpl implements RankService {
             realECpm=cpmMin/1000d;
         }
         result.setEcpm2(realECpm);
+        AdPlatformBidCreativeDTO dto=groupMap.get(topItem.getAdId()).get(0);
         JSONObject object=new JSONObject();
         object.put("mid",request.getMid());
         object.put("adid",result.getCreativeId());
@@ -332,7 +357,8 @@ public class RankServiceImpl implements RankService {
         //临时加入供pid v2使用
         object.put("realECpm",realECpm);
         object.put("creativeId",result.getCreativeId());
-        object.put("cpa",topItem.getCpa());
+        //CPA还原
+        object.put("cpa",dto.getCpa());
         object.put("dataTime",currentTime.format(timeFormatter));
         log.info("svc=adBidRankNewPid {}", JSONObject.toJSONString(object));
         object.remove("lrsamples");