IGNITE-9065: Gradient boosting optimization
authorAlexey Platonov <aplatonovv@gmail.com>
Wed, 8 Aug 2018 10:22:26 +0000 (13:22 +0300)
committerYury Babak <ybabak@gridgain.com>
Wed, 8 Aug 2018 10:22:26 +0000 (13:22 +0300)
this closes #4486

examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java [moved from examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GRBOnTreesRegressionTrainerExample.java with 97% similarity]
modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBBinaryClassifierOnTreesTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBRegressionOnTreesTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator.java

@@ -37,7 +37,7 @@ import org.jetbrains.annotations.NotNull;
  *
  * In this example dataset is creating automatically by parabolic function f(x) = x^2.
  */
-public class GRBOnTreesRegressionTrainerExample {
+public class GDBOnTreesRegressionTrainerExample {
     /**
      * Run example.
      *
@@ -49,7 +49,7 @@ public class GRBOnTreesRegressionTrainerExample {
             System.out.println(">>> Ignite grid started.");
 
             IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
-                GRBOnTreesRegressionTrainerExample.class.getSimpleName(), () -> {
+                GDBOnTreesRegressionTrainerExample.class.getSimpleName(), () -> {
 
                 // Create cache with training data.
                 CacheConfiguration<Integer, double[]> trainingSetCfg = createCacheConfiguration();
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java
new file mode 100644 (file)
index 0000000..375748a
--- /dev/null
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.boosting;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.composition.ModelsComposition;
+import org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.environment.LearningEnvironment;
+import org.apache.ignite.ml.environment.logging.MLLogger;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.functions.IgniteSupplier;
+import org.apache.ignite.ml.math.functions.IgniteTriFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.trainers.DatasetTrainer;
+
+/**
+ * Learning strategy for gradient boosting.
+ */
+public class GDBLearningStrategy {
+    /** Learning environment. */
+    protected LearningEnvironment environment;
+
+    /** Count of iterations. */
+    protected int cntOfIterations;
+
+    /** Loss of gradient. */
+    protected IgniteTriFunction<Long, Double, Double, Double> lossGradient;
+
+    /** External label to internal mapping. */
+    protected IgniteFunction<Double, Double> externalLbToInternalMapping;
+
+    /** Base model trainer builder. */
+    protected IgniteSupplier<DatasetTrainer<? extends Model<Vector, Double>, Double>> baseMdlTrainerBuilder;
+
+    /** Mean label value. */
+    protected double meanLabelValue;
+
+    /** Sample size. */
+    protected long sampleSize;
+
+    /** Composition weights. */
+    protected double[] compositionWeights;
+
+    /**
+     * Implementation of gradient boosting iterations. At each step of iterations this algorithm
+     * build a regression model based on gradient of loss-function for current models composition.
+     *
+     * @param datasetBuilder Dataset builder.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @return list of learned models.
+     */
+    public <K, V> List<Model<Vector, Double>> learnModels(DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
+
+        List<Model<Vector, Double>> models = new ArrayList<>();
+        DatasetTrainer<? extends Model<Vector, Double>, Double> trainer = baseMdlTrainerBuilder.get();
+        for (int i = 0; i < cntOfIterations; i++) {
+            double[] weights = Arrays.copyOf(compositionWeights, i);
+
+            WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(weights, meanLabelValue);
+            Model<Vector, Double> currComposition = new ModelsComposition(models, aggregator);
+
+            IgniteBiFunction<K, V, Double> lbExtractorWrap = (k, v) -> {
+                Double realAnswer = externalLbToInternalMapping.apply(lbExtractor.apply(k, v));
+                Double mdlAnswer = currComposition.apply(featureExtractor.apply(k, v));
+                return -lossGradient.apply(sampleSize, realAnswer, mdlAnswer);
+            };
+
+            long startTs = System.currentTimeMillis();
+            models.add(trainer.fit(datasetBuilder, featureExtractor, lbExtractorWrap));
+            double learningTime = (double)(System.currentTimeMillis() - startTs) / 1000.0;
+            environment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "One model training time was %.2fs", learningTime);
+        }
+
+        return models;
+    }
+
+    /**
+     * Sets learning environment.
+     *
+     * @param environment Learning Environment.
+     */
+    public GDBLearningStrategy withEnvironment(LearningEnvironment environment) {
+        this.environment = environment;
+        return this;
+    }
+
+    /**
+     * Sets count of iterations.
+     *
+     * @param cntOfIterations Count of iterations.
+     */
+    public GDBLearningStrategy withCntOfIterations(int cntOfIterations) {
+        this.cntOfIterations = cntOfIterations;
+        return this;
+    }
+
+    /**
+     * Sets gradient of loss function.
+     *
+     * @param lossGradient Loss gradient.
+     */
+    public GDBLearningStrategy withLossGradient(IgniteTriFunction<Long, Double, Double, Double> lossGradient) {
+        this.lossGradient = lossGradient;
+        return this;
+    }
+
+    /**
+     * Sets external to internal label representation mapping.
+     *
+     * @param externalLbToInternal External label to internal.
+     */
+    public GDBLearningStrategy withExternalLabelToInternal(IgniteFunction<Double, Double> externalLbToInternal) {
+        this.externalLbToInternalMapping = externalLbToInternal;
+        return this;
+    }
+
+    /**
+     * Sets base model builder.
+     *
+     * @param buildBaseMdlTrainer Build base model trainer.
+     */
+    public GDBLearningStrategy withBaseModelTrainerBuilder(IgniteSupplier<DatasetTrainer<? extends Model<Vector, Double>, Double>> buildBaseMdlTrainer) {
+        this.baseMdlTrainerBuilder = buildBaseMdlTrainer;
+        return this;
+    }
+
+    /**
+     * Sets mean label value.
+     *
+     * @param meanLabelValue Mean label value.
+     */
+    public GDBLearningStrategy withMeanLabelValue(double meanLabelValue) {
+        this.meanLabelValue = meanLabelValue;
+        return this;
+    }
+
+    /**
+     * Sets sample size.
+     *
+     * @param sampleSize Sample size.
+     */
+    public GDBLearningStrategy withSampleSize(long sampleSize) {
+        this.sampleSize = sampleSize;
+        return this;
+    }
+
+    /**
+     * Sets composition weights vector.
+     *
+     * @param compositionWeights Composition weights.
+     */
+    public GDBLearningStrategy withCompositionWeights(double[] compositionWeights) {
+        this.compositionWeights = compositionWeights;
+        return this;
+    }
+}
index 8663d3d..5a0f52a 100644 (file)
@@ -17,7 +17,6 @@
 
 package org.apache.ignite.ml.composition.boosting;
 
-import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
 import org.apache.ignite.lang.IgniteBiTuple;
@@ -53,16 +52,18 @@ import org.jetbrains.annotations.NotNull;
  *
  * But in practice Decision Trees is most used regressors (see: {@link DecisionTreeRegressionTrainer}).
  */
-abstract class GDBTrainer extends DatasetTrainer<Model<Vector, Double>, Double> {
+public abstract class GDBTrainer extends DatasetTrainer<Model<Vector, Double>, Double> {
     /** Gradient step. */
     private final double gradientStep;
+
     /** Count of iterations. */
     private final int cntOfIterations;
+
     /**
      * Gradient of loss function. First argument is sample size, second argument is valid answer, third argument is
      * current model prediction.
      */
-    private final IgniteTriFunction<Long, Double, Double, Double> lossGradient;
+    protected final IgniteTriFunction<Long, Double, Double, Double> lossGradient;
 
     /**
      * Constructs GDBTrainer instance.
@@ -91,28 +92,23 @@ abstract class GDBTrainer extends DatasetTrainer<Model<Vector, Double>, Double>
         Double mean = initAndSampleSize.get1();
         Long sampleSize = initAndSampleSize.get2();
 
-        List<Model<Vector, Double>> models = new ArrayList<>();
         double[] compositionWeights = new double[cntOfIterations];
         Arrays.fill(compositionWeights, gradientStep);
         WeightedPredictionsAggregator resAggregator = new WeightedPredictionsAggregator(compositionWeights, mean);
 
         long learningStartTs = System.currentTimeMillis();
-        for (int i = 0; i < cntOfIterations; i++) {
-            double[] weights = Arrays.copyOf(compositionWeights, i);
-            WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(weights, mean);
-            Model<Vector, Double> currComposition = new ModelsComposition(models, aggregator);
-
-            IgniteBiFunction<K, V, Double> lbExtractorWrap = (k, v) -> {
-                Double realAnswer = externalLabelToInternal(lbExtractor.apply(k, v));
-                Double mdlAnswer = currComposition.apply(featureExtractor.apply(k, v));
-                return -lossGradient.apply(sampleSize, realAnswer, mdlAnswer);
-            };
-
-            long startTs = System.currentTimeMillis();
-            models.add(buildBaseModelTrainer().fit(datasetBuilder, featureExtractor, lbExtractorWrap));
-            double learningTime = (double)(System.currentTimeMillis() - startTs) / 1000.0;
-            environment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "One model training time was %.2fs", learningTime);
-        }
+
+        List<Model<Vector, Double>> models = getLearningStrategy()
+            .withBaseModelTrainerBuilder(this::buildBaseModelTrainer)
+            .withExternalLabelToInternal(this::externalLabelToInternal)
+            .withCntOfIterations(cntOfIterations)
+            .withCompositionWeights(compositionWeights)
+            .withEnvironment(environment)
+            .withLossGradient(lossGradient)
+            .withSampleSize(sampleSize)
+            .withMeanLabelValue(mean)
+            .learnModels(datasetBuilder, featureExtractor, lbExtractor);
+
         double learningTime = (double)(System.currentTimeMillis() - learningStartTs) / 1000.0;
         environment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "The training time was %.2fs", learningTime);
 
@@ -136,7 +132,8 @@ abstract class GDBTrainer extends DatasetTrainer<Model<Vector, Double>, Double>
     /**
      * Returns regressor model trainer for one step of GDB.
      */
-    @NotNull protected abstract DatasetTrainer<? extends Model<Vector, Double>, Double> buildBaseModelTrainer();
+    @NotNull
+    protected abstract DatasetTrainer<? extends Model<Vector, Double>, Double> buildBaseModelTrainer();
 
     /**
      * Maps external representation of label to internal.
@@ -191,4 +188,13 @@ abstract class GDBTrainer extends DatasetTrainer<Model<Vector, Double>, Double>
             throw new RuntimeException(e);
         }
     }
+
+    /**
+     * Returns learning strategy.
+     *
+     * @return learning strategy.
+     */
+    protected GDBLearningStrategy getLearningStrategy() {
+        return new GDBLearningStrategy();
+    }
 }
index 270f14a..de8994a 100644 (file)
@@ -79,20 +79,24 @@ public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends Dataset
             new EmptyContextBuilder<>(),
             new DecisionTreeDataBuilder<>(featureExtractor, lbExtractor, useIndex)
         )) {
-            return split(dataset, e -> true, 0, getImpurityMeasureCalculator(dataset));
+            return fit(dataset);
         }
         catch (Exception e) {
             throw new RuntimeException(e);
         }
     }
 
+    public <K,V> DecisionTreeNode fit(Dataset<EmptyContext, DecisionTreeData> dataset) {
+        return split(dataset, e -> true, 0, getImpurityMeasureCalculator(dataset));
+    }
+
     /**
      * Returns impurity measure calculator.
      *
      * @param dataset Dataset.
      * @return Impurity measure calculator.
      */
-    abstract ImpurityMeasureCalculator<T> getImpurityMeasureCalculator(Dataset<EmptyContext, DecisionTreeData> dataset);
+    protected abstract ImpurityMeasureCalculator<T> getImpurityMeasureCalculator(Dataset<EmptyContext, DecisionTreeData> dataset);
 
     /**
      * Splits the node specified by the given dataset and predicate and returns decision tree node.
index f371334..f8fc769 100644 (file)
@@ -96,7 +96,7 @@ public class DecisionTreeClassificationTrainer extends DecisionTree<GiniImpurity
     }
 
     /** {@inheritDoc} */
-    @Override ImpurityMeasureCalculator<GiniImpurityMeasure> getImpurityMeasureCalculator(
+    @Override protected ImpurityMeasureCalculator<GiniImpurityMeasure> getImpurityMeasureCalculator(
         Dataset<EmptyContext, DecisionTreeData> dataset) {
         Set<Double> labels = dataset.compute(part -> {
 
index 7446237..4c9aac9 100644 (file)
@@ -64,7 +64,7 @@ public class DecisionTreeRegressionTrainer extends DecisionTree<MSEImpurityMeasu
     }
 
     /** {@inheritDoc} */
-    @Override ImpurityMeasureCalculator<MSEImpurityMeasure> getImpurityMeasureCalculator(
+    @Override protected ImpurityMeasureCalculator<MSEImpurityMeasure> getImpurityMeasureCalculator(
         Dataset<EmptyContext, DecisionTreeData> dataset) {
 
         return new MSEImpurityMeasureCalculator(useIndex);
index 631e848..4d87b47 100644 (file)
 
 package org.apache.ignite.ml.tree.boosting;
 
-import org.apache.ignite.ml.Model;
 import org.apache.ignite.ml.composition.boosting.GDBBinaryClassifierTrainer;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.trainers.DatasetTrainer;
+import org.apache.ignite.ml.composition.boosting.GDBLearningStrategy;
 import org.apache.ignite.ml.tree.DecisionTreeRegressionTrainer;
 import org.jetbrains.annotations.NotNull;
 
@@ -54,7 +52,7 @@ public class GDBBinaryClassifierOnTreesTrainer extends GDBBinaryClassifierTraine
     }
 
     /** {@inheritDoc} */
-    @NotNull @Override protected DatasetTrainer<? extends Model<Vector, Double>, Double> buildBaseModelTrainer() {
+    @NotNull @Override protected DecisionTreeRegressionTrainer buildBaseModelTrainer() {
         return new DecisionTreeRegressionTrainer(maxDepth, minImpurityDecrease).withUseIndex(useIndex);
     }
 
@@ -68,4 +66,9 @@ public class GDBBinaryClassifierOnTreesTrainer extends GDBBinaryClassifierTraine
         this.useIndex = useIndex;
         return this;
     }
+
+    /** {@inheritDoc} */
+    @Override protected GDBLearningStrategy getLearningStrategy() {
+        return new GDBOnTreesLearningStrategy(useIndex);
+    }
 }
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java
new file mode 100644 (file)
index 0000000..8589a79
--- /dev/null
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.tree.boosting;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.composition.ModelsComposition;
+import org.apache.ignite.ml.composition.boosting.GDBLearningStrategy;
+import org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator;
+import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.environment.logging.MLLogger;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.trainers.DatasetTrainer;
+import org.apache.ignite.ml.tree.DecisionTree;
+import org.apache.ignite.ml.tree.data.DecisionTreeData;
+import org.apache.ignite.ml.tree.data.DecisionTreeDataBuilder;
+
+/**
+ * Gradient boosting on trees specific learning strategy reusing learning dataset with index between
+ * several learning iterations.
+ */
+public class GDBOnTreesLearningStrategy  extends GDBLearningStrategy {
+    private boolean useIndex;
+
+    /**
+     * Create an instance of learning strategy.
+     *
+     * @param useIndex Use index.
+     */
+    public GDBOnTreesLearningStrategy(boolean useIndex) {
+        this.useIndex = useIndex;
+    }
+
+    /** {@inheritDoc} */
+    @Override public <K, V> List<Model<Vector, Double>> learnModels(DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
+
+        DatasetTrainer<? extends Model<Vector, Double>, Double> trainer = baseMdlTrainerBuilder.get();
+        assert trainer instanceof DecisionTree;
+        DecisionTree decisionTreeTrainer = (DecisionTree) trainer;
+
+        List<Model<Vector, Double>> models = new ArrayList<>();
+        try (Dataset<EmptyContext, DecisionTreeData> dataset = datasetBuilder.build(
+            new EmptyContextBuilder<>(),
+            new DecisionTreeDataBuilder<>(featureExtractor, lbExtractor, useIndex)
+        )) {
+            for (int i = 0; i < cntOfIterations; i++) {
+                double[] weights = Arrays.copyOf(compositionWeights, i);
+                WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(weights, meanLabelValue);
+                Model<Vector, Double> currComposition = new ModelsComposition(models, aggregator);
+
+                dataset.compute(part -> {
+                    if(part.getCopyOfOriginalLabels() == null)
+                        part.setCopyOfOriginalLabels(Arrays.copyOf(part.getLabels(), part.getLabels().length));
+
+                    for(int j = 0; j < part.getLabels().length; j++) {
+                        double mdlAnswer = currComposition.apply(VectorUtils.of(part.getFeatures()[j]));
+                        double originalLbVal = externalLbToInternalMapping.apply(part.getCopyOfOriginalLabels()[j]);
+                        part.getLabels()[j] = -lossGradient.apply(sampleSize, originalLbVal, mdlAnswer);
+                    }
+                });
+
+                long startTs = System.currentTimeMillis();
+                models.add(decisionTreeTrainer.fit(dataset));
+                double learningTime = (double)(System.currentTimeMillis() - startTs) / 1000.0;
+                environment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "One model training time was %.2fs", learningTime);
+            }
+        }
+        catch (Exception e) {
+            throw new RuntimeException(e);
+        }
+
+        return models;
+    }
+}
index 450dae3..e2a183c 100644 (file)
 
 package org.apache.ignite.ml.tree.boosting;
 
-import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.composition.boosting.GDBLearningStrategy;
 import org.apache.ignite.ml.composition.boosting.GDBRegressionTrainer;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.trainers.DatasetTrainer;
 import org.apache.ignite.ml.tree.DecisionTreeRegressionTrainer;
 import org.jetbrains.annotations.NotNull;
 
@@ -54,7 +52,7 @@ public class GDBRegressionOnTreesTrainer extends GDBRegressionTrainer {
     }
 
     /** {@inheritDoc} */
-    @NotNull @Override protected DatasetTrainer<? extends Model<Vector, Double>, Double> buildBaseModelTrainer() {
+    @NotNull @Override protected DecisionTreeRegressionTrainer buildBaseModelTrainer() {
         return new DecisionTreeRegressionTrainer(maxDepth, minImpurityDecrease).withUseIndex(useIndex);
     }
 
@@ -68,4 +66,9 @@ public class GDBRegressionOnTreesTrainer extends GDBRegressionTrainer {
         this.useIndex = useIndex;
         return this;
     }
+
+    /** {@inheritDoc} */
+    @Override protected GDBLearningStrategy getLearningStrategy() {
+        return new GDBOnTreesLearningStrategy(useIndex);
+    }
 }
index c017e5c..d5750ea 100644 (file)
@@ -31,6 +31,9 @@ public class DecisionTreeData implements AutoCloseable {
     /** Vector with labels. */
     private final double[] labels;
 
+    /** Copy of vector with original labels. Auxiliary for Gradient Boosting on Trees.*/
+    private double[] copyOfOriginalLabels;
+
     /** Indexes cache. */
     private final List<TreeDataIndex> indexesCache;
 
@@ -137,6 +140,14 @@ public class DecisionTreeData implements AutoCloseable {
         return labels;
     }
 
+    public double[] getCopyOfOriginalLabels() {
+        return copyOfOriginalLabels;
+    }
+
+    public void setCopyOfOriginalLabels(double[] copyOfOriginalLabels) {
+        this.copyOfOriginalLabels = copyOfOriginalLabels;
+    }
+
     /** {@inheritDoc} */
     @Override public void close() {
         // Do nothing, GC will clean up.
index 709f68e..0c67535 100644 (file)
@@ -18,6 +18,8 @@
 package org.apache.ignite.ml.tree.impurity;
 
 import java.io.Serializable;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.apache.ignite.ml.tree.TreeFilter;
 import org.apache.ignite.ml.tree.data.DecisionTreeData;
 import org.apache.ignite.ml.tree.data.TreeDataIndex;
@@ -98,4 +100,8 @@ public abstract class ImpurityMeasureCalculator<T extends ImpurityMeasure<T>> im
     protected double getFeatureValue(DecisionTreeData data, TreeDataIndex idx, int featureId, int k) {
         return useIndex ? idx.featureInSortedOrder(k, featureId) : data.getFeatures()[k][featureId];
     }
+
+    protected Vector getFeatureValues(DecisionTreeData data, TreeDataIndex idx, int featureId, int k) {
+        return VectorUtils.of(useIndex ? idx.featuresInSortedOrder(k, featureId) : data.getFeatures()[k]);
+    }
 }