IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / modules / ml / src / main / java / org / apache / ignite / ml / svm / SVMLinearMultiClassClassificationTrainer.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.ml.svm;
19
20 import java.util.ArrayList;
21 import java.util.Collection;
22 import java.util.HashSet;
23 import java.util.List;
24 import java.util.Set;
25 import java.util.stream.Collectors;
26 import java.util.stream.Stream;
27 import org.apache.ignite.ml.dataset.Dataset;
28 import org.apache.ignite.ml.dataset.DatasetBuilder;
29 import org.apache.ignite.ml.dataset.PartitionDataBuilder;
30 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
31 import org.apache.ignite.ml.math.Vector;
32 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
33 import org.apache.ignite.ml.structures.partition.LabelPartitionDataBuilderOnHeap;
34 import org.apache.ignite.ml.structures.partition.LabelPartitionDataOnHeap;
35 import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
36
37 /**
38 * Base class for a soft-margin SVM linear multiclass-classification trainer based on the communication-efficient
39 * distributed dual coordinate ascent algorithm (CoCoA) with hinge-loss function.
40 *
41 * All common parameters are shared with bunch of binary classification trainers.
42 */
43 public class SVMLinearMultiClassClassificationTrainer
44 implements SingleLabelDatasetTrainer<SVMLinearMultiClassClassificationModel> {
45 /** Amount of outer SDCA algorithm iterations. */
46 private int amountOfIterations = 20;
47
48 /** Amount of local SDCA algorithm iterations. */
49 private int amountOfLocIterations = 50;
50
51 /** Regularization parameter. */
52 private double lambda = 0.2;
53
54 /**
55 * Trains model based on the specified data.
56 *
57 * @param datasetBuilder Dataset builder.
58 * @param featureExtractor Feature extractor.
59 * @param lbExtractor Label extractor.
60 * @return Model.
61 */
62 @Override public <K, V> SVMLinearMultiClassClassificationModel fit(DatasetBuilder<K, V> datasetBuilder,
63 IgniteBiFunction<K, V, Vector> featureExtractor,
64 IgniteBiFunction<K, V, Double> lbExtractor) {
65 List<Double> classes = extractClassLabels(datasetBuilder, lbExtractor);
66
67 SVMLinearMultiClassClassificationModel multiClsMdl = new SVMLinearMultiClassClassificationModel();
68
69 classes.forEach(clsLb -> {
70 SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer()
71 .withAmountOfIterations(this.amountOfIterations())
72 .withAmountOfLocIterations(this.amountOfLocIterations())
73 .withLambda(this.lambda());
74
75 IgniteBiFunction<K, V, Double> lbTransformer = (k, v) -> {
76 Double lb = lbExtractor.apply(k, v);
77
78 if (lb.equals(clsLb))
79 return 1.0;
80 else
81 return -1.0;
82 };
83 multiClsMdl.add(clsLb, trainer.fit(datasetBuilder, featureExtractor, lbTransformer));
84 });
85
86 return multiClsMdl;
87 }
88
89 /** Iterates among dataset and collects class labels. */
90 private <K, V> List<Double> extractClassLabels(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Double> lbExtractor) {
91 assert datasetBuilder != null;
92
93 PartitionDataBuilder<K, V, EmptyContext, LabelPartitionDataOnHeap> partDataBuilder = new LabelPartitionDataBuilderOnHeap<>(lbExtractor);
94
95 List<Double> res = new ArrayList<>();
96
97 try (Dataset<EmptyContext, LabelPartitionDataOnHeap> dataset = datasetBuilder.build(
98 (upstream, upstreamSize) -> new EmptyContext(),
99 partDataBuilder
100 )) {
101 final Set<Double> clsLabels = dataset.compute(data -> {
102 final Set<Double> locClsLabels = new HashSet<>();
103
104 final double[] lbs = data.getY();
105
106 for (double lb : lbs) locClsLabels.add(lb);
107
108 return locClsLabels;
109 }, (a, b) -> a == null ? b : Stream.of(a, b).flatMap(Collection::stream).collect(Collectors.toSet()));
110
111 res.addAll(clsLabels);
112
113 } catch (Exception e) {
114 throw new RuntimeException(e);
115 }
116 return res;
117 }
118
119 /**
120 * Set up the regularization parameter.
121 *
122 * @param lambda The regularization parameter. Should be more than 0.0.
123 * @return Trainer with new lambda parameter value.
124 */
125 public SVMLinearMultiClassClassificationTrainer withLambda(double lambda) {
126 assert lambda > 0.0;
127 this.lambda = lambda;
128 return this;
129 }
130
131 /**
132 * Gets the regularization lambda.
133 *
134 * @return The parameter value.
135 */
136 public double lambda() {
137 return lambda;
138 }
139
140 /**
141 * Gets the amount of outer iterations of SCDA algorithm.
142 *
143 * @return The parameter value.
144 */
145 public int amountOfIterations() {
146 return amountOfIterations;
147 }
148
149 /**
150 * Set up the amount of outer iterations of SCDA algorithm.
151 *
152 * @param amountOfIterations The parameter value.
153 * @return Trainer with new amountOfIterations parameter value.
154 */
155 public SVMLinearMultiClassClassificationTrainer withAmountOfIterations(int amountOfIterations) {
156 this.amountOfIterations = amountOfIterations;
157 return this;
158 }
159
160 /**
161 * Gets the amount of local iterations of SCDA algorithm.
162 *
163 * @return The parameter value.
164 */
165 public int amountOfLocIterations() {
166 return amountOfLocIterations;
167 }
168
169 /**
170 * Set up the amount of local iterations of SCDA algorithm.
171 *
172 * @param amountOfLocIterations The parameter value.
173 * @return Trainer with new amountOfLocIterations parameter value.
174 */
175 public SVMLinearMultiClassClassificationTrainer withAmountOfLocIterations(int amountOfLocIterations) {
176 this.amountOfLocIterations = amountOfLocIterations;
177 return this;
178 }
179 }