|
@@ -0,0 +1,419 @@
|
|
|
+package com.tzld.piaoquan.recommend.server.service.score.model;
|
|
|
+
|
|
|
+import com.typesafe.config.Config;
|
|
|
+import com.typesafe.config.ConfigFactory;
|
|
|
+
|
|
|
+import it.unimi.dsi.fastutil.longs.Long2FloatMap;
|
|
|
+import it.unimi.dsi.fastutil.longs.Long2FloatOpenHashMap;
|
|
|
+import com.tzld.piaoquan.recommend.server.gen.recommend.LRSamples;
|
|
|
+import com.tzld.piaoquan.recommend.server.gen.recommend.GroupedFeature;
|
|
|
+import com.tzld.piaoquan.recommend.server.gen.recommend.BaseFeature;
|
|
|
+
|
|
|
+import org.apache.commons.lang.exception.ExceptionUtils;
|
|
|
+import org.apache.commons.lang3.StringUtils;
|
|
|
+import org.apache.hadoop.conf.Configuration;
|
|
|
+import org.apache.hadoop.fs.FSDataInputStream;
|
|
|
+import org.apache.hadoop.fs.FileStatus;
|
|
|
+import org.apache.hadoop.fs.FileSystem;
|
|
|
+import org.apache.hadoop.fs.Path;
|
|
|
+import org.slf4j.Logger;
|
|
|
+import org.slf4j.LoggerFactory;
|
|
|
+
|
|
|
+import java.io.BufferedReader;
|
|
|
+import java.io.IOException;
|
|
|
+import java.io.InputStreamReader;
|
|
|
+import java.math.BigInteger;
|
|
|
+import java.util.ArrayList;
|
|
|
+import java.util.List;
|
|
|
+import java.util.Map;
|
|
|
+import java.util.concurrent.Callable;
|
|
|
+import java.util.concurrent.ExecutionException;
|
|
|
+import java.util.concurrent.ExecutorService;
|
|
|
+import java.util.concurrent.Executors;
|
|
|
+import java.util.concurrent.Future;
|
|
|
+import java.util.concurrent.TimeUnit;
|
|
|
+
|
|
|
+public class LRModel extends Model {
|
|
|
+ protected static final int MODEL_FIRST_LOAD_COUNT = 1 << 25; // 32M
|
|
|
+ private static final Logger LOGGER = LoggerFactory.getLogger(LRModel.class);
|
|
|
+ private static final ExecutorService executorService = Executors.newCachedThreadPool();
|
|
|
+ private final int bucketBits = 10; // power(2, 10) => 1024 个槽位
|
|
|
+ private List<Long2FloatMap> lrModel;
|
|
|
+ private Config config = ConfigFactory.load("hdfs_conf.properties");
|
|
|
+ private Configuration hdfsConf = new Configuration();
|
|
|
+
|
|
|
+ public LRModel() {
|
|
|
+ //配置不同环境的hdfs conf
|
|
|
+ String coreSiteFile = config.hasPath("hdfs.coreSiteFile") ? StringUtils.trim(config.getString("hdfs.coreSiteFile")) : "core-site.xml";
|
|
|
+ String hdfsSiteFile = config.hasPath("hdfs.hdfsSiteFile") ? StringUtils.trim(config.getString("hdfs.hdfsSiteFile")) : "hdfs-site.xml";
|
|
|
+ hdfsConf.addResource(coreSiteFile);
|
|
|
+ hdfsConf.addResource(hdfsSiteFile);
|
|
|
+ this.lrModel = constructModel();
|
|
|
+ }
|
|
|
+
|
|
|
+ public List<Long2FloatMap> getLrModel() {
|
|
|
+ return lrModel;
|
|
|
+ }
|
|
|
+
|
|
|
+ public List<Long2FloatMap> constructModel() {
|
|
|
+ List<Long2FloatMap> initModel = new ArrayList<Long2FloatMap>();
|
|
|
+ int buckets = (int) Math.pow(2, bucketBits);
|
|
|
+ for (int i = 0; i < buckets; i++) {
|
|
|
+ Long2FloatMap internalModel = new Long2FloatOpenHashMap();
|
|
|
+ internalModel.defaultReturnValue(0.0f);
|
|
|
+ initModel.add(internalModel);
|
|
|
+ }
|
|
|
+ return initModel;
|
|
|
+ }
|
|
|
+
|
|
|
+ public int getBucket(long featureHash) {
|
|
|
+ return (int) (((featureHash >> bucketBits) << bucketBits) ^ featureHash);
|
|
|
+ }
|
|
|
+
|
|
|
+ public void putFeature(List<Long2FloatMap> model, long featureHash, float weight) {
|
|
|
+ model.get(getBucket(featureHash)).put(featureHash, weight);
|
|
|
+ }
|
|
|
+
|
|
|
+ public float getWeight(List<Long2FloatMap> model, long featureHash) {
|
|
|
+ return model.get(getBucket(featureHash)).get(featureHash);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public int getModelSize() {
|
|
|
+ if (this.lrModel == null)
|
|
|
+ return 0;
|
|
|
+ int sum = 0;
|
|
|
+ for (Map<Long, Float> model : this.lrModel) {
|
|
|
+ sum += model.size();
|
|
|
+ }
|
|
|
+ return sum;
|
|
|
+ }
|
|
|
+
|
|
|
+ public void cleanModel() {
|
|
|
+ this.lrModel = null;
|
|
|
+ }
|
|
|
+
|
|
|
+ public Float score(LRSamples lrSamples) {
|
|
|
+ float sum = 0.0f;
|
|
|
+ for(int i=0; i< lrSamples.getFeaturesCount(); i++)
|
|
|
+ {
|
|
|
+ GroupedFeature gf = lrSamples.getFeatures(i);
|
|
|
+ if (gf != null && gf.getFeatures(i) != null) {
|
|
|
+ for(int j=0; j < gf.getFeaturesCount(); j++) {
|
|
|
+ BaseFeature fea = gf.getFeatures(j);
|
|
|
+ if (fea != null) {
|
|
|
+ float tmp = getWeight(this.lrModel, fea.getIdentifier());
|
|
|
+ fea.toBuilder().setWeight(tmp);
|
|
|
+ sum += tmp;
|
|
|
+ }
|
|
|
+
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ float pro = (float) (1.0f / (1 + Math.exp(-sum)));
|
|
|
+ lrSamples.toBuilder().setPredictCtr(pro);
|
|
|
+ return pro;
|
|
|
+ }
|
|
|
+
|
|
|
+ public Float getWeights(LRSamples lrSamples) {
|
|
|
+ float sum = 0.0f;
|
|
|
+
|
|
|
+ for(int i=0; i< lrSamples.getFeaturesCount(); i++)
|
|
|
+ {
|
|
|
+ GroupedFeature gf = lrSamples.getFeatures(i);
|
|
|
+ if (gf != null && gf.getFeatures(i) != null) {
|
|
|
+ for(int j=0; j < gf.getFeaturesCount(); j++) {
|
|
|
+ BaseFeature fea = gf.getFeatures(j);
|
|
|
+ if (fea != null) {
|
|
|
+ float tmp = getWeight(this.lrModel, fea.getIdentifier());
|
|
|
+ fea.toBuilder().setWeight(tmp);
|
|
|
+ sum += tmp;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ lrSamples.toBuilder().setWeight(sum);
|
|
|
+ return sum;
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 目前模型比较大,分两个阶段load模型
|
|
|
+ * (1). load 8M 模型, 并更新;
|
|
|
+ * (2). load 剩余的模型
|
|
|
+ * 中间提供一段时间有损的打分服务
|
|
|
+ *
|
|
|
+ * @param in
|
|
|
+ * @return
|
|
|
+ * @throws IOException
|
|
|
+ */
|
|
|
+ @Override
|
|
|
+ public boolean loadFromStream(InputStreamReader in) throws IOException {
|
|
|
+
|
|
|
+ List<Long2FloatMap> model = constructModel();
|
|
|
+ BufferedReader input = new BufferedReader(in);
|
|
|
+ String line = null;
|
|
|
+ int cnt = 0;
|
|
|
+
|
|
|
+ Integer curTime = new Long(System.currentTimeMillis() / 1000).intValue();
|
|
|
+ LOGGER.info("[MODELLOAD] before model load, key size: {}, current time: {}", lrModel.size(), curTime);
|
|
|
+ //first stage
|
|
|
+ while ((line = input.readLine()) != null) {
|
|
|
+ String[] items = line.split("\t");
|
|
|
+ if (items.length < 2) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+
|
|
|
+ putFeature(model, new BigInteger(items[0]).longValue(), Float.valueOf(items[1]).floatValue());
|
|
|
+ if (cnt++ < 10) {
|
|
|
+ LOGGER.debug("fea: " + items[0] + ", weight: " + items[1]);
|
|
|
+ }
|
|
|
+ if (cnt > MODEL_FIRST_LOAD_COUNT) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ //model update
|
|
|
+ this.lrModel = model;
|
|
|
+ LOGGER.info("[MODELLOAD] after first stage model load, key size: {}, current time: {}", lrModel.size(), curTime);
|
|
|
+ //final stage
|
|
|
+ while ((line = input.readLine()) != null) {
|
|
|
+ String[] items = line.split("\t");
|
|
|
+ if (items.length < 2) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ putFeature(model, new BigInteger(items[0]).longValue(), Float.valueOf(items[1]).floatValue());
|
|
|
+ }
|
|
|
+ LOGGER.info("[MODELLOAD] after model load, key size: {}, current time: {}", lrModel.size(), curTime);
|
|
|
+
|
|
|
+ LOGGER.info("[MODELLOAD] model load over and size " + cnt);
|
|
|
+ input.close();
|
|
|
+ in.close();
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public boolean loadPartitions(final String modelPath, boolean isRegister) {
|
|
|
+ if (isRegister) {
|
|
|
+ return loadPartitionsParallel(modelPath);
|
|
|
+ } else {
|
|
|
+ return loadPartitionsSingle(modelPath);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 程序运行的过程中, model 太大, 为了尽可能防止full gc, 单线程加载
|
|
|
+ *
|
|
|
+ * @param modelPath
|
|
|
+ * @return
|
|
|
+ */
|
|
|
+
|
|
|
+ public boolean loadPartitionsSingle(final String modelPath) {
|
|
|
+ try {
|
|
|
+ Path path = new Path(modelPath);
|
|
|
+ hdfsConf.setBoolean("fs.hdfs.impl.disable.cache", true);
|
|
|
+
|
|
|
+ FileSystem fs = path.getFileSystem(hdfsConf);
|
|
|
+ FileStatus[] listStatus = fs.listStatus(path);
|
|
|
+
|
|
|
+ //judge null and empty
|
|
|
+ if (listStatus == null || listStatus.length == 0) {
|
|
|
+ LOGGER.error("model path is dir, but hdfs patition path is null");
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ long startTime = System.currentTimeMillis();
|
|
|
+
|
|
|
+ // 初始化model大小,直接分配到old heap
|
|
|
+ List<Long2FloatMap> currLrModel = constructModel();
|
|
|
+
|
|
|
+ //multi thread load hdfs news info files
|
|
|
+ int failedPartitionNum = 0;
|
|
|
+ int partitionsNum = listStatus.length;
|
|
|
+ for (final FileStatus file : listStatus) {
|
|
|
+ String absPath = String.format("%s/%s", modelPath, file.getPath().getName());
|
|
|
+ InputStreamReader fin = null;
|
|
|
+ try {
|
|
|
+ Path tmpPath = new Path(absPath);
|
|
|
+ FileSystem tmpFs = tmpPath.getFileSystem(hdfsConf);
|
|
|
+ FSDataInputStream inputStream = tmpFs.open(tmpPath);
|
|
|
+ fin = new InputStreamReader(inputStream);
|
|
|
+
|
|
|
+ BufferedReader input = new BufferedReader(fin);
|
|
|
+ String line = null;
|
|
|
+ int cnt = 0;
|
|
|
+
|
|
|
+ //first stage
|
|
|
+ while ((line = input.readLine()) != null) {
|
|
|
+
|
|
|
+ String[] items = line.split("\t");
|
|
|
+ if (items.length < 2) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+
|
|
|
+ // write sync
|
|
|
+ putFeature(currLrModel, new BigInteger(items[0]).longValue(), Float.valueOf(items[1]).floatValue());
|
|
|
+ cnt++;
|
|
|
+ }
|
|
|
+ LOGGER.info("load model [SUCCESS] , file path [{}], load item number [{}]", absPath, cnt);
|
|
|
+ } catch (Exception e) {
|
|
|
+ failedPartitionNum++;
|
|
|
+ LOGGER.error("load model file from hdfs occur error [FAILED], path: [{}], [{}]",
|
|
|
+ absPath, ExceptionUtils.getFullStackTrace(e));
|
|
|
+ } finally {
|
|
|
+ if (fin != null) {
|
|
|
+ try {
|
|
|
+ fin.close();
|
|
|
+ } catch (IOException e) {
|
|
|
+ LOGGER.error("close [{}] fail: [{}]", absPath, ExceptionUtils.getFullStackTrace(e));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if (failedPartitionNum == 0) {
|
|
|
+ this.lrModel = currLrModel;
|
|
|
+ LOGGER.info("[end] load model data from hdfs, spend time: [{}ms] model size: [{}], " +
|
|
|
+ "total partition number [{}], failed partition numbers: [{}], model path [{}]", new Object[]{
|
|
|
+ (System.currentTimeMillis() - startTime), getModelSize(), partitionsNum, failedPartitionNum, modelPath});
|
|
|
+ return true;
|
|
|
+ } else {
|
|
|
+ LOGGER.error("load model failed parts [{}]", failedPartitionNum);
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ } catch (Exception e) {
|
|
|
+ LOGGER.error("load model partitions occur error, model path [{}], error: [{}]",
|
|
|
+ modelPath, ExceptionUtils.getFullStackTrace(e));
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * concurrency load model from modelpath
|
|
|
+ * put map must be sync
|
|
|
+ * if partitions not 0 && load success: return true
|
|
|
+ * exceptions || 0 partitions || any partitions failed : return false
|
|
|
+ *
|
|
|
+ * @param modelPath
|
|
|
+ * @return
|
|
|
+ */
|
|
|
+ public boolean loadPartitionsParallel(final String modelPath) {
|
|
|
+ try {
|
|
|
+ Path path = new Path(modelPath);
|
|
|
+ FileSystem fs = path.getFileSystem(hdfsConf);
|
|
|
+ FileStatus[] listStatus = fs.listStatus(path);
|
|
|
+
|
|
|
+ //judge null and empty
|
|
|
+ if (listStatus == null || listStatus.length == 0) {
|
|
|
+ LOGGER.error("model path is dir, but hdfs patition path is null");
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ long startTime = System.currentTimeMillis();
|
|
|
+ List<Callable<Long2FloatMap>> callables = new ArrayList<Callable<Long2FloatMap>>();
|
|
|
+ //multi thread load hdfs news info files
|
|
|
+ for (final FileStatus file : listStatus) {
|
|
|
+ callables.add(new Callable<Long2FloatMap>() {
|
|
|
+ @Override
|
|
|
+ public Long2FloatMap call() {
|
|
|
+ // LOGGER.debug("load model file path [{}]", file.getPath().getName());
|
|
|
+ String abspath = String.format("%s/%s", modelPath, file.getPath().getName());
|
|
|
+ InputStreamReader fin = null;
|
|
|
+ Long2FloatMap partModel = new Long2FloatOpenHashMap();
|
|
|
+ try {
|
|
|
+
|
|
|
+ Path path = new Path(abspath);
|
|
|
+ FileSystem fs = path.getFileSystem(hdfsConf);
|
|
|
+ FSDataInputStream inputStream = fs.open(path);
|
|
|
+ fin = new InputStreamReader(inputStream);
|
|
|
+
|
|
|
+ BufferedReader input = new BufferedReader(fin);
|
|
|
+ String line = null;
|
|
|
+ int cnt = 0;
|
|
|
+
|
|
|
+ //first stage
|
|
|
+ while ((line = input.readLine()) != null) {
|
|
|
+
|
|
|
+ String[] items = line.split("\t");
|
|
|
+ if (items.length < 2) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+
|
|
|
+ // write sync
|
|
|
+ partModel.put(new BigInteger(items[0]).longValue(), Float.valueOf(items[1]).floatValue());
|
|
|
+ cnt++;
|
|
|
+ }
|
|
|
+ LOGGER.info("load model [SUCCESS] , file path [{}], load item number [{}]", abspath, cnt);
|
|
|
+ return partModel;
|
|
|
+ } catch (Exception e) {
|
|
|
+ LOGGER.error("load model file from hdfs occur error [FAILED], path: [{}], [{}]",
|
|
|
+ abspath, ExceptionUtils.getFullStackTrace(e));
|
|
|
+ return null;
|
|
|
+ } finally {
|
|
|
+ if (fin != null) {
|
|
|
+ try {
|
|
|
+ fin.close();
|
|
|
+ } catch (IOException e) {
|
|
|
+ LOGGER.error("close [{}] fail: [{}]", abspath, ExceptionUtils.getFullStackTrace(e));
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return partModel;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ });
|
|
|
+ }
|
|
|
+
|
|
|
+ //invoke callable if failed return
|
|
|
+ List<Future<Long2FloatMap>> futures = null;
|
|
|
+ try {
|
|
|
+ futures = executorService.invokeAll(callables, 10, TimeUnit.MINUTES);
|
|
|
+ } catch (InterruptedException e) {
|
|
|
+ LOGGER.error("execute invoke fail: {}", ExceptionUtils.getFullStackTrace(e));
|
|
|
+ }
|
|
|
+
|
|
|
+ //wait for task complete
|
|
|
+ int failedPartitionNum = 0;
|
|
|
+ int partitionsNum = listStatus.length;
|
|
|
+ List<Long2FloatMap> currLrModel = constructModel();
|
|
|
+ for (Future<Long2FloatMap> future : futures) {
|
|
|
+ try {
|
|
|
+ if (future.isDone() && !future.isCancelled()) {
|
|
|
+
|
|
|
+ Long2FloatMap ret = future.get();
|
|
|
+ if (ret == null) {
|
|
|
+ failedPartitionNum++;
|
|
|
+ }
|
|
|
+ for (Map.Entry<Long, Float> entry : ret.entrySet()) {
|
|
|
+ putFeature(currLrModel, entry.getKey(), entry.getValue());
|
|
|
+ }
|
|
|
+ ret = null; // gc
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ } catch (InterruptedException e) {
|
|
|
+ failedPartitionNum++;
|
|
|
+ LOGGER.error("InterruptedException {},{}", ExceptionUtils.getFullStackTrace(e));
|
|
|
+ } catch (ExecutionException e) {
|
|
|
+ failedPartitionNum++;
|
|
|
+ LOGGER.error("ExecutionException [{}]", ExceptionUtils.getFullStackTrace(e));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ long endTime = System.currentTimeMillis();
|
|
|
+ // check all load success
|
|
|
+ if (failedPartitionNum == 0) {
|
|
|
+ this.lrModel = currLrModel;
|
|
|
+ // counter for alarm
|
|
|
+ LOGGER.info("[end] load model data from hdfs, spend time: [{}ms] model size: [{}], " +
|
|
|
+ "total partition number [{}], failed partition numbers: [{}], model path [{}]", new Object[]{
|
|
|
+ (endTime - startTime), getModelSize(), partitionsNum, failedPartitionNum, modelPath});
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+
|
|
|
+ return false;
|
|
|
+ } catch (Exception e) {
|
|
|
+ LOGGER.error("load model partitions occur error, model path [{}], error: [{}]",
|
|
|
+ modelPath, ExceptionUtils.getFullStackTrace(e));
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|