IGNITE-10272: [ML] Inject learning environment into scope of
authorArtem Malykh <amalykhgh@gmail.com>
Tue, 4 Dec 2018 11:57:31 +0000 (14:57 +0300)
committerYury Babak <ybabak@gridgain.com>
Tue, 4 Dec 2018 11:57:31 +0000 (14:57 +0300)
dataset compute task

This closes #5484

91 files changed:
examples/src/main/java/org/apache/ignite/examples/ml/dataset/AlgorithmSpecificDatasetExample.java
examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java
examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestRegressionExample.java
modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBBinaryClassifierTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java
modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBRegressionTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceChecker.java
modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/simple/ConvergenceCheckerStub.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/Dataset.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetBuilder.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetFactory.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/PartitionContextBuilder.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/PartitionDataBuilder.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformer.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformerBuilder.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformerChain.java [deleted file]
modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/bootstrapping/BootstrappedDatasetBuilder.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilder.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDataset.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/DatasetWrapper.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/FeatureMatrixWithLabelsOnHeapDataBuilder.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/builder/context/EmptyContextBuilder.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/builder/data/SimpleDatasetDataBuilder.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/builder/data/SimpleLabeledDatasetDataBuilder.java
modules/ml/src/main/java/org/apache/ignite/ml/environment/DefaultLearningEnvironmentBuilder.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/environment/LearningEnvironment.java
modules/ml/src/main/java/org/apache/ignite/ml/environment/LearningEnvironmentBuilder.java
modules/ml/src/main/java/org/apache/ignite/ml/environment/logging/ConsoleLogger.java
modules/ml/src/main/java/org/apache/ignite/ml/environment/parallelism/ParallelismStrategy.java
modules/ml/src/main/java/org/apache/ignite/ml/knn/KNNUtils.java
modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/math/functions/IgniteFunction.java
modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java
modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/VectorUtils.java
modules/ml/src/main/java/org/apache/ignite/ml/multiclass/OneVsRestTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/pipeline/Pipeline.java
modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/PreprocessingTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/imputing/ImputerTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/maxabsscaling/MaxAbsScalerTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/minmaxscaling/MinMaxScalerTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionDataBuilderOnHeap.java
modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabeledDatasetPartitionDataBuilderOnHeap.java
modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java
modules/ml/src/main/java/org/apache/ignite/ml/trainers/transformers/BaggingUpstreamTransformer.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBBinaryClassifierOnTreesTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBRegressionOnTreesTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java
modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java
modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceCheckerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceCheckerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilderTest.java
modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetTest.java
modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtilsTest.java
modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilderTest.java
modules/ml/src/test/java/org/apache/ignite/ml/dataset/primitive/SimpleDatasetTest.java
modules/ml/src/test/java/org/apache/ignite/ml/dataset/primitive/SimpleLabeledDatasetTest.java
modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentBuilderTest.java
modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java
modules/ml/src/test/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java
modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/imputing/ImputerTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/maxabsscaling/MaxAbsScalerTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/minmaxscaling/MinMaxScalerTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/EvaluatorTest.java
modules/ml/src/test/java/org/apache/ignite/ml/trainers/BaggingTest.java

index 4d42d19..5148d9a 100644 (file)
@@ -73,7 +73,7 @@ public class AlgorithmSpecificDatasetExample {
             try (AlgorithmSpecificDataset dataset = DatasetFactory.create(
                 ignite,
                 persons,
-                (upstream, upstreamSize) -> new AlgorithmSpecificPartitionContext(),
+                (env, upstream, upstreamSize) -> new AlgorithmSpecificPartitionContext(),
                 new SimpleLabeledDatasetDataBuilder<Integer, Person, AlgorithmSpecificPartitionContext>(
                     (k, v) -> VectorUtils.of(v.getAge()),
                     (k, v) -> new double[] {v.getSalary()}
index baf513a..44fb77e 100644 (file)
@@ -81,8 +81,7 @@ public class BaggedLogisticRegressionSGDTrainerExample {
                 0.6,
                 4,
                 3,
-                new OnMajorityPredictionsAggregator(),
-                123L);
+                new OnMajorityPredictionsAggregator());
 
             System.out.println(">>> Perform evaluation of the model.");
 
index 3bf2c8e..a3c33cb 100644 (file)
@@ -31,7 +31,7 @@ import org.apache.ignite.examples.ml.util.MLSandboxDatasets;
 import org.apache.ignite.examples.ml.util.SandboxMLCache;
 import org.apache.ignite.ml.composition.ModelsComposition;
 import org.apache.ignite.ml.dataset.feature.FeatureMeta;
-import org.apache.ignite.ml.environment.LearningEnvironment;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.environment.logging.ConsoleLogger;
 import org.apache.ignite.ml.environment.logging.MLLogger;
 import org.apache.ignite.ml.environment.parallelism.ParallelismStrategy;
@@ -80,10 +80,9 @@ public class RandomForestRegressionExample {
                 .withSubSampleSize(0.3)
                 .withSeed(0);
 
-            trainer.setEnvironment(LearningEnvironment.builder()
-                .withParallelismStrategy(ParallelismStrategy.Type.ON_DEFAULT_POOL)
-                .withLoggingFactory(ConsoleLogger.factory(MLLogger.VerboseLevel.LOW))
-                .build()
+            trainer.withEnvironmentBuilder(LearningEnvironmentBuilder.defaultBuilder()
+                .withParallelismStrategyTypeDependency(part -> ParallelismStrategy.Type.ON_DEFAULT_POOL)
+                .withLoggingFactoryDependency(part -> ConsoleLogger.factory(MLLogger.VerboseLevel.LOW))
             );
 
             System.out.println(">>> Configured trainer: " + trainer.getClass().getSimpleName());
index a20d5da..88ea9b9 100644 (file)
@@ -32,6 +32,7 @@ import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.PartitionDataBuilder;
 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.math.distances.DistanceMeasure;
 import org.apache.ignite.ml.math.distances.EuclideanDistance;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
@@ -78,6 +79,11 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> {
     }
 
     /** {@inheritDoc} */
+    @Override public KMeansTrainer withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
+        return (KMeansTrainer)super.withEnvironmentBuilder(envBuilder);
+    }
+
+    /** {@inheritDoc} */
     @Override protected <K, V> KMeansModel updateModel(KMeansModel mdl, DatasetBuilder<K, V> datasetBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
 
@@ -91,7 +97,8 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> {
         Vector[] centers;
 
         try (Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset = datasetBuilder.build(
-            (upstream, upstreamSize) -> new EmptyContext(),
+            envBuilder,
+            (env, upstream, upstreamSize) -> new EmptyContext(),
             partDataBuilder
         )) {
             final Integer cols = dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> {
index 8682a46..3acca14 100644 (file)
@@ -25,12 +25,14 @@ import org.apache.ignite.ml.composition.boosting.loss.LogLoss;
 import org.apache.ignite.ml.composition.boosting.loss.Loss;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.functions.IgniteFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.structures.LabeledVector;
 import org.apache.ignite.ml.structures.LabeledVectorSet;
 import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap;
+import org.apache.ignite.ml.tree.boosting.GDBBinaryClassifierOnTreesTrainer;
 
 /**
  * Trainer for binary classifier using Gradient Boosting. As preparing stage this algorithm learn labels in dataset and
@@ -69,7 +71,10 @@ public abstract class GDBBinaryClassifierTrainer extends GDBTrainer {
         IgniteBiFunction<K, V, Vector> featureExtractor,
         IgniteBiFunction<K, V, Double> lExtractor) {
 
-        Set<Double> uniqLabels = builder.build(new EmptyContextBuilder<>(), new LabeledDatasetPartitionDataBuilderOnHeap<>(featureExtractor, lExtractor))
+        Set<Double> uniqLabels = builder.build(
+            envBuilder,
+            new EmptyContextBuilder<>(),
+            new LabeledDatasetPartitionDataBuilderOnHeap<>(featureExtractor, lExtractor))
             .compute((IgniteFunction<LabeledVectorSet<Double, LabeledVector>, Set<Double>>)x ->
                     Arrays.stream(x.labels()).boxed().collect(Collectors.toSet()), (a, b) -> {
                     if (a == null)
@@ -102,4 +107,9 @@ public abstract class GDBBinaryClassifierTrainer extends GDBTrainer {
         double internalCls = sigma < 0.5 ? 0.0 : 1.0;
         return internalCls == 0.0 ? externalFirstCls : externalSecondCls;
     }
+
+    /** {@inheritDoc} */
+    @Override public GDBBinaryClassifierOnTreesTrainer withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
+        return (GDBBinaryClassifierOnTreesTrainer)super.withEnvironmentBuilder(envBuilder);
+    }
 }
index e689b91..0b87748 100644 (file)
@@ -29,6 +29,7 @@ import org.apache.ignite.ml.composition.boosting.loss.Loss;
 import org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.environment.LearningEnvironment;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.environment.logging.MLLogger;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.functions.IgniteFunction;
@@ -41,8 +42,11 @@ import org.jetbrains.annotations.NotNull;
  * Learning strategy for gradient boosting.
  */
 public class GDBLearningStrategy {
-    /** Learning environment. */
-    protected LearningEnvironment environment;
+    /** Learning environment builder. */
+    protected LearningEnvironmentBuilder envBuilder;
+
+    /** Learning environment used for trainer. */
+    protected LearningEnvironment trainerEnvironment;
 
     /** Count of iterations. */
     protected int cntOfIterations;
@@ -101,6 +105,8 @@ public class GDBLearningStrategy {
     public <K,V> List<Model<Vector, Double>> update(GDBTrainer.GDBModel mdlToUpdate,
         DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor,
         IgniteBiFunction<K, V, Double> lbExtractor) {
+        if (trainerEnvironment == null)
+            throw new IllegalStateException("Learning environment builder is not set.");
 
         List<Model<Vector, Double>> models = initLearningState(mdlToUpdate);
 
@@ -113,7 +119,7 @@ public class GDBLearningStrategy {
 
             WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(weights, meanLbVal);
             ModelsComposition currComposition = new ModelsComposition(models, aggregator);
-            if (convCheck.isConverged(datasetBuilder, currComposition))
+            if (convCheck.isConverged(envBuilder, datasetBuilder, currComposition))
                 break;
 
             IgniteBiFunction<K, V, Double> lbExtractorWrap = (k, v) -> {
@@ -125,7 +131,7 @@ public class GDBLearningStrategy {
             long startTs = System.currentTimeMillis();
             models.add(trainer.fit(datasetBuilder, featureExtractor, lbExtractorWrap));
             double learningTime = (double)(System.currentTimeMillis() - startTs) / 1000.0;
-            environment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "One model training time was %.2fs", learningTime);
+            trainerEnvironment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "One model training time was %.2fs", learningTime);
         }
 
         return models;
@@ -157,10 +163,11 @@ public class GDBLearningStrategy {
     /**
      * Sets learning environment.
      *
-     * @param environment Learning Environment.
+     * @param envBuilder Learning Environment.
      */
-    public GDBLearningStrategy withEnvironment(LearningEnvironment environment) {
-        this.environment = environment;
+    public GDBLearningStrategy withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
+        this.envBuilder = envBuilder;
+        this.trainerEnvironment = envBuilder.buildForTrainer();
         return this;
     }
 
index 8c1afd7..3dc95ee 100644 (file)
@@ -19,6 +19,7 @@ package org.apache.ignite.ml.composition.boosting;
 
 import org.apache.ignite.ml.composition.boosting.loss.SquaredError;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 
@@ -53,4 +54,9 @@ public abstract class GDBRegressionTrainer extends GDBTrainer {
     @Override protected double internalLabelToExternal(double x) {
         return x;
     }
+
+    /** {@inheritDoc} */
+    @Override public GDBRegressionTrainer withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
+        return (GDBRegressionTrainer)super.withEnvironmentBuilder(envBuilder);
+    }
 }
index 89cc6b1..03772ec 100644 (file)
@@ -30,6 +30,7 @@ import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.environment.logging.MLLogger;
 import org.apache.ignite.ml.knn.regression.KNNRegressionTrainer;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
@@ -99,7 +100,11 @@ public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Doubl
         if (!learnLabels(datasetBuilder, featureExtractor, lbExtractor))
             return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
 
-        IgniteBiTuple<Double, Long> initAndSampleSize = computeInitialValue(datasetBuilder, featureExtractor, lbExtractor);
+        IgniteBiTuple<Double, Long> initAndSampleSize = computeInitialValue(
+            envBuilder,
+            datasetBuilder,
+            featureExtractor,
+            lbExtractor);
         if(initAndSampleSize == null)
             return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
 
@@ -112,7 +117,7 @@ public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Doubl
             .withBaseModelTrainerBuilder(this::buildBaseModelTrainer)
             .withExternalLabelToInternal(this::externalLabelToInternal)
             .withCntOfIterations(cntOfIterations)
-            .withEnvironment(environment)
+            .withEnvironmentBuilder(envBuilder)
             .withLossGradient(loss)
             .withSampleSize(sampleSize)
             .withMeanLabelValue(mean)
@@ -140,6 +145,11 @@ public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Doubl
         return mdl instanceof GDBModel;
     }
 
+    /** {@inheritDoc} */
+    @Override public GDBTrainer withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
+        return (GDBTrainer)super.withEnvironmentBuilder(envBuilder);
+    }
+
     /**
      * Defines unique labels in dataset if need (useful in case of classification).
      *
@@ -175,14 +185,18 @@ public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Doubl
      * Compute mean value of label as first approximation.
      *
      * @param builder Dataset builder.
+     * @param envBuilder Learning environment builder.
      * @param featureExtractor Feature extractor.
      * @param lbExtractor Label extractor.
      */
-    protected <V, K> IgniteBiTuple<Double, Long> computeInitialValue(DatasetBuilder<K, V> builder,
+    protected <V, K> IgniteBiTuple<Double, Long> computeInitialValue(
+        LearningEnvironmentBuilder envBuilder,
+        DatasetBuilder<K, V> builder,
         IgniteBiFunction<K, V, Vector> featureExtractor,
         IgniteBiFunction<K, V, Double> lbExtractor) {
 
         try (Dataset<EmptyContext, DecisionTreeData> dataset = builder.build(
+            envBuilder,
             new EmptyContextBuilder<>(),
             new DecisionTreeDataBuilder<>(featureExtractor, lbExtractor, false)
         )) {
index 88841e2..e383e39 100644 (file)
@@ -26,6 +26,7 @@ import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData;
 import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapDataBuilder;
 import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.functions.IgniteFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
@@ -88,11 +89,16 @@ public abstract class ConvergenceChecker<K, V> implements Serializable {
     /**
      * Checks convergency on dataset.
      *
+     * @param envBuilder Learning environment builder.
      * @param currMdl Current model.
      * @return true if GDB is converged.
      */
-    public boolean isConverged(DatasetBuilder<K, V> datasetBuilder, ModelsComposition currMdl) {
+    public boolean isConverged(
+        LearningEnvironmentBuilder envBuilder,
+        DatasetBuilder<K, V> datasetBuilder,
+        ModelsComposition currMdl) {
         try (Dataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset = datasetBuilder.build(
+            envBuilder,
             new EmptyContextBuilder<>(),
             new FeatureMatrixWithLabelsOnHeapDataBuilder<>(featureExtractor, lbExtractor)
         )) {
index 98cfbe1..193afaf 100644 (file)
@@ -24,6 +24,7 @@ import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData;
 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.functions.IgniteFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
@@ -60,7 +61,7 @@ public class ConvergenceCheckerStub<K,V> extends ConvergenceChecker<K,V> {
     }
 
     /** {@inheritDoc} */
-    @Override public boolean isConverged(DatasetBuilder<K, V> datasetBuilder, ModelsComposition currMdl) {
+    @Override public boolean isConverged(LearningEnvironmentBuilder envBuilder, DatasetBuilder<K, V> datasetBuilder, ModelsComposition currMdl) {
         return false;
     }
 
index 230a467..d821fe3 100644 (file)
@@ -20,6 +20,7 @@ package org.apache.ignite.ml.dataset;
 import java.io.Serializable;
 import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDataset;
 import org.apache.ignite.ml.dataset.impl.local.LocalDataset;
+import org.apache.ignite.ml.environment.LearningEnvironment;
 import org.apache.ignite.ml.math.functions.IgniteBiConsumer;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
@@ -56,49 +57,50 @@ public interface Dataset<C extends Serializable, D extends AutoCloseable> extend
      * Applies the specified {@code map} function to every partition {@code data}, {@code context} and partition
      * index in the dataset and then reduces {@code map} results to final result by using the {@code reduce} function.
      *
-     * @param map Function applied to every partition {@code data}, {@code context} and partition index.
+     * @param map Function applied to every partition {@code data}, {@code context} and {@link LearningEnvironment}.
      * @param reduce Function applied to results of {@code map} to get final result.
      * @param identity Identity.
      * @param <R> Type of a result.
      * @return Final result.
      */
-    public <R> R computeWithCtx(IgniteTriFunction<C, D, Integer, R> map, IgniteBinaryOperator<R> reduce, R identity);
+    public <R> R computeWithCtx(IgniteTriFunction<C, D, LearningEnvironment, R> map, IgniteBinaryOperator<R> reduce, R identity);
 
     /**
-     * Applies the specified {@code map} function to every partition {@code data} and partition index in the dataset
-     * and then reduces {@code map} results to final result by using the {@code reduce} function.
+     * Applies the specified {@code map} function to every partition {@code data} and {@link LearningEnvironment}
+     * in the dataset and then reduces {@code map} results to final result by using the {@code reduce} function.
      *
-     * @param map Function applied to every partition {@code data} and partition index.
+     * @param map Function applied to every partition {@code data} and {@link LearningEnvironment}.
      * @param reduce Function applied to results of {@code map} to get final result.
      * @param identity Identity.
      * @param <R> Type of a result.
      * @return Final result.
      */
-    public <R> R compute(IgniteBiFunction<D, Integer, R> map, IgniteBinaryOperator<R> reduce, R identity);
+    public <R> R compute(IgniteBiFunction<D, LearningEnvironment, R> map, IgniteBinaryOperator<R> reduce, R identity);
 
     /**
-     * Applies the specified {@code map} function to every partition {@code data}, {@code context} and partition
-     * index in the dataset and then reduces {@code map} results to final result by using the {@code reduce} function.
+     * Applies the specified {@code map} function to every partition {@code data}, {@code context} and
+     * {@link LearningEnvironment} in the dataset and then reduces {@code map} results to final
+     * result by using the {@code reduce} function.
      *
-     * @param map Function applied to every partition {@code data}, {@code context} and partition index.
+     * @param map Function applied to every partition {@code data}, {@code context} and {@link LearningEnvironment}.
      * @param reduce Function applied to results of {@code map} to get final result.
      * @param <R> Type of a result.
      * @return Final result.
      */
-    public default <R> R computeWithCtx(IgniteTriFunction<C, D, Integer, R> map, IgniteBinaryOperator<R> reduce) {
+    public default <R> R computeWithCtx(IgniteTriFunction<C, D, LearningEnvironment, R> map, IgniteBinaryOperator<R> reduce) {
         return computeWithCtx(map, reduce, null);
     }
 
     /**
-     * Applies the specified {@code map} function to every partition {@code data} and partition index in the dataset
-     * and then reduces {@code map} results to final result by using the {@code reduce} function.
+     * Applies the specified {@code map} function to every partition {@code data} and {@link LearningEnvironment}
+     * in the dataset and then reduces {@code map} results to final result by using the {@code reduce} function.
      *
-     * @param map Function applied to every partition {@code data} and partition index.
+     * @param map Function applied to every partition {@code data} and {@link LearningEnvironment}.
      * @param reduce Function applied to results of {@code map} to get final result.
      * @param <R> Type of a result.
      * @return Final result.
      */
-    public default <R> R compute(IgniteBiFunction<D, Integer, R> map, IgniteBinaryOperator<R> reduce) {
+    public default <R> R compute(IgniteBiFunction<D, LearningEnvironment, R> map, IgniteBinaryOperator<R> reduce) {
         return compute(map, reduce, null);
     }
 
@@ -113,7 +115,7 @@ public interface Dataset<C extends Serializable, D extends AutoCloseable> extend
      * @return Final result.
      */
     public default <R> R computeWithCtx(IgniteBiFunction<C, D, R> map, IgniteBinaryOperator<R> reduce, R identity) {
-        return computeWithCtx((ctx, data, partIdx) -> map.apply(ctx, data), reduce, identity);
+        return computeWithCtx((ctx, data, env) -> map.apply(ctx, data), reduce, identity);
     }
 
     /**
@@ -127,7 +129,7 @@ public interface Dataset<C extends Serializable, D extends AutoCloseable> extend
      * @return Final result.
      */
     public default <R> R compute(IgniteFunction<D, R> map, IgniteBinaryOperator<R> reduce, R identity) {
-        return compute((data, partIdx) -> map.apply(data), reduce, identity);
+        return compute((data, env) -> map.apply(data), reduce, identity);
     }
 
     /**
@@ -140,7 +142,7 @@ public interface Dataset<C extends Serializable, D extends AutoCloseable> extend
      * @return Final result.
      */
     public default <R> R computeWithCtx(IgniteBiFunction<C, D, R> map, IgniteBinaryOperator<R> reduce) {
-        return computeWithCtx((ctx, data, partIdx) -> map.apply(ctx, data), reduce);
+        return computeWithCtx((ctx, data, env) -> map.apply(ctx, data), reduce);
     }
 
     /**
@@ -153,30 +155,31 @@ public interface Dataset<C extends Serializable, D extends AutoCloseable> extend
      * @return Final result.
      */
     public default <R> R compute(IgniteFunction<D, R> map, IgniteBinaryOperator<R> reduce) {
-        return compute((data, partIdx) -> map.apply(data), reduce);
+        return compute((data, env) -> map.apply(data), reduce);
     }
 
     /**
-     * Applies the specified {@code map} function to every partition {@code data}, {@code context} and partition
-     * index in the dataset.
+     * Applies the specified {@code map} function to every partition {@code data}, {@code context} and
+     * {@link LearningEnvironment} in the dataset.
      *
      * @param map Function applied to every partition {@code data}, {@code context} and partition index.
      */
-    public default void computeWithCtx(IgniteTriConsumer<C, D, Integer> map) {
-        computeWithCtx((ctx, data, partIdx) -> {
-            map.accept(ctx, data, partIdx);
+    public default void computeWithCtx(IgniteTriConsumer<C, D, LearningEnvironment> map) {
+        computeWithCtx((ctx, data, env) -> {
+            map.accept(ctx, data, env);
             return null;
         }, (a, b) -> null);
     }
 
     /**
-     * Applies the specified {@code map} function to every partition {@code data} in the dataset and partition index.
+     * Applies the specified {@code map} function to every partition {@code data} in the dataset and
+     * {@link LearningEnvironment}.
      *
      * @param map Function applied to every partition {@code data} and partition index.
      */
-    public default void compute(IgniteBiConsumer<D, Integer> map) {
-        compute((data, partIdx) -> {
-            map.accept(data, partIdx);
+    public default void compute(IgniteBiConsumer<D, LearningEnvironment> map) {
+        compute((data, env) -> {
+            map.accept(data, env);
             return null;
         }, (a, b) -> null);
     }
@@ -187,7 +190,7 @@ public interface Dataset<C extends Serializable, D extends AutoCloseable> extend
      * @param map Function applied to every partition {@code data} and {@code context}.
      */
     public default void computeWithCtx(IgniteBiConsumer<C, D> map) {
-        computeWithCtx((ctx, data, partIdx) -> map.accept(ctx, data));
+        computeWithCtx((ctx, data, env) -> map.accept(ctx, data));
     }
 
     /**
@@ -196,7 +199,7 @@ public interface Dataset<C extends Serializable, D extends AutoCloseable> extend
      * @param map Function applied to every partition {@code data}.
      */
     public default void compute(IgniteConsumer<D> map) {
-        compute((data, partIdx) -> map.accept(data));
+        compute((data, env) -> map.accept(data));
     }
 
     /**
index 4dd0a96..9900659 100644 (file)
@@ -21,6 +21,7 @@ import java.io.Serializable;
 import org.apache.ignite.lang.IgniteBiPredicate;
 import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.trainers.transformers.BaggingUpstreamTransformer;
 
 /**
@@ -40,6 +41,7 @@ public interface DatasetBuilder<K, V> {
      * Constructs a new instance of {@link Dataset} that includes allocation required data structures and
      * initialization of {@code context} part of partitions.
      *
+     * @param envBuilder Learning environment builder.
      * @param partCtxBuilder Partition {@code context} builder.
      * @param partDataBuilder Partition {@code data} builder.
      * @param <C> Type of a partition {@code context}.
@@ -47,18 +49,25 @@ public interface DatasetBuilder<K, V> {
      * @return Dataset.
      */
     public <C extends Serializable, D extends AutoCloseable> Dataset<C, D> build(
-        PartitionContextBuilder<K, V, C> partCtxBuilder, PartitionDataBuilder<K, V, C, D> partDataBuilder);
+        LearningEnvironmentBuilder envBuilder,
+        PartitionContextBuilder<K, V, C> partCtxBuilder,
+        PartitionDataBuilder<K, V, C, D> partDataBuilder);
 
     /**
-     * Get upstream transformers chain. This chain is applied to upstream data before it is passed
+     * Returns new instance of {@link DatasetBuilder} with new {@link UpstreamTransformerBuilder} added
+     * to chain of upstream transformer builders. When needed, each builder in chain first transformed into
+     * {@link UpstreamTransformer}, those are in turn composed together one after another forming
+     * final {@link UpstreamTransformer}.
+     * This transformer is applied to upstream data before it is passed
      * to {@link PartitionDataBuilder} and {@link PartitionContextBuilder}. This is needed to allow
      * transformation to upstream data which are agnostic of any changes that happen after.
      * Such transformations may be used for deriving meta-algorithms such as bagging
      * (see {@link BaggingUpstreamTransformer}).
      *
-     * @return Upstream transformers chain.
+     * @return Returns new instance of {@link DatasetBuilder} with new {@link UpstreamTransformerBuilder} added
+     * to chain of upstream transformer builders.
      */
-    public UpstreamTransformerChain<K, V> upstreamTransformersChain();
+    public DatasetBuilder<K, V> withUpstreamTransformer(UpstreamTransformerBuilder<K, V> builder);
 
     /**
      * Returns new instance of DatasetBuilder using conjunction of internal filter and {@code filterToAdd}.
index 1623a2b..ef8eb23 100644 (file)
@@ -31,6 +31,7 @@ import org.apache.ignite.ml.dataset.primitive.builder.data.SimpleLabeledDatasetD
 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
 import org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData;
 import org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 
@@ -76,6 +77,7 @@ public class DatasetFactory {
      * {@code partDataBuilder}. This is the generic methods that allows to create any Ignite Cache based datasets with
      * any desired partition {@code context} and {@code data}.
      *
+     * @param envBuilder Learning environment builder.
      * @param datasetBuilder Dataset builder.
      * @param partCtxBuilder Partition {@code context} builder.
      * @param partDataBuilder Partition {@code data} builder.
@@ -86,13 +88,42 @@ public class DatasetFactory {
      * @return Dataset.
      */
     public static <K, V, C extends Serializable, D extends AutoCloseable> Dataset<C, D> create(
-        DatasetBuilder<K, V> datasetBuilder, PartitionContextBuilder<K, V, C> partCtxBuilder,
+        DatasetBuilder<K, V> datasetBuilder,
+        LearningEnvironmentBuilder envBuilder,
+        PartitionContextBuilder<K, V, C> partCtxBuilder,
         PartitionDataBuilder<K, V, C, D> partDataBuilder) {
         return datasetBuilder.build(
+            envBuilder,
             partCtxBuilder,
             partDataBuilder
         );
     }
+
+    /**
+     * Creates a new instance of distributed dataset using the specified {@code partCtxBuilder} and
+     * {@code partDataBuilder}. This is the generic methods that allows to create any Ignite Cache based datasets with
+     * any desired partition {@code context} and {@code data}.
+     *
+     * @param datasetBuilder Dataset builder.
+     * @param partCtxBuilder Partition {@code context} builder.
+     * @param partDataBuilder Partition {@code data} builder.
+     * @param <K> Type of a key in {@code upstream} data.
+     * @param <V> ype of a value in {@code upstream} data.
+     * @param <C> Type of a partition {@code context}.
+     * @param <D> Type of a partition {@code data}.
+     * @return Dataset.
+     */
+    public static <K, V, C extends Serializable, D extends AutoCloseable> Dataset<C, D> create(
+        DatasetBuilder<K, V> datasetBuilder,
+        PartitionContextBuilder<K, V, C> partCtxBuilder,
+        PartitionDataBuilder<K, V, C, D> partDataBuilder) {
+        return datasetBuilder.build(
+            LearningEnvironmentBuilder.defaultBuilder(),
+            partCtxBuilder,
+            partDataBuilder
+        );
+    }
+
     /**
      * Creates a new instance of distributed dataset using the specified {@code partCtxBuilder} and
      * {@code partDataBuilder}. This is the generic methods that allows to create any Ignite Cache based datasets with
@@ -100,6 +131,7 @@ public class DatasetFactory {
      *
      * @param ignite Ignite instance.
      * @param upstreamCache Ignite Cache with {@code upstream} data.
+     * @param envBuilder Learning environment builder.
      * @param partCtxBuilder Partition {@code context} builder.
      * @param partDataBuilder Partition {@code data} builder.
      * @param <K> Type of a key in {@code upstream} data.
@@ -109,7 +141,36 @@ public class DatasetFactory {
      * @return Dataset.
      */
     public static <K, V, C extends Serializable, D extends AutoCloseable> Dataset<C, D> create(
-        Ignite ignite, IgniteCache<K, V> upstreamCache, PartitionContextBuilder<K, V, C> partCtxBuilder,
+        Ignite ignite, IgniteCache<K, V> upstreamCache,
+        LearningEnvironmentBuilder envBuilder,
+        PartitionContextBuilder<K, V, C> partCtxBuilder,
+        PartitionDataBuilder<K, V, C, D> partDataBuilder) {
+        return create(
+            new CacheBasedDatasetBuilder<>(ignite, upstreamCache),
+            envBuilder,
+            partCtxBuilder,
+            partDataBuilder
+        );
+    }
+
+    /**
+     * Creates a new instance of distributed dataset using the specified {@code partCtxBuilder} and
+     * {@code partDataBuilder}. This is the generic methods that allows to create any Ignite Cache based datasets with
+     * any desired partition {@code context} and {@code data}.
+     *
+     * @param ignite Ignite instance.
+     * @param upstreamCache Ignite Cache with {@code upstream} data.
+     * @param partCtxBuilder Partition {@code context} builder.
+     * @param partDataBuilder Partition {@code data} builder.
+     * @param <K> Type of a key in {@code upstream} data.
+     * @param <V> Type of a value in {@code upstream} data.
+     * @param <C> Type of a partition {@code context}.
+     * @param <D> Type of a partition {@code data}.
+     * @return Dataset.
+     */
+    public static <K, V, C extends Serializable, D extends AutoCloseable> Dataset<C, D> create(
+        Ignite ignite, IgniteCache<K, V> upstreamCache,
+        PartitionContextBuilder<K, V, C> partCtxBuilder,
         PartitionDataBuilder<K, V, C, D> partDataBuilder) {
         return create(
             new CacheBasedDatasetBuilder<>(ignite, upstreamCache),
@@ -124,6 +185,7 @@ public class DatasetFactory {
      * allows to use any desired type of partition {@code context}.
      *
      * @param datasetBuilder Dataset builder.
+     * @param envBuilder Learning environment builder.
      * @param partCtxBuilder Partition {@code context} builder.
      * @param featureExtractor Feature extractor used to extract features and build {@link SimpleDatasetData}.
      * @param <K> Type of a key in {@code upstream} data.
@@ -132,10 +194,13 @@ public class DatasetFactory {
      * @return Dataset.
      */
     public static <K, V, C extends Serializable> SimpleDataset<C> createSimpleDataset(
-        DatasetBuilder<K, V> datasetBuilder, PartitionContextBuilder<K, V, C> partCtxBuilder,
+        DatasetBuilder<K, V> datasetBuilder,
+        LearningEnvironmentBuilder envBuilder,
+        PartitionContextBuilder<K, V, C> partCtxBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor) {
         return create(
             datasetBuilder,
+            envBuilder,
             partCtxBuilder,
             new SimpleDatasetDataBuilder<>(featureExtractor)
         ).wrap(SimpleDataset::new);
@@ -148,6 +213,7 @@ public class DatasetFactory {
      *
      * @param ignite Ignite instance.
      * @param upstreamCache Ignite Cache with {@code upstream} data.
+     * @param envBuilder Learning environment builder.
      * @param partCtxBuilder Partition {@code context} builder.
      * @param featureExtractor Feature extractor used to extract features and build {@link SimpleDatasetData}.
      * @param <K> Type of a key in {@code upstream} data.
@@ -156,10 +222,13 @@ public class DatasetFactory {
      * @return Dataset.
      */
     public static <K, V, C extends Serializable> SimpleDataset<C> createSimpleDataset(Ignite ignite,
-        IgniteCache<K, V> upstreamCache, PartitionContextBuilder<K, V, C> partCtxBuilder,
+        IgniteCache<K, V> upstreamCache,
+        LearningEnvironmentBuilder envBuilder,
+        PartitionContextBuilder<K, V, C> partCtxBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor) {
         return createSimpleDataset(
             new CacheBasedDatasetBuilder<>(ignite, upstreamCache),
+            envBuilder,
             partCtxBuilder,
             featureExtractor
         );
@@ -171,6 +240,7 @@ public class DatasetFactory {
      * {@link SimpleLabeledDatasetData}, but allows to use any desired type of partition {@code context}.
      *
      * @param datasetBuilder Dataset builder.
+     * @param envBuilder Learning environment builder.
      * @param partCtxBuilder Partition {@code context} builder.
      * @param featureExtractor Feature extractor used to extract features and build {@link SimpleLabeledDatasetData}.
      * @param lbExtractor Label extractor used to extract labels and build {@link SimpleLabeledDatasetData}.
@@ -180,10 +250,14 @@ public class DatasetFactory {
      * @return Dataset.
      */
     public static <K, V, C extends Serializable> SimpleLabeledDataset<C> createSimpleLabeledDataset(
-        DatasetBuilder<K, V> datasetBuilder, PartitionContextBuilder<K, V, C> partCtxBuilder,
-        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, double[]> lbExtractor) {
+        DatasetBuilder<K, V> datasetBuilder,
+        LearningEnvironmentBuilder envBuilder,
+        PartitionContextBuilder<K, V, C> partCtxBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, double[]> lbExtractor) {
         return create(
             datasetBuilder,
+            envBuilder,
             partCtxBuilder,
             new SimpleLabeledDatasetDataBuilder<>(featureExtractor, lbExtractor)
         ).wrap(SimpleLabeledDataset::new);
@@ -196,6 +270,7 @@ public class DatasetFactory {
      *
      * @param ignite Ignite instance.
      * @param upstreamCache Ignite Cache with {@code upstream} data.
+     * @param envBuilder Learning environment builder.
      * @param partCtxBuilder Partition {@code context} builder.
      * @param featureExtractor Feature extractor used to extract features and build {@link SimpleLabeledDatasetData}.
      * @param lbExtractor Label extractor used to extract labels and build {@link SimpleLabeledDatasetData}.
@@ -204,11 +279,16 @@ public class DatasetFactory {
      * @param <C> Type of a partition {@code context}.
      * @return Dataset.
      */
-    public static <K, V, C extends Serializable> SimpleLabeledDataset<C> createSimpleLabeledDataset(Ignite ignite,
-        IgniteCache<K, V> upstreamCache, PartitionContextBuilder<K, V, C> partCtxBuilder,
-        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, double[]> lbExtractor) {
+    public static <K, V, C extends Serializable> SimpleLabeledDataset<C> createSimpleLabeledDataset(
+        Ignite ignite,
+        IgniteCache<K, V> upstreamCache,
+        LearningEnvironmentBuilder envBuilder,
+        PartitionContextBuilder<K, V, C> partCtxBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, double[]> lbExtractor) {
         return createSimpleLabeledDataset(
             new CacheBasedDatasetBuilder<>(ignite, upstreamCache),
+            envBuilder,
             partCtxBuilder,
             featureExtractor,
             lbExtractor
@@ -221,15 +301,19 @@ public class DatasetFactory {
      * {@link SimpleDatasetData}.
      *
      * @param datasetBuilder Dataset builder.
+     * @param envBuilder Learning environment builder.
      * @param featureExtractor Feature extractor used to extract features and build {@link SimpleDatasetData}.
      * @param <K> Type of a key in {@code upstream} data.
      * @param <V> Type of a value in {@code upstream} data.
      * @return Dataset.
      */
-    public static <K, V> SimpleDataset<EmptyContext> createSimpleDataset(DatasetBuilder<K, V> datasetBuilder,
+    public static <K, V> SimpleDataset<EmptyContext> createSimpleDataset(
+        DatasetBuilder<K, V> datasetBuilder,
+        LearningEnvironmentBuilder envBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor) {
         return createSimpleDataset(
             datasetBuilder,
+            envBuilder,
             new EmptyContextBuilder<>(),
             featureExtractor
         );
@@ -242,15 +326,43 @@ public class DatasetFactory {
      *
      * @param ignite Ignite instance.
      * @param upstreamCache Ignite Cache with {@code upstream} data.
+     * @param envBuilder Learning environment builder.
      * @param featureExtractor Feature extractor used to extract features and build {@link SimpleDatasetData}.
      * @param <K> Type of a key in {@code upstream} data.
      * @param <V> Type of a value in {@code upstream} data.
      * @return Dataset.
      */
-    public static <K, V> SimpleDataset<EmptyContext> createSimpleDataset(Ignite ignite, IgniteCache<K, V> upstreamCache,
+    public static <K, V> SimpleDataset<EmptyContext> createSimpleDataset(
+        Ignite ignite,
+        IgniteCache<K, V> upstreamCache,
+        LearningEnvironmentBuilder envBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor) {
         return createSimpleDataset(
             new CacheBasedDatasetBuilder<>(ignite, upstreamCache),
+            envBuilder,
+            featureExtractor
+        );
+    }
+
+    /**
+     * Creates a new instance of distributed {@link SimpleDataset} using the specified {@code featureExtractor}. This
+     * methods determines partition {@code context} to be {@link EmptyContext} and partition {@code data} to be
+     * {@link SimpleDatasetData}.
+     *
+     * @param ignite Ignite instance.
+     * @param upstreamCache Ignite Cache with {@code upstream} data.
+     * @param featureExtractor Feature extractor used to extract features and build {@link SimpleDatasetData}.
+     * @param <K> Type of a key in {@code upstream} data.
+     * @param <V> Type of a value in {@code upstream} data.
+     * @return Dataset.
+     */
+    public static <K, V> SimpleDataset<EmptyContext> createSimpleDataset(
+        Ignite ignite,
+        IgniteCache<K, V> upstreamCache,
+        IgniteBiFunction<K, V, Vector> featureExtractor) {
+        return createSimpleDataset(
+            new CacheBasedDatasetBuilder<>(ignite, upstreamCache),
+            LearningEnvironmentBuilder.defaultBuilder(),
             featureExtractor
         );
     }
@@ -261,6 +373,7 @@ public class DatasetFactory {
      * partition {@code data} to be {@link SimpleLabeledDatasetData}.
      *
      * @param datasetBuilder Dataset builder.
+     * @param envBuilder Learning environment builder.
      * @param featureExtractor Feature extractor used to extract features and build {@link SimpleLabeledDatasetData}.
      * @param lbExtractor Label extractor used to extract labels and build {@link SimpleLabeledDatasetData}.
      * @param <K> Type of a key in {@code upstream} data.
@@ -268,10 +381,13 @@ public class DatasetFactory {
      * @return Dataset.
      */
     public static <K, V> SimpleLabeledDataset<EmptyContext> createSimpleLabeledDataset(
-        DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor,
+        DatasetBuilder<K, V> datasetBuilder,
+        LearningEnvironmentBuilder envBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor,
         IgniteBiFunction<K, V, double[]> lbExtractor) {
         return createSimpleLabeledDataset(
             datasetBuilder,
+            envBuilder,
             new EmptyContextBuilder<>(),
             featureExtractor,
             lbExtractor
@@ -285,17 +401,21 @@ public class DatasetFactory {
      *
      * @param ignite Ignite instance.
      * @param upstreamCache Ignite Cache with {@code upstream} data.
+     * @param envBuilder Learning environment builder.
      * @param featureExtractor Feature extractor used to extract features and build {@link SimpleLabeledDatasetData}.
      * @param lbExtractor Label extractor used to extract labels and build {@link SimpleLabeledDatasetData}.
      * @param <K> Type of a key in {@code upstream} data.
      * @param <V> Type of a value in {@code upstream} data.
      * @return Dataset.
      */
-    public static <K, V> SimpleLabeledDataset<EmptyContext> createSimpleLabeledDataset(Ignite ignite,
+    public static <K, V> SimpleLabeledDataset<EmptyContext> createSimpleLabeledDataset(
+        Ignite ignite,
+        LearningEnvironmentBuilder envBuilder,
         IgniteCache<K, V> upstreamCache, IgniteBiFunction<K, V, Vector> featureExtractor,
         IgniteBiFunction<K, V, double[]> lbExtractor) {
         return createSimpleLabeledDataset(
             new CacheBasedDatasetBuilder<>(ignite, upstreamCache),
+            envBuilder,
             featureExtractor,
             lbExtractor
         );
@@ -309,6 +429,7 @@ public class DatasetFactory {
      * @param upstreamMap {@code Map} with {@code upstream} data.
      * @param partitions Number of partitions {@code upstream} {@code Map} will be divided on.
      * @param partCtxBuilder Partition {@code context} builder.
+     * @param envBuilder Learning environment builder.
      * @param partDataBuilder Partition {@code data} builder.
      * @param <K> Type of a key in {@code upstream} data.
      * @param <V> Type of a value in {@code upstream} data.
@@ -317,10 +438,13 @@ public class DatasetFactory {
      * @return Dataset.
      */
     public static <K, V, C extends Serializable, D extends AutoCloseable> Dataset<C, D> create(
-        Map<K, V> upstreamMap, int partitions, PartitionContextBuilder<K, V, C> partCtxBuilder,
+        Map<K, V> upstreamMap,
+        LearningEnvironmentBuilder envBuilder,
+        int partitions, PartitionContextBuilder<K, V, C> partCtxBuilder,
         PartitionDataBuilder<K, V, C, D> partDataBuilder) {
         return create(
             new LocalDatasetBuilder<>(upstreamMap, partitions),
+            envBuilder,
             partCtxBuilder,
             partDataBuilder
         );
@@ -333,6 +457,7 @@ public class DatasetFactory {
      *
      * @param upstreamMap {@code Map} with {@code upstream} data.
      * @param partitions Number of partitions {@code upstream} {@code Map} will be divided on.
+     * @param envBuilder Learning environment builder.
      * @param partCtxBuilder Partition {@code context} builder.
      * @param featureExtractor Feature extractor used to extract features and build {@link SimpleDatasetData}.
      * @param <K> Type of a key in {@code upstream} data.
@@ -340,11 +465,15 @@ public class DatasetFactory {
      * @param <C> Type of a partition {@code context}.
      * @return Dataset.
      */
-    public static <K, V, C extends Serializable> SimpleDataset<C> createSimpleDataset(Map<K, V> upstreamMap,
-        int partitions, PartitionContextBuilder<K, V, C> partCtxBuilder,
+    public static <K, V, C extends Serializable> SimpleDataset<C> createSimpleDataset(
+        Map<K, V> upstreamMap,
+        int partitions,
+        LearningEnvironmentBuilder envBuilder,
+        PartitionContextBuilder<K, V, C> partCtxBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor) {
         return createSimpleDataset(
             new LocalDatasetBuilder<>(upstreamMap, partitions),
+            envBuilder,
             partCtxBuilder,
             featureExtractor
         );
@@ -357,6 +486,7 @@ public class DatasetFactory {
      *
      * @param upstreamMap {@code Map} with {@code upstream} data.
      * @param partitions Number of partitions {@code upstream} {@code Map} will be divided on.
+     * @param envBuilder Learning environment builder.
      * @param partCtxBuilder Partition {@code context} builder.
      * @param featureExtractor Feature extractor used to extract features and build {@link SimpleLabeledDatasetData}.
      * @param lbExtractor Label extractor used to extract labels and build {@link SimpleLabeledDatasetData}.
@@ -366,10 +496,14 @@ public class DatasetFactory {
      * @return Dataset.
      */
     public static <K, V, C extends Serializable> SimpleLabeledDataset<C> createSimpleLabeledDataset(
-        Map<K, V> upstreamMap, int partitions, PartitionContextBuilder<K, V, C> partCtxBuilder,
+        Map<K, V> upstreamMap,
+        int partitions,
+        LearningEnvironmentBuilder envBuilder,
+        PartitionContextBuilder<K, V, C> partCtxBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, double[]> lbExtractor) {
         return createSimpleLabeledDataset(
             new LocalDatasetBuilder<>(upstreamMap, partitions),
+            envBuilder,
             partCtxBuilder,
             featureExtractor, lbExtractor
         );
@@ -382,15 +516,18 @@ public class DatasetFactory {
      *
      * @param upstreamMap {@code Map} with {@code upstream} data.
      * @param partitions Number of partitions {@code upstream} {@code Map} will be divided on.
+     * @param envBuilder Learning environment builder.
      * @param featureExtractor Feature extractor used to extract features and build {@link SimpleDatasetData}.
      * @param <K> Type of a key in {@code upstream} data.
      * @param <V> Type of a value in {@code upstream} data.
      * @return Dataset.
      */
     public static <K, V> SimpleDataset<EmptyContext> createSimpleDataset(Map<K, V> upstreamMap, int partitions,
+        LearningEnvironmentBuilder envBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor) {
         return createSimpleDataset(
             new LocalDatasetBuilder<>(upstreamMap, partitions),
+            envBuilder,
             featureExtractor
         );
     }
@@ -402,6 +539,7 @@ public class DatasetFactory {
      *
      * @param upstreamMap {@code Map} with {@code upstream} data.
      * @param partitions Number of partitions {@code upstream} {@code Map} will be divided on.
+     * @param envBuilder Learning environment builder.
      * @param featureExtractor Feature extractor used to extract features and build {@link SimpleLabeledDatasetData}.
      * @param lbExtractor Label extractor used to extract labels and build {@link SimpleLabeledDatasetData}.
      * @param <K> Type of a key in {@code upstream} data.
@@ -409,10 +547,12 @@ public class DatasetFactory {
      * @return Dataset.
      */
     public static <K, V> SimpleLabeledDataset<EmptyContext> createSimpleLabeledDataset(Map<K, V> upstreamMap,
+        LearningEnvironmentBuilder envBuilder,
         int partitions, IgniteBiFunction<K, V, Vector> featureExtractor,
         IgniteBiFunction<K, V, double[]> lbExtractor) {
         return createSimpleLabeledDataset(
             new LocalDatasetBuilder<>(upstreamMap, partitions),
+            envBuilder,
             featureExtractor,
             lbExtractor
         );
index 6e1fec3..c5eac88 100644 (file)
@@ -21,6 +21,7 @@ import java.io.Serializable;
 import java.util.Iterator;
 import java.util.stream.Stream;
 import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
+import org.apache.ignite.ml.environment.LearningEnvironment;
 import org.apache.ignite.ml.math.functions.IgniteFunction;
 
 /**
@@ -43,11 +44,12 @@ public interface PartitionContextBuilder<K, V, C extends Serializable> extends S
      * constraint. This constraint is omitted to allow upstream data transformers in {@link DatasetBuilder} replicating
      * entries. For example it can be useful for bootstrapping.
      *
+     * @param env Learning environment.
      * @param upstreamData Partition {@code upstream} data.
      * @param upstreamDataSize Partition {@code upstream} data size.
      * @return Partition {@code context}.
      */
-    public C build(Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize);
+    public C build(LearningEnvironment env, Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize);
 
 
     /**
@@ -57,12 +59,13 @@ public interface PartitionContextBuilder<K, V, C extends Serializable> extends S
      * constraint. This constraint is omitted to allow upstream data transformers in {@link DatasetBuilder} replicating
      * entries. For example it can be useful for bootstrapping.
      *
+     * @param env Learning environment.
      * @param upstreamData Partition {@code upstream} data.
      * @param upstreamDataSize Partition {@code upstream} data size.
      * @return Partition {@code context}.
      */
-    public default C build(Stream<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize) {
-        return build(upstreamData.iterator(), upstreamDataSize);
+    public default C build(LearningEnvironment env, Stream<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize) {
+        return build(env, upstreamData.iterator(), upstreamDataSize);
     }
 
     /**
@@ -74,6 +77,6 @@ public interface PartitionContextBuilder<K, V, C extends Serializable> extends S
      * @return Composed partition {@code context} builder.
      */
     public default <C2 extends Serializable> PartitionContextBuilder<K, V, C2> andThen(IgniteFunction<C, C2> fun) {
-        return (upstreamData, upstreamDataSize) -> fun.apply(build(upstreamData, upstreamDataSize));
+        return (env, upstreamData, upstreamDataSize) -> fun.apply(build(env, upstreamData, upstreamDataSize));
     }
 }
index 106084b..4a0e68e 100644 (file)
@@ -22,6 +22,7 @@ import java.util.Iterator;
 import java.util.stream.Stream;
 import org.apache.ignite.ml.dataset.primitive.builder.data.SimpleDatasetDataBuilder;
 import org.apache.ignite.ml.dataset.primitive.builder.data.SimpleLabeledDatasetDataBuilder;
+import org.apache.ignite.ml.environment.LearningEnvironment;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 
 /**
@@ -46,12 +47,13 @@ public interface PartitionDataBuilder<K, V, C extends Serializable, D extends Au
      * constraint. This constraint is omitted to allow upstream data transformers in {@link DatasetBuilder} replicating
      * entries. For example it can be useful for bootstrapping.
      *
+     * @param env Learning environment.
      * @param upstreamData Partition {@code upstream} data.
      * @param upstreamDataSize Partition {@code upstream} data size.
      * @param ctx Partition {@code context}.
      * @return Partition {@code data}.
      */
-    public D build(Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize, C ctx);
+    public D build(LearningEnvironment env, Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize, C ctx);
 
     /**
      * Builds a new partition {@code data} from a partition {@code upstream} data and partition {@code context}.
@@ -60,13 +62,14 @@ public interface PartitionDataBuilder<K, V, C extends Serializable, D extends Au
      * constraint. This constraint is omitted to allow upstream data transformers in {@link DatasetBuilder} replicating
      * entries. For example it can be useful for bootstrapping.
      *
+     * @param env Learning environment.
      * @param upstreamData Partition {@code upstream} data.
      * @param upstreamDataSize Partition {@code upstream} data size.
      * @param ctx Partition {@code context}.
      * @return Partition {@code data}.
      */
-    public default D build(Stream<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize, C ctx) {
-        return build(upstreamData.iterator(), upstreamDataSize, ctx);
+    public default D build(LearningEnvironment env, Stream<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize, C ctx) {
+        return build(env, upstreamData.iterator(), upstreamDataSize, ctx);
     }
 
     /**
@@ -79,6 +82,7 @@ public interface PartitionDataBuilder<K, V, C extends Serializable, D extends Au
      */
     public default <D2 extends AutoCloseable> PartitionDataBuilder<K, V, C, D2> andThen(
         IgniteBiFunction<D, C, D2> fun) {
-        return (upstreamData, upstreamDataSize, ctx) -> fun.apply(build(upstreamData, upstreamDataSize, ctx), ctx);
+        return (env, upstreamData, upstreamDataSize, ctx) ->
+            fun.apply(build(env, upstreamData, upstreamDataSize, ctx), ctx);
     }
 }
index ba70e2e..11b250b 100644 (file)
@@ -18,7 +18,6 @@
 package org.apache.ignite.ml.dataset;
 
 import java.io.Serializable;
-import java.util.Random;
 import java.util.stream.Stream;
 
 /**
@@ -27,16 +26,25 @@ import java.util.stream.Stream;
  * @param <K> Type of keys in the upstream.
  * @param <V> Type of values in the upstream.
  */
+// TODO: IGNITE-10297: Investigate possibility of API change.
 @FunctionalInterface
 public interface UpstreamTransformer<K, V> extends Serializable {
     /**
-     * Perform transformation of upstream.
+     * Transform upstream.
      *
-     * @param rnd Random numbers generator.
-     * @param upstream Upstream.
+     * @param upstream Upstream to transform.
      * @return Transformed upstream.
      */
-    // TODO: IGNITE-10296: Inject capabilities of randomization through learning environment.
-    // TODO: IGNITE-10297: Investigate possibility of API change.
-    public Stream<UpstreamEntry<K, V>> transform(Random rnd, Stream<UpstreamEntry<K, V>> upstream);
+    public Stream<UpstreamEntry<K, V>> transform(Stream<UpstreamEntry<K, V>> upstream);
+
+    /**
+     * Get composition of this transformer and other transformer which is
+     * itself is {@link UpstreamTransformer} applying this transformer and then other transformer.
+     *
+     * @param other Other transformer.
+     * @return Composition of this and other transformer.
+     */
+    default UpstreamTransformer<K, V> andThen(UpstreamTransformer<K, V> other) {
+        return upstream -> other.transform(transform(upstream));
+    }
 }
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformerBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformerBuilder.java
new file mode 100644 (file)
index 0000000..9adfab5
--- /dev/null
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.dataset;
+
+import java.io.Serializable;
+import org.apache.ignite.ml.environment.LearningEnvironment;
+
+/**
+ * Builder of {@link UpstreamTransformerBuilder}.
+ * @param <K> Type of keys in upstream.
+ * @param <V> Type of values in upstream.
+ */
+@FunctionalInterface
+public interface UpstreamTransformerBuilder<K, V> extends Serializable {
+    /**
+     * Create {@link UpstreamTransformer} based on learning environment.
+     *
+     * @param env Learning environment.
+     * @return Upstream transformer.
+     */
+    public UpstreamTransformer<K, V> build(LearningEnvironment env);
+
+    /**
+     * Combunes two builders (this and other respectfully)
+     * <pre>
+     * env -> transformer1
+     * env -> transformer2
+     * </pre>
+     * into
+     * <pre>
+     * env -> transformer2 . transformer1
+     * </pre>
+     *
+     * @param other Builder to combine with.
+     * @return Compositional builder.
+     */
+    public default UpstreamTransformerBuilder<K, V> andThen(UpstreamTransformerBuilder<K, V> other) {
+        UpstreamTransformerBuilder<K, V> self = this;
+        return env -> {
+            UpstreamTransformer<K, V> transformer1 = self.build(env);
+            UpstreamTransformer<K, V> transformer2 = other.build(env);
+
+            return upstream -> transformer2.transform(transformer1.transform(upstream));
+        };
+    }
+
+    /**
+     * Returns identity upstream transformer.
+     *
+     * @param <K> Type of keys in upstream.
+     * @param <V> Type of values in upstream.
+     * @return Identity upstream transformer.
+     */
+    public static <K, V> UpstreamTransformerBuilder<K, V> identity() {
+        return env -> upstream -> upstream;
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformerChain.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformerChain.java
deleted file mode 100644 (file)
index 3ad6446..0000000
+++ /dev/null
@@ -1,153 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.dataset;
-
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Random;
-import java.util.stream.Stream;
-import org.apache.ignite.ml.math.functions.IgniteFunction;
-
-/**
- * Class representing chain of transformers applied to upstream.
- *
- * @param <K> Type of upstream keys.
- * @param <V> Type of upstream values.
- */
-public class UpstreamTransformerChain<K, V> implements Serializable {
-    /** Seed used for transformations. */
-    private Long seed;
-
-    /** List of upstream transformations. */
-    private List<UpstreamTransformer<K, V>> list;
-
-    /**
-     * Creates empty upstream transformers chain (basically identity function).
-     *
-     * @param <K> Type of upstream keys.
-     * @param <V> Type of upstream values.
-     * @return Empty upstream transformers chain.
-     */
-    public static <K, V> UpstreamTransformerChain<K, V> empty() {
-        return new UpstreamTransformerChain<>();
-    }
-
-    /**
-     * Creates upstream transformers chain consisting of one specified transformer.
-     *
-     * @param <K> Type of upstream keys.
-     * @param <V> Type of upstream values.
-     * @return Upstream transformers chain consisting of one specified transformer.
-     */
-    public static <K, V> UpstreamTransformerChain<K, V> of(UpstreamTransformer<K, V> trans) {
-        UpstreamTransformerChain<K, V> res = new UpstreamTransformerChain<>();
-        return res.addUpstreamTransformer(trans);
-    }
-
-    /**
-     * Construct instance of this class.
-     */
-    private UpstreamTransformerChain() {
-        list = new ArrayList<>();
-        seed = new Random().nextLong();
-    }
-
-    /**
-     * Adds upstream transformer to this chain.
-     *
-     * @param next Transformer to add.
-     * @return This chain with added transformer.
-     */
-    public UpstreamTransformerChain<K, V> addUpstreamTransformer(UpstreamTransformer<K, V> next) {
-        list.add(next);
-
-        return this;
-    }
-
-    /**
-     * Add upstream transformer based on given lambda.
-     *
-     * @param transformer Transformer.
-     * @return This object.
-     */
-    public UpstreamTransformerChain<K, V> addUpstreamTransformer(IgniteFunction<Stream<UpstreamEntry<K, V>>,
-        Stream<UpstreamEntry<K, V>>> transformer) {
-        return addUpstreamTransformer((rnd, upstream) -> transformer.apply(upstream));
-    }
-
-    /**
-     * Performs stream transformation using RNG based on provided seed as pseudo-randomness source for all
-     * transformers in the chain.
-     *
-     * @param upstream Upstream.
-     * @return Transformed upstream.
-     */
-    public Stream<UpstreamEntry<K, V>> transform(Stream<UpstreamEntry<K, V>> upstream) {
-        Random rnd = new Random(seed);
-
-        Stream<UpstreamEntry<K, V>> res = upstream;
-
-        for (UpstreamTransformer<K, V> kvUpstreamTransformer : list)
-            res = kvUpstreamTransformer.transform(rnd, res);
-
-        return res;
-    }
-
-    /**
-     * Checks if this chain is empty.
-     *
-     * @return Result of check if this chain is empty.
-     */
-    public boolean isEmpty() {
-        return list.isEmpty();
-    }
-
-    /**
-     * Set seed for transformations.
-     *
-     * @param seed Seed.
-     * @return This object.
-     */
-    public UpstreamTransformerChain<K, V> setSeed(long seed) {
-        this.seed = seed;
-
-        return this;
-    }
-
-    /**
-     * Modifies seed for transformations if it is present.
-     *
-     * @param f Modification function.
-     * @return This object.
-     */
-    public UpstreamTransformerChain<K, V> modifySeed(IgniteFunction<Long, Long> f) {
-        seed = f.apply(seed);
-
-        return this;
-    }
-
-    /**
-     * Get seed used for RNG in transformations.
-     *
-     * @return Seed used for RNG in transformations.
-     */
-    public Long seed() {
-        return seed;
-    }
-}
index 8707e3a..c8d78dd 100644 (file)
@@ -20,9 +20,11 @@ package org.apache.ignite.ml.dataset.impl.bootstrapping;
 import java.util.Arrays;
 import java.util.Iterator;
 import org.apache.commons.math3.distribution.PoissonDistribution;
+import org.apache.commons.math3.random.Well19937c;
 import org.apache.ignite.ml.dataset.PartitionDataBuilder;
 import org.apache.ignite.ml.dataset.UpstreamEntry;
 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.environment.LearningEnvironment;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 
@@ -69,13 +71,22 @@ public class BootstrappedDatasetBuilder<K,V> implements PartitionDataBuilder<K,V
     }
 
     /** {@inheritDoc} */
-    @Override public BootstrappedDatasetPartition build(Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize,
+    @Override public BootstrappedDatasetPartition build(
+        LearningEnvironment env,
+        Iterator<UpstreamEntry<K, V>> upstreamData,
+        long upstreamDataSize,
         EmptyContext ctx) {
 
         BootstrappedVector[] dataset = new BootstrappedVector[Math.toIntExact(upstreamDataSize)];
 
         int cntr = 0;
-        PoissonDistribution poissonDistribution = new PoissonDistribution(subsampleSize);
+
+        PoissonDistribution poissonDistribution = new PoissonDistribution(
+            new Well19937c(env.randomNumbersGenerator().nextLong()),
+            subsampleSize,
+            PoissonDistribution.DEFAULT_EPSILON,
+            PoissonDistribution.DEFAULT_MAX_ITERATIONS);
+
         while(upstreamData.hasNext()) {
             UpstreamEntry<K, V> nextRow = upstreamData.next();
             Vector features = featureExtractor.apply(nextRow.getKey(), nextRow.getValue());
index 0736906..bde4bb6 100644 (file)
@@ -28,8 +28,10 @@ import org.apache.ignite.lang.IgniteBiPredicate;
 import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.PartitionDataBuilder;
-import org.apache.ignite.ml.dataset.UpstreamTransformerChain;
+import org.apache.ignite.ml.dataset.UpstreamTransformerBuilder;
 import org.apache.ignite.ml.dataset.impl.cache.util.ComputeUtils;
+import org.apache.ignite.ml.environment.LearningEnvironment;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
 import org.apache.ignite.ml.math.functions.IgniteFunction;
@@ -61,8 +63,8 @@ public class CacheBasedDataset<K, V, C extends Serializable, D extends AutoClose
     /** Filter for {@code upstream} data. */
     private final IgniteBiPredicate<K, V> filter;
 
-    /** Chain of transformers applied to upstream. */
-    private final UpstreamTransformerChain<K, V> upstreamTransformers;
+    /** Builder of transformation applied to upstream. */
+    private final UpstreamTransformerBuilder<K, V> upstreamTransformerBuilder;
 
     /** Ignite Cache with partition {@code context}. */
     private final IgniteCache<Integer, C> datasetCache;
@@ -73,6 +75,9 @@ public class CacheBasedDataset<K, V, C extends Serializable, D extends AutoClose
     /** Dataset ID that is used to identify dataset in local storage on the node where computation is performed. */
     private final UUID datasetId;
 
+    /** Learning environment builder. */
+    private final LearningEnvironmentBuilder envBuilder;
+
     /**
      * Constructs a new instance of dataset based on Ignite Cache, which is used as {@code upstream} and as reliable storage for
      * partition {@code context} as well.
@@ -80,7 +85,7 @@ public class CacheBasedDataset<K, V, C extends Serializable, D extends AutoClose
      * @param ignite Ignite instance.
      * @param upstreamCache Ignite Cache with {@code upstream} data.
      * @param filter Filter for {@code upstream} data.
-     * @param upstreamTransformers Transformers of upstream data (see description in {@link DatasetBuilder}).
+     * @param upstreamTransformerBuilder Transformer of upstream data (see description in {@link DatasetBuilder}).
      * @param datasetCache Ignite Cache with partition {@code context}.
      * @param partDataBuilder Partition {@code data} builder.
      * @param datasetId Dataset ID.
@@ -89,39 +94,45 @@ public class CacheBasedDataset<K, V, C extends Serializable, D extends AutoClose
         Ignite ignite,
         IgniteCache<K, V> upstreamCache,
         IgniteBiPredicate<K, V> filter,
-        UpstreamTransformerChain<K, V> upstreamTransformers,
-        IgniteCache<Integer, C> datasetCache, PartitionDataBuilder<K, V, C, D> partDataBuilder,
+        UpstreamTransformerBuilder<K, V> upstreamTransformerBuilder,
+        IgniteCache<Integer, C> datasetCache,
+        LearningEnvironmentBuilder envBuilder,
+        PartitionDataBuilder<K, V, C, D> partDataBuilder,
         UUID datasetId) {
         this.ignite = ignite;
         this.upstreamCache = upstreamCache;
         this.filter = filter;
-        this.upstreamTransformers = upstreamTransformers;
+        this.upstreamTransformerBuilder = upstreamTransformerBuilder;
         this.datasetCache = datasetCache;
         this.partDataBuilder = partDataBuilder;
+        this.envBuilder = envBuilder;
         this.datasetId = datasetId;
     }
 
     /** {@inheritDoc} */
-    @Override public <R> R computeWithCtx(IgniteTriFunction<C, D, Integer, R> map, IgniteBinaryOperator<R> reduce, R identity) {
+    @Override public <R> R computeWithCtx(IgniteTriFunction<C, D, LearningEnvironment, R> map, IgniteBinaryOperator<R> reduce, R identity) {
         String upstreamCacheName = upstreamCache.getName();
         String datasetCacheName = datasetCache.getName();
 
         return computeForAllPartitions(part -> {
+            LearningEnvironment env = ComputeUtils.getLearningEnvironment(ignite, datasetId, part, envBuilder);
+
             C ctx = ComputeUtils.getContext(Ignition.localIgnite(), datasetCacheName, part);
 
             D data = ComputeUtils.getData(
                 Ignition.localIgnite(),
                 upstreamCacheName,
                 filter,
-                upstreamTransformers,
+                upstreamTransformerBuilder,
                 datasetCacheName,
                 datasetId,
-                part,
-                partDataBuilder
+                partDataBuilder,
+                env
             );
 
+
             if (data != null) {
-                R res = map.apply(ctx, data, part);
+                R res = map.apply(ctx, data, env);
 
                 // Saves partition context after update.
                 ComputeUtils.saveContext(Ignition.localIgnite(), datasetCacheName, part, ctx);
@@ -134,23 +145,24 @@ public class CacheBasedDataset<K, V, C extends Serializable, D extends AutoClose
     }
 
     /** {@inheritDoc} */
-    @Override public <R> R compute(IgniteBiFunction<D, Integer, R> map, IgniteBinaryOperator<R> reduce, R identity) {
+    @Override public <R> R compute(IgniteBiFunction<D, LearningEnvironment, R> map, IgniteBinaryOperator<R> reduce, R identity) {
         String upstreamCacheName = upstreamCache.getName();
         String datasetCacheName = datasetCache.getName();
 
         return computeForAllPartitions(part -> {
+            LearningEnvironment env = ComputeUtils.getLearningEnvironment(Ignition.localIgnite(), datasetId, part, envBuilder);
+
             D data = ComputeUtils.getData(
                 Ignition.localIgnite(),
                 upstreamCacheName,
                 filter,
-                upstreamTransformers,
+                upstreamTransformerBuilder,
                 datasetCacheName,
                 datasetId,
-                part,
-                partDataBuilder
+                partDataBuilder,
+                env
             );
-
-            return data != null ? map.apply(data, part) : null;
+            return data != null ? map.apply(data, env) : null;
         }, reduce, identity);
     }
 
@@ -158,6 +170,7 @@ public class CacheBasedDataset<K, V, C extends Serializable, D extends AutoClose
     @Override public void close() {
         datasetCache.destroy();
         ComputeUtils.removeData(ignite, datasetId);
+        ComputeUtils.removeLearningEnv(ignite, datasetId);
     }
 
     /**
index 1d00875..be40158 100644 (file)
@@ -27,9 +27,10 @@ import org.apache.ignite.lang.IgniteBiPredicate;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.PartitionContextBuilder;
 import org.apache.ignite.ml.dataset.PartitionDataBuilder;
-import org.apache.ignite.ml.dataset.UpstreamTransformerChain;
+import org.apache.ignite.ml.dataset.UpstreamTransformerBuilder;
 import org.apache.ignite.ml.dataset.impl.cache.util.ComputeUtils;
 import org.apache.ignite.ml.dataset.impl.cache.util.DatasetAffinityFunctionWrapper;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 
 /**
  * A dataset builder that makes {@link CacheBasedDataset}. Encapsulate logic of building cache based dataset such as
@@ -57,8 +58,8 @@ public class CacheBasedDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
     /** Filter for {@code upstream} data. */
     private final IgniteBiPredicate<K, V> filter;
 
-    /** Chain of upstream transformers. */
-    private final UpstreamTransformerChain<K, V> transformersChain;
+    /** Upstream transformer builder. */
+    private final UpstreamTransformerBuilder<K, V> transformerBuilder;
 
     /**
      * Constructs a new instance of cache based dataset builder that makes {@link CacheBasedDataset} with default
@@ -79,16 +80,32 @@ public class CacheBasedDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
      * @param filter Filter for {@code upstream} data.
      */
     public CacheBasedDatasetBuilder(Ignite ignite, IgniteCache<K, V> upstreamCache, IgniteBiPredicate<K, V> filter) {
+        this(ignite, upstreamCache, filter, UpstreamTransformerBuilder.identity());
+    }
+
+    /**
+     * Constructs a new instance of cache based dataset builder that makes {@link CacheBasedDataset}.
+     *
+     * @param ignite Ignite instance.
+     * @param upstreamCache Ignite Cache with {@code upstream} data.
+     * @param filter Filter for {@code upstream} data.
+     */
+    public CacheBasedDatasetBuilder(Ignite ignite,
+        IgniteCache<K, V> upstreamCache,
+        IgniteBiPredicate<K, V> filter,
+        UpstreamTransformerBuilder<K, V> transformerBuilder) {
         this.ignite = ignite;
         this.upstreamCache = upstreamCache;
         this.filter = filter;
-        transformersChain = UpstreamTransformerChain.empty();
+        this.transformerBuilder = transformerBuilder;
     }
 
     /** {@inheritDoc} */
     @SuppressWarnings("unchecked")
     @Override public <C extends Serializable, D extends AutoCloseable> CacheBasedDataset<K, V, C, D> build(
-        PartitionContextBuilder<K, V, C> partCtxBuilder, PartitionDataBuilder<K, V, C, D> partDataBuilder) {
+        LearningEnvironmentBuilder envBuilder,
+        PartitionContextBuilder<K, V, C> partCtxBuilder,
+        PartitionDataBuilder<K, V, C, D> partDataBuilder) {
         UUID datasetId = UUID.randomUUID();
 
         // Retrieves affinity function of the upstream Ignite Cache.
@@ -106,25 +123,24 @@ public class CacheBasedDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
         ComputeUtils.initContext(
             ignite,
             upstreamCache.getName(),
+            transformerBuilder,
             filter,
-            transformersChain,
             datasetCache.getName(),
             partCtxBuilder,
+            envBuilder,
             RETRIES,
             RETRY_INTERVAL
         );
 
-        return new CacheBasedDataset<>(ignite, upstreamCache, filter, transformersChain, datasetCache, partDataBuilder, datasetId);
+        return new CacheBasedDataset<>(ignite, upstreamCache, filter, transformerBuilder, datasetCache, envBuilder, partDataBuilder, datasetId);
     }
 
     /** {@inheritDoc} */
-    @Override public UpstreamTransformerChain<K, V> upstreamTransformersChain() {
-        return transformersChain;
+    @Override public DatasetBuilder<K, V> withUpstreamTransformer(UpstreamTransformerBuilder<K, V> builder) {
+        return new CacheBasedDatasetBuilder<>(ignite, upstreamCache, filter, transformerBuilder.andThen(builder));
     }
 
-    /**
-     * {@inheritDoc}
-     */
+    /** {@inheritDoc} */
     @Override public DatasetBuilder<K, V> withFilter(IgniteBiPredicate<K, V> filterToAdd) {
         return new CacheBasedDatasetBuilder<>(ignite, upstreamCache,
             (e1, e2) -> filter.apply(e1, e2) && filterToAdd.apply(e1, e2));
index 4f18a18..1dc5591 100644 (file)
@@ -26,6 +26,8 @@ import java.util.HashMap;
 import java.util.Iterator;
 import java.util.Map;
 import java.util.UUID;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
 import java.util.concurrent.locks.LockSupport;
 import java.util.stream.Stream;
 import org.apache.ignite.Ignite;
@@ -41,7 +43,10 @@ import org.apache.ignite.lang.IgniteFuture;
 import org.apache.ignite.ml.dataset.PartitionContextBuilder;
 import org.apache.ignite.ml.dataset.PartitionDataBuilder;
 import org.apache.ignite.ml.dataset.UpstreamEntry;
-import org.apache.ignite.ml.dataset.UpstreamTransformerChain;
+import org.apache.ignite.ml.dataset.UpstreamTransformer;
+import org.apache.ignite.ml.dataset.UpstreamTransformerBuilder;
+import org.apache.ignite.ml.environment.LearningEnvironment;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.math.functions.IgniteFunction;
 import org.apache.ignite.ml.util.Utils;
 
@@ -49,11 +54,12 @@ import org.apache.ignite.ml.util.Utils;
  * Util class that provides common methods to perform computations on top of the Ignite Compute Grid.
  */
 public class ComputeUtils {
-    /**
-     * Template of the key used to store partition {@code data} in local storage.
-     */
+    /** Template of the key used to store partition {@code data} in local storage. */
     private static final String DATA_STORAGE_KEY_TEMPLATE = "part_data_storage_%s";
 
+    /** Template of the key used to store partition {@link LearningEnvironment} in local storage. */
+    private static final String ENVIRONMENT_STORAGE_KEY_TEMPLATE = "part_environment_storage_%s";
+
     /**
      * Calls the specified {@code fun} function on all partitions so that is't guaranteed that partitions with the same
      * index of all specified caches will be placed on the same node and will not be moved before computation is
@@ -134,6 +140,30 @@ public class ComputeUtils {
     }
 
     /**
+     * Gets learning environment for given partition. If learning environment is not found in local node map,
+     * it will be created with specified {@link LearningEnvironmentBuilder}.
+     *
+     * @param ignite Ignite instance.
+     * @param datasetId Dataset id.
+     * @param part Partition index.
+     * @param envBuilder {@link LearningEnvironmentBuilder}.
+     * @return Learning environment for given partition.
+     */
+    public static LearningEnvironment getLearningEnvironment(Ignite ignite,
+        UUID datasetId,
+        int part,
+        LearningEnvironmentBuilder envBuilder) {
+
+        @SuppressWarnings("unchecked")
+        ConcurrentMap<Integer, LearningEnvironment> envStorage = (ConcurrentMap<Integer, LearningEnvironment>)ignite
+            .cluster()
+            .nodeLocalMap()
+            .computeIfAbsent(String.format(ENVIRONMENT_STORAGE_KEY_TEMPLATE, datasetId), key -> new ConcurrentHashMap<>());
+
+        return envStorage.computeIfAbsent(part, envBuilder::buildForWorker);
+    }
+
+    /**
      * Extracts partition {@code data} from the local storage, if it's not found in local storage recovers this {@code
      * data} from a partition {@code upstream} and {@code context}. Be aware that this method should be called from
      * the node where partition is placed.
@@ -141,11 +171,11 @@ public class ComputeUtils {
      * @param ignite Ignite instance.
      * @param upstreamCacheName Name of an {@code upstream} cache.
      * @param filter Filter for {@code upstream} data.
-     * @param transformersChain Upstream transformers.
+     * @param transformerBuilder Builder of upstream transformers.
      * @param datasetCacheName Name of a partition {@code context} cache.
      * @param datasetId Dataset ID.
-     * @param part Partition index.
      * @param partDataBuilder Partition data builder.
+     * @param env Learning environment.
      * @param <K> Type of a key in {@code upstream} data.
      * @param <V> Type of a value in {@code upstream} data.
      * @param <C> Type of a partition {@code context}.
@@ -155,17 +185,18 @@ public class ComputeUtils {
     public static <K, V, C extends Serializable, D extends AutoCloseable> D getData(
         Ignite ignite,
         String upstreamCacheName, IgniteBiPredicate<K, V> filter,
-        UpstreamTransformerChain<K, V> transformersChain,
-        String datasetCacheName,
-        UUID datasetId,
-        int part,
-        PartitionDataBuilder<K, V, C, D> partDataBuilder) {
+        UpstreamTransformerBuilder<K, V> transformerBuilder,
+        String datasetCacheName, UUID datasetId,
+        PartitionDataBuilder<K, V, C, D> partDataBuilder,
+        LearningEnvironment env) {
 
         PartitionDataStorage dataStorage = (PartitionDataStorage)ignite
             .cluster()
             .nodeLocalMap()
             .computeIfAbsent(String.format(DATA_STORAGE_KEY_TEMPLATE, datasetId), key -> new PartitionDataStorage());
 
+        final int part = env.partition();
+
         return dataStorage.computeDataIfAbsent(part, () -> {
             IgniteCache<Integer, C> learningCtxCache = ignite.cache(datasetCacheName);
             C ctx = learningCtxCache.get(part);
@@ -177,25 +208,24 @@ public class ComputeUtils {
             qry.setPartition(part);
             qry.setFilter(filter);
 
-            UpstreamTransformerChain<K, V> chainCopy = Utils.copy(transformersChain);
-            chainCopy.modifySeed(s -> s + part);
+            UpstreamTransformer<K, V> transformer = transformerBuilder.build(env);
+            UpstreamTransformer<K, V> transformerCp = Utils.copy(transformer);
 
-            long cnt = computeCount(upstreamCache, qry, chainCopy);
+            long cnt = computeCount(upstreamCache, qry, transformer);
 
             if (cnt > 0) {
                 try (QueryCursor<UpstreamEntry<K, V>> cursor = upstreamCache.query(qry,
                     e -> new UpstreamEntry<>(e.getKey(), e.getValue()))) {
 
                     Iterator<UpstreamEntry<K, V>> it = cursor.iterator();
-                    if (!chainCopy.isEmpty()) {
-                        Stream<UpstreamEntry<K, V>> transformedStream = chainCopy.transform(Utils.asStream(it, cnt));
-                        it = transformedStream.iterator();
-                    }
+                    Stream<UpstreamEntry<K, V>> transformedStream = transformerCp.transform(Utils.asStream(it, cnt));
+                    it = transformedStream.iterator();
+
 
                     Iterator<UpstreamEntry<K, V>> iter = new IteratorWithConcurrentModificationChecker<>(it, cnt,
                         "Cache expected to be not modified during dataset data building [partition=" + part + ']');
 
-                    return partDataBuilder.build(iter, cnt, ctx);
+                    return partDataBuilder.build(env, iter, cnt, ctx);
                 }
             }
 
@@ -214,27 +244,40 @@ public class ComputeUtils {
     }
 
     /**
+     * Remove learning environment from local cache by Dataset ID.
+     *
+     * @param ignite Ingnite instance.
+     * @param datasetId Dataset ID.
+     */
+    public static void removeLearningEnv(Ignite ignite, UUID datasetId) {
+        ignite.cluster().nodeLocalMap().remove(String.format(ENVIRONMENT_STORAGE_KEY_TEMPLATE, datasetId));
+    }
+
+    /**
      * Initializes partition {@code context} by loading it from a partition {@code upstream}.
-     *  @param <K> Type of a key in {@code upstream} data.
-     * @param <V> Type of a value in {@code upstream} data.
-     * @param <C> Type of a partition {@code context}.
      * @param ignite Ignite instance.
      * @param upstreamCacheName Name of an {@code upstream} cache.
      * @param filter Filter for {@code upstream} data.
-     * @param transformersChain Upstream data {@link Stream} transformers chain.
+     * @param transformerBuilder Upstream transformer builder.
      * @param ctxBuilder Partition {@code context} builder.
+     * @param envBuilder Environment builder.
+     * @param <K> Type of a key in {@code upstream} data.
+     * @param <V> Type of a value in {@code upstream} data.
+     * @param <C> Type of a partition {@code context}.
      */
     public static <K, V, C extends Serializable> void initContext(
         Ignite ignite,
         String upstreamCacheName,
+        UpstreamTransformerBuilder<K, V> transformerBuilder,
         IgniteBiPredicate<K, V> filter,
-        UpstreamTransformerChain<K, V> transformersChain,
         String datasetCacheName,
         PartitionContextBuilder<K, V, C> ctxBuilder,
+        LearningEnvironmentBuilder envBuilder,
         int retries,
         int interval) {
         affinityCallWithRetries(ignite, Arrays.asList(datasetCacheName, upstreamCacheName), part -> {
             Ignite locIgnite = Ignition.localIgnite();
+            LearningEnvironment env = envBuilder.buildForWorker(part);
 
             IgniteCache<K, V> locUpstreamCache = locIgnite.cache(upstreamCacheName);
 
@@ -244,25 +287,24 @@ public class ComputeUtils {
             qry.setFilter(filter);
 
             C ctx;
-            UpstreamTransformerChain<K, V> chainCopy = Utils.copy(transformersChain);
-            chainCopy.modifySeed(s -> s + part);
+            UpstreamTransformer<K, V> transformer = transformerBuilder.build(env);
+            UpstreamTransformer<K, V> transformerCp = Utils.copy(transformer);
 
-            long cnt = computeCount(locUpstreamCache, qry, transformersChain);
+            long cnt = computeCount(locUpstreamCache, qry, transformer);
 
             try (QueryCursor<UpstreamEntry<K, V>> cursor = locUpstreamCache.query(qry,
                 e -> new UpstreamEntry<>(e.getKey(), e.getValue()))) {
 
                 Iterator<UpstreamEntry<K, V>> it = cursor.iterator();
-                if (!chainCopy.isEmpty()) {
-                    Stream<UpstreamEntry<K, V>> transformedStream = chainCopy.transform(Utils.asStream(it, cnt));
-                    it = transformedStream.iterator();
-                }
+                Stream<UpstreamEntry<K, V>> transformedStream = transformerCp.transform(Utils.asStream(it, cnt));
+                it = transformedStream.iterator();
+
                 Iterator<UpstreamEntry<K, V>> iter = new IteratorWithConcurrentModificationChecker<>(
                     it,
                     cnt,
                     "Cache expected to be not modified during dataset data building [partition=" + part + ']');
 
-                ctx = ctxBuilder.build(iter, cnt);
+                ctx = ctxBuilder.build(env, iter, cnt);
             }
 
             IgniteCache<Integer, C> datasetCache = locIgnite.cache(datasetCacheName);
@@ -279,9 +321,10 @@ public class ComputeUtils {
      * @param ignite Ignite instance.
      * @param upstreamCacheName Name of an {@code upstream} cache.
      * @param filter Filter for {@code upstream} data.
-     * @param transformersChain Transformers of upstream data.
+     * @param transformerBuilder Builder of transformer of upstream data.
      * @param datasetCacheName Name of a partition {@code context} cache.
      * @param ctxBuilder Partition {@code context} builder.
+     * @param envBuilder Environment builder.
      * @param retries Number of retries for the case when one of partitions not found on the node.
      * @param <K> Type of a key in {@code upstream} data.
      * @param <V> Type of a value in {@code upstream} data.
@@ -291,11 +334,12 @@ public class ComputeUtils {
         Ignite ignite,
         String upstreamCacheName,
         IgniteBiPredicate<K, V> filter,
-        UpstreamTransformerChain<K, V> transformersChain,
+        UpstreamTransformerBuilder<K, V> transformerBuilder,
         String datasetCacheName,
         PartitionContextBuilder<K, V, C> ctxBuilder,
+        LearningEnvironmentBuilder envBuilder,
         int retries) {
-        initContext(ignite, upstreamCacheName, filter, transformersChain, datasetCacheName, ctxBuilder, retries, 0);
+        initContext(ignite, upstreamCacheName, transformerBuilder, filter, datasetCacheName, ctxBuilder, envBuilder, retries, 0);
     }
 
     /**
@@ -328,25 +372,21 @@ public class ComputeUtils {
     /**
      * Computes number of entries selected from the cache by the query.
      *
-     * @param <K> Type of a key in {@code upstream} data.
-     * @param <V> Type of a value in {@code upstream} data.
      * @param cache Ignite cache with upstream data.
      * @param qry Cache query.
-     * @param transformersChain Transformers of stream of upstream data.
+     * @param transformer Upstream transformer.
+     * @param <K> Type of a key in {@code upstream} data.
+     * @param <V> Type of a value in {@code upstream} data.
      * @return Number of entries supplied by the iterator.
      */
     private static <K, V> long computeCount(
         IgniteCache<K, V> cache,
         ScanQuery<K, V> qry,
-        UpstreamTransformerChain<K, V> transformersChain) {
+        UpstreamTransformer<K, V> transformer) {
         try (QueryCursor<UpstreamEntry<K, V>> cursor = cache.query(qry,
             e -> new UpstreamEntry<>(e.getKey(), e.getValue()))) {
 
-            // 'If' statement below is just for optimization, to avoid unnecessary iterator -> stream -> iterator
-            // operations.
-            return transformersChain.isEmpty() ?
-                computeCount(cursor.iterator()) :
-                computeCount(transformersChain.transform(Utils.asStream(cursor.iterator())).iterator());
+            return computeCount(transformer.transform(Utils.asStream(cursor.iterator())).iterator());
         }
     }
 
index 975beda..8c67c02 100644 (file)
@@ -20,6 +20,7 @@ package org.apache.ignite.ml.dataset.impl.local;
 import java.io.Serializable;
 import java.util.List;
 import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.environment.LearningEnvironment;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
 import org.apache.ignite.ml.math.functions.IgniteTriFunction;
@@ -32,6 +33,9 @@ import org.apache.ignite.ml.math.functions.IgniteTriFunction;
  * @param <D> Type of a partition {@code data}.
  */
 public class LocalDataset<C extends Serializable, D extends AutoCloseable> implements Dataset<C, D> {
+    /** Partition {@code data} storage. */
+    private final List<LearningEnvironment> envs;
+
     /** Partition {@code context} storage. */
     private final List<C> ctx;
 
@@ -42,38 +46,42 @@ public class LocalDataset<C extends Serializable, D extends AutoCloseable> imple
      * Constructs a new instance of dataset based on local data structures such as {@code Map} and {@code List} and
      * doesn't requires Ignite environment.
      *
+     * @param envs List of {@link LearningEnvironment}.
      * @param ctx Partition {@code context} storage.
      * @param data Partition {@code data} storage.
      */
-    LocalDataset(List<C> ctx, List<D> data) {
+    LocalDataset(List<LearningEnvironment> envs, List<C> ctx, List<D> data) {
+        this.envs = envs;
         this.ctx = ctx;
         this.data = data;
     }
 
     /** {@inheritDoc} */
-    @Override public <R> R computeWithCtx(IgniteTriFunction<C, D, Integer, R> map, IgniteBinaryOperator<R> reduce,
+    @Override public <R> R computeWithCtx(IgniteTriFunction<C, D, LearningEnvironment, R> map, IgniteBinaryOperator<R> reduce,
         R identity) {
         R res = identity;
 
         for (int part = 0; part < ctx.size(); part++) {
             D partData = data.get(part);
+            LearningEnvironment env = envs.get(part);
 
             if (partData != null)
-                res = reduce.apply(res, map.apply(ctx.get(part), partData, part));
+                res = reduce.apply(res, map.apply(ctx.get(part), partData, env));
         }
 
         return res;
     }
 
     /** {@inheritDoc} */
-    @Override public <R> R compute(IgniteBiFunction<D, Integer, R> map, IgniteBinaryOperator<R> reduce, R identity) {
+    @Override public <R> R compute(IgniteBiFunction<D, LearningEnvironment, R> map, IgniteBinaryOperator<R> reduce, R identity) {
         R res = identity;
 
         for (int part = 0; part < data.size(); part++) {
             D partData = data.get(part);
+            LearningEnvironment env = envs.get(part);
 
             if (partData != null)
-                res = reduce.apply(res, map.apply(partData, part));
+                res = reduce.apply(res, map.apply(partData, env));
         }
 
         return res;
index 2514f3e..b8cd8dc 100644 (file)
@@ -22,12 +22,17 @@ import java.util.ArrayList;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
 import org.apache.ignite.lang.IgniteBiPredicate;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.PartitionContextBuilder;
 import org.apache.ignite.ml.dataset.PartitionDataBuilder;
 import org.apache.ignite.ml.dataset.UpstreamEntry;
-import org.apache.ignite.ml.dataset.UpstreamTransformerChain;
+import org.apache.ignite.ml.dataset.UpstreamTransformer;
+import org.apache.ignite.ml.dataset.UpstreamTransformerBuilder;
+import org.apache.ignite.ml.environment.LearningEnvironment;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.math.functions.IgniteFunction;
 import org.apache.ignite.ml.util.Utils;
 
@@ -49,7 +54,7 @@ public class LocalDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
     private final IgniteBiPredicate<K, V> filter;
 
     /** Upstream transformers. */
-    private final UpstreamTransformerChain<K, V> upstreamTransformers;
+    private final UpstreamTransformerBuilder<K, V> upstreamTransformerBuilder;
 
     /**
      * Constructs a new instance of local dataset builder that makes {@link LocalDataset} with default predicate that
@@ -68,16 +73,34 @@ public class LocalDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
      * @param upstreamMap {@code Map} with upstream data.
      * @param filter Filter for {@code upstream} data.
      * @param partitions Number of partitions.
+     * @param upstreamTransformerBuilder Builder of upstream transformer.
      */
-    public LocalDatasetBuilder(Map<K, V> upstreamMap, IgniteBiPredicate<K, V> filter, int partitions) {
+    public LocalDatasetBuilder(Map<K, V> upstreamMap,
+        IgniteBiPredicate<K, V> filter,
+        int partitions,
+        UpstreamTransformerBuilder<K, V> upstreamTransformerBuilder) {
         this.upstreamMap = upstreamMap;
         this.filter = filter;
         this.partitions = partitions;
-        this.upstreamTransformers = UpstreamTransformerChain.empty();
+        this.upstreamTransformerBuilder = upstreamTransformerBuilder;
+    }
+
+    /**
+     * Constructs a new instance of local dataset builder that makes {@link LocalDataset}.
+     *
+     * @param upstreamMap {@code Map} with upstream data.
+     * @param filter Filter for {@code upstream} data.
+     * @param partitions Number of partitions.
+     */
+    public LocalDatasetBuilder(Map<K, V> upstreamMap,
+        IgniteBiPredicate<K, V> filter,
+        int partitions) {
+        this(upstreamMap, filter, partitions, UpstreamTransformerBuilder.identity());
     }
 
     /** {@inheritDoc} */
     @Override public <C extends Serializable, D extends AutoCloseable> LocalDataset<C, D> build(
+        LearningEnvironmentBuilder envBuilder,
         PartitionContextBuilder<K, V, C> partCtxBuilder, PartitionDataBuilder<K, V, C, D> partDataBuilder) {
         List<C> ctxList = new ArrayList<>();
         List<D> dataList = new ArrayList<>();
@@ -99,36 +122,29 @@ public class LocalDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
 
         int ptr = 0;
 
+        List<LearningEnvironment> envs = IntStream.range(0, partitions).boxed().map(envBuilder::buildForWorker)
+            .collect(Collectors.toList());
+
         for (int part = 0; part < partitions; part++) {
-            int cnt = part == partitions - 1 ? entriesList.size() - ptr : Math.min(partSize, entriesList.size() - ptr);
-
-            int p = part;
-            upstreamTransformers.modifySeed(s -> s + p);
-
-            if (!upstreamTransformers.isEmpty()) {
-                cnt = (int)upstreamTransformers.transform(
-                    Utils.asStream(new IteratorWindow<>(thirdKeysIter, k -> k, cnt))).count();
-            }
-
-            Iterator<UpstreamEntry<K, V>> iter;
-            if (upstreamTransformers.isEmpty())
-                iter = new IteratorWindow<>(firstKeysIter, k -> k, cnt);
-
-            else {
-                iter = upstreamTransformers.transform(
-                    Utils.asStream(new IteratorWindow<>(firstKeysIter, k -> k, cnt))).iterator();
-            }
-            C ctx = cnt > 0 ? partCtxBuilder.build(iter, cnt) : null;
-
-            Iterator<UpstreamEntry<K, V>> iter1;
-            if (upstreamTransformers.isEmpty()) {
-                iter1 = upstreamTransformers.transform(
-                    Utils.asStream(new IteratorWindow<>(secondKeysIter, k -> k, cnt))).iterator();
-            }
-            else
-                iter1 = new IteratorWindow<>(secondKeysIter, k -> k, cnt);
-
-            D data = cnt > 0 ? partDataBuilder.build(
+            int cntBeforeTransform =
+                part == partitions - 1 ? entriesList.size() - ptr : Math.min(partSize, entriesList.size() - ptr);
+            LearningEnvironment env = envs.get(part);
+            UpstreamTransformer<K, V> transformer1 = upstreamTransformerBuilder.build(env);
+            UpstreamTransformer<K, V> transformer2 = Utils.copy(transformer1);
+            UpstreamTransformer<K, V> transformer3 = Utils.copy(transformer1);
+
+            int cnt = (int)transformer1.transform(Utils.asStream(new IteratorWindow<>(thirdKeysIter, k -> k, cntBeforeTransform))).count();
+
+            Iterator<UpstreamEntry<K, V>> iter =
+                transformer2.transform(Utils.asStream(new IteratorWindow<>(firstKeysIter, k -> k, cntBeforeTransform))).iterator();
+
+            C ctx = cntBeforeTransform > 0 ? partCtxBuilder.build(env, iter, cnt) : null;
+
+            Iterator<UpstreamEntry<K, V>> iter1 = transformer3.transform(
+                    Utils.asStream(new IteratorWindow<>(secondKeysIter, k -> k, cntBeforeTransform))).iterator();
+
+            D data = cntBeforeTransform > 0 ? partDataBuilder.build(
+                env,
                 iter1,
                 cnt,
                 ctx
@@ -137,20 +153,18 @@ public class LocalDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
             ctxList.add(ctx);
             dataList.add(data);
 
-            ptr += cnt;
+            ptr += cntBeforeTransform;
         }
 
-        return new LocalDataset<>(ctxList, dataList);
+        return new LocalDataset<>(envs, ctxList, dataList);
     }
 
     /** {@inheritDoc} */
-    @Override public UpstreamTransformerChain<K, V> upstreamTransformersChain() {
-        return upstreamTransformers;
+    @Override public DatasetBuilder<K, V> withUpstreamTransformer(UpstreamTransformerBuilder<K, V> builder) {
+        return new LocalDatasetBuilder<>(upstreamMap, filter, partitions, upstreamTransformerBuilder.andThen(builder));
     }
 
-    /**
-     * {@inheritDoc}
-     */
+    /** {@inheritDoc} */
     @Override public DatasetBuilder<K, V> withFilter(IgniteBiPredicate<K, V> filterToAdd) {
         return new LocalDatasetBuilder<>(upstreamMap,
             (e1, e2) -> filter.apply(e1, e2) && filterToAdd.apply(e1, e2), partitions);
@@ -164,24 +178,16 @@ public class LocalDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
      * @param <T> Target type of entries.
      */
     private static class IteratorWindow<K, T> implements Iterator<T> {
-        /**
-         * Delegate iterator.
-         */
+        /** Delegate iterator. */
         private final Iterator<K> delegate;
 
-        /**
-         * Transformer that transforms entries from one type to another.
-         */
+        /** Transformer that transforms entries from one type to another. */
         private final IgniteFunction<K, T> map;
 
-        /**
-         * Count of entries to produce.
-         */
+        /** Count of entries to produce. */
         private final int cnt;
 
-        /**
-         * Number of already produced entries.
-         */
+        /** Number of already produced entries. */
         private int ptr;
 
         /**
@@ -197,16 +203,12 @@ public class LocalDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
             this.cnt = cnt;
         }
 
-        /**
-         * {@inheritDoc}
-         */
+        /** {@inheritDoc} */
         @Override public boolean hasNext() {
             return delegate.hasNext() && ptr < cnt;
         }
 
-        /**
-         * {@inheritDoc}
-         */
+        /** {@inheritDoc} */
         @Override public T next() {
             ++ptr;
 
index 578a149..270c7eb 100644 (file)
@@ -19,6 +19,7 @@ package org.apache.ignite.ml.dataset.primitive;
 
 import java.io.Serializable;
 import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.environment.LearningEnvironment;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
 import org.apache.ignite.ml.math.functions.IgniteTriFunction;
@@ -46,13 +47,13 @@ public class DatasetWrapper<C extends Serializable, D extends AutoCloseable> imp
     }
 
     /** {@inheritDoc} */
-    @Override public <R> R computeWithCtx(IgniteTriFunction<C, D, Integer, R> map, IgniteBinaryOperator<R> reduce,
+    @Override public <R> R computeWithCtx(IgniteTriFunction<C, D, LearningEnvironment, R> map, IgniteBinaryOperator<R> reduce,
         R identity) {
         return delegate.computeWithCtx(map, reduce, identity);
     }
 
     /** {@inheritDoc} */
-    @Override public <R> R compute(IgniteBiFunction<D, Integer, R> map, IgniteBinaryOperator<R> reduce, R identity) {
+    @Override public <R> R compute(IgniteBiFunction<D, LearningEnvironment, R> map, IgniteBinaryOperator<R> reduce, R identity) {
         return delegate.compute(map, reduce, identity);
     }
 
index be1724c..5273fa6 100644 (file)
@@ -21,6 +21,7 @@ import java.io.Serializable;
 import java.util.Iterator;
 import org.apache.ignite.ml.dataset.PartitionDataBuilder;
 import org.apache.ignite.ml.dataset.UpstreamEntry;
+import org.apache.ignite.ml.environment.LearningEnvironment;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.tree.data.DecisionTreeData;
@@ -56,7 +57,11 @@ public class FeatureMatrixWithLabelsOnHeapDataBuilder<K, V, C extends Serializab
     }
 
     /** {@inheritDoc} */
-    @Override public FeatureMatrixWithLabelsOnHeapData build(Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize, C ctx) {
+    @Override public FeatureMatrixWithLabelsOnHeapData build(
+        LearningEnvironment env,
+        Iterator<UpstreamEntry<K, V>> upstreamData,
+        long upstreamDataSize,
+        C ctx) {
         double[][] features = new double[Math.toIntExact(upstreamDataSize)][];
         double[] labels = new double[Math.toIntExact(upstreamDataSize)];
 
index 03b69b5..9fd77b5 100644 (file)
@@ -21,6 +21,7 @@ import java.util.Iterator;
 import org.apache.ignite.ml.dataset.PartitionContextBuilder;
 import org.apache.ignite.ml.dataset.UpstreamEntry;
 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.environment.LearningEnvironment;
 
 /**
  * A partition {@code context} builder that makes {@link EmptyContext}.
@@ -33,7 +34,7 @@ public class EmptyContextBuilder<K, V> implements PartitionContextBuilder<K, V,
     private static final long serialVersionUID = 6620781747993467186L;
 
     /** {@inheritDoc} */
-    @Override public EmptyContext build(Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize) {
+    @Override public EmptyContext build(LearningEnvironment env, Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize) {
         return new EmptyContext();
     }
 }
index cf5bc7a..b14d8a2 100644 (file)
@@ -22,6 +22,7 @@ import java.util.Iterator;
 import org.apache.ignite.ml.dataset.PartitionDataBuilder;
 import org.apache.ignite.ml.dataset.UpstreamEntry;
 import org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData;
+import org.apache.ignite.ml.environment.LearningEnvironment;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 
@@ -50,7 +51,9 @@ public class SimpleDatasetDataBuilder<K, V, C extends Serializable>
     }
 
     /** {@inheritDoc} */
-    @Override public SimpleDatasetData build(Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize, C ctx) {
+    @Override public SimpleDatasetData build(
+        LearningEnvironment env,
+        Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize, C ctx) {
         // Prepares the matrix of features in flat column-major format.
         int cols = -1;
         double[] features = null;
index 6286255..48166ee 100644 (file)
@@ -22,6 +22,7 @@ import java.util.Iterator;
 import org.apache.ignite.ml.dataset.PartitionDataBuilder;
 import org.apache.ignite.ml.dataset.UpstreamEntry;
 import org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData;
+import org.apache.ignite.ml.environment.LearningEnvironment;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 
@@ -56,7 +57,9 @@ public class SimpleLabeledDatasetDataBuilder<K, V, C extends Serializable>
     }
 
     /** {@inheritDoc} */
-    @Override public SimpleLabeledDatasetData build(Iterator<UpstreamEntry<K, V>> upstreamData,
+    @Override public SimpleLabeledDatasetData build(
+        LearningEnvironment env,
+        Iterator<UpstreamEntry<K, V>> upstreamData,
         long upstreamDataSize, C ctx) {
         // Prepares the matrix of features in flat column-major format.
         int featureCols = -1;
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/environment/DefaultLearningEnvironmentBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/environment/DefaultLearningEnvironmentBuilder.java
new file mode 100644 (file)
index 0000000..4aef8f2
--- /dev/null
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.environment;
+
+import java.util.Random;
+import org.apache.ignite.ml.environment.logging.MLLogger;
+import org.apache.ignite.ml.environment.logging.NoOpLogger;
+import org.apache.ignite.ml.environment.parallelism.DefaultParallelismStrategy;
+import org.apache.ignite.ml.environment.parallelism.NoParallelismStrategy;
+import org.apache.ignite.ml.environment.parallelism.ParallelismStrategy;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+
+import static org.apache.ignite.ml.math.functions.IgniteFunction.constant;
+
+/**
+ * Builder for {@link LearningEnvironment}.
+ */
+public class DefaultLearningEnvironmentBuilder implements LearningEnvironmentBuilder {
+    /** Serial version id. */
+    private static final long serialVersionUID = 8502532880517447662L;
+
+    /** Dependency (partition -> Parallelism strategy). */
+    private IgniteFunction<Integer, ParallelismStrategy> parallelismStgy;
+
+    /** Dependency (partition -> Logging factory). */
+    private IgniteFunction<Integer, MLLogger.Factory> loggingFactory;
+
+    /** Dependency (partition -> Random number generator seed). */
+    private IgniteFunction<Integer, Long> seed;
+
+    /** Dependency (partition -> Random numbers generator supplier). */
+    private IgniteFunction<Integer, Random> rngSupplier;
+
+    /**
+     * Creates an instance of DefaultLearningEnvironmentBuilder.
+     */
+    DefaultLearningEnvironmentBuilder() {
+        parallelismStgy = constant(NoParallelismStrategy.INSTANCE);
+        loggingFactory = constant(NoOpLogger.factory());
+        seed = constant(new Random().nextLong());
+        rngSupplier = constant(new Random());
+    }
+
+    /** {@inheritDoc} */
+    @Override public LearningEnvironmentBuilder withRNGSeedDependency(IgniteFunction<Integer, Long> seed) {
+        this.seed = seed;
+
+        return this;
+    }
+
+    /** {@inheritDoc} */
+    @Override public LearningEnvironmentBuilder withRandomDependency(IgniteFunction<Integer, Random> rngSupplier) {
+        this.rngSupplier = rngSupplier;
+
+        return this;
+    }
+
+    /** {@inheritDoc} */
+    @Override public DefaultLearningEnvironmentBuilder withParallelismStrategyDependency(
+        IgniteFunction<Integer, ParallelismStrategy> stgy) {
+        this.parallelismStgy = stgy;
+
+        return this;
+    }
+
+    /** {@inheritDoc} */
+    @Override public DefaultLearningEnvironmentBuilder withParallelismStrategyTypeDependency(
+        IgniteFunction<Integer, ParallelismStrategy.Type> stgyType) {
+        this.parallelismStgy = part -> strategyByType(stgyType.apply(part));
+
+        return this;
+    }
+
+    /**
+     * Get parallelism strategy by {@link ParallelismStrategy.Type}.
+     *
+     * @param stgyType Strategy type.
+     * @return {@link ParallelismStrategy}.
+     */
+    private static ParallelismStrategy strategyByType(ParallelismStrategy.Type stgyType) {
+        switch (stgyType) {
+            case NO_PARALLELISM:
+                return NoParallelismStrategy.INSTANCE;
+            case ON_DEFAULT_POOL:
+                return new DefaultParallelismStrategy();
+        }
+        throw new IllegalStateException("Wrong type");
+    }
+
+
+    /** {@inheritDoc} */
+    @Override public DefaultLearningEnvironmentBuilder withLoggingFactoryDependency(
+        IgniteFunction<Integer, MLLogger.Factory> loggingFactory) {
+        this.loggingFactory = loggingFactory;
+        return this;
+    }
+
+    /** {@inheritDoc} */
+    @Override public LearningEnvironment buildForWorker(int part) {
+        Random random = rngSupplier.apply(part);
+        random.setSeed(seed.apply(part));
+        return new LearningEnvironmentImpl(part, random, parallelismStgy.apply(part), loggingFactory.apply(part));
+    }
+
+    /** Default LearningEnvironment implementation. */
+    private class LearningEnvironmentImpl implements LearningEnvironment {
+        /** Parallelism strategy. */
+        private final ParallelismStrategy parallelismStgy;
+
+        /** Logging factory. */
+        private final MLLogger.Factory loggingFactory;
+
+        /** Partition. */
+        private final int part;
+
+        /** Random numbers generator. */
+        private final Random randomNumGen;
+
+        /**
+         * Creates an instance of LearningEnvironmentImpl.
+         *
+         * @param part Partition.
+         * @param rng Random numbers generator.
+         * @param parallelismStgy Parallelism strategy.
+         * @param loggingFactory Logging factory.
+         */
+        private LearningEnvironmentImpl(
+            int part,
+            Random rng,
+            ParallelismStrategy parallelismStgy,
+            MLLogger.Factory loggingFactory) {
+            this.part = part;
+            this.parallelismStgy = parallelismStgy;
+            this.loggingFactory = loggingFactory;
+            randomNumGen = rng;
+        }
+
+        /** {@inheritDoc} */
+        @Override public ParallelismStrategy parallelismStrategy() {
+            return parallelismStgy;
+        }
+
+        /** {@inheritDoc} */
+        @Override public MLLogger logger() {
+            return loggingFactory.create(getClass());
+        }
+
+        /** {@inheritDoc} */
+        @Override public Random randomNumbersGenerator() {
+            return randomNumGen;
+        }
+
+        /** {@inheritDoc} */
+        @Override public <T> MLLogger logger(Class<T> clazz) {
+            return loggingFactory.create(clazz);
+        }
+
+        /** {@inheritDoc} */
+        @Override public int partition() {
+            return part;
+        }
+    }
+}
index f5fb693..f1e4f32 100644 (file)
@@ -17,6 +17,8 @@
 
 package org.apache.ignite.ml.environment;
 
+import java.util.Random;
+import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.environment.logging.MLLogger;
 import org.apache.ignite.ml.environment.parallelism.ParallelismStrategy;
 
@@ -26,7 +28,7 @@ import org.apache.ignite.ml.environment.parallelism.ParallelismStrategy;
  */
 public interface LearningEnvironment {
     /** Default environment */
-    public static final LearningEnvironment DEFAULT = builder().build();
+    public static final LearningEnvironment DEFAULT_TRAINER_ENV = LearningEnvironmentBuilder.defaultBuilder().buildForTrainer();
 
     /**
      * Returns Parallelism Strategy instance.
@@ -39,6 +41,13 @@ public interface LearningEnvironment {
     public MLLogger logger();
 
     /**
+     * Random numbers generator.
+     *
+     * @return Random numbers generator.
+     */
+    public Random randomNumbersGenerator();
+
+    /**
      * Returns an instance of logger for specific class.
      *
      * @param forCls Logging class context.
@@ -46,9 +55,9 @@ public interface LearningEnvironment {
     public <T> MLLogger logger(Class<T> forCls);
 
     /**
-     * Creates an instance of LearningEnvironmentBuilder.
+     * Gets current partition. If this is called not in one of compute tasks of {@link Dataset}, will return -1.
+     *
+     * @return Partition.
      */
-    public static LearningEnvironmentBuilder builder() {
-        return new LearningEnvironmentBuilder();
-    }
+    public int partition();
 }
index 98f584f..8fcc6b2 100644 (file)
 
 package org.apache.ignite.ml.environment;
 
+import java.io.Serializable;
+import java.util.Random;
 import org.apache.ignite.ml.environment.logging.MLLogger;
-import org.apache.ignite.ml.environment.logging.NoOpLogger;
-import org.apache.ignite.ml.environment.parallelism.DefaultParallelismStrategy;
-import org.apache.ignite.ml.environment.parallelism.NoParallelismStrategy;
 import org.apache.ignite.ml.environment.parallelism.ParallelismStrategy;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+
+import static org.apache.ignite.ml.math.functions.IgniteFunction.constant;
 
 /**
- * Builder for LearningEnvironment.
+ * Builder of learning environment.
  */
-public class LearningEnvironmentBuilder {
-    /** Parallelism strategy. */
-    private ParallelismStrategy parallelismStgy;
-    /** Logging factory. */
-    private MLLogger.Factory loggingFactory;
+public interface LearningEnvironmentBuilder extends Serializable {
+    /**
+     * Builds {@link LearningEnvironment} for worker on given partition.
+     *
+     * @param part Partition.
+     * @return {@link LearningEnvironment} for worker on given partition.
+     */
+    public LearningEnvironment buildForWorker(int part);
 
     /**
-     * Creates an instance of LearningEnvironmentBuilder.
+     * Builds learning environment for trainer.
+     *
+     * @return Learning environment for trainer.
      */
-    public LearningEnvironmentBuilder() {
-        parallelismStgy = NoParallelismStrategy.INSTANCE;
-        loggingFactory = NoOpLogger.factory();
+    public default LearningEnvironment buildForTrainer() {
+        return buildForWorker(-1);
     }
 
     /**
-     * Specifies Parallelism Strategy for LearningEnvironment.
+     * Specifies dependency (partition -> Parallelism Strategy Type for LearningEnvironment).
      *
-     * @param stgy Parallelism Strategy.
+     * @param stgyType Function describing dependency (partition -> Parallelism Strategy Type).
+     * @return This object.
      */
-    public <T> LearningEnvironmentBuilder withParallelismStrategy(ParallelismStrategy stgy) {
-        this.parallelismStgy = stgy;
+    public LearningEnvironmentBuilder withParallelismStrategyTypeDependency(
+        IgniteFunction<Integer, ParallelismStrategy.Type> stgyType);
 
-        return this;
+    /**
+     * Specifies Parallelism Strategy Type for LearningEnvironment. Same strategy type will be used for all partitions.
+     *
+     * @param stgyType Parallelism Strategy Type.
+     * @return This object.
+     */
+    public default LearningEnvironmentBuilder withParallelismStrategyType(ParallelismStrategy.Type stgyType) {
+        return withParallelismStrategyTypeDependency(constant(stgyType));
     }
 
     /**
-     * Specifies Parallelism Strategy for LearningEnvironment.
+     * Specifies dependency (partition -> Parallelism Strategy for LearningEnvironment).
      *
-     * @param stgyType Parallelism Strategy Type.
+     * @param stgy Function describing dependency (partition -> Parallelism Strategy).
+     * @return This object.
      */
-    public LearningEnvironmentBuilder withParallelismStrategy(ParallelismStrategy.Type stgyType) {
-        switch (stgyType) {
-            case NO_PARALLELISM:
-                this.parallelismStgy = NoParallelismStrategy.INSTANCE;
-                break;
-            case ON_DEFAULT_POOL:
-                this.parallelismStgy = new DefaultParallelismStrategy();
-                break;
-        }
-        return this;
+    public LearningEnvironmentBuilder withParallelismStrategyDependency(IgniteFunction<Integer, ParallelismStrategy> stgy);
+
+    /**
+     * Specifies Parallelism Strategy for LearningEnvironment. Same strategy type will be used for all partitions.
+     *
+     * @param stgy Parallelism Strategy.
+     * @param <T> Parallelism strategy type.
+     * @return This object.
+     */
+    public default <T extends ParallelismStrategy & Serializable> LearningEnvironmentBuilder withParallelismStrategy(T stgy) {
+        return withParallelismStrategyDependency(constant(stgy));
     }
 
 
     /**
-     * Specifies Logging factory for LearningEnvironment.
+     * Specify dependency (partition -> logging factory).
      *
-     * @param loggingFactory Logging Factory.
+     * @param loggingFactory Function describing (partition -> logging factory).
+     * @return This object.
      */
-    public LearningEnvironmentBuilder withLoggingFactory(MLLogger.Factory loggingFactory) {
-        this.loggingFactory = loggingFactory;
-        return this;
+    public LearningEnvironmentBuilder withLoggingFactoryDependency(IgniteFunction<Integer, MLLogger.Factory> loggingFactory);
+
+    /**
+     * Specify logging factory.
+     *
+     * @param loggingFactory Logging factory.
+     * @return This object.
+     */
+    public default <T extends MLLogger.Factory & Serializable> LearningEnvironmentBuilder withLoggingFactory(T loggingFactory) {
+        return withLoggingFactoryDependency(constant(loggingFactory));
     }
 
     /**
-     * Create an instance of LearningEnvironment.
+     * Specify dependency (partition -> seed for random number generator). Same seed will be used for all partitions.
+     *
+     * @param seed Function describing dependency (partition -> seed for random number generator).
+     * @return This object.
      */
-    public LearningEnvironment build() {
-        return new LearningEnvironmentImpl(parallelismStgy, loggingFactory);
+    public LearningEnvironmentBuilder withRNGSeedDependency(IgniteFunction<Integer, Long> seed);
+
+    /**
+     * Specify seed for random number generator.
+     *
+     * @param seed Seed for random number generator.
+     * @return This object.
+     */
+    public default LearningEnvironmentBuilder withRNGSeed(long seed) {
+        return withRNGSeedDependency(constant(seed));
     }
 
     /**
-     * Default LearningEnvironment implementation.
+     * Specify dependency (partition -> random numbers generator).
+     *
+     * @param rngSupplier Function describing dependency (partition -> random numbers generator).
+     * @return This object.
+     */
+    public LearningEnvironmentBuilder withRandomDependency(IgniteFunction<Integer, Random> rngSupplier);
+
+    /**
+     * Specify random numbers generator for learning environment. Same random will be used for all partitions.
+     *
+     * @param random Rrandom numbers generator for learning environment.
+     * @return This object.
+     */
+    public default LearningEnvironmentBuilder withRandom(Random random) {
+        return withRandomDependency(constant(random));
+    }
+
+    /**
+     * Get default {@link LearningEnvironmentBuilder}.
+     *
+     * @return Default {@link LearningEnvironmentBuilder}.
      */
-    private class LearningEnvironmentImpl implements LearningEnvironment {
-        /** Parallelism strategy. */
-        private final ParallelismStrategy parallelismStgy;
-        /** Logging factory. */
-        private final MLLogger.Factory loggingFactory;
-
-        /**
-         * Creates an instance of LearningEnvironmentImpl.
-         *
-         * @param parallelismStgy Parallelism strategy.
-         * @param loggingFactory Logging factory.
-         */
-        private LearningEnvironmentImpl(ParallelismStrategy parallelismStgy,
-            MLLogger.Factory loggingFactory) {
-            this.parallelismStgy = parallelismStgy;
-            this.loggingFactory = loggingFactory;
-        }
-
-        /** {@inheritDoc} */
-        @Override public ParallelismStrategy parallelismStrategy() {
-            return parallelismStgy;
-        }
-
-        /** {@inheritDoc} */
-        @Override public MLLogger logger() {
-            return loggingFactory.create(getClass());
-        }
-
-        /** {@inheritDoc} */
-        @Override public <T> MLLogger logger(Class<T> clazz) {
-            return loggingFactory.create(clazz);
-        }
+    public static LearningEnvironmentBuilder defaultBuilder() {
+        return new DefaultLearningEnvironmentBuilder();
     }
 }
index e064fc3..c124e06 100644 (file)
@@ -82,6 +82,9 @@ public class ConsoleLogger implements MLLogger {
      * ConsoleLogger factory.
      */
     private static class Factory implements MLLogger.Factory {
+        /** Serial version uuid. */
+        private static final long serialVersionUID = 5864605548782107893L;
+
         /** Max Verbose level. */
         private final VerboseLevel maxVerboseLevel;
 
index e7228f8..329ce89 100644 (file)
@@ -26,7 +26,6 @@ import org.apache.ignite.ml.math.functions.IgniteSupplier;
  * bagging, learning submodels for One-vs-All model, Cross-Validation etc.
  */
 public interface ParallelismStrategy {
-
     /**
      * The type of parallelism.
      */
index d7bccd8..8239ebd 100644 (file)
@@ -21,6 +21,7 @@ import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.PartitionDataBuilder;
 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.structures.LabeledVector;
@@ -35,12 +36,15 @@ public class KNNUtils {
     /**
      * Builds dataset.
      *
+     * @param envBuilder Learning environment builder.
      * @param datasetBuilder Dataset builder.
      * @param featureExtractor Feature extractor.
      * @param lbExtractor Label extractor.
      * @return Dataset.
      */
-    @Nullable public static <K, V> Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> buildDataset(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
+    @Nullable public static <K, V> Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> buildDataset(
+        LearningEnvironmentBuilder envBuilder,
+        DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
         PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<Double, LabeledVector>> partDataBuilder
             = new LabeledDatasetPartitionDataBuilderOnHeap<>(
             featureExtractor,
@@ -51,7 +55,8 @@ public class KNNUtils {
 
         if (datasetBuilder != null) {
             dataset = datasetBuilder.build(
-                (upstream, upstreamSize) -> new EmptyContext(),
+                envBuilder,
+                (env, upstream, upstreamSize) -> new EmptyContext(),
                 partDataBuilder
             );
         }
index e56a10a..c32ca56 100644 (file)
@@ -31,6 +31,7 @@ import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.PartitionDataBuilder;
 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.math.distances.DistanceMeasure;
 import org.apache.ignite.ml.math.distances.EuclideanDistance;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
@@ -105,6 +106,12 @@ public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClass
         return mdl.getDistanceMeasure().equals(distance) && mdl.getCandidates().rowSize() == k;
     }
 
+    /** {@inheritDoc} */
+    @Override public ANNClassificationTrainer withEnvironmentBuilder(
+        LearningEnvironmentBuilder envBuilder) {
+        return (ANNClassificationTrainer)super.withEnvironmentBuilder(envBuilder);
+    }
+
     /** */
     @NotNull private LabeledVectorSet<ProbableLabel, LabeledVector> buildLabelsForCandidates(List<Vector> centers,
         CentroidStat centroidStat) {
@@ -180,7 +187,8 @@ public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClass
         );
 
         try (Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset = datasetBuilder.build(
-            (upstream, upstreamSize) -> new EmptyContext(),
+            envBuilder,
+            (env, upstream, upstreamSize) -> new EmptyContext(),
             partDataBuilder
         )) {
             return dataset.compute(data -> {
index 1a3ff73..ed55318 100644 (file)
@@ -18,6 +18,7 @@
 package org.apache.ignite.ml.knn.classification;
 
 import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.knn.KNNUtils;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
@@ -46,7 +47,7 @@ public class KNNClassificationTrainer extends SingleLabelDatasetTrainer<KNNClass
         DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor,
         IgniteBiFunction<K, V, Double> lbExtractor) {
 
-        KNNClassificationModel res = new KNNClassificationModel(KNNUtils.buildDataset(datasetBuilder,
+        KNNClassificationModel res = new KNNClassificationModel(KNNUtils.buildDataset(envBuilder, datasetBuilder,
             featureExtractor, lbExtractor));
         if (mdl != null)
             res.copyStateFrom(mdl);
@@ -54,6 +55,11 @@ public class KNNClassificationTrainer extends SingleLabelDatasetTrainer<KNNClass
     }
 
     /** {@inheritDoc} */
+    @Override public KNNClassificationTrainer withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
+        return (KNNClassificationTrainer)super.withEnvironmentBuilder(envBuilder);
+    }
+
+    /** {@inheritDoc} */
     @Override protected boolean checkState(KNNClassificationModel mdl) {
         return true;
     }
index 7a42dc8..9b348f3 100644 (file)
@@ -42,10 +42,13 @@ public class KNNRegressionTrainer extends SingleLabelDatasetTrainer<KNNRegressio
     }
 
     /** {@inheritDoc} */
-    @Override public <K, V> KNNRegressionModel updateModel(KNNRegressionModel mdl, DatasetBuilder<K, V> datasetBuilder,
-        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
+    @Override public <K, V> KNNRegressionModel updateModel(
+        KNNRegressionModel mdl,
+        DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
 
-        KNNRegressionModel res = new KNNRegressionModel(KNNUtils.buildDataset(datasetBuilder,
+        KNNRegressionModel res = new KNNRegressionModel(KNNUtils.buildDataset(envBuilder, datasetBuilder,
             featureExtractor, lbExtractor));
         if (mdl != null)
             res.copyStateFrom(mdl);
index 9d19592..2673b90 100644 (file)
@@ -26,5 +26,15 @@ import java.util.function.Function;
  * @see java.util.function.Function
  */
 public interface IgniteFunction<T, R> extends Function<T, R>, Serializable {
-
+    /**
+     * {@link IgniteFunction} returning specified constant.
+     *
+     * @param r Constant to return.
+     * @param <T> Type of input.
+     * @param <R> Type of output.
+     * @return {@link IgniteFunction} returning specified constant.
+     */
+    static <T, R> IgniteFunction<T, R> constant(R r) {
+        return t -> r;
+    }
 }
index 14356e1..e0376b8 100644 (file)
@@ -23,6 +23,7 @@ import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.PartitionDataBuilder;
 import org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 
 /**
  * Distributed implementation of LSQR algorithm based on {@link AbstractLSQR} and {@link Dataset}.
@@ -35,12 +36,15 @@ public class LSQROnHeap<K, V> extends AbstractLSQR implements AutoCloseable {
      * Constructs a new instance of OnHeap LSQR algorithm implementation.
      *
      * @param datasetBuilder Dataset builder.
+     * @param envBuilder Learning environment builder.
      * @param partDataBuilder Partition data builder.
      */
     public LSQROnHeap(DatasetBuilder<K, V> datasetBuilder,
+        LearningEnvironmentBuilder envBuilder,
         PartitionDataBuilder<K, V, LSQRPartitionContext, SimpleLabeledDatasetData> partDataBuilder) {
         this.dataset = datasetBuilder.build(
-            (upstream, upstreamSize) -> new LSQRPartitionContext(),
+            envBuilder,
+            (env, upstream, upstreamSize) -> new LSQRPartitionContext(),
             partDataBuilder
         );
     }
index 5e1341b..3c580c3 100644 (file)
@@ -40,6 +40,17 @@ public class VectorUtils {
     }
 
     /**
+     * Create new vector of specified size n with specified value.
+     *
+     * @param val Value.
+     * @param n Size;
+     * @return New vector of specified size n with specified value.
+     */
+    public static DenseVector fill(double val, int n) {
+        return (DenseVector)new DenseVector(n).assign(val);
+    }
+
+    /**
      * Turn number into a local Vector of given size with one-hot encoding.
      *
      * @param num Number to turn into vector.
index 7426506..f265318 100644 (file)
@@ -115,7 +115,8 @@ public class OneVsRestTrainer<M extends Model<Vector, Double>>
         List<Double> res = new ArrayList<>();
 
         try (Dataset<EmptyContext, LabelPartitionDataOnHeap> dataset = datasetBuilder.build(
-            (upstream, upstreamSize) -> new EmptyContext(),
+            envBuilder,
+            (env, upstream, upstreamSize) -> new EmptyContext(),
             partDataBuilder
         )) {
             final Set<Double> clsLabels = dataset.compute(data -> {
index 7ee423d..cdaac5a 100644 (file)
@@ -24,6 +24,7 @@ import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.UpstreamEntry;
 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
@@ -59,14 +60,20 @@ public class GaussianNaiveBayesTrainer extends SingleLabelDatasetTrainer<Gaussia
     }
 
     /** {@inheritDoc} */
+    @Override public GaussianNaiveBayesTrainer withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
+        return (GaussianNaiveBayesTrainer)super.withEnvironmentBuilder(envBuilder);
+    }
+
+    /** {@inheritDoc} */
     @Override protected <K, V> GaussianNaiveBayesModel updateModel(GaussianNaiveBayesModel mdl,
         DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor,
         IgniteBiFunction<K, V, Double> lbExtractor) {
         assert datasetBuilder != null;
 
         try (Dataset<EmptyContext, GaussianNaiveBayesSumsHolder> dataset = datasetBuilder.build(
-            (upstream, upstreamSize) -> new EmptyContext(),
-            (upstream, upstreamSize, ctx) -> {
+            envBuilder,
+            (env, upstream, upstreamSize) -> new EmptyContext(),
+            (env, upstream, upstreamSize, ctx) -> {
 
                 GaussianNaiveBayesSumsHolder res = new GaussianNaiveBayesSumsHolder();
                 while (upstream.hasNext()) {
index c75c5bb..ea0bb6c 100644 (file)
@@ -124,6 +124,7 @@ public class MLPTrainer<P extends Serializable> extends MultiLabelDatasetTrainer
         assert updatesStgy!= null;
 
         try (Dataset<EmptyContext, SimpleLabeledDatasetData> dataset = datasetBuilder.build(
+            envBuilder,
             new EmptyContextBuilder<>(),
             new SimpleLabeledDatasetDataBuilder<>(featureExtractor, lbExtractor)
         )) {
index 8bfcb34..1aeac6b 100644 (file)
@@ -26,6 +26,7 @@ import org.apache.ignite.ml.Model;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.preprocessing.PreprocessingTrainer;
@@ -53,6 +54,9 @@ public class Pipeline<K, V, R> {
     /** Final trainer stage. */
     private DatasetTrainer finalStage;
 
+    /** Learning environment builder. */
+    private LearningEnvironmentBuilder envBuilder = LearningEnvironmentBuilder.defaultBuilder();
+
     /**
      * Adds feature extractor as a zero stage.
      *
@@ -110,6 +114,15 @@ public class Pipeline<K, V, R> {
     }
 
     /**
+     * Set learning environment builder.
+     *
+     * @param envBuilder Learning environment builder.
+     */
+    public void setEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
+        this.envBuilder = envBuilder;
+    }
+
+    /**
      * Fits the pipeline to the input mock data.
      *
      * @param data Data.
@@ -132,6 +145,7 @@ public class Pipeline<K, V, R> {
         preprocessors.forEach(e -> {
 
             finalFeatureExtractor = e.fit(
+                envBuilder,
                 datasetBuilder,
                 finalFeatureExtractor
             );
index b977864..89751eb 100644 (file)
@@ -23,6 +23,7 @@ import org.apache.ignite.IgniteCache;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 
 /**
@@ -37,24 +38,61 @@ public interface PreprocessingTrainer<K, V, T, R> {
     /**
      * Fits preprocessor.
      *
+     * @param envBuilder Learning environment builder.
      * @param datasetBuilder Dataset builder.
      * @param basePreprocessor Base preprocessor.
      * @return Preprocessor.
      */
-    public IgniteBiFunction<K, V, R> fit(DatasetBuilder<K, V> datasetBuilder,
+    public IgniteBiFunction<K, V, R> fit(
+        LearningEnvironmentBuilder envBuilder,
+        DatasetBuilder<K, V> datasetBuilder,
         IgniteBiFunction<K, V, T> basePreprocessor);
 
     /**
      * Fits preprocessor.
      *
+     * @param datasetBuilder Dataset builder.
+     * @param basePreprocessor Base preprocessor.
+     * @return Preprocessor.
+     */
+    public default IgniteBiFunction<K, V, R> fit(
+        DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, T> basePreprocessor) {
+        return fit(LearningEnvironmentBuilder.defaultBuilder(), datasetBuilder, basePreprocessor);
+    }
+
+    /**
+     * Fits preprocessor.
+     *
+     * @param ignite Ignite instance.
+     * @param cache Ignite cache.
+     * @param basePreprocessor Base preprocessor.
+     * @return Preprocessor.
+     */
+    public default IgniteBiFunction<K, V, R> fit(
+        Ignite ignite, IgniteCache<K, V> cache,
+        IgniteBiFunction<K, V, T> basePreprocessor) {
+        return fit(
+            new CacheBasedDatasetBuilder<>(ignite, cache),
+            basePreprocessor
+        );
+    }
+
+    /**
+     * Fits preprocessor.
+     *
+     * @param envBuilder Learning environment builder.
      * @param ignite Ignite instance.
      * @param cache Ignite cache.
      * @param basePreprocessor Base preprocessor.
      * @return Preprocessor.
      */
-    public default IgniteBiFunction<K, V, R> fit(Ignite ignite, IgniteCache<K, V> cache,
+    public default IgniteBiFunction<K, V, R> fit(
+        LearningEnvironmentBuilder envBuilder,
+        Ignite ignite, IgniteCache<K, V> cache,
         IgniteBiFunction<K, V, T> basePreprocessor) {
         return fit(
+            envBuilder,
             new CacheBasedDatasetBuilder<>(ignite, cache),
             basePreprocessor
         );
@@ -68,7 +106,29 @@ public interface PreprocessingTrainer<K, V, T, R> {
      * @param basePreprocessor Base preprocessor.
      * @return Preprocessor.
      */
-    public default IgniteBiFunction<K, V, R> fit(Map<K, V> data, int parts,
+    public default IgniteBiFunction<K, V, R> fit(
+        LearningEnvironmentBuilder envBuilder,
+        Map<K, V> data,
+        int parts,
+        IgniteBiFunction<K, V, T> basePreprocessor) {
+        return fit(
+            envBuilder,
+            new LocalDatasetBuilder<>(data, parts),
+            basePreprocessor
+        );
+    }
+
+    /**
+     * Fits preprocessor.
+     *
+     * @param data Data.
+     * @param parts Number of partitions.
+     * @param basePreprocessor Base preprocessor.
+     * @return Preprocessor.
+     */
+    public default IgniteBiFunction<K, V, R> fit(
+        Map<K, V> data,
+        int parts,
         IgniteBiFunction<K, V, T> basePreprocessor) {
         return fit(
             new LocalDatasetBuilder<>(data, parts),
index ad8c90e..039794c 100644 (file)
@@ -18,6 +18,7 @@
 package org.apache.ignite.ml.preprocessing.binarization;
 
 import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.preprocessing.PreprocessingTrainer;
@@ -33,7 +34,9 @@ public class BinarizationTrainer<K, V> implements PreprocessingTrainer<K, V, Vec
     private double threshold;
 
     /** {@inheritDoc} */
-    @Override public BinarizationPreprocessor<K, V> fit(DatasetBuilder<K, V> datasetBuilder,
+    @Override public BinarizationPreprocessor<K, V> fit(
+        LearningEnvironmentBuilder envBuilder,
+        DatasetBuilder<K, V> datasetBuilder,
         IgniteBiFunction<K, V, Vector> basePreprocessor) {
         return new BinarizationPreprocessor<>(threshold, basePreprocessor);
     }
index d5668e4..14a509e 100644 (file)
@@ -29,6 +29,7 @@ import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.UpstreamEntry;
 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.preprocessing.PreprocessingTrainer;
@@ -53,14 +54,17 @@ public class EncoderTrainer<K, V> implements PreprocessingTrainer<K, V, Object[]
     private EncoderSortingStrategy encoderSortingStgy = EncoderSortingStrategy.FREQUENCY_DESC;
 
     /** {@inheritDoc} */
-    @Override public EncoderPreprocessor<K, V> fit(DatasetBuilder<K, V> datasetBuilder,
-                                                   IgniteBiFunction<K, V, Object[]> basePreprocessor) {
+    @Override public EncoderPreprocessor<K, V> fit(
+        LearningEnvironmentBuilder envBuilder,
+        DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Object[]> basePreprocessor) {
         if (handledIndices.isEmpty())
             throw new RuntimeException("Add indices of handled features");
 
         try (Dataset<EmptyContext, EncoderPartitionData> dataset = datasetBuilder.build(
-            (upstream, upstreamSize) -> new EmptyContext(),
-            (upstream, upstreamSize, ctx) -> {
+            envBuilder,
+            (env, upstream, upstreamSize) -> new EmptyContext(),
+            (env, upstream, upstreamSize, ctx) -> {
                 // This array will contain not null values for handled indices
                 Map<String, Integer>[] categoryFrequencies = null;
 
index 090b0a4..e8920f3 100644 (file)
@@ -23,8 +23,10 @@ import java.util.Map;
 import java.util.Optional;
 import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.PartitionContextBuilder;
 import org.apache.ignite.ml.dataset.UpstreamEntry;
 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
@@ -43,11 +45,13 @@ public class ImputerTrainer<K, V> implements PreprocessingTrainer<K, V, Vector,
     private ImputingStrategy imputingStgy = ImputingStrategy.MEAN;
 
     /** {@inheritDoc} */
-    @Override public ImputerPreprocessor<K, V> fit(DatasetBuilder<K, V> datasetBuilder,
+    @Override public ImputerPreprocessor<K, V> fit(LearningEnvironmentBuilder envBuilder, DatasetBuilder<K, V> datasetBuilder,
         IgniteBiFunction<K, V, Vector> basePreprocessor) {
+        PartitionContextBuilder<K, V, EmptyContext> builder = (env, upstream, upstreamSize) -> new EmptyContext();
         try (Dataset<EmptyContext, ImputerPartitionData> dataset = datasetBuilder.build(
-            (upstream, upstreamSize) -> new EmptyContext(),
-            (upstream, upstreamSize, ctx) -> {
+            envBuilder,
+            builder,
+            (env, upstream, upstreamSize, ctx) -> {
                 double[] sums = null;
                 int[] counts = null;
                 Map<Double, Integer>[] valuesByFreq = null;
index c8b1dca..52acea3 100644 (file)
@@ -21,6 +21,7 @@ import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.UpstreamEntry;
 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.preprocessing.PreprocessingTrainer;
@@ -33,11 +34,14 @@ import org.apache.ignite.ml.preprocessing.PreprocessingTrainer;
  */
 public class MaxAbsScalerTrainer<K, V> implements PreprocessingTrainer<K, V, Vector, Vector> {
     /** {@inheritDoc} */
-    @Override public MaxAbsScalerPreprocessor<K, V> fit(DatasetBuilder<K, V> datasetBuilder,
+    @Override public MaxAbsScalerPreprocessor<K, V> fit(
+        LearningEnvironmentBuilder envBuilder,
+        DatasetBuilder<K, V> datasetBuilder,
         IgniteBiFunction<K, V, Vector> basePreprocessor) {
         try (Dataset<EmptyContext, MaxAbsScalerPartitionData> dataset = datasetBuilder.build(
-            (upstream, upstreamSize) -> new EmptyContext(),
-            (upstream, upstreamSize, ctx) -> {
+            envBuilder,
+            (env, upstream, upstreamSize) -> new EmptyContext(),
+            (env, upstream, upstreamSize, ctx) -> {
                 double[] maxAbs = null;
 
                 while (upstream.hasNext()) {
index 6a39236..71f2afc 100644 (file)
@@ -19,8 +19,10 @@ package org.apache.ignite.ml.preprocessing.minmaxscaling;
 
 import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.PartitionContextBuilder;
 import org.apache.ignite.ml.dataset.UpstreamEntry;
 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.preprocessing.PreprocessingTrainer;
@@ -33,11 +35,15 @@ import org.apache.ignite.ml.preprocessing.PreprocessingTrainer;
  */
 public class MinMaxScalerTrainer<K, V> implements PreprocessingTrainer<K, V, Vector, Vector> {
     /** {@inheritDoc} */
-    @Override public MinMaxScalerPreprocessor<K, V> fit(DatasetBuilder<K, V> datasetBuilder,
+    @Override public MinMaxScalerPreprocessor<K, V> fit(
+        LearningEnvironmentBuilder envBuilder,
+        DatasetBuilder<K, V> datasetBuilder,
         IgniteBiFunction<K, V, Vector> basePreprocessor) {
+        PartitionContextBuilder<K, V, EmptyContext> ctxBuilder = (env, upstream, upstreamSize) -> new EmptyContext();
         try (Dataset<EmptyContext, MinMaxScalerPartitionData> dataset = datasetBuilder.build(
-            (upstream, upstreamSize) -> new EmptyContext(),
-            (upstream, upstreamSize, ctx) -> {
+            envBuilder,
+            ctxBuilder,
+            (env, upstream, upstreamSize, ctx) -> {
                 double[] min = null;
                 double[] max = null;
 
index b2dc6ed..08c4a68 100644 (file)
@@ -18,6 +18,7 @@
 package org.apache.ignite.ml.preprocessing.normalization;
 
 import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.preprocessing.PreprocessingTrainer;
@@ -33,7 +34,9 @@ public class NormalizationTrainer<K, V> implements PreprocessingTrainer<K, V, Ve
     private int p = 2;
 
     /** {@inheritDoc} */
-    @Override public NormalizationPreprocessor<K, V> fit(DatasetBuilder<K, V> datasetBuilder,
+    @Override public NormalizationPreprocessor<K, V> fit(
+        LearningEnvironmentBuilder envBuilder,
+        DatasetBuilder<K, V> datasetBuilder,
         IgniteBiFunction<K, V, Vector> basePreprocessor) {
         return new NormalizationPreprocessor<>(p, basePreprocessor);
     }
index 5147b05..604f0b0 100644 (file)
@@ -21,6 +21,7 @@ import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.UpstreamEntry;
 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.preprocessing.PreprocessingTrainer;
@@ -33,9 +34,10 @@ import org.apache.ignite.ml.preprocessing.PreprocessingTrainer;
  */
 public class StandardScalerTrainer<K, V> implements PreprocessingTrainer<K, V, Vector, Vector> {
     /** {@inheritDoc} */
-    @Override public StandardScalerPreprocessor<K, V> fit(DatasetBuilder<K, V> datasetBuilder,
+    @Override public StandardScalerPreprocessor<K, V> fit(LearningEnvironmentBuilder envBuilder,
+        DatasetBuilder<K, V> datasetBuilder,
         IgniteBiFunction<K, V, Vector> basePreprocessor) {
-        StandardScalerData standardScalerData = computeSum(datasetBuilder, basePreprocessor);
+        StandardScalerData standardScalerData = computeSum(envBuilder, datasetBuilder, basePreprocessor);
 
         int n = standardScalerData.sum.length;
         long cnt = standardScalerData.cnt;
@@ -51,11 +53,13 @@ public class StandardScalerTrainer<K, V> implements PreprocessingTrainer<K, V, V
     }
 
     /** Computes sum, squared sum and row count. */
-    private StandardScalerData computeSum(DatasetBuilder<K, V> datasetBuilder,
+    private StandardScalerData computeSum(LearningEnvironmentBuilder envBuilder,
+        DatasetBuilder<K, V> datasetBuilder,
         IgniteBiFunction<K, V, Vector> basePreprocessor) {
         try (Dataset<EmptyContext, StandardScalerData> dataset = datasetBuilder.build(
-            (upstream, upstreamSize) -> new EmptyContext(),
-            (upstream, upstreamSize, ctx) -> {
+            envBuilder,
+            (env, upstream, upstreamSize) -> new EmptyContext(),
+            (env, upstream, upstreamSize, ctx) -> {
                 double[] sum = null;
                 double[] squaredSum = null;
                 long cnt = 0;
index 5497177..dc245d2 100644 (file)
@@ -50,6 +50,7 @@ public class LinearRegressionLSQRTrainer extends SingleLabelDatasetTrainer<Linea
 
         try (LSQROnHeap<K, V> lsqr = new LSQROnHeap<>(
             datasetBuilder,
+            envBuilder,
             new SimpleLabeledDatasetDataBuilder<>(
                 new FeatureExtractorWrapper<>(featureExtractor),
                 lbExtractor.andThen(e -> new double[] {e})
index 71d54fa..fd5a624 100644 (file)
@@ -136,7 +136,8 @@ public class LogRegressionMultiClassTrainer<P extends Serializable>
         List<Double> res = new ArrayList<>();
 
         try (Dataset<EmptyContext, LabelPartitionDataOnHeap> dataset = datasetBuilder.build(
-            (upstream, upstreamSize) -> new EmptyContext(),
+            envBuilder,
+            (env, upstream, upstreamSize) -> new EmptyContext(),
             partDataBuilder
         )) {
             final Set<Double> clsLabels = dataset.compute(data -> {
index 4fba028..5549b08 100644 (file)
@@ -21,6 +21,7 @@ import java.io.Serializable;
 import java.util.Iterator;
 import org.apache.ignite.ml.dataset.PartitionDataBuilder;
 import org.apache.ignite.ml.dataset.UpstreamEntry;
+import org.apache.ignite.ml.environment.LearningEnvironment;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 
 /**
@@ -48,8 +49,11 @@ public class LabelPartitionDataBuilderOnHeap<K, V, C extends Serializable>
     }
 
     /** {@inheritDoc} */
-    @Override public LabelPartitionDataOnHeap build(Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize,
-                                        C ctx) {
+    @Override public LabelPartitionDataOnHeap build(
+        LearningEnvironment env,
+        Iterator<UpstreamEntry<K, V>> upstreamData,
+        long upstreamDataSize,
+        C ctx) {
         double[] y = new double[Math.toIntExact(upstreamDataSize)];
 
         int ptr = 0;
index 0351037..0d054f6 100644 (file)
@@ -21,6 +21,7 @@ import java.io.Serializable;
 import java.util.Iterator;
 import org.apache.ignite.ml.dataset.PartitionDataBuilder;
 import org.apache.ignite.ml.dataset.UpstreamEntry;
+import org.apache.ignite.ml.environment.LearningEnvironment;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.structures.LabeledVector;
@@ -57,8 +58,10 @@ public class LabeledDatasetPartitionDataBuilderOnHeap<K, V, C extends Serializab
     }
 
     /** {@inheritDoc} */
-    @Override public LabeledVectorSet<Double, LabeledVector> build(Iterator<UpstreamEntry<K, V>> upstreamData,
-                                                                   long upstreamDataSize, C ctx) {
+    @Override public LabeledVectorSet<Double, LabeledVector> build(
+        LearningEnvironment env,
+        Iterator<UpstreamEntry<K, V>> upstreamData,
+        long upstreamDataSize, C ctx) {
         int xCols = -1;
         double[][] x = null;
         double[] y = new double[Math.toIntExact(upstreamDataSize)];
index 47666f4..7ceb53b 100644 (file)
@@ -89,7 +89,8 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai
         Vector weights;
 
         try (Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset = datasetBuilder.build(
-            (upstream, upstreamSize) -> new EmptyContext(),
+            envBuilder,
+            (env, upstream, upstreamSize) -> new EmptyContext(),
             partDataBuilder
         )) {
             if (mdl == null) {
index b161914..94f2a99 100644 (file)
@@ -157,7 +157,8 @@ public class SVMLinearMultiClassClassificationTrainer
         List<Double> res = new ArrayList<>();
 
         try (Dataset<EmptyContext, LabelPartitionDataOnHeap> dataset = datasetBuilder.build(
-            (upstream, upstreamSize) -> new EmptyContext(),
+            envBuilder,
+            (env, upstream, upstreamSize) -> new EmptyContext(),
             partDataBuilder
         )) {
             final Set<Double> clsLabels = dataset.compute(data -> {
index f321744..dabf66a 100644 (file)
@@ -26,6 +26,7 @@ import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
 import org.apache.ignite.ml.environment.LearningEnvironment;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.environment.logging.MLLogger;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
@@ -38,8 +39,11 @@ import org.jetbrains.annotations.NotNull;
  * @param <L> Type of a label.
  */
 public abstract class DatasetTrainer<M extends Model, L> {
+    /** Learning environment builder. */
+    protected LearningEnvironmentBuilder envBuilder = LearningEnvironmentBuilder.defaultBuilder();
+
     /** Learning Environment. */
-    protected LearningEnvironment environment = LearningEnvironment.DEFAULT;
+    protected LearningEnvironment environment = envBuilder.buildForTrainer();
 
     /**
      * Trains model based on the specified data.
@@ -289,11 +293,25 @@ public abstract class DatasetTrainer<M extends Model, L> {
     }
 
     /**
-     * Sets learning Environment
-     * @param environment Environment.
+     * Changes learning Environment.
+     *
+     * @param envBuilder Learning environment builder.
+     */
+    // TODO: IGNITE-10441 Think about more elegant ways to perform fluent API.
+    public DatasetTrainer<M, L> withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
+        this.envBuilder  = envBuilder;
+        this.environment = envBuilder.buildForTrainer();
+
+        return this;
+    }
+
+    /**
+     * Get learning environment.
+     *
+     * @return Learning environment.
      */
-    public void setEnvironment(LearningEnvironment environment) {
-        this.environment = environment;
+    public LearningEnvironment learningEnvironment() {
+        return environment;
     }
 
     /**
index 05504c3..1019a39 100644 (file)
 
 package org.apache.ignite.ml.trainers;
 
-import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Random;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
-import org.apache.ignite.lang.IgniteBiPredicate;
 import org.apache.ignite.ml.Model;
 import org.apache.ignite.ml.composition.ModelsComposition;
 import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator;
-import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
-import org.apache.ignite.ml.dataset.PartitionContextBuilder;
-import org.apache.ignite.ml.dataset.PartitionDataBuilder;
-import org.apache.ignite.ml.dataset.UpstreamTransformerChain;
 import org.apache.ignite.ml.environment.LearningEnvironment;
 import org.apache.ignite.ml.environment.logging.MLLogger;
 import org.apache.ignite.ml.environment.parallelism.Promise;
@@ -49,7 +43,7 @@ import org.apache.ignite.ml.util.Utils;
  */
 public class TrainerTransformers {
     /**
-     * Add bagging logic to a given trainer.
+     * Add bagging logic to a given trainer. No features bootstrapping is done.
      *
      * @param ensembleSize Size of ensemble.
      * @param subsampleRatio Subsample ratio to whole dataset.
@@ -63,9 +57,8 @@ public class TrainerTransformers {
         int ensembleSize,
         double subsampleRatio,
         PredictionsAggregator aggregator) {
-        return makeBagged(trainer, ensembleSize, subsampleRatio, -1, -1, aggregator, new Random().nextLong());
+        return makeBagged(trainer, ensembleSize, subsampleRatio, -1, -1, aggregator);
     }
-
     /**
      * Add bagging logic to a given trainer.
      *
@@ -74,31 +67,23 @@ public class TrainerTransformers {
      * @param aggregator Aggregator.
      * @param featureVectorSize Feature vector dimensionality.
      * @param featuresSubspaceDim Feature subspace dimensionality.
-     * @param transformationSeed Transformations seed.
      * @param <M> Type of one model in ensemble.
      * @param <L> Type of labels.
      * @return Bagged trainer.
      */
-    // TODO: IGNITE-10296: Inject capabilities of seeding through learning environment (remove).
     public static <M extends Model<Vector, Double>, L> DatasetTrainer<ModelsComposition, L> makeBagged(
         DatasetTrainer<M, L> trainer,
         int ensembleSize,
         double subsampleRatio,
         int featureVectorSize,
         int featuresSubspaceDim,
-        PredictionsAggregator aggregator,
-        Long transformationSeed) {
+        PredictionsAggregator aggregator) {
         return new DatasetTrainer<ModelsComposition, L>() {
             /** {@inheritDoc} */
             @Override public <K, V> ModelsComposition fit(
                 DatasetBuilder<K, V> datasetBuilder,
                 IgniteBiFunction<K, V, Vector> featureExtractor,
                 IgniteBiFunction<K, V, L> lbExtractor) {
-                datasetBuilder.upstreamTransformersChain().setSeed(
-                    transformationSeed == null
-                        ? new Random().nextLong()
-                        : transformationSeed);
-
                 return runOnEnsemble(
                     (db, i, fe) -> (() -> trainer.fit(db, fe, lbExtractor)),
                     datasetBuilder,
@@ -172,21 +157,17 @@ public class TrainerTransformers {
         log.log(MLLogger.VerboseLevel.LOW, "Start learning.");
 
         List<int[]> mappings = null;
-        if (featuresVectorSize > 0) {
+        if (featuresVectorSize > 0 && featureSubspaceDim != featuresVectorSize) {
             mappings = IntStream.range(0, ensembleSize).mapToObj(
                 modelIdx -> getMapping(
                     featuresVectorSize,
                     featureSubspaceDim,
-                    datasetBuilder.upstreamTransformersChain().seed() + modelIdx))
+                    environment.randomNumbersGenerator().nextLong() + modelIdx))
                 .collect(Collectors.toList());
         }
 
         Long startTs = System.currentTimeMillis();
 
-        datasetBuilder
-            .upstreamTransformersChain()
-            .addUpstreamTransformer(new BaggingUpstreamTransformer<>(subsampleRatio));
-
         List<IgniteSupplier<M>> tasks = new ArrayList<>();
         List<IgniteBiFunction<K, V, Vector>> extractors = new ArrayList<>();
         if (mappings != null) {
@@ -195,10 +176,8 @@ public class TrainerTransformers {
         }
 
         for (int i = 0; i < ensembleSize; i++) {
-            UpstreamTransformerChain<K, V> newChain = Utils.copy(datasetBuilder.upstreamTransformersChain());
-            DatasetBuilder<K, V> newBuilder = withNewChain(datasetBuilder, newChain);
-            int j = i;
-            newChain.modifySeed(s -> s * s + j);
+            DatasetBuilder<K, V> newBuilder =
+                datasetBuilder.withUpstreamTransformer(BaggingUpstreamTransformer.builder(subsampleRatio, i));
             tasks.add(
                 trainingTaskGenerator.apply(newBuilder, i, mappings != null ? extractors.get(i) : extractor));
         }
@@ -338,37 +317,4 @@ public class TrainerTransformers {
             return mapping;
         }
     }
-
-    /**
-     * Creates new dataset builder which is delegate of a given dataset builder in everything except
-     * new transformations chain.
-     *
-     * @param builder Initial builder.
-     * @param chain New chain.
-     * @param <K> Type of keys.
-     * @param <V> Type of values.
-     * @return new dataset builder which is delegate of a given dataset builder in everything except
-     * new transformations chain.
-     */
-    private static <K, V> DatasetBuilder<K, V> withNewChain(
-        DatasetBuilder<K, V> builder,
-        UpstreamTransformerChain<K, V> chain) {
-        return new DatasetBuilder<K, V>() {
-            /** {@inheritDoc} */
-            @Override public <C extends Serializable, D extends AutoCloseable> Dataset<C, D> build(
-                PartitionContextBuilder<K, V, C> partCtxBuilder, PartitionDataBuilder<K, V, C, D> partDataBuilder) {
-                return builder.build(partCtxBuilder, partDataBuilder);
-            }
-
-            /** {@inheritDoc} */
-            @Override public UpstreamTransformerChain<K, V> upstreamTransformersChain() {
-                return chain;
-            }
-
-            /** {@inheritDoc} */
-            @Override public DatasetBuilder<K, V> withFilter(IgniteBiPredicate<K, V> filterToAdd) {
-                return builder.withFilter(filterToAdd);
-            }
-        };
-    }
 }
index f935ebd..7f45fdd 100644 (file)
 
 package org.apache.ignite.ml.trainers.transformers;
 
-import java.util.Random;
 import java.util.stream.Stream;
 import org.apache.commons.math3.distribution.PoissonDistribution;
 import org.apache.commons.math3.random.Well19937c;
 import org.apache.ignite.ml.dataset.UpstreamEntry;
 import org.apache.ignite.ml.dataset.UpstreamTransformer;
+import org.apache.ignite.ml.dataset.UpstreamTransformerBuilder;
 
 /**
  * This class encapsulates the logic needed to do bagging (bootstrap aggregating) by features.
@@ -33,22 +33,43 @@ import org.apache.ignite.ml.dataset.UpstreamTransformer;
  * @param <V> Type of upstream values.
  */
 public class BaggingUpstreamTransformer<K, V> implements UpstreamTransformer<K, V> {
+    /** Serial version uid. */
+    private static final long serialVersionUID = -913152523469994149L;
+
     /** Ratio of subsample to entire upstream size */
     private double subsampleRatio;
 
+    /** Seed used for generating poisson distribution. */
+    private long seed;
+
+    /**
+     * Get builder of {@link BaggingUpstreamTransformer} for a model with a specified index in ensemble.
+     *
+     * @param subsampleRatio Subsample ratio.
+     * @param mdlIdx Index of model in ensemble.
+     * @param <K> Type of upstream keys.
+     * @param <V> Type of upstream values.
+     * @return Builder of {@link BaggingUpstreamTransformer}.
+     */
+    public static <K, V> UpstreamTransformerBuilder<K, V> builder(double subsampleRatio, int mdlIdx) {
+        return env -> new BaggingUpstreamTransformer<>(env.randomNumbersGenerator().nextLong() + mdlIdx, subsampleRatio);
+    }
+
     /**
      * Construct instance of this transformer with a given subsample ratio.
      *
+     * @param seed Seed used for generating poisson distribution which in turn used to make subsamples.
      * @param subsampleRatio Subsample ratio.
      */
-    public BaggingUpstreamTransformer(double subsampleRatio) {
+    public BaggingUpstreamTransformer(long seed, double subsampleRatio) {
         this.subsampleRatio = subsampleRatio;
+        this.seed = seed;
     }
 
     /** {@inheritDoc} */
-    @Override public Stream<UpstreamEntry<K, V>> transform(Random rnd, Stream<UpstreamEntry<K, V>> upstream) {
+    @Override public Stream<UpstreamEntry<K, V>> transform(Stream<UpstreamEntry<K, V>> upstream) {
         PoissonDistribution poisson = new PoissonDistribution(
-            new Well19937c(rnd.nextLong()),
+            new Well19937c(seed),
             subsampleRatio,
             PoissonDistribution.DEFAULT_EPSILON,
             PoissonDistribution.DEFAULT_MAX_ITERATIONS);
index 482c938..510d26e 100644 (file)
@@ -24,6 +24,7 @@ import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.trainers.DatasetTrainer;
@@ -76,6 +77,7 @@ public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends Dataset
     @Override public <K, V> DecisionTreeNode fit(DatasetBuilder<K, V> datasetBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
         try (Dataset<EmptyContext, DecisionTreeData> dataset = datasetBuilder.build(
+            envBuilder,
             new EmptyContextBuilder<>(),
             new DecisionTreeDataBuilder<>(featureExtractor, lbExtractor, usingIdx)
         )) {
@@ -108,6 +110,11 @@ public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends Dataset
         return true;
     }
 
+    /** {@inheritDoc} */
+    @Override public DecisionTree<T> withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
+        return (DecisionTree<T>)super.withEnvironmentBuilder(envBuilder);
+    }
+
     /** */
     public <K,V> DecisionTreeNode fit(Dataset<EmptyContext, DecisionTreeData> dataset) {
         return split(dataset, e -> true, 0, getImpurityMeasureCalculator(dataset));
index 58552f4..321e65f 100644 (file)
@@ -23,6 +23,7 @@ import java.util.Map;
 import java.util.Set;
 import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.tree.data.DecisionTreeData;
 import org.apache.ignite.ml.tree.impurity.ImpurityMeasureCalculator;
 import org.apache.ignite.ml.tree.impurity.gini.GiniImpurityMeasure;
@@ -129,4 +130,9 @@ public class DecisionTreeClassificationTrainer extends DecisionTree<GiniImpurity
 
         return new GiniImpurityMeasureCalculator(encoder, usingIdx);
     }
+
+    /** {@inheritDoc} */
+    @Override public DecisionTreeClassificationTrainer withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
+        return (DecisionTreeClassificationTrainer)super.withEnvironmentBuilder(envBuilder);
+    }
 }
index ea57bcc..2b259f2 100644 (file)
@@ -19,6 +19,7 @@ package org.apache.ignite.ml.tree;
 
 import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.tree.data.DecisionTreeData;
 import org.apache.ignite.ml.tree.impurity.ImpurityMeasureCalculator;
 import org.apache.ignite.ml.tree.impurity.mse.MSEImpurityMeasure;
@@ -69,4 +70,9 @@ public class DecisionTreeRegressionTrainer extends DecisionTree<MSEImpurityMeasu
 
         return new MSEImpurityMeasureCalculator(usingIdx);
     }
+
+    /** {@inheritDoc} */
+    @Override public DecisionTreeRegressionTrainer withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
+        return (DecisionTreeRegressionTrainer)super.withEnvironmentBuilder(envBuilder);
+    }
 }
index b99dc2f..b19652d 100644 (file)
@@ -19,6 +19,7 @@ package org.apache.ignite.ml.tree.boosting;
 
 import org.apache.ignite.ml.composition.boosting.GDBBinaryClassifierTrainer;
 import org.apache.ignite.ml.composition.boosting.GDBLearningStrategy;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.tree.DecisionTreeRegressionTrainer;
 import org.jetbrains.annotations.NotNull;
 
@@ -61,6 +62,11 @@ public class GDBBinaryClassifierOnTreesTrainer extends GDBBinaryClassifierTraine
         return new GDBOnTreesLearningStrategy(usingIdx);
     }
 
+    /** {@inheritDoc} */
+    @Override public GDBBinaryClassifierOnTreesTrainer withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
+        return (GDBBinaryClassifierOnTreesTrainer)super.withEnvironmentBuilder(envBuilder);
+    }
+
     /**
      * Set useIndex parameter and returns trainer instance.
      *
index caac168..71e840c 100644 (file)
@@ -70,6 +70,7 @@ public class GDBOnTreesLearningStrategy  extends GDBLearningStrategy {
             externalLbToInternalMapping, loss, datasetBuilder, featureExtractor, lbExtractor);
 
         try (Dataset<EmptyContext, DecisionTreeData> dataset = datasetBuilder.build(
+            envBuilder,
             new EmptyContextBuilder<>(),
             new DecisionTreeDataBuilder<>(featureExtractor, lbExtractor, useIdx)
         )) {
@@ -95,7 +96,7 @@ public class GDBOnTreesLearningStrategy  extends GDBLearningStrategy {
                 long startTs = System.currentTimeMillis();
                 models.add(decisionTreeTrainer.fit(dataset));
                 double learningTime = (double)(System.currentTimeMillis() - startTs) / 1000.0;
-                environment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "One model training time was %.2fs", learningTime);
+                trainerEnvironment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "One model training time was %.2fs", learningTime);
             }
         }
         catch (Exception e) {
index b6c0b48..9c588ce 100644 (file)
@@ -19,6 +19,7 @@ package org.apache.ignite.ml.tree.boosting;
 
 import org.apache.ignite.ml.composition.boosting.GDBLearningStrategy;
 import org.apache.ignite.ml.composition.boosting.GDBRegressionTrainer;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.tree.DecisionTreeRegressionTrainer;
 import org.jetbrains.annotations.NotNull;
 
@@ -120,4 +121,9 @@ public class GDBRegressionOnTreesTrainer extends GDBRegressionTrainer {
     @Override protected GDBLearningStrategy getLearningStrategy() {
         return new GDBOnTreesLearningStrategy(usingIdx);
     }
+
+    /** {@inheritDoc} */
+    @Override public GDBRegressionOnTreesTrainer withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
+        return (GDBRegressionOnTreesTrainer)super.withEnvironmentBuilder(envBuilder);
+    }
 }
index 4436b07..1378120 100644 (file)
@@ -21,6 +21,7 @@ import java.io.Serializable;
 import java.util.Iterator;
 import org.apache.ignite.ml.dataset.PartitionDataBuilder;
 import org.apache.ignite.ml.dataset.UpstreamEntry;
+import org.apache.ignite.ml.environment.LearningEnvironment;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 
@@ -60,7 +61,11 @@ public class DecisionTreeDataBuilder<K, V, C extends Serializable>
     }
 
     /** {@inheritDoc} */
-    @Override public DecisionTreeData build(Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize, C ctx) {
+    @Override public DecisionTreeData build(
+        LearningEnvironment envBuilder,
+        Iterator<UpstreamEntry<K, V>> upstreamData,
+        long upstreamDataSize,
+        C ctx) {
         double[][] features = new double[Math.toIntExact(upstreamDataSize)][];
         double[] labels = new double[Math.toIntExact(upstreamDataSize)];
 
index 72a97c4..3ee90cb 100644 (file)
@@ -114,6 +114,7 @@ public abstract class RandomForestTrainer<L, S extends ImpurityComputer<Bootstra
         IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
         List<TreeRoot> models = null;
         try (Dataset<EmptyContext, BootstrappedDatasetPartition> dataset = datasetBuilder.build(
+            envBuilder,
             new EmptyContextBuilder<>(),
             new BootstrappedDatasetBuilder<>(featureExtractor, lbExtractor, amountOfTrees, subSampleSize))) {
 
index 4b472cc..1103ef0 100644 (file)
@@ -18,6 +18,7 @@
 package org.apache.ignite.ml;
 
 import java.util.stream.IntStream;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.math.primitives.matrix.Matrix;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.junit.Assert;
@@ -325,4 +326,82 @@ public class TestUtils {
 
         }
     }
+
+    /**
+     * Gets test learning environment builder.
+     *
+     * @return test learning environment builder.
+     */
+    public static LearningEnvironmentBuilder testEnvBuilder() {
+        return testEnvBuilder(123L);
+    }
+
+    /**
+     * Gets test learning environment builder with a given seed.
+     *
+     * @param seed Seed.
+     * @return test learning environment builder.
+     */
+    public static LearningEnvironmentBuilder testEnvBuilder(long seed) {
+        return LearningEnvironmentBuilder.defaultBuilder().withRNGSeed(seed);
+    }
+
+    /**
+     * Simple wrapper class which adds {@link AutoCloseable} to given type.
+     *
+     * @param <T> Type to wrap.
+     */
+    public static class DataWrapper<T> implements AutoCloseable {
+        /**
+         * Value to wrap.
+         */
+        T val;
+
+        /**
+         * Wrap given value in {@link AutoCloseable}.
+         *
+         * @param val Value to wrap.
+         * @param <T> Type of value to wrap.
+         * @return Value wrapped as {@link AutoCloseable}.
+         */
+        public static <T> DataWrapper<T> of(T val) {
+            return new DataWrapper<>(val);
+        }
+
+        /**
+         * Construct instance of this class from given value.
+         *
+         * @param val Value to wrap.
+         */
+        public DataWrapper(T val) {
+            this.val = val;
+        }
+
+        /**
+         * Get wrapped value.
+         *
+         * @return Wrapped value.
+         */
+        public T val() {
+            return val;
+        }
+
+        /** {@inheritDoc} */
+        @Override public void close() throws Exception {
+            if (val instanceof AutoCloseable)
+                ((AutoCloseable)val).close();
+        }
+    }
+
+    /**
+     * Return model which returns given constant.
+     *
+     * @param v Constant value.
+     * @param <T> Type of input.
+     * @param <V> Type of output.
+     * @return Model which returns given constant.
+     */
+    public static <T, V> Model<T, V> constantModel(V v) {
+        return t -> v;
+    }
 }
index 0b42db8..c218a74 100644 (file)
@@ -17,6 +17,7 @@
 
 package org.apache.ignite.ml.composition.boosting.convergence.mean;
 
+import org.apache.ignite.ml.TestUtils;
 import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker;
 import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerTest;
 import org.apache.ignite.ml.dataset.impl.local.LocalDataset;
@@ -25,6 +26,7 @@ import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData;
 import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapDataBuilder;
 import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.junit.Assert;
 import org.junit.Test;
@@ -39,11 +41,14 @@ public class MeanAbsValueConvergenceCheckerTest extends ConvergenceCheckerTest {
             new MeanAbsValueConvergenceCheckerFactory(0.1), datasetBuilder);
 
         double error = checker.computeError(VectorUtils.of(1, 2), 4.0, notConvergedMdl);
+        LearningEnvironmentBuilder envBuilder = TestUtils.testEnvBuilder();
+
         Assert.assertEquals(1.9, error, 0.01);
-        Assert.assertFalse(checker.isConverged(datasetBuilder, notConvergedMdl));
-        Assert.assertTrue(checker.isConverged(datasetBuilder, convergedMdl));
+        Assert.assertFalse(checker.isConverged(envBuilder, datasetBuilder, notConvergedMdl));
+        Assert.assertTrue(checker.isConverged(envBuilder, datasetBuilder, convergedMdl));
 
         try(LocalDataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset = datasetBuilder.build(
+            envBuilder,
             new EmptyContextBuilder<>(), new FeatureMatrixWithLabelsOnHeapDataBuilder<>(fExtr, lbExtr))) {
 
             double onDSError = checker.computeMeanErrorOnDataset(dataset, notConvergedMdl);
@@ -62,6 +67,7 @@ public class MeanAbsValueConvergenceCheckerTest extends ConvergenceCheckerTest {
             new MeanAbsValueConvergenceCheckerFactory(0.1), datasetBuilder);
 
         try(LocalDataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset = datasetBuilder.build(
+            TestUtils.testEnvBuilder(),
             new EmptyContextBuilder<>(), new FeatureMatrixWithLabelsOnHeapDataBuilder<>(fExtr, lbExtr))) {
 
             double onDSError = checker.computeMeanErrorOnDataset(dataset, notConvergedMdl);
index d6880b4..0476a37 100644 (file)
@@ -17,6 +17,7 @@
 
 package org.apache.ignite.ml.composition.boosting.convergence.median;
 
+import org.apache.ignite.ml.TestUtils;
 import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker;
 import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerTest;
 import org.apache.ignite.ml.dataset.impl.local.LocalDataset;
@@ -25,6 +26,7 @@ import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData;
 import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapDataBuilder;
 import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.junit.Assert;
 import org.junit.Test;
@@ -42,10 +44,14 @@ public class MedianOfMedianConvergenceCheckerTest extends ConvergenceCheckerTest
 
         double error = checker.computeError(VectorUtils.of(1, 2), 4.0, notConvergedMdl);
         Assert.assertEquals(1.9, error, 0.01);
-        Assert.assertFalse(checker.isConverged(datasetBuilder, notConvergedMdl));
-        Assert.assertTrue(checker.isConverged(datasetBuilder, convergedMdl));
+
+        LearningEnvironmentBuilder envBuilder = TestUtils.testEnvBuilder();
+
+        Assert.assertFalse(checker.isConverged(envBuilder, datasetBuilder, notConvergedMdl));
+        Assert.assertTrue(checker.isConverged(envBuilder, datasetBuilder, convergedMdl));
 
         try(LocalDataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset = datasetBuilder.build(
+            envBuilder,
             new EmptyContextBuilder<>(), new FeatureMatrixWithLabelsOnHeapDataBuilder<>(fExtr, lbExtr))) {
 
             double onDSError = checker.computeMeanErrorOnDataset(dataset, notConvergedMdl);
index 1cf6dbf..815bd86 100644 (file)
@@ -26,6 +26,7 @@ import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
 import org.apache.ignite.cluster.ClusterNode;
 import org.apache.ignite.configuration.CacheConfiguration;
 import org.apache.ignite.internal.util.IgniteUtils;
+import org.apache.ignite.ml.TestUtils;
 import org.apache.ignite.ml.dataset.UpstreamEntry;
 import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
 
@@ -66,8 +67,9 @@ public class CacheBasedDatasetBuilderTest extends GridCommonAbstractTest {
         CacheBasedDatasetBuilder<Integer, String> builder = new CacheBasedDatasetBuilder<>(ignite, upstreamCache);
 
         CacheBasedDataset<Integer, String, Long, AutoCloseable> dataset = builder.build(
-            (upstream, upstreamSize) -> upstreamSize,
-            (upstream, upstreamSize, ctx) -> null
+            TestUtils.testEnvBuilder(),
+            (env, upstream, upstreamSize) -> upstreamSize,
+            (env, upstream, upstreamSize, ctx) -> null
         );
 
         Affinity<Integer> upstreamAffinity = ignite.affinity(upstreamCache.getName());
@@ -105,14 +107,15 @@ public class CacheBasedDatasetBuilderTest extends GridCommonAbstractTest {
         );
 
         CacheBasedDataset<Integer, Integer, Long, AutoCloseable> dataset = builder.build(
-            (upstream, upstreamSize) -> {
+            TestUtils.testEnvBuilder(),
+            (env, upstream, upstreamSize) -> {
                 UpstreamEntry<Integer, Integer> entry = upstream.next();
                 assertEquals(Integer.valueOf(2), entry.getKey());
                 assertEquals(Integer.valueOf(2), entry.getValue());
                 assertFalse(upstream.hasNext());
                 return 0L;
             },
-            (upstream, upstreamSize, ctx) -> {
+            (env, upstream, upstreamSize, ctx) -> {
                 UpstreamEntry<Integer, Integer> entry = upstream.next();
                 assertEquals(Integer.valueOf(2), entry.getKey());
                 assertEquals(Integer.valueOf(2), entry.getValue());
index a892530..7e31b07 100644 (file)
@@ -38,6 +38,7 @@ import org.apache.ignite.internal.processors.cache.distributed.dht.topology.Grid
 import org.apache.ignite.internal.util.IgniteUtils;
 import org.apache.ignite.internal.util.typedef.G;
 import org.apache.ignite.lang.IgnitePredicate;
+import org.apache.ignite.ml.TestUtils;
 import org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData;
 import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
 
@@ -83,8 +84,9 @@ public class CacheBasedDatasetTest extends GridCommonAbstractTest {
         CacheBasedDatasetBuilder<Integer, String> builder = new CacheBasedDatasetBuilder<>(ignite, upstreamCache);
 
         CacheBasedDataset<Integer, String, Long, SimpleDatasetData> dataset = builder.build(
-            (upstream, upstreamSize) -> upstreamSize,
-            (upstream, upstreamSize, ctx) -> new SimpleDatasetData(new double[0], 0)
+            TestUtils.testEnvBuilder(),
+            (env, upstream, upstreamSize) -> upstreamSize,
+            (env, upstream, upstreamSize, ctx) -> new SimpleDatasetData(new double[0], 0)
         );
 
         assertEquals("Upstream cache name from dataset",
@@ -138,8 +140,9 @@ public class CacheBasedDatasetTest extends GridCommonAbstractTest {
         CacheBasedDatasetBuilder<Integer, String> builder = new CacheBasedDatasetBuilder<>(ignite, upstreamCache);
 
         CacheBasedDataset<Integer, String, Long, SimpleDatasetData> dataset = builder.build(
-            (upstream, upstreamSize) -> upstreamSize,
-            (upstream, upstreamSize, ctx) -> new SimpleDatasetData(new double[0], 0)
+            TestUtils.testEnvBuilder(),
+            (env, upstream, upstreamSize) -> upstreamSize,
+            (env, upstream, upstreamSize, ctx) -> new SimpleDatasetData(new double[0], 0)
         );
 
         assertTrue("Before computation all partitions should not be reserved",
index cee8f4f..202b6bc 100644 (file)
@@ -32,8 +32,9 @@ import org.apache.ignite.cache.affinity.AffinityFunctionContext;
 import org.apache.ignite.cluster.ClusterNode;
 import org.apache.ignite.configuration.CacheConfiguration;
 import org.apache.ignite.internal.util.IgniteUtils;
+import org.apache.ignite.ml.TestUtils;
 import org.apache.ignite.ml.dataset.UpstreamEntry;
-import org.apache.ignite.ml.dataset.UpstreamTransformerChain;
+import org.apache.ignite.ml.dataset.UpstreamTransformerBuilder;
 import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
 
 /**
@@ -179,18 +180,18 @@ public class ComputeUtilsTest extends GridCommonAbstractTest {
                     ignite,
                     upstreamCacheName,
                     (k, v) -> true,
-                    UpstreamTransformerChain.empty(),
+                    UpstreamTransformerBuilder.identity(),
                     datasetCacheName,
                     datasetId,
-                    0,
-                    (upstream, upstreamSize, ctx) -> {
+                    (env, upstream, upstreamSize, ctx) -> {
                         cnt.incrementAndGet();
 
                         assertEquals(1, upstreamSize);
 
                         UpstreamEntry<Integer, Integer> e = upstream.next();
                         return new TestPartitionData(e.getKey() + e.getValue());
-                    }
+                    },
+                    TestUtils.testEnvBuilder().buildForWorker(part)
                 ),
                 0
             );
@@ -229,15 +230,16 @@ public class ComputeUtilsTest extends GridCommonAbstractTest {
             ignite,
             upstreamCacheName,
             (k, v) -> true,
-            UpstreamTransformerChain.empty(),
+            UpstreamTransformerBuilder.identity(),
             datasetCacheName,
-            (upstream, upstreamSize) -> {
+            (env, upstream, upstreamSize) -> {
 
                 assertEquals(1, upstreamSize);
 
                 UpstreamEntry<Integer, Integer> e = upstream.next();
                 return e.getKey() + e.getValue();
             },
+            TestUtils.testEnvBuilder(),
             0
         );
 
index 8dc9354..6088140 100644 (file)
@@ -21,6 +21,7 @@ import java.io.Serializable;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.concurrent.atomic.AtomicLong;
+import org.apache.ignite.ml.TestUtils;
 import org.apache.ignite.ml.dataset.PartitionContextBuilder;
 import org.apache.ignite.ml.dataset.PartitionDataBuilder;
 import org.junit.Test;
@@ -47,7 +48,7 @@ public class LocalDatasetBuilderTest {
 
         AtomicLong cnt = new AtomicLong();
 
-        dataset.compute((partData, partIdx) -> {
+        dataset.compute((partData, env) -> {
            cnt.incrementAndGet();
 
            int[] arr = partData.data;
@@ -55,7 +56,7 @@ public class LocalDatasetBuilderTest {
            assertEquals(10, arr.length);
 
            for (int i = 0; i < 10; i++)
-               assertEquals(partIdx * 10 + i, arr[i]);
+               assertEquals(env.partition() * 10 + i, arr[i]);
         });
 
         assertEquals(10, cnt.intValue());
@@ -74,7 +75,7 @@ public class LocalDatasetBuilderTest {
 
         AtomicLong cnt = new AtomicLong();
 
-        dataset.compute((partData, partIdx) -> {
+        dataset.compute((partData, env) -> {
             cnt.incrementAndGet();
 
             int[] arr = partData.data;
@@ -82,7 +83,7 @@ public class LocalDatasetBuilderTest {
             assertEquals(5, arr.length);
 
             for (int i = 0; i < 5; i++)
-                assertEquals((partIdx * 5 + i) * 2, arr[i]);
+                assertEquals((env.partition() * 5 + i) * 2, arr[i]);
         });
 
         assertEquals(10, cnt.intValue());
@@ -91,10 +92,10 @@ public class LocalDatasetBuilderTest {
     /** */
     private LocalDataset<Serializable, TestPartitionData> buildDataset(
         LocalDatasetBuilder<Integer, Integer> builder) {
-        PartitionContextBuilder<Integer, Integer, Serializable> partCtxBuilder = (upstream, upstreamSize) -> null;
+        PartitionContextBuilder<Integer, Integer, Serializable> partCtxBuilder = (env, upstream, upstreamSize) -> null;
 
         PartitionDataBuilder<Integer, Integer, Serializable, TestPartitionData> partDataBuilder
-            = (upstream, upstreamSize, ctx) -> {
+            = (env, upstream, upstreamSize, ctx) -> {
             int[] arr = new int[Math.toIntExact(upstreamSize)];
 
             int ptr = 0;
@@ -105,6 +106,7 @@ public class LocalDatasetBuilderTest {
         };
 
         return builder.build(
+            TestUtils.testEnvBuilder(),
             partCtxBuilder.andThen(x -> null),
             partDataBuilder.andThen((x, y) -> x)
         );
index eaa03d2..33c0677 100644 (file)
@@ -19,6 +19,7 @@ package org.apache.ignite.ml.dataset.primitive;
 
 import java.util.HashMap;
 import java.util.Map;
+import org.apache.ignite.ml.TestUtils;
 import org.apache.ignite.ml.dataset.DatasetFactory;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.junit.Test;
@@ -43,6 +44,7 @@ public class SimpleDatasetTest {
         try (SimpleDataset<?> dataset = DatasetFactory.createSimpleDataset(
             dataPoints,
             2,
+            TestUtils.testEnvBuilder(),
             (k, v) -> VectorUtils.of(v.getAge(), v.getSalary())
         )) {
             assertArrayEquals("Mean values.", new double[] {37.75, 66000.0}, dataset.mean(), 0);
index f7b0f13..36e540b 100644 (file)
@@ -19,6 +19,7 @@ package org.apache.ignite.ml.dataset.primitive;
 
 import java.util.HashMap;
 import java.util.Map;
+import org.apache.ignite.ml.TestUtils;
 import org.apache.ignite.ml.dataset.DatasetFactory;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.junit.Test;
@@ -47,14 +48,16 @@ public class SimpleLabeledDatasetTest {
         // Creates a local simple dataset containing features and providing standard dataset API.
         try (SimpleLabeledDataset<?> dataset = DatasetFactory.createSimpleLabeledDataset(
             dataPoints,
+            TestUtils.testEnvBuilder(),
             2,
             (k, v) -> VectorUtils.of(v.getAge(), v.getSalary()),
             (k, v) -> new double[] {k, v.getAge(), v.getSalary()}
         )) {
-            assertNull(dataset.compute((data, partIdx) -> {
-                actualFeatures[partIdx] = data.getFeatures();
-                actualLabels[partIdx] = data.getLabels();
-                actualRows[partIdx] = data.getRows();
+            assertNull(dataset.compute((data, env) -> {
+                int part = env.partition();
+                actualFeatures[part] = data.getFeatures();
+                actualLabels[part] = data.getLabels();
+                actualRows[part] = data.getRows();
                 return null;
             }, (k, v) -> null));
         }
index 56f262b..7769092 100644 (file)
@@ -39,7 +39,7 @@ public class LearningEnvironmentBuilderTest {
     /** */
     @Test
     public void basic() {
-        LearningEnvironment env = LearningEnvironment.DEFAULT;
+        LearningEnvironment env = LearningEnvironment.DEFAULT_TRAINER_ENV;
 
         assertNotNull("Strategy", env.parallelismStrategy());
         assertNotNull("Logger", env.logger());
@@ -49,42 +49,44 @@ public class LearningEnvironmentBuilderTest {
     /** */
     @Test
     public void withParallelismStrategy() {
-        assertTrue(LearningEnvironment.builder().withParallelismStrategy(NoParallelismStrategy.INSTANCE).build()
+        assertTrue(LearningEnvironmentBuilder.defaultBuilder().withParallelismStrategyDependency(part -> NoParallelismStrategy.INSTANCE)
+            .buildForTrainer()
             .parallelismStrategy() instanceof NoParallelismStrategy);
 
-        assertTrue(LearningEnvironment.builder().withParallelismStrategy(new DefaultParallelismStrategy()).build()
+        assertTrue(LearningEnvironmentBuilder.defaultBuilder().withParallelismStrategyDependency(part -> new DefaultParallelismStrategy())
+            .buildForTrainer()
             .parallelismStrategy() instanceof DefaultParallelismStrategy);
     }
 
     /** */
     @Test
     public void withParallelismStrategyType() {
-        assertTrue(LearningEnvironment.builder().withParallelismStrategy(NO_PARALLELISM).build()
+        assertTrue(LearningEnvironmentBuilder.defaultBuilder().withParallelismStrategyType(NO_PARALLELISM).buildForTrainer()
             .parallelismStrategy() instanceof NoParallelismStrategy);
 
-        assertTrue(LearningEnvironment.builder().withParallelismStrategy(ON_DEFAULT_POOL).build()
+        assertTrue(LearningEnvironmentBuilder.defaultBuilder().withParallelismStrategyType(ON_DEFAULT_POOL).buildForTrainer()
             .parallelismStrategy() instanceof DefaultParallelismStrategy);
     }
 
     /** */
     @Test
     public void withLoggingFactory() {
-        assertTrue(LearningEnvironment.builder().withLoggingFactory(ConsoleLogger.factory(MLLogger.VerboseLevel.HIGH))
-            .build().logger() instanceof ConsoleLogger);
+        assertTrue(LearningEnvironmentBuilder.defaultBuilder().withLoggingFactoryDependency(part -> ConsoleLogger.factory(MLLogger.VerboseLevel.HIGH))
+            .buildForTrainer().logger() instanceof ConsoleLogger);
 
-        assertTrue(LearningEnvironment.builder().withLoggingFactory(ConsoleLogger.factory(MLLogger.VerboseLevel.HIGH))
-            .build().logger(this.getClass()) instanceof ConsoleLogger);
+        assertTrue(LearningEnvironmentBuilder.defaultBuilder().withLoggingFactoryDependency(part -> ConsoleLogger.factory(MLLogger.VerboseLevel.HIGH))
+            .buildForTrainer().logger(this.getClass()) instanceof ConsoleLogger);
 
-        assertTrue(LearningEnvironment.builder().withLoggingFactory(NoOpLogger.factory())
-            .build().logger() instanceof NoOpLogger);
+        assertTrue(LearningEnvironmentBuilder.defaultBuilder().withLoggingFactoryDependency(part -> NoOpLogger.factory())
+            .buildForTrainer().logger() instanceof NoOpLogger);
 
-        assertTrue(LearningEnvironment.builder().withLoggingFactory(NoOpLogger.factory())
-            .build().logger(this.getClass()) instanceof NoOpLogger);
+        assertTrue(LearningEnvironmentBuilder.defaultBuilder().withLoggingFactoryDependency(part -> NoOpLogger.factory())
+            .buildForTrainer().logger(this.getClass()) instanceof NoOpLogger);
 
-        assertTrue(LearningEnvironment.builder().withLoggingFactory(CustomMLLogger.factory(new NullLogger()))
-            .build().logger() instanceof CustomMLLogger);
+        assertTrue(LearningEnvironmentBuilder.defaultBuilder().withLoggingFactoryDependency(part -> CustomMLLogger.factory(new NullLogger()))
+            .buildForTrainer().logger() instanceof CustomMLLogger);
 
-        assertTrue(LearningEnvironment.builder().withLoggingFactory(CustomMLLogger.factory(new NullLogger()))
-            .build().logger(this.getClass()) instanceof CustomMLLogger);
+        assertTrue(LearningEnvironmentBuilder.defaultBuilder().withLoggingFactoryDependency(part -> CustomMLLogger.factory(new NullLogger()))
+            .buildForTrainer().logger(this.getClass()) instanceof CustomMLLogger);
     }
 }
index 73192f0..4b44196 100644 (file)
 
 package org.apache.ignite.ml.environment;
 
+import java.util.Map;
+import java.util.Random;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.TestUtils;
+import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.PartitionDataBuilder;
 import org.apache.ignite.ml.dataset.feature.FeatureMeta;
+import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
 import org.apache.ignite.ml.environment.logging.ConsoleLogger;
 import org.apache.ignite.ml.environment.logging.MLLogger;
 import org.apache.ignite.ml.environment.parallelism.DefaultParallelismStrategy;
 import org.apache.ignite.ml.environment.parallelism.ParallelismStrategy;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.trainers.DatasetTrainer;
 import org.apache.ignite.ml.tree.randomforest.RandomForestRegressionTrainer;
 import org.apache.ignite.ml.tree.randomforest.data.FeaturesCountSelectionStrategies;
 import org.junit.Test;
 
+import static org.apache.ignite.ml.TestUtils.constantModel;
 import static org.junit.Assert.assertEquals;
 
 /**
@@ -48,13 +62,115 @@ public class LearningEnvironmentTest {
             .withSubSampleSize(0.3)
             .withSeed(0);
 
-        LearningEnvironment environment = LearningEnvironment.builder()
-            .withParallelismStrategy(ParallelismStrategy.Type.ON_DEFAULT_POOL)
-            .withLoggingFactory(ConsoleLogger.factory(MLLogger.VerboseLevel.LOW))
-            .build();
-        trainer.setEnvironment(environment);
-        assertEquals(DefaultParallelismStrategy.class, environment.parallelismStrategy().getClass());
-        assertEquals(ConsoleLogger.class, environment.logger().getClass());
+        LearningEnvironmentBuilder envBuilder = LearningEnvironmentBuilder.defaultBuilder()
+            .withParallelismStrategyType(ParallelismStrategy.Type.ON_DEFAULT_POOL)
+            .withLoggingFactoryDependency(part -> ConsoleLogger.factory(MLLogger.VerboseLevel.LOW));
+
+        trainer.withEnvironmentBuilder(envBuilder);
+
+        assertEquals(DefaultParallelismStrategy.class, trainer.learningEnvironment().parallelismStrategy().getClass());
+        assertEquals(ConsoleLogger.class, trainer.learningEnvironment().logger().getClass());
+    }
+
+    /**
+     * Test random number generator provided by  {@link LearningEnvironment}.
+     * We test that:
+     * 1. Correct random generator is returned for each partition.
+     * 2. Its state is saved between compute calls (for this we do several iterations of compute).
+     */
+    @Test
+    public void testRandomNumbersGenerator() {
+        // We make such builders that provide as functions returning partition index * iteration as random number generator nextInt
+        LearningEnvironmentBuilder envBuilder = TestUtils.testEnvBuilder().withRandomDependency(MockRandom::new);
+        int partitions = 10;
+        int iterations = 2;
+
+        DatasetTrainer<Model<Object, Vector>, Void> trainer = new DatasetTrainer<Model<Object, Vector>, Void>() {
+            /** {@inheritDoc} */
+            @Override public <K, V> Model<Object, Vector> fit(DatasetBuilder<K, V> datasetBuilder,
+                IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Void> lbExtractor) {
+                Dataset<EmptyContext, TestUtils.DataWrapper<Integer>> ds = datasetBuilder.build(envBuilder,
+                    new EmptyContextBuilder<>(),
+                    (PartitionDataBuilder<K, V, EmptyContext, TestUtils.DataWrapper<Integer>>)(env, upstreamData, upstreamDataSize, ctx) ->
+                        TestUtils.DataWrapper.of(env.partition()));
+
+                Vector v = null;
+                for (int iter = 0; iter < iterations; iter++) {
+                    v = ds.compute((dw, env) -> VectorUtils.fill(-1, partitions).set(env.partition(), env.randomNumbersGenerator().nextInt()),
+                        (v1, v2) -> zipOverridingEmpty(v1, v2, -1));
+                }
+                return constantModel(v);
+            }
+
+            /** {@inheritDoc} */
+            @Override protected boolean checkState(Model<Object, Vector> mdl) {
+                return false;
+            }
+
+            /** {@inheritDoc} */
+            @Override protected <K, V> Model<Object, Vector> updateModel(Model<Object, Vector> mdl,
+                DatasetBuilder<K, V> datasetBuilder,
+                IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Void> lbExtractor) {
+                return null;
+            }
+        };
+        trainer.withEnvironmentBuilder(envBuilder);
+        Model<Object, Vector> mdl = trainer.fit(getCacheMock(partitions), partitions, null, null);
+
+        Vector exp = VectorUtils.zeroes(partitions);
+        for (int i = 0; i < partitions; i++)
+            exp.set(i, i * iterations);
+
+
+        Vector res = mdl.apply(null);
+        assertEquals(exp, res);
+    }
+
+    /**
+     * For given two vectors {@code v2, v2} produce vector {@code v} where each component of {@code v}
+     * is produced from corresponding components {@code c1, c2} of {@code v1, v2} respectfully in following way
+     * {@code c = c1 != empty ? c1 : c2}. For example, zipping [2, -1, -1], [-1, 3, -1] will result in [2, 3, -1].
+     *
+     * @param v1 First vector.
+     * @param v2 Second vector.
+     * @param empty Value treated as empty.
+     * @return Result of zipping as described above.
+     */
+    private static Vector zipOverridingEmpty(Vector v1, Vector v2, double empty) {
+        return v1 != null ? (v2 != null ? VectorUtils.zipWith(v1, v2, (d1, d2) -> d1 != empty ? d1 : d2) : v1) : v2;
+    }
+
+    /** Get cache mock */
+    private Map<Integer, Integer> getCacheMock(int partsCnt) {
+        return IntStream.range(0, partsCnt).boxed().collect(Collectors.toMap(x -> x, x -> x));
+    }
+
+    /** Mock random numners generator. */
+    private static class MockRandom extends Random {
+        /** Serial version uuid. */
+        private static final long serialVersionUID = -7738558243461112988L;
+
+        /** Start value. */
+        private int startVal;
+
+        /** Iteration. */
+        private int iter;
+
+        /**
+         * Constructs instance of this class with a specified start value.
+         *
+         * @param startVal Start value.
+         */
+        MockRandom(int startVal) {
+            this.startVal = startVal;
+            iter = 0;
+        }
+
+        /** {@inheritDoc} */
+        @Override public int nextInt() {
+            iter++;
+            return startVal * iter;
+        }
     }
 }
 
index b720695..b743a37 100644 (file)
@@ -20,6 +20,7 @@ package org.apache.ignite.ml.math.isolve.lsqr;
 import java.util.Arrays;
 import java.util.HashMap;
 import java.util.Map;
+import org.apache.ignite.ml.TestUtils;
 import org.apache.ignite.ml.common.TrainerTest;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
@@ -47,6 +48,7 @@ public class LSQROnHeapTest extends TrainerTest {
 
         LSQROnHeap<Integer, double[]> lsqr = new LSQROnHeap<>(
             datasetBuilder,
+            TestUtils.testEnvBuilder(),
             new SimpleLabeledDatasetDataBuilder<>(
                 (k, v) -> VectorUtils.of(Arrays.copyOf(v, v.length - 1)),
                 (k, v) -> new double[]{v[3]}
@@ -80,6 +82,7 @@ public class LSQROnHeapTest extends TrainerTest {
 
         LSQROnHeap<Integer, double[]> lsqr = new LSQROnHeap<>(
             datasetBuilder,
+            TestUtils.testEnvBuilder(),
             new SimpleLabeledDatasetDataBuilder<>(
                 (k, v) -> VectorUtils.of(Arrays.copyOf(v, v.length - 1)),
                 (k, v) -> new double[]{v[3]}
@@ -113,6 +116,7 @@ public class LSQROnHeapTest extends TrainerTest {
 
         try (LSQROnHeap<Integer, double[]> lsqr = new LSQROnHeap<>(
             datasetBuilder,
+            TestUtils.testEnvBuilder(),
             new SimpleLabeledDatasetDataBuilder<>(
                 (k, v) -> VectorUtils.of(Arrays.copyOf(v, v.length - 1)),
                 (k, v) -> new double[]{v[4]}
index 4b7fa33..b611104 100644 (file)
@@ -19,6 +19,7 @@ package org.apache.ignite.ml.preprocessing.binarization;
 
 import java.util.HashMap;
 import java.util.Map;
+import org.apache.ignite.ml.TestUtils;
 import org.apache.ignite.ml.common.TrainerTest;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
@@ -51,6 +52,7 @@ public class BinarizationTrainerTest extends TrainerTest {
         assertEquals(10., binarizationTrainer.getThreshold(), 0);
 
         BinarizationPreprocessor<Integer, double[]> preprocessor = binarizationTrainer.fit(
+            TestUtils.testEnvBuilder(),
             datasetBuilder,
             (k, v) -> VectorUtils.of(v)
         );
@@ -75,6 +77,7 @@ public class BinarizationTrainerTest extends TrainerTest {
         assertEquals(10., binarizationTrainer.getThreshold(), 0);
 
         IgniteBiFunction<Integer, double[], Vector> preprocessor = binarizationTrainer.fit(
+            TestUtils.testEnvBuilder(),
             data,
             parts,
             (k, v) -> VectorUtils.of(v)
index 7c7eabe..f9d56a9 100644 (file)
@@ -19,6 +19,7 @@ package org.apache.ignite.ml.preprocessing.encoding;
 
 import java.util.HashMap;
 import java.util.Map;
+import org.apache.ignite.ml.TestUtils;
 import org.apache.ignite.ml.common.TrainerTest;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
@@ -51,6 +52,7 @@ public class EncoderTrainerTest extends TrainerTest {
             .withEncodedFeature(1);
 
         EncoderPreprocessor<Integer, String[]> preprocessor = strEncoderTrainer.fit(
+            TestUtils.testEnvBuilder(),
             datasetBuilder,
             (k, v) -> v
         );
@@ -77,6 +79,7 @@ public class EncoderTrainerTest extends TrainerTest {
             .withEncodedFeature(1);
 
         EncoderPreprocessor<Integer, Object[]> preprocessor = strEncoderTrainer.fit(
+            TestUtils.testEnvBuilder(),
             datasetBuilder,
             (k, v) -> v
         );
@@ -103,6 +106,7 @@ public class EncoderTrainerTest extends TrainerTest {
             .withEncodedFeature(1);
 
         EncoderPreprocessor<Integer, Object[]> preprocessor = strEncoderTrainer.fit(
+            TestUtils.testEnvBuilder(),
             datasetBuilder,
             (k, v) -> v
         );
@@ -136,6 +140,7 @@ public class EncoderTrainerTest extends TrainerTest {
             .withEncodedFeature(1);
 
         EncoderPreprocessor<Integer, String[]> preprocessor = strEncoderTrainer.fit(
+            TestUtils.testEnvBuilder(),
             datasetBuilder,
             (k, v) -> v
         );
index 9c11d13..f8a5e78 100644 (file)
@@ -19,6 +19,7 @@ package org.apache.ignite.ml.preprocessing.imputing;
 
 import java.util.HashMap;
 import java.util.Map;
+import org.apache.ignite.ml.TestUtils;
 import org.apache.ignite.ml.common.TrainerTest;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
@@ -47,6 +48,7 @@ public class ImputerTrainerTest extends TrainerTest {
             .withImputingStrategy(ImputingStrategy.MOST_FREQUENT);
 
         ImputerPreprocessor<Integer, Vector> preprocessor = imputerTrainer.fit(
+            TestUtils.testEnvBuilder(),
             datasetBuilder,
             (k, v) -> v
         );
index 844468e..fc3433b 100644 (file)
@@ -19,6 +19,7 @@ package org.apache.ignite.ml.preprocessing.maxabsscaling;
 
 import java.util.HashMap;
 import java.util.Map;
+import org.apache.ignite.ml.TestUtils;
 import org.apache.ignite.ml.common.TrainerTest;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
@@ -46,6 +47,7 @@ public class MaxAbsScalerTrainerTest extends TrainerTest {
         MaxAbsScalerTrainer<Integer, Vector> standardizationTrainer = new MaxAbsScalerTrainer<>();
 
         MaxAbsScalerPreprocessor<Integer, Vector> preprocessor = standardizationTrainer.fit(
+            TestUtils.testEnvBuilder(),
             datasetBuilder,
             (k, v) -> v
         );
index 4c0a99f..8716324 100644 (file)
@@ -19,6 +19,7 @@ package org.apache.ignite.ml.preprocessing.minmaxscaling;
 
 import java.util.HashMap;
 import java.util.Map;
+import org.apache.ignite.ml.TestUtils;
 import org.apache.ignite.ml.common.TrainerTest;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
@@ -46,6 +47,7 @@ public class MinMaxScalerTrainerTest extends TrainerTest {
         MinMaxScalerTrainer<Integer, Vector> standardizationTrainer = new MinMaxScalerTrainer<>();
 
         MinMaxScalerPreprocessor<Integer, Vector> preprocessor = standardizationTrainer.fit(
+            TestUtils.testEnvBuilder(),
             datasetBuilder,
             (k, v) -> v
         );
index 9d39354..d8a8191 100644 (file)
@@ -19,6 +19,7 @@ package org.apache.ignite.ml.preprocessing.normalization;
 
 import java.util.HashMap;
 import java.util.Map;
+import org.apache.ignite.ml.TestUtils;
 import org.apache.ignite.ml.common.TrainerTest;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
@@ -50,6 +51,7 @@ public class NormalizationTrainerTest extends TrainerTest {
         assertEquals(3., normalizationTrainer.p(), 0);
 
         NormalizationPreprocessor<Integer, double[]> preprocessor = normalizationTrainer.fit(
+            TestUtils.testEnvBuilder(),
             datasetBuilder,
             (k, v) -> VectorUtils.of(v)
         );
index 6f10b37..839cb20 100644 (file)
@@ -19,6 +19,7 @@ package org.apache.ignite.ml.preprocessing.standardscaling;
 
 import java.util.HashMap;
 import java.util.Map;
+import org.apache.ignite.ml.TestUtils;
 import org.apache.ignite.ml.common.TrainerTest;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
@@ -62,6 +63,7 @@ public class StandardScalerTrainerTest extends TrainerTest {
         double[] expectedMeans = new double[] {0.5, 1.75, 4.5, 0.875};
 
         StandardScalerPreprocessor<Integer, Vector> preprocessor = standardizationTrainer.fit(
+            TestUtils.testEnvBuilder(),
             datasetBuilder,
             (k, v) -> v
         );
@@ -75,6 +77,7 @@ public class StandardScalerTrainerTest extends TrainerTest {
         double[] expectedSigmas = new double[] {0.5, 1.47901995, 14.51723114, 0.93374247};
 
         StandardScalerPreprocessor<Integer, Vector> preprocessor = standardizationTrainer.fit(
+            TestUtils.testEnvBuilder(),
             datasetBuilder,
             (k, v) -> v
         );
index 6f7aa36..1abf7f0 100644 (file)
@@ -48,6 +48,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode;
 import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
 import org.apache.ignite.thread.IgniteThread;
 
+import static org.apache.ignite.ml.TestUtils.testEnvBuilder;
 import static org.junit.Assert.assertArrayEquals;
 
 /**
@@ -288,19 +289,24 @@ public class EvaluatorTest extends GridCommonAbstractTest {
                 .withEncoderType(EncoderType.STRING_ENCODER)
                 .withEncodedFeature(1)
                 .withEncodedFeature(6) // <--- Changed index here
-                .fit(ignite,
+                .fit(
+                    testEnvBuilder(123L),
+                    ignite,
                     cache,
                     featureExtractor
                 );
 
             IgniteBiFunction<Integer, Object[], Vector> imputingPreprocessor = new ImputerTrainer<Integer, Object[]>()
-                .fit(ignite,
+                .fit(
+                    testEnvBuilder(124L),
+                    ignite,
                     cache,
                     strEncoderPreprocessor
                 );
 
             IgniteBiFunction<Integer, Object[], Vector> minMaxScalerPreprocessor = new MinMaxScalerTrainer<Integer, Object[]>()
                 .fit(
+                    testEnvBuilder(125L),
                     ignite,
                     cache,
                     imputingPreprocessor
@@ -309,6 +315,7 @@ public class EvaluatorTest extends GridCommonAbstractTest {
             return new NormalizationTrainer<Integer, Object[]>()
                 .withP(2)
                 .fit(
+                    testEnvBuilder(126L),
                     ignite,
                     cache,
                     minMaxScalerPreprocessor
index a82374b..1b96ce2 100644 (file)
@@ -28,6 +28,8 @@ import org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictio
 import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator;
 import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.environment.LearningEnvironment;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.functions.IgniteTriFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
@@ -75,11 +77,15 @@ public class BaggingTest extends TrainerTest {
                 .withBatchSize(10)
                 .withSeed(123L);
 
+        trainer.withEnvironmentBuilder(TestUtils.testEnvBuilder());
+
         DatasetTrainer<ModelsComposition, Double> baggedTrainer =
             TrainerTransformers.makeBagged(
                 trainer,
                 10,
                 0.7,
+                2,
+                2,
                 new OnMajorityPredictionsAggregator());
 
         ModelsComposition mdl = baggedTrainer.fit(
@@ -98,14 +104,20 @@ public class BaggingTest extends TrainerTest {
      *
      * @param counter Function specifying which data we should count.
      */
-    protected void count(IgniteTriFunction<Long, CountData, Integer, Long> counter) {
+    protected void count(IgniteTriFunction<Long, CountData, LearningEnvironment, Long> counter) {
         Map<Integer, Double[]> cacheMock = getCacheMock();
 
         CountTrainer countTrainer = new CountTrainer(counter);
 
         double subsampleRatio = 0.3;
 
-        ModelsComposition model = TrainerTransformers.makeBagged(countTrainer, 100, subsampleRatio, new MeanValuePredictionsAggregator())
+        ModelsComposition model = TrainerTransformers.makeBagged(
+            countTrainer,
+            100,
+            subsampleRatio,
+            2,
+            2,
+            new MeanValuePredictionsAggregator())
             .fit(cacheMock, parts, null, null);
 
         Double res = model.apply(null);
@@ -155,14 +167,14 @@ public class BaggingTest extends TrainerTest {
         /**
          * Function specifying which entries to count.
          */
-        private final IgniteTriFunction<Long, CountData, Integer, Long> counter;
+        private final IgniteTriFunction<Long, CountData, LearningEnvironment, Long> counter;
 
         /**
          * Construct instance of this class.
          *
          * @param counter Function specifying which entries to count.
          */
-        public CountTrainer(IgniteTriFunction<Long, CountData, Integer, Long> counter) {
+        public CountTrainer(IgniteTriFunction<Long, CountData, LearningEnvironment, Long> counter) {
             this.counter = counter;
         }
 
@@ -172,8 +184,9 @@ public class BaggingTest extends TrainerTest {
             IgniteBiFunction<K, V, Vector> featureExtractor,
             IgniteBiFunction<K, V, Double> lbExtractor) {
             Dataset<Long, CountData> dataset = datasetBuilder.build(
-                (upstreamData, upstreamDataSize) -> upstreamDataSize,
-                (upstreamData, upstreamDataSize, ctx) -> new CountData(upstreamDataSize)
+                TestUtils.testEnvBuilder(),
+                (env, upstreamData, upstreamDataSize) -> upstreamDataSize,
+                (env, upstreamData, upstreamDataSize, ctx) -> new CountData(upstreamDataSize)
             );
 
             Long cnt = dataset.computeWithCtx(counter, BaggingTest::plusOfNullables);
@@ -193,6 +206,11 @@ public class BaggingTest extends TrainerTest {
             IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
             return fit(datasetBuilder, featureExtractor, lbExtractor);
         }
+
+        /** {@inheritDoc} */
+        @Override public CountTrainer withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
+            return (CountTrainer)super.withEnvironmentBuilder(envBuilder);
+        }
     }
 
     /** Data for count trainer. */