|
@@ -0,0 +1,137 @@
|
|
|
+package com.tzld.piaoquan.recommend.server.service.score.model;
|
|
|
+
|
|
|
+
|
|
|
+import org.apache.commons.collections4.MapUtils;
|
|
|
+import org.apache.commons.lang.math.NumberUtils;
|
|
|
+import org.slf4j.Logger;
|
|
|
+import org.slf4j.LoggerFactory;
|
|
|
+
|
|
|
+import java.io.BufferedReader;
|
|
|
+import java.io.IOException;
|
|
|
+import java.io.InputStreamReader;
|
|
|
+import java.util.HashMap;
|
|
|
+import java.util.List;
|
|
|
+import java.util.Map;
|
|
|
+
|
|
|
+
|
|
|
+public class FMModel extends Model {
|
|
|
+ protected static final int MODEL_FIRST_LOAD_COUNT = 1 << 25; // 32M
|
|
|
+ private static final Logger LOGGER = LoggerFactory.getLogger(FMModel.class);
|
|
|
+ private Map<String, List<Float>> model;
|
|
|
+
|
|
|
+ public void putFeature(Map<String, Float> model, String featureKey, float weight) {
|
|
|
+ model.put(featureKey, weight);
|
|
|
+ }
|
|
|
+
|
|
|
+ public float getWeight(Map<String, List<Float>> model, String featureKey, int index) {
|
|
|
+ if (!model.containsKey(featureKey)) {
|
|
|
+ return 0.0f;
|
|
|
+ }
|
|
|
+
|
|
|
+ return model.get(featureKey).get(index);
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public int getModelSize() {
|
|
|
+ if (this.model == null)
|
|
|
+ return 0;
|
|
|
+ return model.size();
|
|
|
+ }
|
|
|
+
|
|
|
+ public void cleanModel() {
|
|
|
+ this.model = null;
|
|
|
+ }
|
|
|
+
|
|
|
+ public Float score(Map<String, String> featureMap) {
|
|
|
+ float sum = 0.0f;
|
|
|
+
|
|
|
+ if (MapUtils.isNotEmpty(featureMap)) {
|
|
|
+ // 计算 sum w*x
|
|
|
+ float sum0 = 0.0f;
|
|
|
+ for (Map.Entry<String, String> e : featureMap.entrySet()) {
|
|
|
+ float x = NumberUtils.toFloat(e.getValue(), 0.0f);
|
|
|
+ float w = getWeight(this.model, e.getKey(), 0);
|
|
|
+ sum0 += w * x;
|
|
|
+ }
|
|
|
+ sum += sum0;
|
|
|
+
|
|
|
+ // 计算 sum v*v*x*X
|
|
|
+ float sum1 = 0.0f;
|
|
|
+ for (int i = 1; i < 9; i++) {
|
|
|
+ float sum10 = 0.0f;
|
|
|
+ float sum11 = 0.0f;
|
|
|
+ for (Map.Entry<String, String> e : featureMap.entrySet()) {
|
|
|
+ float x = NumberUtils.toFloat(e.getValue(), 0.0f);
|
|
|
+ float v = getWeight(this.model, e.getKey(), i);
|
|
|
+ sum10 = v * x;
|
|
|
+ sum11 = v * v * x * x;
|
|
|
+ }
|
|
|
+ sum1 += sum10 * sum10 - sum11;
|
|
|
+ }
|
|
|
+ sum1 = sum1 / 2;
|
|
|
+
|
|
|
+
|
|
|
+ float biasW = model.get("bias").get(0);
|
|
|
+ sum = biasW + sum0 + sum1;
|
|
|
+ }
|
|
|
+
|
|
|
+ return (float) (1.0f / (1 + Math.exp(-sum)));
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 目前模型比较大,分两个阶段load模型
|
|
|
+ * (1). load 8M 模型, 并更新;
|
|
|
+ * (2). load 剩余的模型
|
|
|
+ * 中间提供一段时间有损的打分服务
|
|
|
+ *
|
|
|
+ * @param in
|
|
|
+ * @return
|
|
|
+ * @throws IOException
|
|
|
+ */
|
|
|
+ @Override
|
|
|
+ public boolean loadFromStream(InputStreamReader in) throws IOException {
|
|
|
+
|
|
|
+ Map<String, Float> model = new HashMap<>();
|
|
|
+ BufferedReader input = new BufferedReader(in);
|
|
|
+ String line = null;
|
|
|
+ int cnt = 0;
|
|
|
+
|
|
|
+ Integer curTime = new Long(System.currentTimeMillis() / 1000).intValue();
|
|
|
+ LOGGER.info("[MODELLOAD] before model load, key size: {}, current time: {}", lrModel.size(), curTime);
|
|
|
+ //first stage
|
|
|
+ while ((line = input.readLine()) != null) {
|
|
|
+ String[] items = line.split(" ");
|
|
|
+ if (items.length < 2) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+
|
|
|
+ putFeature(model, items[0], Float.valueOf(items[1].trim()).floatValue());
|
|
|
+ if (cnt++ < 10) {
|
|
|
+ LOGGER.debug("fea: " + items[0] + ", weight: " + items[1]);
|
|
|
+ }
|
|
|
+ if (cnt > MODEL_FIRST_LOAD_COUNT) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ //model update
|
|
|
+ this.lrModel = model;
|
|
|
+
|
|
|
+ LOGGER.info("[MODELLOAD] after first stage model load, key size: {}, current time: {}", lrModel.size(), curTime);
|
|
|
+ //final stage
|
|
|
+ while ((line = input.readLine()) != null) {
|
|
|
+ String[] items = line.split(" ");
|
|
|
+ if (items.length < 2) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ putFeature(model, items[0], Float.valueOf(items[1]).floatValue());
|
|
|
+ }
|
|
|
+ LOGGER.info("[MODELLOAD] after model load, key size: {}, current time: {}", lrModel.size(), curTime);
|
|
|
+
|
|
|
+ LOGGER.info("[MODELLOAD] model load over and size " + cnt);
|
|
|
+ input.close();
|
|
|
+ in.close();
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+
|
|
|
+}
|