丁云鹏 7 hónapja
szülő
commit
69138777cf

+ 50 - 0
.gitignore

@@ -0,0 +1,50 @@
+# ---> Java
+*.class
+
+# Mobile Tools for Java (J2ME)
+.mtj.tmp/
+
+# Package Files #
+*.jar
+*.war
+*.ear
+
+# virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml
+hs_err_pid*
+
+HELP.md
+target/
+!.mvn/wrapper/maven-wrapper.jar
+!**/src/main/**/target/
+!**/src/test/**/target/
+
+### STS ###
+.apt_generated
+.classpath
+.factorypath
+.project
+.settings
+.springBeans
+.sts4-cache
+
+### IntelliJ IDEA ###
+.idea
+*.iws
+*.iml
+*.ipr
+
+### NetBeans ###
+/nbproject/private/
+/nbbuild/
+/dist/
+/nbdist/
+/.nb-gradle/
+build/
+!**/src/main/**/build/
+!**/src/test/**/build/
+
+### VS Code ###
+.vscode/
+
+### log ###
+logs/*

+ 1 - 0
README.md

@@ -0,0 +1 @@
+https://github.com/NLPchina/Word2VEC_java

+ 44 - 0
pom.xml

@@ -0,0 +1,44 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+
+    <modelVersion>4.0.0</modelVersion>
+    <packaging>pom</packaging>
+    <parent>
+        <groupId>com.tzld.commons</groupId>
+        <artifactId>supom</artifactId>
+        <version>1.0.9</version>
+    </parent>
+    <groupId>com.tzld.piaoquan</groupId>
+    <artifactId>recommend-similarity</artifactId>
+    <version>1.0.0</version>
+    <name>recommend-similarity</name>
+    <description>recommend-similarity</description>
+    <dependencies>
+        <dependency>
+            <groupId>org.ansj</groupId>
+            <artifactId>ansj_seg</artifactId>
+            <version>5.1.6</version>
+        </dependency>
+        <dependency>
+            <groupId>org.projectlombok</groupId>
+            <artifactId>lombok</artifactId>
+            <version>1.18.16</version>
+        </dependency>
+
+        <dependency>
+            <groupId>org.apache.commons</groupId>
+            <artifactId>commons-collections4</artifactId>
+            <version>4.1</version>
+        </dependency>
+        <dependency>
+            <groupId>com.aliyun.oss</groupId>
+            <artifactId>aliyun-sdk-oss</artifactId>
+            <version>3.15.1</version>
+        </dependency>
+    </dependencies>
+    <build>
+        <finalName>recommend-similarity-word2vec</finalName>
+    </build>
+
+</project>

+ 453 - 0
src/main/java/com/ansj/vec/Learn.java

@@ -0,0 +1,453 @@
+package com.ansj.vec;
+
+/**
+ * @author dyp
+ */
+import java.io.BufferedOutputStream;
+import java.io.BufferedReader;
+import java.io.DataOutputStream;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileNotFoundException;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+
+import com.ansj.vec.util.MapCount;
+import com.ansj.vec.domain.HiddenNeuron;
+import com.ansj.vec.domain.Neuron;
+import com.ansj.vec.domain.WordNeuron;
+import com.ansj.vec.util.Haffman;
+
+public class Learn {
+
+    private Map<String, Neuron> wordMap = new HashMap<>();
+    /**
+     * 训练多少个特征
+     */
+    private int layerSize = 200;
+
+    /**
+     * 上下文窗口大小
+     */
+    private int window = 5;
+
+    private double sample = 1e-3;
+    private double alpha = 0.025;
+    private double startingAlpha = alpha;
+
+    public int EXP_TABLE_SIZE = 1000;
+
+    private Boolean isCbow = false;
+
+    private double[] expTable = new double[EXP_TABLE_SIZE];
+
+    private int trainWordsCount = 0;
+
+    private int MAX_EXP = 6;
+
+    public Learn(Boolean isCbow, Integer layerSize, Integer window, Double alpha,
+                 Double sample) {
+        createExpTable();
+        if (isCbow != null) {
+            this.isCbow = isCbow;
+        }
+        if (layerSize != null)
+            this.layerSize = layerSize;
+        if (window != null)
+            this.window = window;
+        if (alpha != null)
+            this.alpha = alpha;
+        if (sample != null)
+            this.sample = sample;
+    }
+
+    public Learn() {
+        createExpTable();
+    }
+
+    /**
+     * trainModel
+     *
+     * @throws IOException
+     */
+    private void trainModel(File file) throws IOException {
+        try (BufferedReader br = new BufferedReader(new InputStreamReader(
+                new FileInputStream(file)))) {
+            String temp = null;
+            long nextRandom = 5;
+            int wordCount = 0;
+            int lastWordCount = 0;
+            int wordCountActual = 0;
+            while ((temp = br.readLine()) != null) {
+                if (wordCount - lastWordCount > 10000) {
+                    System.out.println("alpha:" + alpha + "\tProgress: "
+                            + (int) (wordCountActual / (double) (trainWordsCount + 1) * 100)
+                            + "%");
+                    wordCountActual += wordCount - lastWordCount;
+                    lastWordCount = wordCount;
+                    alpha = startingAlpha
+                            * (1 - wordCountActual / (double) (trainWordsCount + 1));
+                    if (alpha < startingAlpha * 0.0001) {
+                        alpha = startingAlpha * 0.0001;
+                    }
+                }
+                String[] strs = temp.split(" ");
+                wordCount += strs.length;
+                List<WordNeuron> sentence = new ArrayList<WordNeuron>();
+                for (int i = 0; i < strs.length; i++) {
+                    Neuron entry = wordMap.get(strs[i]);
+                    if (entry == null) {
+                        continue;
+                    }
+                    // The subsampling randomly discards frequent words while keeping the
+                    // ranking same
+                    if (sample > 0) {
+                        double ran = (Math.sqrt(entry.freq / (sample * trainWordsCount)) + 1)
+                                * (sample * trainWordsCount) / entry.freq;
+                        nextRandom = nextRandom * 25214903917L + 11;
+                        if (ran < (nextRandom & 0xFFFF) / (double) 65536) {
+                            continue;
+                        }
+                    }
+                    sentence.add((WordNeuron) entry);
+                }
+
+                for (int index = 0; index < sentence.size(); index++) {
+                    nextRandom = nextRandom * 25214903917L + 11;
+                    if (isCbow) {
+                        cbowGram(index, sentence, (int) nextRandom % window);
+                    } else {
+                        skipGram(index, sentence, (int) nextRandom % window);
+                    }
+                }
+
+            }
+            System.out.println("Vocab size: " + wordMap.size());
+            System.out.println("Words in train file: " + trainWordsCount);
+            System.out.println("sucess train over!");
+        }
+    }
+
+    /**
+     * skip gram 模型训练
+     *
+     * @param sentence
+     * @param neu1
+     */
+    private void skipGram(int index, List<WordNeuron> sentence, int b) {
+        // TODO Auto-generated method stub
+        WordNeuron word = sentence.get(index);
+        int a, c = 0;
+        for (a = b; a < window * 2 + 1 - b; a++) {
+            if (a == window) {
+                continue;
+            }
+            c = index - window + a;
+            if (c < 0 || c >= sentence.size()) {
+                continue;
+            }
+
+            double[] neu1e = new double[layerSize];// 误差项
+            // HIERARCHICAL SOFTMAX
+            List<Neuron> neurons = word.neurons;
+            WordNeuron we = sentence.get(c);
+            for (int i = 0; i < neurons.size(); i++) {
+                HiddenNeuron out = (HiddenNeuron) neurons.get(i);
+                double f = 0;
+                // Propagate hidden -> output
+                for (int j = 0; j < layerSize; j++) {
+                    f += we.syn0[j] * out.syn1[j];
+                }
+                if (f <= -MAX_EXP || f >= MAX_EXP) {
+                    continue;
+                } else {
+                    f = (f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2);
+                    f = expTable[(int) f];
+                }
+                // 'g' is the gradient multiplied by the learning rate
+                double g = (1 - word.codeArr[i] - f) * alpha;
+                // Propagate errors output -> hidden
+                for (c = 0; c < layerSize; c++) {
+                    neu1e[c] += g * out.syn1[c];
+                }
+                // Learn weights hidden -> output
+                for (c = 0; c < layerSize; c++) {
+                    out.syn1[c] += g * we.syn0[c];
+                }
+            }
+
+            // Learn weights input -> hidden
+            for (int j = 0; j < layerSize; j++) {
+                we.syn0[j] += neu1e[j];
+            }
+        }
+
+    }
+
+    /**
+     * 词袋模型
+     *
+     * @param index
+     * @param sentence
+     * @param b
+     */
+    private void cbowGram(int index, List<WordNeuron> sentence, int b) {
+        WordNeuron word = sentence.get(index);
+        int a, c = 0;
+
+        List<Neuron> neurons = word.neurons;
+        double[] neu1e = new double[layerSize];// 误差项
+        double[] neu1 = new double[layerSize];// 误差项
+        WordNeuron last_word;
+
+        for (a = b; a < window * 2 + 1 - b; a++)
+            if (a != window) {
+                c = index - window + a;
+                if (c < 0)
+                    continue;
+                if (c >= sentence.size())
+                    continue;
+                last_word = sentence.get(c);
+                if (last_word == null)
+                    continue;
+                for (c = 0; c < layerSize; c++)
+                    neu1[c] += last_word.syn0[c];
+            }
+
+        // HIERARCHICAL SOFTMAX
+        for (int d = 0; d < neurons.size(); d++) {
+            HiddenNeuron out = (HiddenNeuron) neurons.get(d);
+            double f = 0;
+            // Propagate hidden -> output
+            for (c = 0; c < layerSize; c++)
+                f += neu1[c] * out.syn1[c];
+            if (f <= -MAX_EXP)
+                continue;
+            else if (f >= MAX_EXP)
+                continue;
+            else
+                f = expTable[(int) ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))];
+            // 'g' is the gradient multiplied by the learning rate
+            // double g = (1 - word.codeArr[d] - f) * alpha;
+            // double g = f*(1-f)*( word.codeArr[i] - f) * alpha;
+            double g = f * (1 - f) * (word.codeArr[d] - f) * alpha;
+            //
+            for (c = 0; c < layerSize; c++) {
+                neu1e[c] += g * out.syn1[c];
+            }
+            // Learn weights hidden -> output
+            for (c = 0; c < layerSize; c++) {
+                out.syn1[c] += g * neu1[c];
+            }
+        }
+        for (a = b; a < window * 2 + 1 - b; a++) {
+            if (a != window) {
+                c = index - window + a;
+                if (c < 0)
+                    continue;
+                if (c >= sentence.size())
+                    continue;
+                last_word = sentence.get(c);
+                if (last_word == null)
+                    continue;
+                for (c = 0; c < layerSize; c++)
+                    last_word.syn0[c] += neu1e[c];
+            }
+
+        }
+    }
+
+    /**
+     * 统计词频
+     *
+     * @param file
+     * @throws IOException
+     */
+    private void readVocab(File file) throws IOException {
+        MapCount<String> mc = new MapCount<>();
+        try (BufferedReader br = new BufferedReader(new InputStreamReader(
+                new FileInputStream(file)))) {
+            String temp = null;
+            while ((temp = br.readLine()) != null) {
+                String[] split = temp.split(" ");
+                trainWordsCount += split.length;
+                for (String string : split) {
+                    mc.add(string);
+                }
+            }
+        }
+        for (Entry<String, Integer> element : mc.get().entrySet()) {
+            wordMap.put(element.getKey(), new WordNeuron(element.getKey(),
+                    (double) element.getValue() / mc.size(), layerSize));
+        }
+    }
+
+    /**
+     * 对文本进行预分类
+     *
+     * @param files
+     * @throws IOException
+     * @throws FileNotFoundException
+     */
+    private void readVocabWithSupervised(File[] files) throws IOException {
+        for (int category = 0; category < files.length; category++) {
+            // 对多个文件学习
+            MapCount<String> mc = new MapCount<>();
+            try (BufferedReader br = new BufferedReader(new InputStreamReader(
+                    new FileInputStream(files[category])))) {
+                String temp = null;
+                while ((temp = br.readLine()) != null) {
+                    String[] split = temp.split(" ");
+                    trainWordsCount += split.length;
+                    for (String string : split) {
+                        mc.add(string);
+                    }
+                }
+            }
+            for (Entry<String, Integer> element : mc.get().entrySet()) {
+                double tarFreq = (double) element.getValue() / mc.size();
+                if (wordMap.get(element.getKey()) != null) {
+                    double srcFreq = wordMap.get(element.getKey()).freq;
+                    if (srcFreq >= tarFreq) {
+                        continue;
+                    } else {
+                        Neuron wordNeuron = wordMap.get(element.getKey());
+                        wordNeuron.category = category;
+                        wordNeuron.freq = tarFreq;
+                    }
+                } else {
+                    wordMap.put(element.getKey(), new WordNeuron(element.getKey(),
+                            tarFreq, category, layerSize));
+                }
+            }
+        }
+    }
+
+    /**
+     * Precompute the exp() table f(x) = x / (x + 1)
+     */
+    private void createExpTable() {
+        for (int i = 0; i < EXP_TABLE_SIZE; i++) {
+            expTable[i] = Math.exp(((i / (double) EXP_TABLE_SIZE * 2 - 1) * MAX_EXP));
+            expTable[i] = expTable[i] / (expTable[i] + 1);
+        }
+    }
+
+    /**
+     * 根据文件学习
+     *
+     * @param file
+     * @throws IOException
+     */
+    public void learnFile(File file) throws IOException {
+        readVocab(file);
+        new Haffman(layerSize).make(wordMap.values());
+
+        // 查找每个神经元
+        for (Neuron neuron : wordMap.values()) {
+            ((WordNeuron) neuron).makeNeurons();
+        }
+
+        trainModel(file);
+    }
+
+    /**
+     * 根据预分类的文件学习
+     *
+     * @param summaryFile
+     *          合并文件
+     * @param classifiedFiles
+     *          分类文件
+     * @throws IOException
+     */
+    public void learnFile(File summaryFile, File[] classifiedFiles)
+            throws IOException {
+        readVocabWithSupervised(classifiedFiles);
+        new Haffman(layerSize).make(wordMap.values());
+        // 查找每个神经元
+        for (Neuron neuron : wordMap.values()) {
+            ((WordNeuron) neuron).makeNeurons();
+        }
+        trainModel(summaryFile);
+    }
+
+    /**
+     * 保存模型
+     */
+    public void saveModel(File file) {
+        // TODO Auto-generated method stub
+
+        try (DataOutputStream dataOutputStream = new DataOutputStream(
+                new BufferedOutputStream(new FileOutputStream(file)))) {
+            dataOutputStream.writeInt(wordMap.size());
+            dataOutputStream.writeInt(layerSize);
+            double[] syn0 = null;
+            for (Entry<String, Neuron> element : wordMap.entrySet()) {
+                dataOutputStream.writeUTF(element.getKey());
+                syn0 = ((WordNeuron) element.getValue()).syn0;
+                for (double d : syn0) {
+                    dataOutputStream.writeFloat(((Double) d).floatValue());
+                }
+            }
+        } catch (IOException e) {
+            // TODO Auto-generated catch block
+            e.printStackTrace();
+        }
+    }
+
+    public int getLayerSize() {
+        return layerSize;
+    }
+
+    public void setLayerSize(int layerSize) {
+        this.layerSize = layerSize;
+    }
+
+    public int getWindow() {
+        return window;
+    }
+
+    public void setWindow(int window) {
+        this.window = window;
+    }
+
+    public double getSample() {
+        return sample;
+    }
+
+    public void setSample(double sample) {
+        this.sample = sample;
+    }
+
+    public double getAlpha() {
+        return alpha;
+    }
+
+    public void setAlpha(double alpha) {
+        this.alpha = alpha;
+        this.startingAlpha = alpha;
+    }
+
+    public Boolean getIsCbow() {
+        return isCbow;
+    }
+
+    public void setIsCbow(Boolean isCbow) {
+        this.isCbow = isCbow;
+    }
+
+    public static void main(String[] args) throws IOException {
+        Learn learn = new Learn();
+        long start = System.currentTimeMillis();
+        learn.learnFile(new File("library/xh.txt"));
+        System.out.println("use time " + (System.currentTimeMillis() - start));
+        learn.saveModel(new File("library/javaVector"));
+
+    }
+}

+ 386 - 0
src/main/java/com/ansj/vec/Word2VEC.java

@@ -0,0 +1,386 @@
+package com.ansj.vec;
+
+/**
+ * @author dyp
+ */
+
+import java.io.*;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Set;
+import java.util.TreeSet;
+
+import com.aliyun.oss.OSS;
+import com.aliyun.oss.OSSClientBuilder;
+import com.aliyun.oss.common.auth.CredentialsProvider;
+import com.aliyun.oss.common.auth.DefaultCredentialProvider;
+import com.aliyun.oss.model.GetObjectRequest;
+import com.ansj.vec.domain.WordEntry;
+
+public class Word2VEC {
+
+    private Map<String, float[]> wordMap = new HashMap<String, float[]>();
+
+    private int words;
+    private int size;
+    private int topNSize = 40;
+
+    /**
+     * 加载模型
+     *
+     * @param path 模型的路径
+     * @throws IOException
+     */
+    public void loadGoogleModel(String path) throws IOException {
+        DataInputStream dis = null;
+        BufferedInputStream bis = null;
+        double len = 0;
+        float vector = 0;
+        try {
+            bis = new BufferedInputStream(new FileInputStream(path));
+            dis = new DataInputStream(bis);
+            // //读取词数
+            words = Integer.parseInt(readString(dis));
+            // //大小
+            size = Integer.parseInt(readString(dis));
+            String word;
+            float[] vectors = null;
+
+            for (int i = 0; i < words; i++) {
+                word = readString(dis);
+                vectors = new float[size];
+                len = 0;
+                for (int j = 0; j < size; j++) {
+                    vector = readFloat(dis);
+                    len += vector * vector;
+                    vectors[j] = vector;
+                }
+                len = Math.sqrt(len);
+
+                for (int j = 0; j < size; j++) {
+                    vectors[j] /= len;
+                }
+
+                wordMap.put(word, vectors);
+                dis.read();
+            }
+        } finally {
+            bis.close();
+            dis.close();
+        }
+    }
+
+    /**
+     * 加载模型
+     *
+     * @param path 模型的路径
+     * @throws IOException
+     */
+    public void loadGoogleModelFromOss(String endpoint, String bucketName, String path, String accessKeyId, String accessKetSecret) throws IOException {
+
+        CredentialsProvider credentialsProvider = new DefaultCredentialProvider(accessKeyId, accessKetSecret);
+        OSS client = new OSSClientBuilder().build(endpoint, credentialsProvider);
+        String file = "word2vec.bin";
+        client.getObject(new GetObjectRequest(bucketName, path), new File(file));
+        loadGoogleModel(file);
+
+
+//        OSSObject ossObj = client.getObject(bucketName, path);
+//        DataInputStream dis = null;
+//        BufferedInputStream bis = null;
+//        double len = 0;
+//        float vector = 0;
+//        try {
+//            bis = new BufferedInputStream(ossObj.getObjectContent());
+//            dis = new DataInputStream(bis);
+//            // //读取词数
+//            words = Integer.parseInt(readString(dis));
+//            // //大小
+//            size = Integer.parseInt(readString(dis));
+//            String word;
+//            float[] vectors = null;
+//            for (int i = 0; i < words; i++) {
+//                word = readString(dis);
+//                vectors = new float[size];
+//                len = 0;
+//                for (int j = 0; j < size; j++) {
+//                    vector = readFloat(dis);
+//                    len += vector * vector;
+//                    vectors[j] = vector;
+//                }
+//                len = Math.sqrt(len);
+//
+//                for (int j = 0; j < size; j++) {
+//                    vectors[j] /= len;
+//                }
+//
+//                wordMap.put(word, vectors);
+//                dis.read();
+//            }
+//        } finally {
+//            bis.close();
+//            dis.close();
+//        }
+
+    }
+
+    public void loadJavaModel(String path) throws IOException {
+        try (DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(path)))) {
+            words = dis.readInt();
+            size = dis.readInt();
+
+            float vector = 0;
+
+            String key = null;
+            float[] value = null;
+            for (int i = 0; i < words; i++) {
+                double len = 0;
+                key = dis.readUTF();
+                value = new float[size];
+                for (int j = 0; j < size; j++) {
+                    vector = dis.readFloat();
+                    len += vector * vector;
+                    value[j] = vector;
+                }
+
+                len = Math.sqrt(len);
+
+                for (int j = 0; j < size; j++) {
+                    value[j] /= len;
+                }
+                wordMap.put(key, value);
+            }
+
+        }
+    }
+
+    private static final int MAX_SIZE = 50;
+
+    /**
+     * 近义词
+     *
+     * @return
+     */
+    public TreeSet<WordEntry> analogy(String word0, String word1, String word2) {
+        float[] wv0 = getWordVector(word0);
+        float[] wv1 = getWordVector(word1);
+        float[] wv2 = getWordVector(word2);
+
+        if (wv1 == null || wv2 == null || wv0 == null) {
+            return null;
+        }
+        float[] wordVector = new float[size];
+        for (int i = 0; i < size; i++) {
+            wordVector[i] = wv1[i] - wv0[i] + wv2[i];
+        }
+        float[] tempVector;
+        String name;
+        List<WordEntry> wordEntrys = new ArrayList<WordEntry>(topNSize);
+        for (Entry<String, float[]> entry : wordMap.entrySet()) {
+            name = entry.getKey();
+            if (name.equals(word0) || name.equals(word1) || name.equals(word2)) {
+                continue;
+            }
+            float dist = 0;
+            tempVector = entry.getValue();
+            for (int i = 0; i < wordVector.length; i++) {
+                dist += wordVector[i] * tempVector[i];
+            }
+            insertTopN(name, dist, wordEntrys);
+        }
+        return new TreeSet<WordEntry>(wordEntrys);
+    }
+
+    private void insertTopN(String name, float score, List<WordEntry> wordsEntrys) {
+        // TODO Auto-generated method stub
+        if (wordsEntrys.size() < topNSize) {
+            wordsEntrys.add(new WordEntry(name, score));
+            return;
+        }
+        float min = Float.MAX_VALUE;
+        int minOffe = 0;
+        for (int i = 0; i < topNSize; i++) {
+            WordEntry wordEntry = wordsEntrys.get(i);
+            if (min > wordEntry.score) {
+                min = wordEntry.score;
+                minOffe = i;
+            }
+        }
+
+        if (score > min) {
+            wordsEntrys.set(minOffe, new WordEntry(name, score));
+        }
+
+    }
+
+    public Set<WordEntry> distance(String queryWord) {
+
+        float[] center = wordMap.get(queryWord);
+        if (center == null) {
+            return Collections.emptySet();
+        }
+
+        int resultSize = wordMap.size() < topNSize ? wordMap.size() : topNSize;
+        TreeSet<WordEntry> result = new TreeSet<WordEntry>();
+
+        double min = Float.MIN_VALUE;
+        for (Map.Entry<String, float[]> entry : wordMap.entrySet()) {
+            float[] vector = entry.getValue();
+            float dist = 0;
+            for (int i = 0; i < vector.length; i++) {
+                dist += center[i] * vector[i];
+            }
+
+            if (dist > min) {
+                result.add(new WordEntry(entry.getKey(), dist));
+                if (resultSize < result.size()) {
+                    result.pollLast();
+                }
+                min = result.last().score;
+            }
+        }
+        result.pollFirst();
+
+        return result;
+    }
+
+    public Set<WordEntry> distance(List<String> words) {
+
+        float[] center = null;
+        for (String word : words) {
+            center = sum(center, wordMap.get(word));
+        }
+
+        if (center == null) {
+            return Collections.emptySet();
+        }
+
+        int resultSize = wordMap.size() < topNSize ? wordMap.size() : topNSize;
+        TreeSet<WordEntry> result = new TreeSet<WordEntry>();
+
+        double min = Float.MIN_VALUE;
+        for (Map.Entry<String, float[]> entry : wordMap.entrySet()) {
+            float[] vector = entry.getValue();
+            float dist = 0;
+            for (int i = 0; i < vector.length; i++) {
+                dist += center[i] * vector[i];
+            }
+
+            if (dist > min) {
+                result.add(new WordEntry(entry.getKey(), dist));
+                if (resultSize < result.size()) {
+                    result.pollLast();
+                }
+                min = result.last().score;
+            }
+        }
+        result.pollFirst();
+
+        return result;
+    }
+
+    private float[] sum(float[] center, float[] fs) {
+        // TODO Auto-generated method stub
+
+        if (center == null && fs == null) {
+            return null;
+        }
+
+        if (fs == null) {
+            return center;
+        }
+
+        if (center == null) {
+            return fs;
+        }
+
+        for (int i = 0; i < fs.length; i++) {
+            center[i] += fs[i];
+        }
+
+        return center;
+    }
+
+    /**
+     * 得到词向量
+     *
+     * @param word
+     * @return
+     */
+    public float[] getWordVector(String word) {
+        return wordMap.get(word);
+    }
+
+    public static float readFloat(InputStream is) throws IOException {
+        byte[] bytes = new byte[4];
+        is.read(bytes);
+        return getFloat(bytes);
+    }
+
+    /**
+     * 读取一个float
+     *
+     * @param b
+     * @return
+     */
+    public static float getFloat(byte[] b) {
+        int accum = 0;
+        accum = accum | (b[0] & 0xff) << 0;
+        accum = accum | (b[1] & 0xff) << 8;
+        accum = accum | (b[2] & 0xff) << 16;
+        accum = accum | (b[3] & 0xff) << 24;
+        return Float.intBitsToFloat(accum);
+    }
+
+    /**
+     * 读取一个字符串
+     *
+     * @param dis
+     * @return
+     * @throws IOException
+     */
+    private static String readString(DataInputStream dis) throws IOException {
+        // TODO Auto-generated method stub
+        byte[] bytes = new byte[MAX_SIZE];
+        byte b = dis.readByte();
+        int i = -1;
+        StringBuilder sb = new StringBuilder();
+        while (b != 32 && b != 10) {
+            i++;
+            bytes[i] = b;
+            b = dis.readByte();
+            if (i == 49) {
+                sb.append(new String(bytes));
+                i = -1;
+                bytes = new byte[MAX_SIZE];
+            }
+        }
+        sb.append(new String(bytes, 0, i + 1));
+        return sb.toString();
+    }
+
+    public int getTopNSize() {
+        return topNSize;
+    }
+
+    public void setTopNSize(int topNSize) {
+        this.topNSize = topNSize;
+    }
+
+    public Map<String, float[]> getWordMap() {
+        return wordMap;
+    }
+
+    public int getWords() {
+        return words;
+    }
+
+    public int getSize() {
+        return size;
+    }
+
+}

+ 14 - 0
src/main/java/com/ansj/vec/domain/HiddenNeuron.java

@@ -0,0 +1,14 @@
+package com.ansj.vec.domain;
+
+/**
+ * @author dyp
+ */
+public class HiddenNeuron extends Neuron{
+
+    public double[] syn1 ; //hidden->out
+
+    public HiddenNeuron(int layerSize){
+        syn1 = new double[layerSize] ;
+    }
+
+}

+ 27 - 0
src/main/java/com/ansj/vec/domain/Neuron.java

@@ -0,0 +1,27 @@
+package com.ansj.vec.domain;
+
+/**
+ * @author dyp
+ */
+public abstract class Neuron implements Comparable<Neuron> {
+    public double freq;
+    public Neuron parent;
+    public int code;
+    // 语料预分类
+    public int category = -1;
+
+    @Override
+    public int compareTo(Neuron neuron) {
+        if (this.category == neuron.category) {
+            if (this.freq > neuron.freq) {
+                return 1;
+            } else {
+                return -1;
+            }
+        } else if (this.category > neuron.category) {
+            return 1;
+        } else {
+            return -1;
+        }
+    }
+}

+ 31 - 0
src/main/java/com/ansj/vec/domain/WordEntry.java

@@ -0,0 +1,31 @@
+package com.ansj.vec.domain;
+
+/**
+ * @author dyp
+ */
+public class WordEntry implements Comparable<WordEntry> {
+    public String name;
+    public float score;
+
+    public WordEntry(String name, float score) {
+        this.name = name;
+        this.score = score;
+    }
+
+    @Override
+    public String toString() {
+        // TODO Auto-generated method stub
+        return this.name + "\t" + score;
+    }
+
+    @Override
+    public int compareTo(WordEntry o) {
+        // TODO Auto-generated method stub
+        if (this.score < o.score) {
+            return 1;
+        } else {
+            return -1;
+        }
+    }
+
+}

+ 65 - 0
src/main/java/com/ansj/vec/domain/WordNeuron.java

@@ -0,0 +1,65 @@
+package com.ansj.vec.domain;
+
+/**
+ * @author dyp
+ */
+import java.util.Collections;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Random;
+
+public class WordNeuron extends Neuron {
+    public String name;
+    public double[] syn0 = null; // input->hidden
+    public List<Neuron> neurons = null;// 路径神经元
+    public int[] codeArr = null;
+
+    public List<Neuron> makeNeurons() {
+        if (neurons != null) {
+            return neurons;
+        }
+        Neuron neuron = this;
+        neurons = new LinkedList<>();
+        while ((neuron = neuron.parent) != null) {
+            neurons.add(neuron);
+        }
+        Collections.reverse(neurons);
+        codeArr = new int[neurons.size()];
+
+        for (int i = 1; i < neurons.size(); i++) {
+            codeArr[i - 1] = neurons.get(i).code;
+        }
+        codeArr[codeArr.length - 1] = this.code;
+
+        return neurons;
+    }
+
+    public WordNeuron(String name, double freq, int layerSize) {
+        this.name = name;
+        this.freq = freq;
+        this.syn0 = new double[layerSize];
+        Random random = new Random();
+        for (int i = 0; i < syn0.length; i++) {
+            syn0[i] = (random.nextDouble() - 0.5) / layerSize;
+        }
+    }
+
+    /**
+     * 用于有监督的创造hoffman tree
+     *
+     * @param name
+     * @param freq
+     * @param layerSize
+     */
+    public WordNeuron(String name, double freq, int category, int layerSize) {
+        this.name = name;
+        this.freq = freq;
+        this.syn0 = new double[layerSize];
+        this.category = category;
+        Random random = new Random();
+        for (int i = 0; i < syn0.length; i++) {
+            syn0[i] = (random.nextDouble() - 0.5) / layerSize;
+        }
+    }
+
+}

+ 48 - 0
src/main/java/com/ansj/vec/util/Haffman.java

@@ -0,0 +1,48 @@
+package com.ansj.vec.util;
+
+/**
+ * @author dyp
+ */
+import java.util.Collection;
+import java.util.List;
+import java.util.TreeSet;
+
+import com.ansj.vec.domain.HiddenNeuron;
+import com.ansj.vec.domain.Neuron;
+
+/**
+ * 构建Haffman编码树
+ *
+ * @author ansj
+ *
+ */
+public class Haffman {
+    private int layerSize;
+
+    public Haffman(int layerSize) {
+        this.layerSize = layerSize;
+    }
+
+    private TreeSet<Neuron> set = new TreeSet<>();
+
+    public void make(Collection<Neuron> neurons) {
+        set.addAll(neurons);
+        while (set.size() > 1) {
+            merger();
+        }
+    }
+
+    private void merger() {
+        HiddenNeuron hn = new HiddenNeuron(layerSize);
+        Neuron min1 = set.pollFirst();
+        Neuron min2 = set.pollFirst();
+        hn.category = min2.category;
+        hn.freq = min1.freq + min2.freq;
+        min1.parent = hn;
+        min2.parent = hn;
+        min1.code = 0;
+        min2.code = 1;
+        set.add(hn);
+    }
+
+}

+ 66 - 0
src/main/java/com/ansj/vec/util/MapCount.java

@@ -0,0 +1,66 @@
+package com.ansj.vec.util;
+
+/**
+ * @author dyp
+ */
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.Map.Entry;
+
+public class MapCount<T> {
+    private HashMap<T, Integer> hm = null;
+
+    public MapCount() {
+        this.hm = new HashMap();
+    }
+
+    public MapCount(int initialCapacity) {
+        this.hm = new HashMap(initialCapacity);
+    }
+
+    public void add(T t, int n) {
+        Integer integer = null;
+        if((integer = (Integer)this.hm.get(t)) != null) {
+            this.hm.put(t, Integer.valueOf(integer.intValue() + n));
+        } else {
+            this.hm.put(t, Integer.valueOf(n));
+        }
+
+    }
+
+    public void add(T t) {
+        this.add(t, 1);
+    }
+
+    public int size() {
+        return this.hm.size();
+    }
+
+    public void remove(T t) {
+        this.hm.remove(t);
+    }
+
+    public HashMap<T, Integer> get() {
+        return this.hm;
+    }
+
+    public String getDic() {
+        Iterator iterator = this.hm.entrySet().iterator();
+        StringBuilder sb = new StringBuilder();
+        Entry next = null;
+
+        while(iterator.hasNext()) {
+            next = (Entry)iterator.next();
+            sb.append(next.getKey());
+            sb.append("\t");
+            sb.append(next.getValue());
+            sb.append("\n");
+        }
+
+        return sb.toString();
+    }
+
+    public static void main(String[] args) {
+        System.out.println(9223372036854775807L);
+    }
+}

+ 164 - 0
src/main/java/com/ansj/vec/util/WordKmeans.java

@@ -0,0 +1,164 @@
+package com.ansj.vec.util;
+
+/**
+ * @author dyp
+ */
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+
+import com.ansj.vec.Word2VEC;
+
+/**
+ * keanmeans聚类
+ *
+ * @author ansj
+ *
+ */
+public class WordKmeans {
+
+    public static void main(String[] args) throws IOException {
+        Word2VEC vec = new Word2VEC();
+        vec.loadGoogleModel("vectors.bin");
+        System.out.println("load model ok!");
+        WordKmeans wordKmeans = new WordKmeans(vec.getWordMap(), 50, 50);
+        Classes[] explain = wordKmeans.explain();
+
+        for (int i = 0; i < explain.length; i++) {
+            System.out.println("--------" + i + "---------");
+            System.out.println(explain[i].getTop(10));
+        }
+
+    }
+
+    private Map<String, float[]> wordMap = null;
+
+    private int iter;
+
+    private Classes[] cArray = null;
+
+    public WordKmeans(Map<String, float[]> wordMap, int clcn, int iter) {
+        this.wordMap = wordMap;
+        this.iter = iter;
+        cArray = new Classes[clcn];
+    }
+
+    public Classes[] explain() {
+        //first 取前clcn个点
+        Iterator<Entry<String, float[]>> iterator = wordMap.entrySet().iterator();
+        for (int i = 0; i < cArray.length; i++) {
+            Entry<String, float[]> next = iterator.next();
+            cArray[i] = new Classes(i, next.getValue());
+        }
+
+        for (int i = 0; i < iter; i++) {
+            for (Classes classes : cArray) {
+                classes.clean();
+            }
+
+            iterator = wordMap.entrySet().iterator();
+            while (iterator.hasNext()) {
+                Entry<String, float[]> next = iterator.next();
+                double miniScore = Double.MAX_VALUE;
+                double tempScore;
+                int classesId = 0;
+                for (Classes classes : cArray) {
+                    tempScore = classes.distance(next.getValue());
+                    if (miniScore > tempScore) {
+                        miniScore = tempScore;
+                        classesId = classes.id;
+                    }
+                }
+                cArray[classesId].putValue(next.getKey(), miniScore);
+            }
+
+            for (Classes classes : cArray) {
+                classes.updateCenter(wordMap);
+            }
+            System.out.println("iter " + i + " ok!");
+        }
+
+        return cArray;
+    }
+
+    public static class Classes {
+        private int id;
+
+        private float[] center;
+
+        public Classes(int id, float[] center) {
+            this.id = id;
+            this.center = center.clone();
+        }
+
+        Map<String, Double> values = new HashMap<>();
+
+        public double distance(float[] value) {
+            double sum = 0;
+            for (int i = 0; i < value.length; i++) {
+                sum += (center[i] - value[i])*(center[i] - value[i]) ;
+            }
+            return sum ;
+        }
+
+        public void putValue(String word, double score) {
+            values.put(word, score);
+        }
+
+        /**
+         * 重新计算中心点
+         * @param wordMap
+         */
+        public void updateCenter(Map<String, float[]> wordMap) {
+            for (int i = 0; i < center.length; i++) {
+                center[i] = 0;
+            }
+            float[] value = null;
+            for (String keyWord : values.keySet()) {
+                value = wordMap.get(keyWord);
+                for (int i = 0; i < value.length; i++) {
+                    center[i] += value[i];
+                }
+            }
+            for (int i = 0; i < center.length; i++) {
+                center[i] = center[i] / values.size();
+            }
+        }
+
+        /**
+         * 清空历史结果
+         */
+        public void clean() {
+            // TODO Auto-generated method stub
+            values.clear();
+        }
+
+        /**
+         * 取得每个类别的前n个结果
+         * @param n
+         * @return
+         */
+        public List<Entry<String, Double>> getTop(int n) {
+            List<Map.Entry<String, Double>> arrayList = new ArrayList<Map.Entry<String, Double>>(
+                    values.entrySet());
+            Collections.sort(arrayList, new Comparator<Map.Entry<String, Double>>() {
+                @Override
+                public int compare(Entry<String, Double> o1, Entry<String, Double> o2) {
+                    // TODO Auto-generated method stub
+                    return o1.getValue() > o2.getValue() ? 1 : -1;
+                }
+            });
+            int min = Math.min(n, arrayList.size() - 1);
+            if(min<=1)return Collections.emptyList() ;
+            return arrayList.subList(0, min);
+        }
+
+    }
+
+}

+ 79 - 0
src/main/java/com/tzld/piaoquan/recommend/similarity/word2vec/Demo.java

@@ -0,0 +1,79 @@
+package com.tzld.piaoquan.recommend.similarity.word2vec;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * @author dyp
+ */
+public class Demo {
+
+    public static void main(String[] args) throws Exception {
+
+        Word2Vec vec = new Word2Vec();
+        try {
+            String endpoint = "oss-cn-hangzhou.aliyuncs.com";
+            String bucketName = "art-recommend";
+            String path = "similarity/word2vec/Google_word2vec_zhwiki210720_300d.bin";
+            String accessKeyId = "LTAIP6x1l3DXfSxm";
+            String accessKetSecret = "KbTaM9ars4OX3PMS6Xm7rtxGr1FLon";
+
+            long start = System.currentTimeMillis();
+            vec.loadGoogleModelFromOss(endpoint, bucketName, path, accessKeyId, accessKetSecret);
+            long end = System.currentTimeMillis();
+            System.out.println("loadGoogleModelFromOss cost = " + (end - start));
+
+        } catch (IOException e) {
+            e.printStackTrace();
+        }
+
+        String[] s = new String[]{
+                "🔴终于找到了这首歌,献给你!",
+                "各位退休的同学,请听!",
+                "这首歌太好听了,听醉了别怪我!",
+                "老了真的很难!",
+                "老同学在聚会上的演讲幽默是太实在了💢",
+                "🔥🔥🔥一篇关于养老金问题的文章,请过来看看",
+                "🔴老人考级的标准出台!符合6个条件的了不得🔥",
+                "超级贵的景色,看过的彻底傻眼了📣",
+                "她走了!泪目!留下了这段话,让人潸然泪下!",
+                "🔴老同学❗️好久不见了,大家来看看吧!",
+                "⭕谁写的?把人《一辈子》写明白了,给老友看看吧 ~!",
+                "太美了,难得一见的美景~"
+        };
+
+
+//分词,获取词语列表
+        List<String>[] words = new ArrayList[s.length];
+        for (int i = 0; i < s.length; i++) {
+            words[i] = Segment.getWords(s[i]);
+        }
+
+//快速句子相似度
+        System.out.println("快速句子相似度:");
+        for (int i = 0; i < words.length - 1; i++) {
+            for (int j = i + 1; j < words.length; j++) {
+                System.out.println(s[i] + "|||" + s[j] + ": " + vec.fastSentenceSimilarity(words[i], words[j]));
+            }
+        }
+
+
+//句子相似度(所有词语权值设为1)
+        System.out.println("句子相似度:");
+        for (int i = 0; i < s.length - 1; i++) {
+            for (int j = i + 1; j < s.length; j++) {
+                System.out.println(s[i] + "|||" + s[j] + ": " + vec.sentenceSimilarity(words[i], words[j]));
+            }
+        }
+
+//句子相似度(名词、动词权值设为1,其他设为0.8)
+//        float[] weightArray1 = Segment.getPOSWeightArray(Segment.getPOS(s1));
+//        float[] weightArray2 = Segment.getPOSWeightArray(Segment.getPOS(s2));
+//        float[] weightArray3 = Segment.getPOSWeightArray(Segment.getPOS(s3));
+//        System.out.println("s1|s2: " + vec.sentenceSimilarity(wordList1, wordList2, weightArray1, weightArray2));
+//        System.out.println("s1|s3: " + vec.sentenceSimilarity(wordList1, wordList3, weightArray1, weightArray3));
+//        System.out.println("s2|s3: " + vec.sentenceSimilarity(wordList2, wordList3, weightArray2, weightArray3));
+    }
+
+}

+ 70 - 0
src/main/java/com/tzld/piaoquan/recommend/similarity/word2vec/Segment.java

@@ -0,0 +1,70 @@
+package com.tzld.piaoquan.recommend.similarity.word2vec;
+
+import org.ansj.domain.Term;
+import org.ansj.recognition.impl.StopRecognition;
+import org.ansj.splitWord.analysis.ToAnalysis;
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+
+/**
+ * @author dyp
+ */
+
+
+public class Segment {
+    public Segment() {
+    }
+
+    public static List<Term> Seg(String sentence) {
+        StopRecognition filter = new StopRecognition();
+        filter.insertStopWords(new String[]{",", " ", ".", ",", "。", ":", ":", "'", "‘", "’", " ", "“", "”", "《", "》", "[",
+                "]", "-"});
+        return ToAnalysis.parse(sentence).recognition(filter).getTerms();
+    }
+
+    public static List<String> getWords(String sentence) {
+        List<Term> termList = Seg(sentence);
+        List<String> wordList = new ArrayList();
+        Iterator var4 = termList.iterator();
+
+        while (var4.hasNext()) {
+            Term wordTerm = (Term) var4.next();
+            wordList.add(wordTerm.getName());
+        }
+
+        return wordList;
+    }
+
+    public static List<String> getPOS(String sentence) {
+        List<Term> termList = Seg(sentence);
+        List<String> natureList = new ArrayList();
+        Iterator var4 = termList.iterator();
+
+        while (var4.hasNext()) {
+            Term wordTerm = (Term) var4.next();
+            natureList.add(wordTerm.getNatureStr());
+        }
+
+        return natureList;
+    }
+
+    public static float[] getPOSWeightArray(List<String> posList) {
+        float[] weightVector = new float[posList.size()];
+
+        for (int i = 0; i < weightVector.length; ++i) {
+            String POS = (String) posList.get(i);
+            switch (POS.charAt(0)) {
+                case 'n':
+                case 'v':
+                    weightVector[i] = 1.0F;
+                    break;
+                default:
+                    weightVector[i] = 0.8F;
+            }
+        }
+
+        return weightVector;
+    }
+}

+ 227 - 0
src/main/java/com/tzld/piaoquan/recommend/similarity/word2vec/Word2Vec.java

@@ -0,0 +1,227 @@
+package com.tzld.piaoquan.recommend.similarity.word2vec;
+
+import com.ansj.vec.Learn;
+import com.ansj.vec.Word2VEC;
+import com.ansj.vec.domain.WordEntry;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Set;
+import java.util.TreeSet;
+import java.util.Map.Entry;
+
+/**
+ * @author dyp
+ */
+
+
+public class Word2Vec {
+    private Word2VEC vec = new Word2VEC();
+    private boolean loadModel = false;
+
+    public Word2Vec() {
+    }
+
+    public void loadGoogleModelFromOss(String endpoint, String bucketName, String path, String accessKeyId, String accessKetSecret) throws IOException {
+        this.vec.loadGoogleModelFromOss(endpoint, bucketName, path, accessKeyId, accessKetSecret);
+        this.loadModel = true;
+    }
+
+    public void loadGoogleModel(String path) throws IOException {
+        this.vec.loadGoogleModel(path);
+        this.loadModel = true;
+    }
+
+    public float[] getWordVector(String word) {
+        return !this.loadModel ? null : this.vec.getWordVector(word);
+    }
+
+    private float calDist(float[] vec1, float[] vec2) {
+        float dist = 0.0F;
+
+        for (int i = 0; i < vec1.length; ++i) {
+            dist += vec1[i] * vec2[i];
+        }
+
+        return dist;
+    }
+
+    private void calSum(float[] sum, float[] vec) {
+        for (int i = 0; i < sum.length; ++i) {
+            sum[i] += vec[i];
+        }
+
+    }
+
+    public float wordSimilarity(String word1, String word2) {
+        if (!this.loadModel) {
+            return 0.0F;
+        } else {
+            float[] word1Vec = this.getWordVector(word1);
+            float[] word2Vec = this.getWordVector(word2);
+            return word1Vec != null && word2Vec != null ? this.calDist(word1Vec, word2Vec) : 0.0F;
+        }
+    }
+
+    public Set<WordEntry> getSimilarWords(String word, int maxReturnNum) {
+        if (!this.loadModel) {
+            return null;
+        } else {
+            float[] center = this.getWordVector(word);
+            if (center == null) {
+                return Collections.emptySet();
+            } else {
+                int resultSize = this.vec.getWords() < maxReturnNum ? this.vec.getWords() : maxReturnNum;
+                TreeSet<WordEntry> result = new TreeSet();
+                double min = 4.9E-324D;
+                Iterator var9 = this.vec.getWordMap().entrySet().iterator();
+
+                while (var9.hasNext()) {
+                    Entry<String, float[]> entry = (Entry) var9.next();
+                    float[] vector = (float[]) entry.getValue();
+                    float dist = this.calDist(center, vector);
+                    if (result.size() <= resultSize) {
+                        result.add(new WordEntry((String) entry.getKey(), dist));
+                        min = (double) ((WordEntry) result.last()).score;
+                    } else if ((double) dist > min) {
+                        result.add(new WordEntry((String) entry.getKey(), dist));
+                        result.pollLast();
+                        min = (double) ((WordEntry) result.last()).score;
+                    }
+                }
+
+                result.pollFirst();
+                return result;
+            }
+        }
+    }
+
+    private float calMaxSimilarity(String centerWord, List<String> wordList) {
+        float max = -1.0F;
+        if (wordList.contains(centerWord)) {
+            return 1.0F;
+        } else {
+            Iterator var5 = wordList.iterator();
+
+            while (var5.hasNext()) {
+                String word = (String) var5.next();
+                float temp = this.wordSimilarity(centerWord, word);
+                if (temp != 0.0F && temp > max) {
+                    max = temp;
+                }
+            }
+
+            if (max == -1.0F) {
+                return 0.0F;
+            } else {
+                return max;
+            }
+        }
+    }
+
+    public float fastSentenceSimilarity(List<String> sentence1Words, List<String> sentence2Words) {
+        if (!this.loadModel) {
+            return 0.0F;
+        } else if (!sentence1Words.isEmpty() && !sentence2Words.isEmpty()) {
+            float[] sen1vector = new float[this.vec.getSize()];
+            float[] sen2vector = new float[this.vec.getSize()];
+            double len1 = 0.0D;
+            double len2 = 0.0D;
+
+            int i;
+            float[] tmp;
+            for (i = 0; i < sentence1Words.size(); ++i) {
+                tmp = this.getWordVector((String) sentence1Words.get(i));
+                if (tmp != null) {
+                    this.calSum(sen1vector, tmp);
+                }
+            }
+
+            for (i = 0; i < sentence2Words.size(); ++i) {
+                tmp = this.getWordVector((String) sentence2Words.get(i));
+                if (tmp != null) {
+                    this.calSum(sen2vector, tmp);
+                }
+            }
+
+            for (i = 0; i < this.vec.getSize(); ++i) {
+                len1 += (double) (sen1vector[i] * sen1vector[i]);
+                len2 += (double) (sen2vector[i] * sen2vector[i]);
+            }
+
+            return (float) ((double) this.calDist(sen1vector, sen2vector) / Math.sqrt(len1 * len2));
+        } else {
+            return 0.0F;
+        }
+    }
+
+    public float sentenceSimilarity(List<String> sentence1Words, List<String> sentence2Words) {
+        if (!this.loadModel) {
+            return 0.0F;
+        } else if (!sentence1Words.isEmpty() && !sentence2Words.isEmpty()) {
+            float sum1 = 0.0F;
+            float sum2 = 0.0F;
+            int count1 = 0;
+            int count2 = 0;
+
+            int i;
+            for (i = 0; i < sentence1Words.size(); ++i) {
+                if (this.getWordVector((String) sentence1Words.get(i)) != null) {
+                    ++count1;
+                    sum1 += this.calMaxSimilarity((String) sentence1Words.get(i), sentence2Words);
+                }
+            }
+
+            for (i = 0; i < sentence2Words.size(); ++i) {
+                if (this.getWordVector((String) sentence2Words.get(i)) != null) {
+                    ++count2;
+                    sum2 += this.calMaxSimilarity((String) sentence2Words.get(i), sentence1Words);
+                }
+            }
+
+            return (sum1 + sum2) / (float) (count1 + count2);
+        } else {
+            return 0.0F;
+        }
+    }
+
+    public float sentenceSimilarity(List<String> sentence1Words, List<String> sentence2Words, float[] weightVector1, float[] weightVector2) throws Exception {
+        if (!this.loadModel) {
+            return 0.0F;
+        } else if (!sentence1Words.isEmpty() && !sentence2Words.isEmpty()) {
+            if (sentence1Words.size() == weightVector1.length && sentence2Words.size() == weightVector2.length) {
+                float sum1 = 0.0F;
+                float sum2 = 0.0F;
+                float divide1 = 0.0F;
+                float divide2 = 0.0F;
+
+                int i;
+                float wordMaxSimi;
+                for (i = 0; i < sentence1Words.size(); ++i) {
+                    if (this.getWordVector((String) sentence1Words.get(i)) != null) {
+                        wordMaxSimi = this.calMaxSimilarity((String) sentence1Words.get(i), sentence2Words);
+                        sum1 += wordMaxSimi * weightVector1[i];
+                        divide1 += weightVector1[i];
+                    }
+                }
+
+                for (i = 0; i < sentence2Words.size(); ++i) {
+                    if (this.getWordVector((String) sentence2Words.get(i)) != null) {
+                        wordMaxSimi = this.calMaxSimilarity((String) sentence2Words.get(i), sentence1Words);
+                        sum2 += wordMaxSimi * weightVector2[i];
+                        divide2 += weightVector2[i];
+                    }
+                }
+
+                return (sum1 + sum2) / (divide1 + divide2);
+            } else {
+                throw new Exception("length of word list and weight vector is different");
+            }
+        } else {
+            return 0.0F;
+        }
+    }
+}