IGNITE-10573: [ML] Consistent API for Ensemble training
authorArtem Malykh <amalykhgh@gmail.com>
Tue, 15 Jan 2019 17:19:48 +0000 (20:19 +0300)
committerYury Babak <ybabak@gridgain.com>
Tue, 15 Jan 2019 17:19:48 +0000 (20:19 +0300)
This closes #5767

55 files changed:
examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java
examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_10_Scaling_With_Stacking.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/IgniteModel.java
modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/composition/CompositionUtils.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/composition/DatasetMapping.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/composition/bagging/BaggedModel.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/composition/bagging/BaggedTrainer.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/package-info.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/parallel/ModelsParallelComposition.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/parallel/TrainersParallelComposition.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/parallel/package-info.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/sequential/ModelsSequentialComposition.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/sequential/TrainersSequentialComposition.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/sequential/package-info.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedDatasetTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedModel.java
modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedVectorDatasetTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetBuilder.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformer.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformerBuilder.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilder.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java
modules/ml/src/main/java/org/apache/ignite/ml/genetic/Chromosome.java
modules/ml/src/main/java/org/apache/ignite/ml/genetic/MutateJob.java
modules/ml/src/main/java/org/apache/ignite/ml/genetic/cache/GeneCacheConfig.java
modules/ml/src/main/java/org/apache/ignite/ml/genetic/cache/PopulationCacheConfig.java
modules/ml/src/main/java/org/apache/ignite/ml/genetic/parameter/ChromosomeCriteria.java
modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/multiclass/OneVsRestTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/trainers/AdaptableDatasetTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java
modules/ml/src/main/java/org/apache/ignite/ml/trainers/transformers/BaggingUpstreamTransformer.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/util/generators/DataStreamGenerator.java
modules/ml/src/main/java/org/apache/ignite/ml/util/generators/DatasetBuilderAdapter.java
modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java
modules/ml/src/test/java/org/apache/ignite/ml/composition/BaggingTest.java
modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java
modules/ml/src/test/java/org/apache/ignite/ml/trainers/StackingTest.java [new file with mode: 0644]
modules/ml/src/test/java/org/apache/ignite/ml/util/generators/DataStreamGeneratorTest.java

index 58f739d..c9b10b1 100644 (file)
@@ -22,7 +22,8 @@ import java.util.Arrays;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.Ignition;
-import org.apache.ignite.ml.composition.ModelsComposition;
+import org.apache.ignite.ml.composition.bagging.BaggedModel;
+import org.apache.ignite.ml.composition.bagging.BaggedTrainer;
 import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.nn.UpdatesStrategy;
@@ -31,7 +32,6 @@ import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalcula
 import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer;
 import org.apache.ignite.ml.selection.cv.CrossValidation;
 import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
-import org.apache.ignite.ml.trainers.DatasetTrainer;
 import org.apache.ignite.ml.trainers.TrainerTransformers;
 import org.apache.ignite.ml.util.MLSandboxDatasets;
 import org.apache.ignite.ml.util.SandboxMLCache;
@@ -75,7 +75,7 @@ public class BaggedLogisticRegressionSGDTrainerExample {
 
             System.out.println(">>> Perform the training to get the model.");
 
-            DatasetTrainer< ModelsComposition, Double> baggedTrainer = TrainerTransformers.makeBagged(
+            BaggedTrainer<Double> baggedTrainer = TrainerTransformers.makeBagged(
                 trainer,
                 10,
                 0.6,
@@ -85,7 +85,7 @@ public class BaggedLogisticRegressionSGDTrainerExample {
 
             System.out.println(">>> Perform evaluation of the model.");
 
-            double[] score = new CrossValidation<ModelsComposition, Double, Integer, Vector>().score(
+            double[] score = new CrossValidation<BaggedModel, Double, Integer, Vector>().score(
                 baggedTrainer,
                 new Accuracy<>(),
                 ignite,
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_10_Scaling_With_Stacking.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_10_Scaling_With_Stacking.java
new file mode 100644 (file)
index 0000000..ec64764
--- /dev/null
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.tutorial;
+
+import java.io.FileNotFoundException;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.composition.stacking.StackedModel;
+import org.apache.ignite.ml.composition.stacking.StackedVectorDatasetTrainer;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.nn.UpdatesStrategy;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
+import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer;
+import org.apache.ignite.ml.preprocessing.encoding.EncoderType;
+import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer;
+import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer;
+import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer;
+import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel;
+import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer;
+import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
+import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
+
+/**
+ * {@link MinMaxScalerTrainer} and {@link NormalizationTrainer} are used in this example due to different values
+ * distribution in columns and rows.
+ * <p>
+ * Code in this example launches Ignite grid and fills the cache with test data (based on Titanic passengers data).</p>
+ * <p>
+ * After that it defines preprocessors that extract features from an upstream data and perform other desired changes
+ * over the extracted data, including the scaling.</p>
+ * <p>
+ * Then, it trains the model based on the processed data using decision tree classification.</p>
+ * <p>
+ * Finally, this example uses {@link BinaryClassificationEvaluator} functionality to compute metrics from predictions.</p>
+ */
+public class Step_10_Scaling_With_Stacking {
+    /** Run example. */
+    public static void main(String[] args) {
+        System.out.println();
+        System.out.println(">>> Tutorial step 5 (scaling) example started.");
+
+        try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
+            try {
+                IgniteCache<Integer, Object[]> dataCache = TitanicUtils.readPassengers(ignite);
+
+                // Defines first preprocessor that extracts features from an upstream data.
+                // Extracts "pclass", "sibsp", "parch", "sex", "embarked", "age", "fare".
+                IgniteBiFunction<Integer, Object[], Object[]> featureExtractor
+                    = (k, v) -> new Object[] {v[0], v[3], v[4], v[5], v[6], v[8], v[10]};
+
+                IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double)v[1];
+
+                IgniteBiFunction<Integer, Object[], Vector> strEncoderPreprocessor = new EncoderTrainer<Integer, Object[]>()
+                    .withEncoderType(EncoderType.STRING_ENCODER)
+                    .withEncodedFeature(1)
+                    .withEncodedFeature(6) // <--- Changed index here.
+                    .fit(ignite,
+                        dataCache,
+                        featureExtractor
+                    );
+
+                IgniteBiFunction<Integer, Object[], Vector> imputingPreprocessor = new ImputerTrainer<Integer, Object[]>()
+                    .fit(ignite,
+                        dataCache,
+                        strEncoderPreprocessor
+                    );
+
+                IgniteBiFunction<Integer, Object[], Vector> minMaxScalerPreprocessor = new MinMaxScalerTrainer<Integer, Object[]>()
+                    .fit(
+                        ignite,
+                        dataCache,
+                        imputingPreprocessor
+                    );
+
+                IgniteBiFunction<Integer, Object[], Vector> normalizationPreprocessor = new NormalizationTrainer<Integer, Object[]>()
+                    .withP(1)
+                    .fit(
+                        ignite,
+                        dataCache,
+                        minMaxScalerPreprocessor
+                    );
+
+                DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(5, 0);
+                DecisionTreeClassificationTrainer trainer1 = new DecisionTreeClassificationTrainer(3, 0);
+                DecisionTreeClassificationTrainer trainer2 = new DecisionTreeClassificationTrainer(4, 0);
+
+                LogisticRegressionSGDTrainer aggregator = new LogisticRegressionSGDTrainer()
+                    .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2),
+                        SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg));
+
+                StackedModel<Vector, Vector, Double, LogisticRegressionModel> mdl =
+                    new StackedVectorDatasetTrainer<>(aggregator)
+                    .addTrainerWithDoubleOutput(trainer)
+                    .addTrainerWithDoubleOutput(trainer1)
+                    .addTrainerWithDoubleOutput(trainer2)
+                    .fit(
+                        ignite,
+                        dataCache,
+                        normalizationPreprocessor,
+                        lbExtractor
+                    );
+
+                System.out.println("\n>>> Trained model: " + mdl);
+
+                double accuracy = BinaryClassificationEvaluator.evaluate(
+                    dataCache,
+                    mdl,
+                    normalizationPreprocessor,
+                    lbExtractor,
+                    new Accuracy<>()
+                );
+
+                System.out.println("\n>>> Accuracy " + accuracy);
+                System.out.println("\n>>> Test Error " + (1 - accuracy));
+
+                System.out.println(">>> Tutorial step 5 (scaling) example completed.");
+            }
+            catch (FileNotFoundException e) {
+                e.printStackTrace();
+            }
+        }
+    }
+}
index a1165e1..6268d06 100644 (file)
@@ -20,8 +20,10 @@ package org.apache.ignite.ml;
 import java.io.Serializable;
 import java.util.function.BiFunction;
 import org.apache.ignite.ml.inference.Model;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
 
 /** Basic interface for all models. */
+@FunctionalInterface
 public interface IgniteModel<T, V> extends Model<T, V>, Serializable {
     /**
      * Combines this model with other model via specified combiner
@@ -37,12 +39,46 @@ public interface IgniteModel<T, V> extends Model<T, V>, Serializable {
     /**
      * Get a composition model of the form {@code x -> after(mdl(x))}.
      *
-     * @param after Function to apply after this model.
+     * @param after Model to apply after this model.
      * @param <V1> Type of input of function applied before this model.
      * @return Composition model of the form {@code x -> after(mdl(x))}.
      */
     public default <V1> IgniteModel<T, V1> andThen(IgniteModel<V, V1> after) {
-        return t -> after.predict(predict(t));
+        IgniteModel<T, V> self = this;
+        return new IgniteModel<T, V1>() {
+            /** {@inheritDoc} */
+            @Override public V1 predict(T input) {
+                return after.predict(self.predict(input));
+            }
+
+            /** {@inheritDoc} */
+            @Override public void close() {
+                self.close();
+                after.close();
+            }
+        };
+    }
+
+    /**
+     * Get a composition model of the form {@code x -> after(mdl(x))}.
+     *
+     * @param after Function to apply after this model.
+     * @param <V1> Type of input of function applied before this model.
+     * @return Composition model of the form {@code x -> after(mdl(x))}.
+     */
+    public default <V1> IgniteModel<T, V1> andThen(IgniteFunction<V, V1> after) {
+        IgniteModel<T, V> self = this;
+        return new IgniteModel<T, V1>() {
+            /** {@inheritDoc} */
+            @Override public V1 predict(T input) {
+                return after.apply(self.predict(input));
+            }
+
+            /** {@inheritDoc} */
+            @Override public void close() {
+                self.close();
+            }
+        };
     }
 
     /**
index 88ea9b9..3206b5f 100644 (file)
@@ -149,7 +149,7 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> {
     }
 
     /** {@inheritDoc} */
-    @Override protected boolean checkState(KMeansModel mdl) {
+    @Override public boolean isUpdateable(KMeansModel mdl) {
         return mdl.getCenters().length == k && mdl.distanceMeasure().equals(distance);
     }
 
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/CompositionUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/CompositionUtils.java
new file mode 100644 (file)
index 0000000..5a2f40a
--- /dev/null
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition;
+
+import org.apache.ignite.ml.IgniteModel;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.trainers.DatasetTrainer;
+
+/**
+ * Various utility functions for trainers composition.
+ */
+public class CompositionUtils {
+    /**
+     * Perform blurring of model type of given trainer to {@code IgniteModel<I, O>}, where I, O are input and output
+     * types of original model.
+     *
+     * @param trainer Trainer to coerce.
+     * @param <I> Type of input of model produced by coerced trainer.
+     * @param <O> Type of output of model produced by coerced trainer.
+     * @param <M> Type of model produced by coerced trainer.
+     * @param <L> Type of labels.
+     * @return Trainer coerced to {@code DatasetTrainer<IgniteModel<I, O>, L>}.
+     */
+    public static <I, O, M extends IgniteModel<I, O>, L> DatasetTrainer<IgniteModel<I, O>, L> unsafeCoerce(
+        DatasetTrainer<? extends M, L> trainer) {
+        return new DatasetTrainer<IgniteModel<I, O>, L>() {
+            /** {@inheritDoc} */
+            @Override public <K, V> IgniteModel<I, O> fit(DatasetBuilder<K, V> datasetBuilder,
+                IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+                return trainer.fit(datasetBuilder, featureExtractor, lbExtractor);
+            }
+
+            /** {@inheritDoc} */
+            @Override public <K, V> IgniteModel<I, O> update(IgniteModel<I, O> mdl, DatasetBuilder<K, V> datasetBuilder,
+                IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+                DatasetTrainer<IgniteModel<I, O>, L> trainer1 = (DatasetTrainer<IgniteModel<I, O>, L>)trainer;
+                return trainer1.update(mdl, datasetBuilder, featureExtractor, lbExtractor);
+            }
+
+            /**
+             * This method is never called, instead of constructing logic of update from
+             * {@link DatasetTrainer#isUpdateable} and
+             * {@link DatasetTrainer#updateModel}
+             * in this class we explicitly override update method.
+             *
+             * @param mdl Model.
+             * @return True if current critical for training parameters correspond to parameters from last training.
+             */
+            @Override public boolean isUpdateable(IgniteModel<I, O> mdl) {
+                throw new IllegalStateException();
+            }
+
+            /**
+             * This method is never called, instead of constructing logic of update from
+             * {@link DatasetTrainer#isUpdateable(IgniteModel)} and
+             * {@link DatasetTrainer#updateModel(IgniteModel, DatasetBuilder, IgniteBiFunction, IgniteBiFunction)}
+             * in this class we explicitly override update method.
+             *
+             * @param mdl Model.
+             * @return Updated model.
+             */
+            @Override protected <K, V> IgniteModel<I, O> updateModel(IgniteModel<I, O> mdl, DatasetBuilder<K, V> datasetBuilder,
+                IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+                throw new IllegalStateException();
+            }
+        };
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/DatasetMapping.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/DatasetMapping.java
new file mode 100644 (file)
index 0000000..9547d54
--- /dev/null
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition;
+
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+
+/**
+ * This class represents dataset mapping. This is just a tuple of two mappings: one for features and one for labels.
+ *
+ * @param <L1> Type of labels before mapping.
+ * @param <L2> Type of labels after mapping.
+ */
+public interface DatasetMapping<L1, L2> {
+    /**
+     * Method used to map feature vectors.
+     *
+     * @param v Feature vector.
+     * @return Mapped feature vector.
+     */
+    public default Vector mapFeatures(Vector v) {
+        return v;
+    }
+
+    /**
+     * Method used to map labels.
+     *
+     * @param lbl Label.
+     * @return Mapped label.
+     */
+    public L2 mapLabels(L1 lbl);
+
+    /**
+     * Dataset mapping which maps features, leaving labels unaffected.
+     *
+     * @param mapper Function used to map features.
+     * @param <L> Type of labels.
+     * @return Dataset mapping which maps features, leaving labels unaffected.
+     */
+    public static <L> DatasetMapping<L, L> mappingFeatures(IgniteFunction<Vector, Vector> mapper) {
+        return new DatasetMapping<L, L>() {
+            /** {@inheritDoc} */
+            @Override public Vector mapFeatures(Vector v) {
+                return mapper.apply(v);
+            }
+
+            /** {@inheritDoc} */
+            @Override public L mapLabels(L lbl) {
+                return lbl;
+            }
+        };
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/bagging/BaggedModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/bagging/BaggedModel.java
new file mode 100644 (file)
index 0000000..c59a634
--- /dev/null
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.bagging;
+
+import org.apache.ignite.ml.IgniteModel;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+
+/**
+ * This class represents model produced by {@link BaggedTrainer}.
+ * It is a wrapper around inner representation of model produced by {@link BaggedTrainer}.
+ */
+public class BaggedModel implements IgniteModel<Vector, Double> {
+    /** Inner representation of model produced by {@link BaggedTrainer}. */
+    private IgniteModel<Vector, Double> mdl;
+
+    /**
+     * Construct instance of this class given specified model.
+     * @param mdl Model to wrap.
+     */
+    BaggedModel(IgniteModel<Vector, Double> mdl) {
+        this.mdl = mdl;
+    }
+
+    /**
+     * Get wrapped model.
+     *
+     * @return Wrapped model.
+     */
+    IgniteModel<Vector, Double> model() {
+        return mdl;
+    }
+
+    /** {@inheritDoc} */
+    @Override public Double predict(Vector i) {
+        return mdl.predict(i);
+    }
+
+    /** {@inheritDoc} */
+    @Override public void close() {
+        mdl.close();
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/bagging/BaggedTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/bagging/BaggedTrainer.java
new file mode 100644 (file)
index 0000000..5b0962a
--- /dev/null
@@ -0,0 +1,212 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.bagging;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.Random;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import org.apache.ignite.ml.IgniteModel;
+import org.apache.ignite.ml.composition.CompositionUtils;
+import org.apache.ignite.ml.composition.combinators.parallel.TrainersParallelComposition;
+import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.trainers.AdaptableDatasetTrainer;
+import org.apache.ignite.ml.trainers.DatasetTrainer;
+import org.apache.ignite.ml.trainers.transformers.BaggingUpstreamTransformer;
+import org.apache.ignite.ml.util.Utils;
+
+/**
+ * Trainer encapsulating logic of bootstrap aggregating (bagging).
+ * This trainer accepts some other trainer and returns bagged version of it.
+ * Resulting model consists of submodels results of which are aggregated by a specified aggregator.
+ * <p>Bagging is done
+ * on both samples and features (<a href="https://en.wikipedia.org/wiki/Bootstrap_aggregating"></a>Samples bagging</a>,
+ * <a href="https://en.wikipedia.org/wiki/Random_subspace_method"></a>Features bagging</a>).</p>
+ *
+ * @param <L> Type of labels.
+ */
+public class BaggedTrainer<L> extends
+    DatasetTrainer<BaggedModel, L> {
+    /** Trainer for which bagged version is created. */
+    private final DatasetTrainer<? extends IgniteModel, L> tr;
+
+    /** Aggregator of submodels results. */
+    private final PredictionsAggregator aggregator;
+
+    /** Count of submodels in the ensemble. */
+    private final int ensembleSize;
+
+    /** Ratio determining which part of dataset will be taken as subsample for each submodel training. */
+    private final double subsampleRatio;
+
+    /** Dimensionality of feature vectors. */
+    private final int featuresVectorSize;
+
+    /** Dimension of subspace on which all samples from subsample are projected. */
+    private final int featureSubspaceDim;
+
+    /**
+     * Construct instance of this class with given parameters.
+     *
+     * @param tr Trainer for making bagged.
+     * @param aggregator Aggregator of models.
+     * @param ensembleSize Size of ensemble.
+     * @param subsampleRatio Ratio (subsample size) / (initial dataset size).
+     * @param featuresVectorSize Dimensionality of feature vector.
+     * @param featureSubspaceDim Dimensionality of feature subspace.
+     */
+    public BaggedTrainer(DatasetTrainer<? extends IgniteModel, L> tr,
+        PredictionsAggregator aggregator, int ensembleSize, double subsampleRatio, int featuresVectorSize,
+        int featureSubspaceDim) {
+        this.tr = tr;
+        this.aggregator = aggregator;
+        this.ensembleSize = ensembleSize;
+        this.subsampleRatio = subsampleRatio;
+        this.featuresVectorSize = featuresVectorSize;
+        this.featureSubspaceDim = featureSubspaceDim;
+    }
+
+    /**
+     * Create trainer bagged trainer.
+     *
+     * @return Bagged trainer.
+     */
+    private DatasetTrainer<IgniteModel<Vector, Double>, L> getTrainer() {
+        List<int[]> mappings = (featuresVectorSize > 0 && featureSubspaceDim != featuresVectorSize) ?
+            IntStream.range(0, ensembleSize).mapToObj(
+                modelIdx -> getMapping(
+                    featuresVectorSize,
+                    featureSubspaceDim,
+                    environment.randomNumbersGenerator().nextLong()))
+                .collect(Collectors.toList()) :
+            null;
+
+        List<DatasetTrainer<? extends IgniteModel, L>> trainers = Collections.nCopies(ensembleSize, tr);
+
+        // Generate a list of trainers each each copy of original trainer but on its own subspace and subsample.
+        List<DatasetTrainer<IgniteModel<Vector, Double>, L>> subspaceTrainers = IntStream.range(0, ensembleSize)
+            .mapToObj(mdlIdx -> {
+                AdaptableDatasetTrainer<Vector, Double, Vector, Double, ? extends IgniteModel, L> tr =
+                    AdaptableDatasetTrainer.of(trainers.get(mdlIdx));
+                if (mappings != null) {
+                    tr = tr.afterFeatureExtractor(featureValues -> {
+                        int[] mapping = mappings.get(mdlIdx);
+                        double[] newFeaturesValues = new double[mapping.length];
+                        for (int j = 0; j < mapping.length; j++)
+                            newFeaturesValues[j] = featureValues.get(mapping[j]);
+
+                        return VectorUtils.of(newFeaturesValues);
+                    }).beforeTrainedModel(getProjector(mappings.get(mdlIdx)));
+                }
+                return tr
+                    .withUpstreamTransformerBuilder(BaggingUpstreamTransformer.builder(subsampleRatio, mdlIdx))
+                    .withEnvironmentBuilder(envBuilder);
+            })
+            .map(CompositionUtils::unsafeCoerce)
+            .collect(Collectors.toList());
+
+        AdaptableDatasetTrainer<Vector, Double, Vector, List<Double>, IgniteModel<Vector, List<Double>>, L> finalTrainer = AdaptableDatasetTrainer.of(
+            new TrainersParallelComposition<>(
+                subspaceTrainers)).afterTrainedModel(l -> aggregator.apply(l.stream().mapToDouble(Double::valueOf).toArray()));
+
+        return CompositionUtils.unsafeCoerce(finalTrainer);
+    }
+
+    /**
+     * Get mapping R^featuresVectorSize -> R^maximumFeaturesCntPerMdl.
+     *
+     * @param featuresVectorSize Features vector size (Dimension of initial space).
+     * @param maximumFeaturesCntPerMdl Dimension of target space.
+     * @param seed Seed.
+     * @return Mapping R^featuresVectorSize -> R^maximumFeaturesCntPerMdl.
+     */
+    public static int[] getMapping(int featuresVectorSize, int maximumFeaturesCntPerMdl, long seed) {
+        return Utils.selectKDistinct(featuresVectorSize, maximumFeaturesCntPerMdl, new Random(seed));
+    }
+
+    /**
+     * Get projector from index mapping.
+     *
+     * @param mapping Index mapping.
+     * @return Projector.
+     */
+    public static IgniteFunction<Vector, Vector> getProjector(int[] mapping) {
+        return v -> {
+            Vector res = VectorUtils.zeroes(mapping.length);
+            for (int i = 0; i < mapping.length; i++)
+                res.set(i, v.get(mapping[i]));
+
+            return res;
+        };
+    }
+
+    /** {@inheritDoc} */
+    @Override public <K, V> BaggedModel fit(DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+        IgniteModel<Vector, Double> fit = getTrainer().fit(datasetBuilder, featureExtractor, lbExtractor);
+        return new BaggedModel(fit);
+    }
+
+    /** {@inheritDoc} */
+    @Override public <K, V> BaggedModel update(BaggedModel mdl, DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+        IgniteModel<Vector, Double> updated = getTrainer().update(mdl.model(), datasetBuilder, featureExtractor, lbExtractor);
+        return new BaggedModel(updated);
+    }
+
+    /** {@inheritDoc} */
+    @Override public BaggedTrainer<L> withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
+        return (BaggedTrainer<L>)super.withEnvironmentBuilder(envBuilder);
+    }
+
+    /**
+     * This method is never called, instead of constructing logic of update from
+     * {@link DatasetTrainer#isUpdateable} and
+     * {@link DatasetTrainer#updateModel}
+     * in this class we explicitly override update method.
+     *
+     * @param mdl Model.
+     * @return True if current critical for training parameters correspond to parameters from last training.
+     */
+    @Override public boolean isUpdateable(BaggedModel mdl) {
+        // Should be never called.
+        throw new IllegalStateException();
+    }
+
+    /**
+     * This method is never called, instead of constructing logic of update from
+     * {@link DatasetTrainer#isUpdateable} and
+     * {@link DatasetTrainer#updateModel}
+     * in this class we explicitly override update method.
+     *
+     * @param mdl Model.
+     * @return Updated model.
+     */
+    @Override protected <K, V> BaggedModel updateModel(BaggedModel mdl, DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+        // Should be never called.
+        throw new IllegalStateException();
+    }
+}
index 35502ab..7d88ddb 100644 (file)
@@ -141,7 +141,7 @@ public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Doubl
     }
 
     /** {@inheritDoc} */
-    @Override protected boolean checkState(ModelsComposition mdl) {
+    @Override public boolean isUpdateable(ModelsComposition mdl) {
         return mdl instanceof GDBModel;
     }
 
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/package-info.java
new file mode 100644 (file)
index 0000000..b39067d
--- /dev/null
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains various combinators of trainers and models.
+ */
+package org.apache.ignite.ml.composition.combinators;
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/parallel/ModelsParallelComposition.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/parallel/ModelsParallelComposition.java
new file mode 100644 (file)
index 0000000..601b639
--- /dev/null
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.combinators.parallel;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+import org.apache.ignite.ml.IgniteModel;
+
+/**
+ * Parallel composition of models.
+ * Parallel composition of models is a model which contains a list of submodels with same input and output types.
+ * Result of prediction in such model is a list of predictions of each of submodels.
+ *
+ * @param <I> Type of submodel input.
+ * @param <O> Type of submodel output.
+ */
+public class ModelsParallelComposition<I, O> implements IgniteModel<I, List<O>> {
+    /** List of submodels. */
+    private final List<IgniteModel<I, O>> submodels;
+
+    /**
+     * Construc an instance of this class from list of submodels.
+     *
+     * @param submodels List of submodels constituting this model.
+     */
+    public ModelsParallelComposition(List<IgniteModel<I, O>> submodels) {
+        this.submodels = submodels;
+    }
+
+    /** {@inheritDoc} */
+    @Override public List<O> predict(I i) {
+        return submodels
+            .stream()
+            .map(m -> m.predict(i))
+            .collect(Collectors.toList());
+    }
+
+    /**
+     * List of submodels constituting this model.
+     *
+     * @return List of submodels constituting this model.
+     */
+    public List<IgniteModel<I, O>> submodels() {
+        return Collections.unmodifiableList(submodels);
+    }
+
+    /** {@inheritDoc} */
+    @Override public void close() {
+        submodels.forEach(IgniteModel::close);
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/parallel/TrainersParallelComposition.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/parallel/TrainersParallelComposition.java
new file mode 100644 (file)
index 0000000..411ed17
--- /dev/null
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.combinators.parallel;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.stream.Collectors;
+import org.apache.ignite.ml.IgniteModel;
+import org.apache.ignite.ml.composition.CompositionUtils;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.environment.parallelism.Promise;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.functions.IgniteSupplier;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.trainers.DatasetTrainer;
+
+/**
+ * This class represents a parallel composition of trainers.
+ * Parallel composition of trainers is a trainer itself which trains a list of trainers with same
+ * input and output. Training is done in following manner:
+ * <pre>
+ *     1. Independently train all trainers on the same dataset and get a list of models.
+ *     2. Combine models produced in step (1) into a {@link ModelsParallelComposition}.
+ * </pre>
+ * Updating is made in a similar fashion.
+ * Like in other trainers combinators we avoid to include type of contained trainers in type parameters
+ * because otherwise compositions of compositions would have a relatively complex generic type which will
+ * reduce readability.
+ *
+ * @param <I> Type of trainers inputs.
+ * @param <O> Type of trainers outputs.
+ * @param <L> Type of dataset labels.
+ */
+public class TrainersParallelComposition<I, O, L> extends DatasetTrainer<IgniteModel<I, List<O>>, L> {
+    /** List of trainers. */
+    private final List<DatasetTrainer<IgniteModel<I, O>, L>> trainers;
+
+    /**
+     * Construct an instance of this class from a list of trainers.
+     *
+     * @param trainers Trainers.
+     * @param <M> Type of model.
+     * @param <T> Type of trainer.
+     */
+    public <M extends IgniteModel<I, O>, T extends DatasetTrainer<? extends IgniteModel<I, O>, L>> TrainersParallelComposition(
+        List<T> trainers) {
+        this.trainers = trainers.stream().map(CompositionUtils::unsafeCoerce).collect(Collectors.toList());
+    }
+
+    /**
+     * Create parallel composition of trainers contained in a given list.
+     *
+     * @param trainers List of trainers.
+     * @param <I> Type of input of model priduced by trainers.
+     * @param <O> Type of output of model priduced by trainers.
+     * @param <M> Type of model priduced by trainers.
+     * @param <T> Type of trainers.
+     * @param <L> Type of input of labels.
+     * @return Parallel composition of trainers contained in a given list.
+     */
+    public static <I, O, M extends IgniteModel<I, O>, T extends DatasetTrainer<M, L>, L> TrainersParallelComposition<I, O, L> of(List<T> trainers) {
+        List<DatasetTrainer<IgniteModel<I, O>, L>> trs =
+            trainers.stream().map(CompositionUtils::unsafeCoerce).collect(Collectors.toList());
+
+        return new TrainersParallelComposition<>(trs);
+    }
+
+    /** {@inheritDoc} */
+    @Override public <K, V> IgniteModel<I, List<O>> fit(DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+        List<IgniteSupplier<IgniteModel<I, O>>> tasks = trainers.stream()
+            .map(tr -> (IgniteSupplier<IgniteModel<I, O>>)(() -> tr.fit(datasetBuilder, featureExtractor, lbExtractor)))
+            .collect(Collectors.toList());
+
+        List<IgniteModel<I, O>> mdls = environment.parallelismStrategy().submit(tasks).stream()
+            .map(Promise::unsafeGet)
+            .collect(Collectors.toList());
+
+        return new ModelsParallelComposition<>(mdls);
+    }
+
+    /** {@inheritDoc} */
+    @Override public <K, V> IgniteModel<I, List<O>> update(IgniteModel<I, List<O>> mdl, DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+        ModelsParallelComposition<I, O> typedMdl = (ModelsParallelComposition<I, O>)mdl;
+
+        assert typedMdl.submodels().size() == trainers.size();
+        List<IgniteSupplier<IgniteModel<I, O>>> tasks = new ArrayList<>();
+
+        for (int i = 0; i < trainers.size(); i++) {
+            int j = i;
+            tasks.add(() -> trainers.get(j).update(typedMdl.submodels().get(j), datasetBuilder, featureExtractor, lbExtractor));
+        }
+
+        List<IgniteModel<I, O>> mdls = environment.parallelismStrategy().submit(tasks).stream()
+            .map(Promise::unsafeGet)
+            .collect(Collectors.toList());
+
+        return new ModelsParallelComposition<>(mdls);
+    }
+
+    /**
+     * This method is never called, instead of constructing logic of update from
+     * {@link DatasetTrainer#isUpdateable} and
+     * {@link DatasetTrainer#updateModel}
+     * in this class we explicitly override update method.
+     *
+     * @param mdl Model.
+     * @return True if current critical for training parameters correspond to parameters from last training.
+     */
+    @Override public boolean isUpdateable(IgniteModel<I, List<O>> mdl) {
+        // Never called.
+        throw new IllegalStateException();
+    }
+
+    /**
+     * This method is never called, instead of constructing logic of update from
+     * {@link DatasetTrainer#isUpdateable(IgniteModel)} and
+     * {@link DatasetTrainer#updateModel(IgniteModel, DatasetBuilder, IgniteBiFunction, IgniteBiFunction)}
+     * in this class we explicitly override update method.
+     *
+     * @param mdl Model.
+     * @return Updated model.
+     */
+    @Override protected <K, V> IgniteModel<I, List<O>> updateModel(IgniteModel<I, List<O>> mdl, DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+        // Never called.
+        throw new IllegalStateException();
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/parallel/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/parallel/package-info.java
new file mode 100644 (file)
index 0000000..cb24250
--- /dev/null
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains parallel combinators of trainers and models.
+ */
+package org.apache.ignite.ml.composition.combinators.parallel;
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/sequential/ModelsSequentialComposition.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/sequential/ModelsSequentialComposition.java
new file mode 100644 (file)
index 0000000..78e2846
--- /dev/null
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.combinators.sequential;
+
+import java.util.List;
+import org.apache.ignite.ml.IgniteModel;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+
+/**
+ * Sequential composition of models.
+ * Sequential composition is a model consisting of two models {@code mdl1 :: I -> O1, mdl2 :: O1 -> O2} with prediction
+ * corresponding to application of composition {@code mdl1 `andThen` mdl2} to input.
+ *
+ * @param <I> Type of input of the first model.
+ * @param <O1> Type of output of the first model (and input of second).
+ * @param <O2> Type of output of the second model.
+ */
+public class ModelsSequentialComposition<I, O1, O2> implements IgniteModel<I, O2> {
+    /** First model. */
+    private IgniteModel<I, O1> mdl1;
+
+    /** Second model. */
+    private IgniteModel<O1, O2> mdl2;
+
+    /**
+     * Get sequential composition of submodels with same type.
+     *
+     * @param lst List of submodels.
+     * @param output2Input Function for conversion output to input.
+     * @param <I> Type of input of submodel.
+     * @param <O> Type of output of submodel.
+     * @return Sequential composition of submodels with same type.
+     */
+    public static <I, O> ModelsSequentialComposition<I, I, O> ofSame(List<? extends IgniteModel<I, O>> lst,
+        IgniteFunction<O, I> output2Input) {
+        assert lst.size() >= 2;
+
+        if (lst.size() == 2)
+            return new ModelsSequentialComposition<>(lst.get(0).andThen(output2Input),
+                lst.get(1));
+
+        return new ModelsSequentialComposition<>(lst.get(0).andThen(output2Input),
+            ofSame(lst.subList(1, lst.size()), output2Input));
+    }
+
+    /**
+     * Construct instance of this class from two given models.
+     *
+     * @param mdl1 First model.
+     * @param mdl2 Second model.
+     */
+    public ModelsSequentialComposition(IgniteModel<I, O1> mdl1, IgniteModel<O1, O2> mdl2) {
+        this.mdl1 = mdl1;
+        this.mdl2 = mdl2;
+    }
+
+    /**
+     * Get first model.
+     *
+     * @return First model.
+     */
+    public IgniteModel<I, O1> firstModel() {
+        return mdl1;
+    }
+
+    /**
+     * Get second model.
+     *
+     * @return Second model.
+     */
+    public IgniteModel<O1, O2> secondModel() {
+        return mdl2;
+    }
+
+    /** {@inheritDoc} */
+    @Override public O2 predict(I i1) {
+        return mdl1.andThen(mdl2).predict(i1);
+    }
+
+    /** {@inheritDoc} */
+    @Override public void close() {
+        mdl1.close();
+        mdl2.close();
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/sequential/TrainersSequentialComposition.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/sequential/TrainersSequentialComposition.java
new file mode 100644 (file)
index 0000000..d36ff9c
--- /dev/null
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.combinators.sequential;
+
+import org.apache.ignite.ml.IgniteModel;
+import org.apache.ignite.ml.composition.CompositionUtils;
+import org.apache.ignite.ml.composition.DatasetMapping;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.trainers.DatasetTrainer;
+
+/**
+ * Sequential composition of trainers.
+ * Sequential composition of trainers is itself trainer which produces {@link ModelsSequentialComposition}.
+ * Training is done in following fashion:
+ * <pre>
+ *     1. First trainer is trained and `mdl1` is produced.
+ *     2. From `mdl1` {@link DatasetMapping} is constructed. This mapping `dsM` encapsulates dependency between first
+ *     training result and second trainer.
+ *     3. Second trainer is trained using dataset aquired from application `dsM` to original dataset; `mdl2` is produced.
+ *     4. `mdl1` and `mdl2` are composed into {@link ModelsSequentialComposition}.
+ * </pre>
+ *
+ * @param <I> Type of input of model produced by first trainer.
+ * @param <O1> Type of output of model produced by first trainer.
+ * @param <O2> Type of output of model produced by second trainer.
+ * @param <L> Type of labels.
+ */
+public class TrainersSequentialComposition<I, O1, O2, L> extends DatasetTrainer<ModelsSequentialComposition<I, O1, O2>, L> {
+    /** First trainer. */
+    private DatasetTrainer<IgniteModel<I, O1>, L> tr1;
+
+    /** Second trainer. */
+    private DatasetTrainer<IgniteModel<O1, O2>, L> tr2;
+
+    /** Dataset mapping. */
+    private IgniteFunction<? super IgniteModel<I, O1>, DatasetMapping<L, L>> datasetMapping;
+
+    /**
+     * Construct sequential composition of given two trainers.
+     *
+     * @param tr1 First trainer.
+     * @param tr2 Second trainer.
+     * @param datasetMapping Dataset mapping.
+     */
+    public TrainersSequentialComposition(DatasetTrainer<? extends IgniteModel<I, O1>, L> tr1,
+        DatasetTrainer<? extends IgniteModel<O1, O2>, L> tr2,
+        IgniteFunction<? super IgniteModel<I, O1>, DatasetMapping<L, L>> datasetMapping) {
+        this.tr1 = CompositionUtils.unsafeCoerce(tr1);
+        this.tr2 = CompositionUtils.unsafeCoerce(tr2);
+        this.datasetMapping = datasetMapping;
+    }
+
+    /** {@inheritDoc} */
+    @Override public <K, V> ModelsSequentialComposition<I, O1, O2> fit(DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+
+        IgniteModel<I, O1> mdl1 = tr1.fit(datasetBuilder, featureExtractor, lbExtractor);
+        DatasetMapping<L, L> mapping = datasetMapping.apply(mdl1);
+
+        IgniteModel<O1, O2> mdl2 = tr2.fit(datasetBuilder,
+            featureExtractor.andThen(mapping::mapFeatures),
+            lbExtractor.andThen(mapping::mapLabels));
+
+        return new ModelsSequentialComposition<>(mdl1, mdl2);
+    }
+
+    /** {@inheritDoc} */
+    @Override public <K, V> ModelsSequentialComposition<I, O1, O2> update(
+        ModelsSequentialComposition<I, O1, O2> mdl, DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+
+        IgniteModel<I, O1> firstUpdated = tr1.update(mdl.firstModel(), datasetBuilder, featureExtractor, lbExtractor);
+        DatasetMapping<L, L> mapping = datasetMapping.apply(firstUpdated);
+
+        IgniteModel<O1, O2> secondUpdated = tr2.update(mdl.secondModel(),
+            datasetBuilder,
+            featureExtractor.andThen(mapping::mapFeatures),
+            lbExtractor.andThen(mapping::mapLabels));
+
+        return new ModelsSequentialComposition<>(firstUpdated, secondUpdated);
+    }
+
+    /**
+     * This method is never called, instead of constructing logic of update from
+     * {@link DatasetTrainer#isUpdateable} and
+     * {@link DatasetTrainer#updateModel}
+     * in this class we explicitly override update method.
+     *
+     * @param mdl Model.
+     * @return True if current critical for training parameters correspond to parameters from last training.
+     */
+    @Override public boolean isUpdateable(ModelsSequentialComposition<I, O1, O2> mdl) {
+        // Never called.
+        throw new IllegalStateException();
+    }
+
+    /**
+     * This method is never called, instead of constructing logic of update from
+     * {@link DatasetTrainer#isUpdateable(IgniteModel)} and
+     * {@link DatasetTrainer#updateModel(IgniteModel, DatasetBuilder, IgniteBiFunction, IgniteBiFunction)}
+     * in this class we explicitly override update method.
+     *
+     * @param mdl Model.
+     * @return Updated model.
+     */
+    @Override protected <K, V> ModelsSequentialComposition<I, O1, O2> updateModel(
+        ModelsSequentialComposition<I, O1, O2> mdl, DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+        // Never called.
+        throw new IllegalStateException();
+    }
+
+    /**
+     * Performs coersion of this trainer to {@code DatasetTrainer<IgniteModel<I, O2>, L>}.
+     *
+     * @return Trainer coerced to {@code DatasetTrainer<IgniteModel<I, O>, L>}.
+     */
+    public DatasetTrainer<IgniteModel<I, O2>, L> unsafeSimplyTyped() {
+        return CompositionUtils.unsafeCoerce(this);
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/sequential/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/sequential/package-info.java
new file mode 100644 (file)
index 0000000..02ca2df
--- /dev/null
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains sequential combinators of trainers and models.
+ */
+package org.apache.ignite.ml.composition.combinators.sequential;
index e58107d..45fcecc 100644 (file)
@@ -21,15 +21,18 @@ import java.util.ArrayList;
 import java.util.List;
 import java.util.stream.Collectors;
 import org.apache.ignite.ml.IgniteModel;
+import org.apache.ignite.ml.composition.CompositionUtils;
+import org.apache.ignite.ml.composition.DatasetMapping;
+import org.apache.ignite.ml.composition.combinators.parallel.ModelsParallelComposition;
+import org.apache.ignite.ml.composition.combinators.parallel.TrainersParallelComposition;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
-import org.apache.ignite.ml.environment.parallelism.Promise;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
 import org.apache.ignite.ml.math.functions.IgniteFunction;
-import org.apache.ignite.ml.math.functions.IgniteSupplier;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.trainers.AdaptableDatasetTrainer;
 import org.apache.ignite.ml.trainers.DatasetTrainer;
 
 /**
@@ -220,31 +223,7 @@ public class StackedDatasetTrainer<IS, IA, O, AM extends IgniteModel<IA, O>, L>
         // Unsafely coerce DatasetTrainer<M1, L> to DatasetTrainer<Model<IS, IA>, L>, but we fully control
         // usages of this unsafely coerced object, on the other hand this makes work with
         // submodelTrainers easier.
-        submodelsTrainers.add(new DatasetTrainer<IgniteModel<IS, IA>, L>() {
-            /** {@inheritDoc} */
-            @Override public <K, V> IgniteModel<IS, IA> fit(DatasetBuilder<K, V> datasetBuilder,
-                IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
-                return trainer.fit(datasetBuilder, featureExtractor, lbExtractor);
-            }
-
-            /** {@inheritDoc} */
-            @Override public <K, V> IgniteModel<IS, IA> update(IgniteModel<IS, IA> mdl, DatasetBuilder<K, V> datasetBuilder,
-                IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
-                DatasetTrainer<IgniteModel<IS, IA>, L> trainer1 = (DatasetTrainer<IgniteModel<IS, IA>, L>)trainer;
-                return trainer1.update(mdl, datasetBuilder, featureExtractor, lbExtractor);
-            }
-
-            /** {@inheritDoc} */
-            @Override protected boolean checkState(IgniteModel<IS, IA> mdl) {
-                return true;
-            }
-
-            /** {@inheritDoc} */
-            @Override protected <K, V> IgniteModel<IS, IA> updateModel(IgniteModel<IS, IA> mdl, DatasetBuilder<K, V> datasetBuilder,
-                IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
-                return null;
-            }
-        });
+        submodelsTrainers.add(CompositionUtils.unsafeCoerce(trainer));
 
         return this;
     }
@@ -254,62 +233,60 @@ public class StackedDatasetTrainer<IS, IA, O, AM extends IgniteModel<IA, O>, L>
         IgniteBiFunction<K, V, Vector> featureExtractor,
         IgniteBiFunction<K, V, L> lbExtractor) {
 
-        return update(null, datasetBuilder, featureExtractor, lbExtractor);
+        return new StackedModel<>(getTrainer().fit(datasetBuilder, featureExtractor, lbExtractor));
     }
 
     /** {@inheritDoc} */
     @Override public <K, V> StackedModel<IS, IA, O, AM> update(StackedModel<IS, IA, O, AM> mdl,
         DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor,
         IgniteBiFunction<K, V, L> lbExtractor) {
-        return runOnSubmodels(
-            ensemble -> {
-                List<IgniteSupplier<IgniteModel<IS, IA>>> res = new ArrayList<>();
-                for (int i = 0; i < ensemble.size(); i++) {
-                    final int j = i;
-                    res.add(() -> {
-                        DatasetTrainer<IgniteModel<IS, IA>, L> trainer = ensemble.get(j);
-                        return mdl == null ?
-                            trainer.fit(datasetBuilder, featureExtractor, lbExtractor) :
-                            trainer.update(mdl.submodels().get(j), datasetBuilder, featureExtractor, lbExtractor);
-                    });
-                }
-                return res;
-            },
-            (at, extr) -> mdl == null ?
-                at.fit(datasetBuilder, extr, lbExtractor) :
-                at.update(mdl.aggregatorModel(), datasetBuilder, extr, lbExtractor),
-            featureExtractor
-        );
-    }
 
-    /** {@inheritDoc} */
-    @Override public StackedDatasetTrainer<IS, IA, O, AM, L> withEnvironmentBuilder(
-        LearningEnvironmentBuilder envBuilder) {
-        submodelsTrainers =
-            submodelsTrainers.stream().map(x -> x.withEnvironmentBuilder(envBuilder)).collect(Collectors.toList());
-        aggregatorTrainer = aggregatorTrainer.withEnvironmentBuilder(envBuilder);
-
-        return this;
+        return new StackedModel<>(getTrainer().update(mdl, datasetBuilder, featureExtractor, lbExtractor));
     }
 
     /**
-     * <pre>
-     * 1. Obtain models produced by running specified tasks;
-     * 2. run other specified task on dataset augmented with results of models from step 2.
-     * </pre>
+     * Get the trainer for stacking.
      *
-     * @param taskSupplier Function used to generate tasks for first step.
-     * @param aggregatorProcessor Function used
-     * @param featureExtractor Feature extractor.
-     * @param <K> Type of keys in upstream.
-     * @param <V> Type of values in upstream.
-     * @return {@link StackedModel}.
+     * @return Trainer for stacking.
      */
-    private <K, V> StackedModel<IS, IA, O, AM> runOnSubmodels(
-        IgniteFunction<List<DatasetTrainer<IgniteModel<IS, IA>, L>>, List<IgniteSupplier<IgniteModel<IS, IA>>>> taskSupplier,
-        IgniteBiFunction<DatasetTrainer<AM, L>, IgniteBiFunction<K, V, Vector>, AM> aggregatorProcessor,
-        IgniteBiFunction<K, V, Vector> featureExtractor) {
+    private DatasetTrainer<IgniteModel<IS, O>, L> getTrainer() {
+        checkConsistency();
+
+        List<DatasetTrainer<IgniteModel<IS, IA>, L>> subs = new ArrayList<>();
+        if (submodelInput2AggregatingInputConverter != null) {
+            DatasetTrainer<IgniteModel<IS, IS>, L> id = DatasetTrainer.identityTrainer();
+            DatasetTrainer<IgniteModel<IS, IA>, L> mappedId = CompositionUtils.unsafeCoerce(
+                AdaptableDatasetTrainer.of(id).afterTrainedModel(submodelInput2AggregatingInputConverter));
+            subs.add(mappedId);
+        }
+
+        subs.addAll(submodelsTrainers);
+
+        TrainersParallelComposition<IS, IA, L> composition = new TrainersParallelComposition<>(subs);
 
+        IgniteBiFunction<List<IgniteModel<IS, IA>>, Vector, Vector> featureMapper = getFeatureExtractorForAggregator(
+            submodelOutput2VectorConverter,
+            vector2SubmodelInputConverter);
+
+        return AdaptableDatasetTrainer
+            .of(composition)
+            .afterTrainedModel(lst -> lst.stream().reduce(aggregatingInputMerger).get())
+            .andThen(aggregatorTrainer, model -> new DatasetMapping<L, L>() {
+                @Override public Vector mapFeatures(Vector v) {
+                    List<IgniteModel<IS, IA>> models = ((ModelsParallelComposition<IS, IA>)model.innerModel()).submodels();
+                    return featureMapper.apply(models, v);
+                }
+
+                @Override public L mapLabels(L lbl) {
+                    return lbl;
+                }
+            }).unsafeSimplyTyped();
+    }
+
+    /**
+     * Method checking consistency of this trainer.
+     */
+    private void checkConsistency() {
         // Make sure there is at least one way for submodel input to propagate to aggregator.
         if (submodelInput2AggregatingInputConverter == null && submodelsTrainers.isEmpty())
             throw new IllegalStateException("There should be at least one way for submodels " +
@@ -321,60 +298,36 @@ public class StackedDatasetTrainer<IS, IA, O, AM extends IgniteModel<IA, O>, L>
 
         if (aggregatingInputMerger == null)
             throw new IllegalStateException("Binary operator used to convert outputs of submodels is not specified");
+    }
 
-        List<IgniteSupplier<IgniteModel<IS, IA>>> mdlSuppliers = taskSupplier.apply(submodelsTrainers);
-
-        List<IgniteModel<IS, IA>> subMdls = environment.parallelismStrategy().submit(mdlSuppliers).stream()
-            .map(Promise::unsafeGet)
-            .collect(Collectors.toList());
-
-        // Add new columns consisting in submodels output in features.
-        IgniteBiFunction<K, V, Vector> augmentedExtractor = getFeatureExtractorForAggregator(featureExtractor,
-            subMdls,
-            submodelInput2AggregatingInputConverter,
-            submodelOutput2VectorConverter,
-            vector2SubmodelInputConverter);
-
-        AM aggregator = aggregatorProcessor.apply(aggregatorTrainer, augmentedExtractor);
-
-        StackedModel<IS, IA, O, AM> res = new StackedModel<>(
-            aggregator,
-            aggregatingInputMerger,
-            submodelInput2AggregatingInputConverter);
-
-        for (IgniteModel<IS, IA> subMdl : subMdls)
-            res.addSubmodel(subMdl);
+    /** {@inheritDoc} */
+    @Override public StackedDatasetTrainer<IS, IA, O, AM, L> withEnvironmentBuilder(
+        LearningEnvironmentBuilder envBuilder) {
+        submodelsTrainers =
+            submodelsTrainers.stream().map(x -> x.withEnvironmentBuilder(envBuilder)).collect(Collectors.toList());
+        aggregatorTrainer = aggregatorTrainer.withEnvironmentBuilder(envBuilder);
 
-        return res;
+        return this;
     }
 
     /**
      * Get feature extractor which will be used for aggregator trainer from original feature extractor.
      * This method is static to make sure that we will not grab context of instance in serialization.
      *
-     * @param featureExtractor Original feature extractor.
-     * @param subMdls Submodels.
+     * @param <IS> Type of submodels input.
+     * @param <IA> Type of aggregator input.
      * @param <K> Type of upstream keys.
-     * @param <V> Type of upstream values.
+     * @param <V> Type of upstream values
      * @return Feature extractor which will be used for aggregator trainer from original feature extractor.
      */
-    private static <IS, IA, K, V> IgniteBiFunction<K, V, Vector> getFeatureExtractorForAggregator(
-        IgniteBiFunction<K, V, Vector> featureExtractor, List<IgniteModel<IS, IA>> subMdls,
-        IgniteFunction<IS, IA> submodelInput2AggregatingInputConverter,
+    private static <IS, IA, K, V> IgniteBiFunction<List<IgniteModel<IS, IA>>, Vector, Vector> getFeatureExtractorForAggregator(
         IgniteFunction<IA, Vector> submodelOutput2VectorConverter,
         IgniteFunction<Vector, IS> vector2SubmodelInputConverter) {
-        if (submodelInput2AggregatingInputConverter != null)
-            return featureExtractor.andThen((Vector v) -> {
-                Vector[] vs = subMdls.stream().map(sm ->
-                    applyToVector(sm, submodelOutput2VectorConverter, vector2SubmodelInputConverter, v)).toArray(Vector[]::new);
-                return VectorUtils.concat(v, vs);
-            });
-        else
-            return featureExtractor.andThen((Vector v) -> {
-                Vector[] vs = subMdls.stream().map(sm ->
-                    applyToVector(sm, submodelOutput2VectorConverter, vector2SubmodelInputConverter, v)).toArray(Vector[]::new);
-                return VectorUtils.concat(vs);
-            });
+        return (List<IgniteModel<IS, IA>> subMdls, Vector v) -> {
+            Vector[] vs = subMdls.stream().map(sm ->
+                applyToVector(sm, submodelOutput2VectorConverter, vector2SubmodelInputConverter, v)).toArray(Vector[]::new);
+            return VectorUtils.concat(vs);
+        };
     }
 
     /**
@@ -396,17 +349,34 @@ public class StackedDatasetTrainer<IS, IA, O, AM extends IgniteModel<IA, O>, L>
         return vector2SubmodelInputConverter.andThen(mdl::predict).andThen(submodelOutput2VectorConverter).apply(v);
     }
 
-    /** {@inheritDoc} */
+    /**
+     * This method is never called, instead of constructing logic of update from
+     * {@link DatasetTrainer#isUpdateable(IgniteModel)} and
+     * {@link DatasetTrainer#updateModel(IgniteModel, DatasetBuilder, IgniteBiFunction, IgniteBiFunction)}
+     * in this class we explicitly override update method.
+     *
+     * @param mdl Model.
+     * @return Updated model.
+     */
     @Override protected <K, V> StackedModel<IS, IA, O, AM> updateModel(StackedModel<IS, IA, O, AM> mdl,
         DatasetBuilder<K, V> datasetBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor,
         IgniteBiFunction<K, V, L> lbExtractor) {
         // This method is never called, we override "update" instead.
-        return null;
+        throw new IllegalStateException();
     }
 
-    /** {@inheritDoc} */
-    @Override protected boolean checkState(StackedModel<IS, IA, O, AM> mdl) {
-        return true;
+    /**
+     * This method is never called, instead of constructing logic of update from
+     * {@link DatasetTrainer#isUpdateable} and
+     * {@link DatasetTrainer#updateModel}
+     * in this class we explicitly override update method.
+     *
+     * @param mdl Model.
+     * @return True if current critical for training parameters correspond to parameters from last training.
+     */
+    @Override public boolean isUpdateable(StackedModel<IS, IA, O, AM> mdl) {
+        // Should be never called.
+        throw new IllegalStateException();
     }
 }
index a9be8f8..34e1a97 100644 (file)
 
 package org.apache.ignite.ml.composition.stacking;
 
-import java.util.ArrayList;
-import java.util.List;
 import org.apache.ignite.ml.IgniteModel;
-import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
-import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.composition.combinators.parallel.ModelsParallelComposition;
 
 /**
+ * This is a wrapper for model produced by {@link StackedDatasetTrainer}.
  * Model consisting of two layers:
  * <pre>
  *     1. Submodels layer {@code (IS -> IA)}.
  *     2. Aggregator layer {@code (IA -> O)}.
  * </pre>
- * Submodels layer is a "parallel" composition of several models {@code IS -> IA} each of them getting same input
+ * Submodels layer is a {@link ModelsParallelComposition} of several models {@code IS -> IA} each of them getting same input
  * {@code IS} and produce own output, these outputs outputs {@code [IA]}
  * are combined into a single output with a given binary "merger" operator {@code IA -> IA -> IA}. Result of merge
  * is then passed to the aggregator layer.
@@ -41,66 +39,24 @@ import org.apache.ignite.ml.math.functions.IgniteFunction;
  * @param <AM> Type of aggregator model.
  */
 public class StackedModel<IS, IA, O, AM extends IgniteModel<IA, O>> implements IgniteModel<IS, O> {
-    /** Submodels layer. */
-    private IgniteModel<IS, IA> subModelsLayer;
-
-    /** Aggregator model. */
-    private final AM aggregatorMdl;
-
-    /** Models constituting submodels layer. */
-    private List<IgniteModel<IS, IA>> submodels;
-
-    /** Binary operator merging submodels outputs. */
-    private final IgniteBinaryOperator<IA> aggregatingInputMerger;
-
-    /**
-     * Constructs instance of this class.
-     *
-     * @param aggregatorMdl Aggregator model.
-     * @param aggregatingInputMerger Binary operator used to merge submodels outputs.
-     * @param subMdlInput2AggregatingInput Function converting submodels input to aggregator input. (This function
-     * is needed when in {@link StackedDatasetTrainer} option to keep original features is chosen).
-     */
-    StackedModel(AM aggregatorMdl,
-        IgniteBinaryOperator<IA> aggregatingInputMerger,
-        IgniteFunction<IS, IA> subMdlInput2AggregatingInput) {
-        this.aggregatorMdl = aggregatorMdl;
-        this.aggregatingInputMerger = aggregatingInputMerger;
-        this.subModelsLayer = subMdlInput2AggregatingInput != null ? subMdlInput2AggregatingInput::apply : null;
-        submodels = new ArrayList<>();
-    }
-
-    /**
-     * Get submodels constituting first layer of this model.
-     *
-     * @return Submodels constituting first layer of this model.
-     */
-    List<IgniteModel<IS, IA>> submodels() {
-        return submodels;
-    }
+    /** Model to wrap. */
+    private IgniteModel<IS, O> mdl;
 
     /**
-     * Get aggregator model.
-     *
-     * @return Aggregator model.
+     * Construct instance of this class from {@link IgniteModel}.
+     * @param mdl
      */
-    AM aggregatorModel() {
-        return aggregatorMdl;
+    StackedModel(IgniteModel<IS, O> mdl) {
+        this.mdl = mdl;
     }
 
-    /**
-     * Add submodel into first layer.
-     *
-     * @param subMdl Submodel to add.
-     */
-    void addSubmodel(IgniteModel<IS, IA> subMdl) {
-        submodels.add(subMdl);
-        subModelsLayer = subModelsLayer != null ? subModelsLayer.combine(subMdl, aggregatingInputMerger)
-            : subMdl;
+    /** {@inheritDoc} */
+    @Override public O predict(IS is) {
+        return mdl.predict(is);
     }
 
     /** {@inheritDoc} */
-    @Override public O predict(IS is) {
-        return subModelsLayer.andThen(aggregatorMdl).predict(is);
+    @Override public void close() {
+        mdl.close();
     }
 }
index 7230e3c..c25b721 100644 (file)
@@ -81,6 +81,7 @@ public class StackedVectorDatasetTrainer<O, AM extends IgniteModel<Vector, O>, L
     }
 
     /** {@inheritDoc} */
+    // TODO: IGNITE-10843 Add possibility to keep features with specific indices.
     @Override public StackedVectorDatasetTrainer<O, AM, L> withOriginalFeaturesKept(
         IgniteFunction<Vector, Vector> submodelInput2AggregatingInputConverter) {
         return (StackedVectorDatasetTrainer<O, AM, L>)super.withOriginalFeaturesKept(
index 9900659..c826a40 100644 (file)
@@ -67,7 +67,7 @@ public interface DatasetBuilder<K, V> {
      * @return Returns new instance of {@link DatasetBuilder} with new {@link UpstreamTransformerBuilder} added
      * to chain of upstream transformer builders.
      */
-    public DatasetBuilder<K, V> withUpstreamTransformer(UpstreamTransformerBuilder<K, V> builder);
+    public DatasetBuilder<K, V> withUpstreamTransformer(UpstreamTransformerBuilder builder);
 
     /**
      * Returns new instance of DatasetBuilder using conjunction of internal filter and {@code filterToAdd}.
index 9c0e281..c7fb92f 100644 (file)
@@ -22,29 +22,15 @@ import java.util.stream.Stream;
 
 /**
  * Interface of transformer of upstream.
- *
- * @param <K> Type of keys in the upstream.
- * @param <V> Type of values in the upstream.
  */
 // TODO: IGNITE-10297: Investigate possibility of API change.
 @FunctionalInterface
-public interface UpstreamTransformer<K, V> extends Serializable {
+public interface UpstreamTransformer extends Serializable {
     /**
      * Transform upstream.
      *
      * @param upstream Upstream to transform.
      * @return Transformed upstream.
      */
-    public Stream<UpstreamEntry<K, V>> transform(Stream<UpstreamEntry<K, V>> upstream);
-
-    /**
-     * Get composition of this transformer and other transformer which is
-     * itself is {@link UpstreamTransformer} applying this transformer and then other transformer.
-     *
-     * @param other Other transformer.
-     * @return Composition of this and other transformer.
-     */
-    public default UpstreamTransformer<K, V> andThen(UpstreamTransformer<K, V> other) {
-        return upstream -> other.transform(transform(upstream));
-    }
+    public Stream<UpstreamEntry> transform(Stream<UpstreamEntry> upstream);
 }
index 9adfab5..ea9f126 100644 (file)
@@ -21,19 +21,17 @@ import java.io.Serializable;
 import org.apache.ignite.ml.environment.LearningEnvironment;
 
 /**
- * Builder of {@link UpstreamTransformerBuilder}.
- * @param <K> Type of keys in upstream.
- * @param <V> Type of values in upstream.
+ * Builder of {@link UpstreamTransformer}.
  */
 @FunctionalInterface
-public interface UpstreamTransformerBuilder<K, V> extends Serializable {
+public interface UpstreamTransformerBuilder extends Serializable {
     /**
      * Create {@link UpstreamTransformer} based on learning environment.
      *
      * @param env Learning environment.
      * @return Upstream transformer.
      */
-    public UpstreamTransformer<K, V> build(LearningEnvironment env);
+    public UpstreamTransformer build(LearningEnvironment env);
 
     /**
      * Combunes two builders (this and other respectfully)
@@ -49,11 +47,11 @@ public interface UpstreamTransformerBuilder<K, V> extends Serializable {
      * @param other Builder to combine with.
      * @return Compositional builder.
      */
-    public default UpstreamTransformerBuilder<K, V> andThen(UpstreamTransformerBuilder<K, V> other) {
-        UpstreamTransformerBuilder<K, V> self = this;
+    public default UpstreamTransformerBuilder andThen(UpstreamTransformerBuilder other) {
+        UpstreamTransformerBuilder self = this;
         return env -> {
-            UpstreamTransformer<K, V> transformer1 = self.build(env);
-            UpstreamTransformer<K, V> transformer2 = other.build(env);
+            UpstreamTransformer transformer1 = self.build(env);
+            UpstreamTransformer transformer2 = other.build(env);
 
             return upstream -> transformer2.transform(transformer1.transform(upstream));
         };
@@ -66,7 +64,7 @@ public interface UpstreamTransformerBuilder<K, V> extends Serializable {
      * @param <V> Type of values in upstream.
      * @return Identity upstream transformer.
      */
-    public static <K, V> UpstreamTransformerBuilder<K, V> identity() {
+    public static <K, V> UpstreamTransformerBuilder identity() {
         return env -> upstream -> upstream;
     }
 }
index bde4bb6..b2aa00b 100644 (file)
@@ -64,7 +64,7 @@ public class CacheBasedDataset<K, V, C extends Serializable, D extends AutoClose
     private final IgniteBiPredicate<K, V> filter;
 
     /** Builder of transformation applied to upstream. */
-    private final UpstreamTransformerBuilder<K, V> upstreamTransformerBuilder;
+    private final UpstreamTransformerBuilder upstreamTransformerBuilder;
 
     /** Ignite Cache with partition {@code context}. */
     private final IgniteCache<Integer, C> datasetCache;
@@ -94,7 +94,7 @@ public class CacheBasedDataset<K, V, C extends Serializable, D extends AutoClose
         Ignite ignite,
         IgniteCache<K, V> upstreamCache,
         IgniteBiPredicate<K, V> filter,
-        UpstreamTransformerBuilder<K, V> upstreamTransformerBuilder,
+        UpstreamTransformerBuilder upstreamTransformerBuilder,
         IgniteCache<Integer, C> datasetCache,
         LearningEnvironmentBuilder envBuilder,
         PartitionDataBuilder<K, V, C, D> partDataBuilder,
index be40158..b85bfc2 100644 (file)
@@ -59,7 +59,7 @@ public class CacheBasedDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
     private final IgniteBiPredicate<K, V> filter;
 
     /** Upstream transformer builder. */
-    private final UpstreamTransformerBuilder<K, V> transformerBuilder;
+    private final UpstreamTransformerBuilder transformerBuilder;
 
     /**
      * Constructs a new instance of cache based dataset builder that makes {@link CacheBasedDataset} with default
@@ -93,7 +93,7 @@ public class CacheBasedDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
     public CacheBasedDatasetBuilder(Ignite ignite,
         IgniteCache<K, V> upstreamCache,
         IgniteBiPredicate<K, V> filter,
-        UpstreamTransformerBuilder<K, V> transformerBuilder) {
+        UpstreamTransformerBuilder transformerBuilder) {
         this.ignite = ignite;
         this.upstreamCache = upstreamCache;
         this.filter = filter;
@@ -136,7 +136,7 @@ public class CacheBasedDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
     }
 
     /** {@inheritDoc} */
-    @Override public DatasetBuilder<K, V> withUpstreamTransformer(UpstreamTransformerBuilder<K, V> builder) {
+    @Override public DatasetBuilder<K, V> withUpstreamTransformer(UpstreamTransformerBuilder builder) {
         return new CacheBasedDatasetBuilder<>(ignite, upstreamCache, filter, transformerBuilder.andThen(builder));
     }
 
index 7fa1efa..f12977c 100644 (file)
@@ -185,7 +185,7 @@ public class ComputeUtils {
     public static <K, V, C extends Serializable, D extends AutoCloseable> D getData(
         Ignite ignite,
         String upstreamCacheName, IgniteBiPredicate<K, V> filter,
-        UpstreamTransformerBuilder<K, V> transformerBuilder,
+        UpstreamTransformerBuilder transformerBuilder,
         String datasetCacheName, UUID datasetId,
         PartitionDataBuilder<K, V, C, D> partDataBuilder,
         LearningEnvironment env) {
@@ -208,8 +208,8 @@ public class ComputeUtils {
             qry.setPartition(part);
             qry.setFilter(filter);
 
-            UpstreamTransformer<K, V> transformer = transformerBuilder.build(env);
-            UpstreamTransformer<K, V> transformerCp = Utils.copy(transformer);
+            UpstreamTransformer transformer = transformerBuilder.build(env);
+            UpstreamTransformer transformerCp = Utils.copy(transformer);
 
             long cnt = computeCount(upstreamCache, qry, transformer);
 
@@ -218,9 +218,8 @@ public class ComputeUtils {
                     e -> new UpstreamEntry<>(e.getKey(), e.getValue()))) {
 
                     Iterator<UpstreamEntry<K, V>> it = cursor.iterator();
-                    Stream<UpstreamEntry<K, V>> transformedStream = transformerCp.transform(Utils.asStream(it, cnt));
-                    it = transformedStream.iterator();
-
+                    Stream<UpstreamEntry> transformedStream = transformerCp.transform(Utils.asStream(it, cnt).map(x -> (UpstreamEntry)x));
+                    it = Utils.asStream(transformedStream.iterator()).map(x -> (UpstreamEntry<K, V>)x).iterator();
 
                     Iterator<UpstreamEntry<K, V>> iter = new IteratorWithConcurrentModificationChecker<>(it, cnt,
                         "Cache expected to be not modified during dataset data building [partition=" + part + ']');
@@ -268,7 +267,7 @@ public class ComputeUtils {
     public static <K, V, C extends Serializable> void initContext(
         Ignite ignite,
         String upstreamCacheName,
-        UpstreamTransformerBuilder<K, V> transformerBuilder,
+        UpstreamTransformerBuilder transformerBuilder,
         IgniteBiPredicate<K, V> filter,
         String datasetCacheName,
         PartitionContextBuilder<K, V, C> ctxBuilder,
@@ -287,8 +286,8 @@ public class ComputeUtils {
             qry.setFilter(filter);
 
             C ctx;
-            UpstreamTransformer<K, V> transformer = transformerBuilder.build(env);
-            UpstreamTransformer<K, V> transformerCp = Utils.copy(transformer);
+            UpstreamTransformer transformer = transformerBuilder.build(env);
+            UpstreamTransformer transformerCp = Utils.copy(transformer);
 
             long cnt = computeCount(locUpstreamCache, qry, transformer);
 
@@ -296,8 +295,8 @@ public class ComputeUtils {
                 e -> new UpstreamEntry<>(e.getKey(), e.getValue()))) {
 
                 Iterator<UpstreamEntry<K, V>> it = cursor.iterator();
-                Stream<UpstreamEntry<K, V>> transformedStream = transformerCp.transform(Utils.asStream(it, cnt));
-                it = transformedStream.iterator();
+                Stream<UpstreamEntry> transformedStream = transformerCp.transform(Utils.asStream(it, cnt).map(x -> (UpstreamEntry)x));
+                it = Utils.asStream(transformedStream.iterator()).map(x -> (UpstreamEntry<K, V>)x).iterator();
 
                 Iterator<UpstreamEntry<K, V>> iter = new IteratorWithConcurrentModificationChecker<>(
                     it,
@@ -334,7 +333,7 @@ public class ComputeUtils {
         Ignite ignite,
         String upstreamCacheName,
         IgniteBiPredicate<K, V> filter,
-        UpstreamTransformerBuilder<K, V> transformerBuilder,
+        UpstreamTransformerBuilder transformerBuilder,
         String datasetCacheName,
         PartitionContextBuilder<K, V, C> ctxBuilder,
         LearningEnvironmentBuilder envBuilder,
@@ -382,11 +381,11 @@ public class ComputeUtils {
     private static <K, V> long computeCount(
         IgniteCache<K, V> cache,
         ScanQuery<K, V> qry,
-        UpstreamTransformer<K, V> transformer) {
+        UpstreamTransformer transformer) {
         try (QueryCursor<UpstreamEntry<K, V>> cursor = cache.query(qry,
             e -> new UpstreamEntry<>(e.getKey(), e.getValue()))) {
 
-            return computeCount(transformer.transform(Utils.asStream(cursor.iterator())).iterator());
+            return computeCount(transformer.transform(Utils.asStream(cursor.iterator()).map(x -> (UpstreamEntry<K, V>)x)).iterator());
         }
     }
 
index b8cd8dc..84f3e08 100644 (file)
@@ -54,7 +54,7 @@ public class LocalDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
     private final IgniteBiPredicate<K, V> filter;
 
     /** Upstream transformers. */
-    private final UpstreamTransformerBuilder<K, V> upstreamTransformerBuilder;
+    private final UpstreamTransformerBuilder upstreamTransformerBuilder;
 
     /**
      * Constructs a new instance of local dataset builder that makes {@link LocalDataset} with default predicate that
@@ -78,7 +78,7 @@ public class LocalDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
     public LocalDatasetBuilder(Map<K, V> upstreamMap,
         IgniteBiPredicate<K, V> filter,
         int partitions,
-        UpstreamTransformerBuilder<K, V> upstreamTransformerBuilder) {
+        UpstreamTransformerBuilder upstreamTransformerBuilder) {
         this.upstreamMap = upstreamMap;
         this.filter = filter;
         this.partitions = partitions;
@@ -129,23 +129,26 @@ public class LocalDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
             int cntBeforeTransform =
                 part == partitions - 1 ? entriesList.size() - ptr : Math.min(partSize, entriesList.size() - ptr);
             LearningEnvironment env = envs.get(part);
-            UpstreamTransformer<K, V> transformer1 = upstreamTransformerBuilder.build(env);
-            UpstreamTransformer<K, V> transformer2 = Utils.copy(transformer1);
-            UpstreamTransformer<K, V> transformer3 = Utils.copy(transformer1);
+            UpstreamTransformer transformer1 = upstreamTransformerBuilder.build(env);
+            UpstreamTransformer transformer2 = Utils.copy(transformer1);
+            UpstreamTransformer transformer3 = Utils.copy(transformer1);
 
             int cnt = (int)transformer1.transform(Utils.asStream(new IteratorWindow<>(thirdKeysIter, k -> k, cntBeforeTransform))).count();
 
-            Iterator<UpstreamEntry<K, V>> iter =
-                transformer2.transform(Utils.asStream(new IteratorWindow<>(firstKeysIter, k -> k, cntBeforeTransform))).iterator();
+            Iterator<UpstreamEntry> iter =
+                transformer2.transform(Utils.asStream(new IteratorWindow<>(firstKeysIter, k -> k, cntBeforeTransform)).map(x -> (UpstreamEntry)x)).iterator();
+            Iterator<UpstreamEntry<K, V>> convertedBack = Utils.asStream(iter).map(x -> (UpstreamEntry<K, V>)x).iterator();
 
-            C ctx = cntBeforeTransform > 0 ? partCtxBuilder.build(env, iter, cnt) : null;
+            C ctx = cntBeforeTransform > 0 ? partCtxBuilder.build(env, convertedBack, cnt) : null;
 
-            Iterator<UpstreamEntry<K, V>> iter1 = transformer3.transform(
+            Iterator<UpstreamEntry> iter1 = transformer3.transform(
                     Utils.asStream(new IteratorWindow<>(secondKeysIter, k -> k, cntBeforeTransform))).iterator();
 
+            Iterator<UpstreamEntry<K, V>> convertedBack1 = Utils.asStream(iter1).map(x -> (UpstreamEntry<K, V>)x).iterator();
+
             D data = cntBeforeTransform > 0 ? partDataBuilder.build(
                 env,
-                iter1,
+                convertedBack1,
                 cnt,
                 ctx
             ) : null;
@@ -160,7 +163,7 @@ public class LocalDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
     }
 
     /** {@inheritDoc} */
-    @Override public DatasetBuilder<K, V> withUpstreamTransformer(UpstreamTransformerBuilder<K, V> builder) {
+    @Override public DatasetBuilder<K, V> withUpstreamTransformer(UpstreamTransformerBuilder builder) {
         return new LocalDatasetBuilder<>(upstreamMap, filter, partitions, upstreamTransformerBuilder.andThen(builder));
     }
 
index ed78e85..0552036 100644 (file)
@@ -19,7 +19,6 @@ package org.apache.ignite.ml.genetic;
 
 import java.util.Arrays;
 import java.util.concurrent.atomic.AtomicLong;
-
 import org.apache.ignite.cache.query.annotations.QuerySqlField;
 
 /**
index b03e7ca..d69911a 100644 (file)
@@ -22,11 +22,10 @@ import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.IgniteException;
 import org.apache.ignite.compute.ComputeJobAdapter;
+import org.apache.ignite.ml.genetic.parameter.GAGridConstants;
 import org.apache.ignite.resources.IgniteInstanceResource;
 import org.apache.ignite.transactions.Transaction;
 
-import org.apache.ignite.ml.genetic.parameter.GAGridConstants;
-
 /**
  * Responsible for applying mutation on respective Chromosome based on mutation Rate
  */
index c5302ee..f980e22 100644 (file)
@@ -20,7 +20,6 @@ package org.apache.ignite.ml.genetic.cache;
 import org.apache.ignite.cache.CacheMode;
 import org.apache.ignite.cache.CacheRebalanceMode;
 import org.apache.ignite.configuration.CacheConfiguration;
-
 import org.apache.ignite.ml.genetic.Gene;
 import org.apache.ignite.ml.genetic.functions.GAGridFunction;
 import org.apache.ignite.ml.genetic.parameter.GAGridConstants;
index cae7c1a..6a8b2b4 100644 (file)
@@ -21,7 +21,6 @@ import org.apache.ignite.cache.CacheAtomicityMode;
 import org.apache.ignite.cache.CacheMode;
 import org.apache.ignite.cache.CacheRebalanceMode;
 import org.apache.ignite.configuration.CacheConfiguration;
-
 import org.apache.ignite.ml.genetic.Chromosome;
 import org.apache.ignite.ml.genetic.parameter.GAGridConstants;
 
index c32ca56..0cdfc52 100644 (file)
@@ -102,7 +102,7 @@ public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClass
     }
 
     /** {@inheritDoc} */
-    @Override protected boolean checkState(ANNClassificationModel mdl) {
+    @Override public boolean isUpdateable(ANNClassificationModel mdl) {
         return mdl.getDistanceMeasure().equals(distance) && mdl.getCandidates().rowSize() == k;
     }
 
index c52ad2b..16bf186 100644 (file)
@@ -60,7 +60,7 @@ public class KNNClassificationTrainer extends SingleLabelDatasetTrainer<KNNClass
     }
 
     /** {@inheritDoc} */
-    @Override protected boolean checkState(KNNClassificationModel mdl) {
+    @Override public boolean isUpdateable(KNNClassificationModel mdl) {
         return true;
     }
 }
index 9b348f3..e621801 100644 (file)
@@ -56,7 +56,7 @@ public class KNNRegressionTrainer extends SingleLabelDatasetTrainer<KNNRegressio
     }
 
     /** {@inheritDoc} */
-    @Override protected boolean checkState(KNNRegressionModel mdl) {
+    @Override public boolean isUpdateable(KNNRegressionModel mdl) {
         return true;
     }
 }
index 4eca27f..a44b5b4 100644 (file)
@@ -101,7 +101,7 @@ public class OneVsRestTrainer<M extends IgniteModel<Vector, Double>>
     }
 
     /** {@inheritDoc} */
-    @Override protected boolean checkState(MultiClassModel<M> mdl) {
+    @Override public boolean isUpdateable(MultiClassModel<M> mdl) {
         return true;
     }
 
index 0779b84..0179b31 100644 (file)
@@ -59,7 +59,7 @@ public class DiscreteNaiveBayesTrainer extends SingleLabelDatasetTrainer<Discret
     }
 
     /** {@inheritDoc} */
-    @Override protected boolean checkState(DiscreteNaiveBayesModel mdl) {
+    @Override public boolean isUpdateable(DiscreteNaiveBayesModel mdl) {
         if (mdl.getBucketThresholds().length != bucketThresholds.length)
             return false;
 
@@ -124,7 +124,7 @@ public class DiscreteNaiveBayesTrainer extends SingleLabelDatasetTrainer<Discret
                     return a.merge(b);
                 });
 
-                if (mdl != null && checkState(mdl)) {
+                if (mdl != null && isUpdateable(mdl)) {
                     if (checkSumsHolder(sumsHolder, mdl.getSumsHolder()))
                         sumsHolder = sumsHolder.merge(mdl.getSumsHolder());
                 }
index cdaac5a..c4ef1bd 100644 (file)
@@ -55,7 +55,7 @@ public class GaussianNaiveBayesTrainer extends SingleLabelDatasetTrainer<Gaussia
     }
 
     /** {@inheritDoc} */
-    @Override protected boolean checkState(GaussianNaiveBayesModel mdl) {
+    @Override public boolean isUpdateable(GaussianNaiveBayesModel mdl) {
         return true;
     }
 
index ea0bb6c..cf511ec 100644 (file)
@@ -354,7 +354,7 @@ public class MLPTrainer<P extends Serializable> extends MultiLabelDatasetTrainer
     }
 
     /** {@inheritDoc} */
-    @Override protected boolean checkState(MultilayerPerceptron mdl) {
+    @Override public boolean isUpdateable(MultilayerPerceptron mdl) {
         return true;
     }
 
index 6b2b11e..e273633 100644 (file)
@@ -79,7 +79,7 @@ public class LinearRegressionLSQRTrainer extends SingleLabelDatasetTrainer<Linea
     }
 
     /** {@inheritDoc} */
-    @Override public boolean checkState(LinearRegressionModel mdl) {
+    @Override public boolean isUpdateable(LinearRegressionModel mdl) {
         return true;
     }
 }
index 4132d35..7dc4df6 100644 (file)
@@ -160,7 +160,7 @@ public class LinearRegressionSGDTrainer<P extends Serializable> extends SingleLa
     }
 
     /** {@inheritDoc} */
-    @Override protected boolean checkState(LinearRegressionModel mdl) {
+    @Override public boolean isUpdateable(LinearRegressionModel mdl) {
         return true;
     }
 
index 864187d..16ffac3 100644 (file)
@@ -139,7 +139,7 @@ public class LogisticRegressionSGDTrainer extends SingleLabelDatasetTrainer<Logi
     }
 
     /** {@inheritDoc} */
-    @Override protected boolean checkState(LogisticRegressionModel mdl) {
+    @Override public boolean isUpdateable(LogisticRegressionModel mdl) {
         return true;
     }
 
index 67484ea..90bbe37 100644 (file)
@@ -121,7 +121,7 @@ public class SVMLinearClassificationTrainer extends SingleLabelDatasetTrainer<SV
     }
 
     /** {@inheritDoc} */
-    @Override protected boolean checkState(SVMLinearClassificationModel mdl) {
+    @Override public boolean isUpdateable(SVMLinearClassificationModel mdl) {
         return true;
     }
 
index 4205286..4695946 100644 (file)
 package org.apache.ignite.ml.trainers;
 
 import org.apache.ignite.ml.IgniteModel;
+import org.apache.ignite.ml.composition.DatasetMapping;
+import org.apache.ignite.ml.composition.combinators.sequential.TrainersSequentialComposition;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.UpstreamTransformerBuilder;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.functions.IgniteFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
@@ -46,6 +49,15 @@ public class AdaptableDatasetTrainer<I, O, IW, OW, M extends IgniteModel<IW, OW>
     /** Function used to convert output type of wrapped trainer. */
     private final IgniteFunction<OW, O> after;
 
+    /** Function which is applied after feature extractor. */
+    private final IgniteFunction<Vector, Vector> afterFeatureExtractor;
+
+    /** Function which is applied after label extractor. */
+    private final IgniteFunction<L, L> afterLabelExtractor;
+
+    /** Upstream transformer builder which will be used in dataset builder. */
+    private final UpstreamTransformerBuilder upstreamTransformerBuilder;
+
     /**
      * Construct instance of this class from a given {@link DatasetTrainer}.
      *
@@ -56,39 +68,65 @@ public class AdaptableDatasetTrainer<I, O, IW, OW, M extends IgniteModel<IW, OW>
      * @param <L> Type of labels.
      * @return Instance of this class.
      */
-    public static <I, O, M extends IgniteModel<I, O>, L> AdaptableDatasetTrainer<I, O, I, O, M, L> of(DatasetTrainer<M, L> wrapped) {
-        return new AdaptableDatasetTrainer<>(IgniteFunction.identity(), wrapped, IgniteFunction.identity());
+    public static <I, O, M extends IgniteModel<I, O>, L> AdaptableDatasetTrainer<I, O, I, O, M, L> of(
+        DatasetTrainer<M, L> wrapped) {
+        return new AdaptableDatasetTrainer<>(IgniteFunction.identity(),
+            wrapped,
+            IgniteFunction.identity(),
+            IgniteFunction.identity(),
+            IgniteFunction.identity(),
+            UpstreamTransformerBuilder.identity());
     }
 
     /**
      * Construct instance of this class with specified wrapped trainer and converter functions.
      *
      * @param before Function used to convert input type of wrapped trainer.
-     * @param wrapped  Wrapped trainer.
+     * @param wrapped Wrapped trainer.
      * @param after Function used to convert output type of wrapped trainer.
+     * @param extractor Function which is applied after label extractor.
+     * @param builder Upstream transformer builder which will be used in dataset builder.
      */
-    private AdaptableDatasetTrainer(IgniteFunction<I, IW> before, DatasetTrainer<M, L> wrapped, IgniteFunction<OW, O> after) {
+    private AdaptableDatasetTrainer(IgniteFunction<I, IW> before, DatasetTrainer<M, L> wrapped,
+        IgniteFunction<OW, O> after,
+        IgniteFunction<Vector, Vector> afterFeatureExtractor,
+        IgniteFunction<L, L> extractor, UpstreamTransformerBuilder builder) {
         this.before = before;
         this.wrapped = wrapped;
         this.after = after;
+        this.afterFeatureExtractor = afterFeatureExtractor;
+        afterLabelExtractor = extractor;
+        upstreamTransformerBuilder = builder;
     }
 
     /** {@inheritDoc} */
     @Override public <K, V> AdaptableDatasetModel<I, O, IW, OW, M> fit(DatasetBuilder<K, V> datasetBuilder,
-        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
-        M fit = wrapped.fit(datasetBuilder, featureExtractor, lbExtractor);
+        IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, L> lbExtractor) {
+        M fit = wrapped.fit(
+            datasetBuilder.withUpstreamTransformer(upstreamTransformerBuilder),
+            featureExtractor.andThen(afterFeatureExtractor),
+            lbExtractor.andThen(afterLabelExtractor));
+
         return new AdaptableDatasetModel<>(before, fit, after);
     }
 
     /** {@inheritDoc} */
-    @Override protected boolean checkState(AdaptableDatasetModel<I, O, IW, OW, M> mdl) {
-        return wrapped.checkState(mdl.innerModel());
+    @Override public boolean isUpdateable(AdaptableDatasetModel<I, O, IW, OW, M> mdl) {
+        return wrapped.isUpdateable(mdl.innerModel());
     }
 
     /** {@inheritDoc} */
-    @Override protected <K, V> AdaptableDatasetModel<I, O, IW, OW, M> updateModel(AdaptableDatasetModel<I, O, IW, OW, M> mdl, DatasetBuilder<K, V> datasetBuilder,
+    @Override protected <K, V> AdaptableDatasetModel<I, O, IW, OW, M> updateModel(
+        AdaptableDatasetModel<I, O, IW, OW, M> mdl, DatasetBuilder<K, V> datasetBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
-        return mdl.withInnerModel(wrapped.updateModel(mdl.innerModel(), datasetBuilder, featureExtractor, lbExtractor));
+        M updated = wrapped.updateModel(
+            mdl.innerModel(),
+            datasetBuilder.withUpstreamTransformer(upstreamTransformerBuilder),
+            featureExtractor.andThen(afterFeatureExtractor),
+            lbExtractor.andThen(afterLabelExtractor));
+
+        return mdl.withInnerModel(updated);
     }
 
     /**
@@ -101,7 +139,12 @@ public class AdaptableDatasetTrainer<I, O, IW, OW, M extends IgniteModel<IW, OW>
      * original trainer.
      */
     public <O1> AdaptableDatasetTrainer<I, O1, IW, OW, M, L> afterTrainedModel(IgniteFunction<O, O1> after) {
-        return new AdaptableDatasetTrainer<>(before, wrapped, i -> after.apply(this.after.apply(i)));
+        return new AdaptableDatasetTrainer<>(before,
+            wrapped,
+            i -> after.apply(this.after.apply(i)),
+            afterFeatureExtractor,
+            afterLabelExtractor,
+            upstreamTransformerBuilder);
     }
 
     /**
@@ -115,6 +158,116 @@ public class AdaptableDatasetTrainer<I, O, IW, OW, M extends IgniteModel<IW, OW>
      */
     public <I1> AdaptableDatasetTrainer<I1, O, IW, OW, M, L> beforeTrainedModel(IgniteFunction<I1, I> before) {
         IgniteFunction<I1, IW> function = i -> this.before.apply(before.apply(i));
-        return new AdaptableDatasetTrainer<>(function, wrapped, after);
+        return new AdaptableDatasetTrainer<>(function,
+            wrapped,
+            after,
+            afterFeatureExtractor,
+            afterLabelExtractor,
+            upstreamTransformerBuilder);
+    }
+
+    /**
+     * Specify {@link DatasetMapping} which will be applied to dataset before fitting and updating.
+     *
+     * @param mapping {@link DatasetMapping} which will be applied to dataset before fitting and updating.
+     * @return New trainer of the same type, but with specified mapping applied to dataset before fitting and updating.
+     */
+    public AdaptableDatasetTrainer<I, O, IW, OW, M, L> withDatasetMapping(DatasetMapping<L, L> mapping) {
+        return of(new DatasetTrainer<M, L>() {
+            @Override public <K, V> M fit(
+                DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor,
+                IgniteBiFunction<K, V, L> lbExtractor) {
+                IgniteBiFunction<K, V, Vector> fe = featureExtractor.andThen(mapping::mapFeatures);
+                IgniteBiFunction<K, V, L> le = lbExtractor.andThen(mapping::mapLabels);
+
+                return wrapped.fit(datasetBuilder,
+                    fe,
+                    le);
+            }
+
+            @Override public <K, V> M update(M mdl, DatasetBuilder<K, V> datasetBuilder,
+                IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+                return wrapped.update(mdl, datasetBuilder,
+                    featureExtractor.andThen(mapping::mapFeatures),
+                    lbExtractor.andThen((IgniteFunction<L, L>)mapping::mapLabels));
+            }
+
+            @Override public boolean isUpdateable(M mdl) {
+                return false;
+            }
+
+            @Override protected <K, V> M updateModel(M mdl, DatasetBuilder<K, V> datasetBuilder,
+                IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+                return null;
+            }
+        }).beforeTrainedModel(before).afterTrainedModel(after);
+    }
+
+    /**
+     * Create a {@link TrainersSequentialComposition} of whis trainer and specified trainer.
+     *
+     * @param tr Trainer to compose with.
+     * @param datasetMappingProducer {@link DatasetMapping} producer specifying dependency between this trainer and
+     * trainer to compose with.
+     * @param <O1> Type of output of trainer to compose with.
+     * @param <M1> Type of model produced by the trainer to compose with.
+     * @return A {@link TrainersSequentialComposition} of whis trainer and specified trainer.
+     */
+    public <O1, M1 extends IgniteModel<O, O1>> TrainersSequentialComposition<I, O, O1, L> andThen(
+        DatasetTrainer<M1, L> tr,
+        IgniteFunction<AdaptableDatasetModel<I, O, IW, OW, M>, DatasetMapping<L, L>> datasetMappingProducer) {
+        IgniteFunction<IgniteModel<I, O>, DatasetMapping<L, L>> coercedMapping = mdl ->
+            datasetMappingProducer.apply((AdaptableDatasetModel<I, O, IW, OW, M>)mdl);
+        return new TrainersSequentialComposition<>(this,
+            tr,
+            coercedMapping);
+    }
+
+    /**
+     * Specify function which will be applied after feature extractor.
+     *
+     * @param after Function which will be applied after feature extractor.
+     * @return New trainer with same parameters as this trainer except that specified function will be applied
+     * after feature extractor.
+     */
+    public AdaptableDatasetTrainer<I, O, IW, OW, M, L> afterFeatureExtractor(IgniteFunction<Vector, Vector> after) {
+        return new AdaptableDatasetTrainer<>(before,
+            wrapped,
+            this.after,
+            after,
+            afterLabelExtractor,
+            upstreamTransformerBuilder);
+    }
+
+    /**
+     * Specify function which will be applied after label extractor.
+     *
+     * @param after Function which will be applied after label extractor.
+     * @return New trainer with same parameters as this trainer has except that specified function will be applied
+     * after label extractor.
+     */
+    public AdaptableDatasetTrainer<I, O, IW, OW, M, L> afterLabelExtractor(IgniteFunction<L, L> after) {
+        return new AdaptableDatasetTrainer<>(before,
+            wrapped,
+            this.after,
+            afterFeatureExtractor,
+            after,
+            upstreamTransformerBuilder);
+    }
+
+    /**
+     * Specify which {@link UpstreamTransformerBuilder} will be used.
+     *
+     * @param upstreamTransformerBuilder {@link UpstreamTransformerBuilder} to use.
+     * @return New trainer with same parameters as this trainer has except that specified {@link UpstreamTransformerBuilder} will be used.
+     */
+    public AdaptableDatasetTrainer<I, O, IW, OW, M, L> withUpstreamTransformerBuilder(
+        UpstreamTransformerBuilder upstreamTransformerBuilder) {
+        return new AdaptableDatasetTrainer<>(before,
+            wrapped,
+            after,
+            afterFeatureExtractor,
+            afterLabelExtractor,
+            upstreamTransformerBuilder);
     }
 }
index 88c4bcd..42cac07 100644 (file)
@@ -71,12 +71,11 @@ public abstract class DatasetTrainer<M extends IgniteModel, L> {
      * @param <V> Type of a value in {@code upstream} data.
      * @return Updated model.
      */
-    //
     public <K,V> M update(M mdl, DatasetBuilder<K, V> datasetBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
 
         if(mdl != null) {
-            if (checkState(mdl))
+            if (isUpdateable(mdl))
                 return updateModel(mdl, datasetBuilder, featureExtractor, lbExtractor);
             else {
                 environment.logger(getClass()).log(
@@ -94,7 +93,7 @@ public abstract class DatasetTrainer<M extends IgniteModel, L> {
      * @param mdl Model.
      * @return true if current critical for training parameters correspond to parameters from last training.
      */
-    protected abstract boolean checkState(M mdl);
+    public abstract boolean isUpdateable(M mdl);
 
     /**
      * Used on update phase when given dataset is empty.
@@ -308,12 +307,12 @@ public abstract class DatasetTrainer<M extends IgniteModel, L> {
     }
 
     /**
-     * Creates {@code DatasetTrainer} with same training logic, but able to accept labels of given new type
+     * Creates {@link DatasetTrainer} with same training logic, but able to accept labels of given new type
      * of labels.
      *
      * @param new2Old Converter of new labels to old labels.
      * @param <L1> New labels type.
-     * @return {@code DatasetTrainer} with same training logic, but able to accept labels of given new type
+     * @return {@link DatasetTrainer} with same training logic, but able to accept labels of given new type
      * of labels.
      */
     public <L1> DatasetTrainer<M, L1> withConvertedLabels(IgniteFunction<L1, L> new2Old) {
@@ -326,8 +325,8 @@ public abstract class DatasetTrainer<M extends IgniteModel, L> {
             }
 
             /** {@inheritDoc} */
-            @Override protected boolean checkState(M mdl) {
-                return old.checkState(mdl);
+            @Override public boolean isUpdateable(M mdl) {
+                return old.isUpdateable(mdl);
             }
 
             /** {@inheritDoc} */
@@ -362,4 +361,31 @@ public abstract class DatasetTrainer<M extends IgniteModel, L> {
         }
     }
 
+    /**
+     * Returns the trainer which returns identity model.
+     *
+     * @param <I> Type of model input.
+     * @param <L> Type of labels in dataset.
+     * @return Trainer which returns identity model.
+     */
+    public static <I, L> DatasetTrainer<IgniteModel<I, I>, L> identityTrainer() {
+        return new DatasetTrainer<IgniteModel<I, I>, L>() {
+            @Override public <K, V> IgniteModel<I, I> fit(DatasetBuilder<K, V> datasetBuilder,
+                IgniteBiFunction<K, V, Vector> featureExtractor,
+                IgniteBiFunction<K, V, L> lbExtractor) {
+                return x -> x;
+            }
+
+            /** {@inheritDoc} */
+            @Override public boolean isUpdateable(IgniteModel<I, I> mdl) {
+                return true;
+            }
+
+            /** {@inheritDoc} */
+            @Override protected <K, V> IgniteModel<I, I> updateModel(IgniteModel<I, I> mdl, DatasetBuilder<K, V> datasetBuilder,
+                IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+                return x -> x;
+            }
+        };
+    }
 }
index 43c1600..db5522e 100644 (file)
@@ -24,6 +24,7 @@ import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 import org.apache.ignite.ml.IgniteModel;
 import org.apache.ignite.ml.composition.ModelsComposition;
+import org.apache.ignite.ml.composition.bagging.BaggedTrainer;
 import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.environment.LearningEnvironment;
@@ -48,12 +49,11 @@ public class TrainerTransformers {
      * @param ensembleSize Size of ensemble.
      * @param subsampleRatio Subsample ratio to whole dataset.
      * @param aggregator Aggregator.
-     * @param <M> Type of one model in ensemble.
      * @param <L> Type of labels.
      * @return Bagged trainer.
      */
-    public static <M extends IgniteModel<Vector, Double>, L> DatasetTrainer<ModelsComposition, L> makeBagged(
-        DatasetTrainer<M, L> trainer,
+    public static <L> BaggedTrainer<L> makeBagged(
+        DatasetTrainer<? extends IgniteModel, L> trainer,
         int ensembleSize,
         double subsampleRatio,
         PredictionsAggregator aggregator) {
@@ -71,58 +71,19 @@ public class TrainerTransformers {
      * @param <L> Type of labels.
      * @return Bagged trainer.
      */
-    public static <M extends IgniteModel<Vector, Double>, L> DatasetTrainer<ModelsComposition, L> makeBagged(
+    public static <M extends IgniteModel<Vector, Double>, L> BaggedTrainer<L> makeBagged(
         DatasetTrainer<M, L> trainer,
         int ensembleSize,
         double subsampleRatio,
         int featureVectorSize,
         int featuresSubspaceDim,
         PredictionsAggregator aggregator) {
-        return new DatasetTrainer<ModelsComposition, L>() {
-            /** {@inheritDoc} */
-            @Override public <K, V> ModelsComposition fit(
-                DatasetBuilder<K, V> datasetBuilder,
-                IgniteBiFunction<K, V, Vector> featureExtractor,
-                IgniteBiFunction<K, V, L> lbExtractor) {
-                return runOnEnsemble(
-                    (db, i, fe) -> (() -> trainer.fit(db, fe, lbExtractor)),
-                    datasetBuilder,
-                    ensembleSize,
-                    subsampleRatio,
-                    featureVectorSize,
-                    featuresSubspaceDim,
-                    featureExtractor,
-                    aggregator,
-                    environment);
-            }
-
-            /** {@inheritDoc} */
-            @Override protected boolean checkState(ModelsComposition mdl) {
-                return mdl.getModels().stream().allMatch(m -> trainer.checkState((M)m));
-            }
-
-            /** {@inheritDoc} */
-            @Override protected <K, V> ModelsComposition updateModel(
-                ModelsComposition mdl,
-                DatasetBuilder<K, V> datasetBuilder,
-                IgniteBiFunction<K, V, Vector> featureExtractor,
-                IgniteBiFunction<K, V, L> lbExtractor) {
-                return runOnEnsemble(
-                    (db, i, fe) -> (() -> trainer.updateModel(
-                        ((ModelWithMapping<Vector, Double, M>)mdl.getModels().get(i)).model(),
-                        db,
-                        fe,
-                        lbExtractor)),
-                    datasetBuilder,
-                    ensembleSize,
-                    subsampleRatio,
-                    featureVectorSize,
-                    featuresSubspaceDim,
-                    featureExtractor,
-                    aggregator,
-                    environment);
-            }
-        }.withEnvironmentBuilder(trainer.envBuilder);
+        return new BaggedTrainer<>(trainer,
+            aggregator,
+            ensembleSize,
+            subsampleRatio,
+            featureVectorSize,
+            featuresSubspaceDim);
     }
 
     /**
index 7f45fdd..36e7867 100644 (file)
@@ -28,11 +28,8 @@ import org.apache.ignite.ml.dataset.UpstreamTransformerBuilder;
  * This class encapsulates the logic needed to do bagging (bootstrap aggregating) by features.
  * The action of this class on a given upstream is to replicate each entry in accordance to
  * Poisson distribution.
- *
- * @param <K> Type of upstream keys.
- * @param <V> Type of upstream values.
  */
-public class BaggingUpstreamTransformer<K, V> implements UpstreamTransformer<K, V> {
+public class BaggingUpstreamTransformer implements UpstreamTransformer {
     /** Serial version uid. */
     private static final long serialVersionUID = -913152523469994149L;
 
@@ -51,8 +48,8 @@ public class BaggingUpstreamTransformer<K, V> implements UpstreamTransformer<K,
      * @param <V> Type of upstream values.
      * @return Builder of {@link BaggingUpstreamTransformer}.
      */
-    public static <K, V> UpstreamTransformerBuilder<K, V> builder(double subsampleRatio, int mdlIdx) {
-        return env -> new BaggingUpstreamTransformer<>(env.randomNumbersGenerator().nextLong() + mdlIdx, subsampleRatio);
+    public static <K, V> UpstreamTransformerBuilder builder(double subsampleRatio, int mdlIdx) {
+        return env -> new BaggingUpstreamTransformer(env.randomNumbersGenerator().nextLong() + mdlIdx, subsampleRatio);
     }
 
     /**
@@ -67,7 +64,7 @@ public class BaggingUpstreamTransformer<K, V> implements UpstreamTransformer<K,
     }
 
     /** {@inheritDoc} */
-    @Override public Stream<UpstreamEntry<K, V>> transform(Stream<UpstreamEntry<K, V>> upstream) {
+    @Override public Stream<UpstreamEntry> transform(Stream<UpstreamEntry> upstream) {
         PoissonDistribution poisson = new PoissonDistribution(
             new Well19937c(seed),
             subsampleRatio,
index 35d1ea4..f3fc4ce 100644 (file)
@@ -106,7 +106,7 @@ public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends Dataset
     }
 
     /** {@inheritDoc} */
-    @Override protected boolean checkState(DecisionTreeNode mdl) {
+    @Override public boolean isUpdateable(DecisionTreeNode mdl) {
         return true;
     }
 
index d9b8e30..6d92948 100644 (file)
@@ -239,7 +239,7 @@ public abstract class RandomForestTrainer<L, S extends ImpurityComputer<Bootstra
     }
 
     /** {@inheritDoc} */
-    @Override protected boolean checkState(ModelsComposition mdl) {
+    @Override public boolean isUpdateable(ModelsComposition mdl) {
         ModelsComposition fakeComposition = buildComposition(Collections.emptyList());
         return mdl.getPredictionsAggregator().getClass() == fakeComposition.getPredictionsAggregator().getClass();
     }
index c2fd652..e57c5ba 100644 (file)
@@ -126,7 +126,7 @@ public interface DataStreamGenerator {
      * @return Dataset builder.
      */
     public default DatasetBuilder<Vector, Double> asDatasetBuilder(int datasetSize, IgniteBiPredicate<Vector, Double> filter,
-        int partitions, UpstreamTransformerBuilder<Vector, Double> upstreamTransformerBuilder) {
+        int partitions, UpstreamTransformerBuilder upstreamTransformerBuilder) {
 
         return new DatasetBuilderAdapter(this, datasetSize, filter, partitions, upstreamTransformerBuilder);
     }
index 189e053..7e5060e 100644 (file)
@@ -48,7 +48,7 @@ class DatasetBuilderAdapter extends LocalDatasetBuilder<Vector, Double> {
      */
     public DatasetBuilderAdapter(DataStreamGenerator generator, int datasetSize,
         IgniteBiPredicate<Vector, Double> filter, int partitions,
-        UpstreamTransformerBuilder<Vector, Double> upstreamTransformerBuilder) {
+        UpstreamTransformerBuilder upstreamTransformerBuilder) {
 
         super(generator.asMap(datasetSize), filter, partitions, upstreamTransformerBuilder);
     }
index fc3bf5c..ed23373 100644 (file)
@@ -429,7 +429,7 @@ public class TestUtils {
             }
 
             /** {@inheritDoc} */
-            @Override public boolean checkState(M mdl) {
+            @Override public boolean isUpdateable(M mdl) {
                 return true;
             }
 
index dd4b11e..4f8f412 100644 (file)
@@ -22,6 +22,8 @@ import java.util.Map;
 import org.apache.ignite.ml.IgniteModel;
 import org.apache.ignite.ml.TestUtils;
 import org.apache.ignite.ml.common.TrainerTest;
+import org.apache.ignite.ml.composition.bagging.BaggedModel;
+import org.apache.ignite.ml.composition.bagging.BaggedTrainer;
 import org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictionsAggregator;
 import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator;
 import org.apache.ignite.ml.dataset.Dataset;
@@ -77,18 +79,16 @@ public class BaggingTest extends TrainerTest {
                 .withBatchSize(10)
                 .withSeed(123L);
 
-        trainer.withEnvironmentBuilder(TestUtils.testEnvBuilder());
-
-        DatasetTrainer<ModelsComposition, Double> baggedTrainer =
-            TrainerTransformers.makeBagged(
-                trainer,
-                10,
-                0.7,
-                2,
-                2,
-                new OnMajorityPredictionsAggregator());
+        BaggedTrainer<Double> baggedTrainer = TrainerTransformers.makeBagged(
+            trainer,
+            10,
+            0.7,
+            2,
+            2,
+            new OnMajorityPredictionsAggregator())
+            .withEnvironmentBuilder(TestUtils.testEnvBuilder());
 
-        ModelsComposition mdl = baggedTrainer.fit(
+        BaggedModel mdl = baggedTrainer.fit(
             cacheMock,
             parts,
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
@@ -111,14 +111,17 @@ public class BaggingTest extends TrainerTest {
 
         double subsampleRatio = 0.3;
 
-        ModelsComposition mdl = TrainerTransformers.makeBagged(
+        BaggedModel mdl = TrainerTransformers.makeBagged(
             cntTrainer,
             100,
             subsampleRatio,
             2,
             2,
             new MeanValuePredictionsAggregator())
-            .fit(cacheMock, parts, null, null);
+            .fit(cacheMock,
+                parts,
+                (integer, doubles) -> VectorUtils.of(doubles),
+                (integer, doubles) -> doubles[doubles.length - 1]);
 
         Double res = mdl.predict(null);
 
@@ -177,7 +180,7 @@ public class BaggingTest extends TrainerTest {
         }
 
         /** {@inheritDoc} */
-        @Override protected boolean checkState(IgniteModel<Vector, Double> mdl) {
+        @Override public boolean isUpdateable(IgniteModel<Vector, Double> mdl) {
             return true;
         }
 
index d253ea0..874547f 100644 (file)
@@ -103,7 +103,7 @@ public class LearningEnvironmentTest {
             }
 
             /** {@inheritDoc} */
-            @Override protected boolean checkState(IgniteModel<Object, Vector> mdl) {
+            @Override public boolean isUpdateable(IgniteModel<Object, Vector> mdl) {
                 return false;
             }
 
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/StackingTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/StackingTest.java
new file mode 100644 (file)
index 0000000..9c089ce
--- /dev/null
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers;
+
+import java.util.Arrays;
+import org.apache.ignite.ml.IgniteModel;
+import org.apache.ignite.ml.TestUtils;
+import org.apache.ignite.ml.common.TrainerTest;
+import org.apache.ignite.ml.composition.stacking.StackedDatasetTrainer;
+import org.apache.ignite.ml.composition.stacking.StackedModel;
+import org.apache.ignite.ml.composition.stacking.StackedVectorDatasetTrainer;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.primitives.matrix.Matrix;
+import org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.nn.Activators;
+import org.apache.ignite.ml.nn.MLPTrainer;
+import org.apache.ignite.ml.nn.MultilayerPerceptron;
+import org.apache.ignite.ml.nn.UpdatesStrategy;
+import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
+import org.apache.ignite.ml.optimization.LossFunctions;
+import org.apache.ignite.ml.optimization.SmoothParametrized;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
+import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer;
+import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+import static junit.framework.TestCase.assertEquals;
+
+/**
+ * Tests stacked trainers.
+ */
+public class StackingTest extends TrainerTest {
+    /** Rule to check exceptions. */
+    @Rule
+    public ExpectedException thrown = ExpectedException.none();
+
+    /**
+     * Tests simple stack training.
+     */
+    @Test
+    public void testSimpleStack() {
+        StackedDatasetTrainer<Vector, Vector, Double, LinearRegressionModel, Double> trainer =
+            new StackedDatasetTrainer<>();
+
+        UpdatesStrategy<SmoothParametrized, SimpleGDParameterUpdate> updatesStgy = new UpdatesStrategy<>(
+            new SimpleGDUpdateCalculator(0.2),
+            SimpleGDParameterUpdate::sumLocal,
+            SimpleGDParameterUpdate::avg
+        );
+
+        MLPArchitecture arch = new MLPArchitecture(2).
+            withAddedLayer(10, true, Activators.RELU).
+            withAddedLayer(1, false, Activators.SIGMOID);
+
+        MLPTrainer<SimpleGDParameterUpdate> trainer1 = new MLPTrainer<>(
+            arch,
+            LossFunctions.MSE,
+            updatesStgy,
+            3000,
+            10,
+            50,
+            123L
+        );
+
+        // Convert model trainer to produce Vector -> Vector model
+        DatasetTrainer<AdaptableDatasetModel<Vector, Vector, Matrix, Matrix, MultilayerPerceptron>, Double> mlpTrainer =
+            AdaptableDatasetTrainer.of(trainer1)
+                .beforeTrainedModel((Vector v) -> new DenseMatrix(v.asArray(), 1))
+                .afterTrainedModel((Matrix mtx) -> mtx.getRow(0))
+                .withConvertedLabels(VectorUtils::num2Arr);
+
+        final double factor = 3;
+
+        StackedModel<Vector, Vector, Double, LinearRegressionModel> mdl = trainer
+            .withAggregatorTrainer(new LinearRegressionLSQRTrainer().withConvertedLabels(x -> x * factor))
+            .addTrainer(mlpTrainer)
+            .withAggregatorInputMerger(VectorUtils::concat)
+            .withSubmodelOutput2VectorConverter(IgniteFunction.identity())
+            .withVector2SubmodelInputConverter(IgniteFunction.identity())
+            .withOriginalFeaturesKept(IgniteFunction.identity())
+            .withEnvironmentBuilder(TestUtils.testEnvBuilder())
+            .fit(getCacheMock(xor),
+                parts,
+                (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+                (k, v) -> v[v.length - 1]);
+
+        assertEquals(0.0 * factor, mdl.predict(VectorUtils.of(0.0, 0.0)), 0.3);
+        assertEquals(1.0 * factor, mdl.predict(VectorUtils.of(0.0, 1.0)), 0.3);
+        assertEquals(1.0 * factor, mdl.predict(VectorUtils.of(1.0, 0.0)), 0.3);
+        assertEquals(0.0 * factor, mdl.predict(VectorUtils.of(1.0, 1.0)), 0.3);
+    }
+
+    /**
+     * Tests simple stack training.
+     */
+    @Test
+    public void testSimpleVectorStack() {
+        StackedVectorDatasetTrainer<Double, LinearRegressionModel, Double> trainer =
+            new StackedVectorDatasetTrainer<>();
+
+        UpdatesStrategy<SmoothParametrized, SimpleGDParameterUpdate> updatesStgy = new UpdatesStrategy<>(
+            new SimpleGDUpdateCalculator(0.2),
+            SimpleGDParameterUpdate::sumLocal,
+            SimpleGDParameterUpdate::avg
+        );
+
+        MLPArchitecture arch = new MLPArchitecture(2).
+            withAddedLayer(10, true, Activators.RELU).
+            withAddedLayer(1, false, Activators.SIGMOID);
+
+        DatasetTrainer<MultilayerPerceptron, Double> mlpTrainer = new MLPTrainer<>(
+            arch,
+            LossFunctions.MSE,
+            updatesStgy,
+            3000,
+            10,
+            50,
+            123L
+        ).withConvertedLabels(VectorUtils::num2Arr);
+
+        final double factor = 3;
+
+        StackedModel<Vector, Vector, Double, LinearRegressionModel> mdl = trainer
+            .withAggregatorTrainer(new LinearRegressionLSQRTrainer().withConvertedLabels(x -> x * factor))
+            .addMatrix2MatrixTrainer(mlpTrainer)
+            .withEnvironmentBuilder(TestUtils.testEnvBuilder())
+            .fit(getCacheMock(xor),
+                parts,
+                (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+                (k, v) -> v[v.length - 1]);
+
+        assertEquals(0.0 * factor, mdl.predict(VectorUtils.of(0.0, 0.0)), 0.3);
+        assertEquals(1.0 * factor, mdl.predict(VectorUtils.of(0.0, 1.0)), 0.3);
+        assertEquals(1.0 * factor, mdl.predict(VectorUtils.of(1.0, 0.0)), 0.3);
+        assertEquals(0.0 * factor, mdl.predict(VectorUtils.of(1.0, 1.0)), 0.3);
+    }
+
+    /**
+     * Tests that if there is no any way for input of first layer to propagate to second layer,
+     * exception will be thrown.
+     */
+    @Test
+    public void testINoWaysOfPropagation() {
+        StackedDatasetTrainer<Void, Void, Void, IgniteModel<Void, Void>, Void> trainer =
+            new StackedDatasetTrainer<>();
+        thrown.expect(IllegalStateException.class);
+        trainer.fit(null, null, null);
+    }
+}
index f2899c2..d711fc4 100644 (file)
@@ -147,8 +147,8 @@ public class DataStreamGeneratorTest {
         DatasetBuilder<Vector, Double> b2 = generator.asDatasetBuilder(N, (v, l) -> l == 0, 2);
         counter.set(0);
         DatasetBuilder<Vector, Double> b3 = generator.asDatasetBuilder(N, (v, l) -> l == 1, 2,
-            new UpstreamTransformerBuilder<Vector, Double>() {
-                @Override public UpstreamTransformer<Vector, Double> build(LearningEnvironment env) {
+            new UpstreamTransformerBuilder() {
+                @Override public UpstreamTransformer build(LearningEnvironment env) {
                     return new UpstreamTransformerForTest();
                 }
             });
@@ -201,10 +201,10 @@ public class DataStreamGeneratorTest {
     }
 
     /** */
-    private static class UpstreamTransformerForTest implements UpstreamTransformer<Vector, Double> {
-        @Override public Stream<UpstreamEntry<Vector, Double>> transform(
-            Stream<UpstreamEntry<Vector, Double>> upstream) {
-            return upstream.map(entry -> new UpstreamEntry<>(entry.getKey(), -entry.getValue()));
+    private static class UpstreamTransformerForTest implements UpstreamTransformer {
+        @Override public Stream<UpstreamEntry> transform(
+            Stream<UpstreamEntry> upstream) {
+            return upstream.map(entry -> new UpstreamEntry<>(entry.getKey(), -((double)entry.getValue())));
         }
     }
 }