IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / modules / ml / src / test / java / org / apache / ignite / ml / tree / DecisionTreeRegressionTrainerIntegrationTest.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.ml.tree;
19
20 import java.util.Arrays;
21 import java.util.Random;
22 import org.apache.ignite.Ignite;
23 import org.apache.ignite.IgniteCache;
24 import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
25 import org.apache.ignite.configuration.CacheConfiguration;
26 import org.apache.ignite.internal.util.IgniteUtils;
27 import org.apache.ignite.ml.math.VectorUtils;
28 import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
29
30 /**
31 * Tests for {@link DecisionTreeRegressionTrainer} that require to start the whole Ignite infrastructure.
32 */
33 public class DecisionTreeRegressionTrainerIntegrationTest extends GridCommonAbstractTest {
34 /** Number of nodes in grid */
35 private static final int NODE_COUNT = 3;
36
37 /** Ignite instance. */
38 private Ignite ignite;
39
40 /** {@inheritDoc} */
41 @Override protected void beforeTestsStarted() throws Exception {
42 for (int i = 1; i <= NODE_COUNT; i++)
43 startGrid(i);
44 }
45
46 /** {@inheritDoc} */
47 @Override protected void afterTestsStopped() {
48 stopAllGrids();
49 }
50
51 /**
52 * {@inheritDoc}
53 */
54 @Override protected void beforeTest() throws Exception {
55 /* Grid instance. */
56 ignite = grid(NODE_COUNT);
57 ignite.configuration().setPeerClassLoadingEnabled(true);
58 IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
59 }
60
61 /** */
62 public void testFit() {
63 int size = 100;
64
65 CacheConfiguration<Integer, double[]> trainingSetCacheCfg = new CacheConfiguration<>();
66 trainingSetCacheCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
67 trainingSetCacheCfg.setName("TRAINING_SET");
68
69 IgniteCache<Integer, double[]> data = ignite.createCache(trainingSetCacheCfg);
70
71 Random rnd = new Random(0);
72 for (int i = 0; i < size; i++) {
73 double x = rnd.nextDouble() - 0.5;
74 data.put(i, new double[]{x, x > 0 ? 1 : 0});
75 }
76
77 DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(1, 0);
78
79 DecisionTreeNode tree = trainer.fit(
80 ignite,
81 data,
82 (k, v) -> VectorUtils.of(Arrays.copyOf(v, v.length - 1)),
83 (k, v) -> v[v.length - 1]
84 );
85
86 assertTrue(tree instanceof DecisionTreeConditionalNode);
87
88 DecisionTreeConditionalNode node = (DecisionTreeConditionalNode) tree;
89
90 assertEquals(0, node.getThreshold(), 1e-3);
91
92 assertTrue(node.getThenNode() instanceof DecisionTreeLeafNode);
93 assertTrue(node.getElseNode() instanceof DecisionTreeLeafNode);
94
95 DecisionTreeLeafNode thenNode = (DecisionTreeLeafNode) node.getThenNode();
96 DecisionTreeLeafNode elseNode = (DecisionTreeLeafNode) node.getElseNode();
97
98 assertEquals(1, thenNode.getVal(), 1e-10);
99 assertEquals(0, elseNode.getVal(), 1e-10);
100 }
101 }