IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / modules / ml / src / main / java / org / apache / ignite / ml / nn / MLPTrainer.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.ml.nn;
19
20 import java.io.Serializable;
21 import java.util.ArrayList;
22 import java.util.List;
23 import java.util.Random;
24 import org.apache.ignite.ml.dataset.Dataset;
25 import org.apache.ignite.ml.dataset.DatasetBuilder;
26 import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
27 import org.apache.ignite.ml.dataset.primitive.builder.data.SimpleLabeledDatasetDataBuilder;
28 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
29 import org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData;
30 import org.apache.ignite.ml.math.Matrix;
31 import org.apache.ignite.ml.math.Vector;
32 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
33 import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction;
34 import org.apache.ignite.ml.math.functions.IgniteFunction;
35 import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
36 import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
37 import org.apache.ignite.ml.nn.initializers.RandomInitializer;
38 import org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator;
39 import org.apache.ignite.ml.trainers.MultiLabelDatasetTrainer;
40 import org.apache.ignite.ml.util.Utils;
41
42 /**
43 * Multilayer perceptron trainer based on partition based {@link Dataset}.
44 *
45 * @param <P> Type of model update used in this trainer.
46 */
47 public class MLPTrainer<P extends Serializable> implements MultiLabelDatasetTrainer<MultilayerPerceptron> {
48 /** Multilayer perceptron architecture supplier that defines layers and activators. */
49 private final IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier;
50
51 /** Loss function to be minimized during the training. */
52 private final IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss;
53
54 /** Update strategy that defines how to update model parameters during the training. */
55 private final UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy;
56
57 /** Maximal number of iterations before the training will be stopped. */
58 private final int maxIterations;
59
60 /** Batch size (per every partition). */
61 private final int batchSize;
62
63 /** Maximal number of local iterations before synchronization. */
64 private final int locIterations;
65
66 /** Multilayer perceptron model initializer. */
67 private final long seed;
68
69 /**
70 * Constructs a new instance of multilayer perceptron trainer.
71 *
72 * @param arch Multilayer perceptron architecture that defines layers and activators.
73 * @param loss Loss function to be minimized during the training.
74 * @param updatesStgy Update strategy that defines how to update model parameters during the training.
75 * @param maxIterations Maximal number of iterations before the training will be stopped.
76 * @param batchSize Batch size (per every partition).
77 * @param locIterations Maximal number of local iterations before synchronization.
78 * @param seed Random initializer seed.
79 */
80 public MLPTrainer(MLPArchitecture arch, IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss,
81 UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy, int maxIterations, int batchSize,
82 int locIterations, long seed) {
83 this(dataset -> arch, loss, updatesStgy, maxIterations, batchSize, locIterations, seed);
84 }
85
86 /**
87 * Constructs a new instance of multilayer perceptron trainer.
88 *
89 * @param archSupplier Multilayer perceptron architecture supplier that defines layers and activators.
90 * @param loss Loss function to be minimized during the training.
91 * @param updatesStgy Update strategy that defines how to update model parameters during the training.
92 * @param maxIterations Maximal number of iterations before the training will be stopped.
93 * @param batchSize Batch size (per every partition).
94 * @param locIterations Maximal number of local iterations before synchronization.
95 * @param seed Random initializer seed.
96 */
97 public MLPTrainer(IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier,
98 IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss,
99 UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy, int maxIterations, int batchSize,
100 int locIterations, long seed) {
101 this.archSupplier = archSupplier;
102 this.loss = loss;
103 this.updatesStgy = updatesStgy;
104 this.maxIterations = maxIterations;
105 this.batchSize = batchSize;
106 this.locIterations = locIterations;
107 this.seed = seed;
108 }
109
110 /** {@inheritDoc} */
111 public <K, V> MultilayerPerceptron fit(DatasetBuilder<K, V> datasetBuilder,
112 IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, double[]> lbExtractor) {
113
114 try (Dataset<EmptyContext, SimpleLabeledDatasetData> dataset = datasetBuilder.build(
115 new EmptyContextBuilder<>(),
116 new SimpleLabeledDatasetDataBuilder<>(featureExtractor, lbExtractor)
117 )) {
118 MLPArchitecture arch = archSupplier.apply(dataset);
119 MultilayerPerceptron mdl = new MultilayerPerceptron(arch, new RandomInitializer(seed));
120 ParameterUpdateCalculator<? super MultilayerPerceptron, P> updater = updatesStgy.getUpdatesCalculator();
121
122 for (int i = 0; i < maxIterations; i += locIterations) {
123
124 MultilayerPerceptron finalMdl = mdl;
125 int finalI = i;
126
127 List<P> totUp = dataset.compute(
128 data -> {
129 P update = updater.init(finalMdl, loss);
130
131 MultilayerPerceptron mlp = Utils.copy(finalMdl);
132
133 if (data.getFeatures() != null) {
134 List<P> updates = new ArrayList<>();
135
136 for (int locStep = 0; locStep < locIterations; locStep++) {
137 int[] rows = Utils.selectKDistinct(
138 data.getRows(),
139 Math.min(batchSize, data.getRows()),
140 new Random(seed ^ (finalI * locStep))
141 );
142
143 double[] inputsBatch = batch(data.getFeatures(), rows, data.getRows());
144 double[] groundTruthBatch = batch(data.getLabels(), rows, data.getRows());
145
146 Matrix inputs = new DenseLocalOnHeapMatrix(inputsBatch, rows.length, 0);
147 Matrix groundTruth = new DenseLocalOnHeapMatrix(groundTruthBatch, rows.length, 0);
148
149 update = updater.calculateNewUpdate(
150 mlp,
151 update,
152 locStep,
153 inputs.transpose(),
154 groundTruth.transpose()
155 );
156
157 mlp = updater.update(mlp, update);
158 updates.add(update);
159 }
160
161 List<P> res = new ArrayList<>();
162 res.add(updatesStgy.locStepUpdatesReducer().apply(updates));
163
164 return res;
165 }
166
167 return null;
168 },
169 (a, b) -> {
170 if (a == null)
171 return b;
172 else if (b == null)
173 return a;
174 else {
175 a.addAll(b);
176 return a;
177 }
178 }
179 );
180
181 P update = updatesStgy.allUpdatesReducer().apply(totUp);
182 mdl = updater.update(mdl, update);
183 }
184
185 return mdl;
186 }
187 catch (Exception e) {
188 throw new RuntimeException(e);
189 }
190 }
191
192 /**
193 * Builds a batch of the data by fetching specified rows.
194 *
195 * @param data All data.
196 * @param rows Rows to be fetched from the data.
197 * @param totalRows Total number of rows in all data.
198 * @return Batch data.
199 */
200 static double[] batch(double[] data, int[] rows, int totalRows) {
201 int cols = data.length / totalRows;
202
203 double[] res = new double[cols * rows.length];
204
205 for (int i = 0; i < rows.length; i++)
206 for (int j = 0; j < cols; j++)
207 res[j * rows.length + i] = data[j * totalRows + rows[i]];
208
209 return res;
210 }
211 }