Kmeans.java 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. package com.aliyun.odps.examples.graph;
  2. import java.io.DataInput;
  3. import java.io.DataOutput;
  4. import java.io.IOException;
  5. import org.apache.commons.logging.Log;
  6. import org.apache.commons.logging.LogFactory;
  7. import com.aliyun.odps.data.TableInfo;
  8. import com.aliyun.odps.graph.Aggregator;
  9. import com.aliyun.odps.graph.ComputeContext;
  10. import com.aliyun.odps.graph.GraphJob;
  11. import com.aliyun.odps.graph.GraphLoader;
  12. import com.aliyun.odps.graph.MutationContext;
  13. import com.aliyun.odps.graph.Vertex;
  14. import com.aliyun.odps.graph.WorkerContext;
  15. import com.aliyun.odps.io.DoubleWritable;
  16. import com.aliyun.odps.io.LongWritable;
  17. import com.aliyun.odps.io.NullWritable;
  18. import com.aliyun.odps.io.Text;
  19. import com.aliyun.odps.io.Tuple;
  20. import com.aliyun.odps.io.Writable;
  21. import com.aliyun.odps.io.WritableRecord;
  22. /**
  23. * Set resources arguments:
  24. * kmeans_centers
  25. * Set program arguments:
  26. * kmeans_in kmeans_out
  27. */
  28. public class Kmeans {
  29. private final static Log LOG = LogFactory.getLog(Kmeans.class);
  30. public static class KmeansVertex extends Vertex<Text, Tuple, NullWritable, NullWritable> {
  31. @Override
  32. public void compute(ComputeContext<Text, Tuple, NullWritable, NullWritable> context,
  33. Iterable<NullWritable> messages) throws IOException {
  34. context.aggregate(getValue());
  35. }
  36. }
  37. public static class KmeansVertexReader extends
  38. GraphLoader<Text, Tuple, NullWritable, NullWritable> {
  39. @Override
  40. public void load(LongWritable recordNum, WritableRecord record,
  41. MutationContext<Text, Tuple, NullWritable, NullWritable> context) throws IOException {
  42. KmeansVertex vertex = new KmeansVertex();
  43. vertex.setId(new Text(String.valueOf(recordNum.get())));
  44. vertex.setValue(new Tuple(record.getAll()));
  45. context.addVertexRequest(vertex);
  46. }
  47. }
  48. public static class KmeansAggrValue implements Writable {
  49. Tuple centers = new Tuple();
  50. Tuple sums = new Tuple();
  51. Tuple counts = new Tuple();
  52. public void write(DataOutput out) throws IOException {
  53. centers.write(out);
  54. sums.write(out);
  55. counts.write(out);
  56. }
  57. public void readFields(DataInput in) throws IOException {
  58. centers = new Tuple();
  59. centers.readFields(in);
  60. sums = new Tuple();
  61. sums.readFields(in);
  62. counts = new Tuple();
  63. counts.readFields(in);
  64. }
  65. @Override
  66. public String toString() {
  67. return "centers " + centers.toString() + ", sums " + sums.toString() + ", counts "
  68. + counts.toString();
  69. }
  70. }
  71. public static class KmeansAggregator extends Aggregator<KmeansAggrValue> {
  72. @SuppressWarnings("rawtypes")
  73. @Override
  74. public KmeansAggrValue createInitialValue(WorkerContext context) throws IOException {
  75. KmeansAggrValue aggrVal = null;
  76. if (context.getSuperstep() == 0) {
  77. aggrVal = new KmeansAggrValue();
  78. aggrVal.centers = new Tuple();
  79. aggrVal.sums = new Tuple();
  80. aggrVal.counts = new Tuple();
  81. byte[] centers = context.readCacheFile("kmeans_centers");
  82. String lines[] = new String(centers).split("\n");
  83. for (int i = 0; i < lines.length; i++) {
  84. String[] ss = lines[i].split(",");
  85. Tuple center = new Tuple();
  86. Tuple sum = new Tuple();
  87. for (int j = 0; j < ss.length; ++j) {
  88. center.append(new DoubleWritable(Double.valueOf(ss[j].trim())));
  89. sum.append(new DoubleWritable(0.0));
  90. }
  91. LongWritable count = new LongWritable(0);
  92. aggrVal.sums.append(sum);
  93. aggrVal.counts.append(count);
  94. aggrVal.centers.append(center);
  95. }
  96. } else {
  97. aggrVal = (KmeansAggrValue) context.getLastAggregatedValue(0);
  98. }
  99. return aggrVal;
  100. }
  101. @Override
  102. public void aggregate(KmeansAggrValue value, Object item) {
  103. int min = 0;
  104. double mindist = Double.MAX_VALUE;
  105. Tuple point = (Tuple) item;
  106. for (int i = 0; i < value.centers.size(); i++) {
  107. Tuple center = (Tuple) value.centers.get(i);
  108. // use Euclidean Distance, no need to calculate sqrt
  109. double dist = 0.0d;
  110. for (int j = 0; j < center.size(); j++) {
  111. double v = ((DoubleWritable) point.get(j)).get() - ((DoubleWritable) center.get(j)).get();
  112. dist += v * v;
  113. }
  114. if (dist < mindist) {
  115. mindist = dist;
  116. min = i;
  117. }
  118. }
  119. // update sum and count
  120. Tuple sum = (Tuple) value.sums.get(min);
  121. for (int i = 0; i < point.size(); i++) {
  122. DoubleWritable s = (DoubleWritable) sum.get(i);
  123. s.set(s.get() + ((DoubleWritable) point.get(i)).get());
  124. }
  125. LongWritable count = (LongWritable) value.counts.get(min);
  126. count.set(count.get() + 1);
  127. }
  128. @Override
  129. public void merge(KmeansAggrValue value, KmeansAggrValue partial) {
  130. for (int i = 0; i < value.sums.size(); i++) {
  131. Tuple sum = (Tuple) value.sums.get(i);
  132. Tuple that = (Tuple) partial.sums.get(i);
  133. for (int j = 0; j < sum.size(); j++) {
  134. DoubleWritable s = (DoubleWritable) sum.get(j);
  135. s.set(s.get() + ((DoubleWritable) that.get(j)).get());
  136. }
  137. }
  138. for (int i = 0; i < value.counts.size(); i++) {
  139. LongWritable count = (LongWritable) value.counts.get(i);
  140. count.set(count.get() + ((LongWritable) partial.counts.get(i)).get());
  141. }
  142. }
  143. @SuppressWarnings("rawtypes")
  144. @Override
  145. public boolean terminate(WorkerContext context, KmeansAggrValue value) throws IOException {
  146. // compute new centers
  147. Tuple newCenters = new Tuple(value.sums.size());
  148. for (int i = 0; i < value.sums.size(); i++) {
  149. Tuple sum = (Tuple) value.sums.get(i);
  150. Tuple newCenter = new Tuple(sum.size());
  151. LongWritable c = (LongWritable) value.counts.get(i);
  152. for (int j = 0; j < sum.size(); j++) {
  153. DoubleWritable s = (DoubleWritable) sum.get(j);
  154. double val = s.get() / c.get();
  155. newCenter.set(j, new DoubleWritable(val));
  156. // reset sum for next iteration
  157. s.set(0.0d);
  158. }
  159. // reset count for next iteration
  160. c.set(0);
  161. newCenters.set(i, newCenter);
  162. }
  163. // update centers
  164. Tuple oldCenters = value.centers;
  165. value.centers = newCenters;
  166. LOG.info("old centers: " + oldCenters + ", new centers: " + newCenters);
  167. // compare new/old centers
  168. boolean converged = true;
  169. for (int i = 0; i < value.centers.size() && converged; i++) {
  170. Tuple oldCenter = (Tuple) oldCenters.get(i);
  171. Tuple newCenter = (Tuple) newCenters.get(i);
  172. double sum = 0.0d;
  173. for (int j = 0; j < newCenter.size(); j++) {
  174. double v =
  175. ((DoubleWritable) newCenter.get(j)).get() - ((DoubleWritable) oldCenter.get(j)).get();
  176. sum += v * v;
  177. }
  178. double dist = Math.sqrt(sum);
  179. LOG.info("old center: " + oldCenter + ", new center: " + newCenter + ", dist: " + dist);
  180. // converge threshold for each center: 0.05
  181. converged = dist < 0.05d;
  182. }
  183. if (converged || context.getSuperstep() == context.getMaxIteration() - 1) {
  184. // converged or reach max iteration, output centers
  185. for (int i = 0; i < value.centers.size(); i++) {
  186. context.write(((Tuple) value.centers.get(i)).toArray());
  187. }
  188. // true means to terminate iteration
  189. return true;
  190. }
  191. // false means to continue iteration
  192. return false;
  193. }
  194. }
  195. private static void printUsage() {
  196. System.out.println("Usage: <in> <out> [Max iterations (default 30)]");
  197. System.exit(-1);
  198. }
  199. public static void main(String[] args) throws IOException {
  200. if (args.length < 2)
  201. printUsage();
  202. GraphJob job = new GraphJob();
  203. job.setGraphLoaderClass(KmeansVertexReader.class);
  204. job.setRuntimePartitioning(false);
  205. job.setVertexClass(KmeansVertex.class);
  206. job.setAggregatorClass(KmeansAggregator.class);
  207. job.addInput(TableInfo.builder().tableName(args[0]).build());
  208. job.addOutput(TableInfo.builder().tableName(args[1]).build());
  209. // default max iteration is 30
  210. job.setMaxIteration(30);
  211. if (args.length >= 3)
  212. job.setMaxIteration(Integer.parseInt(args[2]));
  213. long start = System.currentTimeMillis();
  214. job.run();
  215. System.out.println("Job Finished in " + (System.currentTimeMillis() - start) / 1000.0
  216. + " seconds");
  217. }
  218. }