IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / modules / ml / src / main / java / org / apache / ignite / ml / composition / boosting / GDBBinaryClassifierTrainer.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.ml.composition.boosting;
19
20 import java.util.ArrayList;
21 import java.util.Arrays;
22 import java.util.List;
23 import java.util.Set;
24 import java.util.stream.Collectors;
25 import org.apache.ignite.internal.util.typedef.internal.A;
26 import org.apache.ignite.ml.dataset.DatasetBuilder;
27 import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
28 import org.apache.ignite.ml.math.Vector;
29 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
30 import org.apache.ignite.ml.math.functions.IgniteFunction;
31 import org.apache.ignite.ml.math.functions.IgniteTriFunction;
32 import org.apache.ignite.ml.structures.LabeledDataset;
33 import org.apache.ignite.ml.structures.LabeledVector;
34 import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap;
35
36 /**
37 * Trainer for binary classifier using Gradient Boosting.
38 * As preparing stage this algorithm learn labels in dataset and create mapping dataset labels to 0 and 1.
39 * This algorithm uses gradient of Logarithmic Loss metric [LogLoss] by default in each step of learning.
40 */
41 public abstract class GDBBinaryClassifierTrainer extends GDBTrainer {
42 /** External representation of first class. */
43 private double externalFirstCls; //internal 0.0
44 /** External representation of second class. */
45 private double externalSecondCls; //internal 1.0
46
47 /**
48 * Constructs instance of GDBBinaryClassifierTrainer.
49 *
50 * @param gradStepSize Grad step size.
51 * @param cntOfIterations Count of learning iterations.
52 */
53 public GDBBinaryClassifierTrainer(double gradStepSize, Integer cntOfIterations) {
54 super(gradStepSize,
55 cntOfIterations,
56 LossGradientPerPredictionFunctions.LOG_LOSS);
57 }
58
59 /**
60 * Constructs instance of GDBBinaryClassifierTrainer.
61 *
62 * @param gradStepSize Grad step size.
63 * @param cntOfIterations Count of learning iterations.
64 * @param lossGradient Gradient of loss function. First argument is sample size, second argument is valid answer, third argument is current model prediction.
65 */
66 public GDBBinaryClassifierTrainer(double gradStepSize,
67 Integer cntOfIterations,
68 IgniteTriFunction<Long, Double, Double, Double> lossGradient) {
69
70 super(gradStepSize, cntOfIterations, lossGradient);
71 }
72
73 /** {@inheritDoc} */
74 @Override protected <V, K> void learnLabels(DatasetBuilder<K, V> builder, IgniteBiFunction<K, V, Vector> featureExtractor,
75 IgniteBiFunction<K, V, Double> lExtractor) {
76
77 List<Double> uniqLabels = new ArrayList<Double>(
78 builder.build(new EmptyContextBuilder<>(), new LabeledDatasetPartitionDataBuilderOnHeap<>(featureExtractor, lExtractor))
79 .compute((IgniteFunction<LabeledDataset<Double,LabeledVector>, Set<Double>>) x -> {
80 return Arrays.stream(x.labels()).boxed().collect(Collectors.toSet());
81 }, (a, b) -> {
82 if (a == null)
83 return b;
84 if (b == null)
85 return a;
86 a.addAll(b);
87 return a;
88 }
89 ));
90
91 A.ensure(uniqLabels.size() == 2, "Binary classifier expects two types of labels in learning dataset");
92 externalFirstCls = uniqLabels.get(0);
93 externalSecondCls = uniqLabels.get(1);
94 }
95
96 /** {@inheritDoc} */
97 @Override protected double externalLabelToInternal(double x) {
98 return x == externalFirstCls ? 0.0 : 1.0;
99 }
100
101 /** {@inheritDoc} */
102 @Override protected double internalLabelToExternal(double indent) {
103 double sigma = 1.0 / (1.0 + Math.exp(-indent));
104 double internalCls = sigma < 0.5 ? 0.0 : 1.0;
105 return internalCls == 0.0 ? externalFirstCls : externalSecondCls;
106 }
107 }