IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / modules / ml / src / main / java / org / apache / ignite / ml / preprocessing / encoding / stringencoder / StringEncoderTrainer.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.ml.preprocessing.encoding.stringencoder;
19
20 import java.util.HashMap;
21 import java.util.HashSet;
22 import java.util.LinkedHashMap;
23 import java.util.Map;
24 import java.util.Set;
25 import java.util.stream.Collectors;
26 import org.apache.ignite.ml.dataset.Dataset;
27 import org.apache.ignite.ml.dataset.DatasetBuilder;
28 import org.apache.ignite.ml.dataset.UpstreamEntry;
29 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
30 import org.apache.ignite.ml.math.Vector;
31 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
32 import org.apache.ignite.ml.preprocessing.PreprocessingTrainer;
33 import org.jetbrains.annotations.NotNull;
34
35 /**
36 * Trainer of the String Encoder preprocessor.
37 * The String Encoder encodes string values (categories) to double values in range [0.0, amountOfCategories)
38 * where the most popular value will be presented as 0.0 and the least popular value presented with amountOfCategories-1 value.
39 *
40 * @param <K> Type of a key in {@code upstream} data.
41 * @param <V> Type of a value in {@code upstream} data.
42 */
43 public class StringEncoderTrainer<K, V> implements PreprocessingTrainer<K, V, Object[], Vector> {
44 /** Indices of features which should be encoded. */
45 private Set<Integer> handledIndices = new HashSet<>();
46
47 /** {@inheritDoc} */
48 @Override public StringEncoderPreprocessor<K, V> fit(DatasetBuilder<K, V> datasetBuilder,
49 IgniteBiFunction<K, V, Object[]> basePreprocessor) {
50 if(handledIndices.isEmpty())
51 throw new RuntimeException("Add indices of handled features");
52
53 try (Dataset<EmptyContext, StringEncoderPartitionData> dataset = datasetBuilder.build(
54 (upstream, upstreamSize) -> new EmptyContext(),
55 (upstream, upstreamSize, ctx) -> {
56 // This array will contain not null values for handled indices
57 Map<String, Integer>[] categoryFrequencies = null;
58
59 while (upstream.hasNext()) {
60 UpstreamEntry<K, V> entity = upstream.next();
61 Object[] row = basePreprocessor.apply(entity.getKey(), entity.getValue());
62 categoryFrequencies = calculateFrequencies(row, categoryFrequencies);
63 }
64 return new StringEncoderPartitionData()
65 .withCategoryFrequencies(categoryFrequencies);
66 }
67 )) {
68 Map<String, Integer>[] encodingValues = calculateEncodingValuesByFrequencies(dataset);
69
70 return new StringEncoderPreprocessor<>(encodingValues, basePreprocessor, handledIndices);
71 }
72 catch (Exception e) {
73 throw new RuntimeException(e);
74 }
75 }
76
77 /**
78 * Calculates the encoding values values by frequencies keeping in the given dataset.
79 *
80 * @param dataset The dataset of frequencies for each feature aggregated in each partition.
81 * @return Encoding values for each feature.
82 */
83 private Map<String, Integer>[] calculateEncodingValuesByFrequencies(
84 Dataset<EmptyContext, StringEncoderPartitionData> dataset) {
85 Map<String, Integer>[] frequencies = dataset.compute(
86 StringEncoderPartitionData::categoryFrequencies,
87 (a, b) -> {
88 if (a == null)
89 return b;
90
91 if (b == null)
92 return a;
93
94 assert a.length == b.length;
95
96 for (int i = 0; i < a.length; i++) {
97 if(handledIndices.contains(i)){
98 int finalI = i;
99 a[i].forEach((k, v) -> b[finalI].merge(k, v, (f1, f2) -> f1 + f2));
100 }
101 }
102 return b;
103 }
104 );
105
106 Map<String, Integer>[] res = new HashMap[frequencies.length];
107
108 for (int i = 0; i < frequencies.length; i++)
109 if(handledIndices.contains(i))
110 res[i] = transformFrequenciesToEncodingValues(frequencies[i]);
111
112 return res;
113 }
114
115 /**
116 * Transforms frequencies to the encoding values.
117 *
118 * @param frequencies Frequencies of categories for the specific feature.
119 * @return Encoding values.
120 */
121 private Map<String, Integer> transformFrequenciesToEncodingValues(Map<String, Integer> frequencies) {
122 final HashMap<String, Integer> resMap = frequencies.entrySet()
123 .stream()
124 .sorted(Map.Entry.comparingByValue())
125 .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue,
126 (oldValue, newValue) -> oldValue, LinkedHashMap::new));
127
128 int amountOfLabels = frequencies.size();
129
130 for (Map.Entry<String, Integer> m : resMap.entrySet())
131 m.setValue(--amountOfLabels);
132
133 return resMap;
134 }
135
136 /**
137 * Updates frequencies by values and features.
138 *
139 * @param row Feature vector.
140 * @param categoryFrequencies Holds the frequencies of categories by values and features.
141 * @return Updated frequencies by values and features.
142 */
143 private Map<String, Integer>[] calculateFrequencies(Object[] row, Map<String, Integer>[] categoryFrequencies) {
144 if (categoryFrequencies == null)
145 categoryFrequencies = initializeCategoryFrequencies(row);
146 else
147 assert categoryFrequencies.length == row.length : "Base preprocessor must return exactly "
148 + categoryFrequencies.length + " features";
149
150 for (int i = 0; i < categoryFrequencies.length; i++) {
151 if(handledIndices.contains(i)){
152 String strVal;
153 Object featureVal = row[i];
154
155 if(featureVal.equals(Double.NaN)) {
156 strVal = "";
157 row[i] = strVal;
158 }
159 else strVal = (String)featureVal;
160
161 Map<String, Integer> map = categoryFrequencies[i];
162
163 if (map.containsKey(strVal))
164 map.put(strVal, (map.get(strVal)) + 1);
165 else
166 map.put(strVal, 1);
167 }
168 }
169 return categoryFrequencies;
170 }
171
172 /**
173 * Initialize frequencies for handled indices only.
174 * @param row Feature vector.
175 * @return The array contains not null values for handled indices.
176 */
177 @NotNull private Map<String, Integer>[] initializeCategoryFrequencies(Object[] row) {
178 Map<String, Integer>[] categoryFrequencies = new HashMap[row.length];
179
180 for (int i = 0; i < categoryFrequencies.length; i++)
181 if(handledIndices.contains(i))
182 categoryFrequencies[i] = new HashMap<>();
183
184 return categoryFrequencies;
185 }
186
187 /**
188 * Add the index of encoded feature.
189 * @param idx The index of encoded feature.
190 * @return The changed trainer.
191 */
192 public StringEncoderTrainer<K, V> encodeFeature(int idx){
193 handledIndices.add(idx);
194 return this;
195 }
196 }