Browse Source

合并master

zhaohaipeng 1 năm trước cách đây
mục cha
commit
330608261a

+ 3 - 1
ad-engine-commons/src/main/java/com/tzld/piaoquan/ad/engine/commons/score/ScorerUtils.java

@@ -25,6 +25,8 @@ public final class ScorerUtils {
 
     public static String BREAK_CONFIG = "feeds_score_config_break.conf";
     public static String SHARE0_CONFIG = "feeds_score_config_share0.conf";
+
+    public static String CVR_ADJUSTING = "feeds_score_config_cvr_adjusting.conf";
     public static String UNION_THOMPSON_CONF = "union_score_config_thompson.conf";
 
     public static void warmUp() {
@@ -34,7 +36,7 @@ public final class ScorerUtils {
         ScorerUtils.init(BREAK_CONFIG);
         ScorerUtils.init(SHARE0_CONFIG);
         ScorerUtils.init(UNION_THOMPSON_CONF);
-
+        ScorerUtils.init(CVR_ADJUSTING);
     }
 
     private ScorerUtils() {

+ 89 - 0
ad-engine-commons/src/main/java/com/tzld/piaoquan/ad/engine/commons/score/model/CvrAdjustingModel.java

@@ -0,0 +1,89 @@
+package com.tzld.piaoquan.ad.engine.commons.score.model;
+
+import com.google.common.collect.HashBasedTable;
+import com.google.common.collect.Table;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.BufferedReader;
+import java.io.InputStreamReader;
+import java.math.BigDecimal;
+import java.math.RoundingMode;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.TreeMap;
+import java.util.stream.Collectors;
+
+public class CvrAdjustingModel extends Model {
+
+    private static final Logger LOGGER = LoggerFactory.getLogger(CvrAdjustingModel.class);
+
+    private Table<Double, Double, Double> table = HashBasedTable.create();
+
+    @Override
+    public int getModelSize() {
+        return table.size();
+    }
+
+    @Override
+    public boolean loadFromStream(InputStreamReader in) throws Exception {
+        Table<Double, Double, Double> initTable = HashBasedTable.create();
+        try (BufferedReader input = new BufferedReader(in)) {
+            String line;
+            int cnt = 0;
+            Map<Double, Double> initModel = new TreeMap<>();
+            while ((line = input.readLine()) != null) {
+                String[] items = line.split("\t");
+                if (items.length < 4) {
+                    continue;
+                }
+
+                double key = new BigDecimal(items[2]).doubleValue();
+                double value = new BigDecimal(items[3]).doubleValue();
+                initModel.put(key, value);
+            }
+
+            // 最终生成的格式为  区间最小值,区间最大值,系数
+            List<Double> keySet = initModel.keySet().stream().sorted().collect(Collectors.toList());
+            double preKey = 0.0;
+            for (Double key : keySet) {
+                initTable.put(preKey, key, initModel.get(key));
+                preKey = key;
+            }
+            initTable.put(preKey, Double.MAX_VALUE, initModel.get(preKey));
+
+            this.table = initTable;
+
+            for (Table.Cell<Double, Double, Double> cell : this.table.cellSet()) {
+                LOGGER.info("cell.row: {}, cell.column: {}, cell.value: {}", cell.getRowKey(), cell.getColumnKey(), cell.getValue());
+            }
+            
+            LOGGER.info("[CvrAdjustingModel] model load over and size {}", cnt);
+        } catch (
+                Exception e) {
+            LOGGER.info("[CvrAdjustingModel] model load error ", e);
+        } finally {
+            in.close();
+
+        }
+        return true;
+    }
+
+    public Double getAdjustingCoefficien(double score) {
+        if (Objects.isNull(table)) {
+            return 1.0;
+        }
+
+        for (Table.Cell<Double, Double, Double> cell : table.cellSet()) {
+            double rowKey = cell.getRowKey();
+            double columnKey = cell.getColumnKey();
+            if (rowKey <= score & score < columnKey) {
+                LOGGER.info("score {} in {} - {} , value is {}", score, rowKey, columnKey, cell.getValue());
+                return cell.getValue();
+            }
+        }
+
+        return 1.0;
+    }
+}

+ 27 - 0
ad-engine-server/src/main/resources/feeds_score_config_cvr_adjusting.conf

@@ -0,0 +1,27 @@
+scorer-config = {
+  lr-ctr-score-config = {
+    scorer-name = "com.tzld.piaoquan.ad.engine.service.score.VlogAdCtrLRScorer"
+    scorer-priority = 99
+    model-path = "ad_ctr_model/model_ad_ctr.txt"
+  }
+  lr-cvr-score-config = {
+      scorer-name = "com.tzld.piaoquan.ad.engine.service.score.VlogAdCvrLRScorer"
+      scorer-priority = 98
+      model-path = "ad_cvr_model/model_ad_cvr.txt"
+  }
+    tf-ctr-score-config = {
+      scorer-name = "com.tzld.piaoquan.ad.engine.service.score.VlogAdThompsonScorer"
+      scorer-priority = 97
+      model-path = "ad_thompson_model/model_ad_thompson.txt"
+    }
+    lr-cvr-adjusting-score-config = {
+          scorer-name = "com.tzld.piaoquan.ad.engine.service.score.VlogAdCvrLRAdjustingScorer"
+          scorer-priority = 96
+          model-path = "ad_cvr_model/cvr_adjusting_strategy_coefficient.txt"
+    }
+  lr-ecpm-merge-config = {
+      scorer-name = "com.tzld.piaoquan.ad.engine.service.score.VlogMergeEcpmScorer"
+      scorer-priority = 1
+  }
+
+}

+ 162 - 0
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/VlogAdCvrLRAdjustingScorer.java

@@ -0,0 +1,162 @@
+package com.tzld.piaoquan.ad.engine.service.score;
+
+
+import com.tzld.piaoquan.ad.engine.commons.score.AbstractScorer;
+import com.tzld.piaoquan.ad.engine.commons.score.ScoreParam;
+import com.tzld.piaoquan.ad.engine.commons.score.ScorerConfigInfo;
+import com.tzld.piaoquan.ad.engine.commons.score.model.CvrAdjustingModel;
+import com.tzld.piaoquan.recommend.feature.domain.ad.base.AdRankItem;
+import com.tzld.piaoquan.recommend.feature.domain.ad.base.AdRequestContext;
+import com.tzld.piaoquan.recommend.feature.domain.ad.base.UserAdFeature;
+import org.apache.commons.collections4.CollectionUtils;
+import org.apache.commons.lang.exception.ExceptionUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Objects;
+import java.util.concurrent.*;
+
+
+//@Service
+public class VlogAdCvrLRAdjustingScorer extends AbstractScorer {
+
+    private static final int LOCAL_TIME_OUT = 150;
+    private final static Logger LOGGER = LoggerFactory.getLogger(VlogAdCvrLRAdjustingScorer.class);
+    private static final ExecutorService executorService = Executors.newFixedThreadPool(8);
+
+    public VlogAdCvrLRAdjustingScorer(ScorerConfigInfo configInfo) {
+        super(configInfo);
+    }
+
+    @Override
+    public void loadModel() {
+        doLoadModel(CvrAdjustingModel.class);
+    }
+
+    @Override
+    public List<AdRankItem> scoring(final ScoreParam param,
+                                    final UserAdFeature userFeature,
+                                    final List<AdRankItem> rankItems) {
+
+        if (CollectionUtils.isEmpty(rankItems)) {
+            return rankItems;
+        }
+
+        long startTime = System.currentTimeMillis();
+        List<AdRankItem> result = rankByJava(rankItems, param.getRequestContext(), userFeature);
+        LOGGER.debug("ctr ranker time java items size={}, time={} ",
+                result.size(), System.currentTimeMillis() - startTime);
+
+        return result;
+    }
+
+    private List<AdRankItem> rankByJava(final List<AdRankItem> items,
+                                        final AdRequestContext requestContext,
+                                        final UserAdFeature user) {
+        long startTime = System.currentTimeMillis();
+        CvrAdjustingModel model = (CvrAdjustingModel) this.getModel();
+        LOGGER.debug("model size: [{}]", model.getModelSize());
+
+        // 所有都参与打分,按照cvr排序
+        multipleScore(items, requestContext, model);
+
+        // debug log
+        if (LOGGER.isDebugEnabled()) {
+            for (AdRankItem item : items) {
+                LOGGER.debug("after enter feeds model predict cvr adjusting score [{}] [{}]", item, item.getScore());
+            }
+        }
+
+        LOGGER.debug("[ctr ranker time java] items size={}, cost={} ",
+                items.size(), System.currentTimeMillis() - startTime);
+        return items;
+    }
+
+
+    /**
+     * 校准cvr
+     */
+    public void calcScore(final CvrAdjustingModel model,
+                            final AdRankItem item,
+                            final AdRequestContext requestContext) {
+
+        double pro = item.getCvr();
+        try {
+            Double coef = model.getAdjustingCoefficien(pro);
+            if (Objects.nonNull(coef)) {
+                LOGGER.info("[VlogAdCvrLRAdjustingScorer.cvr adjusting] before: {}", pro);
+                pro = pro / coef;
+                LOGGER.info("[VlogAdCvrLRAdjustingScorer.cvr adjusting] after: {}, coef: {}", pro, coef);
+
+            }
+
+        } catch (
+                Exception e) {
+            LOGGER.error("score error for doc={} exception={}",
+                    item.getAdId(), ExceptionUtils.getFullStackTrace(e));
+        }
+        item.setCvr(pro);
+    }
+
+
+    /**
+     * 并行打分
+     *
+     * @param items
+     * @param userInfoBytes
+     * @param requestContext
+     * @param model
+     */
+    private void multipleScore(final List<AdRankItem> items,
+                               final AdRequestContext requestContext,
+                               final CvrAdjustingModel model) {
+
+        List<Callable<Object>> calls = new ArrayList<Callable<Object>>();
+        for (int index = 0; index < items.size(); index++) {
+            final int fIndex = index;
+            calls.add(new Callable<Object>() {
+                @Override
+                public Object call() throws Exception {
+                    try {
+                        calcScore(model, items.get(fIndex), requestContext);
+                    } catch (
+                            Exception e) {
+                        LOGGER.error("ctr exception: [{}] [{}]", items.get(fIndex).adId, ExceptionUtils.getFullStackTrace(e));
+                    }
+                    return new Object();
+                }
+            });
+        }
+
+        List<Future<Object>> futures = null;
+        try {
+            futures = executorService.invokeAll(calls, LOCAL_TIME_OUT, TimeUnit.MILLISECONDS);
+        } catch (
+                InterruptedException e) {
+            LOGGER.error("execute invoke fail: {}", ExceptionUtils.getFullStackTrace(e));
+        }
+
+        // 等待所有请求的结果返回, 超时也返回
+        int cancel = 0;
+        if (futures != null) {
+            for (Future<Object> future : futures) {
+                try {
+                    if (!future.isDone() || future.isCancelled() || future.get() == null) {
+                        cancel++;
+                    }
+                } catch (
+                        InterruptedException e) {
+                    LOGGER.error("InterruptedException {},{}",
+                            requestContext.getApptype(), ExceptionUtils.getFullStackTrace(e));
+                } catch (
+                        ExecutionException e) {
+                    LOGGER.error("ExecutionException {},{}",
+                            requestContext.getApptype(), ExceptionUtils.getFullStackTrace(e));
+                }
+            }
+        }
+        LOGGER.debug("Ctr Score {}, Total: {}, Cancel: {}", requestContext.getApptype(), items.size(), cancel);
+    }
+}

+ 2 - 3
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/VlogAdCvrLRScorer.java

@@ -14,8 +14,8 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.util.ArrayList;
-import java.util.Collections;
 import java.util.List;
+import java.util.Objects;
 import java.util.concurrent.*;
 
 
@@ -31,12 +31,10 @@ public class VlogAdCvrLRScorer extends BaseLRModelScorer {
     private static final int enterFeedsScoreRatio = 10;
     private static final int enterFeedsScoreNum = 20;
 
-
     public VlogAdCvrLRScorer(ScorerConfigInfo configInfo) {
         super(configInfo);
     }
 
-
     @Override
     public List<AdRankItem> scoring(final ScoreParam param,
                                     final UserAdFeature userFeature,
@@ -113,6 +111,7 @@ public class VlogAdCvrLRScorer extends BaseLRModelScorer {
         if (lrSamples != null && lrSamples.getFeaturesList() != null) {
             try {
                 pro = lrModel.score(lrSamples);
+
             } catch (Exception e) {
                 LOGGER.error("score error for doc={} exception={}", new Object[]{
                         item.getAdId(), ExceptionUtils.getFullStackTrace(e)});

+ 1 - 0
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/VlogMergeEcpmScorer.java

@@ -109,6 +109,7 @@ public class VlogMergeEcpmScorer extends BaseLRModelScorer {
             double bid2 = item.getBid2();
             double pctr = item.getCtr();
             double pcvr = item.getCvr();
+            LOGGER.info("VlogMergeEcmpScore.pcvr: {}", pcvr);
 //            item.setScore_type( isTfType?1:0);
             item.setScore_type( 0);
             //todo

+ 248 - 0
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/container/PidLambdaForCpcContainer.java

@@ -0,0 +1,248 @@
+package com.tzld.piaoquan.ad.engine.service.score.container;
+
+import com.alibaba.fastjson.JSONObject;
+import com.alibaba.fastjson.TypeReference;
+import com.aliyun.oss.OSS;
+import com.aliyun.oss.OSSClientBuilder;
+import com.aliyun.oss.common.auth.CredentialsProvider;
+import com.aliyun.oss.common.auth.DefaultCredentialProvider;
+import com.aliyun.oss.model.CopyObjectResult;
+import com.aliyun.oss.model.OSSObject;
+import com.aliyun.oss.model.PutObjectResult;
+import com.tzld.piaoquan.ad.engine.commons.util.DateUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.beans.factory.annotation.Value;
+import org.springframework.stereotype.Component;
+
+import javax.annotation.PostConstruct;
+import java.io.BufferedReader;
+import java.io.ByteArrayInputStream;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.util.Arrays;
+import java.util.Date;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+
+@Component
+public class PidLambdaForCpcContainer {
+    private final static Logger log = LoggerFactory.getLogger(PidLambdaForCpcContainer.class);
+
+    private static final int SCHEDULE_PERIOD = 10;
+    private static final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1);
+    @Value("${model.oss.internal.endpoint:oss-cn-hangzhou.aliyuncs.com}")
+    String endpoint = "";
+    @Value("${model.oss.accessKeyId:LTAIP6x1l3DXfSxm}")
+    String accessKeyId = "";
+    @Value("${model.oss.accessKetSecret:KbTaM9ars4OX3PMS6Xm7rtxGr1FLon}")
+    String accessKetSecret = "";
+    @Value("${model.oss.bucketName:art-recommend}")
+    String bucketName = "";
+
+    @Value("${model.oss.pid.cpc.filename.lambda:pid/lambda_cpc.txt}")
+    String lambdaFileName = "";
+
+    @Value("${model.oss.pid.cpc.filename.dCpa:pid/dCpa_cpc.txt}")
+    String dCpaFileName = "";
+
+    @Value("${ad.model.pid.cpc.kp:0.4}")
+    Double kp = 0d;
+
+    @Value("${ad.model.pid.cpc.ki:0.4}")
+    Double ki = 0d;
+
+    @Value("${ad.model.pid.cpc.kd:0.2}")
+    Double kd = 0d;
+
+    @Value("${ad.model.pid.cpc.lambda.max:5.0}")
+    Double maxLambda = 0d;
+
+    @Value("${ad.model.pid.cpc.lambda.min:0.2}")
+    Double minLambda = 0d;
+    OSS client;
+
+    private static ConcurrentHashMap<Long, CpcCacheItem>  lambdaCache=new ConcurrentHashMap<>();
+    private Date cacheDate;
+
+    @PostConstruct
+    private void init(){
+        instanceClient();
+        final Runnable task = new Runnable() {
+            public void run() {
+                try {
+                    loadAndCalIfNeed();
+                }catch (Exception e){
+                    e.printStackTrace();
+                }
+            }
+        };
+        scheduler.scheduleAtFixedRate(task, 0, SCHEDULE_PERIOD, TimeUnit.MINUTES); // 10分钟
+    }
+
+    private void instanceClient(){
+        CredentialsProvider credentialsProvider = new DefaultCredentialProvider(accessKeyId, accessKetSecret);
+        this.client = new OSSClientBuilder().build(endpoint, credentialsProvider);
+    }
+
+    private void loadAndCalIfNeed(){
+        loadLambdaFile();
+        OSSObject dCpaFileOjb=client.getObject(bucketName,dCpaFileName);
+        if(cacheDate==null||dCpaFileOjb.getObjectMetadata().getLastModified().after(cacheDate)){
+            calNewLambda(dCpaFileOjb);
+            writeLambdaFileToOss();
+        }
+    }
+
+    private void calNewLambda(OSSObject object) {
+        try {
+            InputStream is=object.getObjectContent();
+            InputStreamReader isr=new InputStreamReader(is);
+            BufferedReader bufferedReader = new BufferedReader(isr);
+            String line = null;
+            ConcurrentHashMap<Long, CpcCacheItem>  temp=new ConcurrentHashMap<>();
+            Double conversion=0d;
+            Double cpa=0d;
+            Double realCost=0d;
+            Double latestRealCPA=0d;
+            double sumE=0d;
+            while ((line = bufferedReader.readLine()) != null){
+                try {
+                    String[] cols=line.split(",");
+                    Long creativeId=Long.parseLong(cols[0]);
+                    CpcCacheItem cacheItem=lambdaCache.getOrDefault(creativeId,new CpcCacheItem(creativeId));
+                    if(DateUtils.getCurrentHour()<=8){
+                        temp.put(creativeId,cacheItem);
+                        continue;
+                    }
+                    conversion=Double.parseDouble(cols[1]);
+                    cpa=Double.parseDouble(cols[2]);
+                    realCost=Double.parseDouble(cols[3]);
+                    if(conversion<1d){
+                        temp.put(creativeId,cacheItem);
+                        continue;
+                    }
+                    latestRealCPA=realCost/conversion;
+                    if(Math.abs(latestRealCPA-cacheItem.latestRealCpa)<0.01){
+                        temp.put(creativeId,cacheItem);
+                        continue;
+                    }
+                    Double lambdaNew =cacheItem.calculate(kp,ki,kd,cpa,latestRealCPA);
+                    if(lambdaNew<minLambda){
+                        lambdaNew=minLambda;
+                    }
+                    cacheItem.lambda=lambdaNew;
+                    cacheItem.latestRealCpa=latestRealCPA;
+                    cacheItem.sumError=sumE;
+                    cacheItem.latestConv=conversion;
+
+                    temp.put(creativeId,cacheItem);
+
+                    log.info("svc=calCPCLambda creativeId={} lambdaNew={}", creativeId,lambdaNew);
+                }catch (Exception e){
+                    e.printStackTrace();
+                }
+            }
+            lambdaCache.clear();
+            lambdaCache=temp;
+        }catch (Exception e){
+            log.error("svc=calCPCLambda status=failed error={}", Arrays.toString(e.getStackTrace()));
+        }
+    }
+
+    private void writeLambdaFileToOss(){
+        //先不考虑各种更新失败及重复更新问题。
+        try {
+            String tempFile=lambdaFileName+"_temp";
+            String content= JSONObject.toJSONString(lambdaCache);
+            PutObjectResult putObjectResult=client.putObject(bucketName,tempFile,new ByteArrayInputStream(content.getBytes()));
+            CopyObjectResult copyObjectResult=client.copyObject(bucketName, tempFile, bucketName, lambdaFileName);
+            this.cacheDate= copyObjectResult.getLastModified();
+            client.deleteObject(bucketName, tempFile);
+        }catch (Exception e){
+            log.error("svc=writeCPCLambdaFileToOss status=failed error={}", Arrays.toString(e.getStackTrace()));
+            e.printStackTrace();
+        }
+    }
+
+    private void loadLambdaFile(){
+        try {
+            OSSObject object=client.getObject(bucketName,lambdaFileName);
+            if(object==null) return;
+            if(cacheDate!=null&& !cacheDate.before(object.getObjectMetadata().getLastModified())) return;
+            StringBuilder builder=new StringBuilder();
+            InputStream is=object.getObjectContent();
+            InputStreamReader isr=new InputStreamReader(is);
+            BufferedReader bufferedReader = new BufferedReader(isr);
+            String line = null;
+            while ((line=bufferedReader.readLine())!=null){
+                builder.append(line);
+            }
+            lambdaCache=JSONObject.parseObject(builder.toString(),new TypeReference<ConcurrentHashMap<Long, CpcCacheItem>>(){});
+            this.cacheDate=object.getObjectMetadata().getLastModified();
+        }catch (Exception e){
+            log.error("svc=loadCPCLambdaFile status=failed error={}", Arrays.toString(e.getStackTrace()));
+            e.printStackTrace();
+        }
+    }
+
+    public static Double getPidLambda(Long creativeId){
+        try {
+            return lambdaCache.getOrDefault(creativeId,new CpcCacheItem(creativeId)).lambda;
+        }catch (Exception e){
+            return 1d;
+        }
+    }
+
+
+    public static class CpcCacheItem {
+
+        public CpcCacheItem(){
+
+        }
+
+        public CpcCacheItem(Long creativeId){
+            this.creativeId=creativeId;
+        }
+
+        public Long creativeId;
+
+        public double lambda=-1d;
+
+        public double latestConv=0d;
+
+        public double sumError=0d;
+
+        public double latestRealCpa=0d;
+
+        public double processVariable=0d; // 处理变量
+        public double integral=0d; // 积分项
+        public double lastError=0d; // 上一个误差
+
+        public double lastPidValue=0d;
+
+        public double calculate(double kp, double ki, double kd, double setPoint,double currentValue) {
+            processVariable = currentValue;
+            double error = setPoint - processVariable;
+
+            integral += error;
+            if(Math.abs(integral)>2*setPoint){
+                integral=(Math.abs(integral)/integral)*2*setPoint;
+            }
+            double derivative = (error - lastError) / 1; // 假设采样间隔为1
+            lastError = error;
+            if(lambda<0){
+                lambda=setPoint;
+            }
+            return lambda+kp * error + ki * integral + kd * derivative;
+        }
+
+        public void reset() {
+            integral = 0;
+            lastError = 0;
+        }
+
+    }
+}

+ 67 - 34
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/impl/RankServiceImpl.java

@@ -3,6 +3,7 @@ package com.tzld.piaoquan.ad.engine.service.score.impl;
 import com.alibaba.fastjson.JSONObject;
 import com.tzld.piaoquan.ad.engine.service.score.container.AdCreativeFeatureContainer;
 import com.tzld.piaoquan.ad.engine.service.score.container.PidLambdaContainer;
+import com.tzld.piaoquan.ad.engine.service.score.container.PidLambdaForCpcContainer;
 import com.tzld.piaoquan.ad.engine.service.score.container.PidLambdaV2Container;
 import com.tzld.piaoquan.ad.engine.service.score.dto.AdPlatformCreativeDTO;
 import com.tzld.piaoquan.ad.engine.service.score.param.BidRankRecommendRequestParam;
@@ -46,6 +47,8 @@ public class RankServiceImpl implements RankService {
     Double cpmMax=200d;
     @Value("${ad.model.cpm.min:30}")
     Double cpmMin=30d;
+    @Value("${ad.pid.cpc.exp:30}")
+    private String cpcPidExpCode;
 
     public AdRankItem adItemRank(RankRecommendRequestParam request){
         ScoreParam param= RequestConvert.requestConvert(request);
@@ -66,7 +69,6 @@ public class RankServiceImpl implements RankService {
                 .getAdIdList()
                 .stream()
                 .collect(Collectors.groupingBy(creativeDTO -> creativeDTO.getCreativeId()));
-//        Map<Long, AdRankItem> cache=adCreativeFeatureContainer.getAll(request.getAdIdList());
         Map<Long, AdRankItem> cache=adCreativeFeatureContainer.getAll(new ArrayList<>(groupMap.keySet()));
         List<AdRankItem> rankItems=Collections.emptyList();
         if(!cache.isEmpty()){
@@ -82,42 +84,63 @@ public class RankServiceImpl implements RankService {
                 rankItems.add(item);
             }
         }
+        boolean inCpcPidExp=false;
+        if (request.getAdAbExpArr() != null && request.getAdAbExpArr().size() != 0) {
+            for (Map<String, Object> map : request.getAdAbExpArr() ) {
+                if (map.getOrDefault("abExpCode", "").equals(cpcPidExpCode)) {
+                    inCpcPidExp = true;
+                }
+            }
+        }
         double lambda=-1d;
-        for(AdRankItem item:rankItems){
-            try {
-//                AdPlatformBidCreativeDTO dto=groupMap.get(item.getAdId()+"").get(0);
-                AdPlatformCreativeDTO dto=groupMap.get(item.getAdId()).get(0);
-                item.setBid1(dto.getBid1());
-                item.setBid2(dto.getBid2());
-//                lambda=PidLambdaContainer.getPidLambda(item.getAdId());
-//                if(lambda<0){
-//                    item.setCpa(dto.getCpa());
-//                    item.setPidLambda(0.6);
-//                }else {
-//                    if(dto.getCpa()>1&&lambda<=1){
-//                        lambda=2d;
-//                    }
-//                    item.setCpa(lambda);
-//                    item.setPidLambda(1d);
-//                }
-                item.setCpa(dto.getCpa());
-                item.setPidLambda(1d);
-            }catch (Exception e){
-                log.error("rankItems info error itemId={}",item.getAdId());
-                e.printStackTrace();
+        if(inCpcPidExp){
+            for(AdRankItem item:rankItems){
+                try {
+                    AdPlatformCreativeDTO dto=groupMap.get(item.getAdId()).get(0);
+                    item.setBid1(dto.getBid1());
+                    item.setBid2(dto.getBid2());
+                    item.setCpa(dto.getCpa());
+                    item.setPidLambda(1d);
+                }catch (Exception e){
+                    log.error("rankItems info error itemId={}",item.getAdId());
+                    e.printStackTrace();
+                }
+            }
+        }else {
+            for(AdRankItem item:rankItems){
+                try {
+                    AdPlatformCreativeDTO dto=groupMap.get(item.getAdId()).get(0);
+                    item.setBid1(dto.getBid1());
+                    item.setBid2(dto.getBid2());
+                    lambda= PidLambdaForCpcContainer.getPidLambda(item.getAdId());
+                    if(lambda<0){
+                        item.setCpa(dto.getCpa());
+                        item.setPidLambda(1);
+                    }else {
+                        if(dto.getCpa()>1&&lambda<=1){
+                            lambda=2d;
+                        }
+                        item.setCpa(lambda);
+                        item.setPidLambda(1d);
+                    }
+                    item.setCpa(dto.getCpa());
+                    item.setPidLambda(1d);
+                }catch (Exception e){
+                    log.error("rankItems info error itemId={}",item.getAdId());
+                    e.printStackTrace();
+                }
             }
         }
-//        for(AdRankItem item:rankItems){
-//            item.setBid1(1d);
-//            item.setBid2(1d);
-//            item.setCpa(75d);
-//            item.setPidLambda(1d);
-//        }
-        //兜底方案
-        List<AdRankItem> rankResult=rank(param, userAdFeature, rankItems,ScorerUtils.BASE_CONF);
+
+        // 兜底方案
+        List<AdRankItem> rankResult;
+        if (inCpcPidExp) {
+            rankResult = rank(param, userAdFeature, rankItems, ScorerUtils.CVR_ADJUSTING);
+        } else {
+            rankResult = rank(param, userAdFeature, rankItems, ScorerUtils.BASE_CONF);
+        }
 
         if (!CollectionUtils.isEmpty(rankResult)) {
-//            log.info("svc=adItemRank request={} rankResult={} dataTime={}", JSONObject.toJSONString(request),JSONObject.toJSONString(rankResult),currentTime.format(timeFormatter));
             JSONObject object=new JSONObject();
             object.put("mid",request.getMid());
             object.put("adid",rankResult.get(0).getAdId());
@@ -128,9 +151,17 @@ public class RankServiceImpl implements RankService {
             object.put("pidLambda",rankResult.get(0).getPidLambda());
             object.put("lrsamples",rankResult.get(0).getLrSampleString());
             object.put("dataTime",currentTime.format(timeFormatter));
+            object.put("creativeId",rankResult.get(0).getAdId());
             log.info("svc=adItemRank {}", JSONObject.toJSONString(object));
             object.remove("lrsamples");
-            log.info("svc=pid_log obj={}", JSONObject.toJSONString(object));
+            if(inCpcPidExp){
+                AdPlatformCreativeDTO dto=groupMap.get(rankResult.get(0).getAdId()).get(0);
+                object.put("cpa",dto.getCpa()*dto.getBid1());
+                object.put("oCpa",dto.getCpa());
+                log.info("svc=cpc_pid obj={}", JSONObject.toJSONString(object));
+            }else {
+                log.info("svc=pid_log obj={}", JSONObject.toJSONString(object));
+            }
             return rankResult.get(0);
         }else {
             //空返回值
@@ -138,6 +169,8 @@ public class RankServiceImpl implements RankService {
         }
     }
 
+
+
     @Override
     public AdPlatformCreativeDTO adBidRank(BidRankRecommendRequestParam request) {
 
@@ -170,7 +203,6 @@ public class RankServiceImpl implements RankService {
         double lambda=-1d;
         for(AdRankItem item:rankItems){
             try {
-//                AdPlatformBidCreativeDTO dto=groupMap.get(item.getAdId()+"").get(0);
                 AdPlatformCreativeDTO dto=groupMap.get(item.getAdId()).get(0);
                 item.setBid1(dto.getBid1());
                 item.setBid2(dto.getBid2());
@@ -286,6 +318,7 @@ public class RankServiceImpl implements RankService {
         object.put("pcvr",topItem.getCvr());
         object.put("lrsamples",topItem.getLrSampleString());
         object.put("pidLambda",topItem.getPidLambda());
+
         //临时加入供pid v2使用
         object.put("realECpm",realECpm);
         object.put("creativeId",result.getCreativeId());

+ 18 - 0
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/impl/RankServiceThompsonImpl.java

@@ -46,6 +46,24 @@ public class RankServiceThompsonImpl implements RankService {
         }
     }
 
+    public AdRankItem adItemRankV2(RankRecommendRequestParam request) {
+
+        ScoreParam param = RequestConvert.requestConvert(request);
+        UserAdFeature userAdFeature = new UserAdFeature();
+        List<AdRankItem> rankItems = featureRemoteService.getAllAdFeatureList(
+                CommonCollectionUtils.toList(request.getAdIdList(), creativeDTO -> creativeDTO.getCreativeId().toString())
+        );
+        List<AdRankItem> rankResult = ScorerUtils
+                .getScorerPipeline(ScorerUtils.THOMPSON_CONF)
+                .scoring(param, userAdFeature, rankItems);
+
+        if (!CollectionUtils.isEmpty(rankResult)) {
+            return rankResult.get(0);
+        } else {
+            return null;
+        }
+    }
+
     @Override
     public AdPlatformCreativeDTO adBidRank(BidRankRecommendRequestParam request) {
         return null;

+ 5 - 0
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/param/RecommendRequestParam.java

@@ -3,6 +3,9 @@ package com.tzld.piaoquan.ad.engine.service.score.param;
 import lombok.Data;
 import lombok.ToString;
 
+import java.util.List;
+import java.util.Map;
+
 @Data
 @ToString
 public class RecommendRequestParam {
@@ -14,4 +17,6 @@ public class RecommendRequestParam {
     //市-中文
     String city = "-1";
     Integer newExpGroup;
+
+    List<Map> adAbExpArr ;
 }