IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / modules / ml / src / test / java / org / apache / ignite / ml / composition / boosting / GDBTrainerTest.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.ml.composition.boosting;
19
20 import java.util.HashMap;
21 import java.util.Map;
22 import org.apache.ignite.ml.Model;
23 import org.apache.ignite.ml.composition.ModelsComposition;
24 import org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator;
25 import org.apache.ignite.ml.math.Vector;
26 import org.apache.ignite.ml.math.VectorUtils;
27 import org.apache.ignite.ml.trainers.DatasetTrainer;
28 import org.apache.ignite.ml.tree.DecisionTreeConditionalNode;
29 import org.apache.ignite.ml.tree.boosting.GDBBinaryClassifierOnTreesTrainer;
30 import org.apache.ignite.ml.tree.boosting.GDBRegressionOnTreesTrainer;
31 import org.junit.Test;
32
33 import static org.junit.Assert.assertEquals;
34 import static org.junit.Assert.assertTrue;
35
36 /** */
37 public class GDBTrainerTest {
38 /** */
39 @Test public void testFitRegression() {
40 int size = 100;
41 double[] xs = new double[size];
42 double[] ys = new double[size];
43 double from = -5.0;
44 double to = 5.0;
45 double step = Math.abs(from - to) / size;
46
47 Map<Integer, double[]> learningSample = new HashMap<>();
48 for (int i = 0; i < size; i++) {
49 xs[i] = from + step * i;
50 ys[i] = 2 * xs[i];
51 learningSample.put(i, new double[] {xs[i], ys[i]});
52 }
53
54 DatasetTrainer<Model<Vector, Double>, Double> trainer = new GDBRegressionOnTreesTrainer(1.0, 2000, 3, 0.0);
55 Model<Vector, Double> model = trainer.fit(
56 learningSample, 1,
57 (k, v) -> VectorUtils.of(v[0]),
58 (k, v) -> v[1]
59 );
60
61 double mse = 0.0;
62 for (int j = 0; j < size; j++) {
63 double x = xs[j];
64 double y = ys[j];
65 double p = model.apply(VectorUtils.of(x));
66 mse += Math.pow(y - p, 2);
67 }
68 mse /= size;
69
70 assertEquals(0.0, mse, 0.0001);
71
72 assertTrue(model instanceof ModelsComposition);
73 ModelsComposition composition = (ModelsComposition) model;
74 composition.getModels().forEach(m -> assertTrue(m instanceof DecisionTreeConditionalNode));
75
76 assertEquals(2000, composition.getModels().size());
77 assertTrue(composition.getPredictionsAggregator() instanceof WeightedPredictionsAggregator);
78 }
79
80 /** */
81 @Test public void testFitClassifier() {
82 int sampleSize = 100;
83 double[] xs = new double[sampleSize];
84 double[] ys = new double[sampleSize];
85
86 for (int i = 0; i < sampleSize; i++) {
87 xs[i] = i;
88 ys[i] = ((int)(xs[i] / 10.0) % 2) == 0 ? -1.0 : 1.0;
89 }
90
91 Map<Integer, double[]> learningSample = new HashMap<>();
92 for (int i = 0; i < sampleSize; i++)
93 learningSample.put(i, new double[] {xs[i], ys[i]});
94
95 DatasetTrainer<Model<Vector, Double>, Double> trainer = new GDBBinaryClassifierOnTreesTrainer(0.3, 500, 3, 0.0);
96 Model<Vector, Double> model = trainer.fit(
97 learningSample, 1,
98 (k, v) -> VectorUtils.of(v[0]),
99 (k, v) -> v[1]
100 );
101
102 int errorsCount = 0;
103 for (int j = 0; j < sampleSize; j++) {
104 double x = xs[j];
105 double y = ys[j];
106 double p = model.apply(VectorUtils.of(x));
107 if(p != y)
108 errorsCount++;
109 }
110
111 assertEquals(0, errorsCount);
112
113 assertTrue(model instanceof ModelsComposition);
114 ModelsComposition composition = (ModelsComposition) model;
115 composition.getModels().forEach(m -> assertTrue(m instanceof DecisionTreeConditionalNode));
116
117 assertEquals(500, composition.getModels().size());
118 assertTrue(composition.getPredictionsAggregator() instanceof WeightedPredictionsAggregator);
119 }
120 }