IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / examples / src / main / java / org / apache / ignite / examples / ml / regression / logistic / multiclass / LogRegressionMultiClassClassificationExample.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.examples.ml.regression.logistic.multiclass;
19
20 import java.util.Arrays;
21 import java.util.UUID;
22 import javax.cache.Cache;
23 import org.apache.ignite.Ignite;
24 import org.apache.ignite.IgniteCache;
25 import org.apache.ignite.Ignition;
26 import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
27 import org.apache.ignite.cache.query.QueryCursor;
28 import org.apache.ignite.cache.query.ScanQuery;
29 import org.apache.ignite.configuration.CacheConfiguration;
30 import org.apache.ignite.ml.math.Vector;
31 import org.apache.ignite.ml.math.VectorUtils;
32 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
33 import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
34 import org.apache.ignite.ml.nn.UpdatesStrategy;
35 import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
36 import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
37 import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer;
38 import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassModel;
39 import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassTrainer;
40 import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationModel;
41 import org.apache.ignite.thread.IgniteThread;
42
43 /**
44 * Run Logistic Regression multi-class classification trainer over distributed dataset to build two models:
45 * one with minmaxscaling and one without minmaxscaling.
46 *
47 * @see SVMLinearMultiClassClassificationModel
48 */
49 public class LogRegressionMultiClassClassificationExample {
50 /** Run example. */
51 public static void main(String[] args) throws InterruptedException {
52 System.out.println();
53 System.out.println(">>> Logistic Regression Multi-class classification model over cached dataset usage example started.");
54 // Start ignite grid.
55 try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
56 System.out.println(">>> Ignite grid started.");
57
58 IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
59 LogRegressionMultiClassClassificationExample.class.getSimpleName(), () -> {
60 IgniteCache<Integer, Vector> dataCache = getTestCache(ignite);
61
62 LogRegressionMultiClassTrainer<?> trainer = new LogRegressionMultiClassTrainer<>()
63 .withUpdatesStgy(new UpdatesStrategy<>(
64 new SimpleGDUpdateCalculator(0.2),
65 SimpleGDParameterUpdate::sumLocal,
66 SimpleGDParameterUpdate::avg
67 ))
68 .withAmountOfIterations(100000)
69 .withAmountOfLocIterations(10)
70 .withBatchSize(100)
71 .withSeed(123L);
72
73 LogRegressionMultiClassModel mdl = trainer.fit(
74 ignite,
75 dataCache,
76 (k, v) -> {
77 double[] arr = v.asArray();
78 return VectorUtils.of(Arrays.copyOfRange(arr, 1, arr.length));
79 },
80 (k, v) -> v.get(0)
81 );
82
83 System.out.println(">>> SVM Multi-class model");
84 System.out.println(mdl.toString());
85
86 MinMaxScalerTrainer<Integer, Vector> normalizationTrainer = new MinMaxScalerTrainer<>();
87
88 IgniteBiFunction<Integer, Vector, Vector> preprocessor = normalizationTrainer.fit(
89 ignite,
90 dataCache,
91 (k, v) -> {
92 double[] arr = v.asArray();
93 return VectorUtils.of(Arrays.copyOfRange(arr, 1, arr.length));
94 }
95 );
96
97 LogRegressionMultiClassModel mdlWithNormalization = trainer.fit(
98 ignite,
99 dataCache,
100 preprocessor,
101 (k, v) -> v.get(0)
102 );
103
104 System.out.println(">>> Logistic Regression Multi-class model with minmaxscaling");
105 System.out.println(mdlWithNormalization.toString());
106
107 System.out.println(">>> ----------------------------------------------------------------");
108 System.out.println(">>> | Prediction\t| Prediction with Normalization\t| Ground Truth\t|");
109 System.out.println(">>> ----------------------------------------------------------------");
110
111 int amountOfErrors = 0;
112 int amountOfErrorsWithNormalization = 0;
113 int totalAmount = 0;
114
115 // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
116 int[][] confusionMtx = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}};
117 int[][] confusionMtxWithNormalization = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}};
118
119 try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
120 for (Cache.Entry<Integer, Vector> observation : observations) {
121 double[] val = observation.getValue().asArray();
122 double[] inputs = Arrays.copyOfRange(val, 1, val.length);
123 double groundTruth = val[0];
124
125 double prediction = mdl.apply(new DenseLocalOnHeapVector(inputs));
126 double predictionWithNormalization = mdlWithNormalization.apply(new DenseLocalOnHeapVector(inputs));
127
128 totalAmount++;
129
130 // Collect data for model
131 if(groundTruth != prediction)
132 amountOfErrors++;
133
134 int idx1 = (int)prediction == 1 ? 0 : ((int)prediction == 3 ? 1 : 2);
135 int idx2 = (int)groundTruth == 1 ? 0 : ((int)groundTruth == 3 ? 1 : 2);
136
137 confusionMtx[idx1][idx2]++;
138
139 // Collect data for model with minmaxscaling
140 if(groundTruth != predictionWithNormalization)
141 amountOfErrorsWithNormalization++;
142
143 idx1 = (int)predictionWithNormalization == 1 ? 0 : ((int)predictionWithNormalization == 3 ? 1 : 2);
144 idx2 = (int)groundTruth == 1 ? 0 : ((int)groundTruth == 3 ? 1 : 2);
145
146 confusionMtxWithNormalization[idx1][idx2]++;
147
148 System.out.printf(">>> | %.4f\t\t| %.4f\t\t\t\t\t\t| %.4f\t\t|\n", prediction, predictionWithNormalization, groundTruth);
149 }
150 System.out.println(">>> ----------------------------------------------------------------");
151 System.out.println("\n>>> -----------------Logistic Regression model-------------");
152 System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
153 System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
154 System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
155
156 System.out.println("\n>>> -----------------Logistic Regression model with Normalization-------------");
157 System.out.println("\n>>> Absolute amount of errors " + amountOfErrorsWithNormalization);
158 System.out.println("\n>>> Accuracy " + (1 - amountOfErrorsWithNormalization / (double)totalAmount));
159 System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtxWithNormalization));
160 }
161 });
162
163 igniteThread.start();
164 igniteThread.join();
165 }
166 }
167
168 /**
169 * Fills cache with data and returns it.
170 *
171 * @param ignite Ignite instance.
172 * @return Filled Ignite Cache.
173 */
174 private static IgniteCache<Integer, Vector> getTestCache(Ignite ignite) {
175 CacheConfiguration<Integer, Vector> cacheConfiguration = new CacheConfiguration<>();
176 cacheConfiguration.setName("TEST_" + UUID.randomUUID());
177 cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10));
178
179 IgniteCache<Integer, Vector> cache = ignite.createCache(cacheConfiguration);
180
181 for (int i = 0; i < data.length; i++)
182 cache.put(i, VectorUtils.of(data[i]));
183
184 return cache;
185 }
186
187 /** The preprocessed Glass dataset from the Machine Learning Repository https://archive.ics.uci.edu/ml/datasets/Glass+Identification
188 * There are 3 classes with labels: 1 {building_windows_float_processed}, 3 {vehicle_windows_float_processed}, 7 {headlamps}.
189 * Feature names: 'Na-Sodium', 'Mg-Magnesium', 'Al-Aluminum', 'Ba-Barium', 'Fe-Iron'.
190 */
191 private static final double[][] data = {
192 {1, 1.52101, 4.49, 1.10, 0.00, 0.00},
193 {1, 1.51761, 3.60, 1.36, 0.00, 0.00},
194 {1, 1.51618, 3.55, 1.54, 0.00, 0.00},
195 {1, 1.51766, 3.69, 1.29, 0.00, 0.00},
196 {1, 1.51742, 3.62, 1.24, 0.00, 0.00},
197 {1, 1.51596, 3.61, 1.62, 0.00, 0.26},
198 {1, 1.51743, 3.60, 1.14, 0.00, 0.00},
199 {1, 1.51756, 3.61, 1.05, 0.00, 0.00},
200 {1, 1.51918, 3.58, 1.37, 0.00, 0.00},
201 {1, 1.51755, 3.60, 1.36, 0.00, 0.11},
202 {1, 1.51571, 3.46, 1.56, 0.00, 0.24},
203 {1, 1.51763, 3.66, 1.27, 0.00, 0.00},
204 {1, 1.51589, 3.43, 1.40, 0.00, 0.24},
205 {1, 1.51748, 3.56, 1.27, 0.00, 0.17},
206 {1, 1.51763, 3.59, 1.31, 0.00, 0.00},
207 {1, 1.51761, 3.54, 1.23, 0.00, 0.00},
208 {1, 1.51784, 3.67, 1.16, 0.00, 0.00},
209 {1, 1.52196, 3.85, 0.89, 0.00, 0.00},
210 {1, 1.51911, 3.73, 1.18, 0.00, 0.00},
211 {1, 1.51735, 3.54, 1.69, 0.00, 0.07},
212 {1, 1.51750, 3.55, 1.49, 0.00, 0.19},
213 {1, 1.51966, 3.75, 0.29, 0.00, 0.00},
214 {1, 1.51736, 3.62, 1.29, 0.00, 0.00},
215 {1, 1.51751, 3.57, 1.35, 0.00, 0.00},
216 {1, 1.51720, 3.50, 1.15, 0.00, 0.00},
217 {1, 1.51764, 3.54, 1.21, 0.00, 0.00},
218 {1, 1.51793, 3.48, 1.41, 0.00, 0.00},
219 {1, 1.51721, 3.48, 1.33, 0.00, 0.00},
220 {1, 1.51768, 3.52, 1.43, 0.00, 0.00},
221 {1, 1.51784, 3.49, 1.28, 0.00, 0.00},
222 {1, 1.51768, 3.56, 1.30, 0.00, 0.14},
223 {1, 1.51747, 3.50, 1.14, 0.00, 0.00},
224 {1, 1.51775, 3.48, 1.23, 0.09, 0.22},
225 {1, 1.51753, 3.47, 1.38, 0.00, 0.06},
226 {1, 1.51783, 3.54, 1.34, 0.00, 0.00},
227 {1, 1.51567, 3.45, 1.21, 0.00, 0.00},
228 {1, 1.51909, 3.53, 1.32, 0.11, 0.00},
229 {1, 1.51797, 3.48, 1.35, 0.00, 0.00},
230 {1, 1.52213, 3.82, 0.47, 0.00, 0.00},
231 {1, 1.52213, 3.82, 0.47, 0.00, 0.00},
232 {1, 1.51793, 3.50, 1.12, 0.00, 0.00},
233 {1, 1.51755, 3.42, 1.20, 0.00, 0.00},
234 {1, 1.51779, 3.39, 1.33, 0.00, 0.00},
235 {1, 1.52210, 3.84, 0.72, 0.00, 0.00},
236 {1, 1.51786, 3.43, 1.19, 0.00, 0.30},
237 {1, 1.51900, 3.48, 1.35, 0.00, 0.00},
238 {1, 1.51869, 3.37, 1.18, 0.00, 0.16},
239 {1, 1.52667, 3.70, 0.71, 0.00, 0.10},
240 {1, 1.52223, 3.77, 0.79, 0.00, 0.00},
241 {1, 1.51898, 3.35, 1.23, 0.00, 0.00},
242 {1, 1.52320, 3.72, 0.51, 0.00, 0.16},
243 {1, 1.51926, 3.33, 1.28, 0.00, 0.11},
244 {1, 1.51808, 2.87, 1.19, 0.00, 0.00},
245 {1, 1.51837, 2.84, 1.28, 0.00, 0.00},
246 {1, 1.51778, 2.81, 1.29, 0.00, 0.09},
247 {1, 1.51769, 2.71, 1.29, 0.00, 0.24},
248 {1, 1.51215, 3.47, 1.12, 0.00, 0.31},
249 {1, 1.51824, 3.48, 1.29, 0.00, 0.00},
250 {1, 1.51754, 3.74, 1.17, 0.00, 0.00},
251 {1, 1.51754, 3.66, 1.19, 0.00, 0.11},
252 {1, 1.51905, 3.62, 1.11, 0.00, 0.00},
253 {1, 1.51977, 3.58, 1.32, 0.69, 0.00},
254 {1, 1.52172, 3.86, 0.88, 0.00, 0.11},
255 {1, 1.52227, 3.81, 0.78, 0.00, 0.00},
256 {1, 1.52172, 3.74, 0.90, 0.00, 0.07},
257 {1, 1.52099, 3.59, 1.12, 0.00, 0.00},
258 {1, 1.52152, 3.65, 0.87, 0.00, 0.17},
259 {1, 1.52152, 3.65, 0.87, 0.00, 0.17},
260 {1, 1.52152, 3.58, 0.90, 0.00, 0.16},
261 {1, 1.52300, 3.58, 0.82, 0.00, 0.03},
262 {3, 1.51769, 3.66, 1.11, 0.00, 0.00},
263 {3, 1.51610, 3.53, 1.34, 0.00, 0.00},
264 {3, 1.51670, 3.57, 1.38, 0.00, 0.10},
265 {3, 1.51643, 3.52, 1.35, 0.00, 0.00},
266 {3, 1.51665, 3.45, 1.76, 0.00, 0.17},
267 {3, 1.52127, 3.90, 0.83, 0.00, 0.00},
268 {3, 1.51779, 3.65, 0.65, 0.00, 0.00},
269 {3, 1.51610, 3.40, 1.22, 0.00, 0.00},
270 {3, 1.51694, 3.58, 1.31, 0.00, 0.00},
271 {3, 1.51646, 3.40, 1.26, 0.00, 0.00},
272 {3, 1.51655, 3.39, 1.28, 0.00, 0.00},
273 {3, 1.52121, 3.76, 0.58, 0.00, 0.00},
274 {3, 1.51776, 3.41, 1.52, 0.00, 0.00},
275 {3, 1.51796, 3.36, 1.63, 0.00, 0.09},
276 {3, 1.51832, 3.34, 1.54, 0.00, 0.00},
277 {3, 1.51934, 3.54, 0.75, 0.15, 0.24},
278 {3, 1.52211, 3.78, 0.91, 0.00, 0.37},
279 {7, 1.51131, 3.20, 1.81, 1.19, 0.00},
280 {7, 1.51838, 3.26, 2.22, 1.63, 0.00},
281 {7, 1.52315, 3.34, 1.23, 0.00, 0.00},
282 {7, 1.52247, 2.20, 2.06, 0.00, 0.00},
283 {7, 1.52365, 1.83, 1.31, 1.68, 0.00},
284 {7, 1.51613, 1.78, 1.79, 0.76, 0.00},
285 {7, 1.51602, 0.00, 2.38, 0.64, 0.09},
286 {7, 1.51623, 0.00, 2.79, 0.40, 0.09},
287 {7, 1.51719, 0.00, 2.00, 1.59, 0.08},
288 {7, 1.51683, 0.00, 1.98, 1.57, 0.07},
289 {7, 1.51545, 0.00, 2.68, 0.61, 0.05},
290 {7, 1.51556, 0.00, 2.54, 0.81, 0.01},
291 {7, 1.51727, 0.00, 2.34, 0.66, 0.00},
292 {7, 1.51531, 0.00, 2.66, 0.64, 0.00},
293 {7, 1.51609, 0.00, 2.51, 0.53, 0.00},
294 {7, 1.51508, 0.00, 2.25, 0.63, 0.00},
295 {7, 1.51653, 0.00, 1.19, 0.00, 0.00},
296 {7, 1.51514, 0.00, 2.42, 0.56, 0.00},
297 {7, 1.51658, 0.00, 1.99, 1.71, 0.00},
298 {7, 1.51617, 0.00, 2.27, 0.67, 0.00},
299 {7, 1.51732, 0.00, 1.80, 1.55, 0.00},
300 {7, 1.51645, 0.00, 1.87, 1.38, 0.00},
301 {7, 1.51831, 0.00, 1.82, 2.88, 0.00},
302 {7, 1.51640, 0.00, 2.74, 0.54, 0.00},
303 {7, 1.51623, 0.00, 2.88, 1.06, 0.00},
304 {7, 1.51685, 0.00, 1.99, 1.59, 0.00},
305 {7, 1.52065, 0.00, 2.02, 1.64, 0.00},
306 {7, 1.51651, 0.00, 1.94, 1.57, 0.00},
307 {7, 1.51711, 0.00, 2.08, 1.67, 0.00},
308 };
309 }