2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 package org
.apache
.ignite
.ml
.svm
;
20 import java
.util
.ArrayList
;
21 import java
.util
.Collection
;
22 import java
.util
.HashSet
;
23 import java
.util
.List
;
25 import java
.util
.stream
.Collectors
;
26 import java
.util
.stream
.Stream
;
27 import org
.apache
.ignite
.ml
.dataset
.Dataset
;
28 import org
.apache
.ignite
.ml
.dataset
.DatasetBuilder
;
29 import org
.apache
.ignite
.ml
.dataset
.PartitionDataBuilder
;
30 import org
.apache
.ignite
.ml
.dataset
.primitive
.context
.EmptyContext
;
31 import org
.apache
.ignite
.ml
.math
.Vector
;
32 import org
.apache
.ignite
.ml
.math
.functions
.IgniteBiFunction
;
33 import org
.apache
.ignite
.ml
.structures
.partition
.LabelPartitionDataBuilderOnHeap
;
34 import org
.apache
.ignite
.ml
.structures
.partition
.LabelPartitionDataOnHeap
;
35 import org
.apache
.ignite
.ml
.trainers
.SingleLabelDatasetTrainer
;
38 * Base class for a soft-margin SVM linear multiclass-classification trainer based on the communication-efficient
39 * distributed dual coordinate ascent algorithm (CoCoA) with hinge-loss function.
41 * All common parameters are shared with bunch of binary classification trainers.
43 public class SVMLinearMultiClassClassificationTrainer
44 implements SingleLabelDatasetTrainer
<SVMLinearMultiClassClassificationModel
> {
45 /** Amount of outer SDCA algorithm iterations. */
46 private int amountOfIterations
= 20;
48 /** Amount of local SDCA algorithm iterations. */
49 private int amountOfLocIterations
= 50;
51 /** Regularization parameter. */
52 private double lambda
= 0.2;
55 * Trains model based on the specified data.
57 * @param datasetBuilder Dataset builder.
58 * @param featureExtractor Feature extractor.
59 * @param lbExtractor Label extractor.
62 @Override public <K
, V
> SVMLinearMultiClassClassificationModel
fit(DatasetBuilder
<K
, V
> datasetBuilder
,
63 IgniteBiFunction
<K
, V
, Vector
> featureExtractor
,
64 IgniteBiFunction
<K
, V
, Double
> lbExtractor
) {
65 List
<Double
> classes
= extractClassLabels(datasetBuilder
, lbExtractor
);
67 SVMLinearMultiClassClassificationModel multiClsMdl
= new SVMLinearMultiClassClassificationModel();
69 classes
.forEach(clsLb
-> {
70 SVMLinearBinaryClassificationTrainer trainer
= new SVMLinearBinaryClassificationTrainer()
71 .withAmountOfIterations(this.amountOfIterations())
72 .withAmountOfLocIterations(this.amountOfLocIterations())
73 .withLambda(this.lambda());
75 IgniteBiFunction
<K
, V
, Double
> lbTransformer
= (k
, v
) -> {
76 Double lb
= lbExtractor
.apply(k
, v
);
83 multiClsMdl
.add(clsLb
, trainer
.fit(datasetBuilder
, featureExtractor
, lbTransformer
));
89 /** Iterates among dataset and collects class labels. */
90 private <K
, V
> List
<Double
> extractClassLabels(DatasetBuilder
<K
, V
> datasetBuilder
, IgniteBiFunction
<K
, V
, Double
> lbExtractor
) {
91 assert datasetBuilder
!= null
;
93 PartitionDataBuilder
<K
, V
, EmptyContext
, LabelPartitionDataOnHeap
> partDataBuilder
= new LabelPartitionDataBuilderOnHeap
<>(lbExtractor
);
95 List
<Double
> res
= new ArrayList
<>();
97 try (Dataset
<EmptyContext
, LabelPartitionDataOnHeap
> dataset
= datasetBuilder
.build(
98 (upstream
, upstreamSize
) -> new EmptyContext(),
101 final Set
<Double
> clsLabels
= dataset
.compute(data
-> {
102 final Set
<Double
> locClsLabels
= new HashSet
<>();
104 final double[] lbs
= data
.getY();
106 for (double lb
: lbs
) locClsLabels
.add(lb
);
109 }, (a
, b
) -> a
== null ? b
: Stream
.of(a
, b
).flatMap(Collection
::stream
).collect(Collectors
.toSet()));
111 res
.addAll(clsLabels
);
113 } catch (Exception e
) {
114 throw new RuntimeException(e
);
120 * Set up the regularization parameter.
122 * @param lambda The regularization parameter. Should be more than 0.0.
123 * @return Trainer with new lambda parameter value.
125 public SVMLinearMultiClassClassificationTrainer
withLambda(double lambda
) {
127 this.lambda
= lambda
;
132 * Gets the regularization lambda.
134 * @return The parameter value.
136 public double lambda() {
141 * Gets the amount of outer iterations of SCDA algorithm.
143 * @return The parameter value.
145 public int amountOfIterations() {
146 return amountOfIterations
;
150 * Set up the amount of outer iterations of SCDA algorithm.
152 * @param amountOfIterations The parameter value.
153 * @return Trainer with new amountOfIterations parameter value.
155 public SVMLinearMultiClassClassificationTrainer
withAmountOfIterations(int amountOfIterations
) {
156 this.amountOfIterations
= amountOfIterations
;
161 * Gets the amount of local iterations of SCDA algorithm.
163 * @return The parameter value.
165 public int amountOfLocIterations() {
166 return amountOfLocIterations
;
170 * Set up the amount of local iterations of SCDA algorithm.
172 * @param amountOfLocIterations The parameter value.
173 * @return Trainer with new amountOfLocIterations parameter value.
175 public SVMLinearMultiClassClassificationTrainer
withAmountOfLocIterations(int amountOfLocIterations
) {
176 this.amountOfLocIterations
= amountOfLocIterations
;