IGNITE-8867: [ML] Bagging on learning sample
authorArtem Malykh <amalykhgh@gmail.com>
Sun, 18 Nov 2018 21:59:56 +0000 (00:59 +0300)
committerYury Babak <ybabak@gridgain.com>
Sun, 18 Nov 2018 21:59:56 +0000 (00:59 +0300)
this closes #5058

24 files changed:
examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java [new file with mode: 0644]
examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/package-info.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/composition/BaggingModelTrainer.java [deleted file]
modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetBuilder.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/PartitionContextBuilder.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/PartitionDataBuilder.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformer.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformerChain.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilder.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDataset.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java
modules/ml/src/main/java/org/apache/ignite/ml/environment/LearningEnvironmentBuilder.java
modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/trainers/transformers/BaggingUpstreamTransformer.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/trainers/transformers/package-info.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogramsComputer.java
modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java
modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtilsTest.java
modules/ml/src/test/java/org/apache/ignite/ml/trainers/BaggingTest.java [new file with mode: 0644]

diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java
new file mode 100644 (file)
index 0000000..baf513a
--- /dev/null
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.regression.logistic.bagged;
+
+import java.io.FileNotFoundException;
+import java.util.Arrays;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.examples.ml.util.MLSandboxDatasets;
+import org.apache.ignite.examples.ml.util.SandboxMLCache;
+import org.apache.ignite.ml.composition.ModelsComposition;
+import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.nn.UpdatesStrategy;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
+import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer;
+import org.apache.ignite.ml.selection.cv.CrossValidation;
+import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import org.apache.ignite.ml.trainers.DatasetTrainer;
+import org.apache.ignite.ml.trainers.TrainerTransformers;
+
+/**
+ * This example shows how bagging technique may be applied to arbitrary trainer.
+ * As an example (a bit synthetic) logistic regression is considered.
+ * <p>
+ * Code in this example launches Ignite grid and fills the cache with test data points (based on the
+ * <a href="https://en.wikipedia.org/wiki/Iris_flower_data_set"></a>Iris dataset</a>).</p>
+ * <p>
+ * After that it trains bootstrapped (or bagged) version of logistic regression trainer. Bootstrapping is done
+ * on both samples and features (<a href="https://en.wikipedia.org/wiki/Bootstrap_aggregating"></a>Samples bagging</a>,
+ * <a href="https://en.wikipedia.org/wiki/Random_subspace_method"></a>Features bagging</a>).</p>
+ * <p>
+ * Finally, this example applies cross-validation to resulted model and prints accuracy if each fold.</p>
+ */
+public class BaggedLogisticRegressionSGDTrainerExample {
+    /** Run example. */
+    public static void main(String[] args) throws FileNotFoundException {
+        System.out.println();
+        System.out.println(">>> Logistic regression model over partitioned dataset usage example started.");
+        // Start ignite grid.
+        try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
+            System.out.println(">>> Ignite grid started.");
+
+            IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
+                .fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
+
+            System.out.println(">>> Create new logistic regression trainer object.");
+            LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>()
+                .withUpdatesStgy(new UpdatesStrategy<>(
+                    new SimpleGDUpdateCalculator(0.2),
+                    SimpleGDParameterUpdate::sumLocal,
+                    SimpleGDParameterUpdate::avg
+                ))
+                .withMaxIterations(100000)
+                .withLocIterations(100)
+                .withBatchSize(10)
+                .withSeed(123L);
+
+            System.out.println(">>> Perform the training to get the model.");
+
+            DatasetTrainer< ModelsComposition, Double> baggedTrainer = TrainerTransformers.makeBagged(
+                trainer,
+                10,
+                0.6,
+                4,
+                3,
+                new OnMajorityPredictionsAggregator(),
+                123L);
+
+            System.out.println(">>> Perform evaluation of the model.");
+
+            double[] score = new CrossValidation<ModelsComposition, Double, Integer, Vector>().score(
+                baggedTrainer,
+                new Accuracy<>(),
+                ignite,
+                dataCache,
+                (k, v) -> v.copyOfRange(1, v.size()),
+                (k, v) -> v.get(0),
+                3
+            );
+
+            System.out.println(">>> ---------------------------------");
+
+            Arrays.stream(score).forEach(sc -> {
+                System.out.println("\n>>> Accuracy " + sc);
+            });
+
+            System.out.println(">>> Bagged logistic regression model over partitioned dataset usage example completed.");
+        }
+    }
+}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/package-info.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/package-info.java
new file mode 100644 (file)
index 0000000..ea0d19e
--- /dev/null
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * ML bagged logistic regression examples.
+ */
+package org.apache.ignite.examples.ml.regression.logistic.bagged;
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/BaggingModelTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/BaggingModelTrainer.java
deleted file mode 100644 (file)
index 493c1da..0000000
+++ /dev/null
@@ -1,200 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.composition;
-
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Random;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
-import org.apache.ignite.ml.Model;
-import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator;
-import org.apache.ignite.ml.dataset.DatasetBuilder;
-import org.apache.ignite.ml.environment.logging.MLLogger;
-import org.apache.ignite.ml.environment.parallelism.Promise;
-import org.apache.ignite.ml.math.functions.IgniteBiFunction;
-import org.apache.ignite.ml.math.functions.IgniteFunction;
-import org.apache.ignite.ml.math.functions.IgniteSupplier;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
-import org.apache.ignite.ml.selection.split.mapper.SHA256UniformMapper;
-import org.apache.ignite.ml.trainers.DatasetTrainer;
-import org.apache.ignite.ml.util.Utils;
-import org.jetbrains.annotations.NotNull;
-
-/**
- * Abstract trainer implementing bagging logic. In each learning iteration the algorithm trains one model on subset of
- * learning sample and subspace of features space. Each model is produced from same model-class [e.g. Decision Trees].
- */
-public abstract class BaggingModelTrainer extends DatasetTrainer<ModelsComposition, Double> {
-    /**
-     * Predictions aggregator.
-     */
-    private final PredictionsAggregator predictionsAggregator;
-    /**
-     * Number of features to draw from original features vector to train each model.
-     */
-    private final int maximumFeaturesCntPerMdl;
-    /**
-     * Ensemble size.
-     */
-    private final int ensembleSize;
-    /**
-     * Size of sample part in percent to train one model.
-     */
-    private final double samplePartSizePerMdl;
-    /**
-     * Feature vector size.
-     */
-    private final int featureVectorSize;
-
-    /**
-     * Constructs new instance of BaggingModelTrainer.
-     *
-     * @param predictionsAggregator Predictions aggregator.
-     * @param featureVectorSize Feature vector size.
-     * @param maximumFeaturesCntPerMdl Number of features to draw from original features vector to train each model.
-     * @param ensembleSize Ensemble size.
-     * @param samplePartSizePerMdl Size of sample part in percent to train one model.
-     */
-    public BaggingModelTrainer(PredictionsAggregator predictionsAggregator,
-        int featureVectorSize,
-        int maximumFeaturesCntPerMdl,
-        int ensembleSize,
-        double samplePartSizePerMdl) {
-
-        this.predictionsAggregator = predictionsAggregator;
-        this.maximumFeaturesCntPerMdl = maximumFeaturesCntPerMdl;
-        this.ensembleSize = ensembleSize;
-        this.samplePartSizePerMdl = samplePartSizePerMdl;
-        this.featureVectorSize = featureVectorSize;
-    }
-
-    /** {@inheritDoc} */
-    @Override public <K, V> ModelsComposition fit(DatasetBuilder<K, V> datasetBuilder,
-        IgniteBiFunction<K, V, Vector> featureExtractor,
-        IgniteBiFunction<K, V, Double> lbExtractor) {
-
-        MLLogger log = environment.logger(getClass());
-        log.log(MLLogger.VerboseLevel.LOW, "Start learning");
-
-        Long startTs = System.currentTimeMillis();
-
-        List<IgniteSupplier<ModelOnFeaturesSubspace>> tasks = new ArrayList<>();
-        for(int i = 0; i < ensembleSize; i++)
-            tasks.add(() -> learnModel(datasetBuilder, featureExtractor, lbExtractor));
-
-        List<Model<Vector, Double>> models = environment.parallelismStrategy().submit(tasks)
-            .stream().map(Promise::unsafeGet)
-            .collect(Collectors.toList());
-
-        double learningTime = (double)(System.currentTimeMillis() - startTs) / 1000.0;
-        log.log(MLLogger.VerboseLevel.LOW, "The training time was %.2fs", learningTime);
-        log.log(MLLogger.VerboseLevel.LOW, "Learning finished");
-        return new ModelsComposition(models, predictionsAggregator);
-    }
-
-    /**
-     * Trains one model on part of sample and features subspace.
-     *
-     * @param datasetBuilder Dataset builder.
-     * @param featureExtractor Feature extractor.
-     * @param lbExtractor Label extractor.
-     */
-    @NotNull private <K, V> ModelOnFeaturesSubspace learnModel(
-        DatasetBuilder<K, V> datasetBuilder,
-        IgniteBiFunction<K, V, Vector> featureExtractor,
-        IgniteBiFunction<K, V, Double> lbExtractor) {
-
-        Random rnd = new Random();
-        SHA256UniformMapper<K, V> sampleFilter = new SHA256UniformMapper<>(rnd);
-        long featureExtractorSeed = rnd.nextLong();
-        Map<Integer, Integer> featuresMapping = createFeaturesMapping(featureExtractorSeed, featureVectorSize);
-
-        //TODO: IGNITE-8867 Need to implement bootstrapping algorithm
-        Long startTs = System.currentTimeMillis();
-        Model<Vector, Double> mdl = buildDatasetTrainerForModel().fit(
-            datasetBuilder.withFilter((features, answer) -> sampleFilter.map(features, answer) < samplePartSizePerMdl),
-            wrapFeatureExtractor(featureExtractor, featuresMapping),
-            lbExtractor);
-        double learningTime = (double)(System.currentTimeMillis() - startTs) / 1000.0;
-        environment.logger(getClass()).log(MLLogger.VerboseLevel.HIGH, "One model training time was %.2fs", learningTime);
-
-        return new ModelOnFeaturesSubspace(featuresMapping, mdl);
-    }
-
-    /**
-     * Constructs mapping from original feature vector to subspace.
-     *
-     * @param seed Seed.
-     * @param featuresVectorSize Features vector size.
-     */
-    private Map<Integer, Integer> createFeaturesMapping(long seed, int featuresVectorSize) {
-        int[] featureIdxs = Utils.selectKDistinct(featuresVectorSize, maximumFeaturesCntPerMdl, new Random(seed));
-        Map<Integer, Integer> locFeaturesMapping = new HashMap<>();
-
-        IntStream.range(0, maximumFeaturesCntPerMdl)
-            .forEach(localId -> locFeaturesMapping.put(localId, featureIdxs[localId]));
-
-        return locFeaturesMapping;
-    }
-
-    /**
-     * Creates trainer specific to ensemble.
-     */
-    protected abstract DatasetTrainer<? extends Model<Vector, Double>, Double> buildDatasetTrainerForModel();
-
-    /**
-     * Wraps the original feature extractor with features subspace mapping applying.
-     *
-     * @param featureExtractor Feature extractor.
-     * @param featureMapping Feature mapping.
-     */
-    private <K, V> IgniteBiFunction<K, V, Vector> wrapFeatureExtractor(
-        IgniteBiFunction<K, V, Vector> featureExtractor,
-        Map<Integer, Integer> featureMapping) {
-
-        return featureExtractor.andThen((IgniteFunction<Vector, Vector>)featureValues -> {
-            double[] newFeaturesValues = new double[featureMapping.size()];
-            featureMapping.forEach((localId, featureValueId) -> newFeaturesValues[localId] = featureValues.get(featureValueId));
-            return VectorUtils.of(newFeaturesValues);
-        });
-    }
-
-    /**
-     * Learn new models on dataset and create new Compositions over them and already learned models.
-     *
-     * @param mdl Learned model.
-     * @param datasetBuilder Dataset builder.
-     * @param featureExtractor Feature extractor.
-     * @param lbExtractor Label extractor.
-     * @param <K> Type of a key in {@code upstream} data.
-     * @param <V> Type of a value in {@code upstream} data.
-     * @return New models composition.
-     */
-    @Override public <K, V> ModelsComposition updateModel(ModelsComposition mdl, DatasetBuilder<K, V> datasetBuilder,
-        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
-
-        ArrayList<Model<Vector, Double>> newModels = new ArrayList<>(mdl.getModels());
-        newModels.addAll(fit(datasetBuilder, featureExtractor, lbExtractor).getModels());
-
-        return new ModelsComposition(newModels, predictionsAggregator);
-    }
-}
index 19bdde9..4dd0a96 100644 (file)
@@ -21,6 +21,7 @@ import java.io.Serializable;
 import org.apache.ignite.lang.IgniteBiPredicate;
 import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import org.apache.ignite.ml.trainers.transformers.BaggingUpstreamTransformer;
 
 /**
  * A builder constructing instances of a {@link Dataset}. Implementations of this interface encapsulate logic of
@@ -48,6 +49,16 @@ public interface DatasetBuilder<K, V> {
     public <C extends Serializable, D extends AutoCloseable> Dataset<C, D> build(
         PartitionContextBuilder<K, V, C> partCtxBuilder, PartitionDataBuilder<K, V, C, D> partDataBuilder);
 
+    /**
+     * Get upstream transformers chain. This chain is applied to upstream data before it is passed
+     * to {@link PartitionDataBuilder} and {@link PartitionContextBuilder}. This is needed to allow
+     * transformation to upstream data which are agnostic of any changes that happen after.
+     * Such transformations may be used for deriving meta-algorithms such as bagging
+     * (see {@link BaggingUpstreamTransformer}).
+     *
+     * @return Upstream transformers chain.
+     */
+    public UpstreamTransformerChain<K, V> upstreamTransformersChain();
 
     /**
      * Returns new instance of DatasetBuilder using conjunction of internal filter and {@code filterToAdd}.
index 027ec34..6e1fec3 100644 (file)
@@ -19,6 +19,7 @@ package org.apache.ignite.ml.dataset;
 
 import java.io.Serializable;
 import java.util.Iterator;
+import java.util.stream.Stream;
 import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
 import org.apache.ignite.ml.math.functions.IgniteFunction;
 
@@ -37,6 +38,10 @@ import org.apache.ignite.ml.math.functions.IgniteFunction;
 public interface PartitionContextBuilder<K, V, C extends Serializable> extends Serializable {
     /**
      * Builds a new partition {@code context} from an {@code upstream} data.
+     * Important: there is no guarantee that there will be no more than one UpstreamEntry with given key,
+     * UpstreamEntry should be thought rather as a container saving all data from upstream, but omitting uniqueness
+     * constraint. This constraint is omitted to allow upstream data transformers in {@link DatasetBuilder} replicating
+     * entries. For example it can be useful for bootstrapping.
      *
      * @param upstreamData Partition {@code upstream} data.
      * @param upstreamDataSize Partition {@code upstream} data size.
@@ -44,6 +49,22 @@ public interface PartitionContextBuilder<K, V, C extends Serializable> extends S
      */
     public C build(Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize);
 
+
+    /**
+     * Builds a new partition {@code context} from an {@code upstream} data.
+     * Important: there is no guarantee that there will be no more than one UpstreamEntry with given key,
+     * UpstreamEntry should be thought rather as a container saving all data from upstream, but omitting uniqueness
+     * constraint. This constraint is omitted to allow upstream data transformers in {@link DatasetBuilder} replicating
+     * entries. For example it can be useful for bootstrapping.
+     *
+     * @param upstreamData Partition {@code upstream} data.
+     * @param upstreamDataSize Partition {@code upstream} data size.
+     * @return Partition {@code context}.
+     */
+    public default C build(Stream<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize) {
+        return build(upstreamData.iterator(), upstreamDataSize);
+    }
+
     /**
      * Makes a composed partition {@code context} builder that first builds a {@code context} and then applies the
      * specified function on the result.
index c1391b1..54c7611 100644 (file)
@@ -19,6 +19,7 @@ package org.apache.ignite.ml.dataset;
 
 import java.io.Serializable;
 import java.util.Iterator;
+import java.util.stream.Stream;
 import org.apache.ignite.ml.dataset.primitive.builder.data.SimpleDatasetDataBuilder;
 import org.apache.ignite.ml.dataset.primitive.builder.data.SimpleLabeledDatasetDataBuilder;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
@@ -39,7 +40,11 @@ import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 @FunctionalInterface
 public interface PartitionDataBuilder<K, V, C extends Serializable, D extends AutoCloseable> extends Serializable {
     /**
-     * Builds a new partition {@code data} from a partition {@code upstream} data and partition {@code context}
+     * Builds a new partition {@code data} from a partition {@code upstream} data and partition {@code context}.
+     * Important: there is no guarantee that there will be no more than one UpstreamEntry with given key,
+     * UpstreamEntry should be thought rather as a container saving all data from upstream, but omitting uniqueness
+     * constraint. This constraint is omitted to allow upstream data transformers in {@link DatasetBuilder} replicating
+     * entries. For example it can be useful for bootstrapping.
      *
      * @param upstreamData Partition {@code upstream} data.
      * @param upstreamDataSize Partition {@code upstream} data size.
@@ -48,6 +53,10 @@ public interface PartitionDataBuilder<K, V, C extends Serializable, D extends Au
      */
     public D build(Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize, C ctx);
 
+    public default D build(Stream<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize, C ctx) {
+        return build(upstreamData.iterator(), upstreamDataSize, ctx);
+    }
+
     /**
      * Makes a composed partition {@code data} builder that first builds a {@code data} and then applies the specified
      * function on the result.
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformer.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformer.java
new file mode 100644 (file)
index 0000000..ba70e2e
--- /dev/null
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.dataset;
+
+import java.io.Serializable;
+import java.util.Random;
+import java.util.stream.Stream;
+
+/**
+ * Interface of transformer of upstream.
+ *
+ * @param <K> Type of keys in the upstream.
+ * @param <V> Type of values in the upstream.
+ */
+@FunctionalInterface
+public interface UpstreamTransformer<K, V> extends Serializable {
+    /**
+     * Perform transformation of upstream.
+     *
+     * @param rnd Random numbers generator.
+     * @param upstream Upstream.
+     * @return Transformed upstream.
+     */
+    // TODO: IGNITE-10296: Inject capabilities of randomization through learning environment.
+    // TODO: IGNITE-10297: Investigate possibility of API change.
+    public Stream<UpstreamEntry<K, V>> transform(Random rnd, Stream<UpstreamEntry<K, V>> upstream);
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformerChain.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformerChain.java
new file mode 100644 (file)
index 0000000..dc83926
--- /dev/null
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.dataset;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+import java.util.stream.Stream;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+
+/**
+ * Class representing chain of transformers applied to upstream.
+ *
+ * @param <K> Type of upstream keys.
+ * @param <V> Type of upstream values.
+ */
+public class UpstreamTransformerChain<K, V> implements Serializable {
+    /** Seed used for transformations. */
+    private Long seed;
+
+    /** List of upstream transformations. */
+    private List<UpstreamTransformer<K, V>> list;
+
+    /**
+     * Creates empty upstream transformers chain (basically identity function).
+     *
+     * @param <K> Type of upstream keys.
+     * @param <V> Type of upstream values.
+     * @return Empty upstream transformers chain.
+     */
+    public static <K, V> UpstreamTransformerChain<K, V> empty() {
+        return new UpstreamTransformerChain<>();
+    }
+
+    /**
+     * Creates upstream transformers chain consisting of one specified transformer.
+     *
+     * @param <K> Type of upstream keys.
+     * @param <V> Type of upstream values.
+     * @return Upstream transformers chain consisting of one specified transformer.
+     */
+    public static <K, V> UpstreamTransformerChain<K, V> of(UpstreamTransformer<K, V> trans) {
+        UpstreamTransformerChain<K, V> res = new UpstreamTransformerChain<>();
+        return res.addUpstreamTransformer(trans);
+    }
+
+    /**
+     * Construct instance of this class.
+     */
+    private UpstreamTransformerChain() {
+        list = new ArrayList<>();
+        seed = new Random().nextLong();
+    }
+
+    /**
+     * Adds upstream transformer to this chain.
+     *
+     * @param next Transformer to add.
+     * @return This chain with added transformer.
+     */
+    public UpstreamTransformerChain<K, V> addUpstreamTransformer(UpstreamTransformer<K, V> next) {
+        list.add(next);
+
+        return this;
+    }
+
+    /**
+     * Add upstream transformer based on given lambda.
+     *
+     * @param transformer Transformer.
+     * @return This object.
+     */
+    public UpstreamTransformerChain<K, V> addUpstreamTransformer(IgniteFunction<Stream<UpstreamEntry<K, V>>,
+        Stream<UpstreamEntry<K, V>>> transformer) {
+        return addUpstreamTransformer((rnd, upstream) -> transformer.apply(upstream));
+    }
+
+    /**
+     * Performs stream transformation using RNG based on provided seed as pseudo-randomness source for all
+     * transformers in the chain.
+     *
+     * @param upstream Upstream.
+     * @return Transformed upstream.
+     */
+    public Stream<UpstreamEntry<K, V>> transform(Stream<UpstreamEntry<K, V>> upstream) {
+        Random rnd = new Random(seed);
+
+        Stream<UpstreamEntry<K, V>> res = upstream;
+
+        for (UpstreamTransformer<K, V> kvUpstreamTransformer : list) {
+            res = kvUpstreamTransformer.transform(rnd, res);
+        }
+
+        return res;
+    }
+
+    /**
+     * Checks if this chain is empty.
+     *
+     * @return Result of check if this chain is empty.
+     */
+    public boolean isEmpty() {
+        return list.isEmpty();
+    }
+
+    /**
+     * Set seed for transformations.
+     *
+     * @param seed Seed.
+     * @return This object.
+     */
+    public UpstreamTransformerChain<K, V> setSeed(long seed) {
+        this.seed = seed;
+
+        return this;
+    }
+
+    /**
+     * Modifies seed for transformations if it is present.
+     *
+     * @param f Modification function.
+     * @return This object.
+     */
+    public UpstreamTransformerChain<K, V> modifySeed(IgniteFunction<Long, Long> f) {
+        seed = f.apply(seed);
+
+        return this;
+    }
+
+    /**
+     * Get seed used for RNG in transformations.
+     *
+     * @return Seed used for RNG in transformations.
+     */
+    public Long seed() {
+        return seed;
+    }
+}
index e5eb483..0736906 100644 (file)
@@ -26,7 +26,9 @@ import org.apache.ignite.IgniteCache;
 import org.apache.ignite.Ignition;
 import org.apache.ignite.lang.IgniteBiPredicate;
 import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.PartitionDataBuilder;
+import org.apache.ignite.ml.dataset.UpstreamTransformerChain;
 import org.apache.ignite.ml.dataset.impl.cache.util.ComputeUtils;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
@@ -59,6 +61,9 @@ public class CacheBasedDataset<K, V, C extends Serializable, D extends AutoClose
     /** Filter for {@code upstream} data. */
     private final IgniteBiPredicate<K, V> filter;
 
+    /** Chain of transformers applied to upstream. */
+    private final UpstreamTransformerChain<K, V> upstreamTransformers;
+
     /** Ignite Cache with partition {@code context}. */
     private final IgniteCache<Integer, C> datasetCache;
 
@@ -75,16 +80,22 @@ public class CacheBasedDataset<K, V, C extends Serializable, D extends AutoClose
      * @param ignite Ignite instance.
      * @param upstreamCache Ignite Cache with {@code upstream} data.
      * @param filter Filter for {@code upstream} data.
+     * @param upstreamTransformers Transformers of upstream data (see description in {@link DatasetBuilder}).
      * @param datasetCache Ignite Cache with partition {@code context}.
      * @param partDataBuilder Partition {@code data} builder.
      * @param datasetId Dataset ID.
      */
-    public CacheBasedDataset(Ignite ignite, IgniteCache<K, V> upstreamCache, IgniteBiPredicate<K, V> filter,
+    public CacheBasedDataset(
+        Ignite ignite,
+        IgniteCache<K, V> upstreamCache,
+        IgniteBiPredicate<K, V> filter,
+        UpstreamTransformerChain<K, V> upstreamTransformers,
         IgniteCache<Integer, C> datasetCache, PartitionDataBuilder<K, V, C, D> partDataBuilder,
         UUID datasetId) {
         this.ignite = ignite;
         this.upstreamCache = upstreamCache;
         this.filter = filter;
+        this.upstreamTransformers = upstreamTransformers;
         this.datasetCache = datasetCache;
         this.partDataBuilder = partDataBuilder;
         this.datasetId = datasetId;
@@ -102,6 +113,7 @@ public class CacheBasedDataset<K, V, C extends Serializable, D extends AutoClose
                 Ignition.localIgnite(),
                 upstreamCacheName,
                 filter,
+                upstreamTransformers,
                 datasetCacheName,
                 datasetId,
                 part,
@@ -131,6 +143,7 @@ public class CacheBasedDataset<K, V, C extends Serializable, D extends AutoClose
                 Ignition.localIgnite(),
                 upstreamCacheName,
                 filter,
+                upstreamTransformers,
                 datasetCacheName,
                 datasetId,
                 part,
index 335ce63..1d00875 100644 (file)
@@ -27,6 +27,7 @@ import org.apache.ignite.lang.IgniteBiPredicate;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.PartitionContextBuilder;
 import org.apache.ignite.ml.dataset.PartitionDataBuilder;
+import org.apache.ignite.ml.dataset.UpstreamTransformerChain;
 import org.apache.ignite.ml.dataset.impl.cache.util.ComputeUtils;
 import org.apache.ignite.ml.dataset.impl.cache.util.DatasetAffinityFunctionWrapper;
 
@@ -56,6 +57,9 @@ public class CacheBasedDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
     /** Filter for {@code upstream} data. */
     private final IgniteBiPredicate<K, V> filter;
 
+    /** Chain of upstream transformers. */
+    private final UpstreamTransformerChain<K, V> transformersChain;
+
     /**
      * Constructs a new instance of cache based dataset builder that makes {@link CacheBasedDataset} with default
      * predicate that passes all upstream entries to dataset.
@@ -78,6 +82,7 @@ public class CacheBasedDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
         this.ignite = ignite;
         this.upstreamCache = upstreamCache;
         this.filter = filter;
+        transformersChain = UpstreamTransformerChain.empty();
     }
 
     /** {@inheritDoc} */
@@ -102,16 +107,24 @@ public class CacheBasedDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
             ignite,
             upstreamCache.getName(),
             filter,
+            transformersChain,
             datasetCache.getName(),
             partCtxBuilder,
             RETRIES,
             RETRY_INTERVAL
         );
 
-        return new CacheBasedDataset<>(ignite, upstreamCache, filter, datasetCache, partDataBuilder, datasetId);
+        return new CacheBasedDataset<>(ignite, upstreamCache, filter, transformersChain, datasetCache, partDataBuilder, datasetId);
     }
 
     /** {@inheritDoc} */
+    @Override public UpstreamTransformerChain<K, V> upstreamTransformersChain() {
+        return transformersChain;
+    }
+
+    /**
+     * {@inheritDoc}
+     */
     @Override public DatasetBuilder<K, V> withFilter(IgniteBiPredicate<K, V> filterToAdd) {
         return new CacheBasedDatasetBuilder<>(ignite, upstreamCache,
             (e1, e2) -> filter.apply(e1, e2) && filterToAdd.apply(e1, e2));
index a5cdd3b..6646e89 100644 (file)
@@ -27,6 +27,7 @@ import java.util.Iterator;
 import java.util.Map;
 import java.util.UUID;
 import java.util.concurrent.locks.LockSupport;
+import java.util.stream.Stream;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.IgniteException;
@@ -40,13 +41,17 @@ import org.apache.ignite.lang.IgniteFuture;
 import org.apache.ignite.ml.dataset.PartitionContextBuilder;
 import org.apache.ignite.ml.dataset.PartitionDataBuilder;
 import org.apache.ignite.ml.dataset.UpstreamEntry;
+import org.apache.ignite.ml.dataset.UpstreamTransformerChain;
 import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.util.Utils;
 
 /**
  * Util class that provides common methods to perform computations on top of the Ignite Compute Grid.
  */
 public class ComputeUtils {
-    /** Template of the key used to store partition {@code data} in local storage. */
+    /**
+     * Template of the key used to store partition {@code data} in local storage.
+     */
     private static final String DATA_STORAGE_KEY_TEMPLATE = "part_data_storage_%s";
 
     /**
@@ -136,6 +141,7 @@ public class ComputeUtils {
      * @param ignite Ignite instance.
      * @param upstreamCacheName Name of an {@code upstream} cache.
      * @param filter Filter for {@code upstream} data.
+     * @param transformersChain Upstream transformers.
      * @param datasetCacheName Name of a partition {@code context} cache.
      * @param datasetId Dataset ID.
      * @param part Partition index.
@@ -146,8 +152,13 @@ public class ComputeUtils {
      * @param <D> Type of a partition {@code data}.
      * @return Partition {@code data}.
      */
-    public static <K, V, C extends Serializable, D extends AutoCloseable> D getData(Ignite ignite,
-        String upstreamCacheName, IgniteBiPredicate<K, V> filter, String datasetCacheName, UUID datasetId, int part,
+    public static <K, V, C extends Serializable, D extends AutoCloseable> D getData(
+        Ignite ignite,
+        String upstreamCacheName, IgniteBiPredicate<K, V> filter,
+        UpstreamTransformerChain<K, V> transformersChain,
+        String datasetCacheName,
+        UUID datasetId,
+        int part,
         PartitionDataBuilder<K, V, C, D> partDataBuilder) {
 
         PartitionDataStorage dataStorage = (PartitionDataStorage)ignite
@@ -166,13 +177,22 @@ public class ComputeUtils {
             qry.setPartition(part);
             qry.setFilter(filter);
 
-            long cnt = computeCount(upstreamCache, qry);
+            UpstreamTransformerChain<K, V> chainCopy = Utils.copy(transformersChain);
+            chainCopy.modifySeed(s -> s + part);
+
+            long cnt = computeCount(upstreamCache, qry, chainCopy);
 
             if (cnt > 0) {
                 try (QueryCursor<UpstreamEntry<K, V>> cursor = upstreamCache.query(qry,
                     e -> new UpstreamEntry<>(e.getKey(), e.getValue()))) {
 
-                    Iterator<UpstreamEntry<K, V>> iter = new IteratorWithConcurrentModificationChecker<>(cursor.iterator(), cnt,
+                    Iterator<UpstreamEntry<K, V>> it = cursor.iterator();
+                    if (!chainCopy.isEmpty()) {
+                        Stream<UpstreamEntry<K, V>> transformedStream = chainCopy.transform(Utils.asStream(it, cnt));
+                        it = transformedStream.iterator();
+                    }
+
+                    Iterator<UpstreamEntry<K, V>> iter = new IteratorWithConcurrentModificationChecker<>(it, cnt,
                         "Cache expected to be not modified during dataset data building [partition=" + part + ']');
 
                     return partDataBuilder.build(iter, cnt, ctx);
@@ -193,21 +213,25 @@ public class ComputeUtils {
         ignite.cluster().nodeLocalMap().remove(String.format(DATA_STORAGE_KEY_TEMPLATE, datasetId));
     }
 
-
     /**
      * Initializes partition {@code context} by loading it from a partition {@code upstream}.
-     *
+     *  @param <K> Type of a key in {@code upstream} data.
+     * @param <V> Type of a value in {@code upstream} data.
+     * @param <C> Type of a partition {@code context}.
      * @param ignite Ignite instance.
      * @param upstreamCacheName Name of an {@code upstream} cache.
      * @param filter Filter for {@code upstream} data.
-     * @param datasetCacheName Name of a partition {@code context} cache.
+     * @param transformersChain Upstream data {@link Stream} transformers chain.
      * @param ctxBuilder Partition {@code context} builder.
-     * @param <K> Type of a key in {@code upstream} data.
-     * @param <V> Type of a value in {@code upstream} data.
-     * @param <C> Type of a partition {@code context}.
      */
-    public static <K, V, C extends Serializable> void initContext(Ignite ignite, String upstreamCacheName,
-        IgniteBiPredicate<K, V> filter, String datasetCacheName, PartitionContextBuilder<K, V, C> ctxBuilder, int retries,
+    public static <K, V, C extends Serializable> void initContext(
+        Ignite ignite,
+        String upstreamCacheName,
+        IgniteBiPredicate<K, V> filter,
+        UpstreamTransformerChain<K, V> transformersChain,
+        String datasetCacheName,
+        PartitionContextBuilder<K, V, C> ctxBuilder,
+        int retries,
         int interval) {
         affinityCallWithRetries(ignite, Arrays.asList(datasetCacheName, upstreamCacheName), part -> {
             Ignite locIgnite = Ignition.localIgnite();
@@ -219,13 +243,23 @@ public class ComputeUtils {
             qry.setPartition(part);
             qry.setFilter(filter);
 
-            long cnt = computeCount(locUpstreamCache, qry);
-
             C ctx;
+            UpstreamTransformerChain<K, V> chainCopy = Utils.copy(transformersChain);
+            chainCopy.modifySeed(s -> s + part);
+
+            long cnt = computeCount(locUpstreamCache, qry, transformersChain);
+
             try (QueryCursor<UpstreamEntry<K, V>> cursor = locUpstreamCache.query(qry,
                 e -> new UpstreamEntry<>(e.getKey(), e.getValue()))) {
 
-                Iterator<UpstreamEntry<K, V>> iter = new IteratorWithConcurrentModificationChecker<>(cursor.iterator(), cnt,
+                Iterator<UpstreamEntry<K, V>> it = cursor.iterator();
+                if (!chainCopy.isEmpty()) {
+                    Stream<UpstreamEntry<K, V>> transformedStream = chainCopy.transform(Utils.asStream(it, cnt));
+                    it = transformedStream.iterator();
+                }
+                Iterator<UpstreamEntry<K, V>> iter = new IteratorWithConcurrentModificationChecker<>(
+                    it,
+                    cnt,
                     "Cache expected to be not modified during dataset data building [partition=" + part + ']');
 
                 ctx = ctxBuilder.build(iter, cnt);
@@ -245,6 +279,7 @@ public class ComputeUtils {
      * @param ignite Ignite instance.
      * @param upstreamCacheName Name of an {@code upstream} cache.
      * @param filter Filter for {@code upstream} data.
+     * @param transformersChain Transformers of upstream data.
      * @param datasetCacheName Name of a partition {@code context} cache.
      * @param ctxBuilder Partition {@code context} builder.
      * @param retries Number of retries for the case when one of partitions not found on the node.
@@ -252,10 +287,15 @@ public class ComputeUtils {
      * @param <V> Type of a value in {@code upstream} data.
      * @param <C> Type of a partition {@code context}.
      */
-    public static <K, V, C extends Serializable> void initContext(Ignite ignite, String upstreamCacheName,
-        IgniteBiPredicate<K, V> filter, String datasetCacheName, PartitionContextBuilder<K, V, C> ctxBuilder,
+    public static <K, V, C extends Serializable> void initContext(
+        Ignite ignite,
+        String upstreamCacheName,
+        IgniteBiPredicate<K, V> filter,
+        UpstreamTransformerChain<K, V> transformersChain,
+        String datasetCacheName,
+        PartitionContextBuilder<K, V, C> ctxBuilder,
         int retries) {
-        initContext(ignite, upstreamCacheName, filter, datasetCacheName, ctxBuilder, retries, 0);
+        initContext(ignite, upstreamCacheName, filter, transformersChain, datasetCacheName, ctxBuilder, retries, 0);
     }
 
     /**
@@ -288,16 +328,25 @@ public class ComputeUtils {
     /**
      * Computes number of entries selected from the cache by the query.
      *
-     * @param cache Ignite cache with upstream data.
-     * @param qry Cache query.
      * @param <K> Type of a key in {@code upstream} data.
      * @param <V> Type of a value in {@code upstream} data.
+     * @param cache Ignite cache with upstream data.
+     * @param qry Cache query.
+     * @param transformersChain Transformers of stream of upstream data.
      * @return Number of entries supplied by the iterator.
      */
-    private static  <K, V> long computeCount(IgniteCache<K, V> cache, ScanQuery<K, V> qry) {
+    private static <K, V> long computeCount(
+        IgniteCache<K, V> cache,
+        ScanQuery<K, V> qry,
+        UpstreamTransformerChain<K, V> transformersChain) {
         try (QueryCursor<UpstreamEntry<K, V>> cursor = cache.query(qry,
             e -> new UpstreamEntry<>(e.getKey(), e.getValue()))) {
-            return computeCount(cursor.iterator());
+
+            // 'If' statement below is just for optimization, to avoid unnecessary iterator -> stream -> iterator
+            // operations.
+            return transformersChain.isEmpty() ?
+                computeCount(cursor.iterator()) :
+                computeCount(transformersChain.transform(Utils.asStream(cursor.iterator())).iterator());
         }
     }
 
index e312b20..975beda 100644 (file)
@@ -25,7 +25,7 @@ import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
 import org.apache.ignite.ml.math.functions.IgniteTriFunction;
 
 /**
- * An implementation of dataset based on local data structures such as {@code Map} and {@code List} and doesn't requires
+ * An implementation of dataset based on local data structures such as {@code Map} and {@code List} and doesn't require
  * Ignite environment. Introduces for testing purposes mostly, but can be used for simple local computations as well.
  *
  * @param <C> Type of a partition {@code context}.
index 6e0df2f..ce909ff 100644 (file)
@@ -19,7 +19,6 @@ package org.apache.ignite.ml.dataset.impl.local;
 
 import java.io.Serializable;
 import java.util.ArrayList;
-import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
@@ -28,7 +27,9 @@ import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.PartitionContextBuilder;
 import org.apache.ignite.ml.dataset.PartitionDataBuilder;
 import org.apache.ignite.ml.dataset.UpstreamEntry;
+import org.apache.ignite.ml.dataset.UpstreamTransformerChain;
 import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.util.Utils;
 
 /**
  * A dataset builder that makes {@link LocalDataset}. Encapsulate logic of building local dataset such as allocation
@@ -47,6 +48,9 @@ public class LocalDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
     /** Filter for {@code upstream} data. */
     private final IgniteBiPredicate<K, V> filter;
 
+    /** Upstream transformers. */
+    private final UpstreamTransformerChain<K, V> upstreamTransformers;
+
     /**
      * Constructs a new instance of local dataset builder that makes {@link LocalDataset} with default predicate that
      * passes all upstream entries to dataset.
@@ -69,6 +73,7 @@ public class LocalDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
         this.upstreamMap = upstreamMap;
         this.filter = filter;
         this.partitions = partitions;
+        this.upstreamTransformers = UpstreamTransformerChain.empty();
     }
 
     /** {@inheritDoc} */
@@ -77,28 +82,55 @@ public class LocalDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
         List<C> ctxList = new ArrayList<>();
         List<D> dataList = new ArrayList<>();
 
-        Map<K, V> filteredMap = new HashMap<>();
-        upstreamMap.forEach((key, val) -> {
-            if (filter.apply(key, val))
-                filteredMap.put(key, val);
-        });
+        List<UpstreamEntry<K, V>> entriesList = new ArrayList<>();
+
+        upstreamMap
+            .entrySet()
+            .stream()
+            .filter(en -> filter.apply(en.getKey(), en.getValue()))
+            .map(en -> new UpstreamEntry<>(en.getKey(), en.getValue()))
+            .forEach(entriesList::add);
 
-        int partSize = Math.max(1, filteredMap.size() / partitions);
+        int partSize = Math.max(1, entriesList.size() / partitions);
 
-        Iterator<K> firstKeysIter = filteredMap.keySet().iterator();
-        Iterator<K> secondKeysIter = filteredMap.keySet().iterator();
+        Iterator<UpstreamEntry<K, V>> firstKeysIter = entriesList.iterator();
+        Iterator<UpstreamEntry<K, V>> secondKeysIter = entriesList.iterator();
+        Iterator<UpstreamEntry<K, V>> thirdKeysIter = entriesList.iterator();
 
         int ptr = 0;
-        for (int part = 0; part < partitions; part++) {
-            int cnt = part == partitions - 1 ? filteredMap.size() - ptr : Math.min(partSize, filteredMap.size() - ptr);
 
-            C ctx = cnt > 0 ? partCtxBuilder.build(
-                new IteratorWindow<>(firstKeysIter, k -> new UpstreamEntry<>(k, filteredMap.get(k)), cnt),
-                cnt
-            ) : null;
+        for (int part = 0; part < partitions; part++) {
+            int cnt = part == partitions - 1 ? entriesList.size() - ptr : Math.min(partSize, entriesList.size() - ptr);
+
+            int p = part;
+            upstreamTransformers.modifySeed(s -> s + p);
+
+            if (!upstreamTransformers.isEmpty()) {
+                cnt = (int)upstreamTransformers.transform(
+                    Utils.asStream(new IteratorWindow<>(thirdKeysIter, k -> k, cnt))).count();
+            }
+
+            Iterator<UpstreamEntry<K, V>> iter;
+            if (upstreamTransformers.isEmpty()) {
+                iter = new IteratorWindow<>(firstKeysIter, k -> k, cnt);
+            }
+            else {
+                iter = upstreamTransformers.transform(
+                    Utils.asStream(new IteratorWindow<>(firstKeysIter, k -> k, cnt))).iterator();
+            }
+            C ctx = cnt > 0 ? partCtxBuilder.build(iter, cnt) : null;
+
+            Iterator<UpstreamEntry<K, V>> iter1;
+            if (upstreamTransformers.isEmpty()) {
+                iter1 = upstreamTransformers.transform(
+                    Utils.asStream(new IteratorWindow<>(secondKeysIter, k -> k, cnt))).iterator();
+            }
+            else {
+                iter1 = new IteratorWindow<>(secondKeysIter, k -> k, cnt);
+            }
 
             D data = cnt > 0 ? partDataBuilder.build(
-                new IteratorWindow<>(secondKeysIter, k -> new UpstreamEntry<>(k, filteredMap.get(k)), cnt),
+                iter1,
                 cnt,
                 ctx
             ) : null;
@@ -113,6 +145,13 @@ public class LocalDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
     }
 
     /** {@inheritDoc} */
+    @Override public UpstreamTransformerChain<K, V> upstreamTransformersChain() {
+        return upstreamTransformers;
+    }
+
+    /**
+     * {@inheritDoc}
+     */
     @Override public DatasetBuilder<K, V> withFilter(IgniteBiPredicate<K, V> filterToAdd) {
         return new LocalDatasetBuilder<>(upstreamMap,
             (e1, e2) -> filter.apply(e1, e2) && filterToAdd.apply(e1, e2), partitions);
@@ -126,16 +165,24 @@ public class LocalDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
      * @param <T> Target type of entries.
      */
     private static class IteratorWindow<K, T> implements Iterator<T> {
-        /** Delegate iterator. */
+        /**
+         * Delegate iterator.
+         */
         private final Iterator<K> delegate;
 
-        /** Transformer that transforms entries from one type to another. */
+        /**
+         * Transformer that transforms entries from one type to another.
+         */
         private final IgniteFunction<K, T> map;
 
-        /** Count of entries to produce. */
+        /**
+         * Count of entries to produce.
+         */
         private final int cnt;
 
-        /** Number of already produced entries. */
+        /**
+         * Number of already produced entries.
+         */
         private int ptr;
 
         /**
@@ -151,12 +198,16 @@ public class LocalDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
             this.cnt = cnt;
         }
 
-        /** {@inheritDoc} */
+        /**
+         * {@inheritDoc}
+         */
         @Override public boolean hasNext() {
             return delegate.hasNext() && ptr < cnt;
         }
 
-        /** {@inheritDoc} */
+        /**
+         * {@inheritDoc}
+         */
         @Override public T next() {
             ++ptr;
 
index 91e832d..98f584f 100644 (file)
@@ -35,7 +35,7 @@ public class LearningEnvironmentBuilder {
     /**
      * Creates an instance of LearningEnvironmentBuilder.
      */
-    LearningEnvironmentBuilder() {
+    public LearningEnvironmentBuilder() {
         parallelismStgy = NoParallelismStrategy.INSTANCE;
         loggingFactory = NoOpLogger.factory();
     }
index 74a296d..47fa59d 100644 (file)
@@ -74,16 +74,15 @@ public class LogisticRegressionSGDTrainer<P extends Serializable> extends Single
         IgniteBiFunction<K, V, Double> lbExtractor) {
 
         IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier = dataset -> {
-            int cols = dataset.compute(data -> {
+            Integer cols = dataset.compute(data -> {
                 if (data.getFeatures() == null)
                     return null;
                 return data.getFeatures().length / data.getRows();
             }, (a, b) -> {
+                // If both are null then zero will be propagated, no good.
                 if (a == null)
-                    return b == null ? 0 : b;
-                if (b == null)
-                    return a;
-                return b;
+                    return b;
+                return a;
             });
 
             MLPArchitecture architecture = new MLPArchitecture(cols);
index 5c3913e..f321744 100644 (file)
@@ -310,4 +310,5 @@ public abstract class DatasetTrainer<M extends Model, L> {
             super("Cannot train model on empty dataset");
         }
     }
+
 }
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java
new file mode 100644 (file)
index 0000000..4f11327
--- /dev/null
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import org.apache.ignite.lang.IgniteBiPredicate;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.composition.ModelsComposition;
+import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator;
+import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.PartitionContextBuilder;
+import org.apache.ignite.ml.dataset.PartitionDataBuilder;
+import org.apache.ignite.ml.dataset.UpstreamTransformerChain;
+import org.apache.ignite.ml.environment.LearningEnvironment;
+import org.apache.ignite.ml.environment.logging.MLLogger;
+import org.apache.ignite.ml.environment.parallelism.Promise;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.functions.IgniteSupplier;
+import org.apache.ignite.ml.math.functions.IgniteTriFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.trainers.transformers.BaggingUpstreamTransformer;
+import org.apache.ignite.ml.util.Utils;
+
+/**
+ * Class containing various trainer transformers.
+ */
+public class TrainerTransformers {
+    /**
+     * Add bagging logic to a given trainer.
+     *
+     * @param ensembleSize Size of ensemble.
+     * @param subsampleRatio Subsample ratio to whole dataset.
+     * @param aggregator Aggregator.
+     * @param <M> Type of one model in ensemble.
+     * @param <L> Type of labels.
+     * @return Bagged trainer.
+     */
+    public static <M extends Model<Vector, Double>, L> DatasetTrainer<ModelsComposition, L> makeBagged(
+        DatasetTrainer<M, L> trainer,
+        int ensembleSize,
+        double subsampleRatio,
+        PredictionsAggregator aggregator) {
+        return makeBagged(trainer, ensembleSize, subsampleRatio, -1, -1, aggregator, new Random().nextLong());
+    }
+
+    /**
+     * Add bagging logic to a given trainer.
+     *
+     * @param ensembleSize Size of ensemble.
+     * @param subsampleRatio Subsample ratio to whole dataset.
+     * @param aggregator Aggregator.
+     * @param featureVectorSize Feature vector dimensionality.
+     * @param featuresSubspaceDim Feature subspace dimensionality.
+     * @param transformationSeed Transformations seed.
+     * @param <M> Type of one model in ensemble.
+     * @param <L> Type of labels.
+     * @return Bagged trainer.
+     */
+    // TODO: IGNITE-10296: Inject capabilities of seeding through learning environment (remove).
+    public static <M extends Model<Vector, Double>, L> DatasetTrainer<ModelsComposition, L> makeBagged(
+        DatasetTrainer<M, L> trainer,
+        int ensembleSize,
+        double subsampleRatio,
+        int featureVectorSize,
+        int featuresSubspaceDim,
+        PredictionsAggregator aggregator,
+        Long transformationSeed) {
+        return new DatasetTrainer<ModelsComposition, L>() {
+            /** {@inheritDoc} */
+            @Override public <K, V> ModelsComposition fit(
+                DatasetBuilder<K, V> datasetBuilder,
+                IgniteBiFunction<K, V, Vector> featureExtractor,
+                IgniteBiFunction<K, V, L> lbExtractor) {
+                datasetBuilder.upstreamTransformersChain().setSeed(
+                    transformationSeed == null
+                        ? new Random().nextLong()
+                        : transformationSeed);
+
+                return runOnEnsemble(
+                    (db, i, fe) -> (() -> trainer.fit(db, fe, lbExtractor)),
+                    datasetBuilder,
+                    ensembleSize,
+                    subsampleRatio,
+                    featureVectorSize,
+                    featuresSubspaceDim,
+                    featureExtractor,
+                    aggregator,
+                    environment);
+            }
+
+            /** {@inheritDoc} */
+            @Override protected boolean checkState(ModelsComposition mdl) {
+                return mdl.getModels().stream().allMatch(m -> trainer.checkState((M)m));
+            }
+
+            /** {@inheritDoc} */
+            @Override protected <K, V> ModelsComposition updateModel(
+                ModelsComposition mdl,
+                DatasetBuilder<K, V> datasetBuilder,
+                IgniteBiFunction<K, V, Vector> featureExtractor,
+                IgniteBiFunction<K, V, L> lbExtractor) {
+                return runOnEnsemble(
+                    (db, i, fe) -> (() -> trainer.updateModel(
+                        ((ModelWithMapping<Vector, Double, M>)mdl.getModels().get(i)).model(),
+                        db,
+                        fe,
+                        lbExtractor)),
+                    datasetBuilder,
+                    ensembleSize,
+                    subsampleRatio,
+                    featureVectorSize,
+                    featuresSubspaceDim,
+                    featureExtractor,
+                    aggregator,
+                    environment);
+            }
+        };
+    }
+
+    /**
+     * This method accepts function which for given dataset builder and index of model in ensemble generates
+     * task of training this model.
+     *
+     * @param trainingTaskGenerator Training test generator.
+     * @param datasetBuilder Dataset builder.
+     * @param ensembleSize Size of ensemble.
+     * @param subsampleRatio Ratio (subsample size) / (initial dataset size).
+     * @param featuresVectorSize Dimensionality of feature vector.
+     * @param featureSubspaceDim Dimensionality of feature subspace.
+     * @param aggregator Aggregator of models.
+     * @param environment Environment.
+     * @param <K> Type of keys in dataset builder.
+     * @param <V> Type of values in dataset builder.
+     * @param <M> Type of model.
+     * @return Composition of models trained on bagged dataset.
+     */
+    private static <K, V, M extends Model<Vector, Double>> ModelsComposition runOnEnsemble(
+        IgniteTriFunction<DatasetBuilder<K, V>, Integer, IgniteBiFunction<K, V, Vector>, IgniteSupplier<M>> trainingTaskGenerator,
+        DatasetBuilder<K, V> datasetBuilder,
+        int ensembleSize,
+        double subsampleRatio,
+        int featuresVectorSize,
+        int featureSubspaceDim,
+        IgniteBiFunction<K, V, Vector> extractor,
+        PredictionsAggregator aggregator,
+        LearningEnvironment environment) {
+
+        MLLogger log = environment.logger(datasetBuilder.getClass());
+        log.log(MLLogger.VerboseLevel.LOW, "Start learning.");
+
+        List<int[]> mappings = null;
+        if (featuresVectorSize > 0) {
+            mappings = IntStream.range(0, ensembleSize).mapToObj(
+                modelIdx -> getMapping(
+                    featuresVectorSize,
+                    featureSubspaceDim,
+                    datasetBuilder.upstreamTransformersChain().seed() + modelIdx))
+                .collect(Collectors.toList());
+        }
+
+        Long startTs = System.currentTimeMillis();
+
+        datasetBuilder
+            .upstreamTransformersChain()
+            .addUpstreamTransformer(new BaggingUpstreamTransformer<>(subsampleRatio));
+
+        List<IgniteSupplier<M>> tasks = new ArrayList<>();
+        List<IgniteBiFunction<K, V, Vector>> extractors = new ArrayList<>();
+        if (mappings != null) {
+            for (int[] mapping : mappings) {
+                extractors.add(wrapExtractor(extractor, mapping));
+            }
+        }
+
+        for (int i = 0; i < ensembleSize; i++) {
+            UpstreamTransformerChain<K, V> newChain = Utils.copy(datasetBuilder.upstreamTransformersChain());
+            DatasetBuilder<K, V> newBuilder = withNewChain(datasetBuilder, newChain);
+            int j = i;
+            newChain.modifySeed(s -> s * s + j);
+            tasks.add(
+                trainingTaskGenerator.apply(newBuilder, i, mappings != null ? extractors.get(i) : extractor));
+        }
+
+        List<ModelWithMapping<Vector, Double, M>> models = environment.parallelismStrategy().submit(tasks)
+            .stream()
+            .map(Promise::unsafeGet)
+            .map(ModelWithMapping<Vector, Double, M>::new)
+            .collect(Collectors.toList());
+
+        // If we need to do projection, do it.
+        if (mappings != null) {
+            for (int i = 0; i < models.size(); i++) {
+                models.get(i).setMapping(getProjector(mappings.get(i)));
+            }
+        }
+
+        double learningTime = (double)(System.currentTimeMillis() - startTs) / 1000.0;
+        log.log(MLLogger.VerboseLevel.LOW, "The training time was %.2fs.", learningTime);
+        log.log(MLLogger.VerboseLevel.LOW, "Learning finished.");
+
+        return new ModelsComposition(models, aggregator);
+    }
+
+    /**
+     * Get mapping R^featuresVectorSize -> R^maximumFeaturesCntPerMdl.
+     *
+     * @param featuresVectorSize Features vector size (Dimension of initial space).
+     * @param maximumFeaturesCntPerMdl Dimension of target space.
+     * @param seed Seed.
+     * @return Mapping R^featuresVectorSize -> R^maximumFeaturesCntPerMdl.
+     */
+    public static int[] getMapping(int featuresVectorSize, int maximumFeaturesCntPerMdl, long seed) {
+        return Utils.selectKDistinct(featuresVectorSize, maximumFeaturesCntPerMdl, new Random(seed));
+    }
+
+    /**
+     * Get projector from index mapping.
+     *
+     * @param mapping Index mapping.
+     * @return Projector.
+     */
+    public static IgniteFunction<Vector, Vector> getProjector(int[] mapping) {
+        return v -> {
+            Vector res = VectorUtils.zeroes(mapping.length);
+            for (int i = 0; i < mapping.length; i++) {
+                res.set(i, v.get(mapping[i]));
+            }
+            return res;
+        };
+    }
+
+    /**
+     * Creates feature extractor which is a composition of given feature extractor and projection given by
+     * coordinate indexes mapping.
+     *
+     * @param featureExtractor Initial feature extractor.
+     * @param featureMapping Coordinate indexes mapping.
+     * @param <K> Type of keys.
+     * @param <V> Type of values.
+     * @return Composition of given feature extractor and projection given by coordinate indexes mapping.
+     */
+    private static <K, V> IgniteBiFunction<K, V, Vector> wrapExtractor(IgniteBiFunction<K, V, Vector> featureExtractor,
+        int[] featureMapping) {
+        return featureExtractor.andThen((IgniteFunction<Vector, Vector>)featureValues -> {
+            double[] newFeaturesValues = new double[featureMapping.length];
+            for (int i = 0; i < featureMapping.length; i++) {
+                newFeaturesValues[i] = featureValues.get(featureMapping[i]);
+            }
+            return VectorUtils.of(newFeaturesValues);
+        });
+    }
+
+    /**
+     * Model with mapping from X to X.
+     *
+     * @param <X> Input space.
+     * @param <Y> Output space.
+     * @param <M> Model.
+     */
+    private static class ModelWithMapping<X, Y, M extends Model<X, Y>> implements Model<X, Y> {
+        /** Model. */
+        private final M model;
+
+        /** Mapping. */
+        private IgniteFunction<X, X> mapping;
+
+        /**
+         * Create instance of this class from a given model.
+         * Identity mapping will be used as a mapping.
+         *
+         * @param model Model.
+         */
+        public ModelWithMapping(M model) {
+            this(model, x -> x);
+        }
+
+        /**
+         * Create instance of this class from given model and mapping.
+         *
+         * @param model Model.
+         * @param mapping Mapping.
+         */
+        public ModelWithMapping(M model, IgniteFunction<X, X> mapping) {
+            this.model = model;
+            this.mapping = mapping;
+        }
+
+        /**
+         * Sets mapping.
+         *
+         * @param mapping Mapping.
+         */
+        public void setMapping(IgniteFunction<X, X> mapping) {
+            this.mapping = mapping;
+        }
+
+        /** {@inheritDoc} */
+        @Override public Y apply(X x) {
+            return model.apply(mapping.apply(x));
+        }
+
+        /**
+         * Gets model.
+         *
+         * @return Model.
+         */
+        public M model() {
+            return model;
+        }
+
+        /**
+         * Gets mapping.
+         *
+         * @return Mapping.
+         */
+        public IgniteFunction<X, X> mapping() {
+            return mapping;
+        }
+    }
+
+    /**
+     * Creates new dataset builder which is delegate of a given dataset builder in everything except
+     * new transformations chain.
+     *
+     * @param builder Initial builder.
+     * @param chain New chain.
+     * @param <K> Type of keys.
+     * @param <V> Type of values.
+     * @return new dataset builder which is delegate of a given dataset builder in everything except
+     * new transformations chain.
+     */
+    private static <K, V> DatasetBuilder<K, V> withNewChain(
+        DatasetBuilder<K, V> builder,
+        UpstreamTransformerChain<K, V> chain) {
+        return new DatasetBuilder<K, V>() {
+            /** {@inheritDoc} */
+            @Override public <C extends Serializable, D extends AutoCloseable> Dataset<C, D> build(
+                PartitionContextBuilder<K, V, C> partCtxBuilder, PartitionDataBuilder<K, V, C, D> partDataBuilder) {
+                return builder.build(partCtxBuilder, partDataBuilder);
+            }
+
+            /** {@inheritDoc} */
+            @Override public UpstreamTransformerChain<K, V> upstreamTransformersChain() {
+                return chain;
+            }
+
+            /** {@inheritDoc} */
+            @Override public DatasetBuilder<K, V> withFilter(IgniteBiPredicate<K, V> filterToAdd) {
+                return builder.withFilter(filterToAdd);
+            }
+        };
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/transformers/BaggingUpstreamTransformer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/transformers/BaggingUpstreamTransformer.java
new file mode 100644 (file)
index 0000000..f935ebd
--- /dev/null
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.transformers;
+
+import java.util.Random;
+import java.util.stream.Stream;
+import org.apache.commons.math3.distribution.PoissonDistribution;
+import org.apache.commons.math3.random.Well19937c;
+import org.apache.ignite.ml.dataset.UpstreamEntry;
+import org.apache.ignite.ml.dataset.UpstreamTransformer;
+
+/**
+ * This class encapsulates the logic needed to do bagging (bootstrap aggregating) by features.
+ * The action of this class on a given upstream is to replicate each entry in accordance to
+ * Poisson distribution.
+ *
+ * @param <K> Type of upstream keys.
+ * @param <V> Type of upstream values.
+ */
+public class BaggingUpstreamTransformer<K, V> implements UpstreamTransformer<K, V> {
+    /** Ratio of subsample to entire upstream size */
+    private double subsampleRatio;
+
+    /**
+     * Construct instance of this transformer with a given subsample ratio.
+     *
+     * @param subsampleRatio Subsample ratio.
+     */
+    public BaggingUpstreamTransformer(double subsampleRatio) {
+        this.subsampleRatio = subsampleRatio;
+    }
+
+    /** {@inheritDoc} */
+    @Override public Stream<UpstreamEntry<K, V>> transform(Random rnd, Stream<UpstreamEntry<K, V>> upstream) {
+        PoissonDistribution poisson = new PoissonDistribution(
+            new Well19937c(rnd.nextLong()),
+            subsampleRatio,
+            PoissonDistribution.DEFAULT_EPSILON,
+            PoissonDistribution.DEFAULT_MAX_ITERATIONS);
+
+        return upstream.sequential().flatMap(en -> Stream.generate(() -> en).limit(poisson.sample()));
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/transformers/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/transformers/package-info.java
new file mode 100644 (file)
index 0000000..b698ead
--- /dev/null
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Various upstream transformers.
+ */
+package org.apache.ignite.ml.trainers.transformers;
\ No newline at end of file
index 8320461..d202441 100644 (file)
@@ -45,7 +45,7 @@ public abstract class ImpurityHistogramsComputer<S extends ImpurityComputer<Boot
     private static final long serialVersionUID = -4984067145908187508L;
 
     /**
-     * Computes histograms for each features.
+     * Computes histograms for each feature.
      *
      * @param roots Random forest roots.
      * @param histMeta Histograms meta.
index ed0ebd3..63a9f3c 100644 (file)
@@ -22,7 +22,12 @@ import java.io.ByteArrayOutputStream;
 import java.io.IOException;
 import java.io.ObjectInputStream;
 import java.io.ObjectOutputStream;
+import java.util.Iterator;
 import java.util.Random;
+import java.util.Spliterator;
+import java.util.Spliterators;
+import java.util.stream.Stream;
+import java.util.stream.StreamSupport;
 import org.apache.ignite.IgniteException;
 
 /**
@@ -98,4 +103,31 @@ public class Utils {
     public static int[] selectKDistinct(int n, int k) {
         return selectKDistinct(n, k, new Random());
     }
+
+    /**
+     * Convert given iterator to a stream with known count of entries.
+     *
+     * @param iter Iterator.
+     * @param cnt Count.
+     * @param <T> Type of entries.
+     * @return Stream constructed from iterator.
+     */
+    public static <T> Stream<T> asStream(Iterator<T> iter, long cnt) {
+        return StreamSupport.stream(
+                Spliterators.spliterator(iter, cnt, Spliterator.ORDERED),
+                false);
+    }
+
+    /**
+     * Convert given iterator to a stream.
+     *
+     * @param iter Iterator.
+     * @param <T> Iterator content type.
+     * @return Stream constructed from iterator.
+     */
+    public static <T> Stream<T> asStream(Iterator<T> iter) {
+        return StreamSupport.stream(
+                Spliterators.spliteratorUnknownSize(iter, Spliterator.ORDERED),
+                false);
+    }
 }
index 481e1fa..e26b5b8 100644 (file)
@@ -32,6 +32,7 @@ import org.apache.ignite.ml.regressions.RegressionsTestSuite;
 import org.apache.ignite.ml.selection.SelectionTestSuite;
 import org.apache.ignite.ml.structures.StructuresTestSuite;
 import org.apache.ignite.ml.svm.SVMTestSuite;
+import org.apache.ignite.ml.trainers.BaggingTest;
 import org.apache.ignite.ml.tree.DecisionTreeTestSuite;
 import org.junit.runner.RunWith;
 import org.junit.runners.Suite;
@@ -57,7 +58,8 @@ import org.junit.runners.Suite;
     CompositionTestSuite.class,
     EnvironmentTestSuite.class,
     StructuresTestSuite.class,
-    CommonTestSuite.class
+    CommonTestSuite.class,
+    BaggingTest.class
 })
 public class IgniteMLTestSuite {
     // No-op.
index 952fc43..cee8f4f 100644 (file)
@@ -33,6 +33,7 @@ import org.apache.ignite.cluster.ClusterNode;
 import org.apache.ignite.configuration.CacheConfiguration;
 import org.apache.ignite.internal.util.IgniteUtils;
 import org.apache.ignite.ml.dataset.UpstreamEntry;
+import org.apache.ignite.ml.dataset.UpstreamTransformerChain;
 import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
 
 /**
@@ -178,6 +179,7 @@ public class ComputeUtilsTest extends GridCommonAbstractTest {
                     ignite,
                     upstreamCacheName,
                     (k, v) -> true,
+                    UpstreamTransformerChain.empty(),
                     datasetCacheName,
                     datasetId,
                     0,
@@ -227,6 +229,7 @@ public class ComputeUtilsTest extends GridCommonAbstractTest {
             ignite,
             upstreamCacheName,
             (k, v) -> true,
+            UpstreamTransformerChain.empty(),
             datasetCacheName,
             (upstream, upstreamSize) -> {
 
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/BaggingTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/BaggingTest.java
new file mode 100644 (file)
index 0000000..c22da04
--- /dev/null
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.TestUtils;
+import org.apache.ignite.ml.common.TrainerTest;
+import org.apache.ignite.ml.composition.ModelsComposition;
+import org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictionsAggregator;
+import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator;
+import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.functions.IgniteTriFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.nn.UpdatesStrategy;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
+import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel;
+import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer;
+import org.junit.Test;
+
+/**
+ * Tests for bagging algorithm.
+ */
+public class BaggingTest extends TrainerTest {
+    /**
+     * Test that count of entries in context is equal to initial dataset size * subsampleRatio.
+     */
+    @Test
+    public void testBaggingContextCount() {
+        count((ctxCount, countData, integer) -> ctxCount);
+    }
+
+    /**
+     * Test that count of entries in data is equal to initial dataset size * subsampleRatio.
+     */
+    @Test
+    public void testBaggingDataCount() {
+        count((ctxCount, countData, integer) -> countData.cnt);
+    }
+
+    /**
+     * Test that bagged log regression makes correct predictions.
+     */
+    @Test
+    public void testNaiveBaggingLogRegression() {
+        Map<Integer, Double[]> cacheMock = getCacheMock();
+
+        DatasetTrainer<LogisticRegressionModel, Double> trainer =
+            (LogisticRegressionSGDTrainer<?>)new LogisticRegressionSGDTrainer<>()
+                .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2),
+                    SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg))
+                .withMaxIterations(30000)
+                .withLocIterations(100)
+                .withBatchSize(10)
+                .withSeed(123L);
+
+        DatasetTrainer<ModelsComposition, Double> baggedTrainer =
+            TrainerTransformers.makeBagged(
+                trainer,
+                10,
+                0.7,
+                new OnMajorityPredictionsAggregator());
+
+        ModelsComposition mdl = baggedTrainer.fit(
+            cacheMock,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        );
+
+        TestUtils.assertEquals(0, mdl.apply(VectorUtils.of(100, 10)), PRECISION);
+        TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(10, 100)), PRECISION);
+    }
+
+    /**
+     * Method used to test counts of data passed in context and in data builders.
+     *
+     * @param counter Function specifying which data we should count.
+     */
+    protected void count(IgniteTriFunction<Long, CountData, Integer, Long> counter) {
+        Map<Integer, Double[]> cacheMock = getCacheMock();
+
+        CountTrainer countTrainer = new CountTrainer(counter);
+
+        double subsampleRatio = 0.3;
+
+        ModelsComposition model = TrainerTransformers.makeBagged(countTrainer, 100, subsampleRatio, new MeanValuePredictionsAggregator())
+            .fit(cacheMock, parts, null, null);
+
+        Double res = model.apply(null);
+
+        TestUtils.assertEquals(twoLinearlySeparableClasses.length * subsampleRatio, res, twoLinearlySeparableClasses.length / 10);
+    }
+
+    /**
+     * Create cache mock.
+     *
+     * @return Cache mock.
+     */
+    private Map<Integer, Double[]> getCacheMock() {
+        Map<Integer, Double[]> cacheMock = new HashMap<>();
+
+        for (int i = 0; i < twoLinearlySeparableClasses.length; i++) {
+            double[] row = twoLinearlySeparableClasses[i];
+            Double[] convertedRow = new Double[row.length];
+            for (int j = 0; j < row.length; j++)
+                convertedRow[j] = row[j];
+            cacheMock.put(i, convertedRow);
+        }
+        return cacheMock;
+    }
+
+    /**
+     * Get sum of two Long values each of which can be null.
+     *
+     * @param a First value.
+     * @param b Second value.
+     * @return Sum of parameters.
+     */
+    protected static Long plusOfNullables(Long a, Long b) {
+        if (a == null) {
+            return b;
+        }
+        if (b == null) {
+            return a;
+        }
+
+        return a + b;
+    }
+
+    /**
+     * Trainer used to count entries in context or in data.
+     */
+    protected static class CountTrainer extends DatasetTrainer<Model<Vector, Double>, Double> {
+        /**
+         * Function specifying which entries to count.
+         */
+        private final IgniteTriFunction<Long, CountData, Integer, Long> counter;
+
+        /**
+         * Construct instance of this class.
+         *
+         * @param counter Function specifying which entries to count.
+         */
+        public CountTrainer(IgniteTriFunction<Long, CountData, Integer, Long> counter) {
+            this.counter = counter;
+        }
+
+        /** {@inheritDoc} */
+        @Override public <K, V> Model<Vector, Double> fit(
+            DatasetBuilder<K, V> datasetBuilder,
+            IgniteBiFunction<K, V, Vector> featureExtractor,
+            IgniteBiFunction<K, V, Double> lbExtractor) {
+            Dataset<Long, CountData> dataset = datasetBuilder.build(
+                (upstreamData, upstreamDataSize) -> upstreamDataSize,
+                (upstreamData, upstreamDataSize, ctx) -> new CountData(upstreamDataSize)
+            );
+
+            Long cnt = dataset.computeWithCtx(counter, BaggingTest::plusOfNullables);
+
+            return x -> Double.valueOf(cnt);
+        }
+
+        /** {@inheritDoc} */
+        @Override protected boolean checkState(Model<Vector, Double> mdl) {
+            return true;
+        }
+
+        /** {@inheritDoc} */
+        @Override protected <K, V> Model<Vector, Double> updateModel(
+            Model<Vector, Double> mdl,
+            DatasetBuilder<K, V> datasetBuilder,
+            IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
+            return fit(datasetBuilder, featureExtractor, lbExtractor);
+        }
+    }
+
+    /** Data for count trainer. */
+    protected static class CountData implements AutoCloseable {
+        /** Counter. */
+        private long cnt;
+
+        /**
+         * Construct instance of this class.
+         *
+         * @param cnt Counter.
+         */
+        public CountData(long cnt) {
+            this.cnt = cnt;
+        }
+
+        /** {@inheritDoc} */
+        @Override public void close() throws Exception {
+            // No-op
+        }
+    }
+}