IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / modules / ml / src / test / java / org / apache / ignite / ml / tree / randomforest / RandomForestRegressionTrainerTest.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.ml.tree.randomforest;
19
20 import java.util.ArrayList;
21 import java.util.HashMap;
22 import java.util.List;
23 import java.util.Map;
24 import org.apache.ignite.ml.composition.ModelOnFeaturesSubspace;
25 import org.apache.ignite.ml.composition.ModelsComposition;
26 import org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictionsAggregator;
27 import org.apache.ignite.ml.math.VectorUtils;
28 import org.apache.ignite.ml.tree.DecisionTreeConditionalNode;
29 import org.junit.Test;
30 import org.junit.runner.RunWith;
31 import org.junit.runners.Parameterized;
32
33 import static org.junit.Assert.assertEquals;
34 import static org.junit.Assert.assertTrue;
35
36 @RunWith(Parameterized.class)
37 public class RandomForestRegressionTrainerTest {
38 /**
39 * Number of parts to be tested.
40 */
41 private static final int[] partsToBeTested = new int[] {1, 2, 3, 4, 5, 7};
42
43 /**
44 * Number of partitions.
45 */
46 @Parameterized.Parameter
47 public int parts;
48
49 @Parameterized.Parameters(name = "Data divided on {0} partitions")
50 public static Iterable<Integer[]> data() {
51 List<Integer[]> res = new ArrayList<>();
52 for (int part : partsToBeTested)
53 res.add(new Integer[] {part});
54
55 return res;
56 }
57
58 /** */
59 @Test public void testFit() {
60 int sampleSize = 1000;
61 Map<Double, double[]> sample = new HashMap<>();
62 for (int i = 0; i < sampleSize; i++) {
63 double x1 = i;
64 double x2 = x1 / 10.0;
65 double x3 = x2 / 10.0;
66 double x4 = x3 / 10.0;
67
68 sample.put(x1 * x2 + x3 * x4, new double[] {x1, x2, x3, x4});
69 }
70
71 RandomForestRegressionTrainer trainer = new RandomForestRegressionTrainer(4, 3, 5, 0.3, 4, 0.1);
72 ModelsComposition model = trainer.fit(sample, parts, (k, v) -> VectorUtils.of(v), (k, v) -> k);
73 model.getModels().forEach(m -> {
74 assertTrue(m instanceof ModelOnFeaturesSubspace);
75 assertTrue(((ModelOnFeaturesSubspace) m).getMdl() instanceof DecisionTreeConditionalNode);
76 });
77
78 assertTrue(model.getPredictionsAggregator() instanceof MeanValuePredictionsAggregator);
79 assertEquals(5, model.getModels().size());
80 }
81 }