IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / modules / ml / src / main / java / org / apache / ignite / ml / svm / SVMLinearBinaryClassificationTrainer.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.ml.svm;
19
20 import java.util.concurrent.ThreadLocalRandom;
21 import org.apache.ignite.ml.dataset.Dataset;
22 import org.apache.ignite.ml.dataset.DatasetBuilder;
23 import org.apache.ignite.ml.dataset.PartitionDataBuilder;
24 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
25 import org.apache.ignite.ml.math.Vector;
26 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
27 import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
28 import org.apache.ignite.ml.structures.LabeledDataset;
29 import org.apache.ignite.ml.structures.LabeledVector;
30 import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap;
31 import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
32 import org.jetbrains.annotations.NotNull;
33
34 /**
35 * Base class for a soft-margin SVM linear classification trainer based on the communication-efficient distributed dual
36 * coordinate ascent algorithm (CoCoA) with hinge-loss function. <p> This trainer takes input as Labeled Dataset with -1
37 * and +1 labels for two classes and makes binary classification. </p> The paper about this algorithm could be found
38 * here https://arxiv.org/abs/1409.1458.
39 */
40 public class SVMLinearBinaryClassificationTrainer implements SingleLabelDatasetTrainer<SVMLinearBinaryClassificationModel> {
41 /** Amount of outer SDCA algorithm iterations. */
42 private int amountOfIterations = 200;
43
44 /** Amount of local SDCA algorithm iterations. */
45 private int amountOfLocIterations = 100;
46
47 /** Regularization parameter. */
48 private double lambda = 0.4;
49
50 /**
51 * Trains model based on the specified data.
52 *
53 * @param datasetBuilder Dataset builder.
54 * @param featureExtractor Feature extractor.
55 * @param lbExtractor Label extractor.
56 * @return Model.
57 */
58 @Override public <K, V> SVMLinearBinaryClassificationModel fit(DatasetBuilder<K, V> datasetBuilder,
59 IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
60
61 assert datasetBuilder != null;
62
63 PartitionDataBuilder<K, V, EmptyContext, LabeledDataset<Double, LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(
64 featureExtractor,
65 lbExtractor
66 );
67
68 Vector weights;
69
70 try(Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset = datasetBuilder.build(
71 (upstream, upstreamSize) -> new EmptyContext(),
72 partDataBuilder
73 )) {
74 final int cols = dataset.compute(data -> data.colSize(), (a, b) -> a == null ? b : a);
75 final int weightVectorSizeWithIntercept = cols + 1;
76 weights = initializeWeightsWithZeros(weightVectorSizeWithIntercept);
77
78 for (int i = 0; i < this.getAmountOfIterations(); i++) {
79 Vector deltaWeights = calculateUpdates(weights, dataset);
80 weights = weights.plus(deltaWeights); // creates new vector
81 }
82 } catch (Exception e) {
83 throw new RuntimeException(e);
84 }
85 return new SVMLinearBinaryClassificationModel(weights.viewPart(1, weights.size() - 1), weights.get(0));
86 }
87
88 /** */
89 @NotNull private Vector initializeWeightsWithZeros(int vectorSize) {
90 return new DenseLocalOnHeapVector(vectorSize);
91 }
92
93 /** */
94 private Vector calculateUpdates(Vector weights, Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset) {
95 return dataset.compute(data -> {
96 Vector copiedWeights = weights.copy();
97 Vector deltaWeights = initializeWeightsWithZeros(weights.size());
98 final int amountOfObservation = data.rowSize();
99
100 Vector tmpAlphas = initializeWeightsWithZeros(amountOfObservation);
101 Vector deltaAlphas = initializeWeightsWithZeros(amountOfObservation);
102
103 for (int i = 0; i < this.getAmountOfLocIterations(); i++) {
104 int randomIdx = ThreadLocalRandom.current().nextInt(amountOfObservation);
105
106 Deltas deltas = getDeltas(data, copiedWeights, amountOfObservation, tmpAlphas, randomIdx);
107
108 copiedWeights = copiedWeights.plus(deltas.deltaWeights); // creates new vector
109 deltaWeights = deltaWeights.plus(deltas.deltaWeights); // creates new vector
110
111 tmpAlphas.set(randomIdx, tmpAlphas.get(randomIdx) + deltas.deltaAlpha);
112 deltaAlphas.set(randomIdx, deltaAlphas.get(randomIdx) + deltas.deltaAlpha);
113 }
114 return deltaWeights;
115 }, (a, b) -> a == null ? b : a.plus(b));
116 }
117
118 /** */
119 private Deltas getDeltas(LabeledDataset data, Vector copiedWeights, int amountOfObservation, Vector tmpAlphas,
120 int randomIdx) {
121 LabeledVector row = (LabeledVector)data.getRow(randomIdx);
122 Double lb = (Double)row.label();
123 Vector v = makeVectorWithInterceptElement(row);
124
125 double alpha = tmpAlphas.get(randomIdx);
126
127 return maximize(lb, v, alpha, copiedWeights, amountOfObservation);
128 }
129
130 /** */
131 private Vector makeVectorWithInterceptElement(LabeledVector row) {
132 Vector vec = row.features().like(row.features().size() + 1);
133
134 vec.set(0, 1); // set intercept element
135
136 for (int j = 0; j < row.features().size(); j++)
137 vec.set(j + 1, row.features().get(j));
138
139 return vec;
140 }
141
142 /** */
143 private Deltas maximize(double lb, Vector v, double alpha, Vector weights, int amountOfObservation) {
144 double gradient = calcGradient(lb, v, weights, amountOfObservation);
145 double prjGrad = calculateProjectionGradient(alpha, gradient);
146
147 return calcDeltas(lb, v, alpha, prjGrad, weights.size(), amountOfObservation);
148 }
149
150 /** */
151 private Deltas calcDeltas(double lb, Vector v, double alpha, double gradient, int vectorSize,
152 int amountOfObservation) {
153 if (gradient != 0.0) {
154
155 double qii = v.dot(v);
156 double newAlpha = calcNewAlpha(alpha, gradient, qii);
157
158 Vector deltaWeights = v.times(lb * (newAlpha - alpha) / (this.lambda() * amountOfObservation));
159
160 return new Deltas(newAlpha - alpha, deltaWeights);
161 }
162 else
163 return new Deltas(0.0, initializeWeightsWithZeros(vectorSize));
164 }
165
166 /** */
167 private double calcNewAlpha(double alpha, double gradient, double qii) {
168 if (qii != 0.0)
169 return Math.min(Math.max(alpha - (gradient / qii), 0.0), 1.0);
170 else
171 return 1.0;
172 }
173
174 /** */
175 private double calcGradient(double lb, Vector v, Vector weights, int amountOfObservation) {
176 double dotProduct = v.dot(weights);
177 return (lb * dotProduct - 1.0) * (this.lambda() * amountOfObservation);
178 }
179
180 /** */
181 private double calculateProjectionGradient(double alpha, double gradient) {
182 if (alpha <= 0.0)
183 return Math.min(gradient, 0.0);
184
185 else if (alpha >= 1.0)
186 return Math.max(gradient, 0.0);
187
188 else
189 return gradient;
190 }
191
192 /**
193 * Set up the regularization parameter.
194 * @param lambda The regularization parameter. Should be more than 0.0.
195 * @return Trainer with new lambda parameter value.
196 */
197 public SVMLinearBinaryClassificationTrainer withLambda(double lambda) {
198 assert lambda > 0.0;
199 this.lambda = lambda;
200 return this;
201 }
202
203 /**
204 * Gets the regularization lambda.
205 * @return The parameter value.
206 */
207 public double lambda() {
208 return lambda;
209 }
210
211 /**
212 * Gets the amount of outer iterations of SCDA algorithm.
213 * @return The parameter value.
214 */
215 public int getAmountOfIterations() {
216 return amountOfIterations;
217 }
218
219 /**
220 * Set up the amount of outer iterations of SCDA algorithm.
221 * @param amountOfIterations The parameter value.
222 * @return Trainer with new amountOfIterations parameter value.
223 */
224 public SVMLinearBinaryClassificationTrainer withAmountOfIterations(int amountOfIterations) {
225 this.amountOfIterations = amountOfIterations;
226 return this;
227 }
228
229 /**
230 * Gets the amount of local iterations of SCDA algorithm.
231 * @return The parameter value.
232 */
233 public int getAmountOfLocIterations() {
234 return amountOfLocIterations;
235 }
236
237 /**
238 * Set up the amount of local iterations of SCDA algorithm.
239 * @param amountOfLocIterations The parameter value.
240 * @return Trainer with new amountOfLocIterations parameter value.
241 */
242 public SVMLinearBinaryClassificationTrainer withAmountOfLocIterations(int amountOfLocIterations) {
243 this.amountOfLocIterations = amountOfLocIterations;
244 return this;
245 }
246
247 }
248
249 /** This is a helper class to handle pair results which are returned from the calculation method. */
250 class Deltas {
251 /** */
252 public double deltaAlpha;
253
254 /** */
255 public Vector deltaWeights;
256
257 /** */
258 public Deltas(double deltaAlpha, Vector deltaWeights) {
259 this.deltaAlpha = deltaAlpha;
260 this.deltaWeights = deltaWeights;
261 }
262 }
263
264