IGNITE-10606: [ML] Add new tests for BinaryClassificationMetrics and
authorzaleslaw <zaleslaw.sin@gmail.com>
Thu, 27 Dec 2018 13:23:21 +0000 (16:23 +0300)
committerYury Babak <ybabak@gridgain.com>
Thu, 27 Dec 2018 13:23:21 +0000 (16:23 +0300)
Evaluator

This closes #5751

modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/cursor/LocalLabelPairCursor.java
modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/BinaryClassificationEvaluator.java
modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetrics.java
modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java
modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java
modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/BinaryClassificationEvaluatorTest.java [new file with mode: 0644]
modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/EvaluatorTest.java
modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetricsTest.java [new file with mode: 0644]
modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetricsValuesTest.java [new file with mode: 0644]

index f135450..d8c2240 100644 (file)
@@ -100,7 +100,14 @@ public class LocalLabelPairCursor<L, K, V, T> implements LabelPairCursor<L> {
 
         /** {@inheritDoc} */
         @Override public boolean hasNext() {
-            findNext();
+            if (filter == null) {
+                Map.Entry<K, V> entry = iter.next();
+                this.nextEntry = entry;
+                return iter.hasNext();
+            }
+
+            else
+                findNext();
 
             return nextEntry != null;
         }
index 9642bce..5cbe10f 100644 (file)
@@ -17,6 +17,7 @@
 
 package org.apache.ignite.ml.selection.scoring.evaluator;
 
+import java.util.Map;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.lang.IgniteBiPredicate;
 import org.apache.ignite.ml.Model;
@@ -24,6 +25,7 @@ import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.selection.scoring.cursor.CacheBasedLabelPairCursor;
 import org.apache.ignite.ml.selection.scoring.cursor.LabelPairCursor;
+import org.apache.ignite.ml.selection.scoring.cursor.LocalLabelPairCursor;
 import org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetricValues;
 import org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetrics;
 import org.apache.ignite.ml.selection.scoring.metric.Metric;
@@ -35,99 +37,179 @@ public class BinaryClassificationEvaluator {
     /**
      * Computes the given metric on the given cache.
      *
-     * @param dataCache        The given cache.
-     * @param mdl              The model.
+     * @param dataCache The given cache.
+     * @param mdl The model.
      * @param featureExtractor The feature extractor.
-     * @param lbExtractor      The label extractor.
-     * @param metric           The binary classification metric.
-     * @param <K>              The type of cache entry key.
-     * @param <V>              The type of cache entry value.
+     * @param lbExtractor The label extractor.
+     * @param metric The binary classification metric.
+     * @param <K> The type of cache entry key.
+     * @param <V> The type of cache entry value.
      * @return Computed metric.
      */
     public static <L, K, V> double evaluate(IgniteCache<K, V> dataCache,
-                                            Model<Vector, L> mdl,
-                                            IgniteBiFunction<K, V, Vector> featureExtractor,
-                                            IgniteBiFunction<K, V, L> lbExtractor,
-                                            Metric<L> metric) {
+        Model<Vector, L> mdl,
+        IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, L> lbExtractor,
+        Metric<L> metric) {
         return calculateMetric(dataCache, null, mdl, featureExtractor, lbExtractor, metric);
     }
 
     /**
      * Computes the given metric on the given cache.
      *
-     * @param dataCache        The given cache.
-     * @param filter           The given filter.
-     * @param mdl              The model.
+     * @param dataCache The given local data.
+     * @param mdl The model.
      * @param featureExtractor The feature extractor.
-     * @param lbExtractor      The label extractor.
-     * @param metric           The binary classification metric.
-     * @param <L>              The type of label.
-     * @param <K>              The type of cache entry key.
-     * @param <V>              The type of cache entry value.
+     * @param lbExtractor The label extractor.
+     * @param metric The binary classification metric.
+     * @param <K> The type of cache entry key.
+     * @param <V> The type of cache entry value.
+     * @return Computed metric.
+     */
+    public static <L, K, V> double evaluate(Map<K, V> dataCache,
+        Model<Vector, L> mdl,
+        IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, L> lbExtractor,
+        Metric<L> metric) {
+        return calculateMetric(dataCache, null, mdl, featureExtractor, lbExtractor, metric);
+    }
+
+    /**
+     * Computes the given metric on the given cache.
+     *
+     * @param dataCache The given cache.
+     * @param filter The given filter.
+     * @param mdl The model.
+     * @param featureExtractor The feature extractor.
+     * @param lbExtractor The label extractor.
+     * @param metric The binary classification metric.
+     * @param <L> The type of label.
+     * @param <K> The type of cache entry key.
+     * @param <V> The type of cache entry value.
      * @return Computed metric.
      */
     public static <L, K, V> double evaluate(IgniteCache<K, V> dataCache, IgniteBiPredicate<K, V> filter,
-                                            Model<Vector, L> mdl,
-                                            IgniteBiFunction<K, V, Vector> featureExtractor,
-                                            IgniteBiFunction<K, V, L> lbExtractor,
-                                            Metric<L> metric) {
+        Model<Vector, L> mdl,
+        IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, L> lbExtractor,
+        Metric<L> metric) {
+        return calculateMetric(dataCache, filter, mdl, featureExtractor, lbExtractor, metric);
+    }
+
+    /**
+     * Computes the given metric on the given cache.
+     *
+     * @param dataCache The given cache.
+     * @param filter The given filter.
+     * @param mdl The model.
+     * @param featureExtractor The feature extractor.
+     * @param lbExtractor The label extractor.
+     * @param metric The binary classification metric.
+     * @param <L> The type of label.
+     * @param <K> The type of cache entry key.
+     * @param <V> The type of cache entry value.
+     * @return Computed metric.
+     */
+    public static <L, K, V> double evaluate(Map<K, V> dataCache, IgniteBiPredicate<K, V> filter,
+        Model<Vector, L> mdl,
+        IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, L> lbExtractor,
+        Metric<L> metric) {
         return calculateMetric(dataCache, filter, mdl, featureExtractor, lbExtractor, metric);
     }
 
     /**
      * Computes the given metrics on the given cache.
      *
-     * @param dataCache        The given cache.
-     * @param mdl              The model.
+     * @param dataCache The given cache.
+     * @param mdl The model.
      * @param featureExtractor The feature extractor.
-     * @param lbExtractor      The label extractor.
-     * @param <K>              The type of cache entry key.
-     * @param <V>              The type of cache entry value.
+     * @param lbExtractor The label extractor.
+     * @param <K> The type of cache entry key.
+     * @param <V> The type of cache entry value.
      * @return Computed metric.
      */
     public static <K, V> BinaryClassificationMetricValues evaluate(IgniteCache<K, V> dataCache,
-                                                                   Model<Vector, Double> mdl,
-                                                                   IgniteBiFunction<K, V, Vector> featureExtractor,
-                                                                   IgniteBiFunction<K, V, Double> lbExtractor) {
+        Model<Vector, Double> mdl,
+        IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
+        return calcMetricValues(dataCache, null, mdl, featureExtractor, lbExtractor);
+    }
+
+    /**
+     * Computes the given metrics on the given cache.
+     *
+     * @param dataCache The given cache.
+     * @param mdl The model.
+     * @param featureExtractor The feature extractor.
+     * @param lbExtractor The label extractor.
+     * @param <K> The type of cache entry key.
+     * @param <V> The type of cache entry value.
+     * @return Computed metric.
+     */
+    public static <K, V> BinaryClassificationMetricValues evaluate(Map<K, V> dataCache,
+        Model<Vector, Double> mdl,
+        IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
         return calcMetricValues(dataCache, null, mdl, featureExtractor, lbExtractor);
     }
 
     /**
      * Computes the given metrics on the given cache.
      *
-     * @param dataCache        The given cache.
-     * @param filter           The given filter.
-     * @param mdl              The model.
+     * @param dataCache The given cache.
+     * @param filter The given filter.
+     * @param mdl The model.
      * @param featureExtractor The feature extractor.
-     * @param lbExtractor      The label extractor.
-     * @param <K>              The type of cache entry key.
-     * @param <V>              The type of cache entry value.
+     * @param lbExtractor The label extractor.
+     * @param <K> The type of cache entry key.
+     * @param <V> The type of cache entry value.
      * @return Computed metric.
      */
-    public static <K, V> BinaryClassificationMetricValues evaluate(IgniteCache<K, V> dataCache, IgniteBiPredicate<K, V> filter,
-                                                                   Model<Vector, Double> mdl,
-                                                                   IgniteBiFunction<K, V, Vector> featureExtractor,
-                                                                   IgniteBiFunction<K, V, Double> lbExtractor) {
+    public static <K, V> BinaryClassificationMetricValues evaluate(IgniteCache<K, V> dataCache,
+        IgniteBiPredicate<K, V> filter,
+        Model<Vector, Double> mdl,
+        IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
         return calcMetricValues(dataCache, filter, mdl, featureExtractor, lbExtractor);
     }
 
     /**
      * Computes the given metrics on the given cache.
      *
-     * @param dataCache        The given cache.
-     * @param filter           The given filter.
-     * @param mdl              The model.
+     * @param dataCache The given cache.
+     * @param filter The given filter.
+     * @param mdl The model.
      * @param featureExtractor The feature extractor.
-     * @param lbExtractor      The label extractor.
-     * @param <K>              The type of cache entry key.
-     * @param <V>              The type of cache entry value.
+     * @param lbExtractor The label extractor.
+     * @param <K> The type of cache entry key.
+     * @param <V> The type of cache entry value.
+     * @return Computed metric.
+     */
+    public static <K, V> BinaryClassificationMetricValues evaluate(Map<K, V> dataCache, IgniteBiPredicate<K, V> filter,
+        Model<Vector, Double> mdl,
+        IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
+        return calcMetricValues(dataCache, filter, mdl, featureExtractor, lbExtractor);
+    }
+
+    /**
+     * Computes the given metrics on the given cache.
+     *
+     * @param dataCache The given cache.
+     * @param filter The given filter.
+     * @param mdl The model.
+     * @param featureExtractor The feature extractor.
+     * @param lbExtractor The label extractor.
+     * @param <K> The type of cache entry key.
+     * @param <V> The type of cache entry value.
      * @return Computed metric.
      */
     private static <K, V> BinaryClassificationMetricValues calcMetricValues(IgniteCache<K, V> dataCache,
-                                                                            IgniteBiPredicate<K, V> filter,
-                                                                            Model<Vector, Double> mdl,
-                                                                            IgniteBiFunction<K, V, Vector> featureExtractor,
-                                                                            IgniteBiFunction<K, V, Double> lbExtractor) {
+        IgniteBiPredicate<K, V> filter,
+        Model<Vector, Double> mdl,
+        IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
         BinaryClassificationMetricValues metricValues;
         BinaryClassificationMetrics binaryMetrics = new BinaryClassificationMetrics();
 
@@ -139,7 +221,44 @@ public class BinaryClassificationEvaluator {
             mdl
         )) {
             metricValues = binaryMetrics.scoreAll(cursor.iterator());
-        } catch (Exception e) {
+        }
+        catch (Exception e) {
+            throw new RuntimeException(e);
+        }
+
+        return metricValues;
+    }
+
+    /**
+     * Computes the given metrics on the given cache.
+     *
+     * @param dataCache The given cache.
+     * @param filter The given filter.
+     * @param mdl The model.
+     * @param featureExtractor The feature extractor.
+     * @param lbExtractor The label extractor.
+     * @param <K> The type of cache entry key.
+     * @param <V> The type of cache entry value.
+     * @return Computed metric.
+     */
+    private static <K, V> BinaryClassificationMetricValues calcMetricValues(Map<K, V> dataCache,
+        IgniteBiPredicate<K, V> filter,
+        Model<Vector, Double> mdl,
+        IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
+        BinaryClassificationMetricValues metricValues;
+        BinaryClassificationMetrics binaryMetrics = new BinaryClassificationMetrics();
+
+        try (LabelPairCursor<Double> cursor = new LocalLabelPairCursor<>(
+            dataCache,
+            filter,
+            featureExtractor,
+            lbExtractor,
+            mdl
+        )) {
+            metricValues = binaryMetrics.scoreAll(cursor.iterator());
+        }
+        catch (Exception e) {
             throw new RuntimeException(e);
         }
 
@@ -149,20 +268,20 @@ public class BinaryClassificationEvaluator {
     /**
      * Computes the given metric on the given cache.
      *
-     * @param dataCache        The given cache.
-     * @param filter           The given filter.
-     * @param mdl              The model.
+     * @param dataCache The given cache.
+     * @param filter The given filter.
+     * @param mdl The model.
      * @param featureExtractor The feature extractor.
-     * @param lbExtractor      The label extractor.
-     * @param metric           The binary classification metric.
-     * @param <L>              The type of label.
-     * @param <K>              The type of cache entry key.
-     * @param <V>              The type of cache entry value.
+     * @param lbExtractor The label extractor.
+     * @param metric The binary classification metric.
+     * @param <L> The type of label.
+     * @param <K> The type of cache entry key.
+     * @param <V> The type of cache entry value.
      * @return Computed metric.
      */
     private static <L, K, V> double calculateMetric(IgniteCache<K, V> dataCache, IgniteBiPredicate<K, V> filter,
-                                                    Model<Vector, L> mdl, IgniteBiFunction<K, V, Vector> featureExtractor,
-                                                    IgniteBiFunction<K, V, L> lbExtractor, Metric<L> metric) {
+        Model<Vector, L> mdl, IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, L> lbExtractor, Metric<L> metric) {
         double metricRes;
 
         try (LabelPairCursor<L> cursor = new CacheBasedLabelPairCursor<>(
@@ -173,7 +292,43 @@ public class BinaryClassificationEvaluator {
             mdl
         )) {
             metricRes = metric.score(cursor.iterator());
-        } catch (Exception e) {
+        }
+        catch (Exception e) {
+            throw new RuntimeException(e);
+        }
+
+        return metricRes;
+    }
+
+    /**
+     * Computes the given metric on the given cache.
+     *
+     * @param dataCache The given cache.
+     * @param filter The given filter.
+     * @param mdl The model.
+     * @param featureExtractor The feature extractor.
+     * @param lbExtractor The label extractor.
+     * @param metric The binary classification metric.
+     * @param <L> The type of label.
+     * @param <K> The type of cache entry key.
+     * @param <V> The type of cache entry value.
+     * @return Computed metric.
+     */
+    private static <L, K, V> double calculateMetric(Map<K, V> dataCache, IgniteBiPredicate<K, V> filter,
+        Model<Vector, L> mdl, IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, L> lbExtractor, Metric<L> metric) {
+        double metricRes;
+
+        try (LabelPairCursor<L> cursor = new LocalLabelPairCursor<>(
+            dataCache,
+            filter,
+            featureExtractor,
+            lbExtractor,
+            mdl
+        )) {
+            metricRes = metric.score(cursor.iterator());
+        }
+        catch (Exception e) {
             throw new RuntimeException(e);
         }
 
index bd4067a..35da9fa 100644 (file)
@@ -74,7 +74,8 @@ public class BinaryClassificationMetrics implements Metric<Double> {
 
     /** */
     public BinaryClassificationMetrics withPositiveClsLb(double positiveClsLb) {
-        this.positiveClsLb = positiveClsLb;
+        if (Double.isFinite(positiveClsLb))
+            this.positiveClsLb = positiveClsLb;
         return this;
     }
 
@@ -85,13 +86,15 @@ public class BinaryClassificationMetrics implements Metric<Double> {
 
     /** */
     public BinaryClassificationMetrics withNegativeClsLb(double negativeClsLb) {
-        this.negativeClsLb = negativeClsLb;
+        if (Double.isFinite(negativeClsLb))
+            this.negativeClsLb = negativeClsLb;
         return this;
     }
 
     /** */
     public BinaryClassificationMetrics withMetric(Function<BinaryClassificationMetricValues, Double> metric) {
-        this.metric = metric;
+        if (metric != null)
+            this.metric = metric;
         return this;
     }
 
index 161a40c..3f715dc 100644 (file)
@@ -302,7 +302,7 @@ public abstract class DatasetTrainer<M extends Model, L> {
     // TODO: IGNITE-10441 Think about more elegant ways to perform fluent API.
     public DatasetTrainer<M, L> withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
         this.envBuilder  = envBuilder;
-        this.environment = envBuilder.buildForTrainer();
+        environment = envBuilder.buildForTrainer();
 
         return this;
     }
index 748123a..6fe8a63 100644 (file)
@@ -85,9 +85,9 @@ public class KNNClassificationTest {
             .withDistanceMeasure(new EuclideanDistance())
             .withStrategy(NNStrategy.SIMPLE);
 
-        assertTrue(knnMdl.toString().length() > 0);
-        assertTrue(knnMdl.toString(true).length() > 0);
-        assertTrue(knnMdl.toString(false).length() > 0);
+        assertTrue(!knnMdl.toString().isEmpty());
+        assertTrue(!knnMdl.toString(true).isEmpty());
+        assertTrue(!knnMdl.toString(false).isEmpty());
 
         Vector firstVector = new DenseVector(new double[] {2.0, 2.0});
         assertEquals(knnMdl.apply(firstVector), 1.0);
index e2f8feb..0f62c92 100644 (file)
@@ -23,8 +23,11 @@ import org.apache.ignite.ml.selection.cv.CrossValidationTest;
 import org.apache.ignite.ml.selection.paramgrid.ParameterSetGeneratorTest;
 import org.apache.ignite.ml.selection.scoring.cursor.CacheBasedLabelPairCursorTest;
 import org.apache.ignite.ml.selection.scoring.cursor.LocalLabelPairCursorTest;
+import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluatorTest;
 import org.apache.ignite.ml.selection.scoring.evaluator.EvaluatorTest;
 import org.apache.ignite.ml.selection.scoring.metric.AccuracyTest;
+import org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetricsTest;
+import org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetricsValuesTest;
 import org.apache.ignite.ml.selection.scoring.metric.FmeasureTest;
 import org.apache.ignite.ml.selection.scoring.metric.PrecisionTest;
 import org.apache.ignite.ml.selection.scoring.metric.RecallTest;
@@ -53,6 +56,10 @@ public class SelectionTestSuite {
         suite.addTest(new JUnit4TestAdapter(TrainTestDatasetSplitterTest.class));
         suite.addTest(new JUnit4TestAdapter(EvaluatorTest.class));
         suite.addTest(new JUnit4TestAdapter(CacheBasedLabelPairCursorTest.class));
+        suite.addTest(new JUnit4TestAdapter(BinaryClassificationMetricsTest.class));
+        suite.addTest(new JUnit4TestAdapter(BinaryClassificationMetricsValuesTest.class));
+        suite.addTest(new JUnit4TestAdapter(BinaryClassificationEvaluatorTest.class));
+
 
         return suite;
     }
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/BinaryClassificationEvaluatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/BinaryClassificationEvaluatorTest.java
new file mode 100644 (file)
index 0000000..c6222c8
--- /dev/null
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.selection.scoring.evaluator;
+
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.ignite.ml.common.TrainerTest;
+import org.apache.ignite.ml.knn.NNClassificationModel;
+import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter;
+import org.apache.ignite.ml.selection.split.TrainTestSplit;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Tests for {@link BinaryClassificationEvaluator}.
+ */
+public class BinaryClassificationEvaluatorTest extends TrainerTest {
+    /**
+     * Test evalutor and trainer on classification model y = x.
+     */
+    @Test
+    public void testEvaluatorWithoutFilter() {
+        Map<Integer, Vector> cacheMock = new HashMap<>();
+
+        for (int i = 0; i < twoLinearlySeparableClasses.length; i++)
+            cacheMock.put(i, VectorUtils.of(twoLinearlySeparableClasses[i]));
+
+        KNNClassificationTrainer trainer = new KNNClassificationTrainer();
+
+        IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
+        IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+
+        NNClassificationModel mdl = trainer.fit(
+            cacheMock,
+            parts,
+            featureExtractor,
+            lbExtractor
+        ).withK(3);
+
+        double score = BinaryClassificationEvaluator.evaluate(cacheMock, mdl, featureExtractor, lbExtractor, new Accuracy<>());
+
+        assertEquals(0.9839357429718876, score, 1e-12);
+    }
+
+    /**
+     * Test evalutor and trainer on classification model y = x.
+     */
+    @Test
+    public void testEvaluatorWithFilter() {
+        Map<Integer, Vector> cacheMock = new HashMap<>();
+
+        for (int i = 0; i < twoLinearlySeparableClasses.length; i++)
+            cacheMock.put(i, VectorUtils.of(twoLinearlySeparableClasses[i]));
+
+        KNNClassificationTrainer trainer = new KNNClassificationTrainer();
+
+        IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
+        IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+
+        TrainTestSplit<Integer, Vector> split = new TrainTestDatasetSplitter<Integer, Vector>()
+            .split(0.75);
+
+        NNClassificationModel mdl = trainer.fit(
+            cacheMock,
+            split.getTrainFilter(),
+            parts,
+            featureExtractor,
+            lbExtractor
+        ).withK(3);
+
+        double score = BinaryClassificationEvaluator.evaluate(cacheMock, mdl, featureExtractor, lbExtractor, new Accuracy<>());
+
+        assertEquals(0.9, score, 1);
+    }
+}
index 5025460..9ce35a0 100644 (file)
@@ -203,7 +203,7 @@ public class EvaluatorTest extends GridCommonAbstractTest {
 
     /** */
     private void assertResults(CrossValidationResult res, List<double[]> scores, double accuracy, double accuracy2) {
-        assertTrue(res.toString().length() > 0);
+        assertTrue(!res.toString().isEmpty());
         assertEquals("Best maxDeep", 1.0, res.getBest("maxDeep"));
         assertEquals("Best minImpurityDecrease", 0.0, res.getBest("minImpurityDecrease"));
         assertArrayEquals("Best score", new double[] {0.6666666666666666, 0.6, 0}, res.getBestScore(), 0);
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetricsTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetricsTest.java
new file mode 100644 (file)
index 0000000..a173f5e
--- /dev/null
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.selection.scoring.metric;
+
+import java.util.Arrays;
+import org.apache.ignite.ml.selection.scoring.TestLabelPairCursor;
+import org.apache.ignite.ml.selection.scoring.cursor.LabelPairCursor;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Tests for {@link BinaryClassificationMetrics}.
+ */
+public class BinaryClassificationMetricsTest {
+    /** */
+    @Test
+    public void testDefaultBehaviour() {
+        Metric scoreCalculator = new BinaryClassificationMetrics();
+
+        LabelPairCursor<Double> cursor = new TestLabelPairCursor<>(
+            Arrays.asList(1.0, 1.0, 1.0, 1.0),
+            Arrays.asList(1.0, 1.0, 0.0, 1.0)
+        );
+
+        double score = scoreCalculator.score(cursor.iterator());
+
+        assertEquals(0.75, score, 1e-12);
+    }
+
+    /** */
+    @Test
+    public void testDefaultBehaviourForScoreAll() {
+        BinaryClassificationMetrics scoreCalculator = new BinaryClassificationMetrics();
+
+        LabelPairCursor<Double> cursor = new TestLabelPairCursor<>(
+            Arrays.asList(1.0, 1.0, 1.0, 1.0),
+            Arrays.asList(1.0, 1.0, 0.0, 1.0)
+        );
+
+        BinaryClassificationMetricValues metricValues = scoreCalculator.scoreAll(cursor.iterator());
+
+        assertEquals(0.75, metricValues.accuracy(), 1e-12);
+    }
+
+    /** */
+    @Test
+    public void testAccuracy() {
+        Metric scoreCalculator = new BinaryClassificationMetrics()
+            .withNegativeClsLb(1.0)
+            .withPositiveClsLb(2.0);
+
+        LabelPairCursor<Double> cursor = new TestLabelPairCursor<>(
+            Arrays.asList(2.0, 2.0, 2.0, 2.0),
+            Arrays.asList(2.0, 2.0, 1.0, 2.0)
+        );
+
+        double score = scoreCalculator.score(cursor.iterator());
+
+        assertEquals(0.75, score, 1e-12);
+    }
+
+    /** */
+    @Test
+    public void testCustomMetric() {
+        Metric scoreCalculator = new BinaryClassificationMetrics()
+            .withNegativeClsLb(1.0)
+            .withPositiveClsLb(2.0)
+            .withMetric(BinaryClassificationMetricValues::tp);
+
+        LabelPairCursor<Double> cursor = new TestLabelPairCursor<>(
+            Arrays.asList(2.0, 2.0, 2.0, 2.0),
+            Arrays.asList(2.0, 2.0, 1.0, 2.0)
+        );
+
+        double score = scoreCalculator.score(cursor.iterator());
+
+        assertEquals(3, score, 1e-12);
+    }
+
+    /** */
+    @Test
+    public void testNullCustomMetric() {
+        Metric scoreCalculator = new BinaryClassificationMetrics()
+            .withNegativeClsLb(1.0)
+            .withPositiveClsLb(2.0)
+            .withMetric(null);
+
+        LabelPairCursor<Double> cursor = new TestLabelPairCursor<>(
+            Arrays.asList(2.0, 2.0, 2.0, 2.0),
+            Arrays.asList(2.0, 2.0, 1.0, 2.0)
+        );
+
+        double score = scoreCalculator.score(cursor.iterator());
+
+        // accuracy as default metric
+        assertEquals(0.75, score, 1e-12);
+    }
+
+    /** */
+    @Test
+    public void testNaNinClassLabels() {
+        Metric scoreCalculator = new BinaryClassificationMetrics()
+            .withNegativeClsLb(Double.NaN)
+            .withPositiveClsLb(Double.POSITIVE_INFINITY);
+
+        LabelPairCursor<Double> cursor = new TestLabelPairCursor<>(
+            Arrays.asList(1.0, 1.0, 1.0, 1.0),
+            Arrays.asList(1.0, 1.0, 0.0, 1.0)
+        );
+
+        double score = scoreCalculator.score(cursor.iterator());
+
+        // accuracy as default metric
+        assertEquals(0.75, score, 1e-12);
+    }
+
+    /** */
+    @Test(expected = org.apache.ignite.ml.selection.scoring.metric.UnknownClassLabelException.class)
+    public void testFailWithIncorrectClassLabelsInData() {
+        Metric scoreCalculator = new BinaryClassificationMetrics();
+
+        LabelPairCursor<Double> cursor = new TestLabelPairCursor<>(
+            Arrays.asList(2.0, 2.0, 2.0, 2.0),
+            Arrays.asList(2.0, 2.0, 1.0, 2.0)
+        );
+
+        scoreCalculator.score(cursor.iterator());
+    }
+
+    /** */
+    @Test(expected = org.apache.ignite.ml.selection.scoring.metric.UnknownClassLabelException.class)
+    public void testFailWithIncorrectClassLabelsInMetrics() {
+        Metric scoreCalculator = new BinaryClassificationMetrics()
+            .withPositiveClsLb(42);
+
+        LabelPairCursor<Double> cursor = new TestLabelPairCursor<>(
+            Arrays.asList(1.0, 1.0, 1.0, 1.0),
+            Arrays.asList(1.0, 1.0, 0.0, 1.0)
+        );
+
+        scoreCalculator.score(cursor.iterator());
+    }
+}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetricsValuesTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetricsValuesTest.java
new file mode 100644 (file)
index 0000000..75a8183
--- /dev/null
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.selection.scoring.metric;
+
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Tests for {@link BinaryClassificationMetrics}.
+ */
+public class BinaryClassificationMetricsValuesTest {
+    /** */
+    @Test
+    public void testDefaultBehaviour() {
+        BinaryClassificationMetricValues metricValues = new BinaryClassificationMetricValues(10, 10, 5, 5);
+
+        assertEquals(10, metricValues.tp(), 1e-2);
+        assertEquals(10, metricValues.tn(), 1e-2);
+        assertEquals(5, metricValues.fn(), 1e-2);
+        assertEquals(5, metricValues.fp(), 1e-2);
+        assertEquals(0.66, metricValues.accuracy(), 1e-2);
+        assertEquals(0.66, metricValues.balancedAccuracy(), 1e-2);
+        assertEquals(0.66, metricValues.f1Score(), 1e-2);
+        assertEquals(0.33, metricValues.fallOut(), 1e-2);
+        assertEquals(0.33, metricValues.fdr(), 1e-2);
+        assertEquals(0.33, metricValues.missRate(), 1e-2);
+        assertEquals(0.66, metricValues.npv(), 1e-2);
+        assertEquals(0.66, metricValues.precision(), 1e-2);
+        assertEquals(0.66, metricValues.recall(), 1e-2);
+        assertEquals(0.66, metricValues.specificity(), 1e-2);
+    }
+}