|
|
@@ -0,0 +1,113 @@
|
|
|
+package examples.model;
|
|
|
+
|
|
|
+import java.io.Serializable;
|
|
|
+import java.util.HashMap;
|
|
|
+import java.util.List;
|
|
|
+import java.util.Map;
|
|
|
+
|
|
|
+public class FMModel implements Serializable {
|
|
|
+ private int factor;
|
|
|
+ private double bias;
|
|
|
+ private Map<String, double[]> weight;
|
|
|
+
|
|
|
+ public static FMModel builder(List<String> lines, int factor) {
|
|
|
+ if (factor < 0 || factor > 32) {
|
|
|
+ System.out.printf("factor=%d is wrong\n", factor);
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+
|
|
|
+ double bias = 0;
|
|
|
+ Map<String, double[]> weight = new HashMap<>();
|
|
|
+
|
|
|
+ int number = 0;
|
|
|
+ final int filedNum = 1 + 1 + factor;
|
|
|
+ for (String line : lines) {
|
|
|
+ number += 1;
|
|
|
+ String[] cells = line.split("\t");
|
|
|
+ if (1 == number) {
|
|
|
+ // w0
|
|
|
+ if (cells.length != 2) {
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+ if ("bias".equals(cells[0])) {
|
|
|
+ bias = Double.parseDouble(cells[1]);
|
|
|
+ } else {
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ // w1 & vector
|
|
|
+ if (cells.length != filedNum) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ String key = cells[0];
|
|
|
+ double[] coefficient = new double[factor + 1];
|
|
|
+ for (int i = 1; i <= factor + 1; i++) {
|
|
|
+ coefficient[i - 1] = Double.parseDouble(cells[i]);
|
|
|
+ }
|
|
|
+ weight.put(key, coefficient);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ System.out.printf("load %d features\n", weight.size());
|
|
|
+ if (weight.size() < 100) {
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+
|
|
|
+ FMModel model = new FMModel();
|
|
|
+ model.setFactor(factor);
|
|
|
+ model.setBias(bias);
|
|
|
+ model.setWeight(weight);
|
|
|
+ return model;
|
|
|
+ }
|
|
|
+
|
|
|
+ private void setBias(double bias) {
|
|
|
+ this.bias = bias;
|
|
|
+ }
|
|
|
+
|
|
|
+ private void setFactor(int factor) {
|
|
|
+ this.factor = factor;
|
|
|
+ }
|
|
|
+
|
|
|
+ private void setWeight(Map<String, double[]> weight) {
|
|
|
+ this.weight = weight;
|
|
|
+ }
|
|
|
+
|
|
|
+ private double sigmod(double score) {
|
|
|
+ return 1.0 / (1.0 + Math.exp(-score));
|
|
|
+ }
|
|
|
+
|
|
|
+ public double predict(Map<String, Double> features) {
|
|
|
+ double score = 0;
|
|
|
+ if (null != features && !features.isEmpty() && null != this.weight) {
|
|
|
+ // bias
|
|
|
+ score += this.bias;
|
|
|
+
|
|
|
+ double[] sumSquare = new double[this.factor];
|
|
|
+ double[] squareSum = new double[this.factor];
|
|
|
+ for (Map.Entry<String, Double> entry : features.entrySet()) {
|
|
|
+ String key = entry.getKey();
|
|
|
+ double val = entry.getValue();
|
|
|
+ double[] vector = this.weight.get(key);
|
|
|
+ if (vector == null) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+
|
|
|
+ // w1
|
|
|
+ score += val * vector[0];
|
|
|
+
|
|
|
+ // sumSquare, squareSum
|
|
|
+ for (int i = 0; i < this.factor; i++) {
|
|
|
+ double mul = val * vector[i + 1]; // Vni*X
|
|
|
+ sumSquare[i] += mul;
|
|
|
+ squareSum[i] += mul * mul; // (Vni*X)^2
|
|
|
+
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // vector
|
|
|
+ for (int i = 0; i < this.factor; i++) {
|
|
|
+ score += 0.5 * (sumSquare[i] * sumSquare[i] - squareSum[i]);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return sigmod(score);
|
|
|
+ }
|
|
|
+}
|