2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 package org
.apache
.ignite
.ml
.svm
;
20 import java
.util
.Arrays
;
21 import java
.util
.HashMap
;
23 import java
.util
.concurrent
.ThreadLocalRandom
;
24 import org
.apache
.ignite
.ml
.TestUtils
;
25 import org
.apache
.ignite
.ml
.math
.VectorUtils
;
26 import org
.apache
.ignite
.ml
.math
.impls
.vector
.DenseLocalOnHeapVector
;
27 import org
.junit
.Test
;
30 * Tests for {@link SVMLinearBinaryClassificationTrainer}.
32 public class SVMMultiClassTrainerTest
{
33 /** Fixed size of Dataset. */
34 private static final int AMOUNT_OF_OBSERVATIONS
= 1000;
36 /** Fixed size of columns in Dataset. */
37 private static final int AMOUNT_OF_FEATURES
= 2;
39 /** Precision in test checks. */
40 private static final double PRECISION
= 1e-2;
43 * Test trainer on classification model y = x.
46 public void testTrainWithTheLinearlySeparableCase() {
47 Map
<Integer
, double[]> data
= new HashMap
<>();
49 ThreadLocalRandom rndX
= ThreadLocalRandom
.current();
50 ThreadLocalRandom rndY
= ThreadLocalRandom
.current();
52 for (int i
= 0; i
< AMOUNT_OF_OBSERVATIONS
; i
++) {
53 double x
= rndX
.nextDouble(-1000, 1000);
54 double y
= rndY
.nextDouble(-1000, 1000);
55 double[] vec
= new double[AMOUNT_OF_FEATURES
+ 1];
56 vec
[0] = y
- x
> 0 ?
1 : -1; // assign label.
62 SVMLinearMultiClassClassificationTrainer trainer
= new SVMLinearMultiClassClassificationTrainer()
64 .withAmountOfLocIterations(100)
65 .withAmountOfIterations(20);
67 SVMLinearMultiClassClassificationModel mdl
= trainer
.fit(
70 (k
, v
) -> VectorUtils
.of(Arrays
.copyOfRange(v
, 1, v
.length
)),
74 TestUtils
.assertEquals(-1, mdl
.apply(new DenseLocalOnHeapVector(new double[]{100, 10})), PRECISION
);
75 TestUtils
.assertEquals(1, mdl
.apply(new DenseLocalOnHeapVector(new double[]{10, 100})), PRECISION
);