|
@@ -0,0 +1,259 @@
|
|
|
+package com.aliyun.odps.examples.graph;
|
|
|
+
|
|
|
+import java.io.DataInput;
|
|
|
+import java.io.DataOutput;
|
|
|
+import java.io.IOException;
|
|
|
+
|
|
|
+import org.apache.commons.logging.Log;
|
|
|
+import org.apache.commons.logging.LogFactory;
|
|
|
+
|
|
|
+import com.aliyun.odps.data.TableInfo;
|
|
|
+import com.aliyun.odps.graph.Aggregator;
|
|
|
+import com.aliyun.odps.graph.ComputeContext;
|
|
|
+import com.aliyun.odps.graph.GraphJob;
|
|
|
+import com.aliyun.odps.graph.GraphLoader;
|
|
|
+import com.aliyun.odps.graph.MutationContext;
|
|
|
+import com.aliyun.odps.graph.Vertex;
|
|
|
+import com.aliyun.odps.graph.WorkerContext;
|
|
|
+import com.aliyun.odps.io.DoubleWritable;
|
|
|
+import com.aliyun.odps.io.LongWritable;
|
|
|
+import com.aliyun.odps.io.NullWritable;
|
|
|
+import com.aliyun.odps.io.Text;
|
|
|
+import com.aliyun.odps.io.Tuple;
|
|
|
+import com.aliyun.odps.io.Writable;
|
|
|
+import com.aliyun.odps.io.WritableRecord;
|
|
|
+
|
|
|
+
|
|
|
+ * Set resources arguments:
|
|
|
+ * kmeans_centers
|
|
|
+ * Set program arguments:
|
|
|
+ * kmeans_in kmeans_out
|
|
|
+ */
|
|
|
+public class Kmeans {
|
|
|
+ private final static Log LOG = LogFactory.getLog(Kmeans.class);
|
|
|
+
|
|
|
+ public static class KmeansVertex extends Vertex<Text, Tuple, NullWritable, NullWritable> {
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void compute(ComputeContext<Text, Tuple, NullWritable, NullWritable> context,
|
|
|
+ Iterable<NullWritable> messages) throws IOException {
|
|
|
+ context.aggregate(getValue());
|
|
|
+ }
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ public static class KmeansVertexReader extends
|
|
|
+ GraphLoader<Text, Tuple, NullWritable, NullWritable> {
|
|
|
+ @Override
|
|
|
+ public void load(LongWritable recordNum, WritableRecord record,
|
|
|
+ MutationContext<Text, Tuple, NullWritable, NullWritable> context) throws IOException {
|
|
|
+ KmeansVertex vertex = new KmeansVertex();
|
|
|
+ vertex.setId(new Text(String.valueOf(recordNum.get())));
|
|
|
+ vertex.setValue(new Tuple(record.getAll()));
|
|
|
+ context.addVertexRequest(vertex);
|
|
|
+ }
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ public static class KmeansAggrValue implements Writable {
|
|
|
+
|
|
|
+ Tuple centers = new Tuple();
|
|
|
+ Tuple sums = new Tuple();
|
|
|
+ Tuple counts = new Tuple();
|
|
|
+
|
|
|
+ public void write(DataOutput out) throws IOException {
|
|
|
+ centers.write(out);
|
|
|
+ sums.write(out);
|
|
|
+ counts.write(out);
|
|
|
+ }
|
|
|
+
|
|
|
+ public void readFields(DataInput in) throws IOException {
|
|
|
+ centers = new Tuple();
|
|
|
+ centers.readFields(in);
|
|
|
+ sums = new Tuple();
|
|
|
+ sums.readFields(in);
|
|
|
+ counts = new Tuple();
|
|
|
+ counts.readFields(in);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public String toString() {
|
|
|
+ return "centers " + centers.toString() + ", sums " + sums.toString() + ", counts "
|
|
|
+ + counts.toString();
|
|
|
+ }
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ public static class KmeansAggregator extends Aggregator<KmeansAggrValue> {
|
|
|
+
|
|
|
+ @SuppressWarnings("rawtypes")
|
|
|
+ @Override
|
|
|
+ public KmeansAggrValue createInitialValue(WorkerContext context) throws IOException {
|
|
|
+ KmeansAggrValue aggrVal = null;
|
|
|
+ if (context.getSuperstep() == 0) {
|
|
|
+ aggrVal = new KmeansAggrValue();
|
|
|
+ aggrVal.centers = new Tuple();
|
|
|
+ aggrVal.sums = new Tuple();
|
|
|
+ aggrVal.counts = new Tuple();
|
|
|
+
|
|
|
+ byte[] centers = context.readCacheFile("kmeans_centers");
|
|
|
+ String lines[] = new String(centers).split("\n");
|
|
|
+
|
|
|
+ for (int i = 0; i < lines.length; i++) {
|
|
|
+ String[] ss = lines[i].split(",");
|
|
|
+ Tuple center = new Tuple();
|
|
|
+ Tuple sum = new Tuple();
|
|
|
+ for (int j = 0; j < ss.length; ++j) {
|
|
|
+ center.append(new DoubleWritable(Double.valueOf(ss[j].trim())));
|
|
|
+ sum.append(new DoubleWritable(0.0));
|
|
|
+ }
|
|
|
+ LongWritable count = new LongWritable(0);
|
|
|
+ aggrVal.sums.append(sum);
|
|
|
+ aggrVal.counts.append(count);
|
|
|
+ aggrVal.centers.append(center);
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ aggrVal = (KmeansAggrValue) context.getLastAggregatedValue(0);
|
|
|
+ }
|
|
|
+
|
|
|
+ return aggrVal;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void aggregate(KmeansAggrValue value, Object item) {
|
|
|
+ int min = 0;
|
|
|
+ double mindist = Double.MAX_VALUE;
|
|
|
+ Tuple point = (Tuple) item;
|
|
|
+
|
|
|
+ for (int i = 0; i < value.centers.size(); i++) {
|
|
|
+ Tuple center = (Tuple) value.centers.get(i);
|
|
|
+
|
|
|
+ double dist = 0.0d;
|
|
|
+ for (int j = 0; j < center.size(); j++) {
|
|
|
+ double v = ((DoubleWritable) point.get(j)).get() - ((DoubleWritable) center.get(j)).get();
|
|
|
+ dist += v * v;
|
|
|
+ }
|
|
|
+ if (dist < mindist) {
|
|
|
+ mindist = dist;
|
|
|
+ min = i;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ Tuple sum = (Tuple) value.sums.get(min);
|
|
|
+ for (int i = 0; i < point.size(); i++) {
|
|
|
+ DoubleWritable s = (DoubleWritable) sum.get(i);
|
|
|
+ s.set(s.get() + ((DoubleWritable) point.get(i)).get());
|
|
|
+ }
|
|
|
+ LongWritable count = (LongWritable) value.counts.get(min);
|
|
|
+ count.set(count.get() + 1);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void merge(KmeansAggrValue value, KmeansAggrValue partial) {
|
|
|
+ for (int i = 0; i < value.sums.size(); i++) {
|
|
|
+ Tuple sum = (Tuple) value.sums.get(i);
|
|
|
+ Tuple that = (Tuple) partial.sums.get(i);
|
|
|
+
|
|
|
+ for (int j = 0; j < sum.size(); j++) {
|
|
|
+ DoubleWritable s = (DoubleWritable) sum.get(j);
|
|
|
+ s.set(s.get() + ((DoubleWritable) that.get(j)).get());
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ for (int i = 0; i < value.counts.size(); i++) {
|
|
|
+ LongWritable count = (LongWritable) value.counts.get(i);
|
|
|
+ count.set(count.get() + ((LongWritable) partial.counts.get(i)).get());
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ @SuppressWarnings("rawtypes")
|
|
|
+ @Override
|
|
|
+ public boolean terminate(WorkerContext context, KmeansAggrValue value) throws IOException {
|
|
|
+
|
|
|
+
|
|
|
+ Tuple newCenters = new Tuple(value.sums.size());
|
|
|
+ for (int i = 0; i < value.sums.size(); i++) {
|
|
|
+ Tuple sum = (Tuple) value.sums.get(i);
|
|
|
+ Tuple newCenter = new Tuple(sum.size());
|
|
|
+ LongWritable c = (LongWritable) value.counts.get(i);
|
|
|
+ for (int j = 0; j < sum.size(); j++) {
|
|
|
+
|
|
|
+ DoubleWritable s = (DoubleWritable) sum.get(j);
|
|
|
+ double val = s.get() / c.get();
|
|
|
+ newCenter.set(j, new DoubleWritable(val));
|
|
|
+
|
|
|
+
|
|
|
+ s.set(0.0d);
|
|
|
+ }
|
|
|
+
|
|
|
+ c.set(0);
|
|
|
+ newCenters.set(i, newCenter);
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ Tuple oldCenters = value.centers;
|
|
|
+ value.centers = newCenters;
|
|
|
+
|
|
|
+ LOG.info("old centers: " + oldCenters + ", new centers: " + newCenters);
|
|
|
+
|
|
|
+
|
|
|
+ boolean converged = true;
|
|
|
+ for (int i = 0; i < value.centers.size() && converged; i++) {
|
|
|
+ Tuple oldCenter = (Tuple) oldCenters.get(i);
|
|
|
+ Tuple newCenter = (Tuple) newCenters.get(i);
|
|
|
+ double sum = 0.0d;
|
|
|
+ for (int j = 0; j < newCenter.size(); j++) {
|
|
|
+ double v =
|
|
|
+ ((DoubleWritable) newCenter.get(j)).get() - ((DoubleWritable) oldCenter.get(j)).get();
|
|
|
+ sum += v * v;
|
|
|
+ }
|
|
|
+ double dist = Math.sqrt(sum);
|
|
|
+ LOG.info("old center: " + oldCenter + ", new center: " + newCenter + ", dist: " + dist);
|
|
|
+
|
|
|
+ converged = dist < 0.05d;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (converged || context.getSuperstep() == context.getMaxIteration() - 1) {
|
|
|
+
|
|
|
+ for (int i = 0; i < value.centers.size(); i++) {
|
|
|
+ context.write(((Tuple) value.centers.get(i)).toArray());
|
|
|
+ }
|
|
|
+
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private static void printUsage() {
|
|
|
+ System.out.println("Usage: <in> <out> [Max iterations (default 30)]");
|
|
|
+ System.exit(-1);
|
|
|
+ }
|
|
|
+
|
|
|
+ public static void main(String[] args) throws IOException {
|
|
|
+ if (args.length < 2)
|
|
|
+ printUsage();
|
|
|
+
|
|
|
+ GraphJob job = new GraphJob();
|
|
|
+
|
|
|
+ job.setGraphLoaderClass(KmeansVertexReader.class);
|
|
|
+ job.setRuntimePartitioning(false);
|
|
|
+ job.setVertexClass(KmeansVertex.class);
|
|
|
+ job.setAggregatorClass(KmeansAggregator.class);
|
|
|
+ job.addInput(TableInfo.builder().tableName(args[0]).build());
|
|
|
+ job.addOutput(TableInfo.builder().tableName(args[1]).build());
|
|
|
+
|
|
|
+
|
|
|
+ job.setMaxIteration(30);
|
|
|
+ if (args.length >= 3)
|
|
|
+ job.setMaxIteration(Integer.parseInt(args[2]));
|
|
|
+
|
|
|
+ long start = System.currentTimeMillis();
|
|
|
+ job.run();
|
|
|
+ System.out.println("Job Finished in " + (System.currentTimeMillis() - start) / 1000.0
|
|
|
+ + " seconds");
|
|
|
+ }
|
|
|
+}
|