IGNITE-10605: [ML] Add binary metrics calculations to Cross-Validation
authorzaleslaw <zaleslaw.sin@gmail.com>
Fri, 21 Dec 2018 11:43:21 +0000 (14:43 +0300)
committerYury Babak <ybabak@gridgain.com>
Fri, 21 Dec 2018 11:43:21 +0000 (14:43 +0300)
This closes #5712

examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java
examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_metrics.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/BinaryClassificationEvaluator.java
modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetrics.java
modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationTest.java

index 552bcd2..462186c 100644 (file)
@@ -27,6 +27,8 @@ import org.apache.ignite.configuration.CacheConfiguration;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.apache.ignite.ml.selection.cv.CrossValidation;
 import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetricValues;
+import org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetrics;
 import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
 import org.apache.ignite.ml.tree.DecisionTreeNode;
 
@@ -71,7 +73,7 @@ public class CrossValidationExample {
             CrossValidation<DecisionTreeNode, Double, Integer, LabeledPoint> scoreCalculator
                 = new CrossValidation<>();
 
-            double[] scores = scoreCalculator.score(
+            double[] accuracyScores = scoreCalculator.score(
                 trainer,
                 new Accuracy<>(),
                 ignite,
@@ -81,7 +83,24 @@ public class CrossValidationExample {
                 4
             );
 
-            System.out.println(">>> Accuracy: " + Arrays.toString(scores));
+            System.out.println(">>> Accuracy: " + Arrays.toString(accuracyScores));
+
+            BinaryClassificationMetrics metrics = new BinaryClassificationMetrics()
+                .withNegativeClsLb(0.0)
+                .withPositiveClsLb(1.0)
+                .withMetric(BinaryClassificationMetricValues::balancedAccuracy);
+
+            double[] balancedAccuracyScores = scoreCalculator.score(
+                trainer,
+                metrics,
+                ignite,
+                trainingSet,
+                (k, v) -> VectorUtils.of(v.x, v.y),
+                (k, v) -> v.lb,
+                4
+            );
+
+            System.out.println(">>> Balanced Accuracy: " + Arrays.toString(balancedAccuracyScores));
 
             System.out.println(">>> Cross validation score calculator example completed.");
         }
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_metrics.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_metrics.java
new file mode 100644 (file)
index 0000000..0ea0ca2
--- /dev/null
@@ -0,0 +1,192 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.tutorial;
+
+import java.io.FileNotFoundException;
+import java.util.Arrays;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer;
+import org.apache.ignite.ml.preprocessing.encoding.EncoderType;
+import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer;
+import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer;
+import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer;
+import org.apache.ignite.ml.selection.cv.CrossValidation;
+import org.apache.ignite.ml.selection.cv.CrossValidationResult;
+import org.apache.ignite.ml.selection.paramgrid.ParamGrid;
+import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
+import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetricValues;
+import org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetrics;
+import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter;
+import org.apache.ignite.ml.selection.split.TrainTestSplit;
+import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
+import org.apache.ignite.ml.tree.DecisionTreeNode;
+
+/**
+ * To choose the best hyperparameters the cross-validation with {@link ParamGrid} will be used in this example.
+ * <p>
+ * Code in this example launches Ignite grid and fills the cache with test data (based on Titanic passengers data).</p>
+ * <p>
+ * After that it defines how to split the data to train and test sets and configures preprocessors that extract
+ * features from an upstream data and perform other desired changes over the extracted data.</p>
+ * <p>
+ * Then, it tunes hyperparams with K-fold Cross-Validation on the split training set and trains the model based on
+ * the processed data using decision tree classification and the obtained hyperparams.</p>
+ * <p>
+ * Finally, this example uses {@link BinaryClassificationEvaluator} functionality to compute metrics from predictions.</p>
+ * <p>
+ * The purpose of cross-validation is model checking, not model building.</p>
+ * <p>
+ * We train {@code k} different models.</p>
+ * <p>
+ * They differ in that {@code 1/(k-1)}th of the training data is exchanged against other cases.</p>
+ * <p>
+ * These models are sometimes called surrogate models because the (average) performance measured for these models
+ * is taken as a surrogate of the performance of the model trained on all cases.</p>
+ * <p>
+ * All scenarios are described there: https://sebastianraschka.com/faq/docs/evaluate-a-model.html</p>
+ */
+public class Step_8_CV_with_Param_Grid_and_metrics {
+    /** Run example. */
+    public static void main(String[] args) {
+        System.out.println();
+        System.out.println(">>> Tutorial step 8 (cross-validation with param grid) example started.");
+
+        try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
+            try {
+                IgniteCache<Integer, Object[]> dataCache = TitanicUtils.readPassengers(ignite);
+
+                // Defines first preprocessor that extracts features from an upstream data.
+                // Extracts "pclass", "sibsp", "parch", "sex", "embarked", "age", "fare" .
+                IgniteBiFunction<Integer, Object[], Object[]> featureExtractor
+                    = (k, v) -> new Object[] {v[0], v[3], v[4], v[5], v[6], v[8], v[10]};
+
+                IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double)v[1];
+
+                TrainTestSplit<Integer, Object[]> split = new TrainTestDatasetSplitter<Integer, Object[]>()
+                    .split(0.75);
+
+                IgniteBiFunction<Integer, Object[], Vector> strEncoderPreprocessor = new EncoderTrainer<Integer, Object[]>()
+                    .withEncoderType(EncoderType.STRING_ENCODER)
+                    .withEncodedFeature(1)
+                    .withEncodedFeature(6) // <--- Changed index here.
+                    .fit(ignite,
+                        dataCache,
+                        featureExtractor
+                    );
+
+                IgniteBiFunction<Integer, Object[], Vector> imputingPreprocessor = new ImputerTrainer<Integer, Object[]>()
+                    .fit(ignite,
+                        dataCache,
+                        strEncoderPreprocessor
+                    );
+
+                IgniteBiFunction<Integer, Object[], Vector> minMaxScalerPreprocessor = new MinMaxScalerTrainer<Integer, Object[]>()
+                    .fit(
+                        ignite,
+                        dataCache,
+                        imputingPreprocessor
+                    );
+
+                IgniteBiFunction<Integer, Object[], Vector> normalizationPreprocessor = new NormalizationTrainer<Integer, Object[]>()
+                    .withP(2)
+                    .fit(
+                        ignite,
+                        dataCache,
+                        minMaxScalerPreprocessor
+                    );
+
+                // Tune hyperparams with K-fold Cross-Validation on the split training set.
+
+                DecisionTreeClassificationTrainer trainerCV = new DecisionTreeClassificationTrainer();
+
+                CrossValidation<DecisionTreeNode, Double, Integer, Object[]> scoreCalculator
+                    = new CrossValidation<>();
+
+                ParamGrid paramGrid = new ParamGrid()
+                    .addHyperParam("maxDeep", new Double[] {1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 10.0})
+                    .addHyperParam("minImpurityDecrease", new Double[] {0.0, 0.25, 0.5});
+
+                BinaryClassificationMetrics metrics = new BinaryClassificationMetrics()
+                    .withNegativeClsLb(0.0)
+                    .withPositiveClsLb(1.0)
+                    .withMetric(BinaryClassificationMetricValues::accuracy);
+
+                CrossValidationResult crossValidationRes = scoreCalculator.score(
+                    trainerCV,
+                    metrics,
+                    ignite,
+                    dataCache,
+                    split.getTrainFilter(),
+                    normalizationPreprocessor,
+                    lbExtractor,
+                    3,
+                    paramGrid
+                );
+
+                System.out.println("Train with maxDeep: " + crossValidationRes.getBest("maxDeep")
+                    + " and minImpurityDecrease: " + crossValidationRes.getBest("minImpurityDecrease"));
+
+                DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer()
+                    .withMaxDeep(crossValidationRes.getBest("maxDeep"))
+                    .withMinImpurityDecrease(crossValidationRes.getBest("minImpurityDecrease"));
+
+                System.out.println(crossValidationRes);
+
+                System.out.println("Best score: " + Arrays.toString(crossValidationRes.getBestScore()));
+                System.out.println("Best hyper params: " + crossValidationRes.getBestHyperParams());
+                System.out.println("Best average score: " + crossValidationRes.getBestAvgScore());
+
+                crossValidationRes.getScoringBoard().forEach((hyperParams, score)
+                    -> System.out.println("Score " + Arrays.toString(score) + " for hyper params " + hyperParams));
+
+                // Train decision tree model.
+                DecisionTreeNode bestMdl = trainer.fit(
+                    ignite,
+                    dataCache,
+                    split.getTrainFilter(),
+                    normalizationPreprocessor,
+                    lbExtractor
+                );
+
+                System.out.println("\n>>> Trained model: " + bestMdl);
+
+                double accuracy = BinaryClassificationEvaluator.evaluate(
+                    dataCache,
+                    split.getTestFilter(),
+                    bestMdl,
+                    normalizationPreprocessor,
+                    lbExtractor,
+                    new Accuracy<>()
+                );
+
+                System.out.println("\n>>> Accuracy " + accuracy);
+                System.out.println("\n>>> Test Error " + (1 - accuracy));
+
+                System.out.println(">>> Tutorial step 8 (cross-validation with param grid) example started.");
+            }
+            catch (FileNotFoundException e) {
+                e.printStackTrace();
+            }
+        }
+    }
+}
index 30adc5c..9642bce 100644 (file)
@@ -49,7 +49,6 @@ public class BinaryClassificationEvaluator {
                                             IgniteBiFunction<K, V, Vector> featureExtractor,
                                             IgniteBiFunction<K, V, L> lbExtractor,
                                             Metric<L> metric) {
-
         return calculateMetric(dataCache, null, mdl, featureExtractor, lbExtractor, metric);
     }
 
@@ -72,7 +71,6 @@ public class BinaryClassificationEvaluator {
                                             IgniteBiFunction<K, V, Vector> featureExtractor,
                                             IgniteBiFunction<K, V, L> lbExtractor,
                                             Metric<L> metric) {
-
         return calculateMetric(dataCache, filter, mdl, featureExtractor, lbExtractor, metric);
     }
 
@@ -140,7 +138,7 @@ public class BinaryClassificationEvaluator {
             lbExtractor,
             mdl
         )) {
-            metricValues = binaryMetrics.score(cursor.iterator());
+            metricValues = binaryMetrics.scoreAll(cursor.iterator());
         } catch (Exception e) {
             throw new RuntimeException(e);
         }
@@ -163,8 +161,8 @@ public class BinaryClassificationEvaluator {
      * @return Computed metric.
      */
     private static <L, K, V> double calculateMetric(IgniteCache<K, V> dataCache, IgniteBiPredicate<K, V> filter,
-                                                    Model<Vector, L> mdl, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor,
-                                                    Metric<L> metric) {
+                                                    Model<Vector, L> mdl, IgniteBiFunction<K, V, Vector> featureExtractor,
+                                                    IgniteBiFunction<K, V, L> lbExtractor, Metric<L> metric) {
         double metricRes;
 
         try (LabelPairCursor<L> cursor = new CacheBasedLabelPairCursor<>(
index 0b15d04..bd4067a 100644 (file)
 package org.apache.ignite.ml.selection.scoring.metric;
 
 import java.util.Iterator;
+import java.util.function.Function;
 import org.apache.ignite.ml.selection.scoring.LabelPair;
 
 /**
  * Binary classification metrics calculator.
+ * It could be used in two ways: to caculate all binary classification metrics or specific metric.
  */
-public class BinaryClassificationMetrics {
+public class BinaryClassificationMetrics implements Metric<Double> {
     /** Positive class label. */
     private double positiveClsLb = 1.0;
 
     /** Negative class label. Default value is 0.0. */
     private double negativeClsLb;
 
+    /** The main metric to get individual score. */
+    private Function<BinaryClassificationMetricValues, Double> metric = BinaryClassificationMetricValues::accuracy;
+
     /**
      * Calculates binary metrics values.
      *
      * @param iter Iterator that supplies pairs of truth values and predicated.
      * @return Scores for all binary metrics.
      */
-    public BinaryClassificationMetricValues score(Iterator<LabelPair<Double>> iter) {
+    public BinaryClassificationMetricValues scoreAll(Iterator<LabelPair<Double>> iter) {
         long tp = 0;
         long tn = 0;
         long fp = 0;
@@ -83,4 +88,20 @@ public class BinaryClassificationMetrics {
         this.negativeClsLb = negativeClsLb;
         return this;
     }
+
+    /** */
+    public BinaryClassificationMetrics withMetric(Function<BinaryClassificationMetricValues, Double> metric) {
+        this.metric = metric;
+        return this;
+    }
+
+    /** {@inheritDoc} */
+    @Override public double score(Iterator<LabelPair<Double>> iter) {
+        return metric.apply(scoreAll(iter));
+    }
+
+    /** {@inheritDoc} */
+    @Override public String name() {
+        return "Binary classification metrics";
+    }
 }
index 3e8b9dd..e64aa7a 100644 (file)
@@ -21,6 +21,8 @@ import java.util.HashMap;
 import java.util.Map;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetricValues;
+import org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetrics;
 import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
 import org.apache.ignite.ml.tree.DecisionTreeNode;
 import org.junit.Test;
@@ -71,6 +73,46 @@ public class CrossValidationTest {
 
     /** */
     @Test
+    public void testScoreWithGoodDatasetAndBinaryMetrics() {
+        Map<Integer, Double> data = new HashMap<>();
+
+        for (int i = 0; i < 1000; i++)
+            data.put(i, i > 500 ? 1.0 : 0.0);
+
+        DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(1, 0);
+
+        CrossValidation<DecisionTreeNode, Double, Integer, Double> scoreCalculator =
+            new CrossValidation<>();
+
+        int folds = 4;
+
+        BinaryClassificationMetrics metrics = new BinaryClassificationMetrics()
+            .withMetric(BinaryClassificationMetricValues::accuracy);
+
+        verifyScores(folds, scoreCalculator.score(
+            trainer,
+            metrics,
+            data,
+            1,
+            (k, v) -> VectorUtils.of(k),
+            (k, v) -> v,
+            folds
+        ));
+
+        verifyScores(folds, scoreCalculator.score(
+            trainer,
+            new Accuracy<>(),
+            data,
+            (e1, e2) -> true,
+            1,
+            (k, v) -> VectorUtils.of(k),
+            (k, v) -> v,
+            folds
+        ));
+    }
+
+    /** */
+    @Test
     public void testScoreWithBadDataset() {
         Map<Integer, Double> data = new HashMap<>();