IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / modules / ml / src / test / java / org / apache / ignite / ml / svm / SVMMultiClassTrainerTest.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.ml.svm;
19
20 import java.util.Arrays;
21 import java.util.HashMap;
22 import java.util.Map;
23 import java.util.concurrent.ThreadLocalRandom;
24 import org.apache.ignite.ml.TestUtils;
25 import org.apache.ignite.ml.math.VectorUtils;
26 import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
27 import org.junit.Test;
28
29 /**
30 * Tests for {@link SVMLinearBinaryClassificationTrainer}.
31 */
32 public class SVMMultiClassTrainerTest {
33 /** Fixed size of Dataset. */
34 private static final int AMOUNT_OF_OBSERVATIONS = 1000;
35
36 /** Fixed size of columns in Dataset. */
37 private static final int AMOUNT_OF_FEATURES = 2;
38
39 /** Precision in test checks. */
40 private static final double PRECISION = 1e-2;
41
42 /**
43 * Test trainer on classification model y = x.
44 */
45 @Test
46 public void testTrainWithTheLinearlySeparableCase() {
47 Map<Integer, double[]> data = new HashMap<>();
48
49 ThreadLocalRandom rndX = ThreadLocalRandom.current();
50 ThreadLocalRandom rndY = ThreadLocalRandom.current();
51
52 for (int i = 0; i < AMOUNT_OF_OBSERVATIONS; i++) {
53 double x = rndX.nextDouble(-1000, 1000);
54 double y = rndY.nextDouble(-1000, 1000);
55 double[] vec = new double[AMOUNT_OF_FEATURES + 1];
56 vec[0] = y - x > 0 ? 1 : -1; // assign label.
57 vec[1] = x;
58 vec[2] = y;
59 data.put(i, vec);
60 }
61
62 SVMLinearMultiClassClassificationTrainer trainer = new SVMLinearMultiClassClassificationTrainer()
63 .withLambda(0.3)
64 .withAmountOfLocIterations(100)
65 .withAmountOfIterations(20);
66
67 SVMLinearMultiClassClassificationModel mdl = trainer.fit(
68 data,
69 10,
70 (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
71 (k, v) -> v[0]
72 );
73
74 TestUtils.assertEquals(-1, mdl.apply(new DenseLocalOnHeapVector(new double[]{100, 10})), PRECISION);
75 TestUtils.assertEquals(1, mdl.apply(new DenseLocalOnHeapVector(new double[]{10, 100})), PRECISION);
76 }
77 }