浏览代码

加入T-Digest算法
增加阈值日志

gufengshou1 1 年之前
父节点
当前提交
f167522630

+ 6 - 0
ad-engine-service/pom.xml

@@ -30,5 +30,11 @@
             <artifactId>odps-sdk-core</artifactId>
             <version>0.45.6-public</version>
         </dependency>
+        <!-- t-digest算法包       -->
+        <dependency>
+            <groupId>com.tdunning</groupId>
+            <artifactId>t-digest</artifactId>
+            <version>3.3</version>
+        </dependency>
     </dependencies>
 </project>

+ 33 - 0
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/predict/container/ThresholdModelContainer.java

@@ -1,10 +1,13 @@
 package com.tzld.piaoquan.ad.engine.service.predict.container;
 
+import com.tdunning.math.stats.Centroid;
+import com.tdunning.math.stats.MergingDigest;
 import com.tzld.piaoquan.ad.engine.service.predict.model.threshold.ThresholdPredictModel;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.context.ApplicationContext;
 import org.springframework.stereotype.Component;
 
+
 import javax.annotation.PostConstruct;
 import java.util.HashMap;
 import java.util.Map;
@@ -16,6 +19,8 @@ public class ThresholdModelContainer {
     private ApplicationContext applicationContext;
 
     public static Map<String,ThresholdPredictModel> modelMap=new HashMap<>();
+    public static MergingDigest mergingDigest = new MergingDigest(100);
+
     @PostConstruct
     public void init() {
         Map<String,ThresholdPredictModel> beanMap= applicationContext.getBeansOfType(ThresholdPredictModel.class);
@@ -31,4 +36,32 @@ public class ThresholdModelContainer {
     public static ThresholdPredictModel getBasicPredictModel(){
         return modelMap.get("basic");
     }
+
+    public static void mergingDigestAddScore(Double score){
+        mergingDigest.add(score);
+    }
+
+    public static double getThresholdByTDigest(Double sortPosition){
+        return mergingDigest.quantile(sortPosition);
+    }
+
+//    public static void main(String[] args){
+//        MergingDigest mergingDigest = new MergingDigest(100);
+//        for(long i=0;i<1000;i++){
+//            double newDataPoint = Math.random() * 100;
+//            // 向MergingDigest中添加新数据
+//            mergingDigest.add(newDataPoint);
+//        }
+//        System.out.println(mergingDigest.quantile(0.12));
+//        System.out.println(mergingDigest.quantile(0.6));
+//        Iterable<Centroid> centroids = mergingDigest.centroids();
+//        Integer totalW=0;
+//        Integer totalS=0;
+//        // 遍历质点列表并输出
+//        for (Centroid centroid : centroids) {
+//            System.out.println("值: " + centroid.mean() + ", 权重: " + centroid.count());
+//        }
+//        System.out.println(totalW);
+//        System.out.println(totalS);
+//    }
 }

+ 8 - 1
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/predict/model/threshold/ScoreV2ThresholdPredictModel.java

@@ -5,6 +5,7 @@ import com.tzld.piaoquan.ad.engine.commons.score.AdConfig;
 import com.tzld.piaoquan.ad.engine.commons.score.ScoreParam;
 import com.tzld.piaoquan.ad.engine.commons.score.ScorerUtils;
 import com.tzld.piaoquan.ad.engine.commons.util.CommonCollectionUtils;
+import com.tzld.piaoquan.ad.engine.service.predict.container.ThresholdModelContainer;
 import com.tzld.piaoquan.ad.engine.service.predict.impl.PredictModelServiceImpl;
 import com.tzld.piaoquan.ad.engine.service.predict.param.ThresholdPredictModelParam;
 import com.tzld.piaoquan.ad.engine.service.remote.FeatureRemoteService;
@@ -34,6 +35,9 @@ public class ScoreV2ThresholdPredictModel extends ThresholdPredictModel {
     @Value("${ad.predict.threshold:1}")
     private double threshold;
 
+    @Value("${ad.predict.t-digest.position:0.52}")
+    private double position;
+
     @Override
     String initName() {
         return "modelV2";
@@ -95,12 +99,15 @@ public class ScoreV2ThresholdPredictModel extends ThresholdPredictModel {
             // Otherwise, show the ad
             adPredict = 2;
         }
+        if(maxItem != null){
+            ThresholdModelContainer.mergingDigestAddScore(maxItem.getScore());
+        }
 
         Map<String, Object> result = new HashMap<>();
         result.put("threshold", threshold);
         result.put("score", maxItem == null ? -1 : maxItem.getScore());
         result.put("ad_predict", adPredict);
-        log.info("svc=predict modelName=modelV2 result={}", JSONObject.toJSONString(result));
+        log.info("svc=predict modelName=modelV2 mergingDigestThreshold={}", ThresholdModelContainer.getThresholdByTDigest(position));
 
         return result;
     }