IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / modules / ml / src / main / java / org / apache / ignite / ml / preprocessing / imputing / ImputerTrainer.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.ml.preprocessing.imputing;
19
20 import java.util.Comparator;
21 import java.util.HashMap;
22 import java.util.Map;
23 import java.util.Optional;
24 import org.apache.ignite.ml.dataset.Dataset;
25 import org.apache.ignite.ml.dataset.DatasetBuilder;
26 import org.apache.ignite.ml.dataset.UpstreamEntry;
27 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
28 import org.apache.ignite.ml.math.Vector;
29 import org.apache.ignite.ml.math.VectorUtils;
30 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
31 import org.apache.ignite.ml.preprocessing.PreprocessingTrainer;
32
33 /**
34 * Trainer of the imputing preprocessor.
35 * The imputing fills the missed values according the imputing strategy (default: mean value for each feature).
36 * It supports double values in features only.
37 *
38 * @param <K> Type of a key in {@code upstream} data.
39 * @param <V> Type of a value in {@code upstream} data.
40 */
41 public class ImputerTrainer<K, V> implements PreprocessingTrainer<K, V, Vector, Vector> {
42 /** The imputing strategy. */
43 private ImputingStrategy imputingStgy = ImputingStrategy.MEAN;
44
45 /** {@inheritDoc} */
46 @Override public ImputerPreprocessor<K, V> fit(DatasetBuilder<K, V> datasetBuilder,
47 IgniteBiFunction<K, V, Vector> basePreprocessor) {
48 try (Dataset<EmptyContext, ImputerPartitionData> dataset = datasetBuilder.build(
49 (upstream, upstreamSize) -> new EmptyContext(),
50 (upstream, upstreamSize, ctx) -> {
51 double[] sums = null;
52 int[] counts = null;
53 Map<Double, Integer>[] valuesByFreq = null;
54
55 while (upstream.hasNext()) {
56 UpstreamEntry<K, V> entity = upstream.next();
57 Vector row = basePreprocessor.apply(entity.getKey(), entity.getValue());
58
59 switch (imputingStgy) {
60 case MEAN:
61 sums = calculateTheSums(row, sums);
62 counts = calculateTheCounts(row, counts);
63 break;
64 case MOST_FREQUENT:
65 valuesByFreq = calculateFrequencies(row, valuesByFreq);
66 break;
67 default: throw new UnsupportedOperationException("The chosen strategy is not supported");
68 }
69 }
70
71 ImputerPartitionData partData;
72
73 switch (imputingStgy) {
74 case MEAN:
75 partData = new ImputerPartitionData().withSums(sums).withCounts(counts);
76 break;
77 case MOST_FREQUENT:
78 partData = new ImputerPartitionData().withValuesByFrequency(valuesByFreq);
79 break;
80 default: throw new UnsupportedOperationException("The chosen strategy is not supported");
81 }
82 return partData;
83 }
84 )) {
85
86 Vector imputingValues;
87
88 switch (imputingStgy) {
89 case MEAN:
90 imputingValues = VectorUtils.of(calculateImputingValuesBySumsAndCounts(dataset));
91 break;
92 case MOST_FREQUENT:
93 imputingValues = VectorUtils.of(calculateImputingValuesByFrequencies(dataset));
94 break;
95 default: throw new UnsupportedOperationException("The chosen strategy is not supported");
96 }
97
98 return new ImputerPreprocessor<>(imputingValues, basePreprocessor);
99
100 }
101 catch (Exception e) {
102 throw new RuntimeException(e);
103 }
104 }
105
106 /**
107 * Calculates the imputing values by frequencies keeping in the given dataset.
108 *
109 * @param dataset The dataset of frequencies for each feature aggregated in each partition..
110 * @return Most frequent value for each feature.
111 */
112 private double[] calculateImputingValuesByFrequencies(
113 Dataset<EmptyContext, ImputerPartitionData> dataset) {
114 Map<Double, Integer>[] frequencies = dataset.compute(
115 ImputerPartitionData::valuesByFrequency,
116 (a, b) -> {
117 if (a == null)
118 return b;
119
120 if (b == null)
121 return a;
122
123 assert a.length == b.length;
124
125 for (int i = 0; i < a.length; i++) {
126 int finalI = i;
127 a[i].forEach((k, v) -> b[finalI].merge(k, v, (f1, f2) -> f1 + f2));
128 }
129 return b;
130 }
131 );
132
133 double[] res = new double[frequencies.length];
134
135 for (int i = 0; i < frequencies.length; i++) {
136 Optional<Map.Entry<Double, Integer>> max = frequencies[i].entrySet()
137 .stream()
138 .max(Comparator.comparingInt(Map.Entry::getValue));
139
140 if(max.isPresent())
141 res[i] = max.get().getKey();
142 }
143
144 return res;
145 }
146
147 /**
148 * Calculates the imputing values by sums and counts keeping in the given dataset.
149 *
150 * @param dataset The dataset with sums and counts for each feature aggregated in each partition.
151 * @return The mean value for each feature.
152 */
153 private double[] calculateImputingValuesBySumsAndCounts(Dataset<EmptyContext, ImputerPartitionData> dataset) {
154 double[] sums = dataset.compute(
155 ImputerPartitionData::sums,
156 (a, b) -> {
157 if (a == null)
158 return b;
159
160 if (b == null)
161 return a;
162
163 assert a.length == b.length;
164
165 for (int i = 0; i < a.length; i++)
166 a[i] += b[i];
167
168 return a;
169 }
170 );
171
172 int[] counts = dataset.compute(
173 ImputerPartitionData::counts,
174 (a, b) -> {
175 if (a == null)
176 return b;
177
178 if (b == null)
179 return a;
180
181 assert a.length == b.length;
182
183 for (int i = 0; i < a.length; i++)
184 a[i] += b[i];
185
186 return a;
187 }
188 );
189
190 double[] means = new double[sums.length];
191
192 for (int i = 0; i < means.length; i++)
193 means[i] = sums[i]/counts[i];
194
195 return means;
196 }
197
198 /**
199 * Updates frequencies by values and features.
200 *
201 * @param row Feature vector.
202 * @param valuesByFreq Holds the sums by values and features.
203 * @return Updated sums by values and features.
204 */
205 private Map<Double, Integer>[] calculateFrequencies(Vector row, Map<Double, Integer>[] valuesByFreq) {
206 if (valuesByFreq == null) {
207 valuesByFreq = new HashMap[row.size()];
208 for (int i = 0; i < valuesByFreq.length; i++) valuesByFreq[i] = new HashMap<>();
209 }
210 else
211 assert valuesByFreq.length == row.size() : "Base preprocessor must return exactly " + valuesByFreq.length
212 + " features";
213
214 for (int i = 0; i < valuesByFreq.length; i++) {
215 double v = row.get(i);
216
217 if(!Double.valueOf(v).equals(Double.NaN)) {
218 Map<Double, Integer> map = valuesByFreq[i];
219
220 if (map.containsKey(v))
221 map.put(v, (map.get(v)) + 1);
222 else
223 map.put(v, 1);
224 }
225 }
226 return valuesByFreq;
227 }
228
229 /**
230 * Updates sums by features.
231 *
232 * @param row Feature vector.
233 * @param sums Holds the sums by features.
234 * @return Updated sums by features.
235 */
236 private double[] calculateTheSums(Vector row, double[] sums) {
237 if (sums == null)
238 sums = new double[row.size()];
239 else
240 assert sums.length == row.size() : "Base preprocessor must return exactly " + sums.length
241 + " features";
242
243 for (int i = 0; i < sums.length; i++){
244 if(!Double.valueOf(row.get(i)).equals(Double.NaN))
245 sums[i] += row.get(i);
246 }
247
248 return sums;
249 }
250
251 /**
252 * Updates counts by features.
253 *
254 * @param row Feature vector.
255 * @param counts Holds the counts by features.
256 * @return Updated counts by features.
257 */
258 private int[] calculateTheCounts(Vector row, int[] counts) {
259 if (counts == null)
260 counts = new int[row.size()];
261 else
262 assert counts.length == row.size() : "Base preprocessor must return exactly " + counts.length
263 + " features";
264
265 for (int i = 0; i < counts.length; i++){
266 if(!Double.valueOf(row.get(i)).equals(Double.NaN))
267 counts[i]++;
268 }
269
270 return counts;
271 }
272
273 /**
274 * Sets the imputing strategy.
275 *
276 * @param imputingStgy The given value.
277 * @return The updated imputing trainer.
278 */
279 public ImputerTrainer<K, V> withImputingStrategy(ImputingStrategy imputingStgy){
280 this.imputingStgy = imputingStgy;
281 return this;
282 }
283 }