IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / modules / ml / src / main / java / org / apache / ignite / ml / tree / DecisionTree.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.ml.tree;
19
20 import java.io.Serializable;
21 import java.util.Arrays;
22 import org.apache.ignite.ml.dataset.Dataset;
23 import org.apache.ignite.ml.dataset.DatasetBuilder;
24 import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
25 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
26 import org.apache.ignite.ml.math.Vector;
27 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
28 import org.apache.ignite.ml.trainers.DatasetTrainer;
29 import org.apache.ignite.ml.tree.data.DecisionTreeData;
30 import org.apache.ignite.ml.tree.data.DecisionTreeDataBuilder;
31 import org.apache.ignite.ml.tree.impurity.ImpurityMeasure;
32 import org.apache.ignite.ml.tree.impurity.ImpurityMeasureCalculator;
33 import org.apache.ignite.ml.tree.impurity.util.StepFunction;
34 import org.apache.ignite.ml.tree.impurity.util.StepFunctionCompressor;
35 import org.apache.ignite.ml.tree.leaf.DecisionTreeLeafBuilder;
36
37 /**
38 * Distributed decision tree trainer that allows to fit trees using row-partitioned dataset.
39 *
40 * @param <T> Type of impurity measure.
41 */
42 public abstract class DecisionTree<T extends ImpurityMeasure<T>> implements DatasetTrainer<DecisionTreeNode, Double> {
43 /** Max tree deep. */
44 private final int maxDeep;
45
46 /** Min impurity decrease. */
47 private final double minImpurityDecrease;
48
49 /** Step function compressor. */
50 private final StepFunctionCompressor<T> compressor;
51
52 /** Decision tree leaf builder. */
53 private final DecisionTreeLeafBuilder decisionTreeLeafBuilder;
54
55 /**
56 * Constructs a new distributed decision tree trainer.
57 *
58 * @param maxDeep Max tree deep.
59 * @param minImpurityDecrease Min impurity decrease.
60 * @param compressor Impurity function compressor.
61 * @param decisionTreeLeafBuilder Decision tree leaf builder.
62 */
63 DecisionTree(int maxDeep, double minImpurityDecrease, StepFunctionCompressor<T> compressor, DecisionTreeLeafBuilder decisionTreeLeafBuilder) {
64 this.maxDeep = maxDeep;
65 this.minImpurityDecrease = minImpurityDecrease;
66 this.compressor = compressor;
67 this.decisionTreeLeafBuilder = decisionTreeLeafBuilder;
68 }
69
70 /** {@inheritDoc} */
71 @Override public <K, V> DecisionTreeNode fit(DatasetBuilder<K, V> datasetBuilder,
72 IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
73 try (Dataset<EmptyContext, DecisionTreeData> dataset = datasetBuilder.build(
74 new EmptyContextBuilder<>(),
75 new DecisionTreeDataBuilder<>(featureExtractor, lbExtractor)
76 )) {
77 return split(dataset, e -> true, 0, getImpurityMeasureCalculator(dataset));
78 }
79 catch (Exception e) {
80 throw new RuntimeException(e);
81 }
82 }
83
84 /**
85 * Returns impurity measure calculator.
86 *
87 * @param dataset Dataset.
88 * @return Impurity measure calculator.
89 */
90 abstract ImpurityMeasureCalculator<T> getImpurityMeasureCalculator(Dataset<EmptyContext, DecisionTreeData> dataset);
91
92 /**
93 * Splits the node specified by the given dataset and predicate and returns decision tree node.
94 *
95 * @param dataset Dataset.
96 * @param filter Decision tree node predicate.
97 * @param deep Current tree deep.
98 * @param impurityCalc Impurity measure calculator.
99 * @return Decision tree node.
100 */
101 private DecisionTreeNode split(Dataset<EmptyContext, DecisionTreeData> dataset, TreeFilter filter, int deep,
102 ImpurityMeasureCalculator<T> impurityCalc) {
103 if (deep >= maxDeep)
104 return decisionTreeLeafBuilder.createLeafNode(dataset, filter);
105
106 StepFunction<T>[] criterionFunctions = calculateImpurityForAllColumns(dataset, filter, impurityCalc);
107
108 if (criterionFunctions == null)
109 return decisionTreeLeafBuilder.createLeafNode(dataset, filter);
110
111 SplitPoint splitPnt = calculateBestSplitPoint(criterionFunctions);
112
113 if (splitPnt == null)
114 return decisionTreeLeafBuilder.createLeafNode(dataset, filter);
115
116 return new DecisionTreeConditionalNode(
117 splitPnt.col,
118 splitPnt.threshold,
119 split(dataset, updatePredicateForThenNode(filter, splitPnt), deep + 1, impurityCalc),
120 split(dataset, updatePredicateForElseNode(filter, splitPnt), deep + 1, impurityCalc)
121 );
122 }
123
124 /**
125 * Calculates impurity measure functions for all columns for the node specified by the given dataset and predicate.
126 *
127 * @param dataset Dataset.
128 * @param filter Decision tree node predicate.
129 * @param impurityCalc Impurity measure calculator.
130 * @return Array of impurity measure functions for all columns.
131 */
132 private StepFunction<T>[] calculateImpurityForAllColumns(Dataset<EmptyContext, DecisionTreeData> dataset,
133 TreeFilter filter, ImpurityMeasureCalculator<T> impurityCalc) {
134 return dataset.compute(
135 part -> {
136 if (compressor != null)
137 return compressor.compress(impurityCalc.calculate(part.filter(filter)));
138 else
139 return impurityCalc.calculate(part.filter(filter));
140 }, this::reduce
141 );
142 }
143
144 /**
145 * Calculates best split point.
146 *
147 * @param criterionFunctions Array of impurity measure functions for all columns.
148 * @return Best split point.
149 */
150 private SplitPoint calculateBestSplitPoint(StepFunction<T>[] criterionFunctions) {
151 SplitPoint<T> res = null;
152
153 for (int col = 0; col < criterionFunctions.length; col++) {
154 StepFunction<T> criterionFunctionForCol = criterionFunctions[col];
155
156 double[] arguments = criterionFunctionForCol.getX();
157 T[] values = criterionFunctionForCol.getY();
158
159 for (int leftSize = 1; leftSize < values.length - 1; leftSize++) {
160 if ((values[0].impurity() - values[leftSize].impurity()) > minImpurityDecrease
161 && (res == null || values[leftSize].compareTo(res.val) < 0))
162 res = new SplitPoint<>(values[leftSize], col, calculateThreshold(arguments, leftSize));
163 }
164 }
165
166 return res;
167 }
168
169 /**
170 * Merges two arrays gotten from two partitions.
171 *
172 * @param a First step function.
173 * @param b Second step function.
174 * @return Merged step function.
175 */
176 private StepFunction<T>[] reduce(StepFunction<T>[] a, StepFunction<T>[] b) {
177 if (a == null)
178 return b;
179 if (b == null)
180 return a;
181 else {
182 StepFunction<T>[] res = Arrays.copyOf(a, a.length);
183
184 for (int i = 0; i < res.length; i++)
185 res[i] = res[i].add(b[i]);
186
187 return res;
188 }
189 }
190
191 /**
192 * Calculates threshold based on the given step function arguments and split point (specified left size).
193 *
194 * @param arguments Step function arguments.
195 * @param leftSize Split point (left size).
196 * @return Threshold.
197 */
198 private double calculateThreshold(double[] arguments, int leftSize) {
199 return (arguments[leftSize] + arguments[leftSize + 1]) / 2.0;
200 }
201
202 /**
203 * Constructs a new predicate for "then" node based on the parent node predicate and split point.
204 *
205 * @param filter Parent node predicate.
206 * @param splitPnt Split point.
207 * @return Predicate for "then" node.
208 */
209 private TreeFilter updatePredicateForThenNode(TreeFilter filter, SplitPoint splitPnt) {
210 return filter.and(f -> f[splitPnt.col] > splitPnt.threshold);
211 }
212
213 /**
214 * Constructs a new predicate for "else" node based on the parent node predicate and split point.
215 *
216 * @param filter Parent node predicate.
217 * @param splitPnt Split point.
218 * @return Predicate for "else" node.
219 */
220 private TreeFilter updatePredicateForElseNode(TreeFilter filter, SplitPoint splitPnt) {
221 return filter.and(f -> f[splitPnt.col] <= splitPnt.threshold);
222 }
223
224 /**
225 * Util class that represents split point.
226 */
227 private static class SplitPoint<T extends ImpurityMeasure<T>> implements Serializable {
228 /** */
229 private static final long serialVersionUID = -1758525953544425043L;
230
231 /** Split point impurity measure value. */
232 private final T val;
233
234 /** Column. */
235 private final int col;
236
237 /** Threshold. */
238 private final double threshold;
239
240 /**
241 * Constructs a new instance of split point.
242 *
243 * @param val Split point impurity measure value.
244 * @param col Column.
245 * @param threshold Threshold.
246 */
247 SplitPoint(T val, int col, double threshold) {
248 this.val = val;
249 this.col = col;
250 this.threshold = threshold;
251 }
252 }
253 }