IGNITE-9261: [ML] Add ANN algorithm based on ACD concept
authorzaleslaw <zaleslaw.sin@gmail.com>
Wed, 15 Aug 2018 16:09:26 +0000 (19:09 +0300)
committerYury Babak <ybabak@gridgain.com>
Wed, 15 Aug 2018 16:09:26 +0000 (19:09 +0300)
this closes #4534

31 files changed:
examples/src/main/java/org/apache/ignite/examples/ml/knn/ANNClassificationExample.java [new file with mode: 0644]
examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java
examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java
examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java
modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBBinaryClassifierTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/environment/parallelism/ParallelismStrategy.java
modules/ml/src/main/java/org/apache/ignite/ml/knn/KNNUtils.java
modules/ml/src/main/java/org/apache/ignite/ml/knn/NNClassificationModel.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationModel.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNModelFormat.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ProbableLabel.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/package-info.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java
modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNModelFormat.java
modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/NNStrategy.java [moved from modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNStrategy.java with 98% similarity]
modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionModel.java
modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVectorSet.java [moved from modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDataset.java with 88% similarity]
modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVectorSetTestTrainPair.java [moved from modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDatasetTestTrainPair.java with 89% similarity]
modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabeledDatasetPartitionDataBuilderOnHeap.java
modules/ml/src/main/java/org/apache/ignite/ml/structures/preprocessing/LabeledDatasetLoader.java
modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java
modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java
modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java [new file with mode: 0644]
modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java
modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNTestSuite.java
modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetHelper.java
modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledVectorSetTest.java [moved from modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java with 80% similarity]

diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/knn/ANNClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/knn/ANNClassificationExample.java
new file mode 100644 (file)
index 0000000..9a68207
--- /dev/null
@@ -0,0 +1,307 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.knn;
+
+import java.util.Arrays;
+import java.util.UUID;
+import javax.cache.Cache;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
+import org.apache.ignite.cache.query.QueryCursor;
+import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.configuration.CacheConfiguration;
+import org.apache.ignite.ml.knn.NNClassificationModel;
+import org.apache.ignite.ml.knn.ann.ANNClassificationTrainer;
+import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer;
+import org.apache.ignite.ml.knn.classification.NNStrategy;
+import org.apache.ignite.ml.math.distances.EuclideanDistance;
+import org.apache.ignite.ml.math.distances.ManhattanDistance;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
+import org.apache.ignite.thread.IgniteThread;
+
+/**
+ * Run ANN multi-class classification trainer over distributed dataset.
+ *
+ * @see KNNClassificationTrainer
+ */
+public class ANNClassificationExample {
+    /** Run example. */
+    public static void main(String[] args) throws InterruptedException {
+        System.out.println();
+        System.out.println(">>> ANN multi-class classification algorithm over cached dataset usage example started.");
+        // Start ignite grid.
+        try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
+            System.out.println(">>> Ignite grid started.");
+
+            IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
+                ANNClassificationExample.class.getSimpleName(), () -> {
+                IgniteCache<Integer, double[]> dataCache = getTestCache(ignite);
+
+                ANNClassificationTrainer trainer = new ANNClassificationTrainer()
+                    .withDistance(new ManhattanDistance())
+                    .withK(50)
+                    .withMaxIterations(1000)
+                    .withSeed(1234L)
+                    .withEpsilon(1e-2);
+
+                long startTrainingTime = System.currentTimeMillis();
+
+                NNClassificationModel knnMdl = trainer.fit(
+                    ignite,
+                    dataCache,
+                    (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+                    (k, v) -> v[0]
+                ).withK(5)
+                    .withDistanceMeasure(new EuclideanDistance())
+                    .withStrategy(NNStrategy.WEIGHTED);
+
+                long endTrainingTime = System.currentTimeMillis();
+
+                System.out.println(">>> ---------------------------------");
+                System.out.println(">>> | Prediction\t| Ground Truth\t|");
+                System.out.println(">>> ---------------------------------");
+
+                int amountOfErrors = 0;
+                int totalAmount = 0;
+
+                long totalPredictionTime = 0L;
+
+                try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
+                    for (Cache.Entry<Integer, double[]> observation : observations) {
+                        double[] val = observation.getValue();
+                        double[] inputs = Arrays.copyOfRange(val, 1, val.length);
+                        double groundTruth = val[0];
+
+                        long startPredictionTime = System.currentTimeMillis();
+                        double prediction = knnMdl.apply(new DenseVector(inputs));
+                        long endPredictionTime = System.currentTimeMillis();
+
+                        totalPredictionTime += (endPredictionTime - startPredictionTime);
+
+                        totalAmount++;
+                        if (groundTruth != prediction)
+                            amountOfErrors++;
+
+                        System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
+                    }
+
+                    System.out.println(">>> ---------------------------------");
+
+                    System.out.println("Training costs = " + (endTrainingTime - startTrainingTime));
+                    System.out.println("Prediction costs = " + totalPredictionTime);
+
+                    System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
+                    System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double) totalAmount));
+                    System.out.println(totalAmount);
+                }
+            });
+
+            igniteThread.start();
+            igniteThread.join();
+        }
+    }
+
+    /**
+     * Fills cache with data and returns it.
+     *
+     * @param ignite Ignite instance.
+     * @return Filled Ignite Cache.
+     */
+    private static IgniteCache<Integer, double[]> getTestCache(Ignite ignite) {
+        CacheConfiguration<Integer, double[]> cacheConfiguration = new CacheConfiguration<>();
+        cacheConfiguration.setName("TEST_" + UUID.randomUUID());
+        cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10));
+
+        IgniteCache<Integer, double[]> cache = ignite.createCache(cacheConfiguration);
+
+        for (int k = 0; k < 10; k++) { // multiplies the Iris dataset k times.
+            for (int i = 0; i < data.length; i++)
+                cache.put(k * 10000 + i, mutate(data[i], k));
+        }
+
+        return cache;
+    }
+
+    /**
+     * Tiny changing of data depending on k parameter.
+     * @param datum The vector data.
+     * @param k The passed parameter.
+     * @return The changed vector data.
+     */
+    private static double[] mutate(double[] datum, int k) {
+        for (int i = 0; i < datum.length; i++) datum[i] += k / 100000;
+        return datum;
+    }
+
+    /** The Iris dataset. */
+    private static final double[][] data = {
+        {1, 5.1, 3.5, 1.4, 0.2},
+        {1, 4.9, 3, 1.4, 0.2},
+        {1, 4.7, 3.2, 1.3, 0.2},
+        {1, 4.6, 3.1, 1.5, 0.2},
+        {1, 5, 3.6, 1.4, 0.2},
+        {1, 5.4, 3.9, 1.7, 0.4},
+        {1, 4.6, 3.4, 1.4, 0.3},
+        {1, 5, 3.4, 1.5, 0.2},
+        {1, 4.4, 2.9, 1.4, 0.2},
+        {1, 4.9, 3.1, 1.5, 0.1},
+        {1, 5.4, 3.7, 1.5, 0.2},
+        {1, 4.8, 3.4, 1.6, 0.2},
+        {1, 4.8, 3, 1.4, 0.1},
+        {1, 4.3, 3, 1.1, 0.1},
+        {1, 5.8, 4, 1.2, 0.2},
+        {1, 5.7, 4.4, 1.5, 0.4},
+        {1, 5.4, 3.9, 1.3, 0.4},
+        {1, 5.1, 3.5, 1.4, 0.3},
+        {1, 5.7, 3.8, 1.7, 0.3},
+        {1, 5.1, 3.8, 1.5, 0.3},
+        {1, 5.4, 3.4, 1.7, 0.2},
+        {1, 5.1, 3.7, 1.5, 0.4},
+        {1, 4.6, 3.6, 1, 0.2},
+        {1, 5.1, 3.3, 1.7, 0.5},
+        {1, 4.8, 3.4, 1.9, 0.2},
+        {1, 5, 3, 1.6, 0.2},
+        {1, 5, 3.4, 1.6, 0.4},
+        {1, 5.2, 3.5, 1.5, 0.2},
+        {1, 5.2, 3.4, 1.4, 0.2},
+        {1, 4.7, 3.2, 1.6, 0.2},
+        {1, 4.8, 3.1, 1.6, 0.2},
+        {1, 5.4, 3.4, 1.5, 0.4},
+        {1, 5.2, 4.1, 1.5, 0.1},
+        {1, 5.5, 4.2, 1.4, 0.2},
+        {1, 4.9, 3.1, 1.5, 0.1},
+        {1, 5, 3.2, 1.2, 0.2},
+        {1, 5.5, 3.5, 1.3, 0.2},
+        {1, 4.9, 3.1, 1.5, 0.1},
+        {1, 4.4, 3, 1.3, 0.2},
+        {1, 5.1, 3.4, 1.5, 0.2},
+        {1, 5, 3.5, 1.3, 0.3},
+        {1, 4.5, 2.3, 1.3, 0.3},
+        {1, 4.4, 3.2, 1.3, 0.2},
+        {1, 5, 3.5, 1.6, 0.6},
+        {1, 5.1, 3.8, 1.9, 0.4},
+        {1, 4.8, 3, 1.4, 0.3},
+        {1, 5.1, 3.8, 1.6, 0.2},
+        {1, 4.6, 3.2, 1.4, 0.2},
+        {1, 5.3, 3.7, 1.5, 0.2},
+        {1, 5, 3.3, 1.4, 0.2},
+        {2, 7, 3.2, 4.7, 1.4},
+        {2, 6.4, 3.2, 4.5, 1.5},
+        {2, 6.9, 3.1, 4.9, 1.5},
+        {2, 5.5, 2.3, 4, 1.3},
+        {2, 6.5, 2.8, 4.6, 1.5},
+        {2, 5.7, 2.8, 4.5, 1.3},
+        {2, 6.3, 3.3, 4.7, 1.6},
+        {2, 4.9, 2.4, 3.3, 1},
+        {2, 6.6, 2.9, 4.6, 1.3},
+        {2, 5.2, 2.7, 3.9, 1.4},
+        {2, 5, 2, 3.5, 1},
+        {2, 5.9, 3, 4.2, 1.5},
+        {2, 6, 2.2, 4, 1},
+        {2, 6.1, 2.9, 4.7, 1.4},
+        {2, 5.6, 2.9, 3.6, 1.3},
+        {2, 6.7, 3.1, 4.4, 1.4},
+        {2, 5.6, 3, 4.5, 1.5},
+        {2, 5.8, 2.7, 4.1, 1},
+        {2, 6.2, 2.2, 4.5, 1.5},
+        {2, 5.6, 2.5, 3.9, 1.1},
+        {2, 5.9, 3.2, 4.8, 1.8},
+        {2, 6.1, 2.8, 4, 1.3},
+        {2, 6.3, 2.5, 4.9, 1.5},
+        {2, 6.1, 2.8, 4.7, 1.2},
+        {2, 6.4, 2.9, 4.3, 1.3},
+        {2, 6.6, 3, 4.4, 1.4},
+        {2, 6.8, 2.8, 4.8, 1.4},
+        {2, 6.7, 3, 5, 1.7},
+        {2, 6, 2.9, 4.5, 1.5},
+        {2, 5.7, 2.6, 3.5, 1},
+        {2, 5.5, 2.4, 3.8, 1.1},
+        {2, 5.5, 2.4, 3.7, 1},
+        {2, 5.8, 2.7, 3.9, 1.2},
+        {2, 6, 2.7, 5.1, 1.6},
+        {2, 5.4, 3, 4.5, 1.5},
+        {2, 6, 3.4, 4.5, 1.6},
+        {2, 6.7, 3.1, 4.7, 1.5},
+        {2, 6.3, 2.3, 4.4, 1.3},
+        {2, 5.6, 3, 4.1, 1.3},
+        {2, 5.5, 2.5, 4, 1.3},
+        {2, 5.5, 2.6, 4.4, 1.2},
+        {2, 6.1, 3, 4.6, 1.4},
+        {2, 5.8, 2.6, 4, 1.2},
+        {2, 5, 2.3, 3.3, 1},
+        {2, 5.6, 2.7, 4.2, 1.3},
+        {2, 5.7, 3, 4.2, 1.2},
+        {2, 5.7, 2.9, 4.2, 1.3},
+        {2, 6.2, 2.9, 4.3, 1.3},
+        {2, 5.1, 2.5, 3, 1.1},
+        {2, 5.7, 2.8, 4.1, 1.3},
+        {3, 6.3, 3.3, 6, 2.5},
+        {3, 5.8, 2.7, 5.1, 1.9},
+        {3, 7.1, 3, 5.9, 2.1},
+        {3, 6.3, 2.9, 5.6, 1.8},
+        {3, 6.5, 3, 5.8, 2.2},
+        {3, 7.6, 3, 6.6, 2.1},
+        {3, 4.9, 2.5, 4.5, 1.7},
+        {3, 7.3, 2.9, 6.3, 1.8},
+        {3, 6.7, 2.5, 5.8, 1.8},
+        {3, 7.2, 3.6, 6.1, 2.5},
+        {3, 6.5, 3.2, 5.1, 2},
+        {3, 6.4, 2.7, 5.3, 1.9},
+        {3, 6.8, 3, 5.5, 2.1},
+        {3, 5.7, 2.5, 5, 2},
+        {3, 5.8, 2.8, 5.1, 2.4},
+        {3, 6.4, 3.2, 5.3, 2.3},
+        {3, 6.5, 3, 5.5, 1.8},
+        {3, 7.7, 3.8, 6.7, 2.2},
+        {3, 7.7, 2.6, 6.9, 2.3},
+        {3, 6, 2.2, 5, 1.5},
+        {3, 6.9, 3.2, 5.7, 2.3},
+        {3, 5.6, 2.8, 4.9, 2},
+        {3, 7.7, 2.8, 6.7, 2},
+        {3, 6.3, 2.7, 4.9, 1.8},
+        {3, 6.7, 3.3, 5.7, 2.1},
+        {3, 7.2, 3.2, 6, 1.8},
+        {3, 6.2, 2.8, 4.8, 1.8},
+        {3, 6.1, 3, 4.9, 1.8},
+        {3, 6.4, 2.8, 5.6, 2.1},
+        {3, 7.2, 3, 5.8, 1.6},
+        {3, 7.4, 2.8, 6.1, 1.9},
+        {3, 7.9, 3.8, 6.4, 2},
+        {3, 6.4, 2.8, 5.6, 2.2},
+        {3, 6.3, 2.8, 5.1, 1.5},
+        {3, 6.1, 2.6, 5.6, 1.4},
+        {3, 7.7, 3, 6.1, 2.3},
+        {3, 6.3, 3.4, 5.6, 2.4},
+        {3, 6.4, 3.1, 5.5, 1.8},
+        {3, 6, 3, 4.8, 1.8},
+        {3, 6.9, 3.1, 5.4, 2.1},
+        {3, 6.7, 3.1, 5.6, 2.4},
+        {3, 6.9, 3.1, 5.1, 2.3},
+        {3, 5.8, 2.7, 5.1, 1.9},
+        {3, 6.8, 3.2, 5.9, 2.3},
+        {3, 6.7, 3.3, 5.7, 2.5},
+        {3, 6.7, 3, 5.2, 2.3},
+        {3, 6.3, 2.5, 5, 1.9},
+        {3, 6.5, 3, 5.2, 2},
+        {3, 6.2, 3.4, 5.4, 2.3},
+        {3, 5.9, 3, 5.1, 1.8}
+    };
+}
index d12fc1d..b4602cc 100644 (file)
@@ -27,9 +27,9 @@ import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
 import org.apache.ignite.cache.query.QueryCursor;
 import org.apache.ignite.cache.query.ScanQuery;
 import org.apache.ignite.configuration.CacheConfiguration;
-import org.apache.ignite.ml.knn.classification.KNNClassificationModel;
+import org.apache.ignite.ml.knn.NNClassificationModel;
 import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer;
-import org.apache.ignite.ml.knn.classification.KNNStrategy;
+import org.apache.ignite.ml.knn.classification.NNStrategy;
 import org.apache.ignite.ml.math.distances.EuclideanDistance;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
@@ -55,14 +55,14 @@ public class KNNClassificationExample {
 
                 KNNClassificationTrainer trainer = new KNNClassificationTrainer();
 
-                KNNClassificationModel knnMdl = trainer.fit(
+                NNClassificationModel knnMdl = trainer.fit(
                     ignite,
                     dataCache,
                     (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
                     (k, v) -> v[0]
                 ).withK(3)
                     .withDistanceMeasure(new EuclideanDistance())
-                    .withStrategy(KNNStrategy.WEIGHTED);
+                    .withStrategy(NNStrategy.WEIGHTED);
 
                 System.out.println(">>> ---------------------------------");
                 System.out.println(">>> | Prediction\t| Ground Truth\t|");
index d1e5055..7c84949 100644 (file)
@@ -28,7 +28,7 @@ import org.apache.ignite.cache.query.QueryCursor;
 import org.apache.ignite.cache.query.ScanQuery;
 import org.apache.ignite.configuration.CacheConfiguration;
 import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer;
-import org.apache.ignite.ml.knn.classification.KNNStrategy;
+import org.apache.ignite.ml.knn.classification.NNStrategy;
 import org.apache.ignite.ml.knn.regression.KNNRegressionModel;
 import org.apache.ignite.ml.knn.regression.KNNRegressionTrainer;
 import org.apache.ignite.ml.math.distances.ManhattanDistance;
@@ -63,7 +63,7 @@ public class KNNRegressionExample {
                     (k, v) -> v[0]
                 ).withK(5)
                     .withDistanceMeasure(new ManhattanDistance())
-                    .withStrategy(KNNStrategy.WEIGHTED);
+                    .withStrategy(NNStrategy.WEIGHTED);
 
                 int totalAmount = 0;
                 // Calculate mean squared error (MSE)
index e07e9f8..56e70f1 100644 (file)
@@ -21,9 +21,9 @@ import java.io.FileNotFoundException;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.Ignition;
-import org.apache.ignite.ml.knn.classification.KNNClassificationModel;
+import org.apache.ignite.ml.knn.NNClassificationModel;
 import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer;
-import org.apache.ignite.ml.knn.classification.KNNStrategy;
+import org.apache.ignite.ml.knn.classification.NNStrategy;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer;
@@ -88,12 +88,12 @@ public class Step_6_KNN {
                     KNNClassificationTrainer trainer = new KNNClassificationTrainer();
 
                     // Train decision tree model.
-                    KNNClassificationModel mdl = trainer.fit(
+                    NNClassificationModel mdl = trainer.fit(
                         ignite,
                         dataCache,
                         normalizationPreprocessor,
                         lbExtractor
-                    ).withK(1).withStrategy(KNNStrategy.WEIGHTED);
+                    ).withK(1).withStrategy(NNStrategy.WEIGHTED);
 
                     double accuracy = Evaluator.evaluate(
                         dataCache,
index 7dbc78a..c005312 100644 (file)
 package org.apache.ignite.ml.clustering.kmeans;
 
 import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Random;
+import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
@@ -35,8 +38,8 @@ import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
 import org.apache.ignite.ml.math.util.MapUtil;
-import org.apache.ignite.ml.structures.LabeledDataset;
 import org.apache.ignite.ml.structures.LabeledVector;
+import org.apache.ignite.ml.structures.LabeledVectorSet;
 import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap;
 import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
 
@@ -62,23 +65,23 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> {
     /**
      * Trains model based on the specified data.
      *
-     * @param datasetBuilder Dataset builder.
+     * @param datasetBuilder   Dataset builder.
      * @param featureExtractor Feature extractor.
-     * @param lbExtractor Label extractor.
+     * @param lbExtractor      Label extractor.
      * @return Model.
      */
     @Override public <K, V> KMeansModel fit(DatasetBuilder<K, V> datasetBuilder,
-        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
+                                            IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
         assert datasetBuilder != null;
 
-        PartitionDataBuilder<K, V, EmptyContext, LabeledDataset<Double, LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(
+        PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<Double, LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(
             featureExtractor,
             lbExtractor
         );
 
         Vector[] centers;
 
-        try (Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset = datasetBuilder.build(
+        try (Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset = datasetBuilder.build(
             (upstream, upstreamSize) -> new EmptyContext(),
             partDataBuilder
         )) {
@@ -105,10 +108,12 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> {
                 }
 
                 iteration++;
-                centers = newCentroids;
+                for (int i = 0; i < centers.length; i++) {
+                    if (newCentroids[i] != null)
+                        centers[i] = newCentroids[i];
+                }
             }
-        }
-        catch (Exception e) {
+        } catch (Exception e) {
             throw new RuntimeException(e);
         }
         return new KMeansModel(centers, distance);
@@ -119,11 +124,11 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> {
      *
      * @param centers Current centers on the current iteration.
      * @param dataset Dataset.
-     * @param cols Amount of columns.
+     * @param cols    Amount of columns.
      * @return Helper data to calculate the new centroids.
      */
     private TotalCostAndCounts calcDataForNewCentroids(Vector[] centers,
-        Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset, int cols) {
+                                                       Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset, int cols) {
         final Vector[] finalCenters = centers;
 
         return dataset.compute(data -> {
@@ -142,10 +147,10 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> {
 
                 int finalI = i;
                 res.sums.compute(centroidIdx,
-                    (IgniteBiFunction<Integer, Vector, Vector>)(ind, v) -> v.plus(data.getRow(finalI).features()));
+                    (IgniteBiFunction<Integer, Vector, Vector>) (ind, v) -> v.plus(data.getRow(finalI).features()));
 
                 res.counts.merge(centroidIdx, 1,
-                    (IgniteBiFunction<Integer, Integer, Integer>)(i1, i2) -> i1 + i2);
+                    (IgniteBiFunction<Integer, Integer, Integer>) (i1, i2) -> i1 + i2);
             }
             return res;
         }, (a, b) -> a == null ? b : a.merge(b));
@@ -155,7 +160,7 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> {
      * Find the closest cluster center index and distance to it from a given point.
      *
      * @param centers Centers to look in.
-     * @param pnt Point.
+     * @param pnt     Point.
      */
     private IgniteBiTuple<Integer, Double> findClosestCentroid(Vector[] centers, LabeledVector pnt) {
         double bestDistance = Double.POSITIVE_INFINITY;
@@ -175,31 +180,62 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> {
      * K cluster centers are initialized randomly.
      *
      * @param dataset The dataset to pick up random centers.
-     * @param k Amount of clusters.
+     * @param k       Amount of clusters.
      * @return K cluster centers.
      */
-    private Vector[] initClusterCentersRandomly(Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset,
-        int k) {
+    private Vector[] initClusterCentersRandomly(Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset,
+                                                int k) {
 
         Vector[] initCenters = new DenseVector[k];
 
+        // Gets k or less vectors from each partition.
         List<LabeledVector> rndPnts = dataset.compute(data -> {
             List<LabeledVector> rndPnt = new ArrayList<>();
-            rndPnt.add(data.getRow(new Random(seed).nextInt(data.rowSize())));
+
+            if (data.rowSize() != 0) {
+                if (data.rowSize() > k) { // If it's enough rows in partition to pick k vectors.
+                    final Random random = new Random(seed);
+
+                    for (int i = 0; i < k; i++) {
+                        Set<Integer> uniqueIndices = new HashSet<>();
+                        int nextIdx = random.nextInt(data.rowSize());
+                        int maxRandomSearch = k; // It required to make the next cycle is finite.
+                        int cntr = 0;
+
+                        // Repeat nextIdx generation if it was picked earlier.
+                        while (uniqueIndices.contains(nextIdx) && cntr < maxRandomSearch) {
+                            nextIdx = random.nextInt(data.rowSize());
+                            cntr++;
+                        }
+                        uniqueIndices.add(nextIdx);
+
+                        rndPnt.add(data.getRow(nextIdx));
+                    }
+                } else // If it's not enough vectors to pick k vectors.
+                    for (int i = 0; i < data.rowSize(); i++)
+                        rndPnt.add(data.getRow(i));
+            }
             return rndPnt;
         }, (a, b) -> a == null ? b : Stream.concat(a.stream(), b.stream()).collect(Collectors.toList()));
 
-        for (int i = 0; i < k; i++) {
-            final LabeledVector rndPnt = rndPnts.get(new Random(seed).nextInt(rndPnts.size()));
-            rndPnts.remove(rndPnt);
-            initCenters[i] = rndPnt.features();
-        }
+        // Shuffle them.
+        Collections.shuffle(rndPnts);
+
+        // Pick k vectors randomly.
+        if (rndPnts.size() >= k) {
+            for (int i = 0; i < k; i++) {
+                final LabeledVector rndPnt = rndPnts.get(new Random(seed).nextInt(rndPnts.size()));
+                rndPnts.remove(rndPnt);
+                initCenters[i] = rndPnt.features();
+            }
+        } else
+            throw new RuntimeException("The KMeans Trainer required more than " + k + " vectors to find " + k + " clusters");
 
         return initCenters;
     }
 
     /** Service class used for statistics. */
-    private static class TotalCostAndCounts {
+    public static class TotalCostAndCounts {
         /** */
         double totalCost;
 
@@ -209,13 +245,24 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> {
         /** Count of points closest to the center with a given index. */
         ConcurrentHashMap<Integer, Integer> counts = new ConcurrentHashMap<>();
 
+
+        /** Count of points closest to the center with a given index. */
+        ConcurrentHashMap<Integer, ConcurrentHashMap<Double, Integer>> centroidStat = new ConcurrentHashMap<>();
+
         /** Merge current */
         TotalCostAndCounts merge(TotalCostAndCounts other) {
             this.totalCost += totalCost;
             this.sums = MapUtil.mergeMaps(sums, other.sums, Vector::plus, ConcurrentHashMap::new);
             this.counts = MapUtil.mergeMaps(counts, other.counts, (i1, i2) -> i1 + i2, ConcurrentHashMap::new);
+            this.centroidStat = MapUtil.mergeMaps(centroidStat, other.centroidStat, (m1, m2) ->
+                MapUtil.mergeMaps(m1, m2, (i1, i2) -> i1 + i2, ConcurrentHashMap::new), ConcurrentHashMap::new);
             return this;
         }
+
+        public ConcurrentHashMap<Integer, ConcurrentHashMap<Double, Integer>> getCentroidStat() {
+            return centroidStat;
+        }
+
     }
 
     /**
index 53a6219..3701557 100644 (file)
@@ -29,8 +29,8 @@ import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.functions.IgniteFunction;
 import org.apache.ignite.ml.math.functions.IgniteTriFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.structures.LabeledDataset;
 import org.apache.ignite.ml.structures.LabeledVector;
+import org.apache.ignite.ml.structures.LabeledVectorSet;
 import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap;
 
 /**
@@ -76,7 +76,7 @@ public abstract class GDBBinaryClassifierTrainer extends GDBTrainer {
 
         List<Double> uniqLabels = new ArrayList<Double>(
             builder.build(new EmptyContextBuilder<>(), new LabeledDatasetPartitionDataBuilderOnHeap<>(featureExtractor, lExtractor))
-                .compute((IgniteFunction<LabeledDataset<Double,LabeledVector>, Set<Double>>) x ->
+                .compute((IgniteFunction<LabeledVectorSet<Double,LabeledVector>, Set<Double>>) x ->
                     Arrays.stream(x.labels()).boxed().collect(Collectors.toSet()), (a, b) -> {
                         if (a == null)
                             return b;
index cdf2d50..e7228f8 100644 (file)
@@ -26,9 +26,13 @@ import org.apache.ignite.ml.math.functions.IgniteSupplier;
  * bagging, learning submodels for One-vs-All model, Cross-Validation etc.
  */
 public interface ParallelismStrategy {
+
+    /**
+     * The type of parallelism.
+     */
     public enum Type {
-        NO_PARALLELISM,
-        ON_DEFAULT_POOL
+        /** No parallelism. */NO_PARALLELISM,
+        /** On default pool. */ON_DEFAULT_POOL
     }
 
     /**
@@ -38,6 +42,13 @@ public interface ParallelismStrategy {
      */
     public <T> Promise<T> submit(IgniteSupplier<T> task);
 
+    /**
+     * Submit the list of tasks.
+     *
+     * @param tasks The task list.
+     * @param <T> The type of return value.
+     * @return The result of submit operation.
+     */
     public default <T> List<Promise<T>> submit(List<IgniteSupplier<T>> tasks) {
         List<Promise<T>> results = new ArrayList<>();
         for(IgniteSupplier<T> task : tasks)
index b5a0cdb..d7bccd8 100644 (file)
@@ -23,8 +23,8 @@ import org.apache.ignite.ml.dataset.PartitionDataBuilder;
 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.structures.LabeledDataset;
 import org.apache.ignite.ml.structures.LabeledVector;
+import org.apache.ignite.ml.structures.LabeledVectorSet;
 import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap;
 import org.jetbrains.annotations.Nullable;
 
@@ -40,14 +40,14 @@ public class KNNUtils {
      * @param lbExtractor Label extractor.
      * @return Dataset.
      */
-    @Nullable public static <K, V> Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> buildDataset(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
-        PartitionDataBuilder<K, V, EmptyContext, LabeledDataset<Double, LabeledVector>> partDataBuilder
+    @Nullable public static <K, V> Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> buildDataset(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
+        PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<Double, LabeledVector>> partDataBuilder
             = new LabeledDatasetPartitionDataBuilderOnHeap<>(
             featureExtractor,
             lbExtractor
         );
 
-        Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset = null;
+        Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset = null;
 
         if (datasetBuilder != null) {
             dataset = datasetBuilder.build(
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/NNClassificationModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/NNClassificationModel.java
new file mode 100644 (file)
index 0000000..b7a57f5
--- /dev/null
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.knn;
+
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeMap;
+import org.apache.ignite.ml.Exportable;
+import org.apache.ignite.ml.Exporter;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.knn.classification.KNNModelFormat;
+import org.apache.ignite.ml.knn.classification.NNStrategy;
+import org.apache.ignite.ml.math.distances.DistanceMeasure;
+import org.apache.ignite.ml.math.distances.EuclideanDistance;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.structures.LabeledVector;
+import org.apache.ignite.ml.structures.LabeledVectorSet;
+import org.apache.ignite.ml.util.ModelTrace;
+import org.jetbrains.annotations.NotNull;
+
+/**
+ * Common methods and fields for all kNN and aNN models
+ * to predict label based on neighbours' labels.
+ */
+public abstract class NNClassificationModel implements Model<Vector, Double>, Exportable<KNNModelFormat> {
+    /** Amount of nearest neighbors. */
+    protected int k = 5;
+
+    /** Distance measure. */
+    protected DistanceMeasure distanceMeasure = new EuclideanDistance();
+
+    /** kNN strategy. */
+    protected NNStrategy stgy = NNStrategy.SIMPLE;
+
+    /**
+     * Set up parameter of the NN model.
+     * @param k Amount of nearest neighbors.
+     * @return Model.
+     */
+    public NNClassificationModel withK(int k) {
+        this.k = k;
+        return this;
+    }
+
+    /**
+     * Set up parameter of the NN model.
+     * @param stgy Strategy of calculations.
+     * @return Model.
+     */
+    public NNClassificationModel withStrategy(NNStrategy stgy) {
+        this.stgy = stgy;
+        return this;
+    }
+
+    /**
+     * Set up parameter of the NN model.
+     * @param distanceMeasure Distance measure.
+     * @return Model.
+     */
+    public NNClassificationModel withDistanceMeasure(DistanceMeasure distanceMeasure) {
+        this.distanceMeasure = distanceMeasure;
+        return this;
+    }
+
+    /** */
+    protected LabeledVectorSet<Double, LabeledVector> buildLabeledDatasetOnListOfVectors(
+        List<LabeledVector> neighborsFromPartitions) {
+        LabeledVector[] arr = new LabeledVector[neighborsFromPartitions.size()];
+        for (int i = 0; i < arr.length; i++)
+            arr[i] = neighborsFromPartitions.get(i);
+
+        return new LabeledVectorSet<Double, LabeledVector>(arr);
+    }
+
+    /**
+     * Iterates along entries in distance map and fill the resulting k-element array.
+     *
+     * @param trainingData The training data.
+     * @param distanceIdxPairs The distance map.
+     * @return K-nearest neighbors.
+     */
+    @NotNull protected LabeledVector[] getKClosestVectors(LabeledVectorSet<Double, LabeledVector> trainingData,
+                                                          TreeMap<Double, Set<Integer>> distanceIdxPairs) {
+        LabeledVector[] res;
+
+        if (trainingData.rowSize() <= k) {
+            res = new LabeledVector[trainingData.rowSize()];
+            for (int i = 0; i < trainingData.rowSize(); i++)
+                res[i] = trainingData.getRow(i);
+        }
+        else {
+            res = new LabeledVector[k];
+            int i = 0;
+            final Iterator<Double> iter = distanceIdxPairs.keySet().iterator();
+            while (i < k) {
+                double key = iter.next();
+                Set<Integer> idxs = distanceIdxPairs.get(key);
+                for (Integer idx : idxs) {
+                    res[i] = trainingData.getRow(idx);
+                    i++;
+                    if (i >= k)
+                        break; // go to next while-loop iteration
+                }
+            }
+        }
+
+        return res;
+    }
+
+    /**
+     * Computes distances between given vector and each vector in training dataset.
+     *
+     * @param v The given vector.
+     * @param trainingData The training dataset.
+     * @return Key - distanceMeasure from given features before features with idx stored in value. Value is presented
+     * with Set because there can be a few vectors with the same distance.
+     */
+    @NotNull protected TreeMap<Double, Set<Integer>> getDistances(Vector v, LabeledVectorSet<Double, LabeledVector> trainingData) {
+        TreeMap<Double, Set<Integer>> distanceIdxPairs = new TreeMap<>();
+
+        for (int i = 0; i < trainingData.rowSize(); i++) {
+
+            LabeledVector labeledVector = trainingData.getRow(i);
+            if (labeledVector != null) {
+                double distance = distanceMeasure.compute(v, labeledVector.features());
+                putDistanceIdxPair(distanceIdxPairs, i, distance);
+            }
+        }
+        return distanceIdxPairs;
+    }
+
+    /** */
+    protected void putDistanceIdxPair(Map<Double, Set<Integer>> distanceIdxPairs, int i, double distance) {
+        if (distanceIdxPairs.containsKey(distance)) {
+            Set<Integer> idxs = distanceIdxPairs.get(distance);
+            idxs.add(i);
+        }
+        else {
+            Set<Integer> idxs = new HashSet<>();
+            idxs.add(i);
+            distanceIdxPairs.put(distance, idxs);
+        }
+    }
+
+    /** */
+    protected double getClassWithMaxVotes(Map<Double, Double> clsVotes) {
+        return Collections.max(clsVotes.entrySet(), Map.Entry.comparingByValue()).getKey();
+    }
+
+    /** */
+    protected double getClassVoteForVector(NNStrategy stgy, double distance) {
+        if (stgy.equals(NNStrategy.WEIGHTED))
+            return 1 / distance; // strategy.WEIGHTED
+        else
+            return 1.0; // strategy.SIMPLE
+    }
+
+    /** {@inheritDoc} */
+    @Override public int hashCode() {
+        int res = 1;
+
+        res = res * 37 + k;
+        res = res * 37 + distanceMeasure.hashCode();
+        res = res * 37 + stgy.hashCode();
+
+        return res;
+    }
+
+    /** {@inheritDoc} */
+    @Override public boolean equals(Object obj) {
+        if (this == obj)
+            return true;
+
+        if (obj == null || getClass() != obj.getClass())
+            return false;
+
+        NNClassificationModel that = (NNClassificationModel)obj;
+
+        return k == that.k && distanceMeasure.equals(that.distanceMeasure) && stgy.equals(that.stgy);
+    }
+
+    /** {@inheritDoc} */
+    @Override public String toString() {
+        return toString(false);
+    }
+
+    /** {@inheritDoc} */
+    @Override public String toString(boolean pretty) {
+        return ModelTrace.builder("KNNClassificationModel", pretty)
+            .addField("k", String.valueOf(k))
+            .addField("measure", distanceMeasure.getClass().getSimpleName())
+            .addField("strategy", stgy.name())
+            .toString();
+    }
+
+    /** */
+    public abstract <P> void saveModel(Exporter<KNNModelFormat, P> exporter, P path);
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationModel.java
new file mode 100644 (file)
index 0000000..e8c0b4a
--- /dev/null
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.knn.ann;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeMap;
+import org.apache.ignite.ml.Exporter;
+import org.apache.ignite.ml.knn.NNClassificationModel;
+import org.apache.ignite.ml.knn.classification.KNNModelFormat;
+import org.apache.ignite.ml.knn.classification.NNStrategy;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.structures.LabeledVector;
+import org.apache.ignite.ml.structures.LabeledVectorSet;
+import org.apache.ignite.ml.util.ModelTrace;
+import org.jetbrains.annotations.NotNull;
+
+/**
+ * ANN model to predict labels in multi-class classification task.
+ */
+public class ANNClassificationModel extends NNClassificationModel  {
+    /** */
+    private static final long serialVersionUID = -127312378991350345L;
+
+    /** The labeled set of candidates. */
+    private final LabeledVectorSet<ProbableLabel, LabeledVector> candidates;
+
+    /**
+     * Build the model based on a candidates set.
+     * @param centers The candidates set.
+     */
+    public ANNClassificationModel(LabeledVectorSet<ProbableLabel, LabeledVector> centers) {
+       this.candidates = centers;
+    }
+
+    /** */
+    public LabeledVectorSet<ProbableLabel, LabeledVector> getCandidates() {
+        return candidates;
+    }
+
+    /** {@inheritDoc} */
+    @Override public Double apply(Vector v) {
+            List<LabeledVector> neighbors = findKNearestNeighbors(v);
+            return classify(neighbors, v, stgy);
+    }
+
+    /** */
+    @Override public <P> void saveModel(Exporter<KNNModelFormat, P> exporter, P path) {
+        ANNModelFormat mdlData = new ANNModelFormat(k, distanceMeasure, stgy, candidates);
+        exporter.save(mdlData, path);
+    }
+
+    /**
+     * The main idea is calculation all distance pairs between given vector and all centroids in candidates set, sorting
+     * them and finding k vectors with min distance with the given vector.
+     *
+     * @param v The given vector.
+     * @return K-nearest neighbors.
+     */
+    private List<LabeledVector> findKNearestNeighbors(Vector v) {
+        return Arrays.asList(getKClosestVectors(getDistances(v)));
+    }
+
+    /**
+     * Iterates along entries in distance map and fill the resulting k-element array.
+     * @param distanceIdxPairs The distance map.
+     * @return K-nearest neighbors.
+     */
+    @NotNull private LabeledVector[] getKClosestVectors(
+        TreeMap<Double, Set<Integer>> distanceIdxPairs) {
+        LabeledVector[] res;
+
+        if (candidates.rowSize() <= k) {
+            res = new LabeledVector[candidates.rowSize()];
+            for (int i = 0; i < candidates.rowSize(); i++)
+                res[i] = candidates.getRow(i);
+        }
+        else {
+            res = new LabeledVector[k];
+            int i = 0;
+            final Iterator<Double> iter = distanceIdxPairs.keySet().iterator();
+            while (i < k) {
+                double key = iter.next();
+                Set<Integer> idxs = distanceIdxPairs.get(key);
+                for (Integer idx : idxs) {
+                    res[i] = candidates.getRow(idx);
+                    i++;
+                    if (i >= k)
+                        break; // go to next while-loop iteration
+                }
+            }
+        }
+
+        return res;
+    }
+
+    /**
+     * Computes distances between given vector and each vector in training dataset.
+     *
+     * @param v The given vector.
+     * @return Key - distanceMeasure from given features before features with idx stored in value. Value is presented
+     * with Set because there can be a few vectors with the same distance.
+     */
+    @NotNull private TreeMap<Double, Set<Integer>> getDistances(Vector v) {
+        TreeMap<Double, Set<Integer>> distanceIdxPairs = new TreeMap<>();
+
+        for (int i = 0; i < candidates.rowSize(); i++) {
+
+            LabeledVector labeledVector = candidates.getRow(i);
+            if (labeledVector != null) {
+                double distance = distanceMeasure.compute(v, labeledVector.features());
+                putDistanceIdxPair(distanceIdxPairs, i, distance);
+            }
+        }
+        return distanceIdxPairs;
+    }
+
+    /** */
+    private double classify(List<LabeledVector> neighbors, Vector v, NNStrategy stgy) {
+        Map<Double, Double> clsVotes = new HashMap<>();
+
+        for (LabeledVector neighbor : neighbors) {
+            TreeMap<Double, Double> probableClsLb = ((ProbableLabel)neighbor.label()).clsLbls;
+
+            double distance = distanceMeasure.compute(v, neighbor.features());
+
+            // we predict class label, not the probability vector (it need here another math with counting of votes)
+            probableClsLb.forEach((label, probability) -> {
+                double cnt = clsVotes.containsKey(label) ? clsVotes.get(label) : 0;
+                clsVotes.put(label, cnt + probability * getClassVoteForVector(stgy, distance));
+            });
+        }
+        return getClassWithMaxVotes(clsVotes);
+    }
+
+    /** {@inheritDoc} */
+    @Override public int hashCode() {
+        int res = 1;
+
+        res = res * 37 + k;
+        res = res * 37 + distanceMeasure.hashCode();
+        res = res * 37 + stgy.hashCode();
+        res = res * 37 + candidates.hashCode();
+
+        return res;
+    }
+
+    /** {@inheritDoc} */
+    @Override public boolean equals(Object obj) {
+        if (this == obj)
+            return true;
+
+        if (obj == null || getClass() != obj.getClass())
+            return false;
+
+        ANNClassificationModel that = (ANNClassificationModel)obj;
+
+        return k == that.k
+            && distanceMeasure.equals(that.distanceMeasure)
+            && stgy.equals(that.stgy)
+            && candidates.equals(that.candidates);
+    }
+
+    /** {@inheritDoc} */
+    @Override public String toString() {
+        return toString(false);
+    }
+
+    /** {@inheritDoc} */
+    @Override public String toString(boolean pretty) {
+        return ModelTrace.builder("KNNClassificationModel", pretty)
+            .addField("k", String.valueOf(k))
+            .addField("measure", distanceMeasure.getClass().getSimpleName())
+            .addField("strategy", stgy.name())
+            .addField("amount of candidates", String.valueOf(candidates.rowSize()))
+            .toString();
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java
new file mode 100644 (file)
index 0000000..282be3c
--- /dev/null
@@ -0,0 +1,343 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.knn.ann;
+
+import java.util.TreeMap;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentSkipListSet;
+import org.apache.ignite.lang.IgniteBiTuple;
+import org.apache.ignite.ml.clustering.kmeans.KMeansModel;
+import org.apache.ignite.ml.clustering.kmeans.KMeansTrainer;
+import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.PartitionDataBuilder;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.math.distances.DistanceMeasure;
+import org.apache.ignite.ml.math.distances.EuclideanDistance;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.util.MapUtil;
+import org.apache.ignite.ml.structures.LabeledVector;
+import org.apache.ignite.ml.structures.LabeledVectorSet;
+import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap;
+import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
+import org.jetbrains.annotations.NotNull;
+
+/**
+ * ANN algorithm trainer to solve multi-class classification task.
+ * This trainer is based on ACD strategy and KMeans clustering algorithm to find centroids.
+ */
+public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClassificationModel> {
+    /** Amount of clusters. */
+    private int k = 2;
+
+    /** Amount of iterations. */
+    private int maxIterations = 10;
+
+    /** Delta of convergence. */
+    private double epsilon = 1e-4;
+
+    /** Distance measure. */
+    private DistanceMeasure distance = new EuclideanDistance();
+
+    /** KMeans initializer. */
+    private long seed;
+
+    /**
+     * Trains model based on the specified data.
+     *
+     * @param datasetBuilder   Dataset builder.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor      Label extractor.
+     * @return Model.
+     */
+    @Override public <K, V> ANNClassificationModel fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
+        final Vector[] centers = getCentroids(featureExtractor, lbExtractor, datasetBuilder);
+
+        final CentroidStat centroidStat = getCentroidStat(datasetBuilder, featureExtractor, lbExtractor, centers);
+
+        final LabeledVectorSet<ProbableLabel, LabeledVector> dataset = buildLabelsForCandidates(centers, centroidStat);
+
+        return new ANNClassificationModel(dataset);
+    }
+
+    /** */
+    @NotNull private LabeledVectorSet<ProbableLabel, LabeledVector> buildLabelsForCandidates(Vector[] centers, CentroidStat centroidStat) {
+        // init
+        final LabeledVector<Vector, ProbableLabel>[] arr = new LabeledVector[centers.length];
+
+        // fill label for each centroid
+        for (int i = 0; i < centers.length; i++)
+            arr[i] = new LabeledVector<>(centers[i], fillProbableLabel(i, centroidStat));
+
+        return new LabeledVectorSet<>(arr);
+    }
+
+    /**
+     * Perform KMeans clusterization algorithm to find centroids.
+     *
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor      Label extractor.
+     * @param datasetBuilder   The dataset builder.
+     * @param <K>              Type of a key in {@code upstream} data.
+     * @param <V>              Type of a value in {@code upstream} data.
+     * @return The arrays of vectors.
+     */
+    private <K, V> Vector[] getCentroids(IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor, DatasetBuilder<K, V> datasetBuilder) {
+        KMeansTrainer trainer = new KMeansTrainer()
+            .withK(k)
+            .withMaxIterations(maxIterations)
+            .withSeed(seed)
+            .withDistance(distance)
+            .withEpsilon(epsilon);
+
+        KMeansModel mdl = trainer.fit(
+            datasetBuilder,
+            featureExtractor,
+            lbExtractor
+        );
+
+        return mdl.centers();
+    }
+
+    /** */
+    private ProbableLabel fillProbableLabel(int centroidIdx, CentroidStat centroidStat) {
+        TreeMap<Double, Double> clsLbls = new TreeMap<>();
+
+        // add all class labels as keys
+        centroidStat.clsLblsSet.forEach(t -> clsLbls.put(t, 0.0));
+
+        ConcurrentHashMap<Double, Integer> centroidLbDistribution
+            = centroidStat.centroidStat().get(centroidIdx);
+
+        if(centroidStat.counts.containsKey(centroidIdx)){
+
+            int clusterSize = centroidStat
+                .counts
+                .get(centroidIdx);
+
+            clsLbls.keySet().forEach(
+                (label) -> clsLbls.put(label, centroidLbDistribution.containsKey(label) ? ((double) (centroidLbDistribution.get(label)) / clusterSize) : 0.0)
+            );
+        }
+        return new ProbableLabel(clsLbls);
+    }
+
+    /** */
+    private <K, V> CentroidStat getCentroidStat(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor, Vector[] centers) {
+        PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<Double, LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(
+            featureExtractor,
+            lbExtractor
+        );
+
+        try (Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset = datasetBuilder.build(
+            (upstream, upstreamSize) -> new EmptyContext(),
+            partDataBuilder
+        )) {
+
+            return dataset.compute(data -> {
+
+                CentroidStat res = new CentroidStat();
+
+                for (int i = 0; i < data.rowSize(); i++) {
+                    final IgniteBiTuple<Integer, Double> closestCentroid = findClosestCentroid(centers, data.getRow(i));
+
+                    int centroidIdx = closestCentroid.get1();
+
+                    double lb = data.label(i);
+
+                    // add new label to label set
+                    res.labels().add(lb);
+
+                    ConcurrentHashMap<Double, Integer> centroidStat = res.centroidStat.get(centroidIdx);
+
+                    if (centroidStat == null) {
+                        centroidStat = new ConcurrentHashMap<>();
+                        centroidStat.put(lb, 1);
+                        res.centroidStat.put(centroidIdx, centroidStat);
+                    } else {
+                        int cnt = centroidStat.containsKey(lb) ? centroidStat.get(lb) : 0;
+                        centroidStat.put(lb, cnt + 1);
+                    }
+
+                    res.counts.merge(centroidIdx, 1,
+                        (IgniteBiFunction<Integer, Integer, Integer>) (i1, i2) -> i1 + i2);
+                }
+                return res;
+            }, (a, b) -> a == null ? b : a.merge(b));
+
+        } catch (Exception e) {
+            throw new RuntimeException(e);
+        }
+    }
+
+    /**
+     * Find the closest cluster center index and distance to it from a given point.
+     *
+     * @param centers Centers to look in.
+     * @param pnt     Point.
+     */
+    private IgniteBiTuple<Integer, Double> findClosestCentroid(Vector[] centers, LabeledVector pnt) {
+        double bestDistance = Double.POSITIVE_INFINITY;
+        int bestInd = 0;
+
+        for (int i = 0; i < centers.length; i++) {
+            if (centers[i] != null) {
+                double dist = distance.compute(centers[i], pnt.features());
+                if (dist < bestDistance) {
+                    bestDistance = dist;
+                    bestInd = i;
+                }
+            }
+        }
+        return new IgniteBiTuple<>(bestInd, bestDistance);
+    }
+
+
+    /**
+     * Gets the amount of clusters.
+     *
+     * @return The parameter value.
+     */
+    public int getK() {
+        return k;
+    }
+
+    /**
+     * Set up the amount of clusters.
+     *
+     * @param k The parameter value.
+     * @return Model with new amount of clusters parameter value.
+     */
+    public ANNClassificationTrainer withK(int k) {
+        this.k = k;
+        return this;
+    }
+
+    /**
+     * Gets the max number of iterations before convergence.
+     *
+     * @return The parameter value.
+     */
+    public int getMaxIterations() {
+        return maxIterations;
+    }
+
+    /**
+     * Set up the max number of iterations before convergence.
+     *
+     * @param maxIterations The parameter value.
+     * @return Model with new max number of iterations before convergence parameter value.
+     */
+    public ANNClassificationTrainer withMaxIterations(int maxIterations) {
+        this.maxIterations = maxIterations;
+        return this;
+    }
+
+    /**
+     * Gets the epsilon.
+     *
+     * @return The parameter value.
+     */
+    public double getEpsilon() {
+        return epsilon;
+    }
+
+    /**
+     * Set up the epsilon.
+     *
+     * @param epsilon The parameter value.
+     * @return Model with new epsilon parameter value.
+     */
+    public ANNClassificationTrainer withEpsilon(double epsilon) {
+        this.epsilon = epsilon;
+        return this;
+    }
+
+    /**
+     * Gets the distance.
+     *
+     * @return The parameter value.
+     */
+    public DistanceMeasure getDistance() {
+        return distance;
+    }
+
+    /**
+     * Set up the distance.
+     *
+     * @param distance The parameter value.
+     * @return Model with new distance parameter value.
+     */
+    public ANNClassificationTrainer withDistance(DistanceMeasure distance) {
+        this.distance = distance;
+        return this;
+    }
+
+    /**
+     * Gets the seed number.
+     *
+     * @return The parameter value.
+     */
+    public long getSeed() {
+        return seed;
+    }
+
+    /**
+     * Set up the seed.
+     *
+     * @param seed The parameter value.
+     * @return Model with new seed parameter value.
+     */
+    public ANNClassificationTrainer withSeed(long seed) {
+        this.seed = seed;
+        return this;
+    }
+
+    /** Service class used for statistics. */
+    public static class CentroidStat {
+
+        /** Count of points closest to the center with a given index. */
+        ConcurrentHashMap<Integer, ConcurrentHashMap<Double, Integer>> centroidStat = new ConcurrentHashMap<>();
+
+        /** Count of points closest to the center with a given index. */
+        ConcurrentHashMap<Integer, Integer> counts = new ConcurrentHashMap<>();
+
+        /** Set of unique labels. */
+        ConcurrentSkipListSet<Double> clsLblsSet = new ConcurrentSkipListSet<>();
+
+        /** Merge current */
+        CentroidStat merge(CentroidStat other) {
+            this.counts = MapUtil.mergeMaps(counts, other.counts, (i1, i2) -> i1 + i2, ConcurrentHashMap::new);
+            this.centroidStat = MapUtil.mergeMaps(centroidStat, other.centroidStat, (m1, m2) ->
+                MapUtil.mergeMaps(m1, m2, (i1, i2) -> i1 + i2, ConcurrentHashMap::new), ConcurrentHashMap::new);
+            this.clsLblsSet.addAll(other.clsLblsSet);
+            return this;
+        }
+
+        /** */
+        public ConcurrentSkipListSet<Double> labels() {
+            return clsLblsSet;
+        }
+
+        /** */
+        ConcurrentHashMap<Integer, ConcurrentHashMap<Double, Integer>> centroidStat() {
+            return centroidStat;
+        }
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNModelFormat.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNModelFormat.java
new file mode 100644 (file)
index 0000000..e10f3b2
--- /dev/null
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.knn.ann;
+
+import java.io.Serializable;
+import org.apache.ignite.ml.knn.classification.KNNModelFormat;
+import org.apache.ignite.ml.knn.classification.NNStrategy;
+import org.apache.ignite.ml.math.distances.DistanceMeasure;
+import org.apache.ignite.ml.structures.LabeledVector;
+import org.apache.ignite.ml.structures.LabeledVectorSet;
+
+/**
+ * ANN model representation.
+ *
+ * @see ANNClassificationModel
+ */
+public class ANNModelFormat extends KNNModelFormat implements Serializable {
+    /** The labeled set of candidates. */
+    private LabeledVectorSet<ProbableLabel, LabeledVector> candidates;
+
+    /**
+     * Creates an instance.
+     * @param k Amount of nearest neighbors.
+     * @param measure Distance measure.
+     * @param stgy kNN strategy.
+     */
+    public ANNModelFormat(int k,
+                          DistanceMeasure measure,
+                          NNStrategy stgy,
+                          LabeledVectorSet<ProbableLabel, LabeledVector> candidates) {
+        this.k = k;
+        this.distanceMeasure = measure;
+        this.stgy = stgy;
+        this.candidates = candidates;
+    }
+
+    /** */
+    public LabeledVectorSet<ProbableLabel, LabeledVector> getCandidates() {
+        return candidates;
+    }
+
+    /** {@inheritDoc} */
+    @Override public int hashCode() {
+        int res = 1;
+
+        res = res * 37 + k;
+        res = res * 37 + distanceMeasure.hashCode();
+        res = res * 37 + stgy.hashCode();
+        res = res * 37 + candidates.hashCode();
+
+        return res;
+    }
+
+    /** {@inheritDoc} */
+    @Override public boolean equals(Object obj) {
+        if (this == obj)
+            return true;
+
+        if (obj == null || getClass() != obj.getClass())
+            return false;
+
+        ANNModelFormat that = (ANNModelFormat)obj;
+
+        return k == that.k
+            && distanceMeasure.equals(that.distanceMeasure)
+            && stgy.equals(that.stgy)
+            && candidates.equals(that.candidates);
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ProbableLabel.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ProbableLabel.java
new file mode 100644 (file)
index 0000000..1fee123
--- /dev/null
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.knn.ann;
+
+import java.util.TreeMap;
+
+/**
+ * The special class for fuzzy labels presenting the probability distribution
+ * over the class labels.
+ */
+public class ProbableLabel {
+    /** Key is label, value is probability to be this class */
+    TreeMap<Double, Double> clsLbls;
+
+    /**
+     * The key is class label,
+     * the value is the probability to be an item of this class.
+     *
+     * @param clsLbls Class labels.
+     */
+    public ProbableLabel(TreeMap<Double, Double> clsLbls) {
+        this.clsLbls = clsLbls;
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/package-info.java
new file mode 100644 (file)
index 0000000..c18867e
--- /dev/null
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains main APIs for ANN classification algorithms.
+ */
+package org.apache.ignite.ml.knn.ann;
index c2c1c43..3404ae8 100644 (file)
 package org.apache.ignite.ml.knn.classification;
 
 import java.util.Arrays;
-import java.util.Collections;
 import java.util.HashMap;
-import java.util.HashSet;
-import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
@@ -30,41 +27,28 @@ import java.util.stream.Collectors;
 import java.util.stream.Stream;
 import org.apache.ignite.ml.Exportable;
 import org.apache.ignite.ml.Exporter;
-import org.apache.ignite.ml.Model;
 import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
-import org.apache.ignite.ml.math.distances.DistanceMeasure;
-import org.apache.ignite.ml.math.distances.EuclideanDistance;
+import org.apache.ignite.ml.knn.NNClassificationModel;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.structures.LabeledDataset;
 import org.apache.ignite.ml.structures.LabeledVector;
-import org.apache.ignite.ml.util.ModelTrace;
-import org.jetbrains.annotations.NotNull;
+import org.apache.ignite.ml.structures.LabeledVectorSet;
 
 /**
  * kNN algorithm model to solve multi-class classification task.
  */
-public class KNNClassificationModel implements Model<Vector, Double>, Exportable<KNNModelFormat> {
+public class KNNClassificationModel extends NNClassificationModel implements Exportable<KNNModelFormat> {
     /** */
     private static final long serialVersionUID = -127386523291350345L;
 
-    /** Amount of nearest neighbors. */
-    protected int k = 5;
-
-    /** Distance measure. */
-    protected DistanceMeasure distanceMeasure = new EuclideanDistance();
-
-    /** kNN strategy. */
-    protected KNNStrategy stgy = KNNStrategy.SIMPLE;
-
     /** Dataset. */
-    private Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset;
+    private Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset;
 
     /**
      * Builds the model via prepared dataset.
      * @param dataset Specially prepared object to run algorithm over it.
      */
-    public KNNClassificationModel(Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset) {
+    public KNNClassificationModel(Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset) {
         this.dataset = dataset;
     }
 
@@ -85,36 +69,6 @@ public class KNNClassificationModel implements Model<Vector, Double>, Exportable
     }
 
     /**
-     * Set up parameter of the kNN model.
-     * @param k Amount of nearest neighbors.
-     * @return Model.
-     */
-    public KNNClassificationModel withK(int k) {
-        this.k = k;
-        return this;
-    }
-
-    /**
-     * Set up parameter of the kNN model.
-     * @param stgy Strategy of calculations.
-     * @return Model.
-     */
-    public KNNClassificationModel withStrategy(KNNStrategy stgy) {
-        this.stgy = stgy;
-        return this;
-    }
-
-    /**
-     * Set up parameter of the kNN model.
-     * @param distanceMeasure Distance measure.
-     * @return Model.
-     */
-    public KNNClassificationModel withDistanceMeasure(DistanceMeasure distanceMeasure) {
-        this.distanceMeasure = distanceMeasure;
-        return this;
-    }
-
-    /**
      * The main idea is calculation all distance pairs between given vector and all vectors in training set, sorting
      * them and finding k vectors with min distance with the given vector.
      *
@@ -127,94 +81,14 @@ public class KNNClassificationModel implements Model<Vector, Double>, Exportable
             return Arrays.asList(getKClosestVectors(data, distanceIdxPairs));
         }, (a, b) -> a == null ? b : Stream.concat(a.stream(), b.stream()).collect(Collectors.toList()));
 
-        LabeledDataset<Double, LabeledVector> neighborsToFilter = buildLabeledDatasetOnListOfVectors(neighborsFromPartitions);
+        LabeledVectorSet<Double, LabeledVector> neighborsToFilter = buildLabeledDatasetOnListOfVectors(neighborsFromPartitions);
 
         return Arrays.asList(getKClosestVectors(neighborsToFilter, getDistances(v, neighborsToFilter)));
     }
 
 
     /** */
-    private LabeledDataset<Double, LabeledVector> buildLabeledDatasetOnListOfVectors(
-        List<LabeledVector> neighborsFromPartitions) {
-        LabeledVector[] arr = new LabeledVector[neighborsFromPartitions.size()];
-        for (int i = 0; i < arr.length; i++)
-            arr[i] = neighborsFromPartitions.get(i);
-
-        return new LabeledDataset<Double, LabeledVector>(arr);
-    }
-
-    /**
-     * Iterates along entries in distance map and fill the resulting k-element array.
-     *
-     * @param trainingData The training data.
-     * @param distanceIdxPairs The distance map.
-     * @return K-nearest neighbors.
-     */
-    @NotNull private LabeledVector[] getKClosestVectors(LabeledDataset<Double, LabeledVector> trainingData,
-        TreeMap<Double, Set<Integer>> distanceIdxPairs) {
-        LabeledVector[] res;
-
-        if (trainingData.rowSize() <= k) {
-            res = new LabeledVector[trainingData.rowSize()];
-            for (int i = 0; i < trainingData.rowSize(); i++)
-                res[i] = trainingData.getRow(i);
-        }
-        else {
-            res = new LabeledVector[k];
-            int i = 0;
-            final Iterator<Double> iter = distanceIdxPairs.keySet().iterator();
-            while (i < k) {
-                double key = iter.next();
-                Set<Integer> idxs = distanceIdxPairs.get(key);
-                for (Integer idx : idxs) {
-                    res[i] = trainingData.getRow(idx);
-                    i++;
-                    if (i >= k)
-                        break; // go to next while-loop iteration
-                }
-            }
-        }
-
-        return res;
-    }
-
-    /**
-     * Computes distances between given vector and each vector in training dataset.
-     *
-     * @param v The given vector.
-     * @param trainingData The training dataset.
-     * @return Key - distanceMeasure from given features before features with idx stored in value. Value is presented
-     * with Set because there can be a few vectors with the same distance.
-     */
-    @NotNull private TreeMap<Double, Set<Integer>> getDistances(Vector v, LabeledDataset<Double, LabeledVector> trainingData) {
-        TreeMap<Double, Set<Integer>> distanceIdxPairs = new TreeMap<>();
-
-        for (int i = 0; i < trainingData.rowSize(); i++) {
-
-            LabeledVector labeledVector = trainingData.getRow(i);
-            if (labeledVector != null) {
-                double distance = distanceMeasure.compute(v, labeledVector.features());
-                putDistanceIdxPair(distanceIdxPairs, i, distance);
-            }
-        }
-        return distanceIdxPairs;
-    }
-
-    /** */
-    private void putDistanceIdxPair(Map<Double, Set<Integer>> distanceIdxPairs, int i, double distance) {
-        if (distanceIdxPairs.containsKey(distance)) {
-            Set<Integer> idxs = distanceIdxPairs.get(distance);
-            idxs.add(i);
-        }
-        else {
-            Set<Integer> idxs = new HashSet<>();
-            idxs.add(i);
-            distanceIdxPairs.put(distance, idxs);
-        }
-    }
-
-    /** */
-    private double classify(List<LabeledVector> neighbors, Vector v, KNNStrategy stgy) {
+    private double classify(List<LabeledVector> neighbors, Vector v, NNStrategy stgy) {
         Map<Double, Double> clsVotes = new HashMap<>();
 
         for (LabeledVector neighbor : neighbors) {
@@ -235,54 +109,5 @@ public class KNNClassificationModel implements Model<Vector, Double>, Exportable
         return getClassWithMaxVotes(clsVotes);
     }
 
-    /** */
-    private double getClassWithMaxVotes(Map<Double, Double> clsVotes) {
-        return Collections.max(clsVotes.entrySet(), Map.Entry.comparingByValue()).getKey();
-    }
-
-    /** */
-    private double getClassVoteForVector(KNNStrategy stgy, double distance) {
-        if (stgy.equals(KNNStrategy.WEIGHTED))
-            return 1 / distance; // strategy.WEIGHTED
-        else
-            return 1.0; // strategy.SIMPLE
-    }
 
-    /** {@inheritDoc} */
-    @Override public int hashCode() {
-        int res = 1;
-
-        res = res * 37 + k;
-        res = res * 37 + distanceMeasure.hashCode();
-        res = res * 37 + stgy.hashCode();
-
-        return res;
-    }
-
-    /** {@inheritDoc} */
-    @Override public boolean equals(Object obj) {
-        if (this == obj)
-            return true;
-
-        if (obj == null || getClass() != obj.getClass())
-            return false;
-
-        KNNClassificationModel that = (KNNClassificationModel)obj;
-
-        return k == that.k && distanceMeasure.equals(that.distanceMeasure) && stgy.equals(that.stgy);
-    }
-
-    /** {@inheritDoc} */
-    @Override public String toString() {
-        return toString(false);
-    }
-
-    /** {@inheritDoc} */
-    @Override public String toString(boolean pretty) {
-        return ModelTrace.builder("KNNClassificationModel", pretty)
-            .addField("k", String.valueOf(k))
-            .addField("measure", distanceMeasure.getClass().getSimpleName())
-            .addField("strategy", stgy.name())
-            .toString();
-    }
 }
index a2efe7f..a588b6e 100644 (file)
@@ -27,13 +27,13 @@ import org.apache.ignite.ml.math.distances.DistanceMeasure;
  */
 public class KNNModelFormat implements Serializable {
     /** Amount of nearest neighbors. */
-    private int k;
+    protected int k;
 
     /** Distance measure. */
-    private DistanceMeasure distanceMeasure;
+    protected DistanceMeasure distanceMeasure;
 
     /** kNN strategy. */
-    private KNNStrategy stgy;
+    protected NNStrategy stgy;
 
     /** Gets amount of nearest neighbors.*/
     public int getK() {
@@ -46,17 +46,21 @@ public class KNNModelFormat implements Serializable {
     }
 
     /** Gets kNN strategy.*/
-    public KNNStrategy getStgy() {
+    public NNStrategy getStgy() {
         return stgy;
     }
 
+    /** */
+    public KNNModelFormat() {
+    }
+
     /**
      * Creates an instance.
      * @param k Amount of nearest neighbors.
      * @param measure Distance measure.
      * @param stgy kNN strategy.
      */
-    public KNNModelFormat(int k, DistanceMeasure measure, KNNStrategy stgy) {
+    public KNNModelFormat(int k, DistanceMeasure measure, NNStrategy stgy) {
         this.k = k;
         this.distanceMeasure = measure;
         this.stgy = stgy;
@@ -18,7 +18,7 @@
 package org.apache.ignite.ml.knn.classification;
 
 /** This enum contains settings for kNN algorithm. */
-public enum KNNStrategy {
+public enum NNStrategy {
     /** The default strategy. All k neighbors have the same weight which is independent
      * on their distance to the query point.*/
     SIMPLE,
index 16dcd8a..c0d6680 100644 (file)
@@ -22,8 +22,8 @@ import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
 import org.apache.ignite.ml.knn.classification.KNNClassificationModel;
 import org.apache.ignite.ml.math.exceptions.UnsupportedOperationException;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.structures.LabeledDataset;
 import org.apache.ignite.ml.structures.LabeledVector;
+import org.apache.ignite.ml.structures.LabeledVectorSet;
 import org.apache.ignite.ml.util.ModelTrace;
 
 /**
@@ -45,7 +45,7 @@ public class KNNRegressionModel extends KNNClassificationModel {
      * Builds the model via prepared dataset.
      * @param dataset Specially prepared object to run algorithm over it.
      */
-    public KNNRegressionModel(Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset) {
+    public KNNRegressionModel(Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset) {
         super(dataset);
     }
 
@@ -24,13 +24,13 @@ import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
 
 /**
- * Class for set of labeled vectors.
+ * The set of labeled vectors used in local partition calculations.
  */
-public class LabeledDataset<L, Row extends LabeledVector> extends Dataset<Row> implements AutoCloseable {
+public class LabeledVectorSet<L, Row extends LabeledVector> extends Dataset<Row> implements AutoCloseable {
     /**
      * Default constructor (required by Externalizable).
      */
-    public LabeledDataset() {
+    public LabeledVectorSet() {
         super();
     }
 
@@ -41,7 +41,7 @@ public class LabeledDataset<L, Row extends LabeledVector> extends Dataset<Row> i
      * @param colSize Amount of attributes. Should be > 0.
      * @param isDistributed Use distributed data structures to keep data.
      */
-    public LabeledDataset(int rowSize, int colSize,  boolean isDistributed){
+    public LabeledVectorSet(int rowSize, int colSize, boolean isDistributed){
         this(rowSize, colSize, null, isDistributed);
     }
 
@@ -51,7 +51,7 @@ public class LabeledDataset<L, Row extends LabeledVector> extends Dataset<Row> i
      * @param rowSize Amount of instances. Should be > 0.
      * @param colSize Amount of attributes. Should be > 0.
      */
-    public LabeledDataset(int rowSize, int colSize){
+    public LabeledVectorSet(int rowSize, int colSize){
         this(rowSize, colSize, null, false);
     }
 
@@ -63,7 +63,7 @@ public class LabeledDataset<L, Row extends LabeledVector> extends Dataset<Row> i
      * @param featureNames Column names.
      * @param isDistributed Use distributed data structures to keep data.
      */
-    public LabeledDataset(int rowSize, int colSize, String[] featureNames, boolean isDistributed){
+    public LabeledVectorSet(int rowSize, int colSize, String[] featureNames, boolean isDistributed){
         super(rowSize, colSize, featureNames, isDistributed);
 
         initializeDataWithLabeledVectors();
@@ -74,7 +74,7 @@ public class LabeledDataset<L, Row extends LabeledVector> extends Dataset<Row> i
      *
      * @param data Should be initialized with one vector at least.
      */
-    public LabeledDataset(Row[] data) {
+    public LabeledVectorSet(Row[] data) {
         super(data);
     }
 
@@ -91,7 +91,7 @@ public class LabeledDataset<L, Row extends LabeledVector> extends Dataset<Row> i
      * @param data Should be initialized with one vector at least.
      * @param colSize Amount of observed attributes in each vector.
      */
-    public LabeledDataset(Row[] data, int colSize) {
+    public LabeledVectorSet(Row[] data, int colSize) {
         super(data, colSize);
     }
 
@@ -102,7 +102,7 @@ public class LabeledDataset<L, Row extends LabeledVector> extends Dataset<Row> i
      * @param mtx Given matrix with rows as observations.
      * @param lbs Labels of observations.
      */
-    public LabeledDataset(double[][] mtx, double[] lbs) {
+    public LabeledVectorSet(double[][] mtx, double[] lbs) {
        this(mtx, lbs, null, false);
     }
 
@@ -114,7 +114,7 @@ public class LabeledDataset<L, Row extends LabeledVector> extends Dataset<Row> i
      * @param featureNames Column names.
      * @param isDistributed Use distributed data structures to keep data.
      */
-    public LabeledDataset(double[][] mtx, double[] lbs, String[] featureNames, boolean isDistributed) {
+    public LabeledVectorSet(double[][] mtx, double[] lbs, String[] featureNames, boolean isDistributed) {
         super();
         assert mtx != null;
         assert lbs != null;
@@ -203,8 +203,8 @@ public class LabeledDataset<L, Row extends LabeledVector> extends Dataset<Row> i
     }
 
     /** Makes copy with new Label objects and old features and Metadata objects. */
-    public LabeledDataset copy(){
-        LabeledDataset res = new LabeledDataset(this.data, this.colSize);
+    public LabeledVectorSet copy(){
+        LabeledVectorSet res = new LabeledVectorSet(this.data, this.colSize);
         res.isDistributed = this.isDistributed;
         res.meta = this.meta;
         for (int i = 0; i < rowSize; i++)
@@ -27,12 +27,12 @@ import org.jetbrains.annotations.NotNull;
 /**
  * Class for splitting Labeled Dataset on train and test sets.
  */
-public class LabeledDatasetTestTrainPair implements Serializable {
+public class LabeledVectorSetTestTrainPair implements Serializable {
     /** Data to keep train set. */
-    private LabeledDataset train;
+    private LabeledVectorSet train;
 
     /** Data to keep test set. */
-    private LabeledDataset test;
+    private LabeledVectorSet test;
 
     /**
      * Creates two subsets of given dataset.
@@ -42,7 +42,7 @@ public class LabeledDatasetTestTrainPair implements Serializable {
      * @param dataset The dataset to split on train and test subsets.
      * @param testPercentage The percentage of the test subset.
      */
-    public LabeledDatasetTestTrainPair(LabeledDataset dataset, double testPercentage) {
+    public LabeledVectorSetTestTrainPair(LabeledVectorSet dataset, double testPercentage) {
         assert testPercentage > 0.0;
         assert testPercentage < 1.0;
         final int datasetSize = dataset.rowSize();
@@ -78,8 +78,8 @@ public class LabeledDatasetTestTrainPair implements Serializable {
             }
         }
 
-        test = new LabeledDataset(testVectors, dataset.colSize());
-        train = new LabeledDataset(trainVectors, dataset.colSize());
+        test = new LabeledVectorSet(testVectors, dataset.colSize());
+        train = new LabeledVectorSet(trainVectors, dataset.colSize());
     }
 
     /** This method generates "random double, integer" pairs, sort them, gets first "testSize" elements and returns appropriate indices */
@@ -100,7 +100,7 @@ public class LabeledDatasetTestTrainPair implements Serializable {
      * Train subset of the whole dataset.
      * @return Train subset.
      */
-    public LabeledDataset train() {
+    public LabeledVectorSet train() {
         return train;
     }
 
@@ -108,7 +108,7 @@ public class LabeledDatasetTestTrainPair implements Serializable {
      * Test subset of the whole dataset.
      * @return Test subset.
      */
-    public LabeledDataset test() {
+    public LabeledVectorSet test() {
         return test;
     }
 }
index b4e552b..0351037 100644 (file)
@@ -23,18 +23,18 @@ import org.apache.ignite.ml.dataset.PartitionDataBuilder;
 import org.apache.ignite.ml.dataset.UpstreamEntry;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.structures.LabeledDataset;
 import org.apache.ignite.ml.structures.LabeledVector;
+import org.apache.ignite.ml.structures.LabeledVectorSet;
 
 /**
- * Partition data builder that builds {@link LabeledDataset}.
+ * Partition data builder that builds {@link LabeledVectorSet}.
  *
  * @param <K> Type of a key in <tt>upstream</tt> data.
  * @param <V> Type of a value in <tt>upstream</tt> data.
  * @param <C> Type of a partition <tt>context</tt>.
  */
 public class LabeledDatasetPartitionDataBuilderOnHeap<K, V, C extends Serializable>
-    implements PartitionDataBuilder<K, V, C, LabeledDataset<Double, LabeledVector>> {
+    implements PartitionDataBuilder<K, V, C, LabeledVectorSet<Double, LabeledVector>> {
     /** */
     private static final long serialVersionUID = -7820760153954269227L;
 
@@ -57,8 +57,8 @@ public class LabeledDatasetPartitionDataBuilderOnHeap<K, V, C extends Serializab
     }
 
     /** {@inheritDoc} */
-    @Override public LabeledDataset<Double, LabeledVector> build(Iterator<UpstreamEntry<K, V>> upstreamData,
-        long upstreamDataSize, C ctx) {
+    @Override public LabeledVectorSet<Double, LabeledVector> build(Iterator<UpstreamEntry<K, V>> upstreamData,
+                                                                   long upstreamDataSize, C ctx) {
         int xCols = -1;
         double[][] x = null;
         double[] y = new double[Math.toIntExact(upstreamDataSize)];
@@ -82,6 +82,6 @@ public class LabeledDatasetPartitionDataBuilderOnHeap<K, V, C extends Serializab
 
             ptr++;
         }
-        return new LabeledDataset<>(x, y);
+        return new LabeledVectorSet<>(x, y);
     }
 }
index 5c20d9c..f370cbd 100644 (file)
@@ -28,8 +28,8 @@ import org.apache.ignite.ml.math.exceptions.NoDataException;
 import org.apache.ignite.ml.math.exceptions.knn.EmptyFileException;
 import org.apache.ignite.ml.math.exceptions.knn.FileParsingException;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.structures.LabeledDataset;
 import org.apache.ignite.ml.structures.LabeledVector;
+import org.apache.ignite.ml.structures.LabeledVectorSet;
 import org.jetbrains.annotations.NotNull;
 
 /** Data pre-processing step which loads data from different file types. */
@@ -43,8 +43,8 @@ public class LabeledDatasetLoader {
      * @param isFallOnBadData Fall on incorrect data if true.
      * @return Labeled Dataset parsed from file.
      */
-    public static LabeledDataset loadFromTxtFile(Path pathToFile, String separator, boolean isDistributed,
-        boolean isFallOnBadData) throws IOException {
+    public static LabeledVectorSet loadFromTxtFile(Path pathToFile, String separator, boolean isDistributed,
+                                                   boolean isFallOnBadData) throws IOException {
         Stream<String> stream = Files.lines(pathToFile);
         List<String> list = new ArrayList<>();
         stream.forEach(list::add);
@@ -81,7 +81,7 @@ public class LabeledDatasetLoader {
                 for (int i = 0; i < vectors.size(); i++)
                     data[i] = new LabeledVector(vectors.get(i), labels.get(i));
 
-                return new LabeledDataset(data, colSize);
+                return new LabeledVectorSet(data, colSize);
             }
             else
                 throw new NoDataException("File should contain first row with data");
@@ -93,7 +93,7 @@ public class LabeledDatasetLoader {
     /** */
     @NotNull private static Vector parseFeatures(Path pathToFile, boolean isDistributed, boolean isFallOnBadData,
         int colSize, int rowIdx, String[] rowData) {
-        final Vector vec = LabeledDataset.emptyVector(colSize, isDistributed);
+        final Vector vec = LabeledVectorSet.emptyVector(colSize, isDistributed);
 
         if (isFallOnBadData && rowData.length != colSize + 1)
             throw new CardinalityException(colSize + 1, rowData.length);
index 1ae896f..4f11318 100644 (file)
@@ -25,8 +25,8 @@ import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
-import org.apache.ignite.ml.structures.LabeledDataset;
 import org.apache.ignite.ml.structures.LabeledVector;
+import org.apache.ignite.ml.structures.LabeledVectorSet;
 import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap;
 import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
 import org.jetbrains.annotations.NotNull;
@@ -60,14 +60,14 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai
 
         assert datasetBuilder != null;
 
-        PartitionDataBuilder<K, V, EmptyContext, LabeledDataset<Double, LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(
+        PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<Double, LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(
             featureExtractor,
             lbExtractor
         );
 
         Vector weights;
 
-        try(Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset = datasetBuilder.build(
+        try(Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset = datasetBuilder.build(
             (upstream, upstreamSize) -> new EmptyContext(),
             partDataBuilder
         )) {
@@ -91,7 +91,7 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai
     }
 
     /** */
-    private Vector calculateUpdates(Vector weights, Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset) {
+    private Vector calculateUpdates(Vector weights, Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset) {
         return dataset.compute(data -> {
             Vector copiedWeights = weights.copy();
             Vector deltaWeights = initializeWeightsWithZeros(weights.size());
@@ -116,8 +116,8 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai
     }
 
     /** */
-    private Deltas getDeltas(LabeledDataset data, Vector copiedWeights, int amountOfObservation, Vector tmpAlphas,
-        int randomIdx) {
+    private Deltas getDeltas(LabeledVectorSet data, Vector copiedWeights, int amountOfObservation, Vector tmpAlphas,
+                             int randomIdx) {
         LabeledVector row = (LabeledVector)data.getRow(randomIdx);
         Double lb = (Double)row.label();
         Vector v = makeVectorWithInterceptElement(row);
index 3e3bab5..42f5dec 100644 (file)
@@ -28,13 +28,20 @@ import org.apache.ignite.ml.clustering.kmeans.KMeansModel;
 import org.apache.ignite.ml.clustering.kmeans.KMeansModelFormat;
 import org.apache.ignite.ml.clustering.kmeans.KMeansTrainer;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import org.apache.ignite.ml.knn.NNClassificationModel;
+import org.apache.ignite.ml.knn.ann.ANNClassificationModel;
+import org.apache.ignite.ml.knn.ann.ANNModelFormat;
+import org.apache.ignite.ml.knn.ann.ProbableLabel;
 import org.apache.ignite.ml.knn.classification.KNNClassificationModel;
 import org.apache.ignite.ml.knn.classification.KNNModelFormat;
-import org.apache.ignite.ml.knn.classification.KNNStrategy;
+import org.apache.ignite.ml.knn.classification.NNStrategy;
 import org.apache.ignite.ml.math.distances.EuclideanDistance;
+import org.apache.ignite.ml.math.distances.ManhattanDistance;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
 import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
+import org.apache.ignite.ml.structures.LabeledVector;
+import org.apache.ignite.ml.structures.LabeledVectorSet;
 import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationModel;
 import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationModel;
 import org.junit.Assert;
@@ -165,10 +172,10 @@ public class LocalModelsTest {
     @Test
     public void importExportKNNModelTest() throws IOException {
         executeModelTest(mdlFilePath -> {
-            KNNClassificationModel mdl = new KNNClassificationModel(null)
+            NNClassificationModel mdl = new KNNClassificationModel(null)
                 .withK(3)
                 .withDistanceMeasure(new EuclideanDistance())
-                .withStrategy(KNNStrategy.SIMPLE);
+                .withStrategy(NNStrategy.SIMPLE);
 
             Exporter<KNNModelFormat, String> exporter = new FileExporter<>();
             mdl.saveModel(exporter, mdlFilePath);
@@ -177,7 +184,37 @@ public class LocalModelsTest {
 
             Assert.assertNotNull(load);
 
-            KNNClassificationModel importedMdl = new KNNClassificationModel(null)
+            NNClassificationModel importedMdl = new KNNClassificationModel(null)
+                .withK(load.getK())
+                .withDistanceMeasure(load.getDistanceMeasure())
+                .withStrategy(load.getStgy());
+
+            Assert.assertTrue("", mdl.equals(importedMdl));
+
+            return null;
+        });
+    }
+
+    /** */
+    @Test
+    public void importExportANNModelTest() throws IOException {
+        executeModelTest(mdlFilePath -> {
+            final LabeledVectorSet<ProbableLabel, LabeledVector> centers = new LabeledVectorSet<>();
+
+            NNClassificationModel mdl = new ANNClassificationModel(centers)
+                .withK(4)
+                .withDistanceMeasure(new ManhattanDistance())
+                .withStrategy(NNStrategy.WEIGHTED);
+
+            Exporter<KNNModelFormat, String> exporter = new FileExporter<>();
+            mdl.saveModel(exporter, mdlFilePath);
+
+            ANNModelFormat load = (ANNModelFormat) exporter.load(mdlFilePath);
+
+            Assert.assertNotNull(load);
+
+
+            NNClassificationModel importedMdl = new ANNClassificationModel(load.getCandidates())
                 .withK(load.getK())
                 .withDistanceMeasure(load.getDistanceMeasure())
                 .withStrategy(load.getStgy());
index c4d896c..552c478 100644 (file)
@@ -23,7 +23,7 @@ import org.apache.ignite.ml.clustering.kmeans.KMeansModel;
 import org.apache.ignite.ml.clustering.kmeans.KMeansModelFormat;
 import org.apache.ignite.ml.knn.classification.KNNClassificationModel;
 import org.apache.ignite.ml.knn.classification.KNNModelFormat;
-import org.apache.ignite.ml.knn.classification.KNNStrategy;
+import org.apache.ignite.ml.knn.classification.NNStrategy;
 import org.apache.ignite.ml.math.distances.EuclideanDistance;
 import org.apache.ignite.ml.math.distances.HammingDistance;
 import org.apache.ignite.ml.math.distances.ManhattanDistance;
@@ -83,8 +83,8 @@ public class CollectionsTest {
         test(new KMeansModel(new Vector[] {}, new ManhattanDistance()),
             new KMeansModel(new Vector[] {}, new HammingDistance()));
 
-        test(new KNNModelFormat(1, new ManhattanDistance(), KNNStrategy.SIMPLE),
-            new KNNModelFormat(2, new ManhattanDistance(), KNNStrategy.SIMPLE));
+        test(new KNNModelFormat(1, new ManhattanDistance(), NNStrategy.SIMPLE),
+            new KNNModelFormat(2, new ManhattanDistance(), NNStrategy.SIMPLE));
 
         test(new KNNClassificationModel(null).withK(1), new KNNClassificationModel(null).withK(2));
 
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java
new file mode 100644 (file)
index 0000000..ea602cd
--- /dev/null
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.knn;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ThreadLocalRandom;
+import org.apache.ignite.ml.TestUtils;
+import org.apache.ignite.ml.knn.ann.ANNClassificationTrainer;
+import org.apache.ignite.ml.knn.classification.NNStrategy;
+import org.apache.ignite.ml.math.distances.EuclideanDistance;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+/** Tests behaviour of ANNClassificationTest. */
+@RunWith(Parameterized.class)
+public class ANNClassificationTest {
+    /** Number of parts to be tested. */
+    private static final int[] partsToBeTested = new int[]{1, 2, 3, 4, 5, 7, 100};
+
+    /** Fixed size of Dataset. */
+    private static final int AMOUNT_OF_OBSERVATIONS = 1000;
+
+    /** Fixed size of columns in Dataset. */
+    private static final int AMOUNT_OF_FEATURES = 2;
+
+    /** Precision in test checks. */
+    private static final double PRECISION = 1e-2;
+
+    /** Number of partitions. */
+    @Parameterized.Parameter
+    public int parts;
+
+    /** Parameters. */
+    @Parameterized.Parameters(name = "Data divided on {0} partitions, training with batch size {1}")
+    public static Iterable<Integer[]> data() {
+        List<Integer[]> res = new ArrayList<>();
+
+        for (int part : partsToBeTested)
+            res.add(new Integer[]{part});
+
+        return res;
+    }
+
+    /** */
+    @Test
+    public void testBinaryClassificationTest() {
+        Map<Integer, double[]> data = new HashMap<>();
+
+        ThreadLocalRandom rndX = ThreadLocalRandom.current();
+        ThreadLocalRandom rndY = ThreadLocalRandom.current();
+
+        for (int i = 0; i < AMOUNT_OF_OBSERVATIONS; i++) {
+            double x = rndX.nextDouble(500, 600);
+            double y = rndY.nextDouble(500, 600);
+            double[] vec = new double[AMOUNT_OF_FEATURES + 1];
+            vec[0] = 0; // assign label.
+            vec[1] = x;
+            vec[2] = y;
+            data.put(i, vec);
+        }
+
+        for (int i = AMOUNT_OF_OBSERVATIONS; i < AMOUNT_OF_OBSERVATIONS * 2; i++) {
+            double x = rndX.nextDouble(-600, -500);
+            double y = rndY.nextDouble(-600, -500);
+            double[] vec = new double[AMOUNT_OF_FEATURES + 1];
+            vec[0] = 1; // assign label.
+            vec[1] = x;
+            vec[2] = y;
+            data.put(i, vec);
+        }
+
+        ANNClassificationTrainer trainer = new ANNClassificationTrainer()
+            .withK(10);
+
+        NNClassificationModel mdl = trainer.fit(
+            data,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        ).withK(3)
+            .withDistanceMeasure(new EuclideanDistance())
+            .withStrategy(NNStrategy.SIMPLE);
+
+        TestUtils.assertEquals(0, mdl.apply(VectorUtils.of(550, 550)), PRECISION);
+        TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(-550, -550)), PRECISION);
+    }
+}
\ No newline at end of file
index aeb2414..c176682 100644 (file)
@@ -22,9 +22,8 @@ import java.util.Arrays;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
-import org.apache.ignite.ml.knn.classification.KNNClassificationModel;
 import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer;
-import org.apache.ignite.ml.knn.classification.KNNStrategy;
+import org.apache.ignite.ml.knn.classification.NNStrategy;
 import org.apache.ignite.ml.math.distances.EuclideanDistance;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
@@ -70,14 +69,14 @@ public class KNNClassificationTest {
 
         KNNClassificationTrainer trainer = new KNNClassificationTrainer();
 
-        KNNClassificationModel knnMdl = trainer.fit(
+        NNClassificationModel knnMdl = trainer.fit(
             data,
             parts,
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
             (k, v) -> v[2]
         ).withK(3)
             .withDistanceMeasure(new EuclideanDistance())
-            .withStrategy(KNNStrategy.SIMPLE);
+            .withStrategy(NNStrategy.SIMPLE);
 
         assertTrue(knnMdl.toString().length() > 0);
         assertTrue(knnMdl.toString(true).length() > 0);
@@ -102,14 +101,14 @@ public class KNNClassificationTest {
 
         KNNClassificationTrainer trainer = new KNNClassificationTrainer();
 
-        KNNClassificationModel knnMdl = trainer.fit(
+        NNClassificationModel knnMdl = trainer.fit(
             data,
             parts,
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
             (k, v) -> v[2]
         ).withK(1)
             .withDistanceMeasure(new EuclideanDistance())
-            .withStrategy(KNNStrategy.SIMPLE);
+            .withStrategy(NNStrategy.SIMPLE);
 
         Vector firstVector = new DenseVector(new double[] {2.0, 2.0});
         assertEquals(knnMdl.apply(firstVector), 1.0);
@@ -130,14 +129,14 @@ public class KNNClassificationTest {
 
         KNNClassificationTrainer trainer = new KNNClassificationTrainer();
 
-        KNNClassificationModel knnMdl = trainer.fit(
+        NNClassificationModel knnMdl = trainer.fit(
             data,
             parts,
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
             (k, v) -> v[2]
         ).withK(3)
             .withDistanceMeasure(new EuclideanDistance())
-            .withStrategy(KNNStrategy.SIMPLE);
+            .withStrategy(NNStrategy.SIMPLE);
 
         Vector vector = new DenseVector(new double[] {-1.01, -1.01});
         assertEquals(knnMdl.apply(vector), 2.0);
@@ -156,14 +155,14 @@ public class KNNClassificationTest {
 
         KNNClassificationTrainer trainer = new KNNClassificationTrainer();
 
-        KNNClassificationModel knnMdl = trainer.fit(
+        NNClassificationModel knnMdl = trainer.fit(
             data,
             parts,
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
             (k, v) -> v[2]
         ).withK(3)
             .withDistanceMeasure(new EuclideanDistance())
-            .withStrategy(KNNStrategy.WEIGHTED);
+            .withStrategy(NNStrategy.WEIGHTED);
 
         Vector vector = new DenseVector(new double[] {-1.01, -1.01});
         assertEquals(knnMdl.apply(vector), 1.0);
index 7d57ec9..e05903e 100644 (file)
@@ -23,7 +23,7 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
-import org.apache.ignite.ml.knn.classification.KNNStrategy;
+import org.apache.ignite.ml.knn.classification.NNStrategy;
 import org.apache.ignite.ml.knn.regression.KNNRegressionModel;
 import org.apache.ignite.ml.knn.regression.KNNRegressionTrainer;
 import org.apache.ignite.ml.math.distances.EuclideanDistance;
@@ -77,7 +77,7 @@ public class KNNRegressionTest {
             (k, v) -> v[0]
         ).withK(1)
             .withDistanceMeasure(new EuclideanDistance())
-            .withStrategy(KNNStrategy.SIMPLE);
+            .withStrategy(NNStrategy.SIMPLE);
 
         Vector vector = new DenseVector(new double[] {0, 0, 0, 5.0, 0.0});
         System.out.println(knnMdl.apply(vector));
@@ -87,17 +87,17 @@ public class KNNRegressionTest {
     /** */
     @Test
     public void testLongly() {
-        testLongly(KNNStrategy.SIMPLE);
+        testLongly(NNStrategy.SIMPLE);
     }
 
     /** */
     @Test
     public void testLonglyWithWeightedStrategy() {
-        testLongly(KNNStrategy.WEIGHTED);
+        testLongly(NNStrategy.WEIGHTED);
     }
 
     /** */
-    private void testLongly(KNNStrategy stgy) {
+    private void testLongly(NNStrategy stgy) {
         Map<Integer, double[]> data = new HashMap<>();
         data.put(0, new double[] {60323, 83.0, 234289, 2356, 1590, 107608, 1947});
         data.put(1, new double[] {61122, 88.5, 259426, 2325, 1456, 108632, 1948});
@@ -123,16 +123,12 @@ public class KNNRegressionTest {
             (k, v) -> v[0]
         ).withK(3)
             .withDistanceMeasure(new EuclideanDistance())
-            .withStrategy(stgy);
+            .withStrategy(NNStrategy.SIMPLE);
 
         Vector vector = new DenseVector(new double[] {104.6, 419180, 2822, 2857, 118734, 1956});
 
         Assert.assertNotNull(knnMdl.apply(vector));
 
         Assert.assertEquals(67857, knnMdl.apply(vector), 2000);
-
-        Assert.assertTrue(knnMdl.toString().contains(stgy.name()));
-        Assert.assertTrue(knnMdl.toString(true).contains(stgy.name()));
-        Assert.assertTrue(knnMdl.toString(false).contains(stgy.name()));
     }
 }
index 55ef24e..0303d26 100644 (file)
@@ -25,9 +25,10 @@ import org.junit.runners.Suite;
  */
 @RunWith(Suite.class)
 @Suite.SuiteClasses({
+    ANNClassificationTest.class,
     KNNClassificationTest.class,
     KNNRegressionTest.class,
-    LabeledDatasetTest.class
+    LabeledVectorSetTest.class
 })
 public class KNNTestSuite {
 }
index dbcdb99..f3b8b3a 100644 (file)
@@ -21,7 +21,7 @@ import java.io.IOException;
 import java.net.URISyntaxException;
 import java.nio.file.Path;
 import java.nio.file.Paths;
-import org.apache.ignite.ml.structures.LabeledDataset;
+import org.apache.ignite.ml.structures.LabeledVectorSet;
 import org.apache.ignite.ml.structures.preprocessing.LabeledDatasetLoader;
 
 /**
@@ -37,7 +37,7 @@ public class LabeledDatasetHelper {
      * @param rsrcPath path to dataset.
      * @return null if path is incorrect.
      */
-    public static LabeledDataset loadDatasetFromTxt(String rsrcPath, boolean isFallOnBadData) {
+    public static LabeledVectorSet loadDatasetFromTxt(String rsrcPath, boolean isFallOnBadData) {
         try {
             Path path = Paths.get(LabeledDatasetHelper.class.getClassLoader().getResource(rsrcPath).toURI());
             try {
@@ -29,17 +29,17 @@ import org.apache.ignite.ml.math.exceptions.knn.EmptyFileException;
 import org.apache.ignite.ml.math.exceptions.knn.FileParsingException;
 import org.apache.ignite.ml.math.exceptions.knn.NoLabelVectorException;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.structures.LabeledDataset;
-import org.apache.ignite.ml.structures.LabeledDatasetTestTrainPair;
 import org.apache.ignite.ml.structures.LabeledVector;
+import org.apache.ignite.ml.structures.LabeledVectorSet;
+import org.apache.ignite.ml.structures.LabeledVectorSetTestTrainPair;
 import org.apache.ignite.ml.structures.preprocessing.LabeledDatasetLoader;
 import org.junit.Test;
 
 import static junit.framework.TestCase.assertEquals;
 import static junit.framework.TestCase.fail;
 
-/** Tests behaviour of LabeledDataset. */
-public class LabeledDatasetTest implements ExternalizableTest<LabeledDataset> {
+/** Tests behaviour of KNNClassificationTest. */
+public class LabeledVectorSetTest implements ExternalizableTest<LabeledVectorSet> {
     /** */
     private static final String KNN_IRIS_TXT = "datasets/knn/iris.txt";
 
@@ -69,7 +69,7 @@ public class LabeledDatasetTest implements ExternalizableTest<LabeledDataset> {
         double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
 
         String[] featureNames = new String[] {"x", "y"};
-        final LabeledDataset dataset = new LabeledDataset(mtx, lbs, featureNames, false);
+        final LabeledVectorSet dataset = new LabeledVectorSet(mtx, lbs, featureNames, false);
 
         assertEquals(dataset.getFeatureName(0), "x");
     }
@@ -87,7 +87,7 @@ public class LabeledDatasetTest implements ExternalizableTest<LabeledDataset> {
                 {-2.0, -1.0}};
         double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
 
-        final LabeledDataset dataset = new LabeledDataset(mtx, lbs, null, false);
+        final LabeledVectorSet dataset = new LabeledVectorSet(mtx, lbs, null, false);
 
         assertEquals(dataset.colSize(), 2);
         assertEquals(dataset.rowSize(), 6);
@@ -104,10 +104,10 @@ public class LabeledDatasetTest implements ExternalizableTest<LabeledDataset> {
         dataset.setLabel(0, 2.0);
         assertEquals(row.label(), 2.0);
 
-        assertEquals(0, new LabeledDataset().rowSize());
-        assertEquals(1, new LabeledDataset(1, 2).rowSize());
-        assertEquals(1, new LabeledDataset(1, 2, true).rowSize());
-        assertEquals(1, new LabeledDataset(1, 2, null, true).rowSize());
+        assertEquals(0, new LabeledVectorSet().rowSize());
+        assertEquals(1, new LabeledVectorSet(1, 2).rowSize());
+        assertEquals(1, new LabeledVectorSet(1, 2, true).rowSize());
+        assertEquals(1, new LabeledVectorSet(1, 2, null, true).rowSize());
     }
 
     /** */
@@ -124,7 +124,7 @@ public class LabeledDatasetTest implements ExternalizableTest<LabeledDataset> {
         double[] lbs = new double[] {};
 
         try {
-            new LabeledDataset(mtx, lbs);
+            new LabeledVectorSet(mtx, lbs);
             fail("CardinalityException");
         }
         catch (CardinalityException e) {
@@ -141,7 +141,7 @@ public class LabeledDatasetTest implements ExternalizableTest<LabeledDataset> {
         double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
 
         try {
-            new LabeledDataset(mtx, lbs);
+            new LabeledVectorSet(mtx, lbs);
             fail("CardinalityException");
         }
         catch (CardinalityException e) {
@@ -153,8 +153,8 @@ public class LabeledDatasetTest implements ExternalizableTest<LabeledDataset> {
     /** */
     @Test
     public void testLoadingCorrectTxtFile() {
-        LabeledDataset training = LabeledDatasetHelper.loadDatasetFromTxt(KNN_IRIS_TXT, false);
-        assertEquals(Objects.requireNonNull(training).rowSize(), 150);
+        LabeledVectorSet training = LabeledDatasetHelper.loadDatasetFromTxt(KNN_IRIS_TXT, false);
+        assertEquals(training.rowSize(), 150);
     }
 
     /** */
@@ -186,8 +186,8 @@ public class LabeledDatasetTest implements ExternalizableTest<LabeledDataset> {
     /** */
     @Test
     public void testLoadingFileWithIncorrectData() {
-        LabeledDataset training = LabeledDatasetHelper.loadDatasetFromTxt(IRIS_INCORRECT_TXT, false);
-        assertEquals(149, Objects.requireNonNull(training).rowSize());
+        LabeledVectorSet training = LabeledDatasetHelper.loadDatasetFromTxt(IRIS_INCORRECT_TXT, false);
+        assertEquals(149, training.rowSize());
     }
 
     /** */
@@ -209,7 +209,7 @@ public class LabeledDatasetTest implements ExternalizableTest<LabeledDataset> {
     public void testLoadingFileWithMissedData() throws URISyntaxException, IOException {
         Path path = Paths.get(Objects.requireNonNull(this.getClass().getClassLoader().getResource(IRIS_MISSED_DATA)).toURI());
 
-        LabeledDataset training = LabeledDatasetLoader.loadFromTxtFile(path, ",", false, false);
+        LabeledVectorSet training = LabeledDatasetLoader.loadFromTxtFile(path, ",", false, false);
 
         assertEquals(training.features(2).get(1), 0.0);
     }
@@ -227,24 +227,24 @@ public class LabeledDatasetTest implements ExternalizableTest<LabeledDataset> {
                 {-2.0, -1.0}};
         double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
 
-        LabeledDataset training = new LabeledDataset(mtx, lbs);
+        LabeledVectorSet training = new LabeledVectorSet(mtx, lbs);
 
-        LabeledDatasetTestTrainPair split1 = new LabeledDatasetTestTrainPair(training, 0.67);
+        LabeledVectorSetTestTrainPair split1 = new LabeledVectorSetTestTrainPair(training, 0.67);
 
         assertEquals(4, split1.test().rowSize());
         assertEquals(2, split1.train().rowSize());
 
-        LabeledDatasetTestTrainPair split2 = new LabeledDatasetTestTrainPair(training, 0.65);
+        LabeledVectorSetTestTrainPair split2 = new LabeledVectorSetTestTrainPair(training, 0.65);
 
         assertEquals(3, split2.test().rowSize());
         assertEquals(3, split2.train().rowSize());
 
-        LabeledDatasetTestTrainPair split3 = new LabeledDatasetTestTrainPair(training, 0.4);
+        LabeledVectorSetTestTrainPair split3 = new LabeledVectorSetTestTrainPair(training, 0.4);
 
         assertEquals(2, split3.test().rowSize());
         assertEquals(4, split3.train().rowSize());
 
-        LabeledDatasetTestTrainPair split4 = new LabeledDatasetTestTrainPair(training, 0.3);
+        LabeledVectorSetTestTrainPair split4 = new LabeledVectorSetTestTrainPair(training, 0.3);
 
         assertEquals(1, split4.test().rowSize());
         assertEquals(5, split4.train().rowSize());
@@ -263,7 +263,7 @@ public class LabeledDatasetTest implements ExternalizableTest<LabeledDataset> {
                 {-2.0, -1.0}};
         double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
 
-        LabeledDataset dataset = new LabeledDataset(mtx, lbs);
+        LabeledVectorSet dataset = new LabeledVectorSet(mtx, lbs);
         final double[] labels = dataset.labels();
         for (int i = 0; i < lbs.length; i++)
             assertEquals(lbs[i], labels[i]);
@@ -273,7 +273,7 @@ public class LabeledDatasetTest implements ExternalizableTest<LabeledDataset> {
     @Test(expected = NoLabelVectorException.class)
     @SuppressWarnings("unchecked")
     public void testSetLabelInvalid() {
-        new LabeledDataset(new LabeledVector[1]).setLabel(0, 2.0);
+        new LabeledVectorSet(new LabeledVector[1]).setLabel(0, 2.0);
     }
 
     /** */
@@ -288,7 +288,7 @@ public class LabeledDatasetTest implements ExternalizableTest<LabeledDataset> {
                 {-2.0, -1.0}};
         double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
 
-        LabeledDataset dataset = new LabeledDataset(mtx, lbs);
+        LabeledVectorSet dataset = new LabeledVectorSet(mtx, lbs);
         this.externalizeTest(dataset);
     }
 }