IGNITE-7590: fixed tree example
authorartemmalykh <amalykh@gridgain.com>
Thu, 1 Feb 2018 09:43:02 +0000 (12:43 +0300)
committerYury Babak <ybabak@gridgain.com>
Thu, 1 Feb 2018 09:43:02 +0000 (12:43 +0300)
this closes #3459

examples/src/main/java/org/apache/ignite/examples/ml/MLExamplesCommonArgs.java [new file with mode: 0644]
examples/src/main/java/org/apache/ignite/examples/ml/trees/DecisionTreesExample.java [moved from examples/src/main/java/org/apache/ignite/examples/ml/trees/MNISTExample.java with 52% similarity]
examples/src/test/java/org/apache/ignite/testsuites/IgniteExamplesMLTestSuite.java

diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/MLExamplesCommonArgs.java b/examples/src/main/java/org/apache/ignite/examples/ml/MLExamplesCommonArgs.java
new file mode 100644 (file)
index 0000000..701894b
--- /dev/null
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml;
+
+/**
+ * Some common arguments for examples in ML module.
+ */
+public class MLExamplesCommonArgs {
+    /**
+     * Unattended argument.
+     */
+    public static String UNATTENDED = "unattended";
+
+    /** Empty args for ML examples. */
+    public static final String[] EMPTY_ARGS_ML = new String[] {"--" + UNATTENDED};
+}
 
 package org.apache.ignite.examples.ml.trees;
 
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
 import java.io.IOException;
+import java.net.URL;
+import java.nio.channels.Channels;
+import java.nio.channels.ReadableByteChannel;
+import java.util.Collection;
 import java.util.HashMap;
 import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
 import java.util.Random;
+import java.util.Scanner;
 import java.util.function.Function;
+import java.util.stream.Collectors;
 import java.util.stream.Stream;
+import java.util.zip.GZIPInputStream;
 import org.apache.commons.cli.BasicParser;
 import org.apache.commons.cli.CommandLine;
 import org.apache.commons.cli.CommandLineParser;
@@ -34,11 +46,10 @@ import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.IgniteDataStreamer;
 import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.CacheAtomicityMode;
-import org.apache.ignite.cache.CacheMode;
 import org.apache.ignite.cache.CacheWriteSynchronizationMode;
 import org.apache.ignite.configuration.CacheConfiguration;
 import org.apache.ignite.examples.ExampleNodeStartup;
+import org.apache.ignite.examples.ml.MLExamplesCommonArgs;
 import org.apache.ignite.internal.util.IgniteUtils;
 import org.apache.ignite.lang.IgniteBiTuple;
 import org.apache.ignite.ml.Model;
@@ -70,32 +81,12 @@ import org.jetbrains.annotations.NotNull;
  * It is recommended to start at least one node prior to launching this example if you intend
  * to run it with default memory settings.</p>
  * <p>
- * This example should with program arguments, for example
- * -ts_i /path/to/train-images-idx3-ubyte
- * -ts_l /path/to/train-labels-idx1-ubyte
- * -tss_i /path/to/t10k-images-idx3-ubyte
- * -tss_l /path/to/t10k-labels-idx1-ubyte
+ * This example should be run with program arguments, for example
  * -cfg examples/config/example-ignite.xml.</p>
  * <p>
- * -ts_i specifies path to training set images of MNIST;
- * -ts_l specifies path to training set labels of MNIST;
- * -tss_i specifies path to test set images of MNIST;
- * -tss_l specifies path to test set labels of MNIST;
  * -cfg specifies path to a config path.</p>
  */
-public class MNISTExample {
-    /** Name of parameter specifying path to training set images. */
-    private static final String MNIST_TRAINING_IMAGES_PATH = "ts_i";
-
-    /** Name of parameter specifying path to training set labels. */
-    private static final String MNIST_TRAINING_LABELS_PATH = "ts_l";
-
-    /** Name of parameter specifying path to test set images. */
-    private static final String MNIST_TEST_IMAGES_PATH = "tss_i";
-
-    /** Name of parameter specifying path to test set labels. */
-    private static final String MNIST_TEST_LABELS_PATH = "tss_l";
-
+public class DecisionTreesExample {
     /** Name of parameter specifying path of Ignite config. */
     private static final String CONFIG = "cfg";
 
@@ -103,11 +94,38 @@ public class MNISTExample {
     private static final String DEFAULT_CONFIG = "examples/config/example-ignite.xml";
 
     /**
+     * Folder in which MNIST dataset is expected.
+     */
+    private static String MNIST_DIR = "examples/src/main/resources/";
+
+    /**
+     * Key for MNIST training images.
+     */
+    private static String MNIST_TRAIN_IMAGES = "train_images";
+
+    /**
+     * Key for MNIST training labels.
+     */
+    private static String MNIST_TRAIN_LABELS = "train_labels";
+
+    /**
+     * Key for MNIST test images.
+     */
+    private static String MNIST_TEST_IMAGES = "test_images";
+
+    /**
+     * Key for MNIST test labels.
+     */
+    private static String MNIST_TEST_LABELS = "test_labels";
+
+    /**
      * Launches example.
      *
      * @param args Program arguments.
      */
-    public static void main(String[] args) {
+    public static void main(String[] args) throws IOException {
+        System.out.println(">>> Decision trees example started.");
+
         String igniteCfgPath;
 
         CommandLineParser parser = new BasicParser();
@@ -118,14 +136,23 @@ public class MNISTExample {
         String testImagesPath;
         String testLabelsPath;
 
+        Map<String, String> mnistPaths = new HashMap<>();
+
+        mnistPaths.put(MNIST_TRAIN_IMAGES, "train-images-idx3-ubyte");
+        mnistPaths.put(MNIST_TRAIN_LABELS, "train-labels-idx1-ubyte");
+        mnistPaths.put(MNIST_TEST_IMAGES, "t10k-images-idx3-ubyte");
+        mnistPaths.put(MNIST_TEST_LABELS, "t10k-labels-idx1-ubyte");
+
         try {
             // Parse the command line arguments.
             CommandLine line = parser.parse(buildOptions(), args);
 
-            trainingImagesPath = line.getOptionValue(MNIST_TRAINING_IMAGES_PATH);
-            trainingLabelsPath = line.getOptionValue(MNIST_TRAINING_LABELS_PATH);
-            testImagesPath = line.getOptionValue(MNIST_TEST_IMAGES_PATH);
-            testLabelsPath = line.getOptionValue(MNIST_TEST_LABELS_PATH);
+            if (line.hasOption(MLExamplesCommonArgs.UNATTENDED)) {
+                System.out.println(">>> Skipped example execution because 'unattended' mode is used.");
+                System.out.println(">>> Decision trees example finished.");
+                return;
+            }
+
             igniteCfgPath = line.getOptionValue(CONFIG, DEFAULT_CONFIG);
         }
         catch (ParseException e) {
@@ -133,67 +160,144 @@ public class MNISTExample {
             return;
         }
 
+        if (!getMNIST(mnistPaths.values())) {
+            System.out.println(">>> You should have MNIST dataset in " + MNIST_DIR + " to run this example.");
+            return;
+        }
+
+        trainingImagesPath = Objects.requireNonNull(IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" +
+            mnistPaths.get(MNIST_TRAIN_IMAGES))).getPath();
+        trainingLabelsPath = Objects.requireNonNull(IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" +
+            mnistPaths.get(MNIST_TRAIN_LABELS))).getPath();
+        testImagesPath = Objects.requireNonNull(IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" +
+            mnistPaths.get(MNIST_TEST_IMAGES))).getPath();
+        testLabelsPath = Objects.requireNonNull(IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" +
+            mnistPaths.get(MNIST_TEST_LABELS))).getPath();
+
         try (Ignite ignite = Ignition.start(igniteCfgPath)) {
             IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
 
             int ptsCnt = 60000;
             int featCnt = 28 * 28;
 
-            Stream<DenseLocalOnHeapVector> trainingMnistStream = MnistUtils.mnist(trainingImagesPath, trainingLabelsPath, new Random(123L), ptsCnt);
-            Stream<DenseLocalOnHeapVector> testMnistStream = MnistUtils.mnist(testImagesPath, testLabelsPath, new Random(123L), 10_000);
+            Stream<DenseLocalOnHeapVector> trainingMnistStream = MnistUtils.mnist(trainingImagesPath, trainingLabelsPath,
+                new Random(123L), ptsCnt);
+
+            Stream<DenseLocalOnHeapVector> testMnistStream = MnistUtils.mnist(testImagesPath, testLabelsPath,
+                new Random(123L), 10_000);
 
             IgniteCache<BiIndex, Double> cache = createBiIndexedCache(ignite);
 
             loadVectorsIntoBiIndexedCache(cache.getName(), trainingMnistStream.iterator(), featCnt + 1, ignite);
 
-            ColumnDecisionTreeTrainer<GiniSplitCalculator.GiniData> trainer =
-                new ColumnDecisionTreeTrainer<>(10, ContinuousSplitCalculators.GINI.apply(ignite), RegionCalculators.GINI, RegionCalculators.MOST_COMMON, ignite);
+            ColumnDecisionTreeTrainer<GiniSplitCalculator.GiniData> trainer = new ColumnDecisionTreeTrainer<>(10,
+                ContinuousSplitCalculators.GINI.apply(ignite),
+                RegionCalculators.GINI,
+                RegionCalculators.MOST_COMMON,
+                ignite);
 
             System.out.println(">>> Training started");
             long before = System.currentTimeMillis();
             DecisionTreeModel mdl = trainer.train(new BiIndexedCacheColumnDecisionTreeTrainerInput(cache, new HashMap<>(), ptsCnt, featCnt));
             System.out.println(">>> Training finished in " + (System.currentTimeMillis() - before));
 
-            IgniteTriFunction<Model<Vector, Double>, Stream<IgniteBiTuple<Vector, Double>>, Function<Double, Double>, Double> mse = Estimators.errorsPercentage();
-            Double accuracy = mse.apply(mdl, testMnistStream.map(v -> new IgniteBiTuple<>(v.viewPart(0, featCnt), v.getX(featCnt))), Function.identity());
+            IgniteTriFunction<Model<Vector, Double>, Stream<IgniteBiTuple<Vector, Double>>, Function<Double, Double>, Double> mse =
+                Estimators.errorsPercentage();
+
+            Double accuracy = mse.apply(mdl, testMnistStream.map(v ->
+                new IgniteBiTuple<>(v.viewPart(0, featCnt), v.getX(featCnt))), Function.identity());
+
             System.out.println(">>> Errs percentage: " + accuracy);
         }
         catch (IOException e) {
             e.printStackTrace();
         }
+
+        System.out.println(">>> Decision trees example finished.");
     }
 
     /**
-     * Build cli options.
+     * Get MNIST dataset. Value of predicate 'MNIST dataset is present in expected folder' is returned.
+     *
+     * @param mnistFileNames File names of MNIST dataset.
+     * @return Value of predicate 'MNIST dataset is present in expected folder'.
+     * @throws IOException In case of file system errors.
      */
-    @NotNull private static Options buildOptions() {
-        Options options = new Options();
+    private static boolean getMNIST(Collection<String> mnistFileNames) throws IOException {
+        List<String> missing = mnistFileNames.stream().
+            filter(f -> IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" + f) == null).
+            collect(Collectors.toList());
+
+        if (!missing.isEmpty()) {
+            System.out.println(">>> You have not fully downloaded MNIST dataset in directory " + MNIST_DIR +
+                ", do you want it to be downloaded? [y]/n");
+            Scanner s = new Scanner(System.in);
+            String str = s.nextLine();
+
+            if (!str.isEmpty() && !str.toLowerCase().equals("y"))
+                return false;
+        }
 
-        Option trsImagesPathOpt = OptionBuilder.withArgName(MNIST_TRAINING_IMAGES_PATH).withLongOpt(MNIST_TRAINING_IMAGES_PATH).hasArg()
-            .withDescription("Path to the MNIST training set.")
-            .isRequired(true).create();
+        for (String s : missing) {
+            String f = s + ".gz";
+            System.out.println(">>> Downloading " + f + "...");
+            URL website = new URL("http://yann.lecun.com/exdb/mnist/" + f);
+            ReadableByteChannel rbc = Channels.newChannel(website.openStream());
+            FileOutputStream fos = new FileOutputStream(MNIST_DIR + "/" + f);
+            fos.getChannel().transferFrom(rbc, 0, Long.MAX_VALUE);
+            System.out.println(">>> Done.");
 
-        Option trsLabelsPathOpt = OptionBuilder.withArgName(MNIST_TRAINING_LABELS_PATH).withLongOpt(MNIST_TRAINING_LABELS_PATH).hasArg()
-            .withDescription("Path to the MNIST training set.")
-            .isRequired(true).create();
+            System.out.println(">>> Unzipping " + f + "...");
+            unzip(MNIST_DIR + "/" + f, MNIST_DIR + "/" + s);
 
-        Option tssImagesPathOpt = OptionBuilder.withArgName(MNIST_TEST_IMAGES_PATH).withLongOpt(MNIST_TEST_IMAGES_PATH).hasArg()
-            .withDescription("Path to the MNIST test set.")
-            .isRequired(true).create();
+            System.out.println(">>> Deleting gzip " + f + ", status: " +
+                Objects.requireNonNull(IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" + f)).delete());
 
-        Option tssLabelsPathOpt = OptionBuilder.withArgName(MNIST_TEST_LABELS_PATH).withLongOpt(MNIST_TEST_LABELS_PATH).hasArg()
-            .withDescription("Path to the MNIST test set.")
-            .isRequired(true).create();
+            System.out.println(">>> Done.");
+        }
+
+        return true;
+    }
 
-        Option configOpt = OptionBuilder.withArgName(CONFIG).withLongOpt(CONFIG).hasArg()
+    /**
+     * Unzip file located in {@code input} to {@code output}.
+     *
+     * @param input Input file path.
+     * @param output Output file path.
+     * @throws IOException In case of file system errors.
+     */
+    private static void unzip(String input, String output) throws IOException {
+        byte[] buf = new byte[1024];
+
+        try (GZIPInputStream gis = new GZIPInputStream(new FileInputStream(input));
+             FileOutputStream out = new FileOutputStream(output)) {
+            int sz;
+            while ((sz = gis.read(buf)) > 0)
+                out.write(buf, 0, sz);
+        }
+    }
+
+    /**
+     * Build cli options.
+     */
+    @NotNull private static Options buildOptions() {
+        Options options = new Options();
+
+        Option cfgOpt = OptionBuilder
+            .withArgName(CONFIG)
+            .withLongOpt(CONFIG)
+            .hasArg()
             .withDescription("Path to the config.")
             .isRequired(false).create();
 
-        options.addOption(trsImagesPathOpt);
-        options.addOption(trsLabelsPathOpt);
-        options.addOption(tssImagesPathOpt);
-        options.addOption(tssLabelsPathOpt);
-        options.addOption(configOpt);
+        Option unattended = OptionBuilder
+            .withArgName(MLExamplesCommonArgs.UNATTENDED)
+            .withLongOpt(MLExamplesCommonArgs.UNATTENDED)
+            .withDescription("Is example run unattended.")
+            .isRequired(false).create();
+
+        options.addOption(cfgOpt);
+        options.addOption(unattended);
 
         return options;
     }
@@ -210,20 +314,9 @@ public class MNISTExample {
         // Write to primary.
         cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.PRIMARY_SYNC);
 
-        // Atomic transactions only.
-        cfg.setAtomicityMode(CacheAtomicityMode.ATOMIC);
-
-        // No eviction.
-        cfg.setEvictionPolicy(null);
-
         // No copying of values.
         cfg.setCopyOnRead(false);
 
-        // Cache is partitioned.
-        cfg.setCacheMode(CacheMode.PARTITIONED);
-
-        cfg.setBackups(0);
-
         cfg.setName("TMP_BI_INDEXED_CACHE");
 
         return ignite.getOrCreateCache(cfg);
@@ -233,11 +326,11 @@ public class MNISTExample {
      * Loads vectors into cache.
      *
      * @param cacheName Name of cache.
-     * @param vectorsIterator Iterator over vectors to load.
+     * @param vectorsIter Iterator over vectors to load.
      * @param vectorSize Size of vector.
      * @param ignite Ignite instance.
      */
-    private static void loadVectorsIntoBiIndexedCache(String cacheName, Iterator<? extends Vector> vectorsIterator,
+    private static void loadVectorsIntoBiIndexedCache(String cacheName, Iterator<? extends Vector> vectorsIter,
         int vectorSize, Ignite ignite) {
         try (IgniteDataStreamer<BiIndex, Double> streamer =
                  ignite.dataStreamer(cacheName)) {
@@ -245,8 +338,8 @@ public class MNISTExample {
 
             streamer.perNodeBufferSize(10000);
 
-            while (vectorsIterator.hasNext()) {
-                org.apache.ignite.ml.math.Vector next = vectorsIterator.next();
+            while (vectorsIter.hasNext()) {
+                org.apache.ignite.ml.math.Vector next = vectorsIter.next();
 
                 for (int i = 0; i < vectorSize; i++)
                     streamer.addData(new BiIndex(sampleIdx, i), next.getX(i));
@@ -254,7 +347,7 @@ public class MNISTExample {
                 sampleIdx++;
 
                 if (sampleIdx % 1000 == 0)
-                    System.out.println("Loaded " + sampleIdx + " vectors.");
+                    System.out.println(">>> Loaded " + sampleIdx + " vectors.");
             }
         }
     }
index d2f40e6..df85f1a 100644 (file)
@@ -30,6 +30,7 @@ import javassist.CtClass;
 import javassist.CtNewMethod;
 import javassist.NotFoundException;
 import junit.framework.TestSuite;
+import org.apache.ignite.examples.ml.MLExamplesCommonArgs;
 import org.apache.ignite.testframework.GridTestUtils;
 import org.apache.ignite.testframework.junits.common.GridAbstractExamplesTest;
 
@@ -85,8 +86,8 @@ public class IgniteExamplesMLTestSuite extends TestSuite {
         cl.addMethod(CtNewMethod.make("public void testExample() { "
             + exampleCls.getCanonicalName()
             + ".main("
-            + GridAbstractExamplesTest.class.getName()
-            + ".EMPTY_ARGS); }", cl));
+            + MLExamplesCommonArgs.class.getName()
+            + ".EMPTY_ARGS_ML); }", cl));
 
         return cl.toClass();
     }