|
@@ -9,6 +9,7 @@ import org.slf4j.LoggerFactory;
|
|
|
import java.io.BufferedReader;
|
|
|
import java.io.IOException;
|
|
|
import java.io.InputStreamReader;
|
|
|
+import java.util.ArrayList;
|
|
|
import java.util.HashMap;
|
|
|
import java.util.List;
|
|
|
import java.util.Map;
|
|
@@ -19,8 +20,15 @@ public class FMModel extends Model {
|
|
|
private static final Logger LOGGER = LoggerFactory.getLogger(FMModel.class);
|
|
|
private Map<String, List<Float>> model;
|
|
|
|
|
|
- public void putFeature(Map<String, Float> model, String featureKey, float weight) {
|
|
|
- model.put(featureKey, weight);
|
|
|
+ public void putFeature(Map<String, List<Float>> model, String[] items) {
|
|
|
+
|
|
|
+ String featureKey = items[0];
|
|
|
+ List<Float> weights = new ArrayList<>();
|
|
|
+ for (int i = 1; i < items.length; i++) {
|
|
|
+ weights.add(Float.valueOf(items[i]));
|
|
|
+ }
|
|
|
+
|
|
|
+ model.put(featureKey, weights);
|
|
|
}
|
|
|
|
|
|
public float getWeight(Map<String, List<Float>> model, String featureKey, int index) {
|
|
@@ -64,14 +72,13 @@ public class FMModel extends Model {
|
|
|
for (Map.Entry<String, String> e : featureMap.entrySet()) {
|
|
|
float x = NumberUtils.toFloat(e.getValue(), 0.0f);
|
|
|
float v = getWeight(this.model, e.getKey(), i);
|
|
|
- sum10 = v * x;
|
|
|
- sum11 = v * v * x * x;
|
|
|
+ float d = v * x;
|
|
|
+ sum10 += d;
|
|
|
+ sum11 += d * d;
|
|
|
}
|
|
|
sum1 += sum10 * sum10 - sum11;
|
|
|
}
|
|
|
- sum1 = sum1 / 2;
|
|
|
-
|
|
|
-
|
|
|
+ sum1 = 0.5f * sum1;
|
|
|
float biasW = model.get("bias").get(0);
|
|
|
sum = biasW + sum0 + sum1;
|
|
|
}
|
|
@@ -92,41 +99,38 @@ public class FMModel extends Model {
|
|
|
@Override
|
|
|
public boolean loadFromStream(InputStreamReader in) throws IOException {
|
|
|
|
|
|
- Map<String, Float> model = new HashMap<>();
|
|
|
+ Map<String, List<Float>> model = new HashMap<>();
|
|
|
BufferedReader input = new BufferedReader(in);
|
|
|
String line = null;
|
|
|
int cnt = 0;
|
|
|
|
|
|
Integer curTime = new Long(System.currentTimeMillis() / 1000).intValue();
|
|
|
- LOGGER.info("[MODELLOAD] before model load, key size: {}, current time: {}", lrModel.size(), curTime);
|
|
|
+ LOGGER.info("[MODELLOAD] before model load, key size: {}, current time: {}", model.size(), curTime);
|
|
|
//first stage
|
|
|
while ((line = input.readLine()) != null) {
|
|
|
String[] items = line.split(" ");
|
|
|
- if (items.length < 2) {
|
|
|
+ if (items.length < 9) {
|
|
|
continue;
|
|
|
}
|
|
|
|
|
|
- putFeature(model, items[0], Float.valueOf(items[1].trim()).floatValue());
|
|
|
- if (cnt++ < 10) {
|
|
|
- LOGGER.debug("fea: " + items[0] + ", weight: " + items[1]);
|
|
|
- }
|
|
|
+ putFeature(model, items);
|
|
|
if (cnt > MODEL_FIRST_LOAD_COUNT) {
|
|
|
break;
|
|
|
}
|
|
|
}
|
|
|
//model update
|
|
|
- this.lrModel = model;
|
|
|
+ this.model = model;
|
|
|
|
|
|
- LOGGER.info("[MODELLOAD] after first stage model load, key size: {}, current time: {}", lrModel.size(), curTime);
|
|
|
+ LOGGER.info("[MODELLOAD] after first stage model load, key size: {}, current time: {}", model.size(), curTime);
|
|
|
//final stage
|
|
|
while ((line = input.readLine()) != null) {
|
|
|
String[] items = line.split(" ");
|
|
|
- if (items.length < 2) {
|
|
|
+ if (items.length < 9) {
|
|
|
continue;
|
|
|
}
|
|
|
- putFeature(model, items[0], Float.valueOf(items[1]).floatValue());
|
|
|
+ putFeature(model, items);
|
|
|
}
|
|
|
- LOGGER.info("[MODELLOAD] after model load, key size: {}, current time: {}", lrModel.size(), curTime);
|
|
|
+ LOGGER.info("[MODELLOAD] after model load, key size: {}, current time: {}", model.size(), curTime);
|
|
|
|
|
|
LOGGER.info("[MODELLOAD] model load over and size " + cnt);
|
|
|
input.close();
|