IGNITE-7451: Make Linear SVM for multi-classification
authorzaleslaw <zaleslaw.sin@gmail.com>
Mon, 12 Feb 2018 19:30:22 +0000 (22:30 +0300)
committerYuriBabak <y.chief@gmail.com>
Mon, 12 Feb 2018 19:30:22 +0000 (22:30 +0300)
this closes #3484

15 files changed:
examples/src/main/java/org/apache/ignite/examples/ml/svm/SVMBinaryClassificationExample.java
modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDataset.java
modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationModel.java [moved from modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationModel.java with 85% similarity]
modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationModel.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java [new file with mode: 0644]
modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java
modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java
modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java
modules/ml/src/test/java/org/apache/ignite/ml/svm/binary/DistributedLinearSVMBinaryClassificationTrainerTest.java [moved from modules/ml/src/test/java/org/apache/ignite/ml/svm/DistributedLinearSVMClassificationSCDATrainerTest.java with 79% similarity]
modules/ml/src/test/java/org/apache/ignite/ml/svm/binary/GenericLinearSVMBinaryClassificationTrainerTest.java [moved from modules/ml/src/test/java/org/apache/ignite/ml/svm/GenericLinearSVMTrainerTest.java with 88% similarity]
modules/ml/src/test/java/org/apache/ignite/ml/svm/binary/LocalLinearSVMBinaryClassificationTrainerTest.java [moved from modules/ml/src/test/java/org/apache/ignite/ml/svm/LocalLinearSVMClassificationSCDATrainerTest.java with 81% similarity]
modules/ml/src/test/java/org/apache/ignite/ml/svm/multi/DistributedLinearSVMMultiClassClassificationTrainerTest.java [new file with mode: 0644]
modules/ml/src/test/java/org/apache/ignite/ml/svm/multi/GenericLinearSVMMultiClassClassificationTrainerTest.java [new file with mode: 0644]
modules/ml/src/test/java/org/apache/ignite/ml/svm/multi/LocalLinearSVMMultiClassClassificationTrainerTest.java [new file with mode: 0644]

index 979afe2..e256276 100644 (file)
@@ -19,10 +19,7 @@ package org.apache.ignite.examples.ml.svm;
 
 import java.io.File;
 import java.io.IOException;
-import java.net.URISyntaxException;
-import java.net.URL;
 import java.nio.file.Path;
-import java.nio.file.Paths;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.Ignition;
 import org.apache.ignite.examples.ExampleNodeStartup;
@@ -33,13 +30,13 @@ import org.apache.ignite.ml.structures.LabeledDatasetTestTrainPair;
 import org.apache.ignite.ml.structures.preprocessing.LabeledDatasetLoader;
 import org.apache.ignite.ml.structures.preprocessing.LabellingMachine;
 import org.apache.ignite.ml.structures.preprocessing.Normalizer;
+import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationModel;
 import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationTrainer;
-import org.apache.ignite.ml.svm.SVMLinearClassificationModel;
 import org.apache.ignite.thread.IgniteThread;
 
 /**
  * <p>
- * Example of using {@link org.apache.ignite.ml.svm.SVMLinearClassificationModel} with Titanic dataset.</p>
+ * Example of using {@link org.apache.ignite.ml.svm.SVMLinearBinaryClassificationModel} with Titanic dataset.</p>
  * <p>
  * Note that in this example we cannot guarantee order in which nodes return results of intermediate
  * computations and therefore algorithm can return different results.</p>
@@ -95,10 +92,10 @@ public class SVMBinaryClassificationExample {
                     LabeledDataset train = split.train();
 
                     System.out.println("\n>>> Create new linear binary SVM trainer object.");
-                    Trainer<SVMLinearClassificationModel, LabeledDataset> trainer = new SVMLinearBinaryClassificationTrainer();
+                    Trainer<SVMLinearBinaryClassificationModel, LabeledDataset> trainer = new SVMLinearBinaryClassificationTrainer();
 
                     System.out.println("\n>>> Perform the training to get the model.");
-                    SVMLinearClassificationModel mdl = trainer.train(train);
+                    SVMLinearBinaryClassificationModel mdl = trainer.train(train);
 
                     System.out.println("\n>>> SVM classification model: " + mdl);
 
index 15e8c40..0028a16 100644 (file)
@@ -191,10 +191,20 @@ public class LabeledDataset<L, Row extends LabeledVector> extends Dataset<Row> {
 
     /** */
     public static Vector emptyVector(int size, boolean isDistributed) {
-
         if(isDistributed)
             return new SparseDistributedVector(size);
         else
             return new DenseLocalOnHeapVector(size);
     }
+
+    /** Makes copy with new Label objects and old features and Metadata objects. */
+    public LabeledDataset copy(){
+        LabeledDataset res = new LabeledDataset(this.data, this.colSize);
+        res.isDistributed = this.isDistributed;
+        res.meta = this.meta;
+        for (int i = 0; i < rowSize; i++)
+            res.setLabel(i, this.label(i));
+
+        return res;
+    }
 }
@@ -27,7 +27,7 @@ import org.apache.ignite.ml.math.Vector;
 /**
  * Base class for SVM linear classification model.
  */
-public class SVMLinearClassificationModel implements Model<Vector, Double>, Exportable<SVMLinearClassificationModel>, Serializable {
+public class SVMLinearBinaryClassificationModel implements Model<Vector, Double>, Exportable<SVMLinearBinaryClassificationModel>, Serializable {
     /** Output label format. -1 and +1 for false value and raw distances from the separating hyperplane otherwise. */
     private boolean isKeepingRawLabels = false;
 
@@ -41,47 +41,51 @@ public class SVMLinearClassificationModel implements Model<Vector, Double>, Expo
     private double intercept;
 
     /** */
-    public SVMLinearClassificationModel(Vector weights, double intercept) {
+    public SVMLinearBinaryClassificationModel(Vector weights, double intercept) {
         this.weights = weights;
         this.intercept = intercept;
     }
 
     /**
      * Set up the output label format.
+     *
      * @param isKeepingRawLabels The parameter value.
      * @return Model with new isKeepingRawLabels parameter value.
      */
-    public SVMLinearClassificationModel withRawLabels(boolean isKeepingRawLabels) {
+    public SVMLinearBinaryClassificationModel withRawLabels(boolean isKeepingRawLabels) {
         this.isKeepingRawLabels = isKeepingRawLabels;
         return this;
     }
 
     /**
      * Set up the threshold.
+     *
      * @param threshold The parameter value.
      * @return Model with new threshold parameter value.
      */
-    public SVMLinearClassificationModel withThreshold(double threshold) {
+    public SVMLinearBinaryClassificationModel withThreshold(double threshold) {
         this.threshold = threshold;
         return this;
     }
 
     /**
      * Set up the weights.
+     *
      * @param weights The parameter value.
      * @return Model with new weights parameter value.
      */
-    public SVMLinearClassificationModel withWeights(Vector weights) {
+    public SVMLinearBinaryClassificationModel withWeights(Vector weights) {
         this.weights = weights;
         return this;
     }
 
     /**
      * Set up the intercept.
+     *
      * @param intercept The parameter value.
      * @return Model with new intercept parameter value.
      */
-    public SVMLinearClassificationModel withIntercept(double intercept) {
+    public SVMLinearBinaryClassificationModel withIntercept(double intercept) {
         this.intercept = intercept;
         return this;
     }
@@ -97,6 +101,7 @@ public class SVMLinearClassificationModel implements Model<Vector, Double>, Expo
 
     /**
      * Gets the output label format mode.
+     *
      * @return The parameter value.
      */
     public boolean isKeepingRawLabels() {
@@ -105,6 +110,7 @@ public class SVMLinearClassificationModel implements Model<Vector, Double>, Expo
 
     /**
      * Gets the threshold.
+     *
      * @return The parameter value.
      */
     public double threshold() {
@@ -113,6 +119,7 @@ public class SVMLinearClassificationModel implements Model<Vector, Double>, Expo
 
     /**
      * Gets the weights.
+     *
      * @return The parameter value.
      */
     public Vector weights() {
@@ -121,6 +128,7 @@ public class SVMLinearClassificationModel implements Model<Vector, Double>, Expo
 
     /**
      * Gets the intercept.
+     *
      * @return The parameter value.
      */
     public double intercept() {
@@ -128,7 +136,7 @@ public class SVMLinearClassificationModel implements Model<Vector, Double>, Expo
     }
 
     /** {@inheritDoc} */
-    @Override public <P> void saveModel(Exporter<SVMLinearClassificationModel, P> exporter, P path) {
+    @Override public <P> void saveModel(Exporter<SVMLinearBinaryClassificationModel, P> exporter, P path) {
         exporter.save(this, path);
     }
 
@@ -138,7 +146,7 @@ public class SVMLinearClassificationModel implements Model<Vector, Double>, Expo
             return true;
         if (o == null || getClass() != o.getClass())
             return false;
-        SVMLinearClassificationModel mdl = (SVMLinearClassificationModel)o;
+        SVMLinearBinaryClassificationModel mdl = (SVMLinearBinaryClassificationModel)o;
         return Double.compare(mdl.intercept, intercept) == 0
             && Double.compare(mdl.threshold, threshold) == 0
             && Boolean.compare(mdl.isKeepingRawLabels, isKeepingRawLabels) == 0
index a14694b..ee3b6e8 100644 (file)
@@ -32,7 +32,7 @@ import org.jetbrains.annotations.NotNull;
  * and +1 labels for two classes and makes binary classification. </p> The paper about this algorithm could be found
  * here https://arxiv.org/abs/1409.1458.
  */
-public class SVMLinearBinaryClassificationTrainer implements Trainer<SVMLinearClassificationModel, LabeledDataset> {
+public class SVMLinearBinaryClassificationTrainer implements Trainer<SVMLinearBinaryClassificationModel, LabeledDataset> {
     /** Amount of outer SDCA algorithm iterations. */
     private int amountOfIterations = 20;
 
@@ -51,7 +51,7 @@ public class SVMLinearBinaryClassificationTrainer implements Trainer<SVMLinearCl
      * @param data data to build model
      * @return model
      */
-    @Override public SVMLinearClassificationModel train(LabeledDataset data) {
+    @Override public SVMLinearBinaryClassificationModel train(LabeledDataset data) {
         isDistributed = data.isDistributed();
 
         final int weightVectorSizeWithIntercept = data.colSize() + 1;
@@ -62,7 +62,7 @@ public class SVMLinearBinaryClassificationTrainer implements Trainer<SVMLinearCl
             weights = weights.plus(deltaWeights); // creates new vector
         }
 
-        return new SVMLinearClassificationModel(weights.viewPart(1, weights.size() - 1), weights.get(0));
+        return new SVMLinearBinaryClassificationModel(weights.viewPart(1, weights.size() - 1), weights.get(0));
     }
 
     /** */
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationModel.java
new file mode 100644 (file)
index 0000000..fd91595
--- /dev/null
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.svm;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+import java.util.TreeMap;
+import org.apache.ignite.ml.Exportable;
+import org.apache.ignite.ml.Exporter;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.math.Vector;
+
+/** Base class for multi-classification model for set of SVM classifiers. */
+public class SVMLinearMultiClassClassificationModel implements Model<Vector, Double>, Exportable<SVMLinearMultiClassClassificationModel>, Serializable {
+    /** List of models associated with each class. */
+    private Map<Double, SVMLinearBinaryClassificationModel> models;
+
+    /** */
+    public SVMLinearMultiClassClassificationModel() {
+        this.models = new HashMap<>();
+    }
+
+    /** {@inheritDoc} */
+    @Override public Double apply(Vector input) {
+        TreeMap<Double, Double> maxMargins = new TreeMap<>();
+
+        models.forEach((k, v) -> maxMargins.put(input.dot(v.weights()) + v.intercept(), k));
+
+        return maxMargins.lastEntry().getValue();
+    }
+
+    /** {@inheritDoc} */
+    @Override public <P> void saveModel(Exporter<SVMLinearMultiClassClassificationModel, P> exporter, P path) {
+        exporter.save(this, path);
+    }
+
+    /** {@inheritDoc} */
+    @Override public boolean equals(Object o) {
+        if (this == o)
+            return true;
+        if (o == null || getClass() != o.getClass())
+            return false;
+        SVMLinearMultiClassClassificationModel mdl = (SVMLinearMultiClassClassificationModel)o;
+        return Objects.equals(models, mdl.models);
+    }
+
+    /** {@inheritDoc} */
+    @Override public int hashCode() {
+        return Objects.hash(models);
+    }
+
+    /** {@inheritDoc} */
+    @Override public String toString() {
+        StringBuilder wholeStr = new StringBuilder();
+
+        models.forEach((clsLb, mdl) -> {
+            wholeStr.append("The class with label " + clsLb + " has classifier: " + mdl.toString() + System.lineSeparator());
+        });
+
+        return wholeStr.toString();
+    }
+
+    /**
+     * Adds a specific SVM binary classifier to the bunch of same classifiers.
+     *
+     * @param clsLb The class label for the added model.
+     * @param mdl The model.
+     */
+    public void add(double clsLb, SVMLinearBinaryClassificationModel mdl) {
+        models.put(clsLb, mdl);
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
new file mode 100644 (file)
index 0000000..669e2e3
--- /dev/null
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.svm;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import org.apache.ignite.ml.Trainer;
+import org.apache.ignite.ml.structures.LabeledDataset;
+
+/**
+ * Base class for a soft-margin SVM linear multiclass-classification trainer based on the communication-efficient
+ * distributed dual coordinate ascent algorithm (CoCoA) with hinge-loss function.
+ *
+ * All common parameters are shared with bunch of binary classification trainers.
+ */
+public class SVMLinearMultiClassClassificationTrainer implements Trainer<SVMLinearMultiClassClassificationModel, LabeledDataset> {
+    /** Amount of outer SDCA algorithm iterations. */
+    private int amountOfIterations = 20;
+
+    /** Amount of local SDCA algorithm iterations. */
+    private int amountOfLocIterations = 50;
+
+    /** Regularization parameter. */
+    private double lambda = 0.2;
+
+    /**
+     * Returns model based on data.
+     *
+     * @param data data to build model.
+     * @return model.
+     */
+    @Override public SVMLinearMultiClassClassificationModel train(LabeledDataset data) {
+        List<Double> classes = getClassLabels(data);
+
+        SVMLinearMultiClassClassificationModel multiClsMdl = new SVMLinearMultiClassClassificationModel();
+
+        classes.forEach(clsLb -> {
+            LabeledDataset binarizedDataset = binarizeLabels(data, clsLb);
+
+            SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer()
+                .withAmountOfIterations(this.amountOfIterations())
+                .withAmountOfLocIterations(this.amountOfLocIterations())
+                .withLambda(this.lambda());
+
+            multiClsMdl.add(clsLb, trainer.train(binarizedDataset));
+        });
+
+        return multiClsMdl;
+    }
+
+    /**
+     * Copies the given data and changes class labels in +1 for chosen class and in -1 for the rest classes.
+     *
+     * @param data Data to transform.
+     * @param clsLb Chosen class in schema One-vs-Rest.
+     * @return Copy of dataset with new labels.
+     */
+    private LabeledDataset binarizeLabels(LabeledDataset data, double clsLb) {
+        final LabeledDataset ds = data.copy();
+
+        for (int i = 0; i < ds.rowSize(); i++)
+            ds.setLabel(i, ds.label(i) == clsLb ? 1.0 : -1.0);
+
+        return ds;
+    }
+
+    /** Iterates among dataset and collects class labels. */
+    private List<Double> getClassLabels(LabeledDataset data) {
+        final Set<Double> clsLabels = new HashSet<>();
+
+        for (int i = 0; i < data.rowSize(); i++)
+            clsLabels.add(data.label(i));
+
+        List<Double> res = new ArrayList<>();
+        res.addAll(clsLabels);
+
+        return res;
+    }
+
+    /**
+     * Set up the regularization parameter.
+     *
+     * @param lambda The regularization parameter. Should be more than 0.0.
+     * @return Trainer with new lambda parameter value.
+     */
+    public SVMLinearMultiClassClassificationTrainer withLambda(double lambda) {
+        assert lambda > 0.0;
+        this.lambda = lambda;
+        return this;
+    }
+
+    /**
+     * Gets the regularization lambda.
+     *
+     * @return The parameter value.
+     */
+    public double lambda() {
+        return lambda;
+    }
+
+    /**
+     * Gets the amount of outer iterations of SCDA algorithm.
+     *
+     * @return The parameter value.
+     */
+    public int amountOfIterations() {
+        return amountOfIterations;
+    }
+
+    /**
+     * Set up the amount of outer iterations of SCDA algorithm.
+     *
+     * @param amountOfIterations The parameter value.
+     * @return Trainer with new amountOfIterations parameter value.
+     */
+    public SVMLinearMultiClassClassificationTrainer withAmountOfIterations(int amountOfIterations) {
+        this.amountOfIterations = amountOfIterations;
+        return this;
+    }
+
+    /**
+     * Gets the amount of local iterations of SCDA algorithm.
+     *
+     * @return The parameter value.
+     */
+    public int amountOfLocIterations() {
+        return amountOfLocIterations;
+    }
+
+    /**
+     * Set up the amount of local iterations of SCDA algorithm.
+     *
+     * @param amountOfLocIterations The parameter value.
+     * @return Trainer with new amountOfLocIterations parameter value.
+     */
+    public SVMLinearMultiClassClassificationTrainer withAmountOfLocIterations(int amountOfLocIterations) {
+        this.amountOfLocIterations = amountOfLocIterations;
+        return this;
+    }
+}
+
+
+
index bcbc1fc..57d93d6 100644 (file)
@@ -31,7 +31,8 @@ import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
 import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
 import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
 import org.apache.ignite.ml.structures.LabeledDataset;
-import org.apache.ignite.ml.svm.SVMLinearClassificationModel;
+import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationModel;
+import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationModel;
 import org.junit.Assert;
 import org.junit.Test;
 
@@ -80,13 +81,41 @@ public class LocalModelsTest {
 
     /** */
     @Test
-    public void importExportSVMClassificationModelTest() throws IOException {
+    public void importExportSVMBinaryClassificationModelTest() throws IOException {
         executeModelTest(mdlFilePath -> {
-            SVMLinearClassificationModel mdl = new SVMLinearClassificationModel(new DenseLocalOnHeapVector(new double[]{1, 2}), 3);
-            Exporter<SVMLinearClassificationModel, String> exporter = new FileExporter<>();
+            SVMLinearBinaryClassificationModel mdl = new SVMLinearBinaryClassificationModel(new DenseLocalOnHeapVector(new double[]{1, 2}), 3);
+            Exporter<SVMLinearBinaryClassificationModel, String> exporter = new FileExporter<>();
             mdl.saveModel(exporter, mdlFilePath);
 
-            SVMLinearClassificationModel load = exporter.load(mdlFilePath);
+            SVMLinearBinaryClassificationModel load = exporter.load(mdlFilePath);
+
+            Assert.assertNotNull(load);
+            Assert.assertEquals("", mdl, load);
+
+            return null;
+        });
+    }
+
+
+    /** */
+    @Test
+    public void importExportSVMMulticlassClassificationModelTest() throws IOException {
+        executeModelTest(mdlFilePath -> {
+
+
+            SVMLinearBinaryClassificationModel binaryMdl1 = new SVMLinearBinaryClassificationModel(new DenseLocalOnHeapVector(new double[]{1, 2}), 3);
+            SVMLinearBinaryClassificationModel binaryMdl2 = new SVMLinearBinaryClassificationModel(new DenseLocalOnHeapVector(new double[]{2, 3}), 4);
+            SVMLinearBinaryClassificationModel binaryMdl3 = new SVMLinearBinaryClassificationModel(new DenseLocalOnHeapVector(new double[]{3, 4}), 5);
+
+            SVMLinearMultiClassClassificationModel mdl = new SVMLinearMultiClassClassificationModel();
+            mdl.add(1, binaryMdl1);
+            mdl.add(2, binaryMdl2);
+            mdl.add(3, binaryMdl3);
+
+            Exporter<SVMLinearMultiClassClassificationModel, String> exporter = new FileExporter<>();
+            mdl.saveModel(exporter, mdlFilePath);
+
+            SVMLinearMultiClassClassificationModel load = exporter.load(mdlFilePath);
 
             Assert.assertNotNull(load);
             Assert.assertEquals("", mdl, load);
index 1742593..35b6644 100644 (file)
@@ -36,7 +36,7 @@ public class SVMModelTest {
     @Test
     public void testPredictWithRawLabels() {
         Vector weights = new DenseLocalOnHeapVector(new double[]{2.0, 3.0});
-        SVMLinearClassificationModel mdl = new SVMLinearClassificationModel(weights, 1.0).withRawLabels(true);
+        SVMLinearBinaryClassificationModel mdl = new SVMLinearBinaryClassificationModel(weights, 1.0).withRawLabels(true);
 
         Vector observation = new DenseLocalOnHeapVector(new double[]{1.0, 1.0});
         TestUtils.assertEquals(1.0 + 2.0 * 1.0 + 3.0 * 1.0, mdl.apply(observation), PRECISION);
@@ -56,11 +56,27 @@ public class SVMModelTest {
         Assert.assertEquals(true, mdl.isKeepingRawLabels());
     }
 
+
+    /** */
+    @Test
+    public void testPredictWithMultiClasses() {
+        Vector weights1 = new DenseLocalOnHeapVector(new double[]{10.0, 0.0});
+        Vector weights2 = new DenseLocalOnHeapVector(new double[]{0.0, 10.0});
+        Vector weights3 = new DenseLocalOnHeapVector(new double[]{-1.0, -1.0});
+        SVMLinearMultiClassClassificationModel mdl = new SVMLinearMultiClassClassificationModel();
+        mdl.add(1, new SVMLinearBinaryClassificationModel(weights1, 0.0).withRawLabels(true));
+        mdl.add(2, new SVMLinearBinaryClassificationModel(weights2, 0.0).withRawLabels(true));
+        mdl.add(2, new SVMLinearBinaryClassificationModel(weights3, 0.0).withRawLabels(true));
+
+        Vector observation = new DenseLocalOnHeapVector(new double[]{1.0, 1.0});
+        TestUtils.assertEquals( 1.0, mdl.apply(observation), PRECISION);
+    }
+
     /** */
     @Test
     public void testPredictWithErasedLabels() {
         Vector weights = new DenseLocalOnHeapVector(new double[]{1.0, 1.0});
-        SVMLinearClassificationModel mdl = new SVMLinearClassificationModel(weights, 1.0);
+        SVMLinearBinaryClassificationModel mdl = new SVMLinearBinaryClassificationModel(weights, 1.0);
 
         Vector observation = new DenseLocalOnHeapVector(new double[]{1.0, 1.0});
         TestUtils.assertEquals(1.0, mdl.apply(observation), PRECISION);
@@ -87,7 +103,7 @@ public class SVMModelTest {
     @Test
     public void testPredictWithErasedLabelsAndChangedThreshold() {
         Vector weights = new DenseLocalOnHeapVector(new double[]{1.0, 1.0});
-        SVMLinearClassificationModel mdl = new SVMLinearClassificationModel(weights, 1.0).withThreshold(5);
+        SVMLinearBinaryClassificationModel mdl = new SVMLinearBinaryClassificationModel(weights, 1.0).withThreshold(5);
 
         Vector observation = new DenseLocalOnHeapVector(new double[]{1.0, 1.0});
         TestUtils.assertEquals(-1.0, mdl.apply(observation), PRECISION);
@@ -103,7 +119,7 @@ public class SVMModelTest {
     public void testPredictOnAnObservationWithWrongCardinality() {
         Vector weights = new DenseLocalOnHeapVector(new double[]{2.0, 3.0});
 
-        SVMLinearClassificationModel mdl = new SVMLinearClassificationModel(weights, 1.0);
+        SVMLinearBinaryClassificationModel mdl = new SVMLinearBinaryClassificationModel(weights, 1.0);
 
         Vector observation = new DenseLocalOnHeapVector(new double[]{1.0});
 
index bfc341c..853a43f 100644 (file)
 
 package org.apache.ignite.ml.svm;
 
+import org.apache.ignite.ml.svm.binary.DistributedLinearSVMBinaryClassificationTrainerTest;
+import org.apache.ignite.ml.svm.binary.LocalLinearSVMBinaryClassificationTrainerTest;
+import org.apache.ignite.ml.svm.multi.DistributedLinearSVMMultiClassClassificationTrainerTest;
+import org.apache.ignite.ml.svm.multi.LocalLinearSVMMultiClassClassificationTrainerTest;
 import org.junit.runner.RunWith;
 import org.junit.runners.Suite;
 
@@ -25,8 +29,10 @@ import org.junit.runners.Suite;
  */
 @RunWith(Suite.class)
 @Suite.SuiteClasses({
-    LocalLinearSVMClassificationSCDATrainerTest.class,
-    DistributedLinearSVMClassificationSCDATrainerTest.class,
+    LocalLinearSVMBinaryClassificationTrainerTest.class,
+    DistributedLinearSVMBinaryClassificationTrainerTest.class,
+    LocalLinearSVMMultiClassClassificationTrainerTest.class,
+    DistributedLinearSVMMultiClassClassificationTrainerTest.class,
     SVMModelTest.class
 })
 public class SVMTestSuite {
  * limitations under the License.
  */
 
-package org.apache.ignite.ml.svm;
+package org.apache.ignite.ml.svm.binary;
 
 import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
 import org.apache.ignite.ml.regressions.linear.LinearRegressionSGDTrainer;
+import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationTrainer;
 
 /**
  * Tests for {@link LinearRegressionSGDTrainer} on {@link DenseLocalOnHeapMatrix}.
  */
-public class DistributedLinearSVMClassificationSCDATrainerTest extends GenericLinearSVMTrainerTest {
+public class DistributedLinearSVMBinaryClassificationTrainerTest extends GenericLinearSVMBinaryClassificationTrainerTest {
     /** */
-    public DistributedLinearSVMClassificationSCDATrainerTest() {
+    public DistributedLinearSVMBinaryClassificationTrainerTest() {
         super(
             new SVMLinearBinaryClassificationTrainer(),
             true,
@@ -15,7 +15,7 @@
  * limitations under the License.
  */
 
-package org.apache.ignite.ml.svm;
+package org.apache.ignite.ml.svm.binary;
 
 import java.util.concurrent.ThreadLocalRandom;
 import org.apache.ignite.internal.util.IgniteUtils;
@@ -24,12 +24,14 @@ import org.apache.ignite.ml.Trainer;
 import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
 import org.apache.ignite.ml.structures.LabeledDataset;
 import org.apache.ignite.ml.structures.LabeledVector;
+import org.apache.ignite.ml.svm.BaseSVMTest;
+import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationModel;
 import org.junit.Test;
 
 /**
  * Base class for all linear regression trainers.
  */
-public class GenericLinearSVMTrainerTest extends BaseSVMTest {
+public class GenericLinearSVMBinaryClassificationTrainerTest extends BaseSVMTest {
     /** Fixed size of Dataset. */
     private static final int AMOUNT_OF_OBSERVATIONS = 100;
 
@@ -37,7 +39,7 @@ public class GenericLinearSVMTrainerTest extends BaseSVMTest {
     private static final int AMOUNT_OF_FEATURES = 2;
 
     /** */
-    private final Trainer<SVMLinearClassificationModel, LabeledDataset> trainer;
+    private final Trainer<SVMLinearBinaryClassificationModel, LabeledDataset> trainer;
 
     /** */
     private boolean isDistributed;
@@ -46,8 +48,8 @@ public class GenericLinearSVMTrainerTest extends BaseSVMTest {
     private final double precision;
 
     /** */
-    GenericLinearSVMTrainerTest(
-        Trainer<SVMLinearClassificationModel, LabeledDataset> trainer,
+    GenericLinearSVMBinaryClassificationTrainerTest(
+        Trainer<SVMLinearBinaryClassificationModel, LabeledDataset> trainer,
         boolean isDistributed,
         double precision) {
         super();
@@ -77,7 +79,7 @@ public class GenericLinearSVMTrainerTest extends BaseSVMTest {
             dataset.setLabel(i, lb);
         }
 
-        SVMLinearClassificationModel mdl = trainer.train(dataset);
+        SVMLinearBinaryClassificationModel mdl = trainer.train(dataset);
 
         TestUtils.assertEquals(-1, mdl.apply(new DenseLocalOnHeapVector(new double[] {100, 10})), precision);
         TestUtils.assertEquals(1, mdl.apply(new DenseLocalOnHeapVector(new double[] {10, 100})), precision);
@@ -104,7 +106,7 @@ public class GenericLinearSVMTrainerTest extends BaseSVMTest {
             dataset.setLabel(i, lb);
         }
 
-        SVMLinearClassificationModel mdl = trainer.train(dataset);
+        SVMLinearBinaryClassificationModel mdl = trainer.train(dataset);
 
         TestUtils.assertEquals(-1, mdl.apply(new DenseLocalOnHeapVector(new double[] {100, 10})), precision);
         TestUtils.assertEquals(1, mdl.apply(new DenseLocalOnHeapVector(new double[] {10, 100})), precision);
@@ -131,7 +133,7 @@ public class GenericLinearSVMTrainerTest extends BaseSVMTest {
             dataset.setLabel(i, lb);
         }
 
-        SVMLinearClassificationModel mdl = trainer.train(dataset);
+        SVMLinearBinaryClassificationModel mdl = trainer.train(dataset);
 
         TestUtils.assertEquals(-1, mdl.apply(new DenseLocalOnHeapVector(new double[] {100, 10})), precision);
         TestUtils.assertEquals(1, mdl.apply(new DenseLocalOnHeapVector(new double[] {10, 100})), precision);
  * limitations under the License.
  */
 
-package org.apache.ignite.ml.svm;
+package org.apache.ignite.ml.svm.binary;
 
 import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
 import org.apache.ignite.ml.regressions.linear.LinearRegressionSGDTrainer;
+import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationTrainer;
 
 /**
  * Tests for {@link LinearRegressionSGDTrainer} on {@link DenseLocalOnHeapMatrix}.
  */
-public class LocalLinearSVMClassificationSCDATrainerTest extends GenericLinearSVMTrainerTest {
+public class LocalLinearSVMBinaryClassificationTrainerTest extends GenericLinearSVMBinaryClassificationTrainerTest {
     /** */
-    public LocalLinearSVMClassificationSCDATrainerTest() {
+    public LocalLinearSVMBinaryClassificationTrainerTest() {
         super(
             new SVMLinearBinaryClassificationTrainer()
                 .withLambda(0.2)
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/multi/DistributedLinearSVMMultiClassClassificationTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/multi/DistributedLinearSVMMultiClassClassificationTrainerTest.java
new file mode 100644 (file)
index 0000000..6806e0b
--- /dev/null
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.svm.multi;
+
+import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
+import org.apache.ignite.ml.regressions.linear.LinearRegressionSGDTrainer;
+import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationTrainer;
+
+/**
+ * Tests for {@link LinearRegressionSGDTrainer} on {@link DenseLocalOnHeapMatrix}.
+ */
+public class DistributedLinearSVMMultiClassClassificationTrainerTest extends GenericLinearSVMMultiClassClassificationTrainerTest {
+    /** */
+    public DistributedLinearSVMMultiClassClassificationTrainerTest() {
+        super(
+            new SVMLinearMultiClassClassificationTrainer(),
+            true,
+            1e-2);
+    }
+}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/multi/GenericLinearSVMMultiClassClassificationTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/multi/GenericLinearSVMMultiClassClassificationTrainerTest.java
new file mode 100644 (file)
index 0000000..8c6083d
--- /dev/null
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.svm.multi;
+
+import org.apache.ignite.internal.util.IgniteUtils;
+import org.apache.ignite.ml.TestUtils;
+import org.apache.ignite.ml.Trainer;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+import org.apache.ignite.ml.structures.LabeledDataset;
+import org.apache.ignite.ml.svm.BaseSVMTest;
+import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationModel;
+import org.junit.Test;
+
+/**
+ * Base class for all linear regression trainers.
+ */
+public class GenericLinearSVMMultiClassClassificationTrainerTest extends BaseSVMTest {
+    /** */
+    private final Trainer<SVMLinearMultiClassClassificationModel, LabeledDataset> trainer;
+
+    /** */
+    private boolean isDistributed;
+
+    /** */
+    private final double precision;
+
+    /** */
+    GenericLinearSVMMultiClassClassificationTrainerTest(
+        Trainer<SVMLinearMultiClassClassificationModel, LabeledDataset> trainer,
+        boolean isDistributed,
+        double precision) {
+        super();
+        this.trainer = trainer;
+        this.precision = precision;
+        this.isDistributed = isDistributed;
+    }
+
+    /**
+     * Test trainer on classification model y = x.
+     */
+    @Test
+    public void testTrainWithTheLinearlySeparableCase() {
+        if (isDistributed)
+            IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+        double[][] mtx =
+            new double[][] {
+                {-10.0, 12.0},
+                {-5.0, 14.0},
+                {-3.0, 18.0},
+                {13.0, -1.0},
+                {10.0, -2.0},
+                {15.0, -3.0}};
+        double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
+        LabeledDataset dataset = new LabeledDataset(mtx, lbs, null, isDistributed);
+
+
+        SVMLinearMultiClassClassificationModel mdl = trainer.train(dataset);
+        TestUtils.assertEquals(1.0, mdl.apply(new DenseLocalOnHeapVector(new double[] {-2.0, 15})), precision);
+        TestUtils.assertEquals(2.0, mdl.apply(new DenseLocalOnHeapVector(new double[] {12, -5})), precision);
+    }
+}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/multi/LocalLinearSVMMultiClassClassificationTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/multi/LocalLinearSVMMultiClassClassificationTrainerTest.java
new file mode 100644 (file)
index 0000000..a239c95
--- /dev/null
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.svm.multi;
+
+import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
+import org.apache.ignite.ml.regressions.linear.LinearRegressionSGDTrainer;
+import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationTrainer;
+
+/**
+ * Tests for {@link LinearRegressionSGDTrainer} on {@link DenseLocalOnHeapMatrix}.
+ */
+public class LocalLinearSVMMultiClassClassificationTrainerTest extends GenericLinearSVMMultiClassClassificationTrainerTest {
+    /** */
+    public LocalLinearSVMMultiClassClassificationTrainerTest() {
+        super(
+            new SVMLinearMultiClassClassificationTrainer()
+                .withLambda(0.2)
+                .withAmountOfIterations(10)
+                .withAmountOfLocIterations(20),
+            false,
+            1e-2);
+    }
+}