IGNITE-10480: [ML] Stacking for training and inference
authorArtem Malykh <amalykhgh@gmail.com>
Fri, 14 Dec 2018 16:28:39 +0000 (19:28 +0300)
committerYury Babak <ybabak@gridgain.com>
Fri, 14 Dec 2018 16:28:39 +0000 (19:28 +0300)
This closes #5635

34 files changed:
examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java
examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java
examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java
modules/ml/src/main/java/org/apache/ignite/ml/Model.java
modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/SimpleStackedDatasetTrainer.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedDatasetTrainer.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedModel.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedVectorDatasetTrainer.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/package-info.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformer.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java
modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/math/functions/IgniteFunction.java
modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/matrix/impl/DenseMatrix.java
modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/VectorUtils.java
modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/storage/DenseVectorStorage.java
modules/ml/src/main/java/org/apache/ignite/ml/multiclass/OneVsRestTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/trainers/AdaptableDatasetModel.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/trainers/AdaptableDatasetTrainer.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java
modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java
modules/ml/src/test/java/org/apache/ignite/ml/common/TrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/composition/BaggingTest.java [moved from modules/ml/src/test/java/org/apache/ignite/ml/trainers/BaggingTest.java with 83% similarity]
modules/ml/src/test/java/org/apache/ignite/ml/composition/CompositionTestSuite.java
modules/ml/src/test/java/org/apache/ignite/ml/composition/StackingTest.java [new file with mode: 0644]
modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java
modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java
modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java
modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java

index 3a5941e..58f739d 100644 (file)
@@ -62,7 +62,7 @@ public class BaggedLogisticRegressionSGDTrainerExample {
                 .fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
 
             System.out.println(">>> Create new logistic regression trainer object.");
-            LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>()
+            LogisticRegressionSGDTrainer trainer = new LogisticRegressionSGDTrainer()
                 .withUpdatesStgy(new UpdatesStrategy<>(
                     new SimpleGDUpdateCalculator(0.2),
                     SimpleGDParameterUpdate::sumLocal,
index 8ce46cc..65cf4d1 100644 (file)
@@ -63,7 +63,7 @@ public class LogisticRegressionSGDTrainerExample {
                 .fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
 
             System.out.println(">>> Create new logistic regression trainer object.");
-            LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>()
+            LogisticRegressionSGDTrainer trainer = new LogisticRegressionSGDTrainer()
                 .withUpdatesStgy(new UpdatesStrategy<>(
                     new SimpleGDUpdateCalculator(0.2),
                     SimpleGDParameterUpdate::sumLocal,
index 2c6a820..58466bd 100644 (file)
@@ -124,7 +124,7 @@ public class Step_9_Go_to_LogReg {
                                             minMaxScalerPreprocessor
                                         );
 
-                                    LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>()
+                                    LogisticRegressionSGDTrainer trainer = new LogisticRegressionSGDTrainer()
                                         .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(learningRate),
                                             SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg))
                                         .withMaxIterations(maxIterations)
@@ -188,7 +188,7 @@ public class Step_9_Go_to_LogReg {
                         minMaxScalerPreprocessor
                     );
 
-                LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>()
+                LogisticRegressionSGDTrainer trainer = new LogisticRegressionSGDTrainer()
                     .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(bestLearningRate),
                         SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg))
                     .withMaxIterations(bestMaxIterations)
index 8428b0a..87458a1 100644 (file)
@@ -34,6 +34,17 @@ public interface Model<T, V> extends IgniteFunction<T, V> {
     }
 
     /**
+     * Get a composition model of the form {@code x -> after(mdl(x))}.
+     *
+     * @param after Function to apply after this model.
+     * @param <V1> Type of input of function applied before this model.
+     * @return Composition model of the form {@code x -> after(mdl(x))}.
+     */
+    public default <V1> Model<T, V1> andThen(IgniteFunction<V, V1> after) {
+        return t -> after.apply(apply(t));
+    }
+
+    /**
      * @param pretty Use pretty mode.
      */
     public default String toString(boolean pretty) {
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/SimpleStackedDatasetTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/SimpleStackedDatasetTrainer.java
new file mode 100644 (file)
index 0000000..c4c082f
--- /dev/null
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.stacking;
+
+import java.util.ArrayList;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
+import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.trainers.DatasetTrainer;
+
+/**
+ * {@link DatasetTrainer} with same type of input and output of submodels.
+ *
+ * @param <I> Type of submodels input.
+ * @param <O> Type of aggregator model output.
+ * @param <AM> Type of aggregator model.
+ * @param <L> Type of labels.
+ */
+public class SimpleStackedDatasetTrainer<I, O, AM extends Model<I, O>, L> extends StackedDatasetTrainer<I, I, O, AM, L> {
+    /**
+     * Construct instance of this class.
+     *
+     * @param aggregatingTrainer Aggregator trainer.
+     * @param aggregatingInputMerger Function used to merge submodels outputs into one.
+     * @param submodelInput2AggregatingInputConverter Function used to convert input of submodel to output of submodel
+     * this function is used if user chooses to keep original features.
+     */
+    public SimpleStackedDatasetTrainer(DatasetTrainer<AM, L> aggregatingTrainer,
+        IgniteBinaryOperator<I> aggregatingInputMerger,
+        IgniteFunction<I, I> submodelInput2AggregatingInputConverter,
+        IgniteFunction<Vector, I> vector2SubmodelInputConverter,
+        IgniteFunction<I, Vector> submodelOutput2VectorConverter) {
+        super(aggregatingTrainer,
+            aggregatingInputMerger,
+            submodelInput2AggregatingInputConverter,
+            new ArrayList<>(),
+            vector2SubmodelInputConverter,
+            submodelOutput2VectorConverter);
+    }
+
+    /**
+     * Construct instance of this class.
+     *
+     * @param aggregatingTrainer Aggregator trainer.
+     * @param aggregatingInputMerger Function used to merge submodels outputs into one.
+     */
+    public SimpleStackedDatasetTrainer(DatasetTrainer<AM, L> aggregatingTrainer,
+        IgniteBinaryOperator<I> aggregatingInputMerger) {
+        super(aggregatingTrainer, aggregatingInputMerger, IgniteFunction.identity());
+    }
+
+    /**
+     * Constructs instance of this class.
+     */
+    public SimpleStackedDatasetTrainer() {
+        super();
+    }
+
+    //TODO: IGNITE-10441 -- Look for options to avoid boilerplate overrides.
+    /** {@inheritDoc} */
+    @Override public <M1 extends Model<I, I>> SimpleStackedDatasetTrainer<I, O, AM, L> addTrainer(
+        DatasetTrainer<M1, L> trainer) {
+        return (SimpleStackedDatasetTrainer<I, O, AM, L>)super.addTrainer(trainer);
+    }
+
+    /** {@inheritDoc} */
+    @Override public SimpleStackedDatasetTrainer<I, O, AM, L> withAggregatorTrainer(
+        DatasetTrainer<AM, L> aggregatorTrainer) {
+        return (SimpleStackedDatasetTrainer<I, O, AM, L>)super.withAggregatorTrainer(aggregatorTrainer);
+    }
+
+    /** {@inheritDoc} */
+    @Override public SimpleStackedDatasetTrainer<I, O, AM, L> withOriginalFeaturesDropped() {
+        return (SimpleStackedDatasetTrainer<I, O, AM, L>)super.withOriginalFeaturesDropped();
+    }
+
+    /** {@inheritDoc} */
+    @Override public SimpleStackedDatasetTrainer<I, O, AM, L> withOriginalFeaturesKept(
+        IgniteFunction<I, I> submodelInput2AggregatingInputConverter) {
+        return (SimpleStackedDatasetTrainer<I, O, AM, L>)super.withOriginalFeaturesKept(
+            submodelInput2AggregatingInputConverter);
+    }
+
+    /** {@inheritDoc} */
+    @Override public SimpleStackedDatasetTrainer<I, O, AM, L> withAggregatorInputMerger(IgniteBinaryOperator<I> merger) {
+        return (SimpleStackedDatasetTrainer<I, O, AM, L>)super.withAggregatorInputMerger(merger);
+    }
+
+    /** {@inheritDoc} */
+    @Override public SimpleStackedDatasetTrainer<I, O, AM, L> withEnvironmentBuilder(
+        LearningEnvironmentBuilder envBuilder) {
+        return (SimpleStackedDatasetTrainer<I, O, AM, L>)super.withEnvironmentBuilder(envBuilder);
+    }
+
+    /** {@inheritDoc} */
+    @Override public <L1> SimpleStackedDatasetTrainer<I, O, AM, L1> withConvertedLabels(IgniteFunction<L1, L> new2Old) {
+        return (SimpleStackedDatasetTrainer<I, O, AM, L1>)super.withConvertedLabels(new2Old);
+    }
+
+    /**
+     * Keep original features using {@link IgniteFunction#identity()} as submodelInput2AggregatingInputConverter.
+     *
+     * @return This object.
+     */
+    public SimpleStackedDatasetTrainer<I, O, AM, L> withOriginalFeaturesKept() {
+        return (SimpleStackedDatasetTrainer<I, O, AM, L>)super.withOriginalFeaturesKept(IgniteFunction.identity());
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedDatasetTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedDatasetTrainer.java
new file mode 100644 (file)
index 0000000..bb870cf
--- /dev/null
@@ -0,0 +1,412 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.stacking;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.stream.Collectors;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
+import org.apache.ignite.ml.environment.parallelism.Promise;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.functions.IgniteSupplier;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.trainers.DatasetTrainer;
+
+/**
+ * {@link DatasetTrainer} encapsulating stacking technique for model training.
+ * Model produced by this trainer consists of two layers. First layer is a model {@code IS -> IA}.
+ * This layer is a "parallel" composition of several "submodels", each of them itself is a model
+ * {@code IS -> IA} with their outputs {@code [IA]} merged into single {@code IA}.
+ * Second layer is an aggregator model {@code IA -> O}.
+ * Training corresponds to this layered structure in the following way:
+ * <pre>
+ * 1. train models of first layer;
+ * 2. train aggregator model on dataset augmented with outputs of first layer models converted to vectors.
+ * </pre>
+ * During second step we can choose if we want to keep original features along with converted outputs of first layer
+ * models or use only converted results of first layer models. This choice will also affect inference.
+ * This class is a most general stacked trainer, there is a {@link StackedVectorDatasetTrainer}: a shortcut version of
+ * it with some types and functions specified.
+ *
+ * @param <IS> Type of submodels input.
+ * @param <IA> Type of aggregator input.
+ * @param <O> Type of aggregator output.
+ * @param <L> Type of labels.
+ */
+public class StackedDatasetTrainer<IS, IA, O, AM extends Model<IA, O>, L>
+    extends DatasetTrainer<StackedModel<IS, IA, O, AM>, L> {
+    /** Operator that merges inputs for aggregating model. */
+    private IgniteBinaryOperator<IA> aggregatingInputMerger;
+
+    /** Function transforming input for submodels to input for aggregating model. */
+    private IgniteFunction<IS, IA> submodelInput2AggregatingInputConverter;
+
+    /** Trainers of submodels with converters from and to {@link Vector}. */
+    private List<DatasetTrainer<Model<IS, IA>, L>> submodelsTrainers;
+
+    /** Aggregating trainer. */
+    private DatasetTrainer<AM, L> aggregatorTrainer;
+
+    /** Function used for conversion of {@link Vector} to submodel input. */
+    private IgniteFunction<Vector, IS> vector2SubmodelInputConverter;
+
+    /** Function used for conversion of submodel output to {@link Vector}. */
+    private IgniteFunction<IA, Vector> submodelOutput2VectorConverter;
+
+    /**
+     * Create instance of this class.
+     *
+     * @param aggregatorTrainer Trainer of model used for aggregation of results of submodels.
+     * @param aggregatingInputMerger Binary operator used to merge outputs of submodels into one output passed to
+     * aggregator model.
+     * @param submodelInput2AggregatingInputConverter Function used to convert input of submodel to output of submodel
+     * this function is used if user chooses to keep original features.
+     * @param submodelsTrainers List of submodel trainers.
+     */
+    public StackedDatasetTrainer(DatasetTrainer<AM, L> aggregatorTrainer,
+        IgniteBinaryOperator<IA> aggregatingInputMerger,
+        IgniteFunction<IS, IA> submodelInput2AggregatingInputConverter,
+        List<DatasetTrainer<Model<IS, IA>, L>> submodelsTrainers,
+        IgniteFunction<Vector, IS> vector2SubmodelInputConverter,
+        IgniteFunction<IA, Vector> submodelOutput2VectorConverter) {
+        this.aggregatorTrainer = aggregatorTrainer;
+        this.aggregatingInputMerger = aggregatingInputMerger;
+        this.submodelInput2AggregatingInputConverter = submodelInput2AggregatingInputConverter;
+        this.submodelsTrainers = new ArrayList<>(submodelsTrainers);
+        this.vector2SubmodelInputConverter = vector2SubmodelInputConverter;
+        this.submodelOutput2VectorConverter = submodelOutput2VectorConverter;
+    }
+
+    /**
+     * Constructs instance of this class.
+     *
+     * @param aggregatorTrainer Trainer of model used for aggregation of results of submodels.
+     * @param aggregatingInputMerger Binary operator used to merge outputs of submodels into one output passed to
+     * aggregator model.
+     * @param submodelInput2AggregatingInputConverter Function used to convert input of submodel to output of submodel
+     * this function is used if user chooses to keep original features.
+     */
+    public StackedDatasetTrainer(DatasetTrainer<AM, L> aggregatorTrainer,
+        IgniteBinaryOperator<IA> aggregatingInputMerger,
+        IgniteFunction<IS, IA> submodelInput2AggregatingInputConverter) {
+        this(aggregatorTrainer,
+            aggregatingInputMerger,
+            submodelInput2AggregatingInputConverter,
+            new ArrayList<>(),
+            null,
+            null);
+    }
+
+    /**
+     * Constructs instance of this class.
+     */
+    public StackedDatasetTrainer() {
+        this(null, null, null, new ArrayList<>(), null, null);
+    }
+
+    /**
+     * Keep original features during training and propagate submodels input to aggregator during inference
+     * using given function.
+     * Note that if this object is on, training will be done on vector obtaining from
+     * concatenating features passed to submodels trainers and outputs of submodels converted to vectors, this can,
+     * for example influence aggregator model input vector dimension (if {@code IS = Vector}), or, more generally,
+     * some {@code IS} parameters which are not reflected just by its type. So converter should be
+     * written accordingly.
+     *
+     * @param submodelInput2AggregatingInputConverter Function used to propagate submodels input to aggregator.
+     * @return This object.
+     */
+    public StackedDatasetTrainer<IS, IA, O, AM, L> withOriginalFeaturesKept(
+        IgniteFunction<IS, IA> submodelInput2AggregatingInputConverter) {
+        this.submodelInput2AggregatingInputConverter = submodelInput2AggregatingInputConverter;
+
+        return this;
+    }
+
+    /**
+     * Drop original features during training and inference.
+     *
+     * @return This object.
+     */
+    public StackedDatasetTrainer<IS, IA, O, AM, L> withOriginalFeaturesDropped() {
+        submodelInput2AggregatingInputConverter = null;
+
+        return this;
+    }
+
+    /**
+     * Set function used for conversion of submodel output to {@link Vector}. This function is used during
+     * building of dataset for training aggregator model. This dataset is augmented with results of submodels
+     * converted to {@link Vector}.
+     *
+     * @param submodelOutput2VectorConverter Function used for conversion of submodel output to {@link Vector}.
+     * @return This object.
+     */
+    public StackedDatasetTrainer<IS, IA, O, AM, L> withSubmodelOutput2VectorConverter(
+        IgniteFunction<IA, Vector> submodelOutput2VectorConverter) {
+        this.submodelOutput2VectorConverter = submodelOutput2VectorConverter;
+
+        return this;
+    }
+
+    /**
+     * Set function used for conversion of {@link Vector} to submodel input. This function is used during
+     * building of dataset for training aggregator model. This dataset is augmented with results of submodels
+     * applied to {@link Vector}s in original dataset.
+     *
+     * @param vector2SubmodelInputConverter Function used for conversion of {@link Vector} to submodel input.
+     * @return This object.
+     */
+    public StackedDatasetTrainer<IS, IA, O, AM, L> withVector2SubmodelInputConverter(
+        IgniteFunction<Vector, IS> vector2SubmodelInputConverter) {
+        this.vector2SubmodelInputConverter = vector2SubmodelInputConverter;
+
+        return this;
+    }
+
+    /**
+     * Specify aggregator trainer.
+     *
+     * @param aggregatorTrainer Aggregator trainer.
+     * @return This object.
+     */
+    public StackedDatasetTrainer<IS, IA, O, AM, L> withAggregatorTrainer(DatasetTrainer<AM, L> aggregatorTrainer) {
+        this.aggregatorTrainer = aggregatorTrainer;
+
+        return this;
+    }
+
+    /**
+     * Specify binary operator used to merge submodels outputs to one.
+     *
+     * @param merger Binary operator used to merge submodels outputs to one.
+     * @return This object.
+     */
+    public StackedDatasetTrainer<IS, IA, O, AM, L> withAggregatorInputMerger(IgniteBinaryOperator<IA> merger) {
+        aggregatingInputMerger = merger;
+
+        return this;
+    }
+
+    /**
+     * Adds submodel trainer along with converters needed on training and inference stages.
+     *
+     * @param trainer Submodel trainer.
+     * @return This object.
+     */
+    @SuppressWarnings({"unchecked"})
+    public <M1 extends Model<IS, IA>> StackedDatasetTrainer<IS, IA, O, AM, L> addTrainer(
+        DatasetTrainer<M1, L> trainer) {
+        // Unsafely coerce DatasetTrainer<M1, L> to DatasetTrainer<Model<IS, IA>, L>, but we fully control
+        // usages of this unsafely coerced object, on the other hand this makes work with
+        // submodelTrainers easier.
+        submodelsTrainers.add(new DatasetTrainer<Model<IS, IA>, L>() {
+            /** {@inheritDoc} */
+            @Override public <K, V> Model<IS, IA> fit(DatasetBuilder<K, V> datasetBuilder,
+                IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+                return trainer.fit(datasetBuilder, featureExtractor, lbExtractor);
+            }
+
+            /** {@inheritDoc} */
+            @Override public <K, V> Model<IS, IA> update(Model<IS, IA> mdl, DatasetBuilder<K, V> datasetBuilder,
+                IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+                DatasetTrainer<Model<IS, IA>, L> trainer1 = (DatasetTrainer<Model<IS, IA>, L>)trainer;
+                return trainer1.update(mdl, datasetBuilder, featureExtractor, lbExtractor);
+            }
+
+            /** {@inheritDoc} */
+            @Override protected boolean checkState(Model<IS, IA> mdl) {
+                return true;
+            }
+
+            /** {@inheritDoc} */
+            @Override protected <K, V> Model<IS, IA> updateModel(Model<IS, IA> mdl, DatasetBuilder<K, V> datasetBuilder,
+                IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+                return null;
+            }
+        });
+
+        return this;
+    }
+
+    /** {@inheritDoc} */
+    @Override public <K, V> StackedModel<IS, IA, O, AM> fit(DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, L> lbExtractor) {
+
+        return update(null, datasetBuilder, featureExtractor, lbExtractor);
+    }
+
+    /** {@inheritDoc} */
+    @Override public <K, V> StackedModel<IS, IA, O, AM> update(StackedModel<IS, IA, O, AM> mdl,
+        DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, L> lbExtractor) {
+        return runOnSubmodels(
+            ensemble -> {
+                List<IgniteSupplier<Model<IS, IA>>> res = new ArrayList<>();
+                for (int i = 0; i < ensemble.size(); i++) {
+                    final int j = i;
+                    res.add(() -> {
+                        DatasetTrainer<Model<IS, IA>, L> trainer = ensemble.get(j);
+                        return mdl == null ?
+                            trainer.fit(datasetBuilder, featureExtractor, lbExtractor) :
+                            trainer.update(mdl.submodels().get(j), datasetBuilder, featureExtractor, lbExtractor);
+                    });
+                }
+                return res;
+            },
+            (at, extr) -> mdl == null ?
+                at.fit(datasetBuilder, extr, lbExtractor) :
+                at.update(mdl.aggregatorModel(), datasetBuilder, extr, lbExtractor),
+            featureExtractor
+        );
+    }
+
+    /** {@inheritDoc} */
+    @Override public StackedDatasetTrainer<IS, IA, O, AM, L> withEnvironmentBuilder(
+        LearningEnvironmentBuilder envBuilder) {
+        submodelsTrainers =
+            submodelsTrainers.stream().map(x -> x.withEnvironmentBuilder(envBuilder)).collect(Collectors.toList());
+        aggregatorTrainer = aggregatorTrainer.withEnvironmentBuilder(envBuilder);
+
+        return this;
+    }
+
+    /**
+     * <pre>
+     * 1. Obtain models produced by running specified tasks;
+     * 2. run other specified task on dataset augmented with results of models from step 2.
+     * </pre>
+     *
+     * @param taskSupplier Function used to generate tasks for first step.
+     * @param aggregatorProcessor Function used
+     * @param featureExtractor Feature extractor.
+     * @param <K> Type of keys in upstream.
+     * @param <V> Type of values in upstream.
+     * @return {@link StackedModel}.
+     */
+    private <K, V> StackedModel<IS, IA, O, AM> runOnSubmodels(
+        IgniteFunction<List<DatasetTrainer<Model<IS, IA>, L>>, List<IgniteSupplier<Model<IS, IA>>>> taskSupplier,
+        IgniteBiFunction<DatasetTrainer<AM, L>, IgniteBiFunction<K, V, Vector>, AM> aggregatorProcessor,
+        IgniteBiFunction<K, V, Vector> featureExtractor) {
+
+        // Make sure there is at least one way for submodel input to propagate to aggregator.
+        if (submodelInput2AggregatingInputConverter == null && submodelsTrainers.isEmpty())
+            throw new IllegalStateException("There should be at least one way for submodels " +
+                "input to be propageted to aggregator.");
+
+        if (submodelOutput2VectorConverter == null || vector2SubmodelInputConverter == null)
+            throw new IllegalStateException("There should be a specified way to convert vectors to submodels " +
+                "input and submodels output to vector");
+
+        if (aggregatingInputMerger == null)
+            throw new IllegalStateException("Binary operator used to convert outputs of submodels is not specified");
+
+        List<IgniteSupplier<Model<IS, IA>>> mdlSuppliers = taskSupplier.apply(submodelsTrainers);
+
+        List<Model<IS, IA>> subMdls = environment.parallelismStrategy().submit(mdlSuppliers).stream()
+            .map(Promise::unsafeGet)
+            .collect(Collectors.toList());
+
+        // Add new columns consisting in submodels output in features.
+        IgniteBiFunction<K, V, Vector> augmentedExtractor = getFeatureExtractorForAggregator(featureExtractor,
+            subMdls,
+            submodelInput2AggregatingInputConverter,
+            submodelOutput2VectorConverter,
+            vector2SubmodelInputConverter);
+
+        AM aggregator = aggregatorProcessor.apply(aggregatorTrainer, augmentedExtractor);
+
+        StackedModel<IS, IA, O, AM> res = new StackedModel<>(
+            aggregator,
+            aggregatingInputMerger,
+            submodelInput2AggregatingInputConverter);
+
+        for (Model<IS, IA> subMdl : subMdls)
+            res.addSubmodel(subMdl);
+
+        return res;
+    }
+
+    /**
+     * Get feature extractor which will be used for aggregator trainer from original feature extractor.
+     * This method is static to make sure that we will not grab context of instance in serialization.
+     *
+     * @param featureExtractor Original feature extractor.
+     * @param subMdls Submodels.
+     * @param <K> Type of upstream keys.
+     * @param <V> Type of upstream values.
+     * @return Feature extractor which will be used for aggregator trainer from original feature extractor.
+     */
+    private static <IS, IA, K, V> IgniteBiFunction<K, V, Vector> getFeatureExtractorForAggregator(
+        IgniteBiFunction<K, V, Vector> featureExtractor, List<Model<IS, IA>> subMdls,
+        IgniteFunction<IS, IA> submodelInput2AggregatingInputConverter,
+        IgniteFunction<IA, Vector> submodelOutput2VectorConverter,
+        IgniteFunction<Vector, IS> vector2SubmodelInputConverter) {
+        if (submodelInput2AggregatingInputConverter != null)
+            return featureExtractor.andThen((Vector v) -> {
+                Vector[] vs = subMdls.stream().map(sm ->
+                    applyToVector(sm, submodelOutput2VectorConverter, vector2SubmodelInputConverter, v)).toArray(Vector[]::new);
+                return VectorUtils.concat(v, vs);
+            });
+        else
+            return featureExtractor.andThen((Vector v) -> {
+                Vector[] vs = subMdls.stream().map(sm ->
+                    applyToVector(sm, submodelOutput2VectorConverter, vector2SubmodelInputConverter, v)).toArray(Vector[]::new);
+                return VectorUtils.concat(vs);
+            });
+    }
+
+    /**
+     * Apply submodel to {@link Vector}.
+     *
+     * @param mdl Submodel.
+     * @param submodelOutput2VectorConverter Function for conversion of submodel output to {@link Vector}.
+     * @param vector2SubmodelInputConverter Function used for conversion of {@link Vector} to submodel input.
+     * @param v Vector.
+     * @param <IS> Type of submodel input.
+     * @param <IA> Type of submodel output.
+     * @return Result of application of {@code submodelOutput2VectorConverter . mdl . vector2SubmodelInputConverter}
+     * where dot denotes functions composition.
+     */
+    private static <IS, IA> Vector applyToVector(Model<IS, IA> mdl,
+        IgniteFunction<IA, Vector> submodelOutput2VectorConverter,
+        IgniteFunction<Vector, IS> vector2SubmodelInputConverter,
+        Vector v) {
+        return vector2SubmodelInputConverter.andThen(mdl).andThen(submodelOutput2VectorConverter).apply(v);
+    }
+
+    /** {@inheritDoc} */
+    @Override protected <K, V> StackedModel<IS, IA, O, AM> updateModel(StackedModel<IS, IA, O, AM> mdl,
+        DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, L> lbExtractor) {
+        // This method is never called, we override "update" instead.
+        return null;
+    }
+
+    /** {@inheritDoc} */
+    @Override protected boolean checkState(StackedModel<IS, IA, O, AM> mdl) {
+        return true;
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedModel.java
new file mode 100644 (file)
index 0000000..cb64d01
--- /dev/null
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.stacking;
+
+import java.util.ArrayList;
+import java.util.List;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+
+/**
+ * Model consisting of two layers:
+ * <pre>
+ *     1. Submodels layer {@code (IS -> IA)}.
+ *     2. Aggregator layer {@code (IA -> O)}.
+ * </pre>
+ * Submodels layer is a "parallel" composition of several models {@code IS -> IA} each of them getting same input
+ * {@code IS} and produce own output, these outputs outputs {@code [IA]}
+ * are combined into a single output with a given binary "merger" operator {@code IA -> IA -> IA}. Result of merge
+ * is then passed to the aggregator layer.
+ * Aggregator layer consists of a model {@code IA -> O}.
+ *
+ * @param <IS> Type of submodels input.
+ * @param <IA> Type of submodels output (same as aggregator model input).
+ * @param <O> Type of aggregator model output.
+ * @param <AM> Type of aggregator model.
+ */
+public class StackedModel<IS, IA, O, AM extends Model<IA, O>> implements Model<IS, O> {
+    /** Submodels layer. */
+    private Model<IS, IA> subModelsLayer;
+
+    /** Aggregator model. */
+    private final AM aggregatorMdl;
+
+    /** Models constituting submodels layer. */
+    private List<Model<IS, IA>> submodels;
+
+    /** Binary operator merging submodels outputs. */
+    private final IgniteBinaryOperator<IA> aggregatingInputMerger;
+
+    /**
+     * Constructs instance of this class.
+     *
+     * @param aggregatorMdl Aggregator model.
+     * @param aggregatingInputMerger Binary operator used to merge submodels outputs.
+     * @param subMdlInput2AggregatingInput Function converting submodels input to aggregator input. (This function
+     * is needed when in {@link StackedDatasetTrainer} option to keep original features is chosen).
+     */
+    StackedModel(AM aggregatorMdl,
+        IgniteBinaryOperator<IA> aggregatingInputMerger,
+        IgniteFunction<IS, IA> subMdlInput2AggregatingInput) {
+        this.aggregatorMdl = aggregatorMdl;
+        this.aggregatingInputMerger = aggregatingInputMerger;
+        this.subModelsLayer = subMdlInput2AggregatingInput != null ? subMdlInput2AggregatingInput::apply : null;
+        submodels = new ArrayList<>();
+    }
+
+    /**
+     * Get submodels constituting first layer of this model.
+     *
+     * @return Submodels constituting first layer of this model.
+     */
+    List<Model<IS, IA>> submodels() {
+        return submodels;
+    }
+
+    /**
+     * Get aggregator model.
+     *
+     * @return Aggregator model.
+     */
+    AM aggregatorModel() {
+        return aggregatorMdl;
+    }
+
+    /**
+     * Add submodel into first layer.
+     *
+     * @param subMdl Submodel to add.
+     */
+    void addSubmodel(Model<IS, IA> subMdl) {
+        submodels.add(subMdl);
+        subModelsLayer = subModelsLayer != null ? subModelsLayer.combine(subMdl, aggregatingInputMerger)
+            : subMdl;
+    }
+
+    /** {@inheritDoc} */
+    @Override public O apply(IS is) {
+        return subModelsLayer.andThen(aggregatorMdl).apply(is);
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedVectorDatasetTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedVectorDatasetTrainer.java
new file mode 100644 (file)
index 0000000..16eaec2
--- /dev/null
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.stacking;
+
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
+import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.primitives.matrix.Matrix;
+import org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.trainers.AdaptableDatasetTrainer;
+import org.apache.ignite.ml.trainers.DatasetTrainer;
+
+/**
+ * {@link StackedDatasetTrainer} with {@link Vector} as submodels input and output.
+ *
+ * @param <O> Type of aggregator model output.
+ * @param <L> Type of labels.
+ * @param <AM> Type of aggregator model.
+ */
+public class StackedVectorDatasetTrainer<O, AM extends Model<Vector, O>, L>
+    extends SimpleStackedDatasetTrainer<Vector, O, AM, L> {
+    /**
+     * Constructs instance of this class.
+     *
+     * @param aggregatingTrainer Aggregator trainer.
+     */
+    public StackedVectorDatasetTrainer(DatasetTrainer<AM, L> aggregatingTrainer) {
+        super(aggregatingTrainer,
+            VectorUtils::concat,
+            IgniteFunction.identity(),
+            IgniteFunction.identity(),
+            IgniteFunction.identity());
+    }
+
+    /**
+     * Constructs instance of this class.
+     */
+    public StackedVectorDatasetTrainer() {
+        this(null);
+    }
+
+    /** {@inheritDoc} */
+    @Override public <M1 extends Model<Vector, Vector>> StackedVectorDatasetTrainer<O, AM, L> addTrainer(
+        DatasetTrainer<M1, L> trainer) {
+        return (StackedVectorDatasetTrainer<O, AM, L>)super.addTrainer(trainer);
+    }
+
+    //TODO: IGNITE-10441 -- Look for options to avoid boilerplate overrides.
+    /** {@inheritDoc} */
+    @Override public StackedVectorDatasetTrainer<O, AM, L> withAggregatorTrainer(
+        DatasetTrainer<AM, L> aggregatorTrainer) {
+        return (StackedVectorDatasetTrainer<O, AM, L>)super.withAggregatorTrainer(aggregatorTrainer);
+    }
+
+    /** {@inheritDoc} */
+    @Override public StackedVectorDatasetTrainer<O, AM, L> withOriginalFeaturesKept() {
+        return (StackedVectorDatasetTrainer<O, AM, L>)super.withOriginalFeaturesKept();
+    }
+
+    /** {@inheritDoc} */
+    @Override public StackedVectorDatasetTrainer<O, AM, L> withOriginalFeaturesDropped() {
+        return (StackedVectorDatasetTrainer<O, AM, L>)super.withOriginalFeaturesDropped();
+    }
+
+    /** {@inheritDoc} */
+    @Override public StackedVectorDatasetTrainer<O, AM, L> withOriginalFeaturesKept(
+        IgniteFunction<Vector, Vector> submodelInput2AggregatingInputConverter) {
+        return (StackedVectorDatasetTrainer<O, AM, L>)super.withOriginalFeaturesKept(
+            submodelInput2AggregatingInputConverter);
+    }
+
+    /** {@inheritDoc} */
+    @Override public StackedVectorDatasetTrainer<O, AM, L> withSubmodelOutput2VectorConverter(
+        IgniteFunction<Vector, Vector> submodelOutput2VectorConverter) {
+        return (StackedVectorDatasetTrainer<O, AM, L>)super.withSubmodelOutput2VectorConverter(
+            submodelOutput2VectorConverter);
+    }
+
+    /** {@inheritDoc} */
+    @Override public StackedVectorDatasetTrainer<O, AM, L> withVector2SubmodelInputConverter(
+        IgniteFunction<Vector, Vector> vector2SubmodelInputConverter) {
+        return (StackedVectorDatasetTrainer<O, AM, L>)super.withVector2SubmodelInputConverter(
+            vector2SubmodelInputConverter);
+    }
+
+    /** {@inheritDoc} */
+    @Override public StackedVectorDatasetTrainer<O, AM, L> withAggregatorInputMerger(
+        IgniteBinaryOperator<Vector> merger) {
+        return (StackedVectorDatasetTrainer<O, AM, L>)super.withAggregatorInputMerger(merger);
+    }
+
+    /** {@inheritDoc} */
+    @Override public StackedVectorDatasetTrainer<O, AM, L> withEnvironmentBuilder(
+        LearningEnvironmentBuilder envBuilder) {
+        return (StackedVectorDatasetTrainer<O, AM, L>)super.withEnvironmentBuilder(envBuilder);
+    }
+
+    /** {@inheritDoc} */
+    @Override public <L1> StackedVectorDatasetTrainer<O, AM, L1> withConvertedLabels(
+        IgniteFunction<L1, L> new2Old) {
+        return (StackedVectorDatasetTrainer<O, AM, L1>)super.withConvertedLabels(new2Old);
+    }
+
+    /**
+     * Shortcut for adding trainer {@code Vector -> Double} where this trainer is treated as {@code Vector -> Vector}, where
+     * output {@link Vector} is constructed by wrapping double value.
+     *
+     * @param trainer Submodel trainer.
+     * @param <M1> Type of submodel trainer model.
+     * @return This object.
+     */
+    public <M1 extends Model<Vector, Double>> StackedVectorDatasetTrainer<O, AM, L> addTrainerWithDoubleOutput(
+        DatasetTrainer<M1, L> trainer) {
+        return addTrainer(AdaptableDatasetTrainer.of(trainer).afterTrainedModel(VectorUtils::num2Vec));
+    }
+
+    /**
+     * Shortcut for adding trainer {@code Matrix -> Matrix} where this trainer is treated as {@code Vector -> Vector}, where
+     * input {@link Vector} is turned into {@code 1 x cols} {@link Matrix} and output is a first row of output {@link Matrix}.
+     *
+     * @param trainer Submodel trainer.
+     * @param <M1> Type of submodel trainer model.
+     * @return This object.
+     */
+    public <M1 extends Model<Matrix, Matrix>> StackedVectorDatasetTrainer<O, AM, L> addMatrix2MatrixTrainer(
+        DatasetTrainer<M1, L> trainer) {
+        AdaptableDatasetTrainer<Vector, Vector, Matrix, Matrix, M1, L> adapted = AdaptableDatasetTrainer.of(trainer)
+            .beforeTrainedModel((Vector v) -> new DenseMatrix(v.asArray(), 1))
+            .afterTrainedModel((Matrix mtx) -> mtx.getRow(0));
+
+        return addTrainer(adapted);
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/package-info.java
new file mode 100644 (file)
index 0000000..7facdbf
--- /dev/null
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains classes used for training with stacking technique.
+ */
+package org.apache.ignite.ml.composition.stacking;
index 11b250b..9c0e281 100644 (file)
@@ -44,7 +44,7 @@ public interface UpstreamTransformer<K, V> extends Serializable {
      * @param other Other transformer.
      * @return Composition of this and other transformer.
      */
-    default UpstreamTransformer<K, V> andThen(UpstreamTransformer<K, V> other) {
+    public default UpstreamTransformer<K, V> andThen(UpstreamTransformer<K, V> other) {
         return upstream -> other.transform(transform(upstream));
     }
 }
index 1dc5591..7fa1efa 100644 (file)
@@ -76,7 +76,7 @@ public class ComputeUtils {
      */
     public static <R> Collection<R> affinityCallWithRetries(Ignite ignite, Collection<String> cacheNames,
         IgniteFunction<Integer, R> fun, int retries, int interval) {
-        assert cacheNames.size() > 0;
+        assert !cacheNames.isEmpty();
         assert interval >= 0;
 
         String primaryCache = cacheNames.iterator().next();
index ed55318..c52ad2b 100644 (file)
@@ -43,7 +43,7 @@ public class KNNClassificationTrainer extends SingleLabelDatasetTrainer<KNNClass
     }
 
     /** {@inheritDoc} */
-    @Override public <K, V> KNNClassificationModel updateModel(KNNClassificationModel mdl,
+    @Override protected <K, V> KNNClassificationModel updateModel(KNNClassificationModel mdl,
         DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor,
         IgniteBiFunction<K, V, Double> lbExtractor) {
 
index 2673b90..d81a7e4 100644 (file)
@@ -34,7 +34,18 @@ public interface IgniteFunction<T, R> extends Function<T, R>, Serializable {
      * @param <R> Type of output.
      * @return {@link IgniteFunction} returning specified constant.
      */
-    static <T, R> IgniteFunction<T, R> constant(R r) {
-        return t -> r;
+    // TODO: IGNITE-10653 Maybe we should add toString description to identity and constant.
+    public static <T, R> IgniteFunction<T, R> constant(R r) {
+        return (IgniteFunction<T, R>)t -> r;
+    }
+
+    /**
+     * Identity function.
+     *
+     * @param <T> Type of input and output.
+     * @return Identity function.
+     */
+    public static <T> IgniteFunction<T, T> identity() {
+        return (IgniteFunction<T, T>)t -> t;
     }
 }
index 6dce522..84be5f5 100644 (file)
@@ -92,7 +92,7 @@ public class DenseMatrix extends AbstractMatrix implements OrderedMatrix {
      * Build new matrix from flat raw array.
      */
     public DenseMatrix(double[] mtx, int rows) {
-        this(mtx, StorageConstants.ROW_STORAGE_MODE, rows);
+        this(mtx, rows, StorageConstants.ROW_STORAGE_MODE);
     }
 
     /** */
index 3c580c3..eaf7f91 100644 (file)
@@ -51,14 +51,34 @@ public class VectorUtils {
     }
 
     /**
+     * Wrap specified value into vector.
+     *
+     * @param val Value to wrap.
+     * @return Specified value wrapped into vector.
+     */
+    public static Vector num2Vec(double val) {
+        return fill(val, 1);
+    }
+
+    /**
      * Turn number into a local Vector of given size with one-hot encoding.
      *
      * @param num Number to turn into vector.
      * @param vecSize Vector size of output vector.
      * @return One-hot encoded number.
      */
-    public static Vector num2Vec(int num, int vecSize) {
-        return num2Vec(num, vecSize, false);
+    public static Vector oneHot(int num, int vecSize) {
+        return oneHot(num, vecSize, false);
+    }
+
+    /**
+     * Turn number to 1-sized array.
+     *
+     * @param val Value to wrap in array.
+     * @return Number wrapped in 1-sized array.
+     */
+    public static double[] num2Arr(double val) {
+        return new double[] {val};
     }
 
     /**
@@ -69,7 +89,7 @@ public class VectorUtils {
      * @param isDistributed Flag indicating if distributed vector should be created.
      * @return One-hot encoded number.
      */
-    public static Vector num2Vec(int num, int vecSize, boolean isDistributed) {
+    public static Vector oneHot(int num, int vecSize, boolean isDistributed) {
         Vector res = new DenseVector(vecSize);
         return res.setX(num, 1);
     }
@@ -197,4 +217,50 @@ public class VectorUtils {
 
         return answer;
     }
+
+    /**
+     * Concatenates two given vectors.
+     *
+     * @param v1 First vector.
+     * @param v2 Second vector.
+     * @return Concatenation result.
+     */
+    public static Vector concat(Vector v1, Vector v2) {
+        int size1 = v1.size();
+        int size2 = v2.size();
+        double[] vals = new double[size1 + size2];
+        System.arraycopy(v1.asArray(), 0, vals, 0, size1);
+        System.arraycopy(v2.asArray(), 0, vals, size1, size2);
+
+        return new DenseVector(vals);
+    }
+
+    /**
+     * Concatenates given vectors.
+     *
+     * @param v1 First vector.
+     * @param vs Other vectors.
+     * @return Concatenation result.
+     */
+    public static Vector concat(Vector v1, Vector... vs) {
+        Vector res = v1;
+        for (Vector v : vs)
+            res = concat(res, v);
+        return res;
+    }
+
+    /**
+     * Concatenates given vectors.
+     *
+     * @param vs Other vectors.
+     * @return Concatenation result.
+     */
+    public static Vector concat(Vector... vs) {
+        Vector res = vs.length == 0 ? new DenseVector() : vs[0];
+        for (int i = 1; i < vs.length; i++) {
+            Vector v = vs[i];
+            res = concat(res, v);
+        }
+        return res;
+    }
 }
index 5392cf2..830d494 100644 (file)
@@ -42,7 +42,7 @@ public class DenseVectorStorage implements VectorStorage {
      * @param size Vector size.
      */
     public DenseVectorStorage(int size) {
-        assert size > 0;
+        assert size >= 0;
 
         data = new double[size];
     }
index f265318..0ddad53 100644 (file)
@@ -68,7 +68,7 @@ public class OneVsRestTrainer<M extends Model<Vector, Double>>
     }
 
     /** {@inheritDoc} */
-    @Override public <K, V> MultiClassModel<M> updateModel(MultiClassModel<M> newMdl,
+    @Override protected <K, V> MultiClassModel<M> updateModel(MultiClassModel<M> newMdl,
         DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor,
         IgniteBiFunction<K, V, Double> lbExtractor) {
 
index dc245d2..6b2b11e 100644 (file)
@@ -79,7 +79,7 @@ public class LinearRegressionLSQRTrainer extends SingleLabelDatasetTrainer<Linea
     }
 
     /** {@inheritDoc} */
-    @Override protected boolean checkState(LinearRegressionModel mdl) {
+    @Override public boolean checkState(LinearRegressionModel mdl) {
         return true;
     }
 }
index cdbfe4c..864187d 100644 (file)
@@ -17,7 +17,6 @@
 
 package org.apache.ignite.ml.regressions.logistic;
 
-import java.io.Serializable;
 import java.util.Arrays;
 import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
@@ -41,7 +40,7 @@ import org.jetbrains.annotations.NotNull;
 /**
  * Trainer of the logistic regression model based on stochastic gradient descent algorithm.
  */
-public class LogisticRegressionSGDTrainer<P extends Serializable> extends SingleLabelDatasetTrainer<LogisticRegressionModel> {
+public class LogisticRegressionSGDTrainer extends SingleLabelDatasetTrainer<LogisticRegressionModel> {
     /** Update strategy. */
     private UpdatesStrategy updatesStgy = new UpdatesStrategy<>(
         new SimpleGDUpdateCalculator(0.2),
@@ -150,7 +149,7 @@ public class LogisticRegressionSGDTrainer<P extends Serializable> extends Single
      * @param maxIterations The parameter value.
      * @return Model with new max number of iterations before convergence parameter value.
      */
-    public LogisticRegressionSGDTrainer<P> withMaxIterations(int maxIterations) {
+    public LogisticRegressionSGDTrainer withMaxIterations(int maxIterations) {
         this.maxIterations = maxIterations;
         return this;
     }
@@ -161,7 +160,7 @@ public class LogisticRegressionSGDTrainer<P extends Serializable> extends Single
      * @param batchSize The size of learning batch.
      * @return Trainer with new batch size parameter value.
      */
-    public LogisticRegressionSGDTrainer<P> withBatchSize(int batchSize) {
+    public LogisticRegressionSGDTrainer withBatchSize(int batchSize) {
         this.batchSize = batchSize;
         return this;
     }
@@ -172,7 +171,7 @@ public class LogisticRegressionSGDTrainer<P extends Serializable> extends Single
      * @param amountOfLocIterations The parameter value.
      * @return Trainer with new locIterations parameter value.
      */
-    public LogisticRegressionSGDTrainer<P> withLocIterations(int amountOfLocIterations) {
+    public LogisticRegressionSGDTrainer withLocIterations(int amountOfLocIterations) {
         this.locIterations = amountOfLocIterations;
         return this;
     }
@@ -183,7 +182,7 @@ public class LogisticRegressionSGDTrainer<P extends Serializable> extends Single
      * @param seed Seed for random generator.
      * @return Trainer with new seed parameter value.
      */
-    public LogisticRegressionSGDTrainer<P> withSeed(long seed) {
+    public LogisticRegressionSGDTrainer withSeed(long seed) {
         this.seed = seed;
         return this;
     }
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/AdaptableDatasetModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/AdaptableDatasetModel.java
new file mode 100644 (file)
index 0000000..0e80325
--- /dev/null
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers;
+
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+
+/**
+ * Model which is composition of form {@code before `andThen` inner Mdl `andThen` after}.
+ *
+ * @param <I> Type of input of this model.
+ * @param <O> Type of output of this model.
+ * @param <IW> Type of input of inner model.
+ * @param <OW> Type of output of inner model.
+ * @param <M> Type of inner model.
+ */
+public class AdaptableDatasetModel<I, O, IW, OW, M extends Model<IW, OW>> implements Model<I, O> {
+    /** Function applied before inner model. */
+    private final IgniteFunction<I, IW> before;
+
+    /** Function applied after inner model.*/
+    private final IgniteFunction<OW, O> after;
+
+    /** Inner model. */
+    private final M mdl;
+
+    /**
+     * Construct instance of this class.
+     *
+     * @param before Function applied before wrapped model.
+     * @param mdl Inner model.
+     * @param after Function applied after wrapped model.
+     */
+    public AdaptableDatasetModel(IgniteFunction<I, IW> before, M mdl, IgniteFunction<OW, O> after) {
+        this.before = before;
+        this.after = after;
+        this.mdl = mdl;
+    }
+
+    /**
+     * Result of this model application is a result of composition {@code before `andThen` inner mdl `andThen` after}.
+     */
+    @Override public O apply(I i) {
+        return before.andThen(mdl).andThen(after).apply(i);
+    }
+
+    /** {@inheritDoc} */
+    @Override public <O1> AdaptableDatasetModel<I, O1, IW, OW, M> andThen(IgniteFunction<O, O1> after) {
+        return new AdaptableDatasetModel<>(before, mdl, i -> after.apply(this.after.apply(i)));
+    }
+
+    /**
+     * Create new {@code AdaptableDatasetModel} which is a composition of the form {@code thisMdl . before}.
+     *
+     * @param before Function applied before this model.
+     * @param <I1> Type of function applied before this model.
+     * @return New {@code AdaptableDatasetModel} which is a composition of the form {@code thisMdl . before}.
+     */
+    public <I1> AdaptableDatasetModel<I1, O, IW, OW, M> andBefore(IgniteFunction<I1, I> before) {
+        IgniteFunction<I1, IW> function = i -> this.before.apply(before.apply(i));
+        return new AdaptableDatasetModel<>(function, mdl, after);
+    }
+
+    /**
+     * Get inner model.
+     *
+     * @return Inner model.
+     */
+    public M innerModel() {
+        return mdl;
+    }
+
+    /**
+     * Create new instance of this class with changed inner model.
+     *
+     * @param mdl Inner model.
+     * @param <M1> Type of inner model.
+     * @return New instance of this class with changed inner model.
+     */
+    public <M1 extends Model<IW, OW>> AdaptableDatasetModel<I, O, IW, OW, M1> withInnerModel(M1 mdl) {
+        return new AdaptableDatasetModel<>(before, mdl, after);
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/AdaptableDatasetTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/AdaptableDatasetTrainer.java
new file mode 100644 (file)
index 0000000..7e2e810
--- /dev/null
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers;
+
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+
+/**
+ * Type used to adapt input and output types of wrapped {@link DatasetTrainer}.
+ * Produces model which is composition  of form {@code before `andThen` wMdl `andThen` after} where wMdl is model produced by
+ * wrapped trainer.
+ *
+ * @param <I> Input type of model produced by this trainer.
+ * @param <O> Output type of model produced by this trainer.
+ * @param <IW> Input type of model produced by wrapped trainer.
+ * @param <OW> Output type of model produced by wrapped trainer.
+ * @param <M> Type of model produced by wrapped model.
+ * @param <L> Type of labels.
+ */
+public class AdaptableDatasetTrainer<I, O, IW, OW, M extends Model<IW, OW>, L>
+    extends DatasetTrainer<AdaptableDatasetModel<I, O, IW, OW, M>, L> {
+    /** Wrapped trainer. */
+    private final DatasetTrainer<M, L> wrapped;
+
+    /** Function used to convert input type of wrapped trainer. */
+    private final IgniteFunction<I, IW> before;
+
+    /** Function used to convert output type of wrapped trainer. */
+    private final IgniteFunction<OW, O> after;
+
+    /**
+     * Construct instance of this class from a given {@link DatasetTrainer}.
+     *
+     * @param wrapped Wrapped trainer.
+     * @param <I> Input type of wrapped trainer.
+     * @param <O> Output type of wrapped trainer.
+     * @param <M> Type of model produced by wrapped trainer.
+     * @param <L> Type of labels.
+     * @return Instance of this class.
+     */
+    public static <I, O, M extends Model<I, O>, L> AdaptableDatasetTrainer<I, O, I, O, M, L> of(DatasetTrainer<M, L> wrapped) {
+        return new AdaptableDatasetTrainer<>(IgniteFunction.identity(), wrapped, IgniteFunction.identity());
+    }
+
+    /**
+     * Construct instance of this class with specified wrapped trainer and converter functions.
+     *
+     * @param before Function used to convert input type of wrapped trainer.
+     * @param wrapped  Wrapped trainer.
+     * @param after Function used to convert output type of wrapped trainer.
+     */
+    private AdaptableDatasetTrainer(IgniteFunction<I, IW> before, DatasetTrainer<M, L> wrapped, IgniteFunction<OW, O> after) {
+        this.before = before;
+        this.wrapped = wrapped;
+        this.after = after;
+    }
+
+    /** {@inheritDoc} */
+    @Override public <K, V> AdaptableDatasetModel<I, O, IW, OW, M> fit(DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+        M fit = wrapped.fit(datasetBuilder, featureExtractor, lbExtractor);
+        return new AdaptableDatasetModel<>(before, fit, after);
+    }
+
+    /** {@inheritDoc} */
+    @Override protected boolean checkState(AdaptableDatasetModel<I, O, IW, OW, M> mdl) {
+        return wrapped.checkState(mdl.innerModel());
+    }
+
+    /** {@inheritDoc} */
+    @Override protected <K, V> AdaptableDatasetModel<I, O, IW, OW, M> updateModel(AdaptableDatasetModel<I, O, IW, OW, M> mdl, DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+        return mdl.withInnerModel(wrapped.updateModel(mdl.innerModel(), datasetBuilder, featureExtractor, lbExtractor));
+    }
+
+    /**
+     * Let this trainer produce model {@code mdl}. This method produces a trainer which produces {@code mdl1}, where
+     * {@code mdl1 = mdl `andThen` after}.
+     *
+     * @param after Function inserted before produced model.
+     * @param <O1> Type of produced model output.
+     * @return New {@link DatasetTrainer} which produces composition of specified function and model produced by
+     * original trainer.
+     */
+    public <O1> AdaptableDatasetTrainer<I, O1, IW, OW, M, L> afterTrainedModel(IgniteFunction<O, O1> after) {
+        return new AdaptableDatasetTrainer<>(before, wrapped, i -> after.apply(this.after.apply(i)));
+    }
+
+    /**
+     * Let this trainer produce model {@code mdl}. This method produces a trainer which produces {@code mdl1}, where
+     * {@code mdl1 = f `andThen` mdl}.
+     *
+     * @param before Function inserted before produced model.
+     * @param <I1> Type of produced model input.
+     * @return New {@link DatasetTrainer} which produces composition of specified function and model produced by
+     * original trainer.
+     */
+    public <I1> AdaptableDatasetTrainer<I1, O, IW, OW, M, L> beforeTrainedModel(IgniteFunction<I1, I> before) {
+        IgniteFunction<I1, IW> function = i -> this.before.apply(before.apply(i));
+        return new AdaptableDatasetTrainer<>(function, wrapped, after);
+    }
+}
index dabf66a..161a40c 100644 (file)
@@ -29,6 +29,7 @@ import org.apache.ignite.ml.environment.LearningEnvironment;
 import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.environment.logging.MLLogger;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.jetbrains.annotations.NotNull;
 
@@ -70,6 +71,7 @@ public abstract class DatasetTrainer<M extends Model, L> {
      * @param <V> Type of a value in {@code upstream} data.
      * @return Updated model.
      */
+    //
     public <K,V> M update(M mdl, DatasetBuilder<K, V> datasetBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
 
@@ -306,6 +308,37 @@ public abstract class DatasetTrainer<M extends Model, L> {
     }
 
     /**
+     * Creates {@code DatasetTrainer} with same training logic, but able to accept labels of given new type
+     * of labels.
+     *
+     * @param new2Old Converter of new labels to old labels.
+     * @param <L1> New labels type.
+     * @return {@code DatasetTrainer} with same training logic, but able to accept labels of given new type
+     * of labels.
+     */
+    public <L1> DatasetTrainer<M, L1> withConvertedLabels(IgniteFunction<L1, L> new2Old) {
+        DatasetTrainer<M, L> old = this;
+        return new DatasetTrainer<M, L1>() {
+            /** {@inheritDoc} */
+            @Override public <K, V> M fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor,
+                IgniteBiFunction<K, V, L1> lbExtractor) {
+                return old.fit(datasetBuilder, featureExtractor, lbExtractor.andThen(new2Old));
+            }
+
+            /** {@inheritDoc} */
+            @Override protected boolean checkState(M mdl) {
+                return old.checkState(mdl);
+            }
+
+            /** {@inheritDoc} */
+            @Override protected <K, V> M updateModel(M mdl, DatasetBuilder<K, V> datasetBuilder,
+                IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L1> lbExtractor) {
+                return old.update(mdl, datasetBuilder, featureExtractor, lbExtractor.andThen(new2Old));
+            }
+        };
+    }
+
+    /**
      * Get learning environment.
      *
      * @return Learning environment.
index 510d26e..573759e 100644 (file)
@@ -99,7 +99,7 @@ public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends Dataset
      * @param <V> Type of a value in {@code upstream} data.
      * @return New model based on new dataset.
      */
-    @Override public <K, V> DecisionTreeNode updateModel(DecisionTreeNode mdl, DatasetBuilder<K, V> datasetBuilder,
+    @Override protected <K, V> DecisionTreeNode updateModel(DecisionTreeNode mdl, DatasetBuilder<K, V> datasetBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
 
         return fit(datasetBuilder, featureExtractor, lbExtractor);
index 78d6659..b4cd832 100644 (file)
@@ -34,7 +34,6 @@ import org.apache.ignite.ml.regressions.RegressionsTestSuite;
 import org.apache.ignite.ml.selection.SelectionTestSuite;
 import org.apache.ignite.ml.structures.StructuresTestSuite;
 import org.apache.ignite.ml.svm.SVMTestSuite;
-import org.apache.ignite.ml.trainers.BaggingTest;
 import org.apache.ignite.ml.tree.DecisionTreeTestSuite;
 import org.junit.runner.RunWith;
 import org.junit.runners.Suite;
@@ -62,7 +61,6 @@ import org.junit.runners.Suite;
     StructuresTestSuite.class,
     CommonTestSuite.class,
     InferenceTestSuite.class,
-    BaggingTest.class,
     MultiClassTestSuite.class
 })
 public class IgniteMLTestSuite {
index 1103ef0..b85a5c3 100644 (file)
 package org.apache.ignite.ml;
 
 import java.util.stream.IntStream;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.matrix.Matrix;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.trainers.DatasetTrainer;
 import org.junit.Assert;
 
 import static org.junit.Assert.assertTrue;
@@ -227,7 +230,8 @@ public class TestUtils {
     public static boolean checkIsInEpsilonNeighbourhoodBoolean(Vector v1, Vector v2, double epsilon) {
         try {
             checkIsInEpsilonNeighbourhood(new Vector[] {v1}, new Vector[] {v2}, epsilon);
-        } catch (Throwable e) {
+        }
+        catch (Throwable e) {
             return false;
         }
 
@@ -404,4 +408,36 @@ public class TestUtils {
     public static <T, V> Model<T, V> constantModel(V v) {
         return t -> v;
     }
+
+    /**
+     * Returns trainer which independently of dataset outputs given model.
+     *
+     * @param ml Model.
+     * @param <I> Type of model input.
+     * @param <O> Type of model output.
+     * @param <M> Type of model.
+     * @param <L> Type of dataset labels.
+     * @return Trainer which independently of dataset outputs given model.
+     */
+    public static <I, O, M extends Model<I, O>, L> DatasetTrainer<M, L> constantTrainer(M ml) {
+        return new DatasetTrainer<M, L>() {
+            /** {@inheritDoc} */
+            @Override public <K, V> M fit(DatasetBuilder<K, V> datasetBuilder,
+                IgniteBiFunction<K, V, Vector> featureExtractor,
+                IgniteBiFunction<K, V, L> lbExtractor) {
+                return ml;
+            }
+
+            /** {@inheritDoc} */
+            @Override public boolean checkState(M mdl) {
+                return true;
+            }
+
+            /** {@inheritDoc} */
+            @Override public <K, V> M updateModel(M mdl, DatasetBuilder<K, V> datasetBuilder,
+                IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+                return ml;
+            }
+        };
+    }
 }
index 5d3bb5f..c078066 100644 (file)
@@ -18,7 +18,9 @@
 package org.apache.ignite.ml.common;
 
 import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
 
@@ -1158,4 +1160,31 @@ public class TrainerTest {
         {3, 9.959296741639132, -9.762961500922069},
         {3, 9.882357321966778, -9.069477551120192}
     };
+
+    /** xor truth table. */
+    protected static final double[][] xor = {
+        {0.0, 0.0, 0.0},
+        {0.0, 1.0, 1.0},
+        {1.0, 0.0, 1.0},
+        {1.0 ,1.0, 0.0}
+    };
+
+    /**
+     * Create cache mock.
+     *
+     * @param vals Values for cache mock.
+     * @return Cache mock.
+     */
+    protected Map<Integer, Double[]> getCacheMock(double[][] vals) {
+        Map<Integer, Double[]> cacheMock = new HashMap<>();
+
+        for (int i = 0; i < vals.length; i++) {
+            double[] row = vals[i];
+            Double[] convertedRow = new Double[row.length];
+            for (int j = 0; j < row.length; j++)
+                convertedRow[j] = row[j];
+            cacheMock.put(i, convertedRow);
+        }
+        return cacheMock;
+    }
 }
  * limitations under the License.
  */
 
-package org.apache.ignite.ml.trainers;
+package org.apache.ignite.ml.composition;
 
 import java.util.Arrays;
-import java.util.HashMap;
 import java.util.Map;
 import org.apache.ignite.ml.Model;
 import org.apache.ignite.ml.TestUtils;
 import org.apache.ignite.ml.common.TrainerTest;
-import org.apache.ignite.ml.composition.ModelsComposition;
 import org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictionsAggregator;
 import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator;
 import org.apache.ignite.ml.dataset.Dataset;
@@ -39,6 +37,8 @@ import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpda
 import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
 import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel;
 import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer;
+import org.apache.ignite.ml.trainers.DatasetTrainer;
+import org.apache.ignite.ml.trainers.TrainerTransformers;
 import org.junit.Test;
 
 /**
@@ -66,10 +66,10 @@ public class BaggingTest extends TrainerTest {
      */
     @Test
     public void testNaiveBaggingLogRegression() {
-        Map<Integer, Double[]> cacheMock = getCacheMock();
+        Map<Integer, Double[]> cacheMock = getCacheMock(twoLinearlySeparableClasses);
 
         DatasetTrainer<LogisticRegressionModel, Double> trainer =
-            (LogisticRegressionSGDTrainer<?>)new LogisticRegressionSGDTrainer<>()
+            new LogisticRegressionSGDTrainer()
                 .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2),
                     SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg))
                 .withMaxIterations(30000)
@@ -102,17 +102,17 @@ public class BaggingTest extends TrainerTest {
     /**
      * Method used to test counts of data passed in context and in data builders.
      *
-     * @param counter Function specifying which data we should count.
+     * @param cntr Function specifying which data we should count.
      */
-    protected void count(IgniteTriFunction<Long, CountData, LearningEnvironment, Long> counter) {
-        Map<Integer, Double[]> cacheMock = getCacheMock();
+    protected void count(IgniteTriFunction<Long, CountData, LearningEnvironment, Long> cntr) {
+        Map<Integer, Double[]> cacheMock = getCacheMock(twoLinearlySeparableClasses);
 
-        CountTrainer countTrainer = new CountTrainer(counter);
+        CountTrainer cntTrainer = new CountTrainer(cntr);
 
         double subsampleRatio = 0.3;
 
-        ModelsComposition model = TrainerTransformers.makeBagged(
-            countTrainer,
+        ModelsComposition mdl = TrainerTransformers.makeBagged(
+            cntTrainer,
             100,
             subsampleRatio,
             2,
@@ -120,30 +120,12 @@ public class BaggingTest extends TrainerTest {
             new MeanValuePredictionsAggregator())
             .fit(cacheMock, parts, null, null);
 
-        Double res = model.apply(null);
+        Double res = mdl.apply(null);
 
         TestUtils.assertEquals(twoLinearlySeparableClasses.length * subsampleRatio, res, twoLinearlySeparableClasses.length / 10);
     }
 
     /**
-     * Create cache mock.
-     *
-     * @return Cache mock.
-     */
-    private Map<Integer, Double[]> getCacheMock() {
-        Map<Integer, Double[]> cacheMock = new HashMap<>();
-
-        for (int i = 0; i < twoLinearlySeparableClasses.length; i++) {
-            double[] row = twoLinearlySeparableClasses[i];
-            Double[] convertedRow = new Double[row.length];
-            for (int j = 0; j < row.length; j++)
-                convertedRow[j] = row[j];
-            cacheMock.put(i, convertedRow);
-        }
-        return cacheMock;
-    }
-
-    /**
      * Get sum of two Long values each of which can be null.
      *
      * @param a First value.
@@ -167,15 +149,15 @@ public class BaggingTest extends TrainerTest {
         /**
          * Function specifying which entries to count.
          */
-        private final IgniteTriFunction<Long, CountData, LearningEnvironment, Long> counter;
+        private final IgniteTriFunction<Long, CountData, LearningEnvironment, Long> cntr;
 
         /**
          * Construct instance of this class.
          *
-         * @param counter Function specifying which entries to count.
+         * @param cntr Function specifying which entries to count.
          */
-        public CountTrainer(IgniteTriFunction<Long, CountData, LearningEnvironment, Long> counter) {
-            this.counter = counter;
+        public CountTrainer(IgniteTriFunction<Long, CountData, LearningEnvironment, Long> cntr) {
+            this.cntr = cntr;
         }
 
         /** {@inheritDoc} */
@@ -189,7 +171,7 @@ public class BaggingTest extends TrainerTest {
                 (env, upstreamData, upstreamDataSize, ctx) -> new CountData(upstreamDataSize)
             );
 
-            Long cnt = dataset.computeWithCtx(counter, BaggingTest::plusOfNullables);
+            Long cnt = dataset.computeWithCtx(cntr, BaggingTest::plusOfNullables);
 
             return x -> Double.valueOf(cnt);
         }
index 8714eb2..87d56cd 100644 (file)
@@ -25,13 +25,15 @@ import org.junit.runner.RunWith;
 import org.junit.runners.Suite;
 
 /**
- * Test suite for all tests located in org.apache.ignite.ml.composition package.
+ * Test suite for all ensemble models tests.
  */
 @RunWith(Suite.class)
 @Suite.SuiteClasses({
     GDBTrainerTest.class,
     MeanValuePredictionsAggregatorTest.class,
     OnMajorityPredictionsAggregatorTest.class,
+    BaggingTest.class,
+    StackingTest.class,
     WeightedPredictionsAggregatorTest.class
 })
 public class CompositionTestSuite {
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/StackingTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/StackingTest.java
new file mode 100644 (file)
index 0000000..3336470
--- /dev/null
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition;
+
+import java.util.Arrays;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.TestUtils;
+import org.apache.ignite.ml.common.TrainerTest;
+import org.apache.ignite.ml.composition.stacking.StackedDatasetTrainer;
+import org.apache.ignite.ml.composition.stacking.StackedModel;
+import org.apache.ignite.ml.composition.stacking.StackedVectorDatasetTrainer;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.primitives.matrix.Matrix;
+import org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.nn.Activators;
+import org.apache.ignite.ml.nn.MLPTrainer;
+import org.apache.ignite.ml.nn.MultilayerPerceptron;
+import org.apache.ignite.ml.nn.UpdatesStrategy;
+import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
+import org.apache.ignite.ml.optimization.LossFunctions;
+import org.apache.ignite.ml.optimization.SmoothParametrized;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
+import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer;
+import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
+import org.apache.ignite.ml.trainers.AdaptableDatasetModel;
+import org.apache.ignite.ml.trainers.AdaptableDatasetTrainer;
+import org.apache.ignite.ml.trainers.DatasetTrainer;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+import static junit.framework.TestCase.assertEquals;
+
+/**
+ * Tests stacked trainers.
+ */
+public class StackingTest extends TrainerTest {
+    /** Rule to check exceptions. */
+    @Rule
+    public ExpectedException thrown = ExpectedException.none();
+
+    /**
+     * Tests simple stack training.
+     */
+    @Test
+    public void testSimpleStack() {
+        StackedDatasetTrainer<Vector, Vector, Double, LinearRegressionModel, Double> trainer =
+            new StackedDatasetTrainer<>();
+
+        UpdatesStrategy<SmoothParametrized, SimpleGDParameterUpdate> updatesStgy = new UpdatesStrategy<>(
+            new SimpleGDUpdateCalculator(0.2),
+            SimpleGDParameterUpdate::sumLocal,
+            SimpleGDParameterUpdate::avg
+        );
+
+        MLPArchitecture arch = new MLPArchitecture(2).
+            withAddedLayer(10, true, Activators.RELU).
+            withAddedLayer(1, false, Activators.SIGMOID);
+
+        MLPTrainer<SimpleGDParameterUpdate> trainer1 = new MLPTrainer<>(
+            arch,
+            LossFunctions.MSE,
+            updatesStgy,
+            3000,
+            10,
+            50,
+            123L
+        );
+
+        // Convert model trainer to produce Vector -> Vector model
+        DatasetTrainer<AdaptableDatasetModel<Vector, Vector, Matrix, Matrix, MultilayerPerceptron>, Double> mlpTrainer =
+            AdaptableDatasetTrainer.of(trainer1)
+                .beforeTrainedModel((Vector v) -> new DenseMatrix(v.asArray(), 1))
+                .afterTrainedModel((Matrix mtx) -> mtx.getRow(0))
+                .withConvertedLabels(VectorUtils::num2Arr);
+
+        final double factor = 3;
+
+        StackedModel<Vector, Vector, Double, LinearRegressionModel> mdl = trainer
+            .withAggregatorTrainer(new LinearRegressionLSQRTrainer().withConvertedLabels(x -> x * factor))
+            .addTrainer(mlpTrainer)
+            .withAggregatorInputMerger(VectorUtils::concat)
+            .withSubmodelOutput2VectorConverter(IgniteFunction.identity())
+            .withVector2SubmodelInputConverter(IgniteFunction.identity())
+            .withOriginalFeaturesKept(IgniteFunction.identity())
+            .withEnvironmentBuilder(TestUtils.testEnvBuilder())
+            .fit(getCacheMock(xor),
+                parts,
+                (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+                (k, v) -> v[v.length - 1]);
+
+        assertEquals(0.0 * factor, mdl.apply(VectorUtils.of(0.0, 0.0)), 0.3);
+        assertEquals(1.0 * factor, mdl.apply(VectorUtils.of(0.0, 1.0)), 0.3);
+        assertEquals(1.0 * factor, mdl.apply(VectorUtils.of(1.0, 0.0)), 0.3);
+        assertEquals(0.0 * factor, mdl.apply(VectorUtils.of(1.0, 1.0)), 0.3);
+    }
+
+    /**
+     * Tests simple stack training.
+     */
+    @Test
+    public void testSimpleVectorStack() {
+        StackedVectorDatasetTrainer<Double, LinearRegressionModel, Double> trainer =
+            new StackedVectorDatasetTrainer<>();
+
+        UpdatesStrategy<SmoothParametrized, SimpleGDParameterUpdate> updatesStgy = new UpdatesStrategy<>(
+            new SimpleGDUpdateCalculator(0.2),
+            SimpleGDParameterUpdate::sumLocal,
+            SimpleGDParameterUpdate::avg
+        );
+
+        MLPArchitecture arch = new MLPArchitecture(2).
+            withAddedLayer(10, true, Activators.RELU).
+            withAddedLayer(1, false, Activators.SIGMOID);
+
+        DatasetTrainer<MultilayerPerceptron, Double> mlpTrainer = new MLPTrainer<>(
+            arch,
+            LossFunctions.MSE,
+            updatesStgy,
+            3000,
+            10,
+            50,
+            123L
+        ).withConvertedLabels(VectorUtils::num2Arr);
+
+        final double factor = 3;
+
+        StackedModel<Vector, Vector, Double, LinearRegressionModel> mdl = trainer
+            .withAggregatorTrainer(new LinearRegressionLSQRTrainer().withConvertedLabels(x -> x * factor))
+            .addMatrix2MatrixTrainer(mlpTrainer)
+            .withEnvironmentBuilder(TestUtils.testEnvBuilder())
+            .fit(getCacheMock(xor),
+                parts,
+                (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+                (k, v) -> v[v.length - 1]);
+
+        assertEquals(0.0 * factor, mdl.apply(VectorUtils.of(0.0, 0.0)), 0.3);
+        assertEquals(1.0 * factor, mdl.apply(VectorUtils.of(0.0, 1.0)), 0.3);
+        assertEquals(1.0 * factor, mdl.apply(VectorUtils.of(1.0, 0.0)), 0.3);
+        assertEquals(0.0 * factor, mdl.apply(VectorUtils.of(1.0, 1.0)), 0.3);
+    }
+
+    /**
+     * Tests that if there is no any way for input of first layer to propagate to second layer,
+     * exception will be thrown.
+     */
+    @Test
+    public void testINoWaysOfPropagation() {
+        StackedDatasetTrainer<Void, Void, Void, Model<Void, Void>, Void> trainer =
+            new StackedDatasetTrainer<>();
+        thrown.expect(IllegalStateException.class);
+        trainer.fit(null, null, null);
+    }
+}
index 61f9fc4..74841a3 100644 (file)
@@ -47,7 +47,7 @@ public class OneVsRestTrainerTest extends TrainerTest {
         for (int i = 0; i < twoLinearlySeparableClasses.length; i++)
             cacheMock.put(i, twoLinearlySeparableClasses[i]);
 
-        LogisticRegressionSGDTrainer<?> binaryTrainer = new LogisticRegressionSGDTrainer<>()
+        LogisticRegressionSGDTrainer binaryTrainer = new LogisticRegressionSGDTrainer()
             .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2),
                 SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg))
             .withMaxIterations(1000)
@@ -80,7 +80,7 @@ public class OneVsRestTrainerTest extends TrainerTest {
         for (int i = 0; i < twoLinearlySeparableClasses.length; i++)
             cacheMock.put(i, twoLinearlySeparableClasses[i]);
 
-        LogisticRegressionSGDTrainer<?> binaryTrainer = new LogisticRegressionSGDTrainer<>()
+        LogisticRegressionSGDTrainer binaryTrainer = new LogisticRegressionSGDTrainer()
             .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2),
                 SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg))
             .withMaxIterations(1000)
index bd31b19..5ee50a6 100644 (file)
@@ -106,7 +106,7 @@ public class MLPTrainerMnistIntegrationTest extends GridCommonAbstractTest {
             ignite,
             trainingSet,
             (k, v) -> VectorUtils.of(v.getPixels()),
-            (k, v) -> VectorUtils.num2Vec(v.getLabel(), 10).getStorage().data()
+            (k, v) -> VectorUtils.oneHot(v.getLabel(), 10).getStorage().data()
         );
         System.out.println("Training completed in " + (System.currentTimeMillis() - start) + "ms");
 
index 6a17d18..9396009 100644 (file)
@@ -76,7 +76,7 @@ public class MLPTrainerMnistTest {
             trainingSet,
             1,
             (k, v) -> VectorUtils.of(v.getPixels()),
-            (k, v) -> VectorUtils.num2Vec(v.getLabel(), 10).getStorage().data()
+            (k, v) -> VectorUtils.oneHot(v.getLabel(), 10).getStorage().data()
         );
         System.out.println("Training completed in " + (System.currentTimeMillis() - start) + "ms");
 
index fec6220..694dcd3 100644 (file)
@@ -51,7 +51,7 @@ public class PipelineTest extends TrainerTest {
             cacheMock.put(i, convertedRow);
         }
 
-        LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>()
+        LogisticRegressionSGDTrainer trainer = new LogisticRegressionSGDTrainer()
             .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2),
                 SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg))
             .withMaxIterations(100000)
index c343ab9..681cb72 100644 (file)
@@ -43,7 +43,7 @@ public class LogisticRegressionSGDTrainerTest extends TrainerTest {
         for (int i = 0; i < twoLinearlySeparableClasses.length; i++)
             cacheMock.put(i, twoLinearlySeparableClasses[i]);
 
-        LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>()
+        LogisticRegressionSGDTrainer trainer = new LogisticRegressionSGDTrainer()
             .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2),
                 SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg))
             .withMaxIterations(100000)
@@ -70,7 +70,7 @@ public class LogisticRegressionSGDTrainerTest extends TrainerTest {
         for (int i = 0; i < twoLinearlySeparableClasses.length; i++)
             cacheMock.put(i, twoLinearlySeparableClasses[i]);
 
-        LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>()
+        LogisticRegressionSGDTrainer trainer = new LogisticRegressionSGDTrainer()
             .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2),
                 SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg))
             .withMaxIterations(100000)