Browse Source

LR、FM模型

丁云鹏 10 months ago
parent
commit
34ac09bde9

+ 23 - 19
recommend-server-service/src/main/java/com/tzld/piaoquan/recommend/server/service/score/model/FMModel.java

@@ -9,6 +9,7 @@ import org.slf4j.LoggerFactory;
 import java.io.BufferedReader;
 import java.io.IOException;
 import java.io.InputStreamReader;
+import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -19,8 +20,15 @@ public class FMModel extends Model {
     private static final Logger LOGGER = LoggerFactory.getLogger(FMModel.class);
     private Map<String, List<Float>> model;
 
-    public void putFeature(Map<String, Float> model, String featureKey, float weight) {
-        model.put(featureKey, weight);
+    public void putFeature(Map<String, List<Float>> model, String[] items) {
+
+        String featureKey = items[0];
+        List<Float> weights = new ArrayList<>();
+        for (int i = 1; i < items.length; i++) {
+            weights.add(Float.valueOf(items[i]));
+        }
+
+        model.put(featureKey, weights);
     }
 
     public float getWeight(Map<String, List<Float>> model, String featureKey, int index) {
@@ -64,14 +72,13 @@ public class FMModel extends Model {
                 for (Map.Entry<String, String> e : featureMap.entrySet()) {
                     float x = NumberUtils.toFloat(e.getValue(), 0.0f);
                     float v = getWeight(this.model, e.getKey(), i);
-                    sum10 = v * x;
-                    sum11 = v * v * x * x;
+                    float d = v * x;
+                    sum10 += d;
+                    sum11 += d * d;
                 }
                 sum1 += sum10 * sum10 - sum11;
             }
-            sum1 = sum1 / 2;
-
-
+            sum1 = 0.5f * sum1;
             float biasW = model.get("bias").get(0);
             sum = biasW + sum0 + sum1;
         }
@@ -92,41 +99,38 @@ public class FMModel extends Model {
     @Override
     public boolean loadFromStream(InputStreamReader in) throws IOException {
 
-        Map<String, Float> model = new HashMap<>();
+        Map<String, List<Float>> model = new HashMap<>();
         BufferedReader input = new BufferedReader(in);
         String line = null;
         int cnt = 0;
 
         Integer curTime = new Long(System.currentTimeMillis() / 1000).intValue();
-        LOGGER.info("[MODELLOAD] before model load, key size: {}, current time: {}", lrModel.size(), curTime);
+        LOGGER.info("[MODELLOAD] before model load, key size: {}, current time: {}", model.size(), curTime);
         //first stage
         while ((line = input.readLine()) != null) {
             String[] items = line.split(" ");
-            if (items.length < 2) {
+            if (items.length < 9) {
                 continue;
             }
 
-            putFeature(model, items[0], Float.valueOf(items[1].trim()).floatValue());
-            if (cnt++ < 10) {
-                LOGGER.debug("fea: " + items[0] + ", weight: " + items[1]);
-            }
+            putFeature(model, items);
             if (cnt > MODEL_FIRST_LOAD_COUNT) {
                 break;
             }
         }
         //model update
-        this.lrModel = model;
+        this.model = model;
 
-        LOGGER.info("[MODELLOAD] after first stage model load, key size: {}, current time: {}", lrModel.size(), curTime);
+        LOGGER.info("[MODELLOAD] after first stage model load, key size: {}, current time: {}", model.size(), curTime);
         //final stage
         while ((line = input.readLine()) != null) {
             String[] items = line.split(" ");
-            if (items.length < 2) {
+            if (items.length < 9) {
                 continue;
             }
-            putFeature(model, items[0], Float.valueOf(items[1]).floatValue());
+            putFeature(model, items);
         }
-        LOGGER.info("[MODELLOAD] after model load, key size: {}, current time: {}", lrModel.size(), curTime);
+        LOGGER.info("[MODELLOAD] after model load, key size: {}, current time: {}", model.size(), curTime);
 
         LOGGER.info("[MODELLOAD] model load over and size " + cnt);
         input.close();