IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / modules / ml / src / test / java / org / apache / ignite / ml / tree / DecisionTreeClassificationTrainerTest.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.ml.tree;
19
20 import java.util.ArrayList;
21 import java.util.Arrays;
22 import java.util.HashMap;
23 import java.util.List;
24 import java.util.Map;
25 import java.util.Random;
26 import org.apache.ignite.ml.math.VectorUtils;
27 import org.junit.Test;
28 import org.junit.runner.RunWith;
29 import org.junit.runners.Parameterized;
30
31 import static junit.framework.TestCase.assertEquals;
32 import static junit.framework.TestCase.assertTrue;
33
34 /**
35 * Tests for {@link DecisionTreeClassificationTrainer}.
36 */
37 @RunWith(Parameterized.class)
38 public class DecisionTreeClassificationTrainerTest {
39 /** Number of parts to be tested. */
40 private static final int[] partsToBeTested = new int[] {1, 2, 3, 4, 5, 7};
41
42 /** Number of partitions. */
43 @Parameterized.Parameter
44 public int parts;
45
46
47 @Parameterized.Parameters(name = "Data divided on {0} partitions")
48 public static Iterable<Integer[]> data() {
49 List<Integer[]> res = new ArrayList<>();
50 for (int part : partsToBeTested)
51 res.add(new Integer[] {part});
52
53 return res;
54 }
55
56 /** */
57 @Test
58 public void testFit() {
59 int size = 100;
60
61 Map<Integer, double[]> data = new HashMap<>();
62
63 Random rnd = new Random(0);
64 for (int i = 0; i < size; i++) {
65 double x = rnd.nextDouble() - 0.5;
66 data.put(i, new double[]{x, x > 0 ? 1 : 0});
67 }
68
69 DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(1, 0);
70
71 DecisionTreeNode tree = trainer.fit(
72 data,
73 parts,
74 (k, v) -> VectorUtils.of(Arrays.copyOf(v, v.length - 1)),
75 (k, v) -> v[v.length - 1]
76 );
77
78 assertTrue(tree instanceof DecisionTreeConditionalNode);
79
80 DecisionTreeConditionalNode node = (DecisionTreeConditionalNode) tree;
81
82 assertEquals(0, node.getThreshold(), 1e-3);
83
84 assertTrue(node.getThenNode() instanceof DecisionTreeLeafNode);
85 assertTrue(node.getElseNode() instanceof DecisionTreeLeafNode);
86
87 DecisionTreeLeafNode thenNode = (DecisionTreeLeafNode) node.getThenNode();
88 DecisionTreeLeafNode elseNode = (DecisionTreeLeafNode) node.getElseNode();
89
90 assertEquals(1, thenNode.getVal(), 1e-10);
91 assertEquals(0, elseNode.getVal(), 1e-10);
92 }
93 }