IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / modules / ml / src / test / java / org / apache / ignite / ml / nn / MLPTrainerTest.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.ml.nn;
19
20 import java.io.Serializable;
21 import java.util.ArrayList;
22 import java.util.HashMap;
23 import java.util.List;
24 import java.util.Map;
25 import org.apache.ignite.ml.TestUtils;
26 import org.apache.ignite.ml.math.Matrix;
27 import org.apache.ignite.ml.math.VectorUtils;
28 import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
29 import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
30 import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
31 import org.apache.ignite.ml.optimization.LossFunctions;
32 import org.apache.ignite.ml.optimization.updatecalculators.NesterovParameterUpdate;
33 import org.apache.ignite.ml.optimization.updatecalculators.NesterovUpdateCalculator;
34 import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate;
35 import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator;
36 import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
37 import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
38 import org.junit.Before;
39 import org.junit.Test;
40 import org.junit.experimental.runners.Enclosed;
41 import org.junit.runner.RunWith;
42 import org.junit.runners.Parameterized;
43
44 /**
45 * Tests for {@link MLPTrainer} that don't require to start the whole Ignite infrastructure.
46 */
47 @RunWith(Enclosed.class)
48 public class MLPTrainerTest {
49 /**
50 * Parameterized tests.
51 */
52 @RunWith(Parameterized.class)
53 public static class ComponentParamTests {
54 /** Number of parts to be tested. */
55 private static final int[] partsToBeTested = new int[] {1, 2, 3, 4, 5, 7};
56
57 /** Batch sizes to be tested. */
58 private static final int[] batchSizesToBeTested = new int[] {1, 2, 3, 4};
59
60 /** Parameters. */
61 @Parameterized.Parameters(name = "Data divided on {0} partitions, training with batch size {1}")
62 public static Iterable<Integer[]> data() {
63 List<Integer[]> res = new ArrayList<>();
64 for (int part : partsToBeTested)
65 for (int batchSize1 : batchSizesToBeTested)
66 res.add(new Integer[] {part, batchSize1});
67
68 return res;
69 }
70
71 /** Number of partitions. */
72 @Parameterized.Parameter
73 public int parts;
74
75 /** Batch size. */
76 @Parameterized.Parameter(1)
77 public int batchSize;
78
79 /**
80 * Test 'XOR' operation training with {@link SimpleGDUpdateCalculator} updater.
81 */
82 @Test
83 public void testXORSimpleGD() {
84 xorTest(new UpdatesStrategy<>(
85 new SimpleGDUpdateCalculator(0.2),
86 SimpleGDParameterUpdate::sumLocal,
87 SimpleGDParameterUpdate::avg
88 ));
89 }
90
91 /**
92 * Test 'XOR' operation training with {@link RPropUpdateCalculator}.
93 */
94 @Test
95 public void testXORRProp() {
96 xorTest(new UpdatesStrategy<>(
97 new RPropUpdateCalculator(),
98 RPropParameterUpdate::sumLocal,
99 RPropParameterUpdate::avg
100 ));
101 }
102
103 /**
104 * Test 'XOR' operation training with {@link NesterovUpdateCalculator}.
105 */
106 @Test
107 public void testXORNesterov() {
108 xorTest(new UpdatesStrategy<>(
109 new NesterovUpdateCalculator<MultilayerPerceptron>(0.1, 0.7),
110 NesterovParameterUpdate::sum,
111 NesterovParameterUpdate::avg
112 ));
113 }
114
115 /**
116 * Common method for testing 'XOR' with various updaters.
117 * @param updatesStgy Update strategy.
118 * @param <P> Updater parameters type.
119 */
120 private <P extends Serializable> void xorTest(UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy) {
121 Map<Integer, double[][]> xorData = new HashMap<>();
122 xorData.put(0, new double[][]{{0.0, 0.0}, {0.0}});
123 xorData.put(1, new double[][]{{0.0, 1.0}, {1.0}});
124 xorData.put(2, new double[][]{{1.0, 0.0}, {1.0}});
125 xorData.put(3, new double[][]{{1.0, 1.0}, {0.0}});
126
127 MLPArchitecture arch = new MLPArchitecture(2).
128 withAddedLayer(10, true, Activators.RELU).
129 withAddedLayer(1, false, Activators.SIGMOID);
130
131 MLPTrainer<P> trainer = new MLPTrainer<>(
132 arch,
133 LossFunctions.MSE,
134 updatesStgy,
135 3000,
136 batchSize,
137 50,
138 123L
139 );
140
141 MultilayerPerceptron mlp = trainer.fit(
142 xorData,
143 parts,
144 (k, v) -> VectorUtils.of(v[0]),
145 (k, v) -> v[1]
146 );
147
148 Matrix predict = mlp.apply(new DenseLocalOnHeapMatrix(new double[][]{
149 {0.0, 0.0},
150 {0.0, 1.0},
151 {1.0, 0.0},
152 {1.0, 1.0}
153 }));
154
155 TestUtils.checkIsInEpsilonNeighbourhood(new DenseLocalOnHeapVector(new double[]{0.0}), predict.getRow(0), 1E-1);
156 }
157 }
158
159 /**
160 * Non-parameterized tests.
161 */
162 public static class ComponentSingleTests {
163 /** Data. */
164 private double[] data;
165
166 /** Initialization. */
167 @Before
168 public void init() {
169 data = new double[10];
170 for (int i = 0; i < 10; i++)
171 data[i] = i;
172 }
173
174 /** */
175 @Test
176 public void testBatchWithSingleColumnAndSingleRow() {
177 double[] res = MLPTrainer.batch(data, new int[]{1}, 10);
178
179 TestUtils.assertEquals(new double[]{1.0}, res, 1e-12);
180 }
181
182 /** */
183 @Test
184 public void testBatchWithMultiColumnAndSingleRow() {
185 double[] res = MLPTrainer.batch(data, new int[]{1}, 5);
186
187 TestUtils.assertEquals(new double[]{1.0, 6.0}, res, 1e-12);
188 }
189
190 /** */
191 @Test
192 public void testBatchWithMultiColumnAndMultiRow() {
193 double[] res = MLPTrainer.batch(data, new int[]{1, 3}, 5);
194
195 TestUtils.assertEquals(new double[]{1.0, 3.0, 6.0, 8.0}, res, 1e-12);
196 }
197 }
198 }