소스 검색

LR、FM模型

丁云鹏 10 달 전
부모
커밋
9d662bfe60

+ 31 - 1
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/FeatureService.java

@@ -31,6 +31,7 @@ public class FeatureService {
 
         for (String vid : vidList) {
             // TODO 补充其他特征
+            // vid
             protos.add(genWithVid("alg_vid_feature_all_exp", vid));
             protos.add(genWithVid("alg_vid_feature_all_return", vid));
             protos.add(genWithVid("alg_vid_feature_all_share", vid));
@@ -44,17 +45,29 @@ public class FeatureService {
             protos.add(genWithVid("alg_vid_feature_feed_play", vid));
             protos.add(genWithVid("alg_vid_feature_head_play", vid));
             protos.add(genWithVid("alg_vid_feature_share2return", vid));
-
+            // vid + apptype
             protos.add(genWithVidAndAppType("alg_vid_feature_feed_apptype_exp", vid, appType));
             protos.add(genWithVidAndAppType("alg_vid_feature_feed_apptype_root_return", vid, appType));
             protos.add(genWithVidAndAppType("alg_vid_feature_feed_apptype_root_share", vid, appType));
 
+            // vid + province
             protos.add(genWithVidAndProvince("alg_vid_feature_feed_province_exp", vid, province));
             protos.add(genWithVidAndProvince("alg_vid_feature_feed_province_root_return", vid, province));
             protos.add(genWithVidAndProvince("alg_vid_feature_feed_province_root_share", vid, province));
 
+            // vid + headvid
+            protos.add(genWithVidAndHeadVid("", vid, headVid));
         }
 
+
+        // user
+        protos.add(genWithVidAndProvince("", mid));
+        protos.add(genWithVidAndProvince("", mid));
+        protos.add(genWithVidAndProvince("", mid));
+        protos.add(genWithVidAndProvince("", mid));
+        protos.add(genWithVidAndProvince("", mid));
+
+
         Map<String, String> result = remoteService.getFeature(protos);
 
         Map<String, Map<String, Map<String, String>>> data = new HashMap<>();
@@ -105,5 +118,22 @@ public class FeatureService {
                 .build();
     }
 
+    private FeatureKeyProto genWithVidAndHeadVid(String table, String vid, String headVid) {
+        return FeatureKeyProto.newBuilder()
+                .setUniqueKey(String.format(ukFormat, table, vid))
+                .setTableName(table)
+                .putFieldValue("vid", vid)
+                .putFieldValue("headVid", )
+                .build();
+    }
+
+    private FeatureKeyProto genWithMid(String table, String mid) {
+        return FeatureKeyProto.newBuilder()
+                .setUniqueKey(String.format(ukFormat, table, mid))
+                .setTableName(table)
+                .putFieldValue("mid", )
+                .build();
+    }
+
 
 }

+ 137 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/score/model/FMModel.java

@@ -0,0 +1,137 @@
+package com.tzld.piaoquan.recommend.server.service.score.model;
+
+
+import org.apache.commons.collections4.MapUtils;
+import org.apache.commons.lang.math.NumberUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+
+public class FMModel extends Model {
+    protected static final int MODEL_FIRST_LOAD_COUNT = 1 << 25; // 32M
+    private static final Logger LOGGER = LoggerFactory.getLogger(FMModel.class);
+    private Map<String, List<Float>> model;
+
+    public void putFeature(Map<String, Float> model, String featureKey, float weight) {
+        model.put(featureKey, weight);
+    }
+
+    public float getWeight(Map<String, List<Float>> model, String featureKey, int index) {
+        if (!model.containsKey(featureKey)) {
+            return 0.0f;
+        }
+
+        return model.get(featureKey).get(index);
+
+    }
+
+    @Override
+    public int getModelSize() {
+        if (this.model == null)
+            return 0;
+        return model.size();
+    }
+
+    public void cleanModel() {
+        this.model = null;
+    }
+
+    public Float score(Map<String, String> featureMap) {
+        float sum = 0.0f;
+
+        if (MapUtils.isNotEmpty(featureMap)) {
+            // 计算 sum w*x
+            float sum0 = 0.0f;
+            for (Map.Entry<String, String> e : featureMap.entrySet()) {
+                float x = NumberUtils.toFloat(e.getValue(), 0.0f);
+                float w = getWeight(this.model, e.getKey(), 0);
+                sum0 += w * x;
+            }
+            sum += sum0;
+
+            // 计算 sum v*v*x*X
+            float sum1 = 0.0f;
+            for (int i = 1; i < 9; i++) {
+                float sum10 = 0.0f;
+                float sum11 = 0.0f;
+                for (Map.Entry<String, String> e : featureMap.entrySet()) {
+                    float x = NumberUtils.toFloat(e.getValue(), 0.0f);
+                    float v = getWeight(this.model, e.getKey(), i);
+                    sum10 = v * x;
+                    sum11 = v * v * x * x;
+                }
+                sum1 += sum10 * sum10 - sum11;
+            }
+            sum1 = sum1 / 2;
+
+
+            float biasW = model.get("bias").get(0);
+            sum = biasW + sum0 + sum1;
+        }
+
+        return (float) (1.0f / (1 + Math.exp(-sum)));
+    }
+
+    /**
+     * 目前模型比较大,分两个阶段load模型
+     * (1). load 8M 模型, 并更新;
+     * (2). load 剩余的模型
+     * 中间提供一段时间有损的打分服务
+     *
+     * @param in
+     * @return
+     * @throws IOException
+     */
+    @Override
+    public boolean loadFromStream(InputStreamReader in) throws IOException {
+
+        Map<String, Float> model = new HashMap<>();
+        BufferedReader input = new BufferedReader(in);
+        String line = null;
+        int cnt = 0;
+
+        Integer curTime = new Long(System.currentTimeMillis() / 1000).intValue();
+        LOGGER.info("[MODELLOAD] before model load, key size: {}, current time: {}", lrModel.size(), curTime);
+        //first stage
+        while ((line = input.readLine()) != null) {
+            String[] items = line.split(" ");
+            if (items.length < 2) {
+                continue;
+            }
+
+            putFeature(model, items[0], Float.valueOf(items[1].trim()).floatValue());
+            if (cnt++ < 10) {
+                LOGGER.debug("fea: " + items[0] + ", weight: " + items[1]);
+            }
+            if (cnt > MODEL_FIRST_LOAD_COUNT) {
+                break;
+            }
+        }
+        //model update
+        this.lrModel = model;
+
+        LOGGER.info("[MODELLOAD] after first stage model load, key size: {}, current time: {}", lrModel.size(), curTime);
+        //final stage
+        while ((line = input.readLine()) != null) {
+            String[] items = line.split(" ");
+            if (items.length < 2) {
+                continue;
+            }
+            putFeature(model, items[0], Float.valueOf(items[1]).floatValue());
+        }
+        LOGGER.info("[MODELLOAD] after model load, key size: {}, current time: {}", lrModel.size(), curTime);
+
+        LOGGER.info("[MODELLOAD] model load over and size " + cnt);
+        input.close();
+        in.close();
+        return true;
+    }
+
+}

+ 112 - 0
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/score/model/LRV2Model.java

@@ -0,0 +1,112 @@
+package com.tzld.piaoquan.recommend.server.service.score.model;
+
+
+import org.apache.commons.collections4.MapUtils;
+import org.apache.commons.lang.math.NumberUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.util.HashMap;
+import java.util.Map;
+
+
+public class LRV2Model extends Model {
+    protected static final int MODEL_FIRST_LOAD_COUNT = 1 << 25; // 32M
+    private static final Logger LOGGER = LoggerFactory.getLogger(LRV2Model.class);
+    private Map<String, Float> lrModel;
+
+    public void putFeature(Map<String, Float> model, String featureKey, float weight) {
+        model.put(featureKey, weight);
+    }
+
+    public float getWeight(Map<String, Float> model, String featureKey) {
+        return model.getOrDefault(featureKey, 0.0f);
+    }
+
+    @Override
+    public int getModelSize() {
+        if (this.lrModel == null)
+            return 0;
+        return lrModel.size();
+    }
+
+    public void cleanModel() {
+        this.lrModel = null;
+    }
+
+    public Float score(Map<String, String> featureMap) {
+        float sum = 0.0f;
+
+        if (MapUtils.isNotEmpty(featureMap)) {
+            for (Map.Entry<String, String> e : featureMap.entrySet()) {
+                float w = getWeight(this.lrModel, e.getKey());
+                sum += w * NumberUtils.toFloat(e.getValue(), 0.0f);
+            }
+
+            float biasW = lrModel.get("bias");
+            sum += biasW;
+        }
+
+
+        return (float) (1.0f / (1 + Math.exp(-sum)));
+    }
+
+    /**
+     * 目前模型比较大,分两个阶段load模型
+     * (1). load 8M 模型, 并更新;
+     * (2). load 剩余的模型
+     * 中间提供一段时间有损的打分服务
+     *
+     * @param in
+     * @return
+     * @throws IOException
+     */
+    @Override
+    public boolean loadFromStream(InputStreamReader in) throws IOException {
+
+        Map<String, Float> model = new HashMap<>();
+        BufferedReader input = new BufferedReader(in);
+        String line = null;
+        int cnt = 0;
+
+        Integer curTime = new Long(System.currentTimeMillis() / 1000).intValue();
+        LOGGER.info("[MODELLOAD] before model load, key size: {}, current time: {}", lrModel.size(), curTime);
+        //first stage
+        while ((line = input.readLine()) != null) {
+            String[] items = line.split(" ");
+            if (items.length < 2) {
+                continue;
+            }
+
+            putFeature(model, items[0], Float.valueOf(items[1].trim()).floatValue());
+            if (cnt++ < 10) {
+                LOGGER.debug("fea: " + items[0] + ", weight: " + items[1]);
+            }
+            if (cnt > MODEL_FIRST_LOAD_COUNT) {
+                break;
+            }
+        }
+        //model update
+        this.lrModel = model;
+
+        LOGGER.info("[MODELLOAD] after first stage model load, key size: {}, current time: {}", lrModel.size(), curTime);
+        //final stage
+        while ((line = input.readLine()) != null) {
+            String[] items = line.split(" ");
+            if (items.length < 2) {
+                continue;
+            }
+            putFeature(model, items[0], Float.valueOf(items[1]).floatValue());
+        }
+        LOGGER.info("[MODELLOAD] after model load, key size: {}, current time: {}", lrModel.size(), curTime);
+
+        LOGGER.info("[MODELLOAD] model load over and size " + cnt);
+        input.close();
+        in.close();
+        return true;
+    }
+
+}