IGNITE-7990: Integrate MLP with partition based dataset
authordmitrievanthony <dmitrievanthony@gmail.com>
Thu, 29 Mar 2018 12:26:21 +0000 (15:26 +0300)
committerYury Babak <ybabak@gridgain.com>
Thu, 29 Mar 2018 12:26:21 +0000 (15:26 +0300)
this closes #3673

75 files changed:
examples/src/main/java/org/apache/ignite/examples/ml/dataset/AlgorithmSpecificDatasetExample.java
examples/src/main/java/org/apache/ignite/examples/ml/dataset/CacheBasedDatasetExample.java
examples/src/main/java/org/apache/ignite/examples/ml/dataset/LocalDatasetExample.java
examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPGroupTrainerExample.java [deleted file]
examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPLocalTrainerExample.java [deleted file]
examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java [new file with mode: 0644]
examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/NormalizationExample.java
examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerAndNormalizationExample.java
examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerExample.java
examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java
examples/src/main/java/org/apache/ignite/examples/ml/svm/multiclass/SVMMultiClassClassificationExample.java
examples/src/main/java/org/apache/ignite/examples/ml/trees/DecisionTreesExample.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetFactory.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/SimpleDataset.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/builder/data/SimpleDatasetDataBuilder.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/builder/data/SimpleLabeledDatasetDataBuilder.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/data/SimpleDatasetData.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/data/SimpleLabeledDatasetData.java
modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/LinSysPartitionDataBuilderOnHeap.java
modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/LinSysPartitionDataOnHeap.java
modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java
modules/ml/src/main/java/org/apache/ignite/ml/nn/LabeledVectorsCache.java [deleted file]
modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPGroupUpdateTrainerCacheInput.java [deleted file]
modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/nn/MultilayerPerceptron.java
modules/ml/src/main/java/org/apache/ignite/ml/nn/initializers/RandomInitializer.java
modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/AbstractMLPGroupUpdateTrainerInput.java [deleted file]
modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPCache.java [deleted file]
modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupTrainingCacheValue.java [deleted file]
modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainer.java [deleted file]
modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainerDataCache.java [deleted file]
modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainerLocalContext.java [deleted file]
modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainingContext.java [deleted file]
modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainingData.java [deleted file]
modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainingLoopData.java [deleted file]
modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPMetaoptimizer.java [deleted file]
modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/package-info.java [deleted file]
modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/MLPLocalBatchTrainer.java [deleted file]
modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/package-info.java [deleted file]
modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/NesterovParameterUpdate.java
modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/NesterovUpdateCalculator.java
modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/ParameterUpdateCalculator.java
modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/RPropParameterUpdate.java
modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/RPropUpdateCalculator.java
modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/SimpleGDParameterUpdate.java
modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/SimpleGDUpdateCalculator.java
modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMPartitionDataBuilderOnHeap.java
modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java [moved from modules/ml/src/main/java/org/apache/ignite/ml/DatasetTrainer.java with 74% similarity]
modules/ml/src/main/java/org/apache/ignite/ml/trainers/MultiLabelDatasetTrainer.java [moved from modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/package-info.java with 73% similarity]
modules/ml/src/main/java/org/apache/ignite/ml/trainers/SingleLabelDatasetTrainer.java [moved from modules/ml/src/main/java/org/apache/ignite/ml/trainers/local/package-info.java with 73% similarity]
modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/UpdatesStrategy.java
modules/ml/src/main/java/org/apache/ignite/ml/trainers/local/LocalBatchTrainer.java [deleted file]
modules/ml/src/main/java/org/apache/ignite/ml/trainers/local/LocalBatchTrainerInput.java [deleted file]
modules/ml/src/main/java/org/apache/ignite/ml/util/MnistUtils.java
modules/ml/src/test/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java
modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPGroupTrainerTest.java [deleted file]
modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPLocalTrainerTest.java [deleted file]
modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTest.java
modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTestSuite.java
modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerIntegrationTest.java [new file with mode: 0644]
modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java [new file with mode: 0644]
modules/ml/src/test/java/org/apache/ignite/ml/nn/SimpleMLPLocalBatchTrainerInput.java [deleted file]
modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java [new file with mode: 0644]
modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java [new file with mode: 0644]
modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistDistributed.java [deleted file]
modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistLocal.java [deleted file]
modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistMLPTestUtil.java
modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/trees/performance/ColumnDecisionTreeTrainerBenchmark.java
modules/ml/src/test/resources/manualrun/trees/columntrees.manualrun.properties

index b693dfa..b73e5fb 100644 (file)
@@ -66,8 +66,7 @@ public class AlgorithmSpecificDatasetExample {
                 (upstream, upstreamSize) -> new AlgorithmSpecificPartitionContext(),
                 new SimpleLabeledDatasetDataBuilder<Integer, Person, AlgorithmSpecificPartitionContext>(
                     (k, v) -> new double[] {v.getAge()},
-                    (k, v) -> v.getSalary(),
-                    1
+                    (k, v) -> new double[] {v.getSalary()}
                 ).andThen((data, ctx) -> {
                     double[] features = data.getFeatures();
                     int rows = data.getRows();
@@ -80,7 +79,7 @@ public class AlgorithmSpecificDatasetExample {
 
                     System.arraycopy(features, 0, a, rows, features.length);
 
-                    return new SimpleLabeledDatasetData(a, rows, data.getCols() + 1, data.getLabels());
+                    return new SimpleLabeledDatasetData(a, data.getLabels(), rows);
                 })
             ).wrap(AlgorithmSpecificDataset::new)) {
                 // Trains linear regression model using gradient descent.
@@ -125,11 +124,12 @@ public class AlgorithmSpecificDatasetExample {
         double[] gradient(double[] x) {
             return computeWithCtx((ctx, data, partIdx) -> {
                 double[] tmp = Arrays.copyOf(data.getLabels(), data.getRows());
-                blas.dgemv("N", data.getRows(), data.getCols(), 1.0, data.getFeatures(),
+                int featureCols = data.getFeatures().length / data.getRows();
+                blas.dgemv("N", data.getRows(), featureCols, 1.0, data.getFeatures(),
                     Math.max(1, data.getRows()), x, 1, -1.0, tmp, 1);
 
-                double[] res = new double[data.getCols()];
-                blas.dgemv("T", data.getRows(), data.getCols(), 1.0, data.getFeatures(),
+                double[] res = new double[featureCols];
+                blas.dgemv("T", data.getRows(), featureCols, 1.0, data.getFeatures(),
                     Math.max(1, data.getRows()), tmp, 1, 0.0, res, 1);
 
                 int iteration = ctx.getIteration();
index b1413ad..1ab9210 100644 (file)
@@ -43,8 +43,7 @@ public class CacheBasedDatasetExample {
             try (SimpleDataset<?> dataset = DatasetFactory.createSimpleDataset(
                 ignite,
                 persons,
-                (k, v) -> new double[]{ v.getAge(), v.getSalary() },
-                2
+                (k, v) -> new double[]{ v.getAge(), v.getSalary() }
             )) {
                 // Calculation of the mean value. This calculation will be performed in map-reduce manner.
                 double[] mean = dataset.mean();
index af14836..7ede803 100644 (file)
@@ -42,8 +42,7 @@ public class LocalDatasetExample {
             try (SimpleDataset<?> dataset = DatasetFactory.createSimpleDataset(
                 persons,
                 2,
-                (k, v) -> new double[]{ v.getAge(), v.getSalary() },
-                2
+                (k, v) -> new double[]{ v.getAge(), v.getSalary() }
             )) {
                 // Calculation of the mean value. This calculation will be performed in map-reduce manner.
                 double[] mean = dataset.mean();
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPGroupTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPGroupTrainerExample.java
deleted file mode 100644 (file)
index d45e957..0000000
+++ /dev/null
@@ -1,140 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.examples.ml.nn;
-
-import java.util.Random;
-import org.apache.ignite.Ignite;
-import org.apache.ignite.IgniteCache;
-import org.apache.ignite.IgniteDataStreamer;
-import org.apache.ignite.Ignition;
-import org.apache.ignite.examples.ExampleNodeStartup;
-import org.apache.ignite.ml.math.Matrix;
-import org.apache.ignite.ml.math.StorageConstants;
-import org.apache.ignite.ml.math.Tracer;
-import org.apache.ignite.ml.math.Vector;
-import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
-import org.apache.ignite.ml.nn.Activators;
-import org.apache.ignite.ml.nn.LabeledVectorsCache;
-import org.apache.ignite.ml.nn.MLPGroupUpdateTrainerCacheInput;
-import org.apache.ignite.ml.nn.MultilayerPerceptron;
-import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
-import org.apache.ignite.ml.nn.initializers.RandomInitializer;
-import org.apache.ignite.ml.nn.trainers.distributed.MLPGroupUpdateTrainer;
-import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate;
-import org.apache.ignite.ml.structures.LabeledVector;
-import org.apache.ignite.thread.IgniteThread;
-
-/**
- * Example of using distributed {@link MultilayerPerceptron}.
- * <p>
- * Remote nodes should always be started with special configuration file which
- * enables P2P class loading: {@code 'ignite.{sh|bat} examples/config/example-ignite.xml'}.</p>
- * <p>
- * Alternatively you can run {@link ExampleNodeStartup} in another JVM which will start node
- * with {@code examples/config/example-ignite.xml} configuration.</p>
- */
-public class MLPGroupTrainerExample {
-    /**
-     * Executes example.
-     *
-     * @param args Command line arguments, none required.
-     */
-    public static void main(String[] args) throws InterruptedException {
-        // IMPL NOTE based on MLPGroupTrainerTest#testXOR
-        System.out.println(">>> Distributed multilayer perceptron example started.");
-
-        // Start ignite grid.
-        try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
-            System.out.println(">>> Ignite grid started.");
-
-            // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread
-            // because we create ignite cache internally.
-            IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
-                MLPGroupTrainerExample.class.getSimpleName(), () -> {
-
-                int samplesCnt = 10000;
-
-                Matrix xorInputs = new DenseLocalOnHeapMatrix(
-                    new double[][] {{0.0, 0.0}, {0.0, 1.0}, {1.0, 0.0}, {1.0, 1.0}},
-                    StorageConstants.ROW_STORAGE_MODE).transpose();
-
-                Matrix xorOutputs = new DenseLocalOnHeapMatrix(
-                    new double[][] {{0.0}, {1.0}, {1.0}, {0.0}},
-                    StorageConstants.ROW_STORAGE_MODE).transpose();
-
-                MLPArchitecture conf = new MLPArchitecture(2).
-                    withAddedLayer(10, true, Activators.RELU).
-                    withAddedLayer(1, false, Activators.SIGMOID);
-
-                IgniteCache<Integer, LabeledVector<Vector, Vector>> cache = LabeledVectorsCache.createNew(ignite);
-                String cacheName = cache.getName();
-                Random rnd = new Random(12345L);
-
-                try (IgniteDataStreamer<Integer, LabeledVector<Vector, Vector>> streamer =
-                         ignite.dataStreamer(cacheName)) {
-                    streamer.perNodeBufferSize(100);
-
-                    for (int i = 0; i < samplesCnt; i++) {
-                        int col = Math.abs(rnd.nextInt()) % 4;
-                        streamer.addData(i, new LabeledVector<>(xorInputs.getCol(col), xorOutputs.getCol(col)));
-                    }
-                }
-
-                int totalCnt = 100;
-                int failCnt = 0;
-                MLPGroupUpdateTrainer<RPropParameterUpdate> trainer = MLPGroupUpdateTrainer.getDefault(ignite).
-                    withSyncPeriod(3).
-                    withTolerance(0.001).
-                    withMaxGlobalSteps(20);
-
-                for (int i = 0; i < totalCnt; i++) {
-
-                    MLPGroupUpdateTrainerCacheInput trainerInput = new MLPGroupUpdateTrainerCacheInput(conf,
-                        new RandomInitializer(rnd), 6, cache, 10);
-
-                    MultilayerPerceptron mlp = trainer.train(trainerInput);
-
-                    Matrix predict = mlp.apply(xorInputs);
-
-                    System.out.println(">>> Prediction data at step " + i + " of total " + totalCnt + ":");
-
-                    Tracer.showAscii(predict);
-
-                    System.out.println("Difference estimate: " + xorOutputs.getRow(0).minus(predict.getRow(0)).kNorm(2));
-
-                    failCnt += closeEnough(xorOutputs.getRow(0), predict.getRow(0)) ? 0 : 1;
-                }
-
-                double failRatio = (double)failCnt / totalCnt;
-
-                System.out.println("\n>>> Fail percentage: " + (failRatio * 100) + "%.");
-
-                System.out.println("\n>>> Distributed multilayer perceptron example completed.");
-            });
-
-            igniteThread.start();
-
-            igniteThread.join();
-        }
-    }
-
-    /** */
-    private static boolean closeEnough(Vector v1, Vector v2) {
-        return v1.minus(v2).kNorm(2) < 5E-1;
-    }
-}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPLocalTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPLocalTrainerExample.java
deleted file mode 100644 (file)
index 02280ce..0000000
+++ /dev/null
@@ -1,161 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.examples.ml.nn;
-
-import java.util.Random;
-import org.apache.ignite.lang.IgniteBiTuple;
-import org.apache.ignite.ml.math.Matrix;
-import org.apache.ignite.ml.math.StorageConstants;
-import org.apache.ignite.ml.math.Tracer;
-import org.apache.ignite.ml.math.functions.IgniteSupplier;
-import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
-import org.apache.ignite.ml.nn.Activators;
-import org.apache.ignite.ml.trainers.local.LocalBatchTrainerInput;
-import org.apache.ignite.ml.optimization.LossFunctions;
-import org.apache.ignite.ml.nn.MultilayerPerceptron;
-import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
-import org.apache.ignite.ml.nn.initializers.RandomInitializer;
-import org.apache.ignite.ml.nn.trainers.local.MLPLocalBatchTrainer;
-import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator;
-import org.apache.ignite.ml.util.Utils;
-
-/**
- * Example of using local {@link MultilayerPerceptron}.
- */
-public class MLPLocalTrainerExample {
-    /**
-     * Executes example.
-     *
-     * @param args Command line arguments, none required.
-     */
-    public static void main(String[] args) {
-        // IMPL NOTE based on MLPLocalTrainerTest#testXORRProp
-        System.out.println(">>> Local multilayer perceptron example started.");
-
-        Matrix xorInputs = new DenseLocalOnHeapMatrix(new double[][] {{0.0, 0.0}, {0.0, 1.0}, {1.0, 0.0}, {1.0, 1.0}},
-            StorageConstants.ROW_STORAGE_MODE).transpose();
-
-        System.out.println("\n>>> Input data:");
-
-        Tracer.showAscii(xorInputs);
-
-        Matrix xorOutputs = new DenseLocalOnHeapMatrix(new double[][] {{0.0}, {1.0}, {1.0}, {0.0}},
-            StorageConstants.ROW_STORAGE_MODE).transpose();
-
-        MLPArchitecture conf = new MLPArchitecture(2).
-            withAddedLayer(10, true, Activators.RELU).
-            withAddedLayer(1, false, Activators.SIGMOID);
-
-        SimpleMLPLocalBatchTrainerInput trainerInput = new SimpleMLPLocalBatchTrainerInput(conf,
-            new Random(1234L), xorInputs, xorOutputs, 4);
-
-        System.out.println("\n>>> Perform training.");
-
-        MultilayerPerceptron mlp = new MLPLocalBatchTrainer<>(LossFunctions.MSE,
-            RPropUpdateCalculator::new,
-            0.0001,
-            16000).train(trainerInput);
-
-        System.out.println("\n>>> Apply model.");
-
-        Matrix predict = mlp.apply(xorInputs);
-
-        System.out.println("\n>>> Predicted data:");
-
-        Tracer.showAscii(predict);
-
-        System.out.println("\n>>> Reference expected data:");
-
-        Tracer.showAscii(xorOutputs);
-
-        System.out.println("\n>>> Difference estimate: " + xorOutputs.getRow(0).minus(predict.getRow(0)).kNorm(2));
-
-        System.out.println("\n>>> Local multilayer perceptron example completed.");
-    }
-
-    /**
-     * Class for local batch training of {@link MultilayerPerceptron}.
-     *
-     * It is constructed from two matrices: one containing inputs of function to approximate and other containing ground truth
-     * values of this function for corresponding inputs.
-     *
-     * We fix batch size given by this input by some constant value.
-     */
-    private static class SimpleMLPLocalBatchTrainerInput implements LocalBatchTrainerInput<MultilayerPerceptron> {
-        /**
-         * Multilayer perceptron to be trained.
-         */
-        private final MultilayerPerceptron mlp;
-
-        /**
-         * Inputs stored as columns.
-         */
-        private final Matrix inputs;
-
-        /**
-         * Ground truths stored as columns.
-         */
-        private final Matrix groundTruth;
-
-        /**
-         * Size of batch returned on each step.
-         */
-        private final int batchSize;
-
-        /**
-         * Construct instance of this class.
-         *
-         * @param arch Architecture of multilayer perceptron.
-         * @param rnd Random numbers generator.
-         * @param inputs Inputs stored as columns.
-         * @param groundTruth Ground truth stored as columns.
-         * @param batchSize Size of batch returned on each step.
-         */
-        SimpleMLPLocalBatchTrainerInput(MLPArchitecture arch, Random rnd, Matrix inputs, Matrix groundTruth, int batchSize) {
-            this.mlp = new MultilayerPerceptron(arch, new RandomInitializer(rnd));
-            this.inputs = inputs;
-            this.groundTruth = groundTruth;
-            this.batchSize = batchSize;
-        }
-
-        /** {@inheritDoc} */
-        @Override public IgniteSupplier<IgniteBiTuple<Matrix, Matrix>> batchSupplier() {
-            return () -> {
-                int inputRowSize = inputs.rowSize();
-                int outputRowSize = groundTruth.rowSize();
-
-                Matrix vectors = new DenseLocalOnHeapMatrix(inputRowSize, batchSize);
-                Matrix labels = new DenseLocalOnHeapMatrix(outputRowSize, batchSize);
-
-                int[] samples = Utils.selectKDistinct(inputs.columnSize(), batchSize);
-
-                for (int i = 0; i < batchSize; i++) {
-                    vectors.assignColumn(i, inputs.getCol(samples[i]));
-                    labels.assignColumn(i, groundTruth.getCol(samples[i]));
-                }
-
-                return new IgniteBiTuple<>(vectors, labels);
-            };
-        }
-
-        /** {@inheritDoc} */
-        @Override public MultilayerPerceptron mdl() {
-            return mlp;
-        }
-    }
-}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java
new file mode 100644 (file)
index 0000000..efa1ba7
--- /dev/null
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.nn;
+
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
+import org.apache.ignite.configuration.CacheConfiguration;
+import org.apache.ignite.examples.ExampleNodeStartup;
+import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
+import org.apache.ignite.ml.nn.Activators;
+import org.apache.ignite.ml.nn.MLPTrainer;
+import org.apache.ignite.ml.nn.MultilayerPerceptron;
+import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
+import org.apache.ignite.ml.optimization.LossFunctions;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
+import org.apache.ignite.ml.trainers.group.UpdatesStrategy;
+import org.apache.ignite.thread.IgniteThread;
+
+/**
+ * Example of using distributed {@link MultilayerPerceptron}.
+ * <p>
+ * Remote nodes should always be started with special configuration file which
+ * enables P2P class loading: {@code 'ignite.{sh|bat} examples/config/example-ignite.xml'}.</p>
+ * <p>
+ * Alternatively you can run {@link ExampleNodeStartup} in another JVM which will start node
+ * with {@code examples/config/example-ignite.xml} configuration.</p>
+ */
+public class MLPTrainerExample {
+    /**
+     * Executes example.
+     *
+     * @param args Command line arguments, none required.
+     */
+    public static void main(String[] args) throws InterruptedException {
+        // IMPL NOTE based on MLPGroupTrainerTest#testXOR
+        System.out.println(">>> Distributed multilayer perceptron example started.");
+
+        // Start ignite grid.
+        try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
+            System.out.println(">>> Ignite grid started.");
+
+            // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread
+            // because we create ignite cache internally.
+            IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
+                MLPTrainerExample.class.getSimpleName(), () -> {
+
+                // Create cache with training data.
+                CacheConfiguration<Integer, LabeledPoint> trainingSetCfg = new CacheConfiguration<>();
+                trainingSetCfg.setName("TRAINING_SET");
+                trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
+
+                IgniteCache<Integer, LabeledPoint> trainingSet = ignite.createCache(trainingSetCfg);
+
+                // Fill cache with training data.
+                trainingSet.put(0, new LabeledPoint(0, 0, 0));
+                trainingSet.put(1, new LabeledPoint(0, 1, 1));
+                trainingSet.put(2, new LabeledPoint(1, 0, 1));
+                trainingSet.put(3, new LabeledPoint(1, 1, 0));
+
+                // Define a layered architecture.
+                MLPArchitecture arch = new MLPArchitecture(2).
+                    withAddedLayer(10, true, Activators.RELU).
+                    withAddedLayer(1, false, Activators.SIGMOID);
+
+                // Define a neural network trainer.
+                MLPTrainer<SimpleGDParameterUpdate> trainer = new MLPTrainer<>(
+                    arch,
+                    LossFunctions.MSE,
+                    new UpdatesStrategy<>(
+                        new SimpleGDUpdateCalculator(0.1),
+                        SimpleGDParameterUpdate::sumLocal,
+                        SimpleGDParameterUpdate::avg
+                    ),
+                    3000,
+                    4,
+                    50,
+                    123L
+                );
+
+                // Train neural network and get multilayer perceptron model.
+                MultilayerPerceptron mlp = trainer.fit(
+                    new CacheBasedDatasetBuilder<>(ignite, trainingSet),
+                    (k, v) -> new double[] {v.x, v.y},
+                    (k, v) -> new double[] {v.lb}
+                );
+
+                int totalCnt = 4;
+                int failCnt = 0;
+
+                // Calculate score.
+                for (int i = 0; i < 4; i++) {
+                    LabeledPoint pnt = trainingSet.get(i);
+                    Matrix predicted = mlp.apply(new DenseLocalOnHeapMatrix(new double[][] {{pnt.x, pnt.y}}));
+                    failCnt += Math.abs(predicted.get(0, 0) - pnt.lb) < 0.5 ? 0 : 1;
+                }
+
+                double failRatio = (double)failCnt / totalCnt;
+
+                System.out.println("\n>>> Fail percentage: " + (failRatio * 100) + "%.");
+
+                System.out.println("\n>>> Distributed multilayer perceptron example completed.");
+            });
+
+            igniteThread.start();
+
+            igniteThread.join();
+        }
+    }
+
+    /** Point data class. */
+    private static class Point {
+        /** X coordinate. */
+        final double x;
+
+        /** Y coordinate. */
+        final double y;
+
+        /**
+         * Constructs a new instance of point.
+         *
+         * @param x X coordinate.
+         * @param y Y coordinate.
+         */
+        Point(double x, double y) {
+            this.x = x;
+            this.y = y;
+        }
+    }
+
+    /** Labeled point data class. */
+    private static class LabeledPoint extends Point {
+        /** Point label. */
+        final double lb;
+
+        /**
+         * Constructs a new instance of labeled point data.
+         *
+         * @param x X coordinate.
+         * @param y Y coordinate.
+         * @param lb Point label.
+         */
+        LabeledPoint(double x, double y, double lb) {
+            super(x, y);
+            this.lb = lb;
+        }
+    }
+}
index 008b4ca..e0bcd08 100644 (file)
@@ -62,8 +62,7 @@ public class NormalizationExample {
             // Creates a cache based simple dataset containing features and providing standard dataset API.
             try (SimpleDataset<?> dataset = DatasetFactory.createSimpleDataset(
                 builder,
-                preprocessor,
-                2
+                preprocessor
             )) {
                 // Calculation of the mean value. This calculation will be performed in map-reduce manner.
                 double[] mean = dataset.mean();
index 61195c4..567a599 100644 (file)
@@ -126,14 +126,13 @@ public class DistributedLinearRegressionWithLSQRTrainerAndNormalizationExample {
                 );
 
                 System.out.println(">>> Create new linear regression trainer object.");
-                LinearRegressionLSQRTrainer<Integer, double[]> trainer = new LinearRegressionLSQRTrainer<>();
+                LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
 
                 System.out.println(">>> Perform the training to get the model.");
                 LinearRegressionModel mdl = trainer.fit(
                     new CacheBasedDatasetBuilder<>(ignite, dataCache),
                     preprocessor,
-                    (k, v) -> v[0],
-                    4
+                    (k, v) -> v[0]
                 );
 
                 System.out.println(">>> Linear regression model: " + mdl);
index 20e0653..a853092 100644 (file)
@@ -112,14 +112,13 @@ public class DistributedLinearRegressionWithLSQRTrainerExample {
                 IgniteCache<Integer, double[]> dataCache = getTestCache(ignite);
 
                 System.out.println(">>> Create new linear regression trainer object.");
-                LinearRegressionLSQRTrainer<Integer, double[]> trainer = new LinearRegressionLSQRTrainer<>();
+                LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
 
                 System.out.println(">>> Perform the training to get the model.");
                 LinearRegressionModel mdl = trainer.fit(
                     new CacheBasedDatasetBuilder<>(ignite, dataCache),
                     (k, v) -> Arrays.copyOfRange(v, 1, v.length),
-                    (k, v) -> v[0],
-                    4
+                    (k, v) -> v[0]
                 );
 
                 System.out.println(">>> Linear regression model: " + mdl);
index c00b327..f8bf521 100644 (file)
@@ -51,13 +51,13 @@ public class SVMBinaryClassificationExample {
                 SVMBinaryClassificationExample.class.getSimpleName(), () -> {
                 IgniteCache<Integer, double[]> dataCache = getTestCache(ignite);
 
-                SVMLinearBinaryClassificationTrainer<Integer, double[]> trainer = new SVMLinearBinaryClassificationTrainer<>();
+                SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer();
 
                 SVMLinearBinaryClassificationModel mdl = trainer.fit(
                     new CacheBasedDatasetBuilder<>(ignite, dataCache),
                     (k, v) -> Arrays.copyOfRange(v, 1, v.length),
-                    (k, v) -> v[0],
-                    4);
+                    (k, v) -> v[0]
+                );
 
                 System.out.println(">>> SVM model " + mdl);
 
index 8d5df6e..f8281e4 100644 (file)
@@ -31,7 +31,6 @@ import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
 import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
 import org.apache.ignite.ml.preprocessing.normalization.NormalizationPreprocessor;
 import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer;
-import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationModel;
 import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationModel;
 import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationTrainer;
 import org.apache.ignite.thread.IgniteThread;
@@ -55,13 +54,13 @@ public class SVMMultiClassClassificationExample {
                 SVMMultiClassClassificationExample.class.getSimpleName(), () -> {
                 IgniteCache<Integer, double[]> dataCache = getTestCache(ignite);
 
-                SVMLinearMultiClassClassificationTrainer<Integer, double[]> trainer = new SVMLinearMultiClassClassificationTrainer<>();
+                SVMLinearMultiClassClassificationTrainer trainer = new SVMLinearMultiClassClassificationTrainer();
 
                 SVMLinearMultiClassClassificationModel mdl = trainer.fit(
                     new CacheBasedDatasetBuilder<>(ignite, dataCache),
                     (k, v) -> Arrays.copyOfRange(v, 1, v.length),
-                    (k, v) -> v[0],
-                    5);
+                    (k, v) -> v[0]
+                );
 
                 System.out.println(">>> SVM Multi-class model");
                 System.out.println(mdl.toString());
@@ -77,8 +76,8 @@ public class SVMMultiClassClassificationExample {
                 SVMLinearMultiClassClassificationModel mdlWithNormalization = trainer.fit(
                     new CacheBasedDatasetBuilder<>(ignite, dataCache),
                     preprocessor,
-                    (k, v) -> v[0],
-                    5);
+                    (k, v) -> v[0]
+                );
 
                 System.out.println(">>> SVM Multi-class model with normalization");
                 System.out.println(mdlWithNormalization.toString());
index 3860e8e..b1b2c42 100644 (file)
@@ -180,10 +180,10 @@ public class DecisionTreesExample {
             int ptsCnt = 60000;
             int featCnt = 28 * 28;
 
-            Stream<DenseLocalOnHeapVector> trainingMnistStream = MnistUtils.mnist(trainingImagesPath, trainingLabelsPath,
+            Stream<DenseLocalOnHeapVector> trainingMnistStream = MnistUtils.mnistAsStream(trainingImagesPath, trainingLabelsPath,
                 new Random(123L), ptsCnt);
 
-            Stream<DenseLocalOnHeapVector> testMnistStream = MnistUtils.mnist(testImagesPath, testLabelsPath,
+            Stream<DenseLocalOnHeapVector> testMnistStream = MnistUtils.mnistAsStream(testImagesPath, testLabelsPath,
                 new Random(123L), 10_000);
 
             IgniteCache<BiIndex, Double> cache = createBiIndexedCache(ignite);
@@ -241,7 +241,7 @@ public class DecisionTreesExample {
         for (String s : missing) {
             String f = s + ".gz";
             System.out.println(">>> Downloading " + f + "...");
-            URL website = new URL("http://yann.lecun.com/exdb/mnist/" + f);
+            URL website = new URL("http://yann.lecun.com/exdb/mnistAsStream/" + f);
             ReadableByteChannel rbc = Channels.newChannel(website.openStream());
             FileOutputStream fos = new FileOutputStream(MNIST_DIR + "/" + f);
             fos.getChannel().transferFrom(rbc, 0, Long.MAX_VALUE);
index af44a8a..9e580c4 100644 (file)
@@ -87,7 +87,10 @@ public class DatasetFactory {
     public static <K, V, C extends Serializable, D extends AutoCloseable> Dataset<C, D> create(
         DatasetBuilder<K, V> datasetBuilder, PartitionContextBuilder<K, V, C> partCtxBuilder,
         PartitionDataBuilder<K, V, C, D> partDataBuilder) {
-        return datasetBuilder.build(partCtxBuilder, partDataBuilder);
+        return datasetBuilder.build(
+            partCtxBuilder,
+            partDataBuilder
+        );
     }
     /**
      * Creates a new instance of distributed dataset using the specified {@code partCtxBuilder} and
@@ -107,7 +110,11 @@ public class DatasetFactory {
     public static <K, V, C extends Serializable, D extends AutoCloseable> Dataset<C, D> create(
         Ignite ignite, IgniteCache<K, V> upstreamCache, PartitionContextBuilder<K, V, C> partCtxBuilder,
         PartitionDataBuilder<K, V, C, D> partDataBuilder) {
-        return create(new CacheBasedDatasetBuilder<>(ignite, upstreamCache), partCtxBuilder, partDataBuilder);
+        return create(
+            new CacheBasedDatasetBuilder<>(ignite, upstreamCache),
+            partCtxBuilder,
+            partDataBuilder
+        );
     }
 
     /**
@@ -118,7 +125,6 @@ public class DatasetFactory {
      * @param datasetBuilder Dataset builder.
      * @param partCtxBuilder Partition {@code context} builder.
      * @param featureExtractor Feature extractor used to extract features and build {@link SimpleDatasetData}.
-     * @param cols Number of columns (features) will be extracted.
      * @param <K> Type of a key in {@code upstream} data.
      * @param <V> Type of a value in {@code upstream} data.
      * @param <C> Type of a partition {@code context}.
@@ -126,11 +132,11 @@ public class DatasetFactory {
      */
     public static <K, V, C extends Serializable> SimpleDataset<C> createSimpleDataset(
         DatasetBuilder<K, V> datasetBuilder, PartitionContextBuilder<K, V, C> partCtxBuilder,
-        IgniteBiFunction<K, V, double[]> featureExtractor, int cols) {
+        IgniteBiFunction<K, V, double[]> featureExtractor) {
         return create(
             datasetBuilder,
             partCtxBuilder,
-            new SimpleDatasetDataBuilder<>(featureExtractor, cols)
+            new SimpleDatasetDataBuilder<>(featureExtractor)
         ).wrap(SimpleDataset::new);
     }
 
@@ -143,7 +149,6 @@ public class DatasetFactory {
      * @param upstreamCache Ignite Cache with {@code upstream} data.
      * @param partCtxBuilder Partition {@code context} builder.
      * @param featureExtractor Feature extractor used to extract features and build {@link SimpleDatasetData}.
-     * @param cols Number of columns (features) will be extracted.
      * @param <K> Type of a key in {@code upstream} data.
      * @param <V> Type of a value in {@code upstream} data.
      * @param <C> Type of a partition {@code context}.
@@ -151,9 +156,12 @@ public class DatasetFactory {
      */
     public static <K, V, C extends Serializable> SimpleDataset<C> createSimpleDataset(Ignite ignite,
         IgniteCache<K, V> upstreamCache, PartitionContextBuilder<K, V, C> partCtxBuilder,
-        IgniteBiFunction<K, V, double[]> featureExtractor, int cols) {
-        return createSimpleDataset(new CacheBasedDatasetBuilder<>(ignite, upstreamCache), partCtxBuilder,
-            featureExtractor, cols);
+        IgniteBiFunction<K, V, double[]> featureExtractor) {
+        return createSimpleDataset(
+            new CacheBasedDatasetBuilder<>(ignite, upstreamCache),
+            partCtxBuilder,
+            featureExtractor
+        );
     }
 
     /**
@@ -165,7 +173,6 @@ public class DatasetFactory {
      * @param partCtxBuilder Partition {@code context} builder.
      * @param featureExtractor Feature extractor used to extract features and build {@link SimpleLabeledDatasetData}.
      * @param lbExtractor Label extractor used to extract labels and buikd {@link SimpleLabeledDatasetData}.
-     * @param cols Number of columns (features) will be extracted.
      * @param <K> Type of a key in {@code upstream} data.
      * @param <V> Type of a value in {@code upstream} data.
      * @param <C> Type of a partition {@code context}.
@@ -173,11 +180,11 @@ public class DatasetFactory {
      */
     public static <K, V, C extends Serializable> SimpleLabeledDataset<C> createSimpleLabeledDataset(
         DatasetBuilder<K, V> datasetBuilder, PartitionContextBuilder<K, V, C> partCtxBuilder,
-        IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor, int cols) {
+        IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, double[]> lbExtractor) {
         return create(
             datasetBuilder,
             partCtxBuilder,
-            new SimpleLabeledDatasetDataBuilder<>(featureExtractor, lbExtractor, cols)
+            new SimpleLabeledDatasetDataBuilder<>(featureExtractor, lbExtractor)
         ).wrap(SimpleLabeledDataset::new);
     }
 
@@ -191,7 +198,6 @@ public class DatasetFactory {
      * @param partCtxBuilder Partition {@code context} builder.
      * @param featureExtractor Feature extractor used to extract features and build {@link SimpleLabeledDatasetData}.
      * @param lbExtractor Label extractor used to extract labels and buikd {@link SimpleLabeledDatasetData}.
-     * @param cols Number of columns (features) will be extracted.
      * @param <K> Type of a key in {@code upstream} data.
      * @param <V> Type of a value in {@code upstream} data.
      * @param <C> Type of a partition {@code context}.
@@ -199,9 +205,13 @@ public class DatasetFactory {
      */
     public static <K, V, C extends Serializable> SimpleLabeledDataset<C> createSimpleLabeledDataset(Ignite ignite,
         IgniteCache<K, V> upstreamCache, PartitionContextBuilder<K, V, C> partCtxBuilder,
-        IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor, int cols) {
-        return createSimpleLabeledDataset(new CacheBasedDatasetBuilder<>(ignite, upstreamCache), partCtxBuilder,
-            featureExtractor, lbExtractor, cols);
+        IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, double[]> lbExtractor) {
+        return createSimpleLabeledDataset(
+            new CacheBasedDatasetBuilder<>(ignite, upstreamCache),
+            partCtxBuilder,
+            featureExtractor,
+            lbExtractor
+        );
     }
 
     /**
@@ -211,14 +221,17 @@ public class DatasetFactory {
      *
      * @param datasetBuilder Dataset builder.
      * @param featureExtractor Feature extractor used to extract features and build {@link SimpleDatasetData}.
-     * @param cols Number of columns (features) will be extracted.
      * @param <K> Type of a key in {@code upstream} data.
      * @param <V> Type of a value in {@code upstream} data.
      * @return Dataset.
      */
     public static <K, V> SimpleDataset<EmptyContext> createSimpleDataset(DatasetBuilder<K, V> datasetBuilder,
-        IgniteBiFunction<K, V, double[]> featureExtractor, int cols) {
-        return createSimpleDataset(datasetBuilder, new EmptyContextBuilder<>(), featureExtractor, cols);
+        IgniteBiFunction<K, V, double[]> featureExtractor) {
+        return createSimpleDataset(
+            datasetBuilder,
+            new EmptyContextBuilder<>(),
+            featureExtractor
+        );
     }
 
     /**
@@ -229,14 +242,16 @@ public class DatasetFactory {
      * @param ignite Ignite instance.
      * @param upstreamCache Ignite Cache with {@code upstream} data.
      * @param featureExtractor Feature extractor used to extract features and build {@link SimpleDatasetData}.
-     * @param cols Number of columns (features) will be extracted.
      * @param <K> Type of a key in {@code upstream} data.
      * @param <V> Type of a value in {@code upstream} data.
      * @return Dataset.
      */
     public static <K, V> SimpleDataset<EmptyContext> createSimpleDataset(Ignite ignite, IgniteCache<K, V> upstreamCache,
-        IgniteBiFunction<K, V, double[]> featureExtractor, int cols) {
-        return createSimpleDataset(new CacheBasedDatasetBuilder<>(ignite, upstreamCache), featureExtractor, cols);
+        IgniteBiFunction<K, V, double[]> featureExtractor) {
+        return createSimpleDataset(
+            new CacheBasedDatasetBuilder<>(ignite, upstreamCache),
+            featureExtractor
+        );
     }
 
     /**
@@ -247,16 +262,19 @@ public class DatasetFactory {
      * @param datasetBuilder Dataset builder.
      * @param featureExtractor Feature extractor used to extract features and build {@link SimpleLabeledDatasetData}.
      * @param lbExtractor Label extractor used to extract labels and buikd {@link SimpleLabeledDatasetData}.
-     * @param cols Number of columns (features) will be extracted.
      * @param <K> Type of a key in {@code upstream} data.
      * @param <V> Type of a value in {@code upstream} data.
      * @return Dataset.
      */
     public static <K, V> SimpleLabeledDataset<EmptyContext> createSimpleLabeledDataset(
         DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, double[]> featureExtractor,
-        IgniteBiFunction<K, V, Double> lbExtractor, int cols) {
-        return createSimpleLabeledDataset(datasetBuilder, new EmptyContextBuilder<>(), featureExtractor, lbExtractor,
-            cols);
+        IgniteBiFunction<K, V, double[]> lbExtractor) {
+        return createSimpleLabeledDataset(
+            datasetBuilder,
+            new EmptyContextBuilder<>(),
+            featureExtractor,
+            lbExtractor
+        );
     }
 
     /**
@@ -268,16 +286,18 @@ public class DatasetFactory {
      * @param upstreamCache Ignite Cache with {@code upstream} data.
      * @param featureExtractor Feature extractor used to extract features and build {@link SimpleLabeledDatasetData}.
      * @param lbExtractor Label extractor used to extract labels and buikd {@link SimpleLabeledDatasetData}.
-     * @param cols Number of columns (features) will be extracted.
      * @param <K> Type of a key in {@code upstream} data.
      * @param <V> Type of a value in {@code upstream} data.
      * @return Dataset.
      */
     public static <K, V> SimpleLabeledDataset<EmptyContext> createSimpleLabeledDataset(Ignite ignite,
         IgniteCache<K, V> upstreamCache, IgniteBiFunction<K, V, double[]> featureExtractor,
-        IgniteBiFunction<K, V, Double> lbExtractor, int cols) {
-        return createSimpleLabeledDataset(new CacheBasedDatasetBuilder<>(ignite, upstreamCache), featureExtractor,
-            lbExtractor, cols);
+        IgniteBiFunction<K, V, double[]> lbExtractor) {
+        return createSimpleLabeledDataset(
+            new CacheBasedDatasetBuilder<>(ignite, upstreamCache),
+            featureExtractor,
+            lbExtractor
+        );
     }
 
     /**
@@ -298,7 +318,11 @@ public class DatasetFactory {
     public static <K, V, C extends Serializable, D extends AutoCloseable> Dataset<C, D> create(
         Map<K, V> upstreamMap, int partitions, PartitionContextBuilder<K, V, C> partCtxBuilder,
         PartitionDataBuilder<K, V, C, D> partDataBuilder) {
-        return create(new LocalDatasetBuilder<>(upstreamMap, partitions), partCtxBuilder, partDataBuilder);
+        return create(
+            new LocalDatasetBuilder<>(upstreamMap, partitions),
+            partCtxBuilder,
+            partDataBuilder
+        );
     }
 
     /**
@@ -310,7 +334,6 @@ public class DatasetFactory {
      * @param partitions Number of partitions {@code upstream} {@code Map} will be divided on.
      * @param partCtxBuilder Partition {@code context} builder.
      * @param featureExtractor Feature extractor used to extract features and build {@link SimpleDatasetData}.
-     * @param cols Number of columns (features) will be extracted.
      * @param <K> Type of a key in {@code upstream} data.
      * @param <V> Type of a value in {@code upstream} data.
      * @param <C> Type of a partition {@code context}.
@@ -318,9 +341,12 @@ public class DatasetFactory {
      */
     public static <K, V, C extends Serializable> SimpleDataset<C> createSimpleDataset(Map<K, V> upstreamMap,
         int partitions, PartitionContextBuilder<K, V, C> partCtxBuilder,
-        IgniteBiFunction<K, V, double[]> featureExtractor, int cols) {
-        return createSimpleDataset(new LocalDatasetBuilder<>(upstreamMap, partitions), partCtxBuilder, featureExtractor,
-            cols);
+        IgniteBiFunction<K, V, double[]> featureExtractor) {
+        return createSimpleDataset(
+            new LocalDatasetBuilder<>(upstreamMap, partitions),
+            partCtxBuilder,
+            featureExtractor
+        );
     }
 
     /**
@@ -333,7 +359,6 @@ public class DatasetFactory {
      * @param partCtxBuilder Partition {@code context} builder.
      * @param featureExtractor Feature extractor used to extract features and build {@link SimpleLabeledDatasetData}.
      * @param lbExtractor Label extractor used to extract labels and buikd {@link SimpleLabeledDatasetData}.
-     * @param cols Number of columns (features) will be extracted.
      * @param <K> Type of a key in {@code upstream} data.
      * @param <V> Type of a value in {@code upstream} data.
      * @param <C> Type of a partition {@code context}.
@@ -341,9 +366,12 @@ public class DatasetFactory {
      */
     public static <K, V, C extends Serializable> SimpleLabeledDataset<C> createSimpleLabeledDataset(
         Map<K, V> upstreamMap, int partitions, PartitionContextBuilder<K, V, C> partCtxBuilder,
-        IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor, int cols) {
-        return createSimpleLabeledDataset(new LocalDatasetBuilder<>(upstreamMap, partitions), partCtxBuilder,
-            featureExtractor, lbExtractor, cols);
+        IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, double[]> lbExtractor) {
+        return createSimpleLabeledDataset(
+            new LocalDatasetBuilder<>(upstreamMap, partitions),
+            partCtxBuilder,
+            featureExtractor, lbExtractor
+        );
     }
 
     /**
@@ -354,14 +382,16 @@ public class DatasetFactory {
      * @param upstreamMap {@code Map} with {@code upstream} data.
      * @param partitions Number of partitions {@code upstream} {@code Map} will be divided on.
      * @param featureExtractor Feature extractor used to extract features and build {@link SimpleDatasetData}.
-     * @param cols Number of columns (features) will be extracted.
      * @param <K> Type of a key in {@code upstream} data.
      * @param <V> Type of a value in {@code upstream} data.
      * @return Dataset.
      */
     public static <K, V> SimpleDataset<EmptyContext> createSimpleDataset(Map<K, V> upstreamMap, int partitions,
-        IgniteBiFunction<K, V, double[]> featureExtractor, int cols) {
-        return createSimpleDataset(new LocalDatasetBuilder<>(upstreamMap, partitions), featureExtractor, cols);
+        IgniteBiFunction<K, V, double[]> featureExtractor) {
+        return createSimpleDataset(
+            new LocalDatasetBuilder<>(upstreamMap, partitions),
+            featureExtractor
+        );
     }
 
     /**
@@ -373,15 +403,17 @@ public class DatasetFactory {
      * @param partitions Number of partitions {@code upstream} {@code Map} will be divided on.
      * @param featureExtractor Feature extractor used to extract features and build {@link SimpleLabeledDatasetData}.
      * @param lbExtractor Label extractor used to extract labels and build {@link SimpleLabeledDatasetData}.
-     * @param cols Number of columns (features) will be extracted.
      * @param <K> Type of a key in {@code upstream} data.
      * @param <V> Type of a value in {@code upstream} data.
      * @return Dataset.
      */
     public static <K, V> SimpleLabeledDataset<EmptyContext> createSimpleLabeledDataset(Map<K, V> upstreamMap,
-        int partitions, IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor,
-        int cols) {
-        return createSimpleLabeledDataset(new LocalDatasetBuilder<>(upstreamMap, partitions), featureExtractor,
-            lbExtractor, cols);
+        int partitions, IgniteBiFunction<K, V, double[]> featureExtractor,
+        IgniteBiFunction<K, V, double[]> lbExtractor) {
+        return createSimpleLabeledDataset(
+            new LocalDatasetBuilder<>(upstreamMap, partitions),
+            featureExtractor,
+            lbExtractor
+        );
     }
 }
index 47c0c4b..50e4aba 100644 (file)
@@ -49,7 +49,7 @@ public class SimpleDataset<C extends Serializable> extends DatasetWrapper<C, Sim
         ValueWithCount<double[]> res = delegate.compute((data, partIdx) -> {
             double[] features = data.getFeatures();
             int rows = data.getRows();
-            int cols = data.getCols();
+            int cols = features.length / rows;
 
             double[] y = new double[cols];
 
@@ -78,7 +78,7 @@ public class SimpleDataset<C extends Serializable> extends DatasetWrapper<C, Sim
         ValueWithCount<double[]> res = delegate.compute(data -> {
             double[] features = data.getFeatures();
             int rows = data.getRows();
-            int cols = data.getCols();
+            int cols = features.length / rows;
 
             double[] y = new double[cols];
 
@@ -109,7 +109,7 @@ public class SimpleDataset<C extends Serializable> extends DatasetWrapper<C, Sim
         ValueWithCount<double[][]> res = delegate.compute(data -> {
             double[] features = data.getFeatures();
             int rows = data.getRows();
-            int cols = data.getCols();
+            int cols = features.length / rows;
 
             double[][] y = new double[cols][cols];
 
index 6f29e2f..dc7d8cb 100644 (file)
@@ -39,31 +39,32 @@ public class SimpleDatasetDataBuilder<K, V, C extends Serializable>
     /** Function that extracts features from an {@code upstream} data. */
     private final IgniteBiFunction<K, V, double[]> featureExtractor;
 
-    /** Number of columns (features). */
-    private final int cols;
-
     /**
      * Construct a new instance of partition {@code data} builder that makes {@link SimpleDatasetData}.
      *
      * @param featureExtractor Function that extracts features from an {@code upstream} data.
-     * @param cols Number of columns (features).
      */
-    public SimpleDatasetDataBuilder(IgniteBiFunction<K, V, double[]> featureExtractor, int cols) {
+    public SimpleDatasetDataBuilder(IgniteBiFunction<K, V, double[]> featureExtractor) {
         this.featureExtractor = featureExtractor;
-        this.cols = cols;
     }
 
     /** {@inheritDoc} */
     @Override public SimpleDatasetData build(Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize, C ctx) {
         // Prepares the matrix of features in flat column-major format.
-        double[] features = new double[Math.toIntExact(upstreamDataSize * cols)];
+        int cols = -1;
+        double[] features = null;
 
         int ptr = 0;
         while (upstreamData.hasNext()) {
             UpstreamEntry<K, V> entry = upstreamData.next();
             double[] row = featureExtractor.apply(entry.getKey(), entry.getValue());
 
-            assert row.length == cols : "Feature extractor must return exactly " + cols + " features";
+            if (cols < 0) {
+                cols = row.length;
+                features = new double[Math.toIntExact(upstreamDataSize * cols)];
+            }
+            else
+                assert row.length == cols : "Feature extractor must return exactly " + cols + " features";
 
             for (int i = 0; i < cols; i++)
                 features[Math.toIntExact(i * upstreamDataSize + ptr)] = row[i];
@@ -71,6 +72,6 @@ public class SimpleDatasetDataBuilder<K, V, C extends Serializable>
             ptr++;
         }
 
-        return new SimpleDatasetData(features, Math.toIntExact(upstreamDataSize), cols);
+        return new SimpleDatasetData(features, Math.toIntExact(upstreamDataSize));
     }
 }
index 12fcc4c..d301bbe 100644 (file)
@@ -40,47 +40,61 @@ public class SimpleLabeledDatasetDataBuilder<K, V, C extends Serializable>
     private final IgniteBiFunction<K, V, double[]> featureExtractor;
 
     /** Function that extracts labels from an {@code upstream} data. */
-    private final IgniteBiFunction<K, V, Double> lbExtractor;
-
-    /** Number of columns (features). */
-    private final int cols;
+    private final IgniteBiFunction<K, V, double[]> lbExtractor;
 
     /**
      * Constructs a new instance of partition {@code data} builder that makes {@link SimpleLabeledDatasetData}.
      *
      * @param featureExtractor Function that extracts features from an {@code upstream} data.
      * @param lbExtractor Function that extracts labels from an {@code upstream} data.
-     * @param cols Number of columns (features).
      */
     public SimpleLabeledDatasetDataBuilder(IgniteBiFunction<K, V, double[]> featureExtractor,
-        IgniteBiFunction<K, V, Double> lbExtractor, int cols) {
+        IgniteBiFunction<K, V, double[]> lbExtractor) {
         this.featureExtractor = featureExtractor;
         this.lbExtractor = lbExtractor;
-        this.cols = cols;
     }
 
     /** {@inheritDoc} */
     @Override public SimpleLabeledDatasetData build(Iterator<UpstreamEntry<K, V>> upstreamData,
         long upstreamDataSize, C ctx) {
         // Prepares the matrix of features in flat column-major format.
-        double[] features = new double[Math.toIntExact(upstreamDataSize * cols)];
-        double[] labels = new double[Math.toIntExact(upstreamDataSize)];
+        int featureCols = -1;
+        int lbCols = -1;
+        double[] features = null;
+        double[] labels = null;
 
         int ptr = 0;
         while (upstreamData.hasNext()) {
             UpstreamEntry<K, V> entry = upstreamData.next();
-            double[] row = featureExtractor.apply(entry.getKey(), entry.getValue());
 
-            assert row.length == cols : "Feature extractor must return exactly " + cols + " features";
+            double[] featureRow = featureExtractor.apply(entry.getKey(), entry.getValue());
+
+            if (featureCols < 0) {
+                featureCols = featureRow.length;
+                features = new double[Math.toIntExact(upstreamDataSize * featureCols)];
+            }
+            else
+                assert featureRow.length == featureCols : "Feature extractor must return exactly " + featureCols
+                    + " features";
+
+            for (int i = 0; i < featureCols; i++)
+                features[Math.toIntExact(i * upstreamDataSize) + ptr] = featureRow[i];
+
+            double[] lbRow = lbExtractor.apply(entry.getKey(), entry.getValue());
+
+            if (lbCols < 0) {
+                lbCols = lbRow.length;
+                labels = new double[Math.toIntExact(upstreamDataSize * lbCols)];
+            }
 
-            for (int i = 0; i < cols; i++)
-                features[Math.toIntExact(i * upstreamDataSize) + ptr] = row[i];
+            assert lbRow.length == lbCols : "Label extractor must return exactly " + lbCols + " labels";
 
-            labels[ptr] = lbExtractor.apply(entry.getKey(), entry.getValue());
+            for (int i = 0; i < lbCols; i++)
+                labels[Math.toIntExact(i * upstreamDataSize) + ptr] = lbRow[i];
 
             ptr++;
         }
 
-        return new SimpleLabeledDatasetData(features, Math.toIntExact(upstreamDataSize), cols, labels);
+        return new SimpleLabeledDatasetData(features, labels, Math.toIntExact(upstreamDataSize));
     }
 }
index 7f82720..0b1b3ae 100644 (file)
@@ -30,21 +30,16 @@ public class SimpleDatasetData implements AutoCloseable {
     /** Number of rows. */
     private final int rows;
 
-    /** Number of columns. */
-    private final int cols;
-
     /**
      * Constructs a new instance of partition {@code data} of the {@link SimpleDataset} containing matrix of features in
      * flat column-major format stored in heap.
      *
      * @param features Matrix of features in a dense flat column-major format.
      * @param rows Number of rows.
-     * @param cols Number of columns.
      */
-    public SimpleDatasetData(double[] features, int rows, int cols) {
+    public SimpleDatasetData(double[] features, int rows) {
         this.features = features;
         this.rows = rows;
-        this.cols = cols;
     }
 
     /** */
@@ -57,11 +52,6 @@ public class SimpleDatasetData implements AutoCloseable {
         return rows;
     }
 
-    /** */
-    public int getCols() {
-        return cols;
-    }
-
     /** {@inheritDoc} */
     @Override public void close() {
         // Do nothing, GC will clean up.
index 38041f8..45dfc3a 100644 (file)
@@ -27,29 +27,24 @@ public class SimpleLabeledDatasetData implements AutoCloseable {
     /** Matrix with features in a dense flat column-major format. */
     private final double[] features;
 
-    /** Number of rows. */
-    private final int rows;
-
-    /** Number of columns. */
-    private final int cols;
-
     /** Vector with labels. */
     private final double[] labels;
 
+    /** Number of rows. */
+    private final int rows;
+
     /**
      * Constructs a new instance of partition {@code data} of the {@link SimpleLabeledDataset} containing matrix of
      * features in flat column-major format stored in heap and vector of labels stored in heap as well.
      *
      * @param features Matrix with features in a dense flat column-major format.
-     * @param rows Number of rows.
-     * @param cols Number of columns.
      * @param labels Vector with labels.
+     * @param rows Number of rows.
      */
-    public SimpleLabeledDatasetData(double[] features, int rows, int cols, double[] labels) {
+    public SimpleLabeledDatasetData(double[] features, double[] labels, int rows) {
         this.features = features;
-        this.rows = rows;
-        this.cols = cols;
         this.labels = labels;
+        this.rows = rows;
     }
 
     /** */
@@ -63,11 +58,6 @@ public class SimpleLabeledDatasetData implements AutoCloseable {
     }
 
     /** */
-    public int getCols() {
-        return cols;
-    }
-
-    /** */
     public double[] getLabels() {
         return labels;
     }
index 1c2e2cf..e80b935 100644 (file)
@@ -41,28 +41,24 @@ public class LinSysPartitionDataBuilderOnHeap<K, V, C extends Serializable>
     /** Extractor of Y vector value. */
     private final IgniteBiFunction<K, V, Double> yExtractor;
 
-    /** Number of columns. */
-    private final int cols;
-
     /**
      * Constructs a new instance of linear system partition data builder.
      *
      * @param xExtractor Extractor of X matrix row.
      * @param yExtractor Extractor of Y vector value.
-     * @param cols Number of columns.
      */
     public LinSysPartitionDataBuilderOnHeap(IgniteBiFunction<K, V, double[]> xExtractor,
-        IgniteBiFunction<K, V, Double> yExtractor, int cols) {
+        IgniteBiFunction<K, V, Double> yExtractor) {
         this.xExtractor = xExtractor;
         this.yExtractor = yExtractor;
-        this.cols = cols;
     }
 
     /** {@inheritDoc} */
     @Override public LinSysPartitionDataOnHeap build(Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize,
         C ctx) {
         // Prepares the matrix of features in flat column-major format.
-        double[] x = new double[Math.toIntExact(upstreamDataSize * cols)];
+        int xCols = -1;
+        double[] x = null;//new double[Math.toIntExact(upstreamDataSize * cols)];
         double[] y = new double[Math.toIntExact(upstreamDataSize)];
 
         int ptr = 0;
@@ -70,9 +66,14 @@ public class LinSysPartitionDataBuilderOnHeap<K, V, C extends Serializable>
             UpstreamEntry<K, V> entry = upstreamData.next();
             double[] row = xExtractor.apply(entry.getKey(), entry.getValue());
 
-            assert row.length == cols : "X extractor must return exactly " + cols + " columns";
+            if (xCols < 0) {
+                xCols = row.length;
+                x = new double[Math.toIntExact(upstreamDataSize * xCols)];
+            }
+            else
+                assert row.length == xCols : "X extractor must return exactly " + xCols + " columns";
 
-            for (int i = 0; i < cols; i++)
+            for (int i = 0; i < xCols; i++)
                 x[Math.toIntExact(i * upstreamDataSize) + ptr] = row[i];
 
             y[ptr] = yExtractor.apply(entry.getKey(), entry.getValue());
@@ -80,6 +81,6 @@ public class LinSysPartitionDataBuilderOnHeap<K, V, C extends Serializable>
             ptr++;
         }
 
-        return new LinSysPartitionDataOnHeap(x, Math.toIntExact(upstreamDataSize), cols, y);
+        return new LinSysPartitionDataOnHeap(x, y, Math.toIntExact(upstreamDataSize));
     }
 }
index e0b8f46..89c8e44 100644 (file)
@@ -24,27 +24,22 @@ public class LinSysPartitionDataOnHeap implements AutoCloseable {
     /** Part of X matrix. */
     private final double[] x;
 
-    /** Number of rows. */
-    private final int rows;
-
-    /** Number of columns. */
-    private final int cols;
-
     /** Part of Y vector. */
     private final double[] y;
 
+    /** Number of rows. */
+    private final int rows;
+
     /**
      * Constructs a new instance of linear system partition data.
      *
      * @param x Part of X matrix.
-     * @param rows Number of rows.
-     * @param cols Number of columns.
      * @param y Part of Y vector.
+     * @param rows Number of rows.
      */
-    public LinSysPartitionDataOnHeap(double[] x, int rows, int cols, double[] y) {
+    public LinSysPartitionDataOnHeap(double[] x, double[] y, int rows) {
         this.x = x;
         this.rows = rows;
-        this.cols = cols;
         this.y = y;
     }
 
@@ -59,11 +54,6 @@ public class LinSysPartitionDataOnHeap implements AutoCloseable {
     }
 
     /** */
-    public int getCols() {
-        return cols;
-    }
-
-    /** */
     public double[] getY() {
         return y;
     }
index fa8e713..1db3e8b 100644 (file)
@@ -51,34 +51,43 @@ public class LSQROnHeap<K, V> extends AbstractLSQR implements AutoCloseable {
             ctx.setU(Arrays.copyOf(data.getY(), data.getY().length));
 
             return BLAS.getInstance().dnrm2(data.getY().length, data.getY(), 1);
-        }, (a, b) -> a == null ? b : Math.sqrt(a * a + b * b));
+        }, (a, b) -> a == null ? b : b == null ? a : Math.sqrt(a * a + b * b));
     }
 
     /** {@inheritDoc} */
     @Override protected double beta(double[] x, double alfa, double beta) {
         return dataset.computeWithCtx((ctx, data) -> {
-            BLAS.getInstance().dgemv("N", data.getRows(), data.getCols(), alfa, data.getX(),
+            if (data.getX() == null)
+                return null;
+
+            int cols = data.getX().length / data.getRows();
+            BLAS.getInstance().dgemv("N", data.getRows(), cols, alfa, data.getX(),
                 Math.max(1, data.getRows()), x, 1, beta, ctx.getU(), 1);
 
             return BLAS.getInstance().dnrm2(ctx.getU().length, ctx.getU(), 1);
-        }, (a, b) -> a == null ? b : Math.sqrt(a * a + b * b));
+        }, (a, b) -> a == null ? b : b == null ? a : Math.sqrt(a * a + b * b));
     }
 
     /** {@inheritDoc} */
     @Override protected double[] iter(double bnorm, double[] target) {
         double[] res = dataset.computeWithCtx((ctx, data) -> {
+            if (data.getX() == null)
+                return null;
+
+            int cols =  data.getX().length / data.getRows();
             BLAS.getInstance().dscal(ctx.getU().length, 1 / bnorm, ctx.getU(), 1);
-            double[] v = new double[data.getCols()];
-            BLAS.getInstance().dgemv("T", data.getRows(), data.getCols(), 1.0, data.getX(),
+            double[] v = new double[cols];
+            BLAS.getInstance().dgemv("T", data.getRows(), cols, 1.0, data.getX(),
                 Math.max(1, data.getRows()), ctx.getU(), 1, 0, v, 1);
 
             return v;
         }, (a, b) -> {
             if (a == null)
                 return b;
+            else if (b == null)
+                return a;
             else {
                 BLAS.getInstance().daxpy(a.length, 1.0, a, 1, b, 1);
-
                 return b;
             }
         });
@@ -92,7 +101,7 @@ public class LSQROnHeap<K, V> extends AbstractLSQR implements AutoCloseable {
      * @return number of columns
      */
     @Override protected int getColumns() {
-        return dataset.compute(LinSysPartitionDataOnHeap::getCols, (a, b) -> a == null ? b : a);
+        return dataset.compute(data -> data.getX() == null ? null :  data.getX().length / data.getRows(), (a, b) -> a == null ? b : a);
     }
 
     /** {@inheritDoc} */
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/LabeledVectorsCache.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/LabeledVectorsCache.java
deleted file mode 100644 (file)
index 07a6e2a..0000000
+++ /dev/null
@@ -1,63 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.nn;
-
-import java.util.UUID;
-import org.apache.ignite.Ignite;
-import org.apache.ignite.IgniteCache;
-import org.apache.ignite.cache.CacheAtomicityMode;
-import org.apache.ignite.cache.CacheMode;
-import org.apache.ignite.cache.CacheWriteSynchronizationMode;
-import org.apache.ignite.configuration.CacheConfiguration;
-import org.apache.ignite.ml.math.Vector;
-import org.apache.ignite.ml.structures.LabeledVector;
-
-/**
- * Class for working with labeled vectors cache.
- */
-public class LabeledVectorsCache {
-    /**
-     * Create new labeled vectors cache.
-     *
-     * @param ignite Ignite instance.
-     * @return new labeled vectors cache.
-     */
-    public static IgniteCache<Integer, LabeledVector<Vector, Vector>> createNew(Ignite ignite) {
-        CacheConfiguration<Integer, LabeledVector<Vector, Vector>> cfg = new CacheConfiguration<>();
-
-        // Write to primary.
-        cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.PRIMARY_SYNC);
-
-        // Atomic transactions only.
-        cfg.setAtomicityMode(CacheAtomicityMode.ATOMIC);
-
-        // No copying of values.
-        cfg.setCopyOnRead(false);
-
-        // Cache is partitioned.
-        cfg.setCacheMode(CacheMode.PARTITIONED);
-
-        cfg.setBackups(0);
-
-        cfg.setOnheapCacheEnabled(true);
-
-        cfg.setName("LBLD_VECS_" + UUID.randomUUID().toString());
-
-        return ignite.getOrCreateCache(cfg);
-    }
-}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPGroupUpdateTrainerCacheInput.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPGroupUpdateTrainerCacheInput.java
deleted file mode 100644 (file)
index ce42938..0000000
+++ /dev/null
@@ -1,157 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.nn;
-
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Random;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
-import org.apache.ignite.Ignite;
-import org.apache.ignite.IgniteCache;
-import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.affinity.Affinity;
-import org.apache.ignite.lang.IgniteBiTuple;
-import org.apache.ignite.ml.math.Matrix;
-import org.apache.ignite.ml.math.Vector;
-import org.apache.ignite.ml.math.functions.IgniteSupplier;
-import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
-import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
-import org.apache.ignite.ml.nn.initializers.MLPInitializer;
-import org.apache.ignite.ml.nn.trainers.distributed.AbstractMLPGroupUpdateTrainerInput;
-import org.apache.ignite.ml.nn.trainers.distributed.MLPGroupUpdateTrainer;
-import org.apache.ignite.ml.structures.LabeledVector;
-import org.apache.ignite.ml.util.Utils;
-
-/**
- * Input for {@link MLPGroupUpdateTrainer} where batches are taken from cache of labeled vectors.
- */
-public class MLPGroupUpdateTrainerCacheInput extends AbstractMLPGroupUpdateTrainerInput {
-    /**
-     * Cache of labeled vectors.
-     */
-    private final IgniteCache<Integer, LabeledVector<Vector, Vector>> cache;
-
-    /**
-     * Size of batch to return on each training iteration.
-     */
-    private final int batchSize;
-
-    /**
-     * Multilayer perceptron.
-     */
-    private final MultilayerPerceptron mlp;
-
-    /**
-     * Random number generator.
-     */
-    private final Random rand;
-
-    /**
-     * Construct instance of this class with given parameters.
-     *
-     * @param arch Architecture of multilayer perceptron.
-     * @param init Initializer of multilayer perceptron.
-     * @param networksCnt Count of networks to be trained in parallel by {@link MLPGroupUpdateTrainer}.
-     * @param cache Cache with labeled vectors.
-     * @param batchSize Size of batch to return on each training iteration.
-     * @param rand RNG.
-     */
-    public MLPGroupUpdateTrainerCacheInput(MLPArchitecture arch, MLPInitializer init,
-        int networksCnt, IgniteCache<Integer, LabeledVector<Vector, Vector>> cache,
-        int batchSize, Random rand) {
-        super(networksCnt);
-
-        this.batchSize = batchSize;
-        this.cache = cache;
-        this.mlp = new MultilayerPerceptron(arch, init);
-        this.rand = rand;
-    }
-
-    /**
-     * Construct instance of this class with given parameters.
-     *
-     * @param arch Architecture of multilayer perceptron.
-     * @param init Initializer of multilayer perceptron.
-     * @param networksCnt Count of networks to be trained in parallel by {@link MLPGroupUpdateTrainer}.
-     * @param cache Cache with labeled vectors.
-     * @param batchSize Size of batch to return on each training iteration.
-     */
-    public MLPGroupUpdateTrainerCacheInput(MLPArchitecture arch, MLPInitializer init,
-        int networksCnt, IgniteCache<Integer, LabeledVector<Vector, Vector>> cache,
-        int batchSize) {
-        this(arch, init, networksCnt, cache, batchSize, null);
-    }
-
-    /**
-     * Construct instance of this class with given parameters and default initializer.
-     *
-     * @param arch Architecture of multilayer perceptron.
-     * @param networksCnt Count of networks to be trained in parallel by {@link MLPGroupUpdateTrainer}.
-     * @param cache Cache with labeled vectors.
-     * @param batchSize Size of batch to return on each training iteration.
-     */
-    public MLPGroupUpdateTrainerCacheInput(MLPArchitecture arch, int networksCnt,
-        IgniteCache<Integer, LabeledVector<Vector, Vector>> cache,
-        int batchSize) {
-        this(arch, null, networksCnt, cache, batchSize);
-    }
-
-    /** {@inheritDoc} */
-    @Override public IgniteSupplier<IgniteBiTuple<Matrix, Matrix>> batchSupplier() {
-        String cName = cache.getName();
-
-        int bs = batchSize; // This line is for prohibiting of 'this' object be caught into serialization context of lambda.
-        Random r = rand; // This line is for prohibiting of 'this' object be caught into serialization context of lambda.
-
-        return () -> {
-            Ignite ignite = Ignition.localIgnite();
-            IgniteCache<Integer, LabeledVector<Vector, Vector>> cache = ignite.getOrCreateCache(cName);
-            int total = cache.size();
-            Affinity<Integer> affinity = ignite.affinity(cName);
-
-            List<Integer> allKeys = IntStream.range(0, total).boxed().collect(Collectors.toList());
-            List<Integer> keys = new ArrayList<>(affinity.mapKeysToNodes(allKeys).get(ignite.cluster().localNode()));
-
-            int locKeysCnt = keys.size();
-
-            int[] selected = Utils.selectKDistinct(locKeysCnt, Math.min(bs, locKeysCnt), r);
-
-            // Get dimensions of vectors in cache. We suppose that every feature vector has
-            // same dimension d 1 and every label has the same dimension d2.
-            LabeledVector<Vector, Vector> dimEntry = cache.get(keys.get(selected[0]));
-
-            Matrix inputs = new DenseLocalOnHeapMatrix(dimEntry.features().size(), bs);
-            Matrix groundTruth = new DenseLocalOnHeapMatrix(dimEntry.label().size(), bs);
-
-            for (int i = 0; i < selected.length; i++) {
-                LabeledVector<Vector, Vector> labeled = cache.get(keys.get(selected[i]));
-
-                inputs.assignColumn(i, labeled.features());
-                groundTruth.assignColumn(i, labeled.label());
-            }
-
-            return new IgniteBiTuple<>(inputs, groundTruth);
-        };
-    }
-
-    /** {@inheritDoc} */
-    @Override public MultilayerPerceptron mdl() {
-        return mlp;
-    }
-}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java
new file mode 100644 (file)
index 0000000..47d2022
--- /dev/null
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.nn;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+import org.apache.ignite.ml.trainers.MultiLabelDatasetTrainer;
+import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
+import org.apache.ignite.ml.dataset.primitive.builder.data.SimpleLabeledDatasetDataBuilder;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData;
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
+import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
+import org.apache.ignite.ml.nn.initializers.RandomInitializer;
+import org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator;
+import org.apache.ignite.ml.trainers.group.UpdatesStrategy;
+import org.apache.ignite.ml.util.Utils;
+
+/**
+ * Multilayer perceptron trainer based on partition based {@link Dataset}.
+ *
+ * @param <P> Type of model update used in this trainer.
+ */
+public class MLPTrainer<P extends Serializable> implements MultiLabelDatasetTrainer<MultilayerPerceptron> {
+    /** Multilayer perceptron architecture that defines layers and activators. */
+    private final MLPArchitecture arch;
+
+    /** Loss function to be minimized during the training. */
+    private final IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss;
+
+    /** Update strategy that defines how to update model parameters during the training. */
+    private final UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy;
+
+    /** Maximal number of iterations before the training will be stopped. */
+    private final int maxIterations;
+
+    /** Batch size (per every partition). */
+    private final int batchSize;
+
+    /** Maximal number of local iterations before synchronization. */
+    private final int locIterations;
+
+    /** Multilayer perceptron model initializer. */
+    private final long seed;
+
+    /**
+     * Constructs a new instance of multilayer perceptron trainer.
+     *
+     * @param arch Multilayer perceptron architecture that defines layers and activators.
+     * @param loss Loss function to be minimized during the training.
+     * @param updatesStgy Update strategy that defines how to update model parameters during the training.
+     * @param maxIterations Maximal number of iterations before the training will be stopped.
+     * @param batchSize Batch size (per every partition).
+     * @param locIterations Maximal number of local iterations before synchronization.
+     * @param seed Random initializer seed.
+     */
+    public MLPTrainer(MLPArchitecture arch, IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss,
+        UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy, int maxIterations, int batchSize,
+        int locIterations, long seed) {
+        this.arch = arch;
+        this.loss = loss;
+        this.updatesStgy = updatesStgy;
+        this.maxIterations = maxIterations;
+        this.batchSize = batchSize;
+        this.locIterations = locIterations;
+        this.seed = seed;
+    }
+
+    /** {@inheritDoc} */
+    public <K, V> MultilayerPerceptron fit(DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, double[]> lbExtractor) {
+
+        MultilayerPerceptron mdl = new MultilayerPerceptron(arch, new RandomInitializer(seed));
+        ParameterUpdateCalculator<? super MultilayerPerceptron, P> updater = updatesStgy.getUpdatesCalculator();
+
+        try (Dataset<EmptyContext, SimpleLabeledDatasetData> dataset = datasetBuilder.build(
+            new EmptyContextBuilder<>(),
+            new SimpleLabeledDatasetDataBuilder<>(featureExtractor, lbExtractor)
+        )) {
+            for (int i = 0; i < maxIterations; i += locIterations) {
+
+                MultilayerPerceptron finalMdl = mdl;
+                int finalI = i;
+
+                List<P> totUp = dataset.compute(
+                    data -> {
+                        P update = updater.init(finalMdl, loss);
+
+                        MultilayerPerceptron mlp = Utils.copy(finalMdl);
+
+                        if (data.getFeatures() != null) {
+                            List<P> updates = new ArrayList<>();
+
+                            for (int locStep = 0; locStep < locIterations; locStep++) {
+                                int[] rows = Utils.selectKDistinct(
+                                    data.getRows(),
+                                    Math.min(batchSize, data.getRows()),
+                                    new Random(seed ^ (finalI * locStep))
+                                );
+
+                                double[] inputsBatch = batch(data.getFeatures(), rows, data.getRows());
+                                double[] groundTruthBatch = batch(data.getLabels(), rows, data.getRows());
+
+                                Matrix inputs = new DenseLocalOnHeapMatrix(inputsBatch, rows.length, 0);
+                                Matrix groundTruth = new DenseLocalOnHeapMatrix(groundTruthBatch, rows.length, 0);
+
+                                update = updater.calculateNewUpdate(
+                                    mlp,
+                                    update,
+                                    locStep,
+                                    inputs.transpose(),
+                                    groundTruth.transpose()
+                                );
+
+                                mlp = updater.update(mlp, update);
+                                updates.add(update);
+                            }
+
+                            List<P> res = new ArrayList<>();
+                            res.add(updatesStgy.locStepUpdatesReducer().apply(updates));
+
+                            return res;
+                        }
+
+                        return null;
+                    },
+                    (a, b) -> {
+                        if (a == null)
+                            return b;
+                        else if (b == null)
+                            return a;
+                        else {
+                            a.addAll(b);
+                            return a;
+                        }
+                    }
+                );
+
+                P update = updatesStgy.allUpdatesReducer().apply(totUp);
+                mdl = updater.update(mdl, update);
+            }
+        }
+        catch (Exception e) {
+            throw new RuntimeException(e);
+        }
+
+        return mdl;
+    }
+
+    /**
+     * Builds a batch of the data by fetching specified rows.
+     *
+     * @param data All data.
+     * @param rows Rows to be fetched from the data.
+     * @param totalRows Total number of rows in all data.
+     * @return Batch data.
+     */
+    static double[] batch(double[] data, int[] rows, int totalRows) {
+        int cols = data.length / totalRows;
+
+        double[] res = new double[cols * rows.length];
+
+        for (int i = 0; i < rows.length; i++)
+            for (int j = 0; j < cols; j++)
+                res[j * rows.length + i] = data[j * totalRows + rows[i]];
+
+        return res;
+    }
+}
index 7bf238d..819ed6e 100644 (file)
@@ -42,7 +42,8 @@ import static org.apache.ignite.ml.math.util.MatrixUtil.elementWiseTimes;
 /**
  * Class encapsulating logic of multilayer perceptron.
  */
-public class MultilayerPerceptron implements Model<Matrix, Matrix>, SmoothParametrized<MultilayerPerceptron>, Serializable {
+public class MultilayerPerceptron implements Model<Matrix, Matrix>, SmoothParametrized<MultilayerPerceptron>,
+    Serializable {
     /**
      * This MLP architecture.
      */
@@ -169,16 +170,15 @@ public class MultilayerPerceptron implements Model<Matrix, Matrix>, SmoothParame
     }
 
     /**
-     * Predict values on inputs given as columns in a given matrix.
+     * Makes a prediction for the given objects.
      *
-     * @param val Matrix containing inputs as columns.
-     * @return Matrix with predicted vectors stored in columns with column indexes corresponding to column indexes in
-     * the input matrix.
+     * @param val Matrix containing objects.
+     * @return Matrix with predicted vectors.
      */
     @Override public Matrix apply(Matrix val) {
         MLPState state = new MLPState(null);
-        forwardPass(val, state, false);
-        return state.activatorsOutput.get(state.activatorsOutput.size() - 1);
+        forwardPass(val.transpose(), state, false);
+        return state.activatorsOutput.get(state.activatorsOutput.size() - 1).transpose();
     }
 
     /**
index 25c27cd..fa905f7 100644 (file)
@@ -39,6 +39,22 @@ public class RandomInitializer implements MLPInitializer {
         this.rnd = rnd;
     }
 
+    /**
+     * Constructs RandomInitializer with the given seed.
+     *
+     * @param seed Seed.
+     */
+    public RandomInitializer(long seed) {
+        this.rnd = new Random(seed);
+    }
+
+    /**
+     * Constructs RandomInitializer with random seed.
+     */
+    public RandomInitializer() {
+        this.rnd = new Random();
+    }
+
     /** {@inheritDoc} */
     @Override public void initWeights(Matrix weights) {
         weights.map(value -> 2 * (rnd.nextDouble() - 0.5));
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/AbstractMLPGroupUpdateTrainerInput.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/AbstractMLPGroupUpdateTrainerInput.java
deleted file mode 100644 (file)
index f2d95d5..0000000
+++ /dev/null
@@ -1,60 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.nn.trainers.distributed;
-
-import java.util.UUID;
-import java.util.stream.Stream;
-import org.apache.ignite.ml.math.functions.IgniteSupplier;
-import org.apache.ignite.ml.trainers.local.LocalBatchTrainerInput;
-import org.apache.ignite.ml.nn.MultilayerPerceptron;
-import org.apache.ignite.ml.trainers.group.GroupTrainerCacheKey;
-import org.apache.ignite.ml.trainers.group.GroupTrainerInput;
-
-/**
- * Abstract class for {@link MLPGroupUpdateTrainer} inputs.
- */
-public abstract class AbstractMLPGroupUpdateTrainerInput implements GroupTrainerInput<Void>, LocalBatchTrainerInput<MultilayerPerceptron> {
-    /**
-     * Count of networks to be trained in parallel.
-     */
-    private final int networksCnt;
-
-    /**
-     * Construct instance of this class with given parameters.
-     *
-     * @param networksCnt Count of networks to be trained in parallel.
-     */
-    public AbstractMLPGroupUpdateTrainerInput(int networksCnt) {
-        this.networksCnt = networksCnt;
-    }
-
-    /** {@inheritDoc} */
-    @Override public IgniteSupplier<Stream<GroupTrainerCacheKey<Void>>> initialKeys(UUID trainingUUID) {
-        final int nt = networksCnt; // IMPL NOTE intermediate variable is intended to have smaller lambda
-        return () -> MLPCache.allKeys(nt, trainingUUID);
-    }
-
-    /**
-     * Get count of networks to be trained in parallel.
-     *
-     * @return Count of networks.
-     */
-    public int trainingsCount() {
-        return networksCnt;
-    }
-}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPCache.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPCache.java
deleted file mode 100644 (file)
index 0fa2f29..0000000
+++ /dev/null
@@ -1,91 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.nn.trainers.distributed;
-
-import java.util.UUID;
-import java.util.stream.IntStream;
-import java.util.stream.Stream;
-import org.apache.ignite.Ignite;
-import org.apache.ignite.IgniteCache;
-import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.CacheAtomicityMode;
-import org.apache.ignite.cache.CacheMode;
-import org.apache.ignite.cache.CacheWriteSynchronizationMode;
-import org.apache.ignite.cache.affinity.Affinity;
-import org.apache.ignite.configuration.CacheConfiguration;
-import org.apache.ignite.ml.trainers.group.GroupTrainerCacheKey;
-
-/**
- * Cache for distributed MLP.
- */
-public class MLPCache {
-    /**
-     * Cache name.
-     */
-    public static String CACHE_NAME = "MLP_CACHE";
-
-    /**
-     * Affinity service for region projections cache.
-     *
-     * @return Affinity service for region projections cache.
-     */
-    public static Affinity<GroupTrainerCacheKey<Void>> affinity() {
-        return Ignition.localIgnite().affinity(CACHE_NAME);
-    }
-
-    /**
-     * Get or create region projections cache.
-     *
-     * @param ignite Ignite instance.
-     * @return Region projections cache.
-     */
-    public static IgniteCache<GroupTrainerCacheKey<Void>, MLPGroupTrainingCacheValue> getOrCreate(Ignite ignite) {
-        CacheConfiguration<GroupTrainerCacheKey<Void>, MLPGroupTrainingCacheValue> cfg = new CacheConfiguration<>();
-
-        // Write to primary.
-        cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.PRIMARY_SYNC);
-
-        // Atomic transactions only.
-        cfg.setAtomicityMode(CacheAtomicityMode.ATOMIC);
-
-        // No copying of values.
-        cfg.setCopyOnRead(false);
-
-        // Cache is partitioned.
-        cfg.setCacheMode(CacheMode.PARTITIONED);
-
-        cfg.setBackups(0);
-
-        cfg.setOnheapCacheEnabled(true);
-
-        cfg.setName(CACHE_NAME);
-
-        return ignite.getOrCreateCache(cfg);
-    }
-
-    /**
-     * Get all keys of this cache for given parameters.
-     *
-     * @param trainingsCnt Parallel trainings count.
-     * @param uuid Training UUID.
-     * @return All keys of this cache for given parameters.
-     */
-    public static Stream<GroupTrainerCacheKey<Void>> allKeys(int trainingsCnt, UUID uuid) {
-        return IntStream.range(0, trainingsCnt).mapToObj(i -> new GroupTrainerCacheKey<Void>(i, null, uuid));
-    }
-}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupTrainingCacheValue.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupTrainingCacheValue.java
deleted file mode 100644 (file)
index f8e75f6..0000000
+++ /dev/null
@@ -1,48 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.nn.trainers.distributed;
-
-import org.apache.ignite.ml.nn.MultilayerPerceptron;
-
-/**
- * Value of cache used for group training by {@link MLPGroupUpdateTrainer}.
- */
-public class MLPGroupTrainingCacheValue {
-    /**
-     * Multilayer perceptron.
-     */
-    private MultilayerPerceptron mlp;
-
-    /**
-     * Construct instance of this class with given parameters.
-     *
-     * @param mlp Multilayer perceptron.
-     */
-    public MLPGroupTrainingCacheValue(MultilayerPerceptron mlp) {
-        this.mlp = mlp;
-    }
-
-    /**
-     * Get multilayer perceptron.
-     *
-     * @return Multilayer perceptron.
-     */
-    public MultilayerPerceptron perceptron() {
-        return mlp;
-    }
-}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainer.java
deleted file mode 100644 (file)
index 333afcc..0000000
+++ /dev/null
@@ -1,377 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.nn.trainers.distributed;
-
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Map;
-import java.util.Objects;
-import java.util.Set;
-import java.util.UUID;
-import java.util.stream.Collectors;
-import java.util.stream.Stream;
-import org.apache.ignite.Ignite;
-import org.apache.ignite.Ignition;
-import org.apache.ignite.lang.IgniteBiTuple;
-import org.apache.ignite.ml.math.Matrix;
-import org.apache.ignite.ml.math.Vector;
-import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction;
-import org.apache.ignite.ml.math.functions.IgniteFunction;
-import org.apache.ignite.ml.math.functions.IgniteSupplier;
-import org.apache.ignite.ml.math.util.MatrixUtil;
-import org.apache.ignite.ml.optimization.LossFunctions;
-import org.apache.ignite.ml.nn.MultilayerPerceptron;
-import org.apache.ignite.ml.optimization.SmoothParametrized;
-import org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator;
-import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate;
-import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator;
-import org.apache.ignite.ml.trainers.group.GroupTrainerCacheKey;
-import org.apache.ignite.ml.trainers.group.MetaoptimizerGroupTrainer;
-import org.apache.ignite.ml.trainers.group.ResultAndUpdates;
-import org.apache.ignite.ml.trainers.group.UpdatesStrategy;
-import org.apache.ignite.ml.trainers.group.chain.EntryAndContext;
-import org.apache.ignite.ml.util.Utils;
-
-/**
- * Update-based distributed training of MLP.
- *
- * @param <U> Type of update.
- */
-public class MLPGroupUpdateTrainer<U extends Serializable> extends
-    MetaoptimizerGroupTrainer<MLPGroupUpdateTrainerLocalContext,
-        Void,
-        MLPGroupTrainingCacheValue,
-        U,
-        MultilayerPerceptron,
-        U,
-        MultilayerPerceptron,
-        AbstractMLPGroupUpdateTrainerInput,
-        MLPGroupUpdateTrainingContext<U>,
-        ArrayList<U>,
-        MLPGroupUpdateTrainingLoopData<U>,
-        U> {
-    /**
-     * Loss function.
-     */
-    private final IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss;
-
-    /**
-     * Error tolerance.
-     */
-    private final double tolerance;
-
-    /**
-     * Maximal count of global steps.
-     */
-    private final int maxGlobalSteps;
-
-    /**
-     * Synchronize updates between networks every syncPeriod steps.
-     */
-    private final int syncPeriod;
-
-    /**
-     * Function used to reduce updates from different networks (for example, averaging of gradients of all networks).
-     */
-    private final IgniteFunction<List<U>, U> allUpdatesReducer;
-
-    /**
-     * Function used to reduce updates in one training (for example, sum all sequential gradient updates to get one
-     * gradient update).
-     */
-    private final IgniteFunction<List<U>, U> locStepUpdatesReducer;
-
-    /**
-     * Updates calculator.
-     */
-    private final ParameterUpdateCalculator<? super MultilayerPerceptron, U> updateCalculator;
-
-    /**
-     * Default maximal count of global steps.
-     */
-    private static final int DEFAULT_MAX_GLOBAL_STEPS = 30;
-
-    /**
-     * Default sync rate.
-     */
-    private static final int DEFAULT_SYNC_RATE = 5;
-
-    /**
-     * Default all updates reducer.
-     */
-    private static final IgniteFunction<List<RPropParameterUpdate>, RPropParameterUpdate>
-        DEFAULT_ALL_UPDATES_REDUCER = RPropParameterUpdate::avg;
-
-    /**
-     * Default local steps updates reducer.
-     */
-    private static final IgniteFunction<List<RPropParameterUpdate>, RPropParameterUpdate>
-        DEFAULT_LOCAL_STEP_UPDATES_REDUCER = RPropParameterUpdate::sumLocal;
-
-    /**
-     * Default update calculator.
-     */
-    private static final ParameterUpdateCalculator<SmoothParametrized, RPropParameterUpdate>
-        DEFAULT_UPDATE_CALCULATOR = new RPropUpdateCalculator();
-
-    /**
-     * Default loss function.
-     */
-    private static final IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> DEFAULT_LOSS
-        = LossFunctions.MSE;
-
-    /**
-     * Construct instance of this class with given parameters.
-     *
-     * @param loss Loss function.
-     * @param ignite Ignite instance.
-     * @param tolerance Error tolerance.
-     */
-    public MLPGroupUpdateTrainer(int maxGlobalSteps,
-        int syncPeriod,
-        IgniteFunction<List<U>, U> allUpdatesReducer,
-        IgniteFunction<List<U>, U> locStepUpdatesReducer,
-        ParameterUpdateCalculator<? super MultilayerPerceptron, U> updateCalculator,
-        IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss,
-        Ignite ignite, double tolerance) {
-        super(new MLPMetaoptimizer<>(allUpdatesReducer), MLPCache.getOrCreate(ignite), ignite);
-
-        this.maxGlobalSteps = maxGlobalSteps;
-        this.syncPeriod = syncPeriod;
-        this.allUpdatesReducer = allUpdatesReducer;
-        this.locStepUpdatesReducer = locStepUpdatesReducer;
-        this.updateCalculator = updateCalculator;
-        this.loss = loss;
-        this.tolerance = tolerance;
-    }
-
-    /**
-     * Get default {@link MLPGroupUpdateTrainer}.
-     *
-     * @param ignite Ignite instance.
-     * @return Default {@link MLPGroupUpdateTrainer}.
-     */
-    public static MLPGroupUpdateTrainer<RPropParameterUpdate> getDefault(Ignite ignite) {
-        return new MLPGroupUpdateTrainer<>(DEFAULT_MAX_GLOBAL_STEPS, DEFAULT_SYNC_RATE, DEFAULT_ALL_UPDATES_REDUCER,
-            DEFAULT_LOCAL_STEP_UPDATES_REDUCER, DEFAULT_UPDATE_CALCULATOR, DEFAULT_LOSS, ignite, 0.01);
-    }
-
-    /** {@inheritDoc} */
-    @Override protected void init(AbstractMLPGroupUpdateTrainerInput data, UUID trainingUUID) {
-        super.init(data, trainingUUID);
-
-        MLPGroupUpdateTrainerDataCache.getOrCreate(ignite).put(trainingUUID, new MLPGroupUpdateTrainingData<>(
-            updateCalculator,
-            syncPeriod,
-            locStepUpdatesReducer,
-            data.batchSupplier(),
-            loss,
-            tolerance
-        ));
-    }
-
-    /** {@inheritDoc} */
-    @Override protected IgniteFunction<GroupTrainerCacheKey<Void>, ResultAndUpdates<U>> distributedInitializer(
-        AbstractMLPGroupUpdateTrainerInput data) {
-        MultilayerPerceptron initPerceptron = data.mdl();
-
-        // For each key put initial network into the cache.
-        return key -> {
-            Ignite ignite = Ignition.localIgnite();
-
-            U initUpdate = updateCalculator.init(initPerceptron, loss);
-
-            return ResultAndUpdates.of(initUpdate).updateCache(MLPCache.getOrCreate(ignite), key,
-                new MLPGroupTrainingCacheValue(initPerceptron));
-        };
-    }
-
-    /** {@inheritDoc} */
-    @Override protected IgniteFunction<EntryAndContext<Void, MLPGroupTrainingCacheValue,
-        MLPGroupUpdateTrainingContext<U>>, MLPGroupUpdateTrainingLoopData<U>> trainingLoopStepDataExtractor() {
-        return entryAndContext -> {
-            MLPGroupUpdateTrainingContext<U> ctx = entryAndContext.context();
-            Map.Entry<GroupTrainerCacheKey<Void>, MLPGroupTrainingCacheValue> entry = entryAndContext.entry();
-            MLPGroupUpdateTrainingData<U> data = ctx.data();
-
-            return new MLPGroupUpdateTrainingLoopData<>(entry.getValue().perceptron(),
-                data.updateCalculator(), data.stepsCnt(), data.updateReducer(), ctx.previousUpdate(), entry.getKey(),
-                data.batchSupplier(), data.loss(), data.tolerance());
-        };
-    }
-
-    /** {@inheritDoc} */
-    @Override protected IgniteSupplier<Stream<GroupTrainerCacheKey<Void>>> keysToProcessInTrainingLoop(
-        MLPGroupUpdateTrainerLocalContext locCtx) {
-        int trainingsCnt = locCtx.parallelTrainingsCnt();
-        UUID uuid = locCtx.trainingUUID();
-
-        return () -> MLPCache.allKeys(trainingsCnt, uuid);
-    }
-
-    /** {@inheritDoc} */
-    @Override protected IgniteSupplier<MLPGroupUpdateTrainingContext<U>> remoteContextExtractor(U prevUpdate,
-        MLPGroupUpdateTrainerLocalContext ctx) {
-        UUID uuid = ctx.trainingUUID();
-
-        return () -> {
-            MLPGroupUpdateTrainingData<U> data = MLPGroupUpdateTrainerDataCache.getOrCreate(Ignition.localIgnite()).get(uuid);
-            return new MLPGroupUpdateTrainingContext<>(data, prevUpdate);
-        };
-    }
-
-    /** {@inheritDoc} */
-    @Override protected IgniteFunction<MLPGroupUpdateTrainingLoopData<U>, ResultAndUpdates<U>> dataProcessor() {
-        return data -> {
-            MultilayerPerceptron mlp = data.mlp();
-
-            // Apply previous update.
-            MultilayerPerceptron newMlp = updateCalculator.update(mlp, data.previousUpdate());
-
-            MultilayerPerceptron mlpCp = Utils.copy(newMlp);
-            ParameterUpdateCalculator<? super MultilayerPerceptron, U> updateCalculator = data.updateCalculator();
-            IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss = data.loss();
-
-            // ParameterUpdateCalculator API to have proper way to setting loss.
-            updateCalculator.init(mlpCp, loss);
-
-            // Generate new update.
-            int steps = data.stepsCnt();
-            List<U> updates = new ArrayList<>(steps);
-            U curUpdate = data.previousUpdate();
-
-            for (int i = 0; i < steps; i++) {
-                IgniteBiTuple<Matrix, Matrix> batch = data.batchSupplier().get();
-                Matrix input = batch.get1();
-                Matrix truth = batch.get2();
-
-                int batchSize = truth.columnSize();
-
-                curUpdate = updateCalculator.calculateNewUpdate(mlpCp, curUpdate, i, input, truth);
-                mlpCp = updateCalculator.update(mlpCp, curUpdate);
-                updates.add(curUpdate);
-
-                Matrix predicted = mlpCp.apply(input);
-
-                double err = MatrixUtil.zipFoldByColumns(predicted, truth, (predCol, truthCol) ->
-                    loss.apply(truthCol).apply(predCol)).sum() / batchSize;
-
-                if (err < data.tolerance())
-                    break;
-            }
-
-            U accumulatedUpdate = data.getUpdateReducer().apply(updates);
-
-            return new ResultAndUpdates<>(accumulatedUpdate).
-                updateCache(MLPCache.getOrCreate(Ignition.localIgnite()), data.key(),
-                    new MLPGroupTrainingCacheValue(newMlp));
-        };
-    }
-
-    /** {@inheritDoc} */
-    @Override protected MLPGroupUpdateTrainerLocalContext<U> initialLocalContext(
-        AbstractMLPGroupUpdateTrainerInput data, UUID trainingUUID) {
-        return new MLPGroupUpdateTrainerLocalContext<>(trainingUUID, maxGlobalSteps, allUpdatesReducer,
-            data.trainingsCount());
-    }
-
-    /** {@inheritDoc} */
-    @Override protected IgniteSupplier<Stream<GroupTrainerCacheKey<Void>>> finalResultKeys(U data,
-        MLPGroupUpdateTrainerLocalContext locCtx) {
-        UUID uuid = locCtx.trainingUUID();
-        int trainingsCnt = locCtx.parallelTrainingsCnt();
-
-        return () -> MLPCache.allKeys(trainingsCnt, uuid);
-    }
-
-    /** {@inheritDoc} */
-    @Override protected IgniteSupplier<MLPGroupUpdateTrainingContext<U>> extractContextForFinalResultCreation(U data,
-        MLPGroupUpdateTrainerLocalContext locCtx) {
-        return () -> null;
-    }
-
-    /** {@inheritDoc} */
-    @Override protected IgniteFunction<EntryAndContext<Void, MLPGroupTrainingCacheValue,
-        MLPGroupUpdateTrainingContext<U>>, ResultAndUpdates<MultilayerPerceptron>> finalResultsExtractor() {
-        return context -> ResultAndUpdates.of(context.entry().getValue().perceptron());
-    }
-
-    /** {@inheritDoc} */
-    @Override protected IgniteFunction<List<MultilayerPerceptron>, MultilayerPerceptron> finalResultsReducer() {
-        // Just take any of MLPs since they will be in the same state.
-        return mlps -> mlps.stream().filter(Objects::nonNull).findFirst().orElse(null);
-    }
-
-    /** {@inheritDoc} */
-    @Override protected MultilayerPerceptron mapFinalResult(MultilayerPerceptron res,
-        MLPGroupUpdateTrainerLocalContext locCtx) {
-        return res;
-    }
-
-    /** {@inheritDoc} */
-    @Override protected void cleanup(MLPGroupUpdateTrainerLocalContext locCtx) {
-        MLPGroupUpdateTrainerDataCache.getOrCreate(ignite).remove(locCtx.trainingUUID());
-        Set<GroupTrainerCacheKey<Void>> toRmv = MLPCache.allKeys(locCtx.parallelTrainingsCnt(), locCtx.trainingUUID()).collect(Collectors.toSet());
-        MLPCache.getOrCreate(ignite).removeAll(toRmv);
-    }
-
-    /**
-     * Create new {@link MLPGroupUpdateTrainer} with new maxGlobalSteps value.
-     *
-     * @param maxGlobalSteps New maxGlobalSteps value.
-     * @return New {@link MLPGroupUpdateTrainer} with new maxGlobalSteps value.
-     */
-    public MLPGroupUpdateTrainer<U> withMaxGlobalSteps(int maxGlobalSteps) {
-        return new MLPGroupUpdateTrainer<>(maxGlobalSteps, syncPeriod, allUpdatesReducer, locStepUpdatesReducer,
-            updateCalculator, loss, ignite, tolerance);
-    }
-
-    /**
-     * Create new {@link MLPGroupUpdateTrainer} with new syncPeriod value.
-     *
-     * @param syncPeriod New syncPeriod value.
-     * @return New {@link MLPGroupUpdateTrainer} with new syncPeriod value.
-     */
-    public MLPGroupUpdateTrainer<U> withSyncPeriod(int syncPeriod) {
-        return new MLPGroupUpdateTrainer<>(maxGlobalSteps, syncPeriod
-            , allUpdatesReducer, locStepUpdatesReducer, updateCalculator, loss, ignite, tolerance);
-    }
-
-    /**
-     * Create new {@link MLPGroupUpdateTrainer} with new tolerance.
-     *
-     * @param tolerance New tolerance value.
-     * @return New {@link MLPGroupUpdateTrainer} with new tolerance value.
-     */
-    public MLPGroupUpdateTrainer<U> withTolerance(double tolerance) {
-        return new MLPGroupUpdateTrainer<>(maxGlobalSteps, syncPeriod, allUpdatesReducer, locStepUpdatesReducer,
-            updateCalculator, loss, ignite, tolerance);
-    }
-
-    /**
-     * Create new {@link MLPGroupUpdateTrainer} with new update strategy.
-     *
-     * @param stgy New update strategy.
-     * @return New {@link MLPGroupUpdateTrainer} with new tolerance value.
-     */
-    public <U1 extends Serializable> MLPGroupUpdateTrainer<U1> withUpdateStrategy(UpdatesStrategy<? super MultilayerPerceptron, U1> stgy) {
-        return new MLPGroupUpdateTrainer<>(maxGlobalSteps, syncPeriod, stgy.allUpdatesReducer(), stgy.locStepUpdatesReducer(),
-            stgy.getUpdatesCalculator(), loss, ignite, tolerance);
-    }
-}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainerDataCache.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainerDataCache.java
deleted file mode 100644 (file)
index 4200321..0000000
+++ /dev/null
@@ -1,77 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.nn.trainers.distributed;
-
-import java.util.UUID;
-import org.apache.ignite.Ignite;
-import org.apache.ignite.IgniteCache;
-import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.CacheAtomicityMode;
-import org.apache.ignite.cache.CacheMode;
-import org.apache.ignite.cache.CacheWriteSynchronizationMode;
-import org.apache.ignite.cache.affinity.Affinity;
-import org.apache.ignite.configuration.CacheConfiguration;
-
-/**
- * Cache used for storing data for {@link MLPGroupUpdateTrainer}.
- */
-public class MLPGroupUpdateTrainerDataCache {
-    /**
-     * Cache name.
-     */
-    public static String CACHE_NAME = "MLP_GRP_TRN_DATA_CACHE";
-
-    /**
-     * Affinity service for region projections cache.
-     *
-     * @return Affinity service for region projections cache.
-     */
-    public static Affinity<UUID> affinity() {
-        return Ignition.localIgnite().affinity(CACHE_NAME);
-    }
-
-    /**
-     * Get or create region projections cache.
-     *
-     * @param ignite Ignite instance.
-     * @return Region projections cache.
-     */
-    public static IgniteCache<UUID, MLPGroupUpdateTrainingData> getOrCreate(Ignite ignite) {
-        CacheConfiguration<UUID, MLPGroupUpdateTrainingData> cfg = new CacheConfiguration<>();
-
-        // Write to primary.
-        cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.FULL_SYNC);
-
-        // Atomic transactions only.
-        cfg.setAtomicityMode(CacheAtomicityMode.ATOMIC);
-
-        // No copying of values.
-        cfg.setCopyOnRead(false);
-
-        // Cache is partitioned.
-        cfg.setCacheMode(CacheMode.REPLICATED);
-
-        cfg.setBackups(0);
-
-        cfg.setOnheapCacheEnabled(true);
-
-        cfg.setName(CACHE_NAME);
-
-        return ignite.getOrCreateCache(cfg);
-    }
-}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainerLocalContext.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainerLocalContext.java
deleted file mode 100644 (file)
index ecb141d..0000000
+++ /dev/null
@@ -1,117 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.nn.trainers.distributed;
-
-import java.util.List;
-import java.util.UUID;
-import org.apache.ignite.ml.math.functions.IgniteFunction;
-import org.apache.ignite.ml.trainers.group.chain.HasTrainingUUID;
-
-/**
- * Local context for {@link MLPGroupUpdateTrainer}.
- *
- * @param <U> Type of updates on which training is done.
- */
-public class MLPGroupUpdateTrainerLocalContext<U> implements HasTrainingUUID {
-    /**
-     * UUID of training.
-     */
-    private final UUID trainingUUID;
-
-    /**
-     * Maximal number of global steps.
-     */
-    private final int globalStepsMaxCnt;
-
-    /**
-     * Reducer used to reduce updates resulted from each parallel training.
-     */
-    private final IgniteFunction<List<U>, U> allUpdatesReducer;
-
-    /**
-     * Count of networks to be trained in parallel.
-     */
-    private final int parallelTrainingsCnt;
-
-    /**
-     * Current global step of {@link MLPGroupUpdateTrainer}.
-     */
-    private int curStep;
-
-    /** Create multilayer perceptron group update trainer local context. */
-    public MLPGroupUpdateTrainerLocalContext(UUID trainingUUID, int globalStepsMaxCnt,
-        IgniteFunction<List<U>, U> allUpdatesReducer, int parallelTrainingsCnt) {
-        this.trainingUUID = trainingUUID;
-        this.globalStepsMaxCnt = globalStepsMaxCnt;
-        this.allUpdatesReducer = allUpdatesReducer;
-        this.parallelTrainingsCnt = parallelTrainingsCnt;
-        curStep = 0;
-    }
-
-    /** {@inheritDoc} */
-    @Override public UUID trainingUUID() {
-        return trainingUUID;
-    }
-
-    /**
-     * Get global steps max count.
-     *
-     * @return Global steps max count.
-     */
-    public int globalStepsMaxCount() {
-        return globalStepsMaxCnt;
-    }
-
-    /**
-     * Get reducer used to reduce updates resulted from each parallel training.
-     *
-     * @return Reducer used to reduce updates resulted from each parallel training.
-     */
-    public IgniteFunction<List<U>, U> allUpdatesReducer() {
-        return allUpdatesReducer;
-    }
-
-    /**
-     * Get count of networks to be trained in parallel.
-     *
-     * @return Count of networks to be trained in parallel.
-     */
-    public int parallelTrainingsCnt() {
-        return parallelTrainingsCnt;
-    }
-
-    /**
-     * Get current global step.
-     *
-     * @return Current global step.
-     */
-    public int currentStep() {
-        return curStep;
-    }
-
-    /**
-     * Increment current global step.
-     *
-     * @return This object.
-     */
-    public MLPGroupUpdateTrainerLocalContext<U> incrementCurrentStep() {
-        curStep++;
-
-        return this;
-    }
-}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainingContext.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainingContext.java
deleted file mode 100644 (file)
index f4ccd98..0000000
+++ /dev/null
@@ -1,64 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.nn.trainers.distributed;
-
-/**
- * Context extracted in distribute phase of training loop step in {@link MLPGroupUpdateTrainer}.
- *
- * @param <U> Type of update.
- */
-public class MLPGroupUpdateTrainingContext<U> {
-    /**
-     * Group training data.
-     */
-    private final MLPGroupUpdateTrainingData<U> data;
-
-    /**
-     * Update produced by previous training loop step.
-     */
-    private final U previousUpdate;
-
-    /**
-     * Construct an instance of this class.
-     *
-     * @param data Group training data.
-     * @param previousUpdate Update produced by previous training loop step.
-     */
-    public MLPGroupUpdateTrainingContext(MLPGroupUpdateTrainingData<U> data, U previousUpdate) {
-        this.data = data;
-        this.previousUpdate = previousUpdate;
-    }
-
-    /**
-     * Get group training data.
-     *
-     * @return Group training data.
-     */
-    public MLPGroupUpdateTrainingData<U> data() {
-        return data;
-    }
-
-    /**
-     * Get update produced by previous training loop step.
-     *
-     * @return Update produced by previous training loop step.
-     */
-    public U previousUpdate() {
-        return previousUpdate;
-    }
-}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainingData.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainingData.java
deleted file mode 100644 (file)
index 3031c8f..0000000
+++ /dev/null
@@ -1,103 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.nn.trainers.distributed;
-
-import java.util.List;
-import org.apache.ignite.lang.IgniteBiTuple;
-import org.apache.ignite.ml.math.Matrix;
-import org.apache.ignite.ml.math.Vector;
-import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction;
-import org.apache.ignite.ml.math.functions.IgniteFunction;
-import org.apache.ignite.ml.math.functions.IgniteSupplier;
-import org.apache.ignite.ml.nn.MultilayerPerceptron;
-import org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator;
-
-/** Multilayer perceptron group update training data. */
-public class MLPGroupUpdateTrainingData<U> {
-    /** {@link ParameterUpdateCalculator}. */
-    private final ParameterUpdateCalculator<? super MultilayerPerceptron, U> updateCalculator;
-
-    /**
-     * Count of steps which should be done by each of parallel trainings before sending it's update for combining with
-     * other parallel trainings updates.
-     */
-    private final int stepsCnt;
-
-    /**
-     * Function used to reduce updates in one training (for example, sum all sequential gradient updates to get one
-     * gradient update).
-     */
-    private final IgniteFunction<List<U>, U> updateReducer;
-
-    /**
-     * Supplier of batches in the form of (inputs, groundTruths).
-     */
-    private final IgniteSupplier<IgniteBiTuple<Matrix, Matrix>> batchSupplier;
-
-    /**
-     * Loss function.
-     */
-    private final IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss;
-
-    /** Error tolerance. */
-    private final double tolerance;
-
-    /** Construct multilayer perceptron group update training data with all parameters provided. */
-    public MLPGroupUpdateTrainingData(
-        ParameterUpdateCalculator<? super MultilayerPerceptron, U> updateCalculator, int stepsCnt,
-        IgniteFunction<List<U>, U> updateReducer,
-        IgniteSupplier<IgniteBiTuple<Matrix, Matrix>> batchSupplier,
-        IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss, double tolerance) {
-        this.updateCalculator = updateCalculator;
-        this.stepsCnt = stepsCnt;
-        this.updateReducer = updateReducer;
-        this.batchSupplier = batchSupplier;
-        this.loss = loss;
-        this.tolerance = tolerance;
-    }
-
-    /** Get update calculator. */
-    public ParameterUpdateCalculator<? super MultilayerPerceptron, U> updateCalculator() {
-        return updateCalculator;
-    }
-
-    /** Get count of steps. */
-    public int stepsCnt() {
-        return stepsCnt;
-    }
-
-    /** Get update reducer. */
-    public IgniteFunction<List<U>, U> updateReducer() {
-        return updateReducer;
-    }
-
-    /** Get batch supplier. */
-    public IgniteSupplier<IgniteBiTuple<Matrix, Matrix>> batchSupplier() {
-        return batchSupplier;
-    }
-
-    /** Get loss function. */
-    public IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss() {
-        return loss;
-    }
-
-    /** Get tolerance. */
-    public double tolerance() {
-        return tolerance;
-    }
-}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainingLoopData.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPGroupUpdateTrainingLoopData.java
deleted file mode 100644 (file)
index 342e7d5..0000000
+++ /dev/null
@@ -1,126 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.nn.trainers.distributed;
-
-import java.io.Serializable;
-import java.util.List;
-import org.apache.ignite.lang.IgniteBiTuple;
-import org.apache.ignite.ml.math.Matrix;
-import org.apache.ignite.ml.math.Vector;
-import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction;
-import org.apache.ignite.ml.math.functions.IgniteFunction;
-import org.apache.ignite.ml.math.functions.IgniteSupplier;
-import org.apache.ignite.ml.nn.MultilayerPerceptron;
-import org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator;
-import org.apache.ignite.ml.trainers.group.GroupTrainerCacheKey;
-
-/** Multilayer perceptron group update training loop data. */
-public class MLPGroupUpdateTrainingLoopData<P> implements Serializable {
-    /** {@link ParameterUpdateCalculator}. */
-    private final ParameterUpdateCalculator<? super MultilayerPerceptron, P> updateCalculator;
-
-    /**
-     * Count of steps which should be done by each of parallel trainings before sending it's update for combining with
-     * other parallel trainings updates.
-     */
-    private final int stepsCnt;
-
-    /** Function used to reduce updates of all steps of given parallel training. */
-    private final IgniteFunction<List<P>, P> updateReducer;
-
-    /** Previous update. */
-    private final P previousUpdate;
-
-    /** Supplier of batches. */
-    private final IgniteSupplier<IgniteBiTuple<Matrix, Matrix>> batchSupplier;
-
-    /** Loss function. */
-    private final IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss;
-
-    /** Error tolerance. */
-    private final double tolerance;
-
-    /** Key. */
-    private final GroupTrainerCacheKey<Void> key;
-
-    /** MLP. */
-    private final MultilayerPerceptron mlp;
-
-    /** Create multilayer perceptron group update training loop data. */
-    public MLPGroupUpdateTrainingLoopData(MultilayerPerceptron mlp,
-        ParameterUpdateCalculator<? super MultilayerPerceptron, P> updateCalculator, int stepsCnt,
-        IgniteFunction<List<P>, P> updateReducer, P previousUpdate,
-        GroupTrainerCacheKey<Void> key, IgniteSupplier<IgniteBiTuple<Matrix, Matrix>> batchSupplier,
-        IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss,
-        double tolerance) {
-        this.mlp = mlp;
-        this.updateCalculator = updateCalculator;
-        this.stepsCnt = stepsCnt;
-        this.updateReducer = updateReducer;
-        this.previousUpdate = previousUpdate;
-        this.key = key;
-        this.batchSupplier = batchSupplier;
-        this.loss = loss;
-        this.tolerance = tolerance;
-    }
-
-    /** Get perceptron. */
-    public MultilayerPerceptron mlp() {
-        return mlp;
-    }
-
-    /** Get update calculator. */
-    public ParameterUpdateCalculator<? super MultilayerPerceptron, P> updateCalculator() {
-        return updateCalculator;
-    }
-
-    /** Get steps count. */
-    public int stepsCnt() {
-        return stepsCnt;
-    }
-
-    /** Get update reducer. */
-    public IgniteFunction<List<P>, P> getUpdateReducer() {
-        return updateReducer;
-    }
-
-    /** Get previous update. */
-    public P previousUpdate() {
-        return previousUpdate;
-    }
-
-    /** Get group trainer cache key. */
-    public GroupTrainerCacheKey<Void> key() {
-        return key;
-    }
-
-    /** Get batch supplier. */
-    public IgniteSupplier<IgniteBiTuple<Matrix, Matrix>> batchSupplier() {
-        return batchSupplier;
-    }
-
-    /** Get loss function. */
-    public IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss() {
-        return loss;
-    }
-
-    /** Get tolerance. */
-    public double tolerance() {
-        return tolerance;
-    }
-}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPMetaoptimizer.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/MLPMetaoptimizer.java
deleted file mode 100644 (file)
index ff95a27..0000000
+++ /dev/null
@@ -1,76 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.nn.trainers.distributed;
-
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Objects;
-import java.util.stream.Collectors;
-import org.apache.ignite.ml.math.functions.IgniteFunction;
-import org.apache.ignite.ml.trainers.group.Metaoptimizer;
-
-/** Meta-optimizer for multilayer perceptron. */
-public class MLPMetaoptimizer<P> implements Metaoptimizer<MLPGroupUpdateTrainerLocalContext,
-    MLPGroupUpdateTrainingLoopData<P>, P, P, P, ArrayList<P>> {
-    /** Function used for reducing updates produced by parallel trainings. */
-    private final IgniteFunction<List<P>, P> allUpdatesReducer;
-
-    /** Construct metaoptimizer. */
-    public MLPMetaoptimizer(IgniteFunction<List<P>, P> allUpdatesReducer) {
-        this.allUpdatesReducer = allUpdatesReducer;
-    }
-
-    /** {@inheritDoc} */
-    @Override public IgniteFunction<List<P>, P> initialReducer() {
-        return allUpdatesReducer;
-    }
-
-    /** {@inheritDoc} */
-    @Override public P locallyProcessInitData(P data, MLPGroupUpdateTrainerLocalContext locCtx) {
-        return data;
-    }
-
-    /** {@inheritDoc} */
-    @Override public IgniteFunction<P, ArrayList<P>> distributedPostprocessor() {
-        return p -> {
-            ArrayList<P> res = new ArrayList<>();
-            res.add(p);
-            return res;
-        };
-    }
-
-    /** {@inheritDoc} */
-    @Override public IgniteFunction<List<ArrayList<P>>, ArrayList<P>> postProcessReducer() {
-        // Flatten.
-        return lists -> new ArrayList<>(lists.stream()
-            .flatMap(List::stream)
-            .collect(Collectors.toList()));
-    }
-
-    /** {@inheritDoc} */
-    @Override public P localProcessor(ArrayList<P> input, MLPGroupUpdateTrainerLocalContext locCtx) {
-        locCtx.incrementCurrentStep();
-
-        return allUpdatesReducer.apply(input.stream().filter(Objects::nonNull).collect(Collectors.toList()));
-    }
-
-    /** {@inheritDoc} */
-    @Override public boolean shouldContinue(P input, MLPGroupUpdateTrainerLocalContext locCtx) {
-        return input != null && locCtx.currentStep() < locCtx.globalStepsMaxCount();
-    }
-}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/distributed/package-info.java
deleted file mode 100644 (file)
index 24c0046..0000000
+++ /dev/null
@@ -1,22 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/**
- * <!-- Package description. -->
- * Contains multilayer perceptron distributed trainers.
- */
-package org.apache.ignite.ml.nn.trainers.distributed;
\ No newline at end of file
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/MLPLocalBatchTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/MLPLocalBatchTrainer.java
deleted file mode 100644 (file)
index ebb78c0..0000000
+++ /dev/null
@@ -1,78 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.nn.trainers.local;
-
-import org.apache.ignite.ml.math.Vector;
-import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction;
-import org.apache.ignite.ml.math.functions.IgniteFunction;
-import org.apache.ignite.ml.math.functions.IgniteSupplier;
-import org.apache.ignite.ml.optimization.LossFunctions;
-import org.apache.ignite.ml.nn.MultilayerPerceptron;
-import org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator;
-import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate;
-import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator;
-import org.apache.ignite.ml.trainers.local.LocalBatchTrainer;
-
-/**
- * Local batch trainer for MLP.
- *
- * @param <P> Parameter updater parameters.
- */
-public class MLPLocalBatchTrainer<P>
-    extends LocalBatchTrainer<MultilayerPerceptron, P> {
-    /**
-     * Default loss function.
-     */
-    private static final IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> DEFAULT_LOSS =
-        LossFunctions.MSE;
-
-    /**
-     * Default error threshold.
-     */
-    private static final double DEFAULT_ERROR_THRESHOLD = 1E-5;
-
-    /**
-     * Default maximal iterations count.
-     */
-    private static final int DEFAULT_MAX_ITERATIONS = 100;
-
-    /**
-     * Construct a trainer.
-     *
-     * @param loss Loss function.
-     * @param updaterSupplier Supplier of updater function.
-     * @param errorThreshold Error threshold.
-     * @param maxIterations Maximal iterations count.
-     */
-    public MLPLocalBatchTrainer(
-        IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss,
-        IgniteSupplier<ParameterUpdateCalculator<? super MultilayerPerceptron, P>> updaterSupplier,
-        double errorThreshold, int maxIterations) {
-        super(loss, updaterSupplier, errorThreshold, maxIterations);
-    }
-
-    /**
-     * Get MLPLocalBatchTrainer with default parameters.
-     *
-     * @return MLPLocalBatchTrainer with default parameters.
-     */
-    public static MLPLocalBatchTrainer<RPropParameterUpdate> getDefault() {
-        return new MLPLocalBatchTrainer<>(DEFAULT_LOSS, () -> new RPropUpdateCalculator(), DEFAULT_ERROR_THRESHOLD,
-            DEFAULT_MAX_ITERATIONS);
-    }
-}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/package-info.java
deleted file mode 100644 (file)
index c90f67a..0000000
+++ /dev/null
@@ -1,22 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/**
- * <!-- Package description. -->
- * Contains multilayer perceptron trainers.
- */
-package org.apache.ignite.ml.nn.trainers;
\ No newline at end of file
index b494b14..0402df6 100644 (file)
@@ -27,6 +27,9 @@ import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
  * Data needed for Nesterov parameters updater.
  */
 public class NesterovParameterUpdate implements Serializable {
+    /** */
+    private static final long serialVersionUID = -6370106062737202385L;
+
     /**
      * Previous step weights updates.
      */
@@ -77,8 +80,13 @@ public class NesterovParameterUpdate implements Serializable {
      * @return Sum of parameters updates.
      */
     public static NesterovParameterUpdate sum(List<NesterovParameterUpdate> parameters) {
-        return parameters.stream().filter(Objects::nonNull).map(NesterovParameterUpdate::prevIterationUpdates)
-            .reduce(Vector::plus).map(NesterovParameterUpdate::new).orElse(null);
+        return parameters
+            .stream()
+            .filter(Objects::nonNull)
+            .map(NesterovParameterUpdate::prevIterationUpdates)
+            .reduce(Vector::plus)
+            .map(NesterovParameterUpdate::new)
+            .orElse(null);
     }
 
     /**
@@ -89,6 +97,8 @@ public class NesterovParameterUpdate implements Serializable {
      */
     public static NesterovParameterUpdate avg(List<NesterovParameterUpdate> parameters) {
         NesterovParameterUpdate sum = sum(parameters);
-        return sum != null ? sum.setPreviousUpdates(sum.prevIterationUpdates().divide(parameters.size())) : null;
+        return sum != null ? sum.setPreviousUpdates(sum.prevIterationUpdates()
+            .divide(parameters.stream()
+                .filter(Objects::nonNull).count())) : null;
     }
 }
index 2bee506..a9b4521 100644 (file)
@@ -28,6 +28,9 @@ import org.apache.ignite.ml.optimization.SmoothParametrized;
  */
 public class NesterovUpdateCalculator<M extends SmoothParametrized<M>>
     implements ParameterUpdateCalculator<M, NesterovParameterUpdate> {
+    /** */
+    private static final long serialVersionUID = 251066184668190622L;
+
     /**
      * Learning rate.
      */
@@ -60,14 +63,12 @@ public class NesterovUpdateCalculator<M extends SmoothParametrized<M>>
 
         M newMdl = mdl;
 
-        if (iteration > 0) {
-            Vector curParams = mdl.parameters();
-            newMdl = mdl.withParameters(curParams.minus(prevUpdates.times(momentum)));
-        }
+        if (iteration > 0)
+            newMdl = mdl.withParameters(mdl.parameters().minus(prevUpdates.times(momentum)));
 
         Vector gradient = newMdl.differentiateByParameters(loss, inputs, groundTruth);
 
-        return new NesterovParameterUpdate(prevUpdates.plus(gradient.times(learningRate)));
+        return new NesterovParameterUpdate(prevUpdates.times(momentum).plus(gradient.times(learningRate)));
     }
 
     /** {@inheritDoc} */
index 92f7583..2853f0d 100644 (file)
@@ -17,6 +17,7 @@
 
 package org.apache.ignite.ml.optimization.updatecalculators;
 
+import java.io.Serializable;
 import org.apache.ignite.ml.math.Matrix;
 import org.apache.ignite.ml.math.Vector;
 import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction;
@@ -28,7 +29,7 @@ import org.apache.ignite.ml.math.functions.IgniteFunction;
  * @param <M> Type of model to be updated.
  * @param <P> Type of parameters needed for this update calculator.
  */
-public interface ParameterUpdateCalculator<M, P> {
+public interface ParameterUpdateCalculator<M, P extends Serializable> extends Serializable {
     /**
      * Initializes the update calculator.
      *
index fd0a045..9c3c8d9 100644 (file)
@@ -31,6 +31,9 @@ import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
  * See <a href="https://paginas.fe.up.pt/~ee02162/dissertacao/RPROP%20paper.pdf">RProp</a>.</p>
  */
 public class RPropParameterUpdate implements Serializable {
+    /** */
+    private static final long serialVersionUID = -165584242642323332L;
+
     /**
      * Previous iteration parameters updates. In original paper they are labeled with "delta w".
      */
index f706a6c..569961c 100644 (file)
@@ -31,6 +31,9 @@ import org.apache.ignite.ml.optimization.SmoothParametrized;
  * See <a href="https://paginas.fe.up.pt/~ee02162/dissertacao/RPROP%20paper.pdf">RProp</a>.</p>
  */
 public class RPropUpdateCalculator implements ParameterUpdateCalculator<SmoothParametrized, RPropParameterUpdate> {
+    /** */
+    private static final long serialVersionUID = -5156816330041409864L;
+
     /**
      * Default initial update.
      */
index 13731ea..9f85c74 100644 (file)
@@ -28,6 +28,9 @@ import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
  * Parameters for {@link SimpleGDUpdateCalculator}.
  */
 public class SimpleGDParameterUpdate implements Serializable {
+    /** */
+    private static final long serialVersionUID = -8732955283436005621L;
+
     /** Gradient. */
     private Vector gradient;
 
index f102396..0056f15 100644 (file)
@@ -27,6 +27,9 @@ import org.apache.ignite.ml.optimization.SmoothParametrized;
  * Simple gradient descent parameters updater.
  */
 public class SimpleGDUpdateCalculator implements ParameterUpdateCalculator<SmoothParametrized, SimpleGDParameterUpdate> {
+    /** */
+    private static final long serialVersionUID = -4237332083320879334L;
+
     /** Learning rate. */
     private double learningRate;
 
index d7d587e..ae15f2f 100644 (file)
@@ -18,7 +18,7 @@
 package org.apache.ignite.ml.regressions.linear;
 
 import java.util.Arrays;
-import org.apache.ignite.ml.DatasetTrainer;
+import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.math.Vector;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
@@ -31,31 +31,18 @@ import org.apache.ignite.ml.math.isolve.lsqr.LSQRResult;
 /**
  * Trainer of the linear regression model based on LSQR algorithm.
  *
- * @param <K> Type of a key in {@code upstream} data.
- * @param <V> Type of a value in {@code upstream} data.
- *
  * @see AbstractLSQR
  */
-public class LinearRegressionLSQRTrainer<K, V> implements DatasetTrainer<K, V, LinearRegressionModel> {
+public class LinearRegressionLSQRTrainer implements SingleLabelDatasetTrainer<LinearRegressionModel> {
     /** {@inheritDoc} */
-    @Override public LinearRegressionModel fit(DatasetBuilder<K, V> datasetBuilder,
-        IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor, int cols) {
+    @Override public <K, V> LinearRegressionModel fit(DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
 
         LSQRResult res;
 
         try (LSQROnHeap<K, V> lsqr = new LSQROnHeap<>(
             datasetBuilder,
-            new LinSysPartitionDataBuilderOnHeap<>(
-                (k, v) -> {
-                    double[] row = Arrays.copyOf(featureExtractor.apply(k, v), cols + 1);
-
-                    row[cols] = 1.0;
-
-                    return row;
-                },
-                lbExtractor,
-                cols + 1
-            )
+            new LinSysPartitionDataBuilderOnHeap<>(new FeatureExtractorWrapper<>(featureExtractor), lbExtractor)
         )) {
             res = lsqr.solve(0, 1e-12, 1e-12, 1e8, -1, false, null);
         }
@@ -63,8 +50,42 @@ public class LinearRegressionLSQRTrainer<K, V> implements DatasetTrainer<K, V, L
             throw new RuntimeException(e);
         }
 
-        Vector weights = new DenseLocalOnHeapVector(Arrays.copyOfRange(res.getX(), 0, cols));
+        double[] x = res.getX();
+        Vector weights = new DenseLocalOnHeapVector(Arrays.copyOfRange(x, 0, x.length - 1));
+
+        return new LinearRegressionModel(weights, x[x.length - 1]);
+    }
+
+    /**
+     * Feature extractor wrapper that adds additional column filled by 1.
+     *
+     * @param <K> Type of a key in {@code upstream} data.
+     * @param <V> Type of a value in {@code upstream} data.
+     */
+    private static class FeatureExtractorWrapper<K, V> implements IgniteBiFunction<K, V, double[]> {
+        /** */
+        private static final long serialVersionUID = -2686524650955735635L;
+
+        /** Underlying feature extractor. */
+        private final IgniteBiFunction<K, V, double[]> featureExtractor;
+
+        /**
+         * Constructs a new instance of feature extractor wrapper.
+         *
+         * @param featureExtractor Underlying feature extractor.
+         */
+        FeatureExtractorWrapper(IgniteBiFunction<K, V, double[]> featureExtractor) {
+            this.featureExtractor = featureExtractor;
+        }
+
+        /** {@inheritDoc} */
+        @Override public double[] apply(K k, V v) {
+            double[] featureRow = featureExtractor.apply(k, v);
+            double[] row = Arrays.copyOf(featureRow, featureRow.length + 1);
+
+            row[featureRow.length] = 1.0;
 
-        return new LinearRegressionModel(weights, res.getX()[cols]);
+            return row;
+        }
     }
 }
index 08def93..84f5eba 100644 (file)
@@ -18,7 +18,7 @@
 package org.apache.ignite.ml.svm;
 
 import java.util.concurrent.ThreadLocalRandom;
-import org.apache.ignite.ml.DatasetTrainer;
+import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
 import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.PartitionDataBuilder;
@@ -35,8 +35,7 @@ import org.jetbrains.annotations.NotNull;
  * and +1 labels for two classes and makes binary classification. </p> The paper about this algorithm could be found
  * here https://arxiv.org/abs/1409.1458.
  */
-public class SVMLinearBinaryClassificationTrainer<K, V>
-    implements DatasetTrainer<K, V, SVMLinearBinaryClassificationModel> {
+public class SVMLinearBinaryClassificationTrainer implements SingleLabelDatasetTrainer<SVMLinearBinaryClassificationModel> {
     /** Amount of outer SDCA algorithm iterations. */
     private int amountOfIterations = 200;
 
@@ -52,17 +51,16 @@ public class SVMLinearBinaryClassificationTrainer<K, V>
      * @param datasetBuilder   Dataset builder.
      * @param featureExtractor Feature extractor.
      * @param lbExtractor      Label extractor.
-     * @param cols             Number of columns.
      * @return Model.
      */
-    @Override public SVMLinearBinaryClassificationModel fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor, int cols) {
+    @Override public <K, V> SVMLinearBinaryClassificationModel fit(DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
 
         assert datasetBuilder != null;
 
         PartitionDataBuilder<K, V, SVMPartitionContext, LabeledDataset<Double, LabeledVector>> partDataBuilder = new SVMPartitionDataBuilderOnHeap<>(
             featureExtractor,
-            lbExtractor,
-            cols
+            lbExtractor
         );
 
         Vector weights;
@@ -71,6 +69,7 @@ public class SVMLinearBinaryClassificationTrainer<K, V>
             (upstream, upstreamSize) -> new SVMPartitionContext(),
             partDataBuilder
         )) {
+            final int cols = dataset.compute(data -> data.colSize(), (a, b) -> a == null ? b : a);
             final int weightVectorSizeWithIntercept = cols + 1;
             weights = initializeWeightsWithZeros(weightVectorSizeWithIntercept);
 
index acaa4b1..cc1039f 100644 (file)
@@ -24,7 +24,7 @@ import java.util.List;
 import java.util.Set;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
-import org.apache.ignite.ml.DatasetTrainer;
+import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
 import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.PartitionDataBuilder;
@@ -39,8 +39,8 @@ import org.apache.ignite.ml.svm.multi.LabelPartitionDataOnHeap;
  *
  * All common parameters are shared with bunch of binary classification trainers.
  */
-public class SVMLinearMultiClassClassificationTrainer<K, V>
-    implements DatasetTrainer<K, V, SVMLinearMultiClassClassificationModel> {
+public class SVMLinearMultiClassClassificationTrainer
+    implements SingleLabelDatasetTrainer<SVMLinearMultiClassClassificationModel> {
     /** Amount of outer SDCA algorithm iterations. */
     private int amountOfIterations = 20;
 
@@ -56,12 +56,11 @@ public class SVMLinearMultiClassClassificationTrainer<K, V>
      * @param datasetBuilder   Dataset builder.
      * @param featureExtractor Feature extractor.
      * @param lbExtractor      Label extractor.
-     * @param cols             Number of columns.
      * @return Model.
      */
-    @Override public SVMLinearMultiClassClassificationModel fit(DatasetBuilder<K, V> datasetBuilder,
+    @Override public <K, V> SVMLinearMultiClassClassificationModel fit(DatasetBuilder<K, V> datasetBuilder,
                                                                 IgniteBiFunction<K, V, double[]> featureExtractor,
-                                                                IgniteBiFunction<K, V, Double> lbExtractor, int cols) {
+                                                                IgniteBiFunction<K, V, Double> lbExtractor) {
         List<Double> classes = extractClassLabels(datasetBuilder, lbExtractor);
 
         SVMLinearMultiClassClassificationModel multiClsMdl = new SVMLinearMultiClassClassificationModel();
@@ -80,14 +79,14 @@ public class SVMLinearMultiClassClassificationTrainer<K, V>
                 else
                     return -1.0;
             };
-            multiClsMdl.add(clsLb, trainer.fit(datasetBuilder, featureExtractor, lbTransformer, cols));
+            multiClsMdl.add(clsLb, trainer.fit(datasetBuilder, featureExtractor, lbTransformer));
         });
 
         return multiClsMdl;
     }
 
     /** Iterates among dataset and collects class labels. */
-    private List<Double> extractClassLabels(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Double> lbExtractor) {
+    private <K, V> List<Double> extractClassLabels(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Double> lbExtractor) {
         assert datasetBuilder != null;
 
         PartitionDataBuilder<K, V, LabelPartitionContext, LabelPartitionDataOnHeap> partDataBuilder = new LabelPartitionDataBuilderOnHeap<>(lbExtractor);
@@ -122,7 +121,7 @@ public class SVMLinearMultiClassClassificationTrainer<K, V>
      * @param lambda The regularization parameter. Should be more than 0.0.
      * @return Trainer with new lambda parameter value.
      */
-    public SVMLinearMultiClassClassificationTrainer<K, V>  withLambda(double lambda) {
+    public SVMLinearMultiClassClassificationTrainer  withLambda(double lambda) {
         assert lambda > 0.0;
         this.lambda = lambda;
         return this;
@@ -152,7 +151,7 @@ public class SVMLinearMultiClassClassificationTrainer<K, V>
      * @param amountOfIterations The parameter value.
      * @return Trainer with new amountOfIterations parameter value.
      */
-    public SVMLinearMultiClassClassificationTrainer<K, V>  withAmountOfIterations(int amountOfIterations) {
+    public SVMLinearMultiClassClassificationTrainer  withAmountOfIterations(int amountOfIterations) {
         this.amountOfIterations = amountOfIterations;
         return this;
     }
@@ -172,7 +171,7 @@ public class SVMLinearMultiClassClassificationTrainer<K, V>
      * @param amountOfLocIterations The parameter value.
      * @return Trainer with new amountOfLocIterations parameter value.
      */
-    public SVMLinearMultiClassClassificationTrainer<K, V>  withAmountOfLocIterations(int amountOfLocIterations) {
+    public SVMLinearMultiClassClassificationTrainer  withAmountOfLocIterations(int amountOfLocIterations) {
         this.amountOfLocIterations = amountOfLocIterations;
         return this;
     }
index 9954892..ba1b82a 100644 (file)
@@ -43,27 +43,24 @@ public class SVMPartitionDataBuilderOnHeap<K, V, C extends Serializable>
     /** Extractor of Y vector value. */
     private final IgniteBiFunction<K, V, Double> yExtractor;
 
-    /** Number of columns. */
-    private final int cols;
-
     /**
      * Constructs a new instance of SVM partition data builder.
      *
      * @param xExtractor Extractor of X matrix row.
      * @param yExtractor Extractor of Y vector value.
-     * @param cols       Number of columns.
      */
     public SVMPartitionDataBuilderOnHeap(IgniteBiFunction<K, V, double[]> xExtractor,
-                                         IgniteBiFunction<K, V, Double> yExtractor, int cols) {
+                                         IgniteBiFunction<K, V, Double> yExtractor) {
         this.xExtractor = xExtractor;
         this.yExtractor = yExtractor;
-        this.cols = cols;
     }
 
     /** {@inheritDoc} */
-    @Override public LabeledDataset<Double, LabeledVector> build(Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize,
-                                                                 C ctx) {
-        double[][] x = new double[Math.toIntExact(upstreamDataSize)][cols];
+    @Override public LabeledDataset<Double, LabeledVector> build(Iterator<UpstreamEntry<K, V>> upstreamData,
+        long upstreamDataSize, C ctx) {
+
+        int xCols = -1;
+        double[][] x = null;
         double[] y = new double[Math.toIntExact(upstreamDataSize)];
 
         int ptr = 0;
@@ -72,7 +69,12 @@ public class SVMPartitionDataBuilderOnHeap<K, V, C extends Serializable>
             UpstreamEntry<K, V> entry = upstreamData.next();
             double[] row = xExtractor.apply(entry.getKey(), entry.getValue());
 
-            assert row.length == cols : "X extractor must return exactly " + cols + " columns";
+            if (xCols < 0) {
+                xCols = row.length;
+                x = new double[Math.toIntExact(upstreamDataSize)][xCols];
+            }
+            else
+                assert row.length == xCols : "X extractor must return exactly " + xCols + " columns";
 
             x[ptr] = row;
 
  * limitations under the License.
  */
 
-package org.apache.ignite.ml;
+package org.apache.ignite.ml.trainers;
 
+import org.apache.ignite.ml.Model;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 
 /**
  * Interface for trainers. Trainer is just a function which produces model from the data.
  *
- * @param <K> Type of a key in {@code upstream} data.
- * @param <V> Type of a value in {@code upstream} data.
  * @param <M> Type of a produced model.
+ * @param <L> Type of a label.
  */
-public interface DatasetTrainer<K, V, M extends Model> {
+public interface DatasetTrainer<M extends Model, L> {
     /**
      * Trains model based on the specified data.
      *
      * @param datasetBuilder Dataset builder.
      * @param featureExtractor Feature extractor.
      * @param lbExtractor Label extractor.
-     * @param cols Number of columns.
+     * @param <K> Type of a key in {@code upstream} data.
+     * @param <V> Type of a value in {@code upstream} data.
      * @return Model.
      */
-    public M fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, double[]> featureExtractor,
-        IgniteBiFunction<K, V, Double> lbExtractor, int cols);
+    public <K, V> M fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, double[]> featureExtractor,
+        IgniteBiFunction<K, V, L> lbExtractor);
 }
  * limitations under the License.
  */
 
+package org.apache.ignite.ml.trainers;
+
+import org.apache.ignite.ml.Model;
+
 /**
- * <!-- Package description. -->
- * Contains multilayer perceptron local trainers.
+ * Interface for trainers that trains on dataset with multiple label per object.
+ *
+ * @param <M> Type of a produced model.
  */
-package org.apache.ignite.ml.nn.trainers.local;
\ No newline at end of file
+public interface MultiLabelDatasetTrainer<M extends Model> extends DatasetTrainer<M, double[]> {
+}
  * limitations under the License.
  */
 
+package org.apache.ignite.ml.trainers;
+
+import org.apache.ignite.ml.Model;
+
 /**
- * <!-- Package description. -->
- * Contains local trainers.
+ * Interface for trainers that trains on dataset with singe label per object.
+ *
+ * @param <M> Type of a produced model.
  */
-package org.apache.ignite.ml.trainers.local;
\ No newline at end of file
+public interface SingleLabelDatasetTrainer<M extends Model> extends DatasetTrainer<M, Double> {
+}
index 9deb460..5288dbf 100644 (file)
@@ -17,6 +17,7 @@
 
 package org.apache.ignite.ml.trainers.group;
 
+import java.io.Serializable;
 import java.util.List;
 import org.apache.ignite.ml.math.functions.IgniteFunction;
 import org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator;
@@ -27,7 +28,7 @@ import org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalcul
  * @param <M> Type of model to be optimized.
  * @param <U> Type of update.
  */
-public class UpdatesStrategy<M, U> {
+public class UpdatesStrategy<M, U extends Serializable> {
     /**
      * {@link ParameterUpdateCalculator}.
      */
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/local/LocalBatchTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/local/LocalBatchTrainer.java
deleted file mode 100644 (file)
index cb6fd89..0000000
+++ /dev/null
@@ -1,178 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.trainers.local;
-
-import org.apache.ignite.IgniteLogger;
-import org.apache.ignite.lang.IgniteBiTuple;
-import org.apache.ignite.ml.Model;
-import org.apache.ignite.ml.Trainer;
-import org.apache.ignite.ml.math.Matrix;
-import org.apache.ignite.ml.math.Vector;
-import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction;
-import org.apache.ignite.ml.math.functions.IgniteFunction;
-import org.apache.ignite.ml.math.functions.IgniteSupplier;
-import org.apache.ignite.ml.math.util.MatrixUtil;
-import org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator;
-
-/**
- * Batch trainer. This trainer is not distributed on the cluster, but input can theoretically read data from
- * Ignite cache.
- */
-public class LocalBatchTrainer<M extends Model<Matrix, Matrix>, P>
-    implements Trainer<M, LocalBatchTrainerInput<M>> {
-    /**
-     * Supplier for updater function.
-     */
-    private final IgniteSupplier<ParameterUpdateCalculator<? super M, P>> updaterSupplier;
-
-    /**
-     * Error threshold.
-     */
-    private final double errorThreshold;
-
-    /**
-     * Maximal iterations count.
-     */
-    private final int maxIterations;
-
-    /**
-     * Loss function.
-     */
-    private final IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss;
-
-    /**
-     * Logger.
-     */
-    private IgniteLogger log;
-
-    /**
-     * Construct a trainer.
-     *
-     * @param loss Loss function.
-     * @param updaterSupplier Supplier of updater function.
-     * @param errorThreshold Error threshold.
-     * @param maxIterations Maximal iterations count.
-     */
-    public LocalBatchTrainer(IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss,
-        IgniteSupplier<ParameterUpdateCalculator<? super M, P>> updaterSupplier, double errorThreshold, int maxIterations) {
-        this.loss = loss;
-        this.updaterSupplier = updaterSupplier;
-        this.errorThreshold = errorThreshold;
-        this.maxIterations = maxIterations;
-    }
-
-    /** {@inheritDoc} */
-    @Override public M train(LocalBatchTrainerInput<M> data) {
-        int i = 0;
-        M mdl = data.mdl();
-        double err;
-
-        ParameterUpdateCalculator<? super M, P> updater = updaterSupplier.get();
-
-        P updaterParams = updater.init(mdl, loss);
-
-        while (i < maxIterations) {
-            IgniteBiTuple<Matrix, Matrix> batch = data.batchSupplier().get();
-            Matrix input = batch.get1();
-            Matrix truth = batch.get2();
-
-            updaterParams = updater.calculateNewUpdate(mdl, updaterParams, i, input, truth);
-
-            // Update mdl with updater parameters.
-            mdl = updater.update(mdl, updaterParams);
-
-            Matrix predicted = mdl.apply(input);
-
-            int batchSize = input.columnSize();
-
-            err = MatrixUtil.zipFoldByColumns(predicted, truth, (predCol, truthCol) ->
-                loss.apply(truthCol).apply(predCol)).sum() / batchSize;
-
-            debug("Error: " + err);
-
-            if (err < errorThreshold)
-                break;
-
-            i++;
-        }
-
-        return mdl;
-    }
-
-    /**
-     * Construct new trainer with the same parameters as this trainer, but with new loss.
-     *
-     * @param loss New loss function.
-     * @return new trainer with the same parameters as this trainer, but with new loss.
-     */
-    public LocalBatchTrainer withLoss(IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss) {
-        return new LocalBatchTrainer<>(loss, updaterSupplier, errorThreshold, maxIterations);
-    }
-
-    /**
-     * Construct new trainer with the same parameters as this trainer, but with new updater supplier.
-     *
-     * @param updaterSupplier New updater supplier.
-     * @return new trainer with the same parameters as this trainer, but with new updater supplier.
-     */
-    public LocalBatchTrainer withUpdater(IgniteSupplier<ParameterUpdateCalculator<? super M, P>> updaterSupplier) {
-        return new LocalBatchTrainer<>(loss, updaterSupplier, errorThreshold, maxIterations);
-    }
-
-    /**
-     * Construct new trainer with the same parameters as this trainer, but with new error threshold.
-     *
-     * @param errorThreshold New error threshold.
-     * @return new trainer with the same parameters as this trainer, but with new error threshold.
-     */
-    public LocalBatchTrainer withErrorThreshold(double errorThreshold) {
-        return new LocalBatchTrainer<>(loss, updaterSupplier, errorThreshold, maxIterations);
-    }
-
-    /**
-     * Construct new trainer with the same parameters as this trainer, but with new maximal iterations count.
-     *
-     * @param maxIterations New maximal iterations count.
-     * @return new trainer with the same parameters as this trainer, but with new maximal iterations count.
-     */
-    public LocalBatchTrainer withMaxIterations(int maxIterations) {
-        return new LocalBatchTrainer<>(loss, updaterSupplier, errorThreshold, maxIterations);
-    }
-
-    /**
-     * Set logger.
-     *
-     * @param log Logger.
-     * @return This object.
-     */
-    public LocalBatchTrainer setLogger(IgniteLogger log) {
-        this.log = log;
-
-        return this;
-    }
-
-    /**
-     * Output debug message.
-     *
-     * @param msg Message.
-     */
-    private void debug(String msg) {
-        if (log != null && log.isDebugEnabled())
-            log.debug(msg);
-    }
-}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/local/LocalBatchTrainerInput.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/local/LocalBatchTrainerInput.java
deleted file mode 100644 (file)
index 38b7592..0000000
+++ /dev/null
@@ -1,42 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.trainers.local;
-
-import org.apache.ignite.lang.IgniteBiTuple;
-import org.apache.ignite.ml.Model;
-import org.apache.ignite.ml.math.Matrix;
-import org.apache.ignite.ml.math.functions.IgniteSupplier;
-
-/**
- * Interface for classes containing input parameters for LocalBatchTrainer.
- */
-public interface LocalBatchTrainerInput<M extends Model<Matrix, Matrix>> {
-    /**
-     * Get supplier of next batch in form of matrix of inputs and matrix of outputs.
-     *
-     * @return Supplier of next batch.
-     */
-    IgniteSupplier<IgniteBiTuple<Matrix, Matrix>> batchSupplier();
-
-    /**
-     * Model to train.
-     *
-     * @return Model to train.
-     */
-    M mdl();
-}
index dd5640c..5fc50b5 100644 (file)
@@ -20,6 +20,7 @@ package org.apache.ignite.ml.util;
 import java.io.FileInputStream;
 import java.io.FileWriter;
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
@@ -43,7 +44,7 @@ public class MnistUtils {
      * @return Stream of MNIST samples.
      * @throws IgniteException In case of exception.
      */
-    public static Stream<DenseLocalOnHeapVector> mnist(String imagesPath, String labelsPath, Random rnd, int cnt)
+    public static Stream<DenseLocalOnHeapVector> mnistAsStream(String imagesPath, String labelsPath, Random rnd, int cnt)
         throws IOException {
         FileInputStream isImages = new FileInputStream(imagesPath);
         FileInputStream isLabels = new FileInputStream(labelsPath);
@@ -78,6 +79,50 @@ public class MnistUtils {
     }
 
     /**
+     * Read random {@code count} samples from MNIST dataset from two files (images and labels) into a stream of labeled
+     * vectors.
+     *
+     * @param imagesPath Path to the file with images.
+     * @param labelsPath Path to the file with labels.
+     * @param rnd Random numbers generator.
+     * @param cnt Count of samples to read.
+     * @return List of MNIST samples.
+     * @throws IOException In case of exception.
+     */
+    public static List<MnistLabeledImage> mnistAsList(String imagesPath, String labelsPath, Random rnd, int cnt) throws IOException {
+
+        List<MnistLabeledImage> res = new ArrayList<>();
+
+        try (
+            FileInputStream isImages = new FileInputStream(imagesPath);
+            FileInputStream isLabels = new FileInputStream(labelsPath)
+        ) {
+            read4Bytes(isImages); // Skip magic number.
+            int numOfImages = read4Bytes(isImages);
+            int imgHeight = read4Bytes(isImages);
+            int imgWidth = read4Bytes(isImages);
+
+            read4Bytes(isLabels); // Skip magic number.
+            read4Bytes(isLabels); // Skip number of labels.
+
+            int numOfPixels = imgHeight * imgWidth;
+
+            for (int imgNum = 0; imgNum < numOfImages; imgNum++) {
+                double[] pixels = new double[numOfPixels];
+                for (int p = 0; p < numOfPixels; p++) {
+                    int c = 128 - isImages.read();
+                    pixels[p] = ((double)c) / 128;
+                }
+                res.add(new MnistLabeledImage(pixels, isLabels.read()));
+            }
+        }
+
+        Collections.shuffle(res, rnd);
+
+        return res.subList(0, cnt);
+    }
+
+    /**
      * Convert random {@code count} samples from MNIST dataset from two files (images and labels) into libsvm format.
      *
      * @param imagesPath Path to the file with images.
@@ -91,7 +136,7 @@ public class MnistUtils {
         throws IOException {
 
         try (FileWriter fos = new FileWriter(outPath)) {
-            mnist(imagesPath, labelsPath, rnd, cnt).forEach(vec -> {
+            mnistAsStream(imagesPath, labelsPath, rnd, cnt).forEach(vec -> {
                 try {
                     fos.write((int)vec.get(vec.size() - 1) + " ");
 
@@ -121,4 +166,50 @@ public class MnistUtils {
     private static int read4Bytes(FileInputStream is) throws IOException {
         return (is.read() << 24) | (is.read() << 16) | (is.read() << 8) | (is.read());
     }
+
+    /**
+     * MNIST image.
+     */
+    public static class MnistImage {
+        /** Pixels. */
+        private final double[] pixels;
+
+        /**
+         * Construct a new instance of MNIST image.
+         *
+         * @param pixels Pixels.
+         */
+        public MnistImage(double[] pixels) {
+            this.pixels = pixels;
+        }
+
+        /** */
+        public double[] getPixels() {
+            return pixels;
+        }
+    }
+
+    /**
+     * MNIST labeled image.
+     */
+    public static class MnistLabeledImage extends MnistImage {
+        /** Label. */
+        private final int lb;
+
+        /**
+         * Constructs a new instance of MNIST labeled image.
+         *
+         * @param pixels Pixels.
+         * @param lb Label.
+         */
+        public MnistLabeledImage(double[] pixels, int lb) {
+            super(pixels);
+            this.lb = lb;
+        }
+
+        /** */
+        public int getLabel() {
+            return lb;
+        }
+    }
 }
index 4892ff8..ec9fdaa 100644 (file)
@@ -66,8 +66,7 @@ public class LSQROnHeapTest {
             datasetBuilder,
             new LinSysPartitionDataBuilderOnHeap<>(
                 (k, v) -> Arrays.copyOf(v, v.length - 1),
-                (k, v) -> v[3],
-                3
+                (k, v) -> v[3]
             )
         );
 
@@ -90,8 +89,7 @@ public class LSQROnHeapTest {
             datasetBuilder,
             new LinSysPartitionDataBuilderOnHeap<>(
                 (k, v) -> Arrays.copyOf(v, v.length - 1),
-                (k, v) -> v[3],
-                3
+                (k, v) -> v[3]
             )
         );
 
@@ -122,8 +120,7 @@ public class LSQROnHeapTest {
             datasetBuilder,
             new LinSysPartitionDataBuilderOnHeap<>(
                 (k, v) -> Arrays.copyOf(v, v.length - 1),
-                (k, v) -> v[4],
-                4
+                (k, v) -> v[4]
             )
         )) {
             LSQRResult res = lsqr.solve(0, 1e-12, 1e-12, 1e8, -1, false, null);
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPGroupTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPGroupTrainerTest.java
deleted file mode 100644 (file)
index abd8ad2..0000000
+++ /dev/null
@@ -1,148 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.nn;
-
-import java.io.Serializable;
-import java.util.Random;
-import org.apache.ignite.Ignite;
-import org.apache.ignite.IgniteCache;
-import org.apache.ignite.IgniteDataStreamer;
-import org.apache.ignite.internal.util.typedef.X;
-import org.apache.ignite.ml.TestUtils;
-import org.apache.ignite.ml.math.Matrix;
-import org.apache.ignite.ml.math.StorageConstants;
-import org.apache.ignite.ml.math.Tracer;
-import org.apache.ignite.ml.math.Vector;
-import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
-import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
-import org.apache.ignite.ml.nn.initializers.RandomInitializer;
-import org.apache.ignite.ml.nn.trainers.distributed.MLPGroupUpdateTrainer;
-import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
-import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
-import org.apache.ignite.ml.structures.LabeledVector;
-import org.apache.ignite.ml.trainers.group.UpdateStrategies;
-import org.apache.ignite.ml.trainers.group.UpdatesStrategy;
-import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
-
-/**
- * Test group trainer.
- */
-public class MLPGroupTrainerTest extends GridCommonAbstractTest {
-    /** Count of nodes. */
-    private static final int NODE_COUNT = 3;
-
-    /** Grid instance. */
-    private Ignite ignite;
-
-    /**
-     * {@inheritDoc}
-     */
-    @Override protected void beforeTest() throws Exception {
-        ignite = grid(NODE_COUNT);
-    }
-
-    /** {@inheritDoc} */
-    @Override protected void beforeTestsStarted() throws Exception {
-        for (int i = 1; i <= NODE_COUNT; i++)
-            startGrid(i);
-    }
-
-    /** {@inheritDoc} */
-    @Override protected void afterTestsStopped() throws Exception {
-        stopAllGrids();
-    }
-
-    /**
-     * Test training 'xor' by RProp.
-     */
-    public void testXORRProp() {
-        doTestXOR(UpdateStrategies.RProp());
-    }
-
-    /**
-     * Test training 'xor' by SimpleGD.
-     */
-    public void testXORGD() {
-        doTestXOR(new UpdatesStrategy<>(
-            new SimpleGDUpdateCalculator().withLearningRate(0.5),
-            SimpleGDParameterUpdate::sumLocal,
-            SimpleGDParameterUpdate::avg));
-    }
-
-    /**
-     * Test training of 'xor' by {@link MLPGroupUpdateTrainer}.
-     */
-    private <U extends Serializable> void doTestXOR(UpdatesStrategy<? super MultilayerPerceptron, U> stgy) {
-        int samplesCnt = 1000;
-
-        Matrix xorInputs = new DenseLocalOnHeapMatrix(new double[][] {{0.0, 0.0}, {0.0, 1.0}, {1.0, 0.0}, {1.0, 1.0}},
-            StorageConstants.ROW_STORAGE_MODE).transpose();
-
-        Matrix xorOutputs = new DenseLocalOnHeapMatrix(new double[][] {{0.0}, {1.0}, {1.0}, {0.0}},
-            StorageConstants.ROW_STORAGE_MODE).transpose();
-
-        MLPArchitecture conf = new MLPArchitecture(2).
-            withAddedLayer(10, true, Activators.RELU).
-            withAddedLayer(1, false, Activators.SIGMOID);
-
-        IgniteCache<Integer, LabeledVector<Vector, Vector>> cache = LabeledVectorsCache.createNew(ignite);
-        String cacheName = cache.getName();
-        Random rnd = new Random(12345L);
-
-        try (IgniteDataStreamer<Integer, LabeledVector<Vector, Vector>> streamer =
-                 ignite.dataStreamer(cacheName)) {
-            streamer.perNodeBufferSize(10000);
-
-            for (int i = 0; i < samplesCnt; i++) {
-                int col = Math.abs(rnd.nextInt()) % 4;
-                streamer.addData(i, new LabeledVector<>(xorInputs.getCol(col), xorOutputs.getCol(col)));
-            }
-        }
-
-        int totalCnt = 30;
-        int failCnt = 0;
-        double maxFailRatio = 0.3;
-
-        MLPGroupUpdateTrainer<U> trainer = MLPGroupUpdateTrainer.getDefault(ignite).
-            withSyncPeriod(3).
-            withTolerance(0.001).
-            withMaxGlobalSteps(100).
-            withUpdateStrategy(stgy);
-
-        for (int i = 0; i < totalCnt; i++) {
-            MLPGroupUpdateTrainerCacheInput trainerInput = new MLPGroupUpdateTrainerCacheInput(conf,
-                new RandomInitializer(new Random(123L + i)), 6, cache, 10, new Random(123L + i));
-
-            MultilayerPerceptron mlp = trainer.train(trainerInput);
-
-            Matrix predict = mlp.apply(xorInputs);
-
-            Tracer.showAscii(predict);
-
-            X.println(xorOutputs.getRow(0).minus(predict.getRow(0)).kNorm(2) + "");
-
-            failCnt += TestUtils.checkIsInEpsilonNeighbourhoodBoolean(xorOutputs.getRow(0), predict.getRow(0), 5E-1) ? 0 : 1;
-        }
-
-        double failRatio = (double)failCnt / totalCnt;
-
-        System.out.println("Fail percentage: " + (failRatio * 100) + "%.");
-
-        assertTrue(failRatio < maxFailRatio);
-    }
-}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPLocalTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPLocalTrainerTest.java
deleted file mode 100644 (file)
index 3119170..0000000
+++ /dev/null
@@ -1,97 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.nn;
-
-import java.util.Random;
-import org.apache.ignite.internal.util.typedef.X;
-import org.apache.ignite.ml.TestUtils;
-import org.apache.ignite.ml.math.Matrix;
-import org.apache.ignite.ml.math.StorageConstants;
-import org.apache.ignite.ml.math.Tracer;
-import org.apache.ignite.ml.math.functions.IgniteSupplier;
-import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
-import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
-import org.apache.ignite.ml.nn.trainers.local.MLPLocalBatchTrainer;
-import org.apache.ignite.ml.optimization.LossFunctions;
-import org.apache.ignite.ml.optimization.updatecalculators.NesterovUpdateCalculator;
-import org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator;
-import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator;
-import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
-import org.junit.Test;
-
-/**
- * Tests for {@link MLPLocalBatchTrainer}.
- */
-public class MLPLocalTrainerTest {
-    /**
-     * Test 'XOR' operation training with {@link SimpleGDUpdateCalculator} updater.
-     */
-    @Test
-    public void testXORSimpleGD() {
-        xorTest(() -> new SimpleGDUpdateCalculator(0.3));
-    }
-
-    /**
-     * Test 'XOR' operation training with {@link RPropUpdateCalculator}.
-     */
-    @Test
-    public void testXORRProp() {
-        xorTest(RPropUpdateCalculator::new);
-    }
-
-    /**
-     * Test 'XOR' operation training with {@link NesterovUpdateCalculator}.
-     */
-    @Test
-    public void testXORNesterov() {
-        xorTest(() -> new NesterovUpdateCalculator<>(0.1, 0.7));
-    }
-
-    /**
-     * Common method for testing 'XOR' with various updaters.
-     * @param updaterSupplier Updater supplier.
-     * @param <P> Updater parameters type.
-     */
-    private <P> void xorTest(IgniteSupplier<ParameterUpdateCalculator<? super MultilayerPerceptron, P>> updaterSupplier) {
-        Matrix xorInputs = new DenseLocalOnHeapMatrix(new double[][] {{0.0, 0.0}, {0.0, 1.0}, {1.0, 0.0}, {1.0, 1.0}},
-            StorageConstants.ROW_STORAGE_MODE).transpose();
-
-        Matrix xorOutputs = new DenseLocalOnHeapMatrix(new double[][] {{0.0}, {1.0}, {1.0}, {0.0}},
-            StorageConstants.ROW_STORAGE_MODE).transpose();
-
-        MLPArchitecture conf = new MLPArchitecture(2).
-            withAddedLayer(10, true, Activators.RELU).
-            withAddedLayer(1, false, Activators.SIGMOID);
-
-        SimpleMLPLocalBatchTrainerInput trainerInput = new SimpleMLPLocalBatchTrainerInput(conf,
-            new Random(123L), xorInputs, xorOutputs, 4);
-
-        MultilayerPerceptron mlp = new MLPLocalBatchTrainer<>(LossFunctions.MSE,
-            updaterSupplier,
-            0.0001,
-            16000).train(trainerInput);
-
-        Matrix predict = mlp.apply(xorInputs);
-
-        Tracer.showAscii(predict);
-
-        X.println(xorOutputs.getRow(0).minus(predict.getRow(0)).kNorm(2) + "");
-
-        TestUtils.checkIsInEpsilonNeighbourhood(xorOutputs.getRow(0), predict.getRow(0), 1E-1);
-    }
-}
index 555abce..3072abb 100644 (file)
@@ -30,7 +30,7 @@ import org.junit.Assert;
 import org.junit.Test;
 
 /**
- * Tests for Multilayer perceptron.
+ * Tests for {@link MultilayerPerceptron}.
  */
 public class MLPTest {
     /**
@@ -66,12 +66,12 @@ public class MLPTest {
         mlp.setWeights(2, new DenseLocalOnHeapMatrix(new double[][] {{20.0, 20.0}}));
         mlp.setBiases(2, new DenseLocalOnHeapVector(new double[] {-30.0}));
 
-        Matrix input = new DenseLocalOnHeapMatrix(new double[][] {{0.0, 0.0}, {0.0, 1.0}, {1.0, 0.0}, {1.0, 1.0}}).transpose();
+        Matrix input = new DenseLocalOnHeapMatrix(new double[][] {{0.0, 0.0}, {0.0, 1.0}, {1.0, 0.0}, {1.0, 1.0}});
 
         Matrix predict = mlp.apply(input);
-        Vector truth = new DenseLocalOnHeapVector(new double[] {0.0, 1.0, 1.0, 0.0});
+        Matrix truth = new DenseLocalOnHeapMatrix(new double[][] {{0.0}, {1.0}, {1.0}, {0.0}});
 
-        TestUtils.checkIsInEpsilonNeighbourhood(predict.getRow(0), truth, 1E-4);
+        TestUtils.checkIsInEpsilonNeighbourhood(predict.getRow(0), truth.getRow(0), 1E-4);
     }
 
     /**
@@ -99,8 +99,8 @@ public class MLPTest {
 
         MultilayerPerceptron stackedMLP = mlp1.add(mlp2);
 
-        Matrix predict = mlp.apply(new DenseLocalOnHeapMatrix(new double[][] {{1, 2, 3, 4}}).transpose());
-        Matrix stackedPredict = stackedMLP.apply(new DenseLocalOnHeapMatrix(new double[][] {{1, 2, 3, 4}}).transpose());
+        Matrix predict = mlp.apply(new DenseLocalOnHeapMatrix(new double[][] {{1}, {2}, {3}, {4}}).transpose());
+        Matrix stackedPredict = stackedMLP.apply(new DenseLocalOnHeapMatrix(new double[][] {{1}, {2}, {3}, {4}}).transpose());
 
         Assert.assertEquals(predict, stackedPredict);
     }
index 7a3edf0..2e41813 100644 (file)
@@ -26,8 +26,8 @@ import org.junit.runners.Suite;
 @RunWith(Suite.class)
 @Suite.SuiteClasses({
     MLPTest.class,
-    MLPLocalTrainerTest.class,
-    MLPGroupTrainerTest.class,
+    MLPTrainerTest.class,
+    MLPTrainerIntegrationTest.class
 })
 public class MLPTestSuite {
     // No-op.
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerIntegrationTest.java
new file mode 100644 (file)
index 0000000..5ca661f
--- /dev/null
@@ -0,0 +1,187 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.nn;
+
+import java.io.Serializable;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
+import org.apache.ignite.configuration.CacheConfiguration;
+import org.apache.ignite.internal.util.IgniteUtils;
+import org.apache.ignite.internal.util.typedef.X;
+import org.apache.ignite.ml.TestUtils;
+import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.Tracer;
+import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
+import org.apache.ignite.ml.optimization.LossFunctions;
+import org.apache.ignite.ml.optimization.updatecalculators.NesterovParameterUpdate;
+import org.apache.ignite.ml.optimization.updatecalculators.NesterovUpdateCalculator;
+import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate;
+import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
+import org.apache.ignite.ml.trainers.group.UpdatesStrategy;
+import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
+
+/**
+ * Tests for {@link MLPTrainer} that require to start the whole Ignite infrastructure.
+ */
+public class MLPTrainerIntegrationTest extends GridCommonAbstractTest {
+    /** Number of nodes in grid */
+    private static final int NODE_COUNT = 3;
+
+    /** Ignite instance. */
+    private Ignite ignite;
+
+    /** {@inheritDoc} */
+    @Override protected void beforeTestsStarted() throws Exception {
+        for (int i = 1; i <= NODE_COUNT; i++)
+            startGrid(i);
+    }
+
+    /** {@inheritDoc} */
+    @Override protected void afterTestsStopped() {
+        stopAllGrids();
+    }
+
+    /**
+     * {@inheritDoc}
+     */
+    @Override protected void beforeTest() throws Exception {
+        /* Grid instance. */
+        ignite = grid(NODE_COUNT);
+        ignite.configuration().setPeerClassLoadingEnabled(true);
+        IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+    }
+
+    /**
+     * Test 'XOR' operation training with {@link SimpleGDUpdateCalculator}.
+     */
+    public void testXORSimpleGD() {
+        xorTest(new UpdatesStrategy<>(
+            new SimpleGDUpdateCalculator(0.3),
+            SimpleGDParameterUpdate::sumLocal,
+            SimpleGDParameterUpdate::avg
+        ));
+    }
+
+    /**
+     * Test 'XOR' operation training with {@link RPropUpdateCalculator}.
+     */
+    public void testXORRProp() {
+        xorTest(new UpdatesStrategy<>(
+            new RPropUpdateCalculator(),
+            RPropParameterUpdate::sumLocal,
+            RPropParameterUpdate::avg
+        ));
+    }
+
+    /**
+     * Test 'XOR' operation training with {@link NesterovUpdateCalculator}.
+     */
+    public void testXORNesterov() {
+        xorTest(new UpdatesStrategy<>(
+            new NesterovUpdateCalculator<MultilayerPerceptron>(0.1, 0.7),
+            NesterovParameterUpdate::sum,
+            NesterovParameterUpdate::avg
+        ));
+    }
+
+    /**
+     * Common method for testing 'XOR' with various updaters.
+     * @param updatesStgy Update strategy.
+     * @param <P> Updater parameters type.
+     */
+    private <P extends Serializable> void xorTest(UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy) {
+        CacheConfiguration<Integer, LabeledPoint> xorCacheCfg = new CacheConfiguration<>();
+        xorCacheCfg.setName("XorData");
+        xorCacheCfg.setAffinity(new RendezvousAffinityFunction(false, 5));
+        IgniteCache<Integer, LabeledPoint> xorCache = ignite.createCache(xorCacheCfg);
+
+        try {
+            xorCache.put(0, new LabeledPoint(0.0, 0.0, 0.0));
+            xorCache.put(1, new LabeledPoint(0.0, 1.0, 1.0));
+            xorCache.put(2, new LabeledPoint(1.0, 0.0, 1.0));
+            xorCache.put(3, new LabeledPoint(1.0, 1.0, 0.0));
+
+            MLPArchitecture arch = new MLPArchitecture(2).
+                withAddedLayer(10, true, Activators.RELU).
+                withAddedLayer(1, false, Activators.SIGMOID);
+
+            MLPTrainer<P> trainer = new MLPTrainer<>(
+                arch,
+                LossFunctions.MSE,
+                updatesStgy,
+                2500,
+                4,
+                50,
+                123L
+            );
+
+            MultilayerPerceptron mlp = trainer.fit(
+                new CacheBasedDatasetBuilder<>(ignite, xorCache),
+                (k, v) -> new double[]{ v.x, v.y },
+                (k, v) -> new double[]{ v.lb}
+            );
+
+            Matrix predict = mlp.apply(new DenseLocalOnHeapMatrix(new double[][]{
+                {0.0, 0.0},
+                {0.0, 1.0},
+                {1.0, 0.0},
+                {1.0, 1.0}
+            }));
+
+            Tracer.showAscii(predict);
+
+            X.println(new DenseLocalOnHeapVector(new double[]{0.0}).minus(predict.getRow(0)).kNorm(2) + "");
+
+            TestUtils.checkIsInEpsilonNeighbourhood(new DenseLocalOnHeapVector(new double[]{0.0}), predict.getRow(0), 1E-1);
+        }
+        finally {
+            xorCache.destroy();
+        }
+    }
+
+    /** Labeled point data class. */
+    private static class LabeledPoint {
+        /** X coordinate. */
+        private final double x;
+
+        /** Y coordinate. */
+        private final double y;
+
+        /** Point label. */
+        private final double lb;
+
+        /**
+         * Constructs a new instance of labeled point data.
+         *
+         * @param x X coordinate.
+         * @param y Y coordinate.
+         * @param lb Point label.
+         */
+        public LabeledPoint(double x, double y, double lb) {
+            this.x = x;
+            this.y = y;
+            this.lb = lb;
+        }
+    }
+}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java
new file mode 100644 (file)
index 0000000..6906424
--- /dev/null
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.nn;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import org.apache.ignite.ml.TestUtils;
+import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
+import org.apache.ignite.ml.optimization.LossFunctions;
+import org.apache.ignite.ml.optimization.updatecalculators.NesterovParameterUpdate;
+import org.apache.ignite.ml.optimization.updatecalculators.NesterovUpdateCalculator;
+import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate;
+import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
+import org.apache.ignite.ml.trainers.group.UpdatesStrategy;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.experimental.runners.Enclosed;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+/**
+ * Tests for {@link MLPTrainer} that don't require to start the whole Ignite infrastructure.
+ */
+@RunWith(Enclosed.class)
+public class MLPTrainerTest {
+    /**
+     * Parameterized tests.
+     */
+    @RunWith(Parameterized.class)
+    public static class ComponentParamTests {
+        /** Number of parts to be tested. */
+        private static final int[] partsToBeTested = new int[] {1, 2, 3, 4, 5, 7};
+
+        /** Batch sizes to be tested. */
+        private static final int[] batchSizesToBeTested = new int[] {1, 2, 3, 4};
+
+        /** Parameters. */
+        @Parameterized.Parameters(name = "Data divided on {0} partitions, training with batch size {1}")
+        public static Iterable<Integer[]> data() {
+            List<Integer[]> res = new ArrayList<>();
+            for (int part : partsToBeTested)
+                for (int batchSize1 : batchSizesToBeTested)
+                    res.add(new Integer[] {part, batchSize1});
+
+            return res;
+        }
+
+        /** Number of partitions. */
+        @Parameterized.Parameter
+        public int parts;
+
+        /** Batch size. */
+        @Parameterized.Parameter(1)
+        public int batchSize;
+
+        /**
+         * Test 'XOR' operation training with {@link SimpleGDUpdateCalculator} updater.
+         */
+        @Test
+        public void testXORSimpleGD() {
+            xorTest(new UpdatesStrategy<>(
+                new SimpleGDUpdateCalculator(0.2),
+                SimpleGDParameterUpdate::sumLocal,
+                SimpleGDParameterUpdate::avg
+            ));
+        }
+
+        /**
+         * Test 'XOR' operation training with {@link RPropUpdateCalculator}.
+         */
+        @Test
+        public void testXORRProp() {
+            xorTest(new UpdatesStrategy<>(
+                new RPropUpdateCalculator(),
+                RPropParameterUpdate::sumLocal,
+                RPropParameterUpdate::avg
+            ));
+        }
+
+        /**
+         * Test 'XOR' operation training with {@link NesterovUpdateCalculator}.
+         */
+        @Test
+        public void testXORNesterov() {
+            xorTest(new UpdatesStrategy<>(
+                new NesterovUpdateCalculator<MultilayerPerceptron>(0.1, 0.7),
+                NesterovParameterUpdate::sum,
+                NesterovParameterUpdate::avg
+            ));
+        }
+
+        /**
+         * Common method for testing 'XOR' with various updaters.
+         * @param updatesStgy Update strategy.
+         * @param <P> Updater parameters type.
+         */
+        private <P extends Serializable> void xorTest(UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy) {
+            Map<Integer, double[][]> xorData = new HashMap<>();
+            xorData.put(0, new double[][]{{0.0, 0.0}, {0.0}});
+            xorData.put(1, new double[][]{{0.0, 1.0}, {1.0}});
+            xorData.put(2, new double[][]{{1.0, 0.0}, {1.0}});
+            xorData.put(3, new double[][]{{1.0, 1.0}, {0.0}});
+
+            MLPArchitecture arch = new MLPArchitecture(2).
+                withAddedLayer(10, true, Activators.RELU).
+                withAddedLayer(1, false, Activators.SIGMOID);
+
+            MLPTrainer<P> trainer = new MLPTrainer<>(
+                arch,
+                LossFunctions.MSE,
+                updatesStgy,
+                3000,
+                batchSize,
+                50,
+                123L
+            );
+
+            MultilayerPerceptron mlp = trainer.fit(
+                new LocalDatasetBuilder<>(xorData, parts),
+                (k, v) -> v[0],
+                (k, v) -> v[1]
+            );
+
+            Matrix predict = mlp.apply(new DenseLocalOnHeapMatrix(new double[][]{
+                {0.0, 0.0},
+                {0.0, 1.0},
+                {1.0, 0.0},
+                {1.0, 1.0}
+            }));
+
+            TestUtils.checkIsInEpsilonNeighbourhood(new DenseLocalOnHeapVector(new double[]{0.0}), predict.getRow(0), 1E-1);
+        }
+    }
+
+    /**
+     * Non-parameterized tests.
+     */
+    public static class ComponentSingleTests {
+        /** Data. */
+        private double[] data;
+
+        /** Initialization. */
+        @Before
+        public void init() {
+            data = new double[10];
+            for (int i = 0; i < 10; i++)
+                data[i] = i;
+        }
+
+        /** */
+        @Test
+        public void testBatchWithSingleColumnAndSingleRow() {
+            double[] res = MLPTrainer.batch(data, new int[]{1}, 10);
+
+            TestUtils.assertEquals(new double[]{1.0}, res, 1e-12);
+        }
+
+        /** */
+        @Test
+        public void testBatchWithMultiColumnAndSingleRow() {
+            double[] res = MLPTrainer.batch(data, new int[]{1}, 5);
+
+            TestUtils.assertEquals(new double[]{1.0, 6.0}, res, 1e-12);
+        }
+
+        /** */
+        @Test
+        public void testBatchWithMultiColumnAndMultiRow() {
+            double[] res = MLPTrainer.batch(data, new int[]{1, 3}, 5);
+
+            TestUtils.assertEquals(new double[]{1.0, 3.0, 6.0, 8.0}, res, 1e-12);
+        }
+    }
+}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/SimpleMLPLocalBatchTrainerInput.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/SimpleMLPLocalBatchTrainerInput.java
deleted file mode 100644 (file)
index 8bc0a6d..0000000
+++ /dev/null
@@ -1,99 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.nn;
-
-import java.util.Random;
-import org.apache.ignite.lang.IgniteBiTuple;
-import org.apache.ignite.ml.math.Matrix;
-import org.apache.ignite.ml.math.functions.IgniteSupplier;
-import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
-import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
-import org.apache.ignite.ml.nn.initializers.RandomInitializer;
-import org.apache.ignite.ml.trainers.local.LocalBatchTrainerInput;
-import org.apache.ignite.ml.util.Utils;
-
-/**
- * Class for local batch training of {@link MultilayerPerceptron}.
- *
- * It is constructed from two matrices: one containing inputs of function to approximate and other containing ground truth
- * values of this function for corresponding inputs.
- *
- * We fix batch size given by this input by some constant value.
- */
-public class SimpleMLPLocalBatchTrainerInput implements LocalBatchTrainerInput<MultilayerPerceptron> {
-    /**
-     * Multilayer perceptron to be trained.
-     */
-    private final MultilayerPerceptron mlp;
-
-    /**
-     * Inputs stored as columns.
-     */
-    private final Matrix inputs;
-
-    /**
-     * Ground truths stored as columns.
-     */
-    private final Matrix groundTruth;
-
-    /**
-     * Size of batch returned on each step.
-     */
-    private final int batchSize;
-
-    /**
-     * Construct instance of this class.
-     *
-     * @param arch Architecture of multilayer perceptron.
-     * @param rnd Random numbers generator.
-     * @param inputs Inputs stored as columns.
-     * @param groundTruth Ground truth stored as columns.
-     * @param batchSize Size of batch returned on each step.
-     */
-    public SimpleMLPLocalBatchTrainerInput(MLPArchitecture arch, Random rnd, Matrix inputs, Matrix groundTruth, int batchSize) {
-        this.mlp = new MultilayerPerceptron(arch, new RandomInitializer(rnd));
-        this.inputs = inputs;
-        this.groundTruth = groundTruth;
-        this.batchSize = batchSize;
-    }
-
-    /** {@inheritDoc} */
-    @Override public IgniteSupplier<IgniteBiTuple<Matrix, Matrix>> batchSupplier() {
-        return () -> {
-            int inputRowSize = inputs.rowSize();
-            int outputRowSize = groundTruth.rowSize();
-
-            Matrix vectors = new DenseLocalOnHeapMatrix(inputRowSize, batchSize);
-            Matrix labels = new DenseLocalOnHeapMatrix(outputRowSize, batchSize);
-
-            int[] samples = Utils.selectKDistinct(inputs.columnSize(), batchSize);
-
-            for (int i = 0; i < batchSize; i++) {
-                vectors.assignColumn(i, inputs.getCol(samples[i]));
-                labels.assignColumn(i, groundTruth.getCol(samples[i]));
-            }
-
-            return new IgniteBiTuple<>(vectors, labels);
-        };
-    }
-
-    /** {@inheritDoc} */
-    @Override public MultilayerPerceptron mdl() {
-        return mlp;
-    }
-}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java
new file mode 100644 (file)
index 0000000..c787a47
--- /dev/null
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.nn.performance;
+
+import java.io.IOException;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
+import org.apache.ignite.configuration.CacheConfiguration;
+import org.apache.ignite.internal.util.IgniteUtils;
+import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.VectorUtils;
+import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
+import org.apache.ignite.ml.nn.Activators;
+import org.apache.ignite.ml.nn.MLPTrainer;
+import org.apache.ignite.ml.nn.MultilayerPerceptron;
+import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
+import org.apache.ignite.ml.optimization.LossFunctions;
+import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate;
+import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator;
+import org.apache.ignite.ml.trainers.group.UpdatesStrategy;
+import org.apache.ignite.ml.util.MnistUtils;
+import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
+
+/**
+ * Tests {@link MLPTrainer} on the MNIST dataset that require to start the whole Ignite infrastructure.
+ */
+public class MLPTrainerMnistIntegrationTest extends GridCommonAbstractTest {
+    /** Number of nodes in grid */
+    private static final int NODE_COUNT = 3;
+
+    /** Ignite instance. */
+    private Ignite ignite;
+
+    /** {@inheritDoc} */
+    @Override protected void beforeTestsStarted() throws Exception {
+        for (int i = 1; i <= NODE_COUNT; i++)
+            startGrid(i);
+    }
+
+    /** {@inheritDoc} */
+    @Override protected void afterTestsStopped() {
+        stopAllGrids();
+    }
+
+    /**
+     * {@inheritDoc}
+     */
+    @Override protected void beforeTest() throws Exception {
+        /* Grid instance. */
+        ignite = grid(NODE_COUNT);
+        ignite.configuration().setPeerClassLoadingEnabled(true);
+        IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+    }
+
+    /** Tests on the MNIST dataset. */
+    public void testMNIST() throws IOException {
+        int featCnt = 28 * 28;
+        int hiddenNeuronsCnt = 100;
+
+        CacheConfiguration<Integer, MnistUtils.MnistLabeledImage> trainingSetCacheCfg = new CacheConfiguration<>();
+        trainingSetCacheCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
+        trainingSetCacheCfg.setName("MNIST_TRAINING_SET");
+        IgniteCache<Integer, MnistUtils.MnistLabeledImage> trainingSet = ignite.createCache(trainingSetCacheCfg);
+
+        int i = 0;
+        for (MnistUtils.MnistLabeledImage e : MnistMLPTestUtil.loadTrainingSet(6_000))
+            trainingSet.put(i++, e);
+
+        MLPArchitecture arch = new MLPArchitecture(featCnt).
+            withAddedLayer(hiddenNeuronsCnt, true, Activators.SIGMOID).
+            withAddedLayer(10, false, Activators.SIGMOID);
+
+        MLPTrainer<RPropParameterUpdate> trainer = new MLPTrainer<>(
+            arch,
+            LossFunctions.MSE,
+            new UpdatesStrategy<>(
+                new RPropUpdateCalculator(),
+                RPropParameterUpdate::sum,
+                RPropParameterUpdate::avg
+            ),
+            200,
+            2000,
+            10,
+            123L
+        );
+
+        System.out.println("Start training...");
+        long start = System.currentTimeMillis();
+        MultilayerPerceptron mdl = trainer.fit(
+            new CacheBasedDatasetBuilder<>(ignite, trainingSet),
+            (k, v) -> v.getPixels(),
+            (k, v) -> VectorUtils.num2Vec(v.getLabel(), 10).getStorage().data()
+        );
+        System.out.println("Training completed in " + (System.currentTimeMillis() - start) + "ms");
+
+        int correctAnswers = 0;
+        int incorrectAnswers = 0;
+
+        for (MnistUtils.MnistLabeledImage e : MnistMLPTestUtil.loadTestSet(1_000)) {
+            Matrix input = new DenseLocalOnHeapMatrix(new double[][]{e.getPixels()});
+            Matrix outputMatrix = mdl.apply(input);
+
+            int predicted = (int) VectorUtils.vec2Num(outputMatrix.getRow(0));
+
+            if (predicted == e.getLabel())
+                correctAnswers++;
+            else
+                incorrectAnswers++;
+        }
+
+        double accuracy = 1.0 * correctAnswers / (correctAnswers + incorrectAnswers);
+        assertTrue("Accuracy should be >= 80%", accuracy >= 0.8);
+    }
+}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java
new file mode 100644 (file)
index 0000000..354af2c
--- /dev/null
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.nn.performance;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.VectorUtils;
+import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
+import org.apache.ignite.ml.nn.Activators;
+import org.apache.ignite.ml.nn.MLPTrainer;
+import org.apache.ignite.ml.nn.MultilayerPerceptron;
+import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
+import org.apache.ignite.ml.optimization.LossFunctions;
+import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate;
+import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator;
+import org.apache.ignite.ml.trainers.group.UpdatesStrategy;
+import org.apache.ignite.ml.util.MnistUtils;
+import org.junit.Test;
+
+import static org.junit.Assert.assertTrue;
+
+/**
+ * Tests {@link MLPTrainer} on the MNIST dataset using locally stored data.
+ */
+public class MLPTrainerMnistTest {
+    /** Tests on the MNIST dataset. */
+    @Test
+    public void testMNIST() throws IOException {
+        int featCnt = 28 * 28;
+        int hiddenNeuronsCnt = 100;
+
+        Map<Integer, MnistUtils.MnistLabeledImage> trainingSet = new HashMap<>();
+
+        int i = 0;
+        for (MnistUtils.MnistLabeledImage e : MnistMLPTestUtil.loadTrainingSet(60_000))
+            trainingSet.put(i++, e);
+
+        MLPArchitecture arch = new MLPArchitecture(featCnt).
+            withAddedLayer(hiddenNeuronsCnt, true, Activators.SIGMOID).
+            withAddedLayer(10, false, Activators.SIGMOID);
+
+        MLPTrainer<?> trainer = new MLPTrainer<>(
+            arch,
+            LossFunctions.MSE,
+            new UpdatesStrategy<>(
+                new RPropUpdateCalculator(),
+                RPropParameterUpdate::sum,
+                RPropParameterUpdate::avg
+            ),
+            200,
+            2000,
+            10,
+            123L
+        );
+
+        System.out.println("Start training...");
+        long start = System.currentTimeMillis();
+        MultilayerPerceptron mdl = trainer.fit(
+            new LocalDatasetBuilder<>(trainingSet, 1),
+            (k, v) -> v.getPixels(),
+            (k, v) -> VectorUtils.num2Vec(v.getLabel(), 10).getStorage().data()
+        );
+        System.out.println("Training completed in " + (System.currentTimeMillis() - start) + "ms");
+
+        int correctAnswers = 0;
+        int incorrectAnswers = 0;
+
+        for (MnistUtils.MnistLabeledImage e : MnistMLPTestUtil.loadTestSet(10_000)) {
+            Matrix input = new DenseLocalOnHeapMatrix(new double[][]{e.getPixels()});
+            Matrix outputMatrix = mdl.apply(input);
+
+            int predicted = (int) VectorUtils.vec2Num(outputMatrix.getRow(0));
+
+            if (predicted == e.getLabel())
+                correctAnswers++;
+            else
+                incorrectAnswers++;
+        }
+
+        double accuracy = 1.0 * correctAnswers / (correctAnswers + incorrectAnswers);
+        assertTrue("Accuracy should be >= 80% (not " + accuracy * 100 + "%)", accuracy >= 0.8);
+    }
+}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistDistributed.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistDistributed.java
deleted file mode 100644 (file)
index 5656f68..0000000
+++ /dev/null
@@ -1,154 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.nn.performance;
-
-import java.io.IOException;
-import java.util.List;
-import java.util.stream.Collectors;
-import java.util.stream.Stream;
-import org.apache.ignite.Ignite;
-import org.apache.ignite.IgniteCache;
-import org.apache.ignite.IgniteDataStreamer;
-import org.apache.ignite.internal.util.typedef.X;
-import org.apache.ignite.lang.IgniteBiTuple;
-import org.apache.ignite.ml.math.Matrix;
-import org.apache.ignite.ml.math.Tracer;
-import org.apache.ignite.ml.math.Vector;
-import org.apache.ignite.ml.math.VectorUtils;
-import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
-import org.apache.ignite.ml.nn.Activators;
-import org.apache.ignite.ml.nn.LabeledVectorsCache;
-import org.apache.ignite.ml.nn.MLPGroupUpdateTrainerCacheInput;
-import org.apache.ignite.ml.nn.MultilayerPerceptron;
-import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
-import org.apache.ignite.ml.nn.trainers.distributed.MLPGroupUpdateTrainer;
-import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate;
-import org.apache.ignite.ml.structures.LabeledVector;
-import org.apache.ignite.ml.util.MnistUtils;
-import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
-
-import static org.apache.ignite.ml.nn.performance.MnistMLPTestUtil.createDataset;
-import static org.apache.ignite.ml.nn.performance.MnistMLPTestUtil.loadMnist;
-
-/**
- * Various benchmarks for hand runs.
- */
-public class MnistDistributed extends GridCommonAbstractTest {
-    /** Count of nodes. */
-    private static final int NODE_COUNT = 3;
-
-    /** Features count in MNIST. */
-    private static final int FEATURES_CNT = 28 * 28;
-
-    /** Grid instance. */
-    private Ignite ignite;
-
-    /**
-     * {@inheritDoc}
-     */
-    @Override protected void beforeTest() throws Exception {
-        ignite = grid(NODE_COUNT);
-    }
-
-    /** {@inheritDoc} */
-    @Override protected void beforeTestsStarted() throws Exception {
-        for (int i = 1; i <= NODE_COUNT; i++)
-            startGrid(i);
-    }
-
-    /** {@inheritDoc} */
-    @Override protected void afterTestsStopped() throws Exception {
-        stopAllGrids();
-    }
-
-    /** */
-    public void testMNISTDistributed() throws IOException {
-        int samplesCnt = 60_000;
-        int hiddenNeuronsCnt = 100;
-
-        IgniteBiTuple<Stream<DenseLocalOnHeapVector>, Stream<DenseLocalOnHeapVector>> trainingAndTest = loadMnist(samplesCnt);
-
-        // Load training mnist part into a cache.
-        Stream<DenseLocalOnHeapVector> trainingMnist = trainingAndTest.get1();
-        List<DenseLocalOnHeapVector> trainingMnistLst = trainingMnist.collect(Collectors.toList());
-
-        IgniteCache<Integer, LabeledVector<Vector, Vector>> labeledVectorsCache = LabeledVectorsCache.createNew(ignite);
-        loadIntoCache(trainingMnistLst, labeledVectorsCache);
-
-        MLPGroupUpdateTrainer<RPropParameterUpdate> trainer = MLPGroupUpdateTrainer.getDefault(ignite).
-            withMaxGlobalSteps(35).
-            withSyncPeriod(2);
-
-        MLPArchitecture arch = new MLPArchitecture(FEATURES_CNT).
-            withAddedLayer(hiddenNeuronsCnt, true, Activators.SIGMOID).
-            withAddedLayer(10, false, Activators.SIGMOID);
-
-        MultilayerPerceptron mdl = trainer.train(new MLPGroupUpdateTrainerCacheInput(arch, 9, labeledVectorsCache, 2000));
-
-        IgniteBiTuple<Matrix, Matrix> testDs = createDataset(trainingAndTest.get2(), 10_000, FEATURES_CNT);
-
-        Vector truth = testDs.get2().foldColumns(VectorUtils::vec2Num);
-        Vector predicted = mdl.apply(testDs.get1()).foldColumns(VectorUtils::vec2Num);
-
-        Tracer.showAscii(truth);
-        Tracer.showAscii(predicted);
-
-        X.println("Accuracy: " + VectorUtils.zipWith(predicted, truth, (x, y) -> x.equals(y) ? 1.0 : 0.0).sum() / truth.size() * 100 + "%.");
-    }
-
-    /**
-     * Load MNIST into cache.
-     *
-     * @param trainingMnistLst List with mnist data.
-     * @param labeledVectorsCache Cache to load MNIST into.
-     */
-    private void loadIntoCache(List<DenseLocalOnHeapVector> trainingMnistLst,
-        IgniteCache<Integer, LabeledVector<Vector, Vector>> labeledVectorsCache) {
-        String cacheName = labeledVectorsCache.getName();
-
-        try (IgniteDataStreamer<Integer, LabeledVector<Vector, Vector>> streamer =
-                 ignite.dataStreamer(cacheName)) {
-            int sampleIdx = 0;
-
-            streamer.perNodeBufferSize(10000);
-
-            for (DenseLocalOnHeapVector vector : trainingMnistLst) {
-                streamer.addData(sampleIdx, asLabeledVector(vector, FEATURES_CNT));
-
-                if (sampleIdx % 5000 == 0)
-                    X.println("Loaded " + sampleIdx + " samples.");
-
-                sampleIdx++;
-            }
-        }
-    }
-
-    /**
-     * Transform vector created by {@link MnistUtils} to {@link LabeledVector}.
-     *
-     * @param v Vector to transform.
-     * @param featsCnt Count of features.
-     * @return Vector created by {@link MnistUtils} transformed to {@link LabeledVector}.
-     */
-    private static LabeledVector<Vector, Vector> asLabeledVector(Vector v, int featsCnt) {
-        Vector features = VectorUtils.copyPart(v, 0, featsCnt);
-        Vector lb = VectorUtils.num2Vec((int)v.get(featsCnt), 10);
-
-        return new LabeledVector<>(features, lb);
-    }
-}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistLocal.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistLocal.java
deleted file mode 100644 (file)
index 14c02aa..0000000
+++ /dev/null
@@ -1,95 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.nn.performance;
-
-import java.io.IOException;
-import java.util.Random;
-import java.util.stream.Stream;
-import org.apache.ignite.internal.util.typedef.X;
-import org.apache.ignite.lang.IgniteBiTuple;
-import org.apache.ignite.ml.math.Matrix;
-import org.apache.ignite.ml.math.Tracer;
-import org.apache.ignite.ml.math.Vector;
-import org.apache.ignite.ml.math.VectorUtils;
-import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
-import org.apache.ignite.ml.nn.Activators;
-import org.apache.ignite.ml.optimization.LossFunctions;
-import org.apache.ignite.ml.nn.MultilayerPerceptron;
-import org.apache.ignite.ml.nn.SimpleMLPLocalBatchTrainerInput;
-import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
-import org.apache.ignite.ml.nn.trainers.local.MLPLocalBatchTrainer;
-import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator;
-import org.junit.Test;
-
-import static org.apache.ignite.ml.nn.performance.MnistMLPTestUtil.createDataset;
-import static org.apache.ignite.ml.nn.performance.MnistMLPTestUtil.loadMnist;
-
-/**
- * Various benchmarks for hand runs.
- */
-public class MnistLocal {
-    /**
-     * Run nn classifier on MNIST using bi-indexed cache as a storage for dataset.
-     * To run this test rename this method so it starts from 'test'.
-     *
-     * @throws IOException In case of loading MNIST dataset errors.
-     */
-    @Test
-    public void tstMNISTLocal() throws IOException {
-        int samplesCnt = 60_000;
-        int featCnt = 28 * 28;
-        int hiddenNeuronsCnt = 100;
-
-        IgniteBiTuple<Stream<DenseLocalOnHeapVector>, Stream<DenseLocalOnHeapVector>> trainingAndTest = loadMnist(samplesCnt);
-
-        Stream<DenseLocalOnHeapVector> trainingMnistStream = trainingAndTest.get1();
-        Stream<DenseLocalOnHeapVector> testMnistStream = trainingAndTest.get2();
-
-        IgniteBiTuple<Matrix, Matrix> ds = createDataset(trainingMnistStream, samplesCnt, featCnt);
-        IgniteBiTuple<Matrix, Matrix> testDs = createDataset(testMnistStream, 10000, featCnt);
-
-        MLPArchitecture conf = new MLPArchitecture(featCnt).
-            withAddedLayer(hiddenNeuronsCnt, true, Activators.SIGMOID).
-            withAddedLayer(10, false, Activators.SIGMOID);
-
-        SimpleMLPLocalBatchTrainerInput input = new SimpleMLPLocalBatchTrainerInput(conf,
-            new Random(),
-            ds.get1(),
-            ds.get2(),
-            2000);
-
-        MultilayerPerceptron mdl = new MLPLocalBatchTrainer<>(LossFunctions.MSE,
-            () -> new RPropUpdateCalculator(0.1, 1.2, 0.5),
-            1E-7,
-            200).
-            train(input);
-
-        X.println("Training started");
-        long before = System.currentTimeMillis();
-
-        X.println("Training finished in " + (System.currentTimeMillis() - before));
-
-        Vector predicted = mdl.apply(testDs.get1()).foldColumns(VectorUtils::vec2Num);
-        Vector truth = testDs.get2().foldColumns(VectorUtils::vec2Num);
-
-        Tracer.showAscii(truth);
-        Tracer.showAscii(predicted);
-
-        X.println("Accuracy: " + VectorUtils.zipWith(predicted, truth, (x, y) -> x.equals(y) ? 1.0 : 0.0).sum() / truth.size() * 100 + "%.");
-    }
-}
index 42ce523..e624004 100644 (file)
@@ -22,17 +22,12 @@ import java.io.InputStream;
 import java.util.List;
 import java.util.Properties;
 import java.util.Random;
-import java.util.stream.Collectors;
 import java.util.stream.Stream;
 import org.apache.ignite.lang.IgniteBiTuple;
-import org.apache.ignite.ml.math.Matrix;
-import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
 import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
 import org.apache.ignite.ml.trees.performance.ColumnDecisionTreeTrainerBenchmark;
 import org.apache.ignite.ml.util.MnistUtils;
 
-import static org.apache.ignite.ml.math.VectorUtils.num2Vec;
-
 /** */
 class MnistMLPTestUtil {
     /** Name of the property specifying path to training set images. */
@@ -51,15 +46,39 @@ class MnistMLPTestUtil {
     static IgniteBiTuple<Stream<DenseLocalOnHeapVector>, Stream<DenseLocalOnHeapVector>> loadMnist(int samplesCnt) throws IOException {
         Properties props = loadMNISTProperties();
 
-        Stream<DenseLocalOnHeapVector> trainingMnistStream = MnistUtils.mnist(props.getProperty(PROP_TRAINING_IMAGES),
+        Stream<DenseLocalOnHeapVector> trainingMnistStream = MnistUtils.mnistAsStream(props.getProperty(PROP_TRAINING_IMAGES),
             props.getProperty(PROP_TRAINING_LABELS), new Random(123L), samplesCnt);
 
-        Stream<DenseLocalOnHeapVector> testMnistStream = MnistUtils.mnist(props.getProperty(PROP_TEST_IMAGES),
+        Stream<DenseLocalOnHeapVector> testMnistStream = MnistUtils.mnistAsStream(props.getProperty(PROP_TEST_IMAGES),
             props.getProperty(PROP_TEST_LABELS), new Random(123L), 10_000);
 
         return new IgniteBiTuple<>(trainingMnistStream, testMnistStream);
     }
 
+    /**
+     * Loads training set.
+     *
+     * @param cnt Count of objects.
+     * @return List of MNIST images.
+     * @throws IOException In case of exception.
+     */
+    static List<MnistUtils.MnistLabeledImage> loadTrainingSet(int cnt) throws IOException {
+        Properties props = loadMNISTProperties();
+        return MnistUtils.mnistAsList(props.getProperty(PROP_TRAINING_IMAGES), props.getProperty(PROP_TRAINING_LABELS), new Random(123L), cnt);
+    }
+
+    /**
+     * Loads test set.
+     *
+     * @param cnt Count of objects.
+     * @return List of MNIST images.
+     * @throws IOException In case of exception.
+     */
+    static List<MnistUtils.MnistLabeledImage> loadTestSet(int cnt) throws IOException {
+        Properties props = loadMNISTProperties();
+        return MnistUtils.mnistAsList(props.getProperty(PROP_TEST_IMAGES), props.getProperty(PROP_TEST_LABELS), new Random(123L), cnt);
+    }
+
     /** Load properties for MNIST tests. */
     private static Properties loadMNISTProperties() throws IOException {
         Properties res = new Properties();
@@ -70,19 +89,4 @@ class MnistMLPTestUtil {
 
         return res;
     }
-
-    /** */
-    static IgniteBiTuple<Matrix, Matrix> createDataset(Stream<DenseLocalOnHeapVector> s, int samplesCnt, int featCnt) {
-        Matrix vectors = new DenseLocalOnHeapMatrix(featCnt, samplesCnt);
-        Matrix labels = new DenseLocalOnHeapMatrix(10, samplesCnt);
-        List<DenseLocalOnHeapVector> sc = s.collect(Collectors.toList());
-
-        for (int i = 0; i < samplesCnt; i++) {
-            DenseLocalOnHeapVector v = sc.get(i);
-            vectors.assignColumn(i, v.viewPart(0, featCnt));
-            labels.assignColumn(i, num2Vec((int)v.getX(featCnt), 10));
-        }
-
-        return new IgniteBiTuple<>(vectors, labels);
-    }
 }
index 3bb3ee7..e3f60ec 100644 (file)
@@ -69,13 +69,12 @@ public class LinearRegressionLSQRTrainerTest {
         data.put(8, new double[] {0.20702671, 0.92864654, 0.32721202, -0.09047503, 31.61484949});
         data.put(9, new double[] {-0.37890345, -0.04846179, -0.84122753, -1.14667474, -124.92598583});
 
-        LinearRegressionLSQRTrainer<Integer, double[]> trainer = new LinearRegressionLSQRTrainer<>();
+        LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
 
         LinearRegressionModel mdl = trainer.fit(
             new LocalDatasetBuilder<>(data, parts),
             (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1),
-            (k, v) -> v[4],
-            4
+            (k, v) -> v[4]
         );
 
         assertArrayEquals(
@@ -108,13 +107,12 @@ public class LinearRegressionLSQRTrainerTest {
             data.put(i, x);
         }
 
-        LinearRegressionLSQRTrainer<Integer, double[]> trainer = new LinearRegressionLSQRTrainer<>();
+        LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
 
         LinearRegressionModel mdl = trainer.fit(
             new LocalDatasetBuilder<>(data, parts),
             (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1),
-            (k, v) -> v[coef.length],
-            coef.length
+            (k, v) -> v[coef.length]
         );
 
         assertArrayEquals(coef, mdl.getWeights().getStorage().data(), 1e-6);
index 9e43f91..26ba2fb 100644 (file)
@@ -59,13 +59,13 @@ public class SVMBinaryTrainerTest {
             data.put(i, vec);
         }
 
-        SVMLinearBinaryClassificationTrainer<Integer, double[]> trainer = new SVMLinearBinaryClassificationTrainer<>();
+        SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer();
 
         SVMLinearBinaryClassificationModel mdl = trainer.fit(
             new LocalDatasetBuilder<>(data, 10),
             (k, v) -> Arrays.copyOfRange(v, 1, v.length),
-            (k, v) -> v[0],
-            AMOUNT_OF_FEATURES);
+            (k, v) -> v[0]
+        );
 
         TestUtils.assertEquals(-1, mdl.apply(new DenseLocalOnHeapVector(new double[]{100, 10})), PRECISION);
         TestUtils.assertEquals(1, mdl.apply(new DenseLocalOnHeapVector(new double[]{10, 100})), PRECISION);
index 4218085..ad95eb4 100644 (file)
@@ -59,7 +59,7 @@ public class SVMMultiClassTrainerTest {
             data.put(i, vec);
         }
 
-        SVMLinearMultiClassClassificationTrainer<Integer, double[]> trainer = new SVMLinearMultiClassClassificationTrainer<Integer, double[]>()
+        SVMLinearMultiClassClassificationTrainer trainer = new SVMLinearMultiClassClassificationTrainer()
             .withLambda(0.3)
             .withAmountOfLocIterations(100)
             .withAmountOfIterations(20);
@@ -67,8 +67,8 @@ public class SVMMultiClassTrainerTest {
         SVMLinearMultiClassClassificationModel mdl = trainer.fit(
             new LocalDatasetBuilder<>(data, 10),
             (k, v) -> Arrays.copyOfRange(v, 1, v.length),
-            (k, v) -> v[0],
-            AMOUNT_OF_FEATURES);
+            (k, v) -> v[0]
+        );
 
         TestUtils.assertEquals(-1, mdl.apply(new DenseLocalOnHeapVector(new double[]{100, 10})), PRECISION);
         TestUtils.assertEquals(1, mdl.apply(new DenseLocalOnHeapVector(new double[]{10, 100})), PRECISION);
index a72dec2..21fd692 100644 (file)
@@ -154,8 +154,8 @@ public class ColumnDecisionTreeTrainerBenchmark extends BaseDecisionTreeTest {
 
         Properties props = loadMNISTProperties();
 
-        Stream<DenseLocalOnHeapVector> trainingMnistStream = MnistUtils.mnist(props.getProperty(PROP_TRAINING_IMAGES), props.getProperty(PROP_TRAINING_LABELS), new Random(123L), ptsCnt);
-        Stream<DenseLocalOnHeapVector> testMnistStream = MnistUtils.mnist(props.getProperty(PROP_TEST_IMAGES), props.getProperty(PROP_TEST_LABELS), new Random(123L), 10_000);
+        Stream<DenseLocalOnHeapVector> trainingMnistStream = MnistUtils.mnistAsStream(props.getProperty(PROP_TRAINING_IMAGES), props.getProperty(PROP_TRAINING_LABELS), new Random(123L), ptsCnt);
+        Stream<DenseLocalOnHeapVector> testMnistStream = MnistUtils.mnistAsStream(props.getProperty(PROP_TEST_IMAGES), props.getProperty(PROP_TEST_LABELS), new Random(123L), 10_000);
 
         IgniteCache<BiIndex, Double> cache = createBiIndexedCache();
 
@@ -193,8 +193,8 @@ public class ColumnDecisionTreeTrainerBenchmark extends BaseDecisionTreeTest {
 
         Properties props = loadMNISTProperties();
 
-        Stream<DenseLocalOnHeapVector> trainingMnistStream = MnistUtils.mnist(props.getProperty(PROP_TRAINING_IMAGES), props.getProperty(PROP_TRAINING_LABELS), new Random(123L), ptsCnt);
-        Stream<DenseLocalOnHeapVector> testMnistStream = MnistUtils.mnist(props.getProperty(PROP_TEST_IMAGES), props.getProperty(PROP_TEST_LABELS), new Random(123L), 10_000);
+        Stream<DenseLocalOnHeapVector> trainingMnistStream = MnistUtils.mnistAsStream(props.getProperty(PROP_TRAINING_IMAGES), props.getProperty(PROP_TRAINING_LABELS), new Random(123L), ptsCnt);
+        Stream<DenseLocalOnHeapVector> testMnistStream = MnistUtils.mnistAsStream(props.getProperty(PROP_TEST_IMAGES), props.getProperty(PROP_TEST_LABELS), new Random(123L), 10_000);
 
         SparseDistributedMatrix m = new SparseDistributedMatrix(ptsCnt, featCnt + 1, StorageConstants.COLUMN_STORAGE_MODE, StorageConstants.RANDOM_ACCESS_MODE);
 
index 2fd77ed..b7c9c6d 100644 (file)
@@ -14,7 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-# Paths to mnist dataset parts.
+# Paths to mnistAsStream dataset parts.
 mnist.training.images=/path/to/mnist/train-images-idx3-ubyte
 mnist.training.labels=/path/to/mnist/train-labels-idx1-ubyte
 mnist.test.images=/path/to/mnist/t10k-images-idx3-ubyte