IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / modules / ml / src / test / java / org / apache / ignite / ml / nn / performance / MLPTrainerMnistTest.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.ml.nn.performance;
19
20 import java.io.IOException;
21 import java.util.HashMap;
22 import java.util.Map;
23 import org.apache.ignite.ml.math.Matrix;
24 import org.apache.ignite.ml.math.VectorUtils;
25 import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
26 import org.apache.ignite.ml.nn.Activators;
27 import org.apache.ignite.ml.nn.MLPTrainer;
28 import org.apache.ignite.ml.nn.MultilayerPerceptron;
29 import org.apache.ignite.ml.nn.UpdatesStrategy;
30 import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
31 import org.apache.ignite.ml.optimization.LossFunctions;
32 import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate;
33 import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator;
34 import org.apache.ignite.ml.util.MnistUtils;
35 import org.junit.Test;
36
37 import static org.junit.Assert.assertTrue;
38
39 /**
40 * Tests {@link MLPTrainer} on the MNIST dataset using locally stored data.
41 */
42 public class MLPTrainerMnistTest {
43 /** Tests on the MNIST dataset. */
44 @Test
45 public void testMNIST() throws IOException {
46 int featCnt = 28 * 28;
47 int hiddenNeuronsCnt = 100;
48
49 Map<Integer, MnistUtils.MnistLabeledImage> trainingSet = new HashMap<>();
50
51 int i = 0;
52 for (MnistUtils.MnistLabeledImage e : MnistMLPTestUtil.loadTrainingSet(60_000))
53 trainingSet.put(i++, e);
54
55 MLPArchitecture arch = new MLPArchitecture(featCnt).
56 withAddedLayer(hiddenNeuronsCnt, true, Activators.SIGMOID).
57 withAddedLayer(10, false, Activators.SIGMOID);
58
59 MLPTrainer<?> trainer = new MLPTrainer<>(
60 arch,
61 LossFunctions.MSE,
62 new UpdatesStrategy<>(
63 new RPropUpdateCalculator(),
64 RPropParameterUpdate::sum,
65 RPropParameterUpdate::avg
66 ),
67 200,
68 2000,
69 10,
70 123L
71 );
72
73 System.out.println("Start training...");
74 long start = System.currentTimeMillis();
75 MultilayerPerceptron mdl = trainer.fit(
76 trainingSet,
77 1,
78 (k, v) -> VectorUtils.of(v.getPixels()),
79 (k, v) -> VectorUtils.num2Vec(v.getLabel(), 10).getStorage().data()
80 );
81 System.out.println("Training completed in " + (System.currentTimeMillis() - start) + "ms");
82
83 int correctAnswers = 0;
84 int incorrectAnswers = 0;
85
86 for (MnistUtils.MnistLabeledImage e : MnistMLPTestUtil.loadTestSet(10_000)) {
87 Matrix input = new DenseLocalOnHeapMatrix(new double[][]{e.getPixels()});
88 Matrix outputMatrix = mdl.apply(input);
89
90 int predicted = (int) VectorUtils.vec2Num(outputMatrix.getRow(0));
91
92 if (predicted == e.getLabel())
93 correctAnswers++;
94 else
95 incorrectAnswers++;
96 }
97
98 double accuracy = 1.0 * correctAnswers / (correctAnswers + incorrectAnswers);
99 assertTrue("Accuracy should be >= 80% (not " + accuracy * 100 + "%)", accuracy >= 0.8);
100 }
101 }