IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / modules / ml / src / test / java / org / apache / ignite / ml / nn / MLPTrainerIntegrationTest.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.ml.nn;
19
20 import java.io.Serializable;
21 import org.apache.ignite.Ignite;
22 import org.apache.ignite.IgniteCache;
23 import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
24 import org.apache.ignite.configuration.CacheConfiguration;
25 import org.apache.ignite.internal.util.IgniteUtils;
26 import org.apache.ignite.internal.util.typedef.X;
27 import org.apache.ignite.ml.TestUtils;
28 import org.apache.ignite.ml.math.Matrix;
29 import org.apache.ignite.ml.math.Tracer;
30 import org.apache.ignite.ml.math.VectorUtils;
31 import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
32 import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
33 import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
34 import org.apache.ignite.ml.optimization.LossFunctions;
35 import org.apache.ignite.ml.optimization.updatecalculators.NesterovParameterUpdate;
36 import org.apache.ignite.ml.optimization.updatecalculators.NesterovUpdateCalculator;
37 import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate;
38 import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator;
39 import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
40 import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
41 import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
42
43 /**
44 * Tests for {@link MLPTrainer} that require to start the whole Ignite infrastructure.
45 */
46 public class MLPTrainerIntegrationTest extends GridCommonAbstractTest {
47 /** Number of nodes in grid */
48 private static final int NODE_COUNT = 3;
49
50 /** Ignite instance. */
51 private Ignite ignite;
52
53 /** {@inheritDoc} */
54 @Override protected void beforeTestsStarted() throws Exception {
55 for (int i = 1; i <= NODE_COUNT; i++)
56 startGrid(i);
57 }
58
59 /** {@inheritDoc} */
60 @Override protected void afterTestsStopped() {
61 stopAllGrids();
62 }
63
64 /**
65 * {@inheritDoc}
66 */
67 @Override protected void beforeTest() throws Exception {
68 /* Grid instance. */
69 ignite = grid(NODE_COUNT);
70 ignite.configuration().setPeerClassLoadingEnabled(true);
71 IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
72 }
73
74 /**
75 * Test 'XOR' operation training with {@link SimpleGDUpdateCalculator}.
76 */
77 public void testXORSimpleGD() {
78 xorTest(new UpdatesStrategy<>(
79 new SimpleGDUpdateCalculator(0.3),
80 SimpleGDParameterUpdate::sumLocal,
81 SimpleGDParameterUpdate::avg
82 ));
83 }
84
85 /**
86 * Test 'XOR' operation training with {@link RPropUpdateCalculator}.
87 */
88 public void testXORRProp() {
89 xorTest(new UpdatesStrategy<>(
90 new RPropUpdateCalculator(),
91 RPropParameterUpdate::sumLocal,
92 RPropParameterUpdate::avg
93 ));
94 }
95
96 /**
97 * Test 'XOR' operation training with {@link NesterovUpdateCalculator}.
98 */
99 public void testXORNesterov() {
100 xorTest(new UpdatesStrategy<>(
101 new NesterovUpdateCalculator<MultilayerPerceptron>(0.1, 0.7),
102 NesterovParameterUpdate::sum,
103 NesterovParameterUpdate::avg
104 ));
105 }
106
107 /**
108 * Common method for testing 'XOR' with various updaters.
109 * @param updatesStgy Update strategy.
110 * @param <P> Updater parameters type.
111 */
112 private <P extends Serializable> void xorTest(UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy) {
113 CacheConfiguration<Integer, LabeledPoint> xorCacheCfg = new CacheConfiguration<>();
114 xorCacheCfg.setName("XorData");
115 xorCacheCfg.setAffinity(new RendezvousAffinityFunction(false, 5));
116 IgniteCache<Integer, LabeledPoint> xorCache = ignite.createCache(xorCacheCfg);
117
118 try {
119 xorCache.put(0, new LabeledPoint(0.0, 0.0, 0.0));
120 xorCache.put(1, new LabeledPoint(0.0, 1.0, 1.0));
121 xorCache.put(2, new LabeledPoint(1.0, 0.0, 1.0));
122 xorCache.put(3, new LabeledPoint(1.0, 1.0, 0.0));
123
124 MLPArchitecture arch = new MLPArchitecture(2).
125 withAddedLayer(10, true, Activators.RELU).
126 withAddedLayer(1, false, Activators.SIGMOID);
127
128 MLPTrainer<P> trainer = new MLPTrainer<>(
129 arch,
130 LossFunctions.MSE,
131 updatesStgy,
132 2500,
133 4,
134 50,
135 123L
136 );
137
138 MultilayerPerceptron mlp = trainer.fit(
139 ignite,
140 xorCache,
141 (k, v) -> VectorUtils.of(v.x, v.y ),
142 (k, v) -> new double[]{ v.lb}
143 );
144
145 Matrix predict = mlp.apply(new DenseLocalOnHeapMatrix(new double[][]{
146 {0.0, 0.0},
147 {0.0, 1.0},
148 {1.0, 0.0},
149 {1.0, 1.0}
150 }));
151
152 Tracer.showAscii(predict);
153
154 X.println(new DenseLocalOnHeapVector(new double[]{0.0}).minus(predict.getRow(0)).kNorm(2) + "");
155
156 TestUtils.checkIsInEpsilonNeighbourhood(new DenseLocalOnHeapVector(new double[]{0.0}), predict.getRow(0), 1E-1);
157 }
158 finally {
159 xorCache.destroy();
160 }
161 }
162
163 /** Labeled point data class. */
164 private static class LabeledPoint {
165 /** X coordinate. */
166 private final double x;
167
168 /** Y coordinate. */
169 private final double y;
170
171 /** Point label. */
172 private final double lb;
173
174 /**
175 * Constructs a new instance of labeled point data.
176 *
177 * @param x X coordinate.
178 * @param y Y coordinate.
179 * @param lb Point label.
180 */
181 public LabeledPoint(double x, double y, double lb) {
182 this.x = x;
183 this.y = y;
184 this.lb = lb;
185 }
186 }
187 }