IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / modules / ml / src / main / java / org / apache / ignite / ml / regressions / logistic / binomial / LogisticRegressionSGDTrainer.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.ml.regressions.logistic.binomial;
19
20 import java.io.Serializable;
21 import java.util.Arrays;
22 import org.apache.ignite.ml.dataset.Dataset;
23 import org.apache.ignite.ml.dataset.DatasetBuilder;
24 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
25 import org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData;
26 import org.apache.ignite.ml.math.Vector;
27 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
28 import org.apache.ignite.ml.math.functions.IgniteFunction;
29 import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
30 import org.apache.ignite.ml.nn.Activators;
31 import org.apache.ignite.ml.nn.MLPTrainer;
32 import org.apache.ignite.ml.nn.MultilayerPerceptron;
33 import org.apache.ignite.ml.nn.UpdatesStrategy;
34 import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
35 import org.apache.ignite.ml.optimization.LossFunctions;
36 import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
37
38 /**
39 * Trainer of the logistic regression model based on stochastic gradient descent algorithm.
40 */
41 public class LogisticRegressionSGDTrainer<P extends Serializable> implements SingleLabelDatasetTrainer<LogisticRegressionModel> {
42 /** Update strategy. */
43 private final UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy;
44
45 /** Max number of iteration. */
46 private final int maxIterations;
47
48 /** Batch size. */
49 private final int batchSize;
50
51 /** Number of local iterations. */
52 private final int locIterations;
53
54 /** Seed for random generator. */
55 private final long seed;
56
57 /**
58 * Constructs a new instance of linear regression SGD trainer.
59 *
60 * @param updatesStgy Update strategy.
61 * @param maxIterations Max number of iteration.
62 * @param batchSize Batch size.
63 * @param locIterations Number of local iterations.
64 * @param seed Seed for random generator.
65 */
66 public LogisticRegressionSGDTrainer(UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy, int maxIterations,
67 int batchSize, int locIterations, long seed) {
68 this.updatesStgy = updatesStgy;
69 this.maxIterations = maxIterations;
70 this.batchSize = batchSize;
71 this.locIterations = locIterations;
72 this.seed = seed;
73 }
74
75 /** {@inheritDoc} */
76 @Override public <K, V> LogisticRegressionModel fit(DatasetBuilder<K, V> datasetBuilder,
77 IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
78
79 IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier = dataset -> {
80
81 int cols = dataset.compute(data -> {
82 if (data.getFeatures() == null)
83 return null;
84 return data.getFeatures().length / data.getRows();
85 }, (a, b) -> a == null ? b : a);
86
87 MLPArchitecture architecture = new MLPArchitecture(cols);
88 architecture = architecture.withAddedLayer(1, true, Activators.SIGMOID);
89
90 return architecture;
91 };
92
93 MLPTrainer<?> trainer = new MLPTrainer<>(
94 archSupplier,
95 LossFunctions.L2,
96 updatesStgy,
97 maxIterations,
98 batchSize,
99 locIterations,
100 seed
101 );
102
103 MultilayerPerceptron mlp = trainer.fit(datasetBuilder, featureExtractor, (k, v) -> new double[]{lbExtractor.apply(k, v)});
104
105 double[] params = mlp.parameters().getStorage().data();
106
107 return new LogisticRegressionModel(new DenseLocalOnHeapVector(Arrays.copyOf(params, params.length - 1)),
108 params[params.length - 1]
109 );
110 }
111 }