123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259 |
- package com.aliyun.odps.examples.graph;
- import java.io.DataInput;
- import java.io.DataOutput;
- import java.io.IOException;
- import org.apache.commons.logging.Log;
- import org.apache.commons.logging.LogFactory;
- import com.aliyun.odps.data.TableInfo;
- import com.aliyun.odps.graph.Aggregator;
- import com.aliyun.odps.graph.ComputeContext;
- import com.aliyun.odps.graph.GraphJob;
- import com.aliyun.odps.graph.GraphLoader;
- import com.aliyun.odps.graph.MutationContext;
- import com.aliyun.odps.graph.Vertex;
- import com.aliyun.odps.graph.WorkerContext;
- import com.aliyun.odps.io.DoubleWritable;
- import com.aliyun.odps.io.LongWritable;
- import com.aliyun.odps.io.NullWritable;
- import com.aliyun.odps.io.Text;
- import com.aliyun.odps.io.Tuple;
- import com.aliyun.odps.io.Writable;
- import com.aliyun.odps.io.WritableRecord;
- /**
- * Set resources arguments:
- * kmeans_centers
- * Set program arguments:
- * kmeans_in kmeans_out
- */
- public class Kmeans {
- private final static Log LOG = LogFactory.getLog(Kmeans.class);
- public static class KmeansVertex extends Vertex<Text, Tuple, NullWritable, NullWritable> {
- @Override
- public void compute(ComputeContext<Text, Tuple, NullWritable, NullWritable> context,
- Iterable<NullWritable> messages) throws IOException {
- context.aggregate(getValue());
- }
- }
- public static class KmeansVertexReader extends
- GraphLoader<Text, Tuple, NullWritable, NullWritable> {
- @Override
- public void load(LongWritable recordNum, WritableRecord record,
- MutationContext<Text, Tuple, NullWritable, NullWritable> context) throws IOException {
- KmeansVertex vertex = new KmeansVertex();
- vertex.setId(new Text(String.valueOf(recordNum.get())));
- vertex.setValue(new Tuple(record.getAll()));
- context.addVertexRequest(vertex);
- }
- }
- public static class KmeansAggrValue implements Writable {
- Tuple centers = new Tuple();
- Tuple sums = new Tuple();
- Tuple counts = new Tuple();
- public void write(DataOutput out) throws IOException {
- centers.write(out);
- sums.write(out);
- counts.write(out);
- }
- public void readFields(DataInput in) throws IOException {
- centers = new Tuple();
- centers.readFields(in);
- sums = new Tuple();
- sums.readFields(in);
- counts = new Tuple();
- counts.readFields(in);
- }
- @Override
- public String toString() {
- return "centers " + centers.toString() + ", sums " + sums.toString() + ", counts "
- + counts.toString();
- }
- }
- public static class KmeansAggregator extends Aggregator<KmeansAggrValue> {
- @SuppressWarnings("rawtypes")
- @Override
- public KmeansAggrValue createInitialValue(WorkerContext context) throws IOException {
- KmeansAggrValue aggrVal = null;
- if (context.getSuperstep() == 0) {
- aggrVal = new KmeansAggrValue();
- aggrVal.centers = new Tuple();
- aggrVal.sums = new Tuple();
- aggrVal.counts = new Tuple();
- byte[] centers = context.readCacheFile("kmeans_centers");
- String lines[] = new String(centers).split("\n");
- for (int i = 0; i < lines.length; i++) {
- String[] ss = lines[i].split(",");
- Tuple center = new Tuple();
- Tuple sum = new Tuple();
- for (int j = 0; j < ss.length; ++j) {
- center.append(new DoubleWritable(Double.valueOf(ss[j].trim())));
- sum.append(new DoubleWritable(0.0));
- }
- LongWritable count = new LongWritable(0);
- aggrVal.sums.append(sum);
- aggrVal.counts.append(count);
- aggrVal.centers.append(center);
- }
- } else {
- aggrVal = (KmeansAggrValue) context.getLastAggregatedValue(0);
- }
- return aggrVal;
- }
- @Override
- public void aggregate(KmeansAggrValue value, Object item) {
- int min = 0;
- double mindist = Double.MAX_VALUE;
- Tuple point = (Tuple) item;
- for (int i = 0; i < value.centers.size(); i++) {
- Tuple center = (Tuple) value.centers.get(i);
- // use Euclidean Distance, no need to calculate sqrt
- double dist = 0.0d;
- for (int j = 0; j < center.size(); j++) {
- double v = ((DoubleWritable) point.get(j)).get() - ((DoubleWritable) center.get(j)).get();
- dist += v * v;
- }
- if (dist < mindist) {
- mindist = dist;
- min = i;
- }
- }
- // update sum and count
- Tuple sum = (Tuple) value.sums.get(min);
- for (int i = 0; i < point.size(); i++) {
- DoubleWritable s = (DoubleWritable) sum.get(i);
- s.set(s.get() + ((DoubleWritable) point.get(i)).get());
- }
- LongWritable count = (LongWritable) value.counts.get(min);
- count.set(count.get() + 1);
- }
- @Override
- public void merge(KmeansAggrValue value, KmeansAggrValue partial) {
- for (int i = 0; i < value.sums.size(); i++) {
- Tuple sum = (Tuple) value.sums.get(i);
- Tuple that = (Tuple) partial.sums.get(i);
- for (int j = 0; j < sum.size(); j++) {
- DoubleWritable s = (DoubleWritable) sum.get(j);
- s.set(s.get() + ((DoubleWritable) that.get(j)).get());
- }
- }
- for (int i = 0; i < value.counts.size(); i++) {
- LongWritable count = (LongWritable) value.counts.get(i);
- count.set(count.get() + ((LongWritable) partial.counts.get(i)).get());
- }
- }
- @SuppressWarnings("rawtypes")
- @Override
- public boolean terminate(WorkerContext context, KmeansAggrValue value) throws IOException {
- // compute new centers
- Tuple newCenters = new Tuple(value.sums.size());
- for (int i = 0; i < value.sums.size(); i++) {
- Tuple sum = (Tuple) value.sums.get(i);
- Tuple newCenter = new Tuple(sum.size());
- LongWritable c = (LongWritable) value.counts.get(i);
- for (int j = 0; j < sum.size(); j++) {
- DoubleWritable s = (DoubleWritable) sum.get(j);
- double val = s.get() / c.get();
- newCenter.set(j, new DoubleWritable(val));
- // reset sum for next iteration
- s.set(0.0d);
- }
- // reset count for next iteration
- c.set(0);
- newCenters.set(i, newCenter);
- }
- // update centers
- Tuple oldCenters = value.centers;
- value.centers = newCenters;
- LOG.info("old centers: " + oldCenters + ", new centers: " + newCenters);
- // compare new/old centers
- boolean converged = true;
- for (int i = 0; i < value.centers.size() && converged; i++) {
- Tuple oldCenter = (Tuple) oldCenters.get(i);
- Tuple newCenter = (Tuple) newCenters.get(i);
- double sum = 0.0d;
- for (int j = 0; j < newCenter.size(); j++) {
- double v =
- ((DoubleWritable) newCenter.get(j)).get() - ((DoubleWritable) oldCenter.get(j)).get();
- sum += v * v;
- }
- double dist = Math.sqrt(sum);
- LOG.info("old center: " + oldCenter + ", new center: " + newCenter + ", dist: " + dist);
- // converge threshold for each center: 0.05
- converged = dist < 0.05d;
- }
- if (converged || context.getSuperstep() == context.getMaxIteration() - 1) {
- // converged or reach max iteration, output centers
- for (int i = 0; i < value.centers.size(); i++) {
- context.write(((Tuple) value.centers.get(i)).toArray());
- }
- // true means to terminate iteration
- return true;
- }
- // false means to continue iteration
- return false;
- }
- }
- private static void printUsage() {
- System.out.println("Usage: <in> <out> [Max iterations (default 30)]");
- System.exit(-1);
- }
- public static void main(String[] args) throws IOException {
- if (args.length < 2)
- printUsage();
- GraphJob job = new GraphJob();
- job.setGraphLoaderClass(KmeansVertexReader.class);
- job.setRuntimePartitioning(false);
- job.setVertexClass(KmeansVertex.class);
- job.setAggregatorClass(KmeansAggregator.class);
- job.addInput(TableInfo.builder().tableName(args[0]).build());
- job.addOutput(TableInfo.builder().tableName(args[1]).build());
- // default max iteration is 30
- job.setMaxIteration(30);
- if (args.length >= 3)
- job.setMaxIteration(Integer.parseInt(args[2]));
- long start = System.currentTimeMillis();
- job.run();
- System.out.println("Job Finished in " + (System.currentTimeMillis() - start) / 1000.0
- + " seconds");
- }
- }
|