IGNITE-10713: [ML] Refactor examples with accuracy calculation and
authorzaleslaw <zaleslaw.sin@gmail.com>
Fri, 11 Jan 2019 10:56:16 +0000 (13:56 +0300)
committerYury Babak <ybabak@gridgain.com>
Fri, 11 Jan 2019 10:56:16 +0000 (13:56 +0300)
another metrics usage

This closes #5787

examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java
examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java
examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/DiscreteNaiveBayesTrainerExample.java
examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/GaussianNaiveBayesTrainerExample.java
examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java
examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/EvaluatorExample.java
examples/src/main/java/org/apache/ignite/examples/ml/svm/SVMBinaryClassificationExample.java

index 4a475a0..ec25006 100644 (file)
 package org.apache.ignite.examples.ml.knn;
 
 import java.io.FileNotFoundException;
-import javax.cache.Cache;
-import org.apache.commons.math3.util.Precision;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.query.QueryCursor;
-import org.apache.ignite.cache.query.ScanQuery;
 import org.apache.ignite.ml.knn.NNClassificationModel;
 import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer;
 import org.apache.ignite.ml.knn.classification.NNStrategy;
 import org.apache.ignite.ml.math.distances.EuclideanDistance;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
 import org.apache.ignite.ml.util.MLSandboxDatasets;
 import org.apache.ignite.ml.util.SandboxMLCache;
 
@@ -57,48 +55,30 @@ public class KNNClassificationExample {
             System.out.println(">>> Ignite grid started.");
 
             IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
-                .fillCacheWith(MLSandboxDatasets.IRIS);
+                .fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
 
             KNNClassificationTrainer trainer = new KNNClassificationTrainer();
 
-            NNClassificationModel knnMdl = trainer.fit(
+            IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
+            IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+
+            NNClassificationModel mdl = trainer.fit(
                 ignite,
                 dataCache,
-                (k, v) -> v.copyOfRange(1, v.size()),
-                (k, v) -> v.get(0)
+                featureExtractor,
+                lbExtractor
             ).withK(3)
                 .withDistanceMeasure(new EuclideanDistance())
                 .withStrategy(NNStrategy.WEIGHTED);
 
-            System.out.println(">>> ---------------------------------");
-            System.out.println(">>> | Prediction\t| Ground Truth\t|");
-            System.out.println(">>> ---------------------------------");
-
-            int amountOfErrors = 0;
-            int totalAmount = 0;
-
-            try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
-                for (Cache.Entry<Integer, Vector> observation : observations) {
-                    Vector val = observation.getValue();
-                    Vector inputs = val.copyOfRange(1, val.size());
-                    double groundTruth = val.get(0);
-
-                    double prediction = knnMdl.predict(inputs);
-
-                    totalAmount++;
-                    if (!Precision.equals(groundTruth, prediction, Precision.EPSILON))
-                        amountOfErrors++;
-
-                    System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
-                }
-
-                System.out.println(">>> ---------------------------------");
-
-                System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
-                System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double) totalAmount));
+            double accuracy = BinaryClassificationEvaluator.evaluate(
+                dataCache,
+                mdl,
+                featureExtractor,
+                lbExtractor
+            ).accuracy();
 
-                System.out.println(">>> kNN multi-class classification algorithm over cached dataset usage example completed.");
-            }
+            System.out.println("\n>>> Accuracy " + accuracy);
         }
     }
 }
index 8615b6c..e1791df 100644 (file)
@@ -98,10 +98,10 @@ public class KNNRegressionExample {
 
                 System.out.println(">>> ---------------------------------");
 
-                mse = mse / totalAmount;
+                mse /= totalAmount;
                 System.out.println("\n>>> Mean squared error (MSE) " + mse);
 
-                mae = mae / totalAmount;
+                mae /= totalAmount;
                 System.out.println("\n>>> Mean absolute error (MAE) " + mae);
 
                 System.out.println(">>> kNN regression over cached dataset usage example completed.");
index 54c9ce0..fff298b 100644 (file)
 package org.apache.ignite.examples.ml.naivebayes;
 
 import java.io.FileNotFoundException;
-import java.util.Arrays;
-import javax.cache.Cache;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.query.QueryCursor;
-import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesModel;
 import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesTrainer;
+import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
 import org.apache.ignite.ml.util.MLSandboxDatasets;
 import org.apache.ignite.ml.util.SandboxMLCache;
 
@@ -64,49 +62,26 @@ public class DiscreteNaiveBayesTrainerExample {
                 .setBucketThresholds(thresholds);
 
             System.out.println(">>> Perform the training to get the model.");
+            IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
+            IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+
             DiscreteNaiveBayesModel mdl = trainer.fit(
                 ignite,
                 dataCache,
-                (k, v) -> v.copyOfRange(1, v.size()),
-                (k, v) -> v.get(0)
+                featureExtractor,
+                lbExtractor
             );
 
             System.out.println(">>> Discrete Naive Bayes model: " + mdl);
 
-            int amountOfErrors = 0;
-            int totalAmount = 0;
-
-            // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
-            int[][] confusionMtx = {{0, 0}, {0, 0}};
-
-            try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
-                for (Cache.Entry<Integer, Vector> observation : observations) {
-                    Vector val = observation.getValue();
-                    Vector inputs = val.copyOfRange(1, val.size());
-                    double groundTruth = val.get(0);
-
-                    double prediction = mdl.predict(inputs);
-
-                    totalAmount++;
-                    if (groundTruth != prediction)
-                        amountOfErrors++;
-
-                    int idx1 = (int)prediction;
-                    int idx2 = (int)groundTruth;
-
-                    confusionMtx[idx1][idx2]++;
-
-                    System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
-                }
-
-                System.out.println(">>> ---------------------------------");
-
-                System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
-                System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
-            }
+            double accuracy = BinaryClassificationEvaluator.evaluate(
+                dataCache,
+                mdl,
+                featureExtractor,
+                lbExtractor
+            ).accuracy();
 
-            System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
-            System.out.println(">>> ---------------------------------");
+            System.out.println("\n>>> Accuracy " + accuracy);
 
             System.out.println(">>> Discrete Naive bayes model over partitioned dataset usage example completed.");
         }
index 74e0bfd..4459566 100644 (file)
 package org.apache.ignite.examples.ml.naivebayes;
 
 import java.io.FileNotFoundException;
-import java.util.Arrays;
-import javax.cache.Cache;
-import org.apache.commons.math3.util.Precision;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.query.QueryCursor;
-import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesModel;
 import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesTrainer;
+import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
 import org.apache.ignite.ml.util.MLSandboxDatasets;
 import org.apache.ignite.ml.util.SandboxMLCache;
 
@@ -63,49 +60,27 @@ public class GaussianNaiveBayesTrainerExample {
             GaussianNaiveBayesTrainer trainer = new GaussianNaiveBayesTrainer();
 
             System.out.println(">>> Perform the training to get the model.");
+
+            IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
+            IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+
             GaussianNaiveBayesModel mdl = trainer.fit(
                 ignite,
                 dataCache,
-                (k, v) -> v.copyOfRange(1, v.size()),
-                (k, v) -> v.get(0)
+                featureExtractor,
+                lbExtractor
             );
 
             System.out.println(">>> Naive Bayes model: " + mdl);
 
-            int amountOfErrors = 0;
-            int totalAmount = 0;
-
-            // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
-            int[][] confusionMtx = {{0, 0}, {0, 0}};
-
-            try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
-                for (Cache.Entry<Integer, Vector> observation : observations) {
-                    Vector val = observation.getValue();
-                    Vector inputs = val.copyOfRange(1, val.size());
-                    double groundTruth = val.get(0);
-
-                    double prediction = mdl.predict(inputs);
-
-                    totalAmount++;
-                    if (!Precision.equals(groundTruth, prediction, Precision.EPSILON))
-                        amountOfErrors++;
-
-                    int idx1 = (int)prediction;
-                    int idx2 = (int)groundTruth;
-
-                    confusionMtx[idx1][idx2]++;
-
-                    System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
-                }
-
-                System.out.println(">>> ---------------------------------");
-
-                System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
-                System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
-            }
+            double accuracy = BinaryClassificationEvaluator.evaluate(
+                dataCache,
+                mdl,
+                featureExtractor,
+                lbExtractor
+            ).accuracy();
 
-            System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
-            System.out.println(">>> ---------------------------------");
+            System.out.println("\n>>> Accuracy " + accuracy);
 
             System.out.println(">>> Naive bayes model over partitioned dataset usage example completed.");
         }
index 059f810..2dab1af 100644 (file)
 package org.apache.ignite.examples.ml.regression.logistic.binary;
 
 import java.io.FileNotFoundException;
-import java.util.Arrays;
-import javax.cache.Cache;
-import org.apache.commons.math3.util.Precision;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.query.QueryCursor;
-import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.nn.UpdatesStrategy;
 import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
 import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
 import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel;
 import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer;
+import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
 import org.apache.ignite.ml.util.MLSandboxDatasets;
 import org.apache.ignite.ml.util.SandboxMLCache;
 
@@ -75,49 +72,26 @@ public class LogisticRegressionSGDTrainerExample {
                 .withSeed(123L);
 
             System.out.println(">>> Perform the training to get the model.");
+            IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
+            IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+
             LogisticRegressionModel mdl = trainer.fit(
                 ignite,
                 dataCache,
-                (k, v) -> v.copyOfRange(1, v.size()),
-                (k, v) -> v.get(0)
+                featureExtractor,
+                lbExtractor
             );
 
             System.out.println(">>> Logistic regression model: " + mdl);
 
-            int amountOfErrors = 0;
-            int totalAmount = 0;
-
-            // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
-            int[][] confusionMtx = {{0, 0}, {0, 0}};
-
-            try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
-                for (Cache.Entry<Integer, Vector> observation : observations) {
-                    Vector val = observation.getValue();
-                    Vector inputs = val.copyOfRange(1, val.size());
-                    double groundTruth = val.get(0);
-
-                    double prediction = mdl.predict(inputs);
-
-                    totalAmount++;
-                    if (!Precision.equals(groundTruth, prediction, Precision.EPSILON))
-                        amountOfErrors++;
-
-                    int idx1 = (int)prediction;
-                    int idx2 = (int)groundTruth;
-
-                    confusionMtx[idx1][idx2]++;
-
-                    System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
-                }
-
-                System.out.println(">>> ---------------------------------");
-
-                System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
-                System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
-            }
+            double accuracy = BinaryClassificationEvaluator.evaluate(
+                dataCache,
+                mdl,
+                featureExtractor,
+                lbExtractor
+            ).accuracy();
 
-            System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
-            System.out.println(">>> ---------------------------------");
+            System.out.println("\n>>> Accuracy " + accuracy);
 
             System.out.println(">>> Logistic regression model over partitioned dataset usage example completed.");
         }
index a6a989b..c556e11 100644 (file)
@@ -77,7 +77,15 @@ public class EvaluatorExample {
             );
 
             System.out.println("\n>>> Accuracy " + accuracy);
-            System.out.println("\n>>> Test Error " + (1 - accuracy));
+
+            double f1Score = BinaryClassificationEvaluator.evaluate(
+                dataCache,
+                mdl,
+                featureExtractor,
+                lbExtractor
+            ).f1Score();
+
+            System.out.println("\n>>> F1-Score " + f1Score);
         }
     }
 }
index f057386..291c7f8 100644 (file)
 package org.apache.ignite.examples.ml.svm;
 
 import java.io.FileNotFoundException;
-import java.util.Arrays;
-import javax.cache.Cache;
-import org.apache.commons.math3.util.Precision;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.query.QueryCursor;
-import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
 import org.apache.ignite.ml.svm.SVMLinearClassificationModel;
 import org.apache.ignite.ml.svm.SVMLinearClassificationTrainer;
 import org.apache.ignite.ml.util.MLSandboxDatasets;
@@ -60,54 +57,28 @@ public class SVMBinaryClassificationExample {
 
             SVMLinearClassificationTrainer trainer = new SVMLinearClassificationTrainer();
 
+            IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
+            IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+
             SVMLinearClassificationModel mdl = trainer.fit(
                 ignite,
                 dataCache,
-                (k, v) -> v.copyOfRange(1, v.size()),
-                (k, v) -> v.get(0)
+                featureExtractor,
+                lbExtractor
             );
 
             System.out.println(">>> SVM model " + mdl);
 
-            System.out.println(">>> ---------------------------------");
-            System.out.println(">>> | Prediction\t| Ground Truth\t|");
-            System.out.println(">>> ---------------------------------");
-
-            int amountOfErrors = 0;
-            int totalAmount = 0;
-
-            // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
-            int[][] confusionMtx = {{0, 0}, {0, 0}};
-
-            try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
-                for (Cache.Entry<Integer, Vector> observation : observations) {
-                    Vector val = observation.getValue();
-                    Vector inputs = val.copyOfRange(1, val.size());
-                    double groundTruth = val.get(0);
-
-                    double prediction = mdl.predict(inputs);
-
-                    totalAmount++;
-                    if (!Precision.equals(groundTruth, prediction, Precision.EPSILON))
-                        amountOfErrors++;
-
-                    int idx1 = prediction == 0.0 ? 0 : 1;
-                    int idx2 = groundTruth == 0.0 ? 0 : 1;
-
-                    confusionMtx[idx1][idx2]++;
-
-                    System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
-                }
-
-                System.out.println(">>> ---------------------------------");
-
-                System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
-                System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
-            }
+            double accuracy = BinaryClassificationEvaluator.evaluate(
+                dataCache,
+                mdl,
+                featureExtractor,
+                lbExtractor
+            ).accuracy();
 
-            System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
+            System.out.println("\n>>> Accuracy " + accuracy);
 
-            System.out.println(">>> Linear regression model over cache based dataset usage example completed.");
+            System.out.println(">>> SVM Binary classification model over cache based dataset usage example completed.");
         }
     }
 }