IGNITE-9412: [ML] GDB convergence by error support.
authorAlexey Platonov <aplatonovv@gmail.com>
Thu, 6 Sep 2018 09:08:36 +0000 (12:08 +0300)
committerYury Babak <ybabak@gridgain.com>
Thu, 6 Sep 2018 09:08:36 +0000 (12:08 +0300)
this closes #4670

33 files changed:
examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesClassificationTrainerExample.java
examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java
modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBBinaryClassifierTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java
modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBRegressionTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceChecker.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceCheckerFactory.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceChecker.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceCheckerFactory.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/mean/package-info.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceChecker.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceCheckerFactory.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/median/package-info.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/package-info.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/simple/ConvergenceCheckerStub.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/simple/ConvergenceCheckerStubFactory.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/simple/package-info.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/loss/LogLoss.java [moved from modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/LossGradientPerPredictionFunctions.java with 53% similarity]
modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/loss/Loss.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/loss/SquaredError.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/loss/package-info.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/composition/predictionsaggregator/WeightedPredictionsAggregator.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/FeatureMatrixWithLabelsOnHeapData.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/FeatureMatrixWithLabelsOnHeapDataBuilder.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java
modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceCheckerTest.java [new file with mode: 0644]
modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceCheckerTest.java [new file with mode: 0644]
modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceCheckerTest.java [new file with mode: 0644]
modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java
modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java

index 075eab2..e092e5c 100644 (file)
@@ -22,9 +22,8 @@ import org.apache.ignite.IgniteCache;
 import org.apache.ignite.Ignition;
 import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
 import org.apache.ignite.configuration.CacheConfiguration;
-import org.apache.ignite.ml.Model;
 import org.apache.ignite.ml.composition.ModelsComposition;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.apache.ignite.ml.trainers.DatasetTrainer;
 import org.apache.ignite.ml.tree.boosting.GDBBinaryClassifierOnTreesTrainer;
@@ -59,10 +58,11 @@ public class GDBOnTreesClassificationTrainerExample {
                 IgniteCache<Integer, double[]> trainingSet = fillTrainingData(ignite, trainingSetCfg);
 
                 // Create regression trainer.
-                DatasetTrainer<ModelsComposition, Double> trainer = new GDBBinaryClassifierOnTreesTrainer(1.0, 300, 2, 0.);
+                DatasetTrainer<ModelsComposition, Double> trainer = new GDBBinaryClassifierOnTreesTrainer(1.0, 300, 2, 0.)
+                    .withCheckConvergenceStgyFactory(new MeanAbsValueConvergenceCheckerFactory(0.1));
 
                 // Train decision tree model.
-                Model<Vector, Double> mdl = trainer.fit(
+                ModelsComposition mdl = trainer.fit(
                     ignite,
                     trainingSet,
                     (k, v) -> VectorUtils.of(v[0]),
@@ -81,6 +81,8 @@ public class GDBOnTreesClassificationTrainerExample {
                 }
 
                 System.out.println(">>> ---------------------------------");
+                System.out.println(">>> Count of trees = " + mdl.getModels().size());
+                System.out.println(">>> ---------------------------------");
 
                 System.out.println(">>> GDB classification trainer example completed.");
             });
index b2b08d0..3662973 100644 (file)
@@ -24,6 +24,7 @@ import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
 import org.apache.ignite.configuration.CacheConfiguration;
 import org.apache.ignite.ml.Model;
 import org.apache.ignite.ml.composition.ModelsComposition;
+import org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.apache.ignite.ml.trainers.DatasetTrainer;
@@ -59,7 +60,8 @@ public class GDBOnTreesRegressionTrainerExample {
                 IgniteCache<Integer, double[]> trainingSet = fillTrainingData(ignite, trainingSetCfg);
 
                 // Create regression trainer.
-                DatasetTrainer<ModelsComposition, Double> trainer = new GDBRegressionOnTreesTrainer(1.0, 2000, 1, 0.);
+                DatasetTrainer<ModelsComposition, Double> trainer = new GDBRegressionOnTreesTrainer(1.0, 2000, 1, 0.)
+                    .withCheckConvergenceStgyFactory(new MeanAbsValueConvergenceCheckerFactory(0.001));
 
                 // Train decision tree model.
                 Model<Vector, Double> mdl = trainer.fit(
index 3701557..f6ddfed 100644 (file)
@@ -19,24 +19,23 @@ package org.apache.ignite.ml.composition.boosting;
 
 import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.List;
 import java.util.Set;
 import java.util.stream.Collectors;
-import org.apache.ignite.internal.util.typedef.internal.A;
+import org.apache.ignite.ml.composition.boosting.loss.LogLoss;
+import org.apache.ignite.ml.composition.boosting.loss.Loss;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.functions.IgniteFunction;
-import org.apache.ignite.ml.math.functions.IgniteTriFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.structures.LabeledVector;
 import org.apache.ignite.ml.structures.LabeledVectorSet;
 import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap;
 
 /**
- * Trainer for binary classifier using Gradient Boosting.
- * As preparing stage this algorithm learn labels in dataset and create mapping dataset labels to 0 and 1.
- * This algorithm uses gradient of Logarithmic Loss metric [LogLoss] by default in each step of learning.
+ * Trainer for binary classifier using Gradient Boosting. As preparing stage this algorithm learn labels in dataset and
+ * create mapping dataset labels to 0 and 1. This algorithm uses gradient of Logarithmic Loss metric [LogLoss] by
+ * default in each step of learning.
  */
 public abstract class GDBBinaryClassifierTrainer extends GDBTrainer {
     /** External representation of first class. */
@@ -51,9 +50,7 @@ public abstract class GDBBinaryClassifierTrainer extends GDBTrainer {
      * @param cntOfIterations Count of learning iterations.
      */
     public GDBBinaryClassifierTrainer(double gradStepSize, Integer cntOfIterations) {
-        super(gradStepSize,
-            cntOfIterations,
-            LossGradientPerPredictionFunctions.LOG_LOSS);
+        super(gradStepSize, cntOfIterations, new LogLoss());
     }
 
     /**
@@ -61,35 +58,37 @@ public abstract class GDBBinaryClassifierTrainer extends GDBTrainer {
      *
      * @param gradStepSize Grad step size.
      * @param cntOfIterations Count of learning iterations.
-     * @param lossGradient Gradient of loss function. First argument is sample size, second argument is valid answer, third argument is current model prediction.
+     * @param loss Loss function.
      */
-    public GDBBinaryClassifierTrainer(double gradStepSize,
-        Integer cntOfIterations,
-        IgniteTriFunction<Long, Double, Double, Double> lossGradient) {
-
-        super(gradStepSize, cntOfIterations, lossGradient);
+    public GDBBinaryClassifierTrainer(double gradStepSize, Integer cntOfIterations, Loss loss) {
+        super(gradStepSize, cntOfIterations, loss);
     }
 
     /** {@inheritDoc} */
-    @Override protected <V, K> void learnLabels(DatasetBuilder<K, V> builder, IgniteBiFunction<K, V, Vector> featureExtractor,
+    @Override protected <V, K> boolean learnLabels(DatasetBuilder<K, V> builder,
+        IgniteBiFunction<K, V, Vector> featureExtractor,
         IgniteBiFunction<K, V, Double> lExtractor) {
 
-        List<Double> uniqLabels = new ArrayList<Double>(
-            builder.build(new EmptyContextBuilder<>(), new LabeledDatasetPartitionDataBuilderOnHeap<>(featureExtractor, lExtractor))
-                .compute((IgniteFunction<LabeledVectorSet<Double,LabeledVector>, Set<Double>>) x ->
+        Set<Double> uniqLabels = builder.build(new EmptyContextBuilder<>(), new LabeledDatasetPartitionDataBuilderOnHeap<>(featureExtractor, lExtractor))
+            .compute((IgniteFunction<LabeledVectorSet<Double, LabeledVector>, Set<Double>>)x ->
                     Arrays.stream(x.labels()).boxed().collect(Collectors.toSet()), (a, b) -> {
-                        if (a == null)
-                            return b;
-                        if (b == null)
-                            return a;
-                        a.addAll(b);
+                    if (a == null)
+                        return b;
+                    if (b == null)
                         return a;
-                    }
-                ));
+                    a.addAll(b);
+                    return a;
+                }
+            );
 
-        A.ensure(uniqLabels.size() == 2, "Binary classifier expects two types of labels in learning dataset");
-        externalFirstCls = uniqLabels.get(0);
-        externalSecondCls = uniqLabels.get(1);
+        if (uniqLabels != null && uniqLabels.size() == 2) {
+            ArrayList<Double> lblsArray = new ArrayList<>(uniqLabels);
+            externalFirstCls = lblsArray.get(0);
+            externalSecondCls = lblsArray.get(1);
+            return true;
+        } else {
+            return false;
+        }
     }
 
     /** {@inheritDoc} */
index 375748a..737495e 100644 (file)
@@ -22,6 +22,10 @@ import java.util.Arrays;
 import java.util.List;
 import org.apache.ignite.ml.Model;
 import org.apache.ignite.ml.composition.ModelsComposition;
+import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker;
+import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerFactory;
+import org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory;
+import org.apache.ignite.ml.composition.boosting.loss.Loss;
 import org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.environment.LearningEnvironment;
@@ -29,9 +33,9 @@ import org.apache.ignite.ml.environment.logging.MLLogger;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.functions.IgniteFunction;
 import org.apache.ignite.ml.math.functions.IgniteSupplier;
-import org.apache.ignite.ml.math.functions.IgniteTriFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.trainers.DatasetTrainer;
+import org.jetbrains.annotations.NotNull;
 
 /**
  * Learning strategy for gradient boosting.
@@ -44,7 +48,7 @@ public class GDBLearningStrategy {
     protected int cntOfIterations;
 
     /** Loss of gradient. */
-    protected IgniteTriFunction<Long, Double, Double, Double> lossGradient;
+    protected Loss loss;
 
     /** External label to internal mapping. */
     protected IgniteFunction<Double, Double> externalLbToInternalMapping;
@@ -61,9 +65,15 @@ public class GDBLearningStrategy {
     /** Composition weights. */
     protected double[] compositionWeights;
 
+    /** Check convergence strategy factory. */
+    protected ConvergenceCheckerFactory checkConvergenceStgyFactory = new MeanAbsValueConvergenceCheckerFactory(0.001);
+
+    /** Default gradient step size. */
+    private double defaultGradStepSize;
+
     /**
-     * Implementation of gradient boosting iterations. At each step of iterations this algorithm
-     * build a regression model based on gradient of loss-function for current models composition.
+     * Implementation of gradient boosting iterations. At each step of iterations this algorithm build a regression
+     * model based on gradient of loss-function for current models composition.
      *
      * @param datasetBuilder Dataset builder.
      * @param featureExtractor Feature extractor.
@@ -73,18 +83,43 @@ public class GDBLearningStrategy {
     public <K, V> List<Model<Vector, Double>> learnModels(DatasetBuilder<K, V> datasetBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
 
-        List<Model<Vector, Double>> models = new ArrayList<>();
+        return update(null, datasetBuilder, featureExtractor, lbExtractor);
+    }
+
+    /**
+     * Gets state of model in arguments, compare it with training parameters of trainer and if they are fit then
+     * trainer updates model in according to new data and return new model. In other case trains new model.
+     *
+     * @param mdlToUpdate Learned model.
+     * @param datasetBuilder Dataset builder.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @param <K> Type of a key in {@code upstream} data.
+     * @param <V> Type of a value in {@code upstream} data.
+     * @return Updated models list.
+     */
+    public <K,V> List<Model<Vector, Double>> update(GDBTrainer.GDBModel mdlToUpdate,
+        DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
+
+        List<Model<Vector, Double>> models = initLearningState(mdlToUpdate);
+
+        ConvergenceChecker<K, V> convCheck = checkConvergenceStgyFactory.create(sampleSize,
+            externalLbToInternalMapping, loss, datasetBuilder, featureExtractor, lbExtractor);
+
         DatasetTrainer<? extends Model<Vector, Double>, Double> trainer = baseMdlTrainerBuilder.get();
         for (int i = 0; i < cntOfIterations; i++) {
-            double[] weights = Arrays.copyOf(compositionWeights, i);
+            double[] weights = Arrays.copyOf(compositionWeights, models.size());
 
             WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(weights, meanLabelValue);
-            Model<Vector, Double> currComposition = new ModelsComposition(models, aggregator);
+            ModelsComposition currComposition = new ModelsComposition(models, aggregator);
+            if (convCheck.isConverged(datasetBuilder, currComposition))
+                break;
 
             IgniteBiFunction<K, V, Double> lbExtractorWrap = (k, v) -> {
                 Double realAnswer = externalLbToInternalMapping.apply(lbExtractor.apply(k, v));
                 Double mdlAnswer = currComposition.apply(featureExtractor.apply(k, v));
-                return -lossGradient.apply(sampleSize, realAnswer, mdlAnswer);
+                return -loss.gradient(sampleSize, realAnswer, mdlAnswer);
             };
 
             long startTs = System.currentTimeMillis();
@@ -97,6 +132,29 @@ public class GDBLearningStrategy {
     }
 
     /**
+     * Restores state of already learned model if can and sets learning parameters according to this state.
+     *
+     * @param mdlToUpdate Model to update.
+     * @return list of already learned models.
+     */
+    @NotNull protected List<Model<Vector, Double>> initLearningState(GDBTrainer.GDBModel mdlToUpdate) {
+        List<Model<Vector, Double>> models = new ArrayList<>();
+        if(mdlToUpdate != null) {
+            models.addAll(mdlToUpdate.getModels());
+            WeightedPredictionsAggregator aggregator = (WeightedPredictionsAggregator) mdlToUpdate.getPredictionsAggregator();
+            meanLabelValue = aggregator.getBias();
+            compositionWeights = new double[models.size() + cntOfIterations];
+            for(int i = 0; i < models.size(); i++)
+                compositionWeights[i] = aggregator.getWeights()[i];
+        } else {
+            compositionWeights = new double[cntOfIterations];
+        }
+
+        Arrays.fill(compositionWeights, models.size(), compositionWeights.length, defaultGradStepSize);
+        return models;
+    }
+
+    /**
      * Sets learning environment.
      *
      * @param environment Learning Environment.
@@ -117,12 +175,12 @@ public class GDBLearningStrategy {
     }
 
     /**
-     * Sets gradient of loss function.
+     * Loss function.
      *
-     * @param lossGradient Loss gradient.
+     * @param loss Loss function.
      */
-    public GDBLearningStrategy withLossGradient(IgniteTriFunction<Long, Double, Double, Double> lossGradient) {
-        this.lossGradient = lossGradient;
+    public GDBLearningStrategy withLossGradient(Loss loss) {
+        this.loss = loss;
         return this;
     }
 
@@ -141,7 +199,8 @@ public class GDBLearningStrategy {
      *
      * @param buildBaseMdlTrainer Build base model trainer.
      */
-    public GDBLearningStrategy withBaseModelTrainerBuilder(IgniteSupplier<DatasetTrainer<? extends Model<Vector, Double>, Double>> buildBaseMdlTrainer) {
+    public GDBLearningStrategy withBaseModelTrainerBuilder(
+        IgniteSupplier<DatasetTrainer<? extends Model<Vector, Double>, Double>> buildBaseMdlTrainer) {
         this.baseMdlTrainerBuilder = buildBaseMdlTrainer;
         return this;
     }
@@ -175,4 +234,34 @@ public class GDBLearningStrategy {
         this.compositionWeights = compositionWeights;
         return this;
     }
+
+    /**
+     * Sets CheckConvergenceStgyFactory.
+     *
+     * @param factory Factory.
+     */
+    public GDBLearningStrategy withCheckConvergenceStgyFactory(ConvergenceCheckerFactory factory) {
+        this.checkConvergenceStgyFactory = factory;
+        return this;
+    }
+
+    /**
+     * Sets default gradient step size.
+     *
+     * @param defaultGradStepSize Default gradient step size.
+     */
+    public GDBLearningStrategy withDefaultGradStepSize(double defaultGradStepSize) {
+        this.defaultGradStepSize = defaultGradStepSize;
+        return this;
+    }
+
+    /** */
+    public double[] getCompositionWeights() {
+        return compositionWeights;
+    }
+
+    /** */
+    public double getMeanValue() {
+        return meanLabelValue;
+    }
 }
index 201586e..8c1afd7 100644 (file)
@@ -17,6 +17,7 @@
 
 package org.apache.ignite.ml.composition.boosting;
 
+import org.apache.ignite.ml.composition.boosting.loss.SquaredError;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
@@ -33,15 +34,14 @@ public abstract class GDBRegressionTrainer extends GDBTrainer {
      * @param cntOfIterations Count of learning iterations.
      */
     public GDBRegressionTrainer(double gradStepSize, Integer cntOfIterations) {
-        super(gradStepSize,
-            cntOfIterations,
-            LossGradientPerPredictionFunctions.MSE);
+        super(gradStepSize, cntOfIterations, new SquaredError());
     }
 
     /** {@inheritDoc} */
-    @Override protected <V, K> void learnLabels(DatasetBuilder<K, V> builder, IgniteBiFunction<K, V, Vector> featureExtractor,
+    @Override protected <V, K> boolean learnLabels(DatasetBuilder<K, V> builder, IgniteBiFunction<K, V, Vector> featureExtractor,
         IgniteBiFunction<K, V, Double> lExtractor) {
 
+        return true;
     }
 
     /** {@inheritDoc} */
index c7f21dd..85af798 100644 (file)
@@ -22,6 +22,9 @@ import java.util.List;
 import org.apache.ignite.lang.IgniteBiTuple;
 import org.apache.ignite.ml.Model;
 import org.apache.ignite.ml.composition.ModelsComposition;
+import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerFactory;
+import org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory;
+import org.apache.ignite.ml.composition.boosting.loss.Loss;
 import org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator;
 import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
@@ -30,7 +33,7 @@ import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
 import org.apache.ignite.ml.environment.logging.MLLogger;
 import org.apache.ignite.ml.knn.regression.KNNRegressionTrainer;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
-import org.apache.ignite.ml.math.functions.IgniteTriFunction;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer;
 import org.apache.ignite.ml.regressions.linear.LinearRegressionSGDTrainer;
@@ -60,24 +63,25 @@ public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Doubl
     private final int cntOfIterations;
 
     /**
-     * Gradient of loss function. First argument is sample size, second argument is valid answer, third argument is
-     * current model prediction.
+     * Loss function.
      */
-    protected final IgniteTriFunction<Long, Double, Double, Double> lossGradient;
+    protected final Loss loss;
+
+    /** Check convergence strategy factory. */
+    protected ConvergenceCheckerFactory checkConvergenceStgyFactory = new MeanAbsValueConvergenceCheckerFactory(0.001);
 
     /**
      * Constructs GDBTrainer instance.
      *
      * @param gradStepSize Grad step size.
      * @param cntOfIterations Count of learning iterations.
-     * @param lossGradient Gradient of loss function. First argument is sample size, second argument is valid answer
+     * @param loss Gradient of loss function. First argument is sample size, second argument is valid answer
      * third argument is current model prediction.
      */
-    public GDBTrainer(double gradStepSize, Integer cntOfIterations,
-        IgniteTriFunction<Long, Double, Double, Double> lossGradient) {
+    public GDBTrainer(double gradStepSize, Integer cntOfIterations, Loss loss) {
         gradientStep = gradStepSize;
         this.cntOfIterations = cntOfIterations;
-        this.lossGradient = lossGradient;
+        this.loss = loss;
     }
 
     /** {@inheritDoc} */
@@ -85,53 +89,55 @@ public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Doubl
         IgniteBiFunction<K, V, Vector> featureExtractor,
         IgniteBiFunction<K, V, Double> lbExtractor) {
 
-        learnLabels(datasetBuilder, featureExtractor, lbExtractor);
+        return updateModel(null, datasetBuilder, featureExtractor, lbExtractor);
+    }
+
+    /** {@inheritDoc} */
+    @Override protected <K, V> ModelsComposition updateModel(ModelsComposition mdl, DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
+
+        if (!learnLabels(datasetBuilder, featureExtractor, lbExtractor))
+            return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
+
+        IgniteBiTuple<Double, Long> initAndSampleSize = computeInitialValue(datasetBuilder, featureExtractor, lbExtractor);
+        if(initAndSampleSize == null)
+            return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
 
-        IgniteBiTuple<Double, Long> initAndSampleSize = computeInitialValue(datasetBuilder,
-            featureExtractor, lbExtractor);
         Double mean = initAndSampleSize.get1();
         Long sampleSize = initAndSampleSize.get2();
 
-        double[] compositionWeights = new double[cntOfIterations];
-        Arrays.fill(compositionWeights, gradientStep);
-        WeightedPredictionsAggregator resAggregator = new WeightedPredictionsAggregator(compositionWeights, mean);
-
         long learningStartTs = System.currentTimeMillis();
 
-        List<Model<Vector, Double>> models = getLearningStrategy()
+        GDBLearningStrategy stgy = getLearningStrategy()
             .withBaseModelTrainerBuilder(this::buildBaseModelTrainer)
             .withExternalLabelToInternal(this::externalLabelToInternal)
             .withCntOfIterations(cntOfIterations)
-            .withCompositionWeights(compositionWeights)
             .withEnvironment(environment)
-            .withLossGradient(lossGradient)
+            .withLossGradient(loss)
             .withSampleSize(sampleSize)
             .withMeanLabelValue(mean)
-            .learnModels(datasetBuilder, featureExtractor, lbExtractor);
+            .withDefaultGradStepSize(gradientStep)
+            .withCheckConvergenceStgyFactory(checkConvergenceStgyFactory);
+
+        List<Model<Vector, Double>> models;
+        if (mdl != null)
+            models = stgy.update((GDBModel)mdl, datasetBuilder, featureExtractor, lbExtractor);
+        else
+            models = stgy.learnModels(datasetBuilder, featureExtractor, lbExtractor);
 
         double learningTime = (double)(System.currentTimeMillis() - learningStartTs) / 1000.0;
         environment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "The training time was %.2fs", learningTime);
 
-        return new ModelsComposition(models, resAggregator) {
-            @Override public Double apply(Vector features) {
-                return internalLabelToExternal(super.apply(features));
-            }
-        };
-    }
-
-
-    //TODO: This method will be implemented in IGNITE-9412
-    /** {@inheritDoc} */
-    @Override public <K, V> ModelsComposition updateModel(ModelsComposition mdl, DatasetBuilder<K, V> datasetBuilder,
-        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
-
-        throw new UnsupportedOperationException();
+        WeightedPredictionsAggregator resAggregator = new WeightedPredictionsAggregator(
+            stgy.getCompositionWeights(),
+            stgy.getMeanValue()
+        );
+        return new GDBModel(models, resAggregator, this::internalLabelToExternal);
     }
 
-    //TODO: This method will be implemented in IGNITE-9412
     /** {@inheritDoc} */
     @Override protected boolean checkState(ModelsComposition mdl) {
-        throw new UnsupportedOperationException();
+        return mdl instanceof GDBModel;
     }
 
     /**
@@ -140,8 +146,9 @@ public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Doubl
      * @param builder Dataset builder.
      * @param featureExtractor Feature extractor.
      * @param lExtractor Labels extractor.
+     * @return true if labels learning was successful.
      */
-    protected abstract <V, K> void learnLabels(DatasetBuilder<K, V> builder,
+    protected abstract <V, K> boolean learnLabels(DatasetBuilder<K, V> builder,
         IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lExtractor);
 
     /**
@@ -196,7 +203,8 @@ public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Doubl
                 }
             );
 
-            meanTuple.set1(meanTuple.get1() / meanTuple.get2());
+            if (meanTuple != null)
+                meanTuple.set1(meanTuple.get1() / meanTuple.get2());
             return meanTuple;
         }
         catch (Exception e) {
@@ -205,6 +213,17 @@ public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Doubl
     }
 
     /**
+     * Sets CheckConvergenceStgyFactory.
+     *
+     * @param factory
+     * @return trainer.
+     */
+    public GDBTrainer withCheckConvergenceStgyFactory(ConvergenceCheckerFactory factory) {
+        this.checkConvergenceStgyFactory = factory;
+        return this;
+    }
+
+    /**
      * Returns learning strategy.
      *
      * @return learning strategy.
@@ -212,4 +231,33 @@ public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Doubl
     protected GDBLearningStrategy getLearningStrategy() {
         return new GDBLearningStrategy();
     }
+
+    /** */
+    public static class GDBModel extends ModelsComposition {
+        /** Serial version uid. */
+        private static final long serialVersionUID = 3476661240155508004L;
+
+        /** Internal to external lbl mapping. */
+        private final IgniteFunction<Double, Double> internalToExternalLblMapping;
+
+        /**
+         * Creates an instance of GDBModel.
+         *
+         * @param models Models.
+         * @param predictionsAggregator Predictions aggregator.
+         * @param internalToExternalLblMapping Internal to external lbl mapping.
+         */
+        public GDBModel(List<? extends Model<Vector, Double>> models,
+            WeightedPredictionsAggregator predictionsAggregator,
+            IgniteFunction<Double, Double> internalToExternalLblMapping) {
+
+            super(models, predictionsAggregator);
+            this.internalToExternalLblMapping = internalToExternalLblMapping;
+        }
+
+        /** {@inheritDoc} */
+        @Override public Double apply(Vector features) {
+            return internalToExternalLblMapping.apply(super.apply(features));
+        }
+    }
 }
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceChecker.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceChecker.java
new file mode 100644 (file)
index 0000000..3f6e8ca
--- /dev/null
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.boosting.convergence;
+
+import java.io.Serializable;
+import org.apache.ignite.ml.composition.ModelsComposition;
+import org.apache.ignite.ml.composition.boosting.loss.Loss;
+import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData;
+import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapDataBuilder;
+import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+
+/**
+ * Contains logic of error computing and convergence checking for Gradient Boosting algorithms.
+ *
+ * @param <K> Type of a key in upstream data.
+ * @param <V> Type of a value in upstream data.
+ */
+public abstract class ConvergenceChecker<K, V> implements Serializable {
+    /** Serial version uid. */
+    private static final long serialVersionUID = 710762134746674105L;
+
+    /** Sample size. */
+    private long sampleSize;
+
+    /** External label to internal mapping. */
+    private IgniteFunction<Double, Double> externalLbToInternalMapping;
+
+    /** Loss function. */
+    private Loss loss;
+
+    /** Feature extractor. */
+    private IgniteBiFunction<K, V, Vector> featureExtractor;
+
+    /** Label extractor. */
+    private IgniteBiFunction<K, V, Double> lbExtractor;
+
+    /** Precision of convergence check. */
+    private double precision;
+
+    /**
+     * Constructs an instance of ConvergenceChecker.
+     *
+     * @param sampleSize Sample size.
+     * @param externalLbToInternalMapping External label to internal mapping.
+     * @param loss Loss gradient.
+     * @param datasetBuilder Dataset builder.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @param precision
+     */
+    public ConvergenceChecker(long sampleSize,
+        IgniteFunction<Double, Double> externalLbToInternalMapping, Loss loss,
+        DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor,
+        double precision) {
+
+        assert precision < 1 && precision >= 0;
+
+        this.sampleSize = sampleSize;
+        this.externalLbToInternalMapping = externalLbToInternalMapping;
+        this.loss = loss;
+        this.featureExtractor = featureExtractor;
+        this.lbExtractor = lbExtractor;
+        this.precision = precision;
+    }
+
+    /**
+     * Checks convergency on dataset.
+     *
+     * @param currMdl Current model.
+     * @return true if GDB is converged.
+     */
+    public boolean isConverged(DatasetBuilder<K, V> datasetBuilder, ModelsComposition currMdl) {
+        try (Dataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset = datasetBuilder.build(
+            new EmptyContextBuilder<>(),
+            new FeatureMatrixWithLabelsOnHeapDataBuilder<>(featureExtractor, lbExtractor)
+        )) {
+            return isConverged(dataset, currMdl);
+        }
+        catch (Exception e) {
+            throw new RuntimeException(e);
+        }
+    }
+
+    /**
+     * Checks convergency on dataset.
+     *
+     * @param dataset Dataset.
+     * @param currMdl Current model.
+     * @return true if GDB is converged.
+     */
+    public boolean isConverged(Dataset<EmptyContext, ? extends FeatureMatrixWithLabelsOnHeapData> dataset, ModelsComposition currMdl) {
+        Double error = computeMeanErrorOnDataset(dataset, currMdl);
+        return error < precision || error.isNaN();
+    }
+
+    /**
+     * Compute error for given model on learning dataset.
+     *
+     * @param dataset Learning dataset.
+     * @param mdl Model.
+     * @return error mean value.
+     */
+    public abstract Double computeMeanErrorOnDataset(
+        Dataset<EmptyContext, ? extends FeatureMatrixWithLabelsOnHeapData> dataset,
+        ModelsComposition mdl);
+
+    /**
+     * Compute error for the specific vector of dataset.
+     *
+     * @param currMdl Current model.
+     * @return error.
+     */
+    public double computeError(Vector features, Double answer, ModelsComposition currMdl) {
+        Double realAnswer = externalLbToInternalMapping.apply(answer);
+        Double mdlAnswer = currMdl.apply(features);
+        return -loss.gradient(sampleSize, realAnswer, mdlAnswer);
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceCheckerFactory.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceCheckerFactory.java
new file mode 100644 (file)
index 0000000..7592f50
--- /dev/null
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.boosting.convergence;
+
+import org.apache.ignite.ml.composition.boosting.loss.Loss;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+
+/**
+ * Factory for ConvergenceChecker.
+ */
+public abstract class ConvergenceCheckerFactory {
+    /** Precision of error checking. If error <= precision then it is equated to 0.0*/
+    protected double precision;
+
+    /**
+     * Creates an instance of ConvergenceCheckerFactory.
+     *
+     * @param precision Precision [0 <= precision < 1].
+     */
+    public ConvergenceCheckerFactory(double precision) {
+        this.precision = precision;
+    }
+
+    /**
+     * Create an instance of ConvergenceChecker.
+     *
+     * @param sampleSize Sample size.
+     * @param externalLbToInternalMapping External label to internal mapping.
+     * @param loss Loss function.
+     * @param datasetBuilder Dataset builder.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @return ConvergenceCheckerFactory instance.
+     */
+    public abstract <K,V> ConvergenceChecker<K,V> create(long sampleSize,
+        IgniteFunction<Double, Double> externalLbToInternalMapping, Loss loss,
+        DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor);
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceChecker.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceChecker.java
new file mode 100644 (file)
index 0000000..7340bfa
--- /dev/null
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.boosting.convergence.mean;
+
+import org.apache.ignite.lang.IgniteBiTuple;
+import org.apache.ignite.ml.composition.ModelsComposition;
+import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker;
+import org.apache.ignite.ml.composition.boosting.loss.Loss;
+import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+
+/**
+ * Use mean value of errors for estimating error on dataset.
+ *
+ * @param <K> Type of a key in upstream data.
+ * @param <V> Type of a value in upstream data.
+ */
+public class MeanAbsValueConvergenceChecker<K,V> extends ConvergenceChecker<K,V> {
+    /** Serial version uid. */
+    private static final long serialVersionUID = 8534776439755210864L;
+
+    /**
+     * Creates an intance of MeanAbsValueConvergenceChecker.
+     *
+     * @param sampleSize Sample size.
+     * @param externalLbToInternalMapping External label to internal mapping.
+     * @param loss Loss.
+     * @param datasetBuilder Dataset builder.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     */
+    public MeanAbsValueConvergenceChecker(long sampleSize, IgniteFunction<Double, Double> externalLbToInternalMapping,
+        Loss loss, DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor,
+        double precision) {
+
+        super(sampleSize, externalLbToInternalMapping, loss, datasetBuilder, featureExtractor, lbExtractor, precision);
+    }
+
+    /** {@inheritDoc} */
+    @Override public Double computeMeanErrorOnDataset(Dataset<EmptyContext, ? extends FeatureMatrixWithLabelsOnHeapData> dataset,
+        ModelsComposition mdl) {
+
+        IgniteBiTuple<Double, Long> sumAndCnt = dataset.compute(
+            partition -> computeStatisticOnPartition(mdl, partition),
+            this::reduce
+        );
+
+        if(sumAndCnt == null || sumAndCnt.getValue() == 0)
+            return Double.NaN;
+        return sumAndCnt.getKey() / sumAndCnt.getValue();
+    }
+
+    /**
+     * Compute sum of absolute value of errors and count of rows in partition.
+     *
+     * @param mdl Model.
+     * @param part Partition.
+     * @return Tuple (sum of errors, count of rows)
+     */
+    private IgniteBiTuple<Double, Long> computeStatisticOnPartition(ModelsComposition mdl, FeatureMatrixWithLabelsOnHeapData part) {
+        Double sum = 0.0;
+
+        for(int i = 0; i < part.getFeatures().length; i++) {
+            double error = computeError(VectorUtils.of(part.getFeatures()[i]), part.getLabels()[i], mdl);
+            sum += Math.abs(error);
+        }
+
+        return new IgniteBiTuple<>(sum, (long) part.getLabels().length);
+    }
+
+    /**
+     * Merge left and right statistics from partitions.
+     *
+     * @param left Left.
+     * @param right Right.
+     * @return merged value.
+     */
+    private IgniteBiTuple<Double, Long> reduce(IgniteBiTuple<Double, Long> left, IgniteBiTuple<Double, Long> right) {
+        if (left == null) {
+            if (right != null)
+                return right;
+            else
+                return new IgniteBiTuple<>(0.0, 0L);
+        }
+
+        if (right == null)
+            return left;
+
+        return new IgniteBiTuple<>(
+            left.getKey() + right.getKey(),
+            right.getValue() + left.getValue()
+        );
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceCheckerFactory.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceCheckerFactory.java
new file mode 100644 (file)
index 0000000..f02a606
--- /dev/null
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.boosting.convergence.mean;
+
+import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker;
+import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerFactory;
+import org.apache.ignite.ml.composition.boosting.loss.Loss;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+
+/**
+ * Factory for {@link MeanAbsValueConvergenceChecker}.
+ */
+public class MeanAbsValueConvergenceCheckerFactory extends ConvergenceCheckerFactory {
+    /**
+     * @param precision Precision.
+     */
+    public MeanAbsValueConvergenceCheckerFactory(double precision) {
+        super(precision);
+    }
+
+    /** {@inheritDoc} */
+    @Override public <K, V> ConvergenceChecker<K, V> create(long sampleSize,
+        IgniteFunction<Double, Double> externalLbToInternalMapping, Loss loss, DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
+
+        return new MeanAbsValueConvergenceChecker<>(sampleSize, externalLbToInternalMapping, loss,
+            datasetBuilder, featureExtractor, lbExtractor, precision);
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/mean/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/mean/package-info.java
new file mode 100644 (file)
index 0000000..1ab6e66
--- /dev/null
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains implementation of convergence checking computer by mean of absolute value of errors in dataset.
+ */
+package org.apache.ignite.ml.composition.boosting.convergence.mean;
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceChecker.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceChecker.java
new file mode 100644 (file)
index 0000000..7e66a9c
--- /dev/null
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.boosting.convergence.median;
+
+import java.util.Arrays;
+import org.apache.ignite.ml.composition.ModelsComposition;
+import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker;
+import org.apache.ignite.ml.composition.boosting.loss.Loss;
+import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+
+/**
+ * Use median of median on partitions value of errors for estimating error on dataset. This algorithm may be less
+ * sensitive to
+ *
+ * @param <K> Type of a key in upstream data.
+ * @param <V> Type of a value in upstream data.
+ */
+public class MedianOfMedianConvergenceChecker<K, V> extends ConvergenceChecker<K, V> {
+    /** Serial version uid. */
+    private static final long serialVersionUID = 4902502002933415287L;
+
+    /**
+     * Creates an instance of MedianOfMedianConvergenceChecker.
+     *
+     * @param sampleSize Sample size.
+     * @param lblMapping External label to internal mapping.
+     * @param loss Loss function.
+     * @param datasetBuilder Dataset builder.
+     * @param fExtr Feature extractor.
+     * @param lbExtr Label extractor.
+     * @param precision Precision.
+     */
+    public MedianOfMedianConvergenceChecker(long sampleSize, IgniteFunction<Double, Double> lblMapping, Loss loss,
+        DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> fExtr,
+        IgniteBiFunction<K, V, Double> lbExtr, double precision) {
+
+        super(sampleSize, lblMapping, loss, datasetBuilder, fExtr, lbExtr, precision);
+    }
+
+    /** {@inheritDoc} */
+    @Override public Double computeMeanErrorOnDataset(Dataset<EmptyContext, ? extends FeatureMatrixWithLabelsOnHeapData> dataset,
+        ModelsComposition mdl) {
+
+        double[] medians = dataset.compute(
+            data -> computeMedian(mdl, data),
+            this::reduce
+        );
+
+        if(medians == null)
+            return Double.POSITIVE_INFINITY;
+        return getMedian(medians);
+    }
+
+    /**
+     * Compute median value on data partition.
+     *
+     * @param mdl Model.
+     * @param data Data.
+     * @return median value.
+     */
+    private double[] computeMedian(ModelsComposition mdl, FeatureMatrixWithLabelsOnHeapData data) {
+        double[] errors = new double[data.getLabels().length];
+        for (int i = 0; i < errors.length; i++)
+            errors[i] = Math.abs(computeError(VectorUtils.of(data.getFeatures()[i]), data.getLabels()[i], mdl));
+        return new double[] {getMedian(errors)};
+    }
+
+    /**
+     * Compute median value on array of errors.
+     *
+     * @param errors Error values.
+     * @return median value of errors.
+     */
+    private double getMedian(double[] errors) {
+        if(errors.length == 0)
+            return Double.POSITIVE_INFINITY;
+
+        Arrays.sort(errors);
+        final int middleIdx = (errors.length - 1) / 2;
+        if (errors.length % 2 == 1)
+            return errors[middleIdx];
+        else
+            return (errors[middleIdx + 1] + errors[middleIdx]) / 2;
+    }
+
+    /**
+     * Merge median values among partitions.
+     *
+     * @param left Left partition.
+     * @param right Right partition.
+     * @return merged median values.
+     */
+    private double[] reduce(double[] left, double[] right) {
+        if (left == null)
+            return right;
+        if(right == null)
+            return left;
+
+        double[] res = new double[left.length + right.length];
+        System.arraycopy(left, 0, res, 0, left.length);
+        System.arraycopy(right, 0, res, left.length, right.length);
+        return res;
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceCheckerFactory.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceCheckerFactory.java
new file mode 100644 (file)
index 0000000..a1affe0
--- /dev/null
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.boosting.convergence.median;
+
+import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker;
+import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerFactory;
+import org.apache.ignite.ml.composition.boosting.loss.Loss;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+
+/**
+ * Factory for {@link MedianOfMedianConvergenceChecker}.
+ */
+public class MedianOfMedianConvergenceCheckerFactory extends ConvergenceCheckerFactory {
+    /**
+     * @param precision Precision.
+     */
+    public MedianOfMedianConvergenceCheckerFactory(double precision) {
+        super(precision);
+    }
+
+    /** {@inheritDoc} */
+    @Override public <K, V> ConvergenceChecker<K, V> create(long sampleSize,
+        IgniteFunction<Double, Double> externalLbToInternalMapping, Loss loss, DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
+
+        return new MedianOfMedianConvergenceChecker<>(sampleSize, externalLbToInternalMapping, loss,
+            datasetBuilder, featureExtractor, lbExtractor, precision);
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/median/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/median/package-info.java
new file mode 100644 (file)
index 0000000..3798ef9
--- /dev/null
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains implementation of convergence checking computer by median of medians of errors in dataset.
+ */
+package org.apache.ignite.ml.composition.boosting.convergence.median;
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/package-info.java
new file mode 100644 (file)
index 0000000..6d42c62
--- /dev/null
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Package contains implementation of convergency checking algorithms for gradient boosting.
+ * This algorithms may stop training of gradient boosting if it achieve error on dataset less than precision
+ * specified by user.
+ */
+package org.apache.ignite.ml.composition.boosting.convergence;
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/simple/ConvergenceCheckerStub.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/simple/ConvergenceCheckerStub.java
new file mode 100644 (file)
index 0000000..716d04e
--- /dev/null
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.boosting.convergence.simple;
+
+import org.apache.ignite.ml.composition.ModelsComposition;
+import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker;
+import org.apache.ignite.ml.composition.boosting.loss.Loss;
+import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+
+/**
+ * This strategy skip estimating error on dataset step.
+ * According to this strategy, training will stop after reaching the maximum number of iterations.
+ *
+ * @param <K> Type of a key in upstream data.
+ * @param <V> Type of a value in upstream data.
+ */
+public class ConvergenceCheckerStub<K,V> extends ConvergenceChecker<K,V> {
+    /** Serial version uid. */
+    private static final long serialVersionUID = 8534776439755210864L;
+
+    /**
+     * Creates an intance of ConvergenceCheckerStub.
+     *
+     * @param sampleSize Sample size.
+     * @param externalLbToInternalMapping External label to internal mapping.
+     * @param loss Loss function.
+     * @param datasetBuilder Dataset builder.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     */
+    public ConvergenceCheckerStub(long sampleSize,
+        IgniteFunction<Double, Double> externalLbToInternalMapping, Loss loss,
+        DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
+
+        super(sampleSize, externalLbToInternalMapping, loss, datasetBuilder,
+            featureExtractor, lbExtractor, 0.0);
+    }
+
+    /** {@inheritDoc} */
+    @Override public boolean isConverged(DatasetBuilder<K, V> datasetBuilder, ModelsComposition currMdl) {
+        return false;
+    }
+
+    /** {@inheritDoc} */
+    @Override public boolean isConverged(Dataset<EmptyContext, ? extends FeatureMatrixWithLabelsOnHeapData> dataset,
+        ModelsComposition currMdl) {
+        return false;
+    }
+
+    /** {@inheritDoc} */
+    @Override public Double computeMeanErrorOnDataset(Dataset<EmptyContext, ? extends FeatureMatrixWithLabelsOnHeapData> dataset,
+        ModelsComposition mdl) {
+
+        throw new UnsupportedOperationException();
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/simple/ConvergenceCheckerStubFactory.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/simple/ConvergenceCheckerStubFactory.java
new file mode 100644 (file)
index 0000000..a0f0d5c
--- /dev/null
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.boosting.convergence.simple;
+
+import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker;
+import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerFactory;
+import org.apache.ignite.ml.composition.boosting.loss.Loss;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+
+/**
+ * Factory for {@link ConvergenceCheckerStub}.
+ */
+public class ConvergenceCheckerStubFactory extends ConvergenceCheckerFactory {
+    /**
+     * Create an instance of ConvergenceCheckerStubFactory.
+     */
+    public ConvergenceCheckerStubFactory() {
+        super(0.0);
+    }
+
+    /** {@inheritDoc} */
+    @Override public <K, V> ConvergenceChecker<K, V> create(long sampleSize,
+        IgniteFunction<Double, Double> externalLbToInternalMapping, Loss loss,
+        DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
+
+        return new ConvergenceCheckerStub<>(sampleSize, externalLbToInternalMapping, loss,
+            datasetBuilder, featureExtractor, lbExtractor);
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/simple/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/simple/package-info.java
new file mode 100644 (file)
index 0000000..915903a
--- /dev/null
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains implementation of Stub for convergence checking.
+ * By this implementation gradient boosting will train new submodels until count of models achieving max value [count
+ * of iterations parameter].
+ */
+package org.apache.ignite.ml.composition.boosting.convergence.simple;
  * limitations under the License.
  */
 
-package org.apache.ignite.ml.composition.boosting;
-
-import org.apache.ignite.ml.math.functions.IgniteTriFunction;
+package org.apache.ignite.ml.composition.boosting.loss;
 
 /**
- * Contains implementations of per-prediction loss functions for gradient boosting algorithm.
+ * Logistic regression loss function.
  */
-public class LossGradientPerPredictionFunctions {
-    /** Mean squared error loss for regression. */
-    public static IgniteTriFunction<Long, Double, Double, Double> MSE =
-        (sampleSize, answer, prediction) -> (2.0 / sampleSize) * (prediction - answer);
+public class LogLoss implements Loss {
+    /** Serial version uid. */
+    private static final long serialVersionUID = 2251384437214194977L;
+
+    /** {@inheritDoc} */
+    @Override public double error(long sampleSize, double answer, double prediction) {
+        return -(answer * Math.log(prediction) + (1 - answer) * Math.log(1 - prediction));
+    }
 
-    /** Logarithmic loss for binary classification. */
-    public static IgniteTriFunction<Long, Double, Double, Double> LOG_LOSS =
-        (sampleSize, answer, prediction) -> (prediction - answer) / (prediction * (1.0 - prediction));
+    /** {@inheritDoc} */
+    @Override public double gradient(long sampleSize, double answer, double prediction) {
+        return (prediction - answer) / (prediction * (1.0 - prediction));
+    }
 }
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/loss/Loss.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/loss/Loss.java
new file mode 100644 (file)
index 0000000..72fff30
--- /dev/null
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.boosting.loss;
+
+import java.io.Serializable;
+
+/**
+ * Loss interface of computing error or gradient of error on specific row in dataset.
+ */
+public interface Loss extends Serializable {
+    /**
+     * Error value for model answer.
+     *
+     * @param sampleSize Sample size.
+     * @param lb Label.
+     * @param mdlAnswer Model answer.
+     * @return error value.
+     */
+    public double error(long sampleSize, double lb, double mdlAnswer);
+
+    /**
+     * Error gradient value for model answer.
+     *
+     * @param sampleSize Sample size.
+     * @param lb Label.
+     * @param mdlAnswer Model answer.
+     * @return error value.
+     */
+    public double gradient(long sampleSize, double lb, double mdlAnswer);
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/loss/SquaredError.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/loss/SquaredError.java
new file mode 100644 (file)
index 0000000..8f2f17e
--- /dev/null
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.boosting.loss;
+
+/**
+ * Represent error function as E(label, modelAnswer) = 1/N * (label - prediction)^2
+ */
+public class SquaredError implements Loss {
+    /** Serial version uid. */
+    private static final long serialVersionUID = 564886150646352157L;
+
+    /** {@inheritDoc} */
+    @Override public double error(long sampleSize, double lb, double prediction) {
+        return Math.pow(lb - prediction, 2) / sampleSize;
+    }
+
+    /** {@inheritDoc} */
+    @Override public double gradient(long sampleSize, double lb, double prediction) {
+        return (2.0 / sampleSize) * (prediction - lb);
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/loss/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/loss/package-info.java
new file mode 100644 (file)
index 0000000..83a5e39
--- /dev/null
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains loss functions for Gradient Boosting algorithms.
+ */
+package org.apache.ignite.ml.composition.boosting.loss;
index 8a369ad..5e0f7f1 100644 (file)
@@ -86,4 +86,14 @@ public class WeightedPredictionsAggregator implements PredictionsAggregator {
         return builder.append(bias > 0 ? " + " : " - ").append(String.format("%.4f", bias))
             .append("]").toString();
     }
+
+    /** */
+    public double[] getWeights() {
+        return weights;
+    }
+
+    /** */
+    public double getBias() {
+        return bias;
+    }
 }
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/FeatureMatrixWithLabelsOnHeapData.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/FeatureMatrixWithLabelsOnHeapData.java
new file mode 100644 (file)
index 0000000..9dbc1a9
--- /dev/null
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.dataset.primitive;
+
+/**
+ * A partition {@code data} of the containing matrix of features and vector of labels stored in heap.
+ */
+public class FeatureMatrixWithLabelsOnHeapData implements AutoCloseable {
+    /** Matrix with features. */
+    private final double[][] features;
+
+    /** Vector with labels. */
+    private final double[] labels;
+
+    /**
+     * Constructs an instance of FeatureMatrixWithLabelsOnHeapData.
+     *
+     * @param features Features.
+     * @param labels Labels.
+     */
+    public FeatureMatrixWithLabelsOnHeapData(double[][] features, double[] labels) {
+        assert features.length == labels.length : "Features and labels have to be the same length";
+
+        this.features = features;
+        this.labels = labels;
+    }
+
+    /** */
+    public double[][] getFeatures() {
+        return features;
+    }
+
+    /** */
+    public double[] getLabels() {
+        return labels;
+    }
+
+    /** {@inheritDoc} */
+    @Override public void close() {
+        // Do nothing, GC will clean up.
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/FeatureMatrixWithLabelsOnHeapDataBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/FeatureMatrixWithLabelsOnHeapDataBuilder.java
new file mode 100644 (file)
index 0000000..be1724c
--- /dev/null
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.dataset.primitive;
+
+import java.io.Serializable;
+import java.util.Iterator;
+import org.apache.ignite.ml.dataset.PartitionDataBuilder;
+import org.apache.ignite.ml.dataset.UpstreamEntry;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.tree.data.DecisionTreeData;
+
+/**
+ * A partition {@code data} builder that makes {@link DecisionTreeData}.
+ *
+ * @param <K> Type of a key in <tt>upstream</tt> data.
+ * @param <V> Type of a value in <tt>upstream</tt> data.
+ * @param <C> Type of a partition <tt>context</tt>.
+ */
+public class FeatureMatrixWithLabelsOnHeapDataBuilder<K, V, C extends Serializable>
+    implements PartitionDataBuilder<K, V, C, FeatureMatrixWithLabelsOnHeapData> {
+    /** Serial version uid. */
+    private static final long serialVersionUID = 6273736987424171813L;
+
+    /** Function that extracts features from an {@code upstream} data. */
+    private final IgniteBiFunction<K, V, Vector> featureExtractor;
+
+    /** Function that extracts labels from an {@code upstream} data. */
+    private final IgniteBiFunction<K, V, Double> lbExtractor;
+
+    /**
+     * Constructs a new instance of decision tree data builder.
+     *
+     * @param featureExtractor Function that extracts features from an {@code upstream} data.
+     * @param lbExtractor Function that extracts labels from an {@code upstream} data.
+     */
+    public FeatureMatrixWithLabelsOnHeapDataBuilder(IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
+        this.featureExtractor = featureExtractor;
+        this.lbExtractor = lbExtractor;
+    }
+
+    /** {@inheritDoc} */
+    @Override public FeatureMatrixWithLabelsOnHeapData build(Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize, C ctx) {
+        double[][] features = new double[Math.toIntExact(upstreamDataSize)][];
+        double[] labels = new double[Math.toIntExact(upstreamDataSize)];
+
+        int ptr = 0;
+        while (upstreamData.hasNext()) {
+            UpstreamEntry<K, V> entry = upstreamData.next();
+
+            features[ptr] = featureExtractor.apply(entry.getKey(), entry.getValue()).asArray();
+
+            labels[ptr] = lbExtractor.apply(entry.getKey(), entry.getValue());
+
+            ptr++;
+        }
+
+        return new FeatureMatrixWithLabelsOnHeapData(features, labels);
+    }
+}
index 8589a79..6ebbda1 100644 (file)
 
 package org.apache.ignite.ml.tree.boosting;
 
-import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
 import org.apache.ignite.ml.Model;
 import org.apache.ignite.ml.composition.ModelsComposition;
 import org.apache.ignite.ml.composition.boosting.GDBLearningStrategy;
+import org.apache.ignite.ml.composition.boosting.GDBTrainer;
+import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker;
 import org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator;
 import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
@@ -54,22 +55,30 @@ public class GDBOnTreesLearningStrategy  extends GDBLearningStrategy {
     }
 
     /** {@inheritDoc} */
-    @Override public <K, V> List<Model<Vector, Double>> learnModels(DatasetBuilder<K, V> datasetBuilder,
-        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
+    @Override public <K, V> List<Model<Vector, Double>> update(GDBTrainer.GDBModel mdlToUpdate,
+        DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
 
         DatasetTrainer<? extends Model<Vector, Double>, Double> trainer = baseMdlTrainerBuilder.get();
         assert trainer instanceof DecisionTree;
         DecisionTree decisionTreeTrainer = (DecisionTree) trainer;
 
-        List<Model<Vector, Double>> models = new ArrayList<>();
+        List<Model<Vector, Double>> models = initLearningState(mdlToUpdate);
+
+        ConvergenceChecker<K,V> convCheck = checkConvergenceStgyFactory.create(sampleSize,
+            externalLbToInternalMapping, loss, datasetBuilder, featureExtractor, lbExtractor);
+
         try (Dataset<EmptyContext, DecisionTreeData> dataset = datasetBuilder.build(
             new EmptyContextBuilder<>(),
             new DecisionTreeDataBuilder<>(featureExtractor, lbExtractor, useIndex)
         )) {
             for (int i = 0; i < cntOfIterations; i++) {
-                double[] weights = Arrays.copyOf(compositionWeights, i);
+                double[] weights = Arrays.copyOf(compositionWeights, models.size());
                 WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(weights, meanLabelValue);
-                Model<Vector, Double> currComposition = new ModelsComposition(models, aggregator);
+                ModelsComposition currComposition = new ModelsComposition(models, aggregator);
+
+                if(convCheck.isConverged(dataset, currComposition))
+                    break;
 
                 dataset.compute(part -> {
                     if(part.getCopyOfOriginalLabels() == null)
@@ -78,7 +87,7 @@ public class GDBOnTreesLearningStrategy  extends GDBLearningStrategy {
                     for(int j = 0; j < part.getLabels().length; j++) {
                         double mdlAnswer = currComposition.apply(VectorUtils.of(part.getFeatures()[j]));
                         double originalLbVal = externalLbToInternalMapping.apply(part.getCopyOfOriginalLabels()[j]);
-                        part.getLabels()[j] = -lossGradient.apply(sampleSize, originalLbVal, mdlAnswer);
+                        part.getLabels()[j] = -loss.gradient(sampleSize, originalLbVal, mdlAnswer);
                     }
                 });
 
@@ -92,6 +101,7 @@ public class GDBOnTreesLearningStrategy  extends GDBLearningStrategy {
             throw new RuntimeException(e);
         }
 
+        compositionWeights = Arrays.copyOf(compositionWeights, models.size());
         return models;
     }
 }
index d5750ea..b8a16dc 100644 (file)
@@ -19,18 +19,14 @@ package org.apache.ignite.ml.tree.data;
 
 import java.util.ArrayList;
 import java.util.List;
+import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData;
 import org.apache.ignite.ml.tree.TreeFilter;
 
 /**
- * A partition {@code data} of the containing matrix of features and vector of labels stored in heap.
+ * A partition {@code data} of the containing matrix of features and vector of labels stored in heap
+ * with index on features.
  */
-public class DecisionTreeData implements AutoCloseable {
-    /** Matrix with features. */
-    private final double[][] features;
-
-    /** Vector with labels. */
-    private final double[] labels;
-
+public class DecisionTreeData extends FeatureMatrixWithLabelsOnHeapData implements AutoCloseable {
     /** Copy of vector with original labels. Auxiliary for Gradient Boosting on Trees.*/
     private double[] copyOfOriginalLabels;
 
@@ -48,10 +44,7 @@ public class DecisionTreeData implements AutoCloseable {
      * @param buildIdx Build index.
      */
     public DecisionTreeData(double[][] features, double[] labels, boolean buildIdx) {
-        assert features.length == labels.length : "Features and labels have to be the same length";
-
-        this.features = features;
-        this.labels = labels;
+        super(features, labels);
         this.buildIndex = buildIdx;
 
         indexesCache = new ArrayList<>();
@@ -68,6 +61,8 @@ public class DecisionTreeData implements AutoCloseable {
     public DecisionTreeData filter(TreeFilter filter) {
         int size = 0;
 
+        double[][] features = getFeatures();
+        double[] labels = getLabels();
         for (int i = 0; i < features.length; i++)
             if (filter.test(features[i]))
                 size++;
@@ -95,12 +90,15 @@ public class DecisionTreeData implements AutoCloseable {
      * @param col Column.
      */
     public void sort(int col) {
-        sort(col, 0, features.length - 1);
+        sort(col, 0, getFeatures().length - 1);
     }
 
     /** */
     private void sort(int col, int from, int to) {
         if (from < to) {
+            double[][] features = getFeatures();
+            double[] labels = getLabels();
+
             double pivot = features[(from + to) / 2][col];
 
             int i = from, j = to;
@@ -131,19 +129,11 @@ public class DecisionTreeData implements AutoCloseable {
     }
 
     /** */
-    public double[][] getFeatures() {
-        return features;
-    }
-
-    /** */
-    public double[] getLabels() {
-        return labels;
-    }
-
     public double[] getCopyOfOriginalLabels() {
         return copyOfOriginalLabels;
     }
 
+    /** */
     public void setCopyOfOriginalLabels(double[] copyOfOriginalLabels) {
         this.copyOfOriginalLabels = copyOfOriginalLabels;
     }
@@ -170,7 +160,7 @@ public class DecisionTreeData implements AutoCloseable {
 
         if (depth == indexesCache.size()) {
             if (depth == 0)
-                indexesCache.add(new TreeDataIndex(features, labels));
+                indexesCache.add(new TreeDataIndex(getFeatures(), getLabels()));
             else {
                 TreeDataIndex lastIndex = indexesCache.get(depth - 1);
                 indexesCache.add(lastIndex.filter(filter));
index 3e340f6..89b8c9c 100644 (file)
@@ -22,11 +22,13 @@ import java.util.Map;
 import java.util.function.BiFunction;
 import org.apache.ignite.ml.Model;
 import org.apache.ignite.ml.composition.ModelsComposition;
+import org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory;
+import org.apache.ignite.ml.composition.boosting.convergence.simple.ConvergenceCheckerStubFactory;
 import org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
-import org.apache.ignite.ml.trainers.DatasetTrainer;
 import org.apache.ignite.ml.tree.DecisionTreeConditionalNode;
 import org.apache.ignite.ml.tree.boosting.GDBBinaryClassifierOnTreesTrainer;
 import org.apache.ignite.ml.tree.boosting.GDBRegressionOnTreesTrainer;
@@ -54,8 +56,8 @@ public class GDBTrainerTest {
             learningSample.put(i, new double[] {xs[i], ys[i]});
         }
 
-        DatasetTrainer<ModelsComposition, Double> trainer
-            = new GDBRegressionOnTreesTrainer(1.0, 2000, 3, 0.0).withUseIndex(true);
+        GDBTrainer trainer = new GDBRegressionOnTreesTrainer(1.0, 2000, 3, 0.0)
+            .withUseIndex(true);
 
         Model<Vector, Double> mdl = trainer.fit(
             learningSample, 1,
@@ -74,7 +76,6 @@ public class GDBTrainerTest {
 
         assertEquals(0.0, mse, 0.0001);
 
-        assertTrue(mdl instanceof ModelsComposition);
         ModelsComposition composition = (ModelsComposition)mdl;
         assertTrue(composition.toString().length() > 0);
         assertTrue(composition.toString(true).length() > 0);
@@ -84,6 +85,13 @@ public class GDBTrainerTest {
 
         assertEquals(2000, composition.getModels().size());
         assertTrue(composition.getPredictionsAggregator() instanceof WeightedPredictionsAggregator);
+
+        trainer = trainer.withCheckConvergenceStgyFactory(new MeanAbsValueConvergenceCheckerFactory(0.1));
+        assertTrue(trainer.fit(
+            learningSample, 1,
+            (k, v) -> VectorUtils.of(v[0]),
+            (k, v) -> v[1]
+        ).getModels().size() < 2000);
     }
 
     /** */
@@ -107,7 +115,7 @@ public class GDBTrainerTest {
     }
 
     /** */
-    private void testClassifier(BiFunction<GDBBinaryClassifierOnTreesTrainer, Map<Integer, double[]>,
+    private void testClassifier(BiFunction<GDBTrainer, Map<Integer, double[]>,
         Model<Vector, Double>> fitter) {
         int sampleSize = 100;
         double[] xs = new double[sampleSize];
@@ -122,8 +130,9 @@ public class GDBTrainerTest {
         for (int i = 0; i < sampleSize; i++)
             learningSample.put(i, new double[] {xs[i], ys[i]});
 
-        GDBBinaryClassifierOnTreesTrainer trainer
-            = new GDBBinaryClassifierOnTreesTrainer(0.3, 500, 3, 0.0).withUseIndex(true);
+        GDBTrainer trainer = new GDBBinaryClassifierOnTreesTrainer(0.3, 500, 3, 0.0)
+            .withUseIndex(true)
+            .withCheckConvergenceStgyFactory(new MeanAbsValueConvergenceCheckerFactory(0.3));
 
         Model<Vector, Double> mdl = fitter.apply(trainer, learningSample);
 
@@ -132,7 +141,7 @@ public class GDBTrainerTest {
             double x = xs[j];
             double y = ys[j];
             double p = mdl.apply(VectorUtils.of(x));
-            if(p != y)
+            if (p != y)
                 errorsCnt++;
         }
 
@@ -142,7 +151,61 @@ public class GDBTrainerTest {
         ModelsComposition composition = (ModelsComposition)mdl;
         composition.getModels().forEach(m -> assertTrue(m instanceof DecisionTreeConditionalNode));
 
-        assertEquals(500, composition.getModels().size());
+        assertTrue(composition.getModels().size() < 500);
         assertTrue(composition.getPredictionsAggregator() instanceof WeightedPredictionsAggregator);
+
+        trainer = trainer.withCheckConvergenceStgyFactory(new ConvergenceCheckerStubFactory());
+        assertEquals(500, ((ModelsComposition)fitter.apply(trainer, learningSample)).getModels().size());
+    }
+
+    /** */
+    @Test
+    public void testUpdate() {
+        int sampleSize = 100;
+        double[] xs = new double[sampleSize];
+        double[] ys = new double[sampleSize];
+
+        for (int i = 0; i < sampleSize; i++) {
+            xs[i] = i;
+            ys[i] = ((int)(xs[i] / 10.0) % 2) == 0 ? -1.0 : 1.0;
+        }
+
+        Map<Integer, double[]> learningSample = new HashMap<>();
+        for (int i = 0; i < sampleSize; i++)
+            learningSample.put(i, new double[] {xs[i], ys[i]});
+        IgniteBiFunction<Integer, double[], Vector> fExtr = (k, v) -> VectorUtils.of(v[0]);
+        IgniteBiFunction<Integer, double[], Double> lExtr = (k, v) -> v[1];
+
+        GDBTrainer classifTrainer = new GDBBinaryClassifierOnTreesTrainer(0.3, 500, 3, 0.0)
+            .withUseIndex(true)
+            .withCheckConvergenceStgyFactory(new MeanAbsValueConvergenceCheckerFactory(0.3));
+        GDBTrainer regressTrainer = new GDBRegressionOnTreesTrainer(0.3, 500, 3, 0.0)
+            .withUseIndex(true)
+            .withCheckConvergenceStgyFactory(new MeanAbsValueConvergenceCheckerFactory(0.3));
+
+        testUpdate(learningSample, fExtr, lExtr, classifTrainer);
+        testUpdate(learningSample, fExtr, lExtr, regressTrainer);
+    }
+
+    /** */
+    private void testUpdate(Map<Integer, double[]> dataset, IgniteBiFunction<Integer, double[], Vector> fExtr,
+        IgniteBiFunction<Integer, double[], Double> lExtr, GDBTrainer trainer) {
+
+        ModelsComposition originalMdl = trainer.fit(dataset, 1, fExtr, lExtr);
+        ModelsComposition updatedOnSameDataset = trainer.update(originalMdl, dataset, 1, fExtr, lExtr);
+
+        LocalDatasetBuilder<Integer, double[]> epmtyDataset = new LocalDatasetBuilder<>(new HashMap<>(), 1);
+        ModelsComposition updatedOnEmptyDataset = trainer.updateModel(originalMdl, epmtyDataset, fExtr, lExtr);
+
+        dataset.forEach((k,v) -> {
+            Vector features = fExtr.apply(k, v);
+
+            Double originalAnswer = originalMdl.apply(features);
+            Double updatedMdlAnswer1 = updatedOnSameDataset.apply(features);
+            Double updatedMdlAnswer2 = updatedOnEmptyDataset.apply(features);
+
+            assertEquals(originalAnswer, updatedMdlAnswer1, 0.01);
+            assertEquals(originalAnswer, updatedMdlAnswer2, 0.01);
+        });
     }
 }
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceCheckerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceCheckerTest.java
new file mode 100644 (file)
index 0000000..50fdf8b
--- /dev/null
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.boosting.convergence;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.ignite.ml.composition.ModelsComposition;
+import org.apache.ignite.ml.composition.boosting.loss.Loss;
+import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.junit.Before;
+
+/** */
+public abstract class ConvergenceCheckerTest {
+    /** Not converged model. */
+    protected ModelsComposition notConvergedMdl = new ModelsComposition(Collections.emptyList(), null) {
+        @Override public Double apply(Vector features) {
+            return 2.1 * features.get(0);
+        }
+    };
+
+    /** Converged model. */
+    protected ModelsComposition convergedMdl = new ModelsComposition(Collections.emptyList(), null) {
+        @Override public Double apply(Vector features) {
+            return 2 * (features.get(0) + 1);
+        }
+    };
+
+    /** Features extractor. */
+    protected IgniteBiFunction<double[], Double, Vector> fExtr = (x, y) -> VectorUtils.of(x);
+
+    /** Label extractor. */
+    protected IgniteBiFunction<double[], Double, Double> lbExtr = (x, y) -> y;
+
+    /** Data. */
+    protected Map<double[], Double> data;
+
+    /** */
+    @Before
+    public void setUp() throws Exception {
+        data = new HashMap<>();
+        for(int i = 0; i < 10; i ++)
+            data.put(new double[]{i, i + 1}, (double)(2 * (i + 1)));
+    }
+
+    /** */
+    public ConvergenceChecker<double[], Double> createChecker(ConvergenceCheckerFactory factory,
+        LocalDatasetBuilder<double[], Double> datasetBuilder) {
+
+        return factory.create(data.size(),
+            x -> x,
+            new Loss() {
+                @Override public double error(long sampleSize, double lb, double mdlAnswer) {
+                    return mdlAnswer - lb;
+                }
+
+                @Override public double gradient(long sampleSize, double lb, double mdlAnswer) {
+                    return mdlAnswer - lb;
+                }
+            },
+            datasetBuilder, fExtr, lbExtr
+        );
+    }
+}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceCheckerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceCheckerTest.java
new file mode 100644 (file)
index 0000000..0b42db8
--- /dev/null
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.boosting.convergence.mean;
+
+import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker;
+import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerTest;
+import org.apache.ignite.ml.dataset.impl.local.LocalDataset;
+import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData;
+import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapDataBuilder;
+import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+/** */
+public class MeanAbsValueConvergenceCheckerTest extends ConvergenceCheckerTest {
+    /** */
+    @Test
+    public void testConvergenceChecking() {
+        LocalDatasetBuilder<double[], Double> datasetBuilder = new LocalDatasetBuilder<>(data, 1);
+        ConvergenceChecker<double[], Double> checker = createChecker(
+            new MeanAbsValueConvergenceCheckerFactory(0.1), datasetBuilder);
+
+        double error = checker.computeError(VectorUtils.of(1, 2), 4.0, notConvergedMdl);
+        Assert.assertEquals(1.9, error, 0.01);
+        Assert.assertFalse(checker.isConverged(datasetBuilder, notConvergedMdl));
+        Assert.assertTrue(checker.isConverged(datasetBuilder, convergedMdl));
+
+        try(LocalDataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset = datasetBuilder.build(
+            new EmptyContextBuilder<>(), new FeatureMatrixWithLabelsOnHeapDataBuilder<>(fExtr, lbExtr))) {
+
+            double onDSError = checker.computeMeanErrorOnDataset(dataset, notConvergedMdl);
+            Assert.assertEquals(1.55, onDSError, 0.01);
+        } catch (Exception e) {
+            throw new RuntimeException(e);
+        }
+    }
+
+    /** Mean error more sensitive to anomalies in data */
+    @Test
+    public void testConvergenceCheckingWithAnomaliesInData() {
+        data.put(new double[]{10, 11}, 100000.0);
+        LocalDatasetBuilder<double[], Double> datasetBuilder = new LocalDatasetBuilder<>(data, 1);
+        ConvergenceChecker<double[], Double> checker = createChecker(
+            new MeanAbsValueConvergenceCheckerFactory(0.1), datasetBuilder);
+
+        try(LocalDataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset = datasetBuilder.build(
+            new EmptyContextBuilder<>(), new FeatureMatrixWithLabelsOnHeapDataBuilder<>(fExtr, lbExtr))) {
+
+            double onDSError = checker.computeMeanErrorOnDataset(dataset, notConvergedMdl);
+            Assert.assertEquals(9090.41, onDSError, 0.01);
+        } catch (Exception e) {
+            throw new RuntimeException(e);
+        }
+    }
+}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceCheckerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceCheckerTest.java
new file mode 100644 (file)
index 0000000..d6880b4
--- /dev/null
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.boosting.convergence.median;
+
+import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker;
+import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerTest;
+import org.apache.ignite.ml.dataset.impl.local.LocalDataset;
+import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData;
+import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapDataBuilder;
+import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+/** */
+public class MedianOfMedianConvergenceCheckerTest extends ConvergenceCheckerTest {
+    /** */
+    @Test
+    public void testConvergenceChecking() {
+        data.put(new double[]{10, 11}, 100000.0);
+        LocalDatasetBuilder<double[], Double> datasetBuilder = new LocalDatasetBuilder<>(data, 1);
+
+        ConvergenceChecker<double[], Double> checker = createChecker(
+            new MedianOfMedianConvergenceCheckerFactory(0.1), datasetBuilder);
+
+        double error = checker.computeError(VectorUtils.of(1, 2), 4.0, notConvergedMdl);
+        Assert.assertEquals(1.9, error, 0.01);
+        Assert.assertFalse(checker.isConverged(datasetBuilder, notConvergedMdl));
+        Assert.assertTrue(checker.isConverged(datasetBuilder, convergedMdl));
+
+        try(LocalDataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset = datasetBuilder.build(
+            new EmptyContextBuilder<>(), new FeatureMatrixWithLabelsOnHeapDataBuilder<>(fExtr, lbExtr))) {
+
+            double onDSError = checker.computeMeanErrorOnDataset(dataset, notConvergedMdl);
+            Assert.assertEquals(1.6, onDSError, 0.01);
+        } catch (Exception e) {
+            throw new RuntimeException(e);
+        }
+    }
+}
index f88fd3e..b06fd67 100644 (file)
 
 package org.apache.ignite.ml.environment;
 
-import java.util.Arrays;
-import java.util.UUID;
-import java.util.concurrent.atomic.AtomicInteger;
-import java.util.concurrent.atomic.AtomicReference;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
-import javax.cache.Cache;
-import org.apache.ignite.Ignite;
-import org.apache.ignite.IgniteCache;
-import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
-import org.apache.ignite.cache.query.QueryCursor;
-import org.apache.ignite.cache.query.ScanQuery;
-import org.apache.ignite.configuration.CacheConfiguration;
-import org.apache.ignite.internal.util.IgniteUtils;
-import org.apache.ignite.ml.composition.ModelsComposition;
 import org.apache.ignite.ml.dataset.feature.FeatureMeta;
 import org.apache.ignite.ml.environment.logging.ConsoleLogger;
 import org.apache.ignite.ml.environment.logging.MLLogger;
+import org.apache.ignite.ml.environment.parallelism.DefaultParallelismStrategy;
 import org.apache.ignite.ml.environment.parallelism.ParallelismStrategy;
-import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.apache.ignite.ml.tree.randomforest.RandomForestRegressionTrainer;
 import org.apache.ignite.ml.tree.randomforest.data.FeaturesCountSelectionStrategies;
-import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
-import org.apache.ignite.thread.IgniteThread;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
 
 /**
  * Tests for {@link LearningEnvironment} that require to start the whole Ignite infrastructure. IMPL NOTE based on
  * RandomForestRegressionExample example.
  */
-public class LearningEnvironmentTest extends GridCommonAbstractTest {
-    /** Number of nodes in grid */
-    private static final int NODE_COUNT = 1;
-
-    /** Ignite instance. */
-    private Ignite ignite;
-
-    /** {@inheritDoc} */
-    @Override protected void beforeTestsStarted() throws Exception {
-        for (int i = 1; i <= NODE_COUNT; i++)
-            startGrid(i);
-    }
-
-    /** {@inheritDoc} */
-    @Override protected void afterTestsStopped() {
-        stopAllGrids();
-    }
-
-    /**
-     * {@inheritDoc}
-     */
-    @Override protected void beforeTest() {
-        /* Grid instance. */
-        ignite = grid(NODE_COUNT);
-        ignite.configuration().setPeerClassLoadingEnabled(true);
-        IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-    }
-
+public class LearningEnvironmentTest {
     /** */
+    @Test
     public void testBasic() throws InterruptedException {
-        AtomicReference<Integer> actualAmount = new AtomicReference<>(null);
-        AtomicReference<Double> actualMse = new AtomicReference<>(null);
-        AtomicReference<Double> actualMae = new AtomicReference<>(null);
-
-        IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
-            LearningEnvironmentTest.class.getSimpleName(), () -> {
-            IgniteCache<Integer, double[]> dataCache = getTestCache(ignite);
-
-            AtomicInteger idx = new AtomicInteger(0);
-            RandomForestRegressionTrainer trainer = new RandomForestRegressionTrainer(
-                IntStream.range(0, data[0].length - 1).mapToObj(
-                    x -> new FeatureMeta("", idx.getAndIncrement(), false)).collect(Collectors.toList())
-            ).withCountOfTrees(101)
-                .withFeaturesCountSelectionStrgy(FeaturesCountSelectionStrategies.ONE_THIRD)
-                .withMaxDepth(4)
-                .withMinImpurityDelta(0.)
-                .withSubsampleSize(0.3)
-                .withSeed(0);
-
-            trainer.setEnvironment(LearningEnvironment.builder()
-                .withParallelismStrategy(ParallelismStrategy.Type.ON_DEFAULT_POOL)
-                .withLoggingFactory(ConsoleLogger.factory(MLLogger.VerboseLevel.LOW))
-                .build()
-            );
-
-            ModelsComposition randomForest = trainer.fit(ignite, dataCache,
-                (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
-                (k, v) -> v[v.length - 1]
-            );
-
-            double mse = 0.0;
-            double mae = 0.0;
-            int totalAmount = 0;
-
-            try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
-                for (Cache.Entry<Integer, double[]> observation : observations) {
-                    double difference = estimatePrediction(randomForest, observation);
-
-                    mse += Math.pow(difference, 2.0);
-                    mae += Math.abs(difference);
-
-                    totalAmount++;
-                }
-            }
-
-            actualAmount.set(totalAmount);
-
-            mse = mse / totalAmount;
-            actualMse.set(mse);
-
-            mae = mae / totalAmount;
-            actualMae.set(mae);
-        });
-
-        igniteThread.start();
-        igniteThread.join();
-
-        assertEquals("Total amount", 23, (int)actualAmount.get());
-        assertTrue("Mean squared error (MSE)", actualMse.get() > 0);
-        assertTrue("Mean absolute error (MAE)", actualMae.get() > 0);
+        RandomForestRegressionTrainer trainer = new RandomForestRegressionTrainer(
+            IntStream.range(0, 0).mapToObj(
+                x -> new FeatureMeta("", 0, false)).collect(Collectors.toList())
+        ).withCountOfTrees(101)
+            .withFeaturesCountSelectionStrgy(FeaturesCountSelectionStrategies.ONE_THIRD)
+            .withMaxDepth(4)
+            .withMinImpurityDelta(0.)
+            .withSubsampleSize(0.3)
+            .withSeed(0);
+
+        LearningEnvironment environment = LearningEnvironment.builder()
+            .withParallelismStrategy(ParallelismStrategy.Type.ON_DEFAULT_POOL)
+            .withLoggingFactory(ConsoleLogger.factory(MLLogger.VerboseLevel.LOW))
+            .build();
+        trainer.setEnvironment(environment);
+        assertEquals(DefaultParallelismStrategy.class, environment.parallelismStrategy().getClass());
+        assertEquals(ConsoleLogger.class, environment.logger().getClass());
     }
-
-    /** */
-    private double estimatePrediction(ModelsComposition randomForest, Cache.Entry<Integer, double[]> observation) {
-        double[] val = observation.getValue();
-        double[] inputs = Arrays.copyOfRange(val, 0, val.length - 1);
-        double groundTruth = val[val.length - 1];
-
-        double prediction = randomForest.apply(VectorUtils.of(inputs));
-
-        return prediction - groundTruth;
-    }
-
-    /**
-     * Fills cache with data and returns it.
-     *
-     * @param ignite Ignite instance.
-     * @return Filled Ignite Cache.
-     */
-    private IgniteCache<Integer, double[]> getTestCache(Ignite ignite) {
-        CacheConfiguration<Integer, double[]> cacheConfiguration = new CacheConfiguration<>();
-        cacheConfiguration.setName(UUID.randomUUID().toString());
-        cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10));
-
-        IgniteCache<Integer, double[]> cache = ignite.createCache(cacheConfiguration);
-
-        for (int i = 0; i < data.length; i++)
-            cache.put(i, data[i]);
-
-        return cache;
-    }
-
-    /**
-     * Part of the Boston housing dataset.
-     */
-    private static final double[][] data = {
-        {0.02731,0.00,7.070,0,0.4690,6.4210,78.90,4.9671,2,242.0,17.80,396.90,9.14,21.60},
-        {0.02729,0.00,7.070,0,0.4690,7.1850,61.10,4.9671,2,242.0,17.80,392.83,4.03,34.70},
-        {0.03237,0.00,2.180,0,0.4580,6.9980,45.80,6.0622,3,222.0,18.70,394.63,2.94,33.40},
-        {0.06905,0.00,2.180,0,0.4580,7.1470,54.20,6.0622,3,222.0,18.70,396.90,5.33,36.20},
-        {0.02985,0.00,2.180,0,0.4580,6.4300,58.70,6.0622,3,222.0,18.70,394.12,5.21,28.70},
-        {0.08829,12.50,7.870,0,0.5240,6.0120,66.60,5.5605,5,311.0,15.20,395.60,12.43,22.90},
-        {0.14455,12.50,7.870,0,0.5240,6.1720,96.10,5.9505,5,311.0,15.20,396.90,19.15,27.10},
-        {0.21124,12.50,7.870,0,0.5240,5.6310,100.00,6.0821,5,311.0,15.20,386.63,29.93,16.50},
-        {0.17004,12.50,7.870,0,0.5240,6.0040,85.90,6.5921,5,311.0,15.20,386.71,17.10,18.90},
-        {0.22489,12.50,7.870,0,0.5240,6.3770,94.30,6.3467,5,311.0,15.20,392.52,20.45,15.00},
-        {0.11747,12.50,7.870,0,0.5240,6.0090,82.90,6.2267,5,311.0,15.20,396.90,13.27,18.90},
-        {0.09378,12.50,7.870,0,0.5240,5.8890,39.00,5.4509,5,311.0,15.20,390.50,15.71,21.70},
-        {0.62976,0.00,8.140,0,0.5380,5.9490,61.80,4.7075,4,307.0,21.00,396.90,8.26,20.40},
-        {0.63796,0.00,8.140,0,0.5380,6.0960,84.50,4.4619,4,307.0,21.00,380.02,10.26,18.20},
-        {0.62739,0.00,8.140,0,0.5380,5.8340,56.50,4.4986,4,307.0,21.00,395.62,8.47,19.90},
-        {1.05393,0.00,8.140,0,0.5380,5.9350,29.30,4.4986,4,307.0,21.00,386.85,6.58,23.10},
-        {0.78420,0.00,8.140,0,0.5380,5.9900,81.70,4.2579,4,307.0,21.00,386.75,14.67,17.50},
-        {0.80271,0.00,8.140,0,0.5380,5.4560,36.60,3.7965,4,307.0,21.00,288.99,11.69,20.20},
-        {0.72580,0.00,8.140,0,0.5380,5.7270,69.50,3.7965,4,307.0,21.00,390.95,11.28,18.20},
-        {1.25179,0.00,8.140,0,0.5380,5.5700,98.10,3.7979,4,307.0,21.00,376.57,21.02,13.60},
-        {0.85204,0.00,8.140,0,0.5380,5.9650,89.20,4.0123,4,307.0,21.00,392.53,13.83,19.60},
-        {1.23247,0.00,8.140,0,0.5380,6.1420,91.70,3.9769,4,307.0,21.00,396.90,18.72,15.20},
-        {0.98843,0.00,8.140,0,0.5380,5.8130,100.00,4.0952,4,307.0,21.00,394.54,19.88,14.50}
-    };
-
 }
 
index d8fb620..199644b 100644 (file)
@@ -93,17 +93,21 @@ public class ANNClassificationTest extends TrainerTest {
             .withDistanceMeasure(new EuclideanDistance())
             .withStrategy(NNStrategy.SIMPLE);
 
-        ANNClassificationModel updatedOnSameDataset = trainer.withSeed(1234L).update(originalMdl,
+        ANNClassificationModel updatedOnSameDataset = (ANNClassificationModel) trainer.withSeed(1234L).update(originalMdl,
             cacheMock, parts,
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
             (k, v) -> v[2]
-        );
+        ).withK(3)
+            .withDistanceMeasure(new EuclideanDistance())
+            .withStrategy(NNStrategy.SIMPLE);
 
-        ANNClassificationModel updatedOnEmptyDataset = trainer.withSeed(1234L).update(originalMdl,
+        ANNClassificationModel updatedOnEmptyDataset = (ANNClassificationModel) trainer.withSeed(1234L).update(originalMdl,
             new HashMap<Integer, double[]>(), parts,
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
             (k, v) -> v[2]
-        );
+        ).withK(3)
+            .withDistanceMeasure(new EuclideanDistance())
+            .withStrategy(NNStrategy.SIMPLE);
 
         Vector v1 = VectorUtils.of(550, 550);
         Vector v2 = VectorUtils.of(-550, -550);