2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 package org
.apache
.ignite
.ml
.tree
.randomforest
;
20 import java
.util
.ArrayList
;
21 import java
.util
.HashMap
;
22 import java
.util
.List
;
24 import org
.apache
.ignite
.ml
.composition
.ModelOnFeaturesSubspace
;
25 import org
.apache
.ignite
.ml
.composition
.ModelsComposition
;
26 import org
.apache
.ignite
.ml
.composition
.predictionsaggregator
.MeanValuePredictionsAggregator
;
27 import org
.apache
.ignite
.ml
.math
.VectorUtils
;
28 import org
.apache
.ignite
.ml
.tree
.DecisionTreeConditionalNode
;
29 import org
.junit
.Test
;
30 import org
.junit
.runner
.RunWith
;
31 import org
.junit
.runners
.Parameterized
;
33 import static org
.junit
.Assert
.assertEquals
;
34 import static org
.junit
.Assert
.assertTrue
;
36 @RunWith(Parameterized
.class)
37 public class RandomForestRegressionTrainerTest
{
39 * Number of parts to be tested.
41 private static final int[] partsToBeTested
= new int[] {1, 2, 3, 4, 5, 7};
44 * Number of partitions.
46 @Parameterized.Parameter
49 @Parameterized.Parameters(name
= "Data divided on {0} partitions")
50 public static Iterable
<Integer
[]> data() {
51 List
<Integer
[]> res
= new ArrayList
<>();
52 for (int part
: partsToBeTested
)
53 res
.add(new Integer
[] {part
});
59 @Test public void testFit() {
60 int sampleSize
= 1000;
61 Map
<Double
, double[]> sample
= new HashMap
<>();
62 for (int i
= 0; i
< sampleSize
; i
++) {
64 double x2
= x1
/ 10.0;
65 double x3
= x2
/ 10.0;
66 double x4
= x3
/ 10.0;
68 sample
.put(x1
* x2
+ x3
* x4
, new double[] {x1
, x2
, x3
, x4
});
71 RandomForestRegressionTrainer trainer
= new RandomForestRegressionTrainer(4, 3, 5, 0.3, 4, 0.1);
72 ModelsComposition model
= trainer
.fit(sample
, parts
, (k
, v
) -> VectorUtils
.of(v
), (k
, v
) -> k
);
73 model
.getModels().forEach(m
-> {
74 assertTrue(m
instanceof ModelOnFeaturesSubspace
);
75 assertTrue(((ModelOnFeaturesSubspace
) m
).getMdl() instanceof DecisionTreeConditionalNode
);
78 assertTrue(model
.getPredictionsAggregator() instanceof MeanValuePredictionsAggregator
);
79 assertEquals(5, model
.getModels().size());