IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / modules / ml / src / main / java / org / apache / ignite / ml / selection / cv / CrossValidation.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.ml.selection.cv;
19
20 import java.util.Map;
21 import java.util.function.BiFunction;
22 import java.util.function.Function;
23 import org.apache.ignite.Ignite;
24 import org.apache.ignite.IgniteCache;
25 import org.apache.ignite.lang.IgniteBiPredicate;
26 import org.apache.ignite.ml.Model;
27 import org.apache.ignite.ml.dataset.DatasetBuilder;
28 import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
29 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
30 import org.apache.ignite.ml.math.Vector;
31 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
32 import org.apache.ignite.ml.selection.scoring.cursor.CacheBasedLabelPairCursor;
33 import org.apache.ignite.ml.selection.scoring.cursor.LabelPairCursor;
34 import org.apache.ignite.ml.selection.scoring.cursor.LocalLabelPairCursor;
35 import org.apache.ignite.ml.selection.scoring.metric.Metric;
36 import org.apache.ignite.ml.selection.split.mapper.SHA256UniformMapper;
37 import org.apache.ignite.ml.selection.split.mapper.UniformMapper;
38 import org.apache.ignite.ml.trainers.DatasetTrainer;
39
40 /**
41 * Cross validation score calculator. Cross validation is an approach that allows to avoid overfitting that is made the
42 * following way: the training set is split into k smaller sets. The following procedure is followed for each of the k
43 * “folds”:
44 * <ul>
45 * <li>A model is trained using k-1 of the folds as training data;</li>
46 * <li>the resulting model is validated on the remaining part of the data (i.e., it is used as a test set to compute
47 * a performance measure such as accuracy).</li>
48 * </ul>
49 *
50 * @param <M> Type of model.
51 * @param <L> Type of a label (truth or prediction).
52 * @param <K> Type of a key in {@code upstream} data.
53 * @param <V> Type of a value in {@code upstream} data.
54 */
55 public class CrossValidation<M extends Model<Vector, L>, L, K, V> {
56 /**
57 * Computes cross-validated metrics.
58 *
59 * @param trainer Trainer of the model.
60 * @param scoreCalculator Score calculator.
61 * @param ignite Ignite instance.
62 * @param upstreamCache Ignite cache with {@code upstream} data.
63 * @param featureExtractor Feature extractor.
64 * @param lbExtractor Label extractor.
65 * @param cv Number of folds.
66 * @return Array of scores of the estimator for each run of the cross validation.
67 */
68 public double[] score(DatasetTrainer<M, L> trainer, Metric<L> scoreCalculator, Ignite ignite,
69 IgniteCache<K, V> upstreamCache, IgniteBiFunction<K, V, Vector> featureExtractor,
70 IgniteBiFunction<K, V, L> lbExtractor, int cv) {
71 return score(trainer, scoreCalculator, ignite, upstreamCache, (k, v) -> true, featureExtractor, lbExtractor,
72 new SHA256UniformMapper<>(), cv);
73 }
74
75 /**
76 * Computes cross-validated metrics.
77 *
78 * @param trainer Trainer of the model.
79 * @param scoreCalculator Base score calculator.
80 * @param ignite Ignite instance.
81 * @param upstreamCache Ignite cache with {@code upstream} data.
82 * @param filter Base {@code upstream} data filter.
83 * @param featureExtractor Feature extractor.
84 * @param lbExtractor Label extractor.
85 * @param cv Number of folds.
86 * @return Array of scores of the estimator for each run of the cross validation.
87 */
88 public double[] score(DatasetTrainer<M, L> trainer, Metric<L> scoreCalculator, Ignite ignite,
89 IgniteCache<K, V> upstreamCache, IgniteBiPredicate<K, V> filter,
90 IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, int cv) {
91 return score(trainer, scoreCalculator, ignite, upstreamCache, filter, featureExtractor, lbExtractor,
92 new SHA256UniformMapper<>(), cv);
93 }
94
95 /**
96 * Computes cross-validated metrics.
97 *
98 * @param trainer Trainer of the model.
99 * @param scoreCalculator Base score calculator.
100 * @param ignite Ignite instance.
101 * @param upstreamCache Ignite cache with {@code upstream} data.
102 * @param filter Base {@code upstream} data filter.
103 * @param featureExtractor Feature extractor.
104 * @param lbExtractor Label extractor.
105 * @param mapper Mapper used to map a key-value pair to a point on the segment (0, 1).
106 * @param cv Number of folds.
107 * @return Array of scores of the estimator for each run of the cross validation.
108 */
109 public double[] score(DatasetTrainer<M, L> trainer, Metric<L> scoreCalculator,
110 Ignite ignite, IgniteCache<K, V> upstreamCache, IgniteBiPredicate<K, V> filter,
111 IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor,
112 UniformMapper<K, V> mapper, int cv) {
113
114 return score(
115 trainer,
116 predicate -> new CacheBasedDatasetBuilder<>(
117 ignite,
118 upstreamCache,
119 (k, v) -> filter.apply(k, v) && predicate.apply(k, v)
120 ),
121 (predicate, mdl) -> new CacheBasedLabelPairCursor<>(
122 upstreamCache,
123 (k, v) -> filter.apply(k, v) && !predicate.apply(k, v),
124 featureExtractor,
125 lbExtractor,
126 mdl
127 ),
128 featureExtractor,
129 lbExtractor,
130 scoreCalculator,
131 mapper,
132 cv
133 );
134 }
135
136 /**
137 * Computes cross-validated metrics.
138 *
139 * @param trainer Trainer of the model.
140 * @param scoreCalculator Base score calculator.
141 * @param upstreamMap Map with {@code upstream} data.
142 * @param parts Number of partitions.
143 * @param featureExtractor Feature extractor.
144 * @param lbExtractor Label extractor.
145 * @param cv Number of folds.
146 * @return Array of scores of the estimator for each run of the cross validation.
147 */
148 public double[] score(DatasetTrainer<M, L> trainer, Metric<L> scoreCalculator, Map<K, V> upstreamMap,
149 int parts, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, int cv) {
150 return score(trainer, scoreCalculator, upstreamMap, (k, v) -> true, parts, featureExtractor, lbExtractor,
151 new SHA256UniformMapper<>(), cv);
152 }
153
154 /**
155 * Computes cross-validated metrics.
156 *
157 * @param trainer Trainer of the model.
158 * @param scoreCalculator Base score calculator.
159 * @param upstreamMap Map with {@code upstream} data.
160 * @param filter Base {@code upstream} data filter.
161 * @param parts Number of partitions.
162 * @param featureExtractor Feature extractor.
163 * @param lbExtractor Label extractor.
164 * @param cv Number of folds.
165 * @return Array of scores of the estimator for each run of the cross validation.
166 */
167 public double[] score(DatasetTrainer<M, L> trainer, Metric<L> scoreCalculator, Map<K, V> upstreamMap,
168 IgniteBiPredicate<K, V> filter, int parts, IgniteBiFunction<K, V, Vector> featureExtractor,
169 IgniteBiFunction<K, V, L> lbExtractor, int cv) {
170 return score(trainer, scoreCalculator, upstreamMap, filter, parts, featureExtractor, lbExtractor,
171 new SHA256UniformMapper<>(), cv);
172 }
173
174 /**
175 * Computes cross-validated metrics.
176 *
177 * @param trainer Trainer of the model.
178 * @param scoreCalculator Base score calculator.
179 * @param upstreamMap Map with {@code upstream} data.
180 * @param filter Base {@code upstream} data filter.
181 * @param parts Number of partitions.
182 * @param featureExtractor Feature extractor.
183 * @param lbExtractor Label extractor.
184 * @param mapper Mapper used to map a key-value pair to a point on the segment (0, 1).
185 * @param cv Number of folds.
186 * @return Array of scores of the estimator for each run of the cross validation.
187 */
188 public double[] score(DatasetTrainer<M, L> trainer, Metric<L> scoreCalculator, Map<K, V> upstreamMap,
189 IgniteBiPredicate<K, V> filter, int parts, IgniteBiFunction<K, V, Vector> featureExtractor,
190 IgniteBiFunction<K, V, L> lbExtractor, UniformMapper<K, V> mapper, int cv) {
191 return score(
192 trainer,
193 predicate -> new LocalDatasetBuilder<>(
194 upstreamMap,
195 (k, v) -> filter.apply(k, v) && predicate.apply(k, v),
196 parts
197 ),
198 (predicate, mdl) -> new LocalLabelPairCursor<>(
199 upstreamMap,
200 (k, v) -> filter.apply(k, v) && !predicate.apply(k, v),
201 featureExtractor,
202 lbExtractor,
203 mdl
204 ),
205 featureExtractor,
206 lbExtractor,
207 scoreCalculator,
208 mapper,
209 cv
210 );
211 }
212
213 /**
214 * Computes cross-validated metrics.
215 *
216 * @param trainer Trainer of the model.
217 * @param datasetBuilderSupplier Dataset builder supplier.
218 * @param testDataIterSupplier Test data iterator supplier.
219 * @param featureExtractor Feature extractor.
220 * @param lbExtractor Label extractor.
221 * @param scoreCalculator Base score calculator.
222 * @param mapper Mapper used to map a key-value pair to a point on the segment (0, 1).
223 * @param cv Number of folds.
224 * @return Array of scores of the estimator for each run of the cross validation.
225 */
226 private double[] score(DatasetTrainer<M, L> trainer, Function<IgniteBiPredicate<K, V>,
227 DatasetBuilder<K, V>> datasetBuilderSupplier,
228 BiFunction<IgniteBiPredicate<K, V>, M, LabelPairCursor<L>> testDataIterSupplier,
229 IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor,
230 Metric<L> scoreCalculator, UniformMapper<K, V> mapper, int cv) {
231
232 double[] scores = new double[cv];
233
234 double foldSize = 1.0 / cv;
235 for (int i = 0; i < cv; i++) {
236 double from = foldSize * i;
237 double to = foldSize * (i + 1);
238
239 IgniteBiPredicate<K, V> trainSetFilter = (k, v) -> {
240 double pnt = mapper.map(k, v);
241 return pnt < from || pnt > to;
242 };
243
244 DatasetBuilder<K, V> datasetBuilder = datasetBuilderSupplier.apply(trainSetFilter);
245 M mdl = trainer.fit(datasetBuilder, featureExtractor, lbExtractor);
246
247 try (LabelPairCursor<L> cursor = testDataIterSupplier.apply(trainSetFilter, mdl)) {
248 scores[i] = scoreCalculator.score(cursor.iterator());
249 }
250 catch (Exception e) {
251 throw new RuntimeException(e);
252 }
253 }
254
255 return scores;
256 }
257 }