IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / examples / src / main / java / org / apache / ignite / examples / ml / svm / multiclass / SVMMultiClassClassificationExample.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.examples.ml.svm.multiclass;
19
20 import java.util.Arrays;
21 import java.util.UUID;
22 import javax.cache.Cache;
23 import org.apache.ignite.Ignite;
24 import org.apache.ignite.IgniteCache;
25 import org.apache.ignite.Ignition;
26 import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
27 import org.apache.ignite.cache.query.QueryCursor;
28 import org.apache.ignite.cache.query.ScanQuery;
29 import org.apache.ignite.configuration.CacheConfiguration;
30 import org.apache.ignite.ml.math.Vector;
31 import org.apache.ignite.ml.math.VectorUtils;
32 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
33 import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
34 import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer;
35 import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationModel;
36 import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationTrainer;
37 import org.apache.ignite.thread.IgniteThread;
38
39 /**
40 * Run SVM multi-class classification trainer over distributed dataset to build two models:
41 * one with minmaxscaling and one without minmaxscaling.
42 *
43 * @see SVMLinearMultiClassClassificationModel
44 */
45 public class SVMMultiClassClassificationExample {
46 /** Run example. */
47 public static void main(String[] args) throws InterruptedException {
48 System.out.println();
49 System.out.println(">>> SVM Multi-class classification model over cached dataset usage example started.");
50 // Start ignite grid.
51 try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
52 System.out.println(">>> Ignite grid started.");
53
54 IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
55 SVMMultiClassClassificationExample.class.getSimpleName(), () -> {
56 IgniteCache<Integer, Vector> dataCache = getTestCache(ignite);
57
58 SVMLinearMultiClassClassificationTrainer trainer = new SVMLinearMultiClassClassificationTrainer();
59
60 SVMLinearMultiClassClassificationModel mdl = trainer.fit(
61 ignite,
62 dataCache,
63 (k, v) -> {
64 double[] arr = v.asArray();
65 return VectorUtils.of(Arrays.copyOfRange(arr, 1, arr.length));
66 },
67 (k, v) -> v.get(0)
68 );
69
70 System.out.println(">>> SVM Multi-class model");
71 System.out.println(mdl.toString());
72
73 MinMaxScalerTrainer<Integer, Vector> normalizationTrainer = new MinMaxScalerTrainer<>();
74
75 IgniteBiFunction<Integer, Vector, Vector> preprocessor = normalizationTrainer.fit(
76 ignite,
77 dataCache,
78 (k, v) -> {
79 double[] arr = v.asArray();
80 return VectorUtils.of(Arrays.copyOfRange(arr, 1, arr.length));
81 }
82 );
83
84 SVMLinearMultiClassClassificationModel mdlWithNormalization = trainer.fit(
85 ignite,
86 dataCache,
87 preprocessor,
88 (k, v) -> v.get(0)
89 );
90
91 System.out.println(">>> SVM Multi-class model with minmaxscaling");
92 System.out.println(mdlWithNormalization.toString());
93
94 System.out.println(">>> ----------------------------------------------------------------");
95 System.out.println(">>> | Prediction\t| Prediction with Normalization\t| Ground Truth\t|");
96 System.out.println(">>> ----------------------------------------------------------------");
97
98 int amountOfErrors = 0;
99 int amountOfErrorsWithNormalization = 0;
100 int totalAmount = 0;
101
102 // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
103 int[][] confusionMtx = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}};
104 int[][] confusionMtxWithNormalization = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}};
105
106 try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
107 for (Cache.Entry<Integer, Vector> observation : observations) {
108 double[] val = observation.getValue().asArray();
109 double[] inputs = Arrays.copyOfRange(val, 1, val.length);
110 double groundTruth = val[0];
111
112 double prediction = mdl.apply(new DenseLocalOnHeapVector(inputs));
113 double predictionWithNormalization = mdlWithNormalization.apply(new DenseLocalOnHeapVector(inputs));
114
115 totalAmount++;
116
117 // Collect data for model
118 if(groundTruth != prediction)
119 amountOfErrors++;
120
121 int idx1 = (int)prediction == 1 ? 0 : ((int)prediction == 3 ? 1 : 2);
122 int idx2 = (int)groundTruth == 1 ? 0 : ((int)groundTruth == 3 ? 1 : 2);
123
124 confusionMtx[idx1][idx2]++;
125
126 // Collect data for model with minmaxscaling
127 if(groundTruth != predictionWithNormalization)
128 amountOfErrorsWithNormalization++;
129
130 idx1 = (int)predictionWithNormalization == 1 ? 0 : ((int)predictionWithNormalization == 3 ? 1 : 2);
131 idx2 = (int)groundTruth == 1 ? 0 : ((int)groundTruth == 3 ? 1 : 2);
132
133 confusionMtxWithNormalization[idx1][idx2]++;
134
135 System.out.printf(">>> | %.4f\t\t| %.4f\t\t\t\t\t\t| %.4f\t\t|\n", prediction, predictionWithNormalization, groundTruth);
136 }
137 System.out.println(">>> ----------------------------------------------------------------");
138 System.out.println("\n>>> -----------------SVM model-------------");
139 System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
140 System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
141 System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
142
143 System.out.println("\n>>> -----------------SVM model with Normalization-------------");
144 System.out.println("\n>>> Absolute amount of errors " + amountOfErrorsWithNormalization);
145 System.out.println("\n>>> Accuracy " + (1 - amountOfErrorsWithNormalization / (double)totalAmount));
146 System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtxWithNormalization));
147 }
148 });
149
150 igniteThread.start();
151 igniteThread.join();
152 }
153 }
154
155 /**
156 * Fills cache with data and returns it.
157 *
158 * @param ignite Ignite instance.
159 * @return Filled Ignite Cache.
160 */
161 private static IgniteCache<Integer, Vector> getTestCache(Ignite ignite) {
162 CacheConfiguration<Integer, Vector> cacheConfiguration = new CacheConfiguration<>();
163 cacheConfiguration.setName("TEST_" + UUID.randomUUID());
164 cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10));
165
166 IgniteCache<Integer, Vector> cache = ignite.createCache(cacheConfiguration);
167
168 for (int i = 0; i < data.length; i++)
169 cache.put(i, VectorUtils.of(data[i]));
170
171 return cache;
172 }
173
174 /** The preprocessed Glass dataset from the Machine Learning Repository https://archive.ics.uci.edu/ml/datasets/Glass+Identification
175 * There are 3 classes with labels: 1 {building_windows_float_processed}, 3 {vehicle_windows_float_processed}, 7 {headlamps}.
176 * Feature names: 'Na-Sodium', 'Mg-Magnesium', 'Al-Aluminum', 'Ba-Barium', 'Fe-Iron'.
177 */
178 private static final double[][] data = {
179 {1, 1.52101, 4.49, 1.10, 0.00, 0.00},
180 {1, 1.51761, 3.60, 1.36, 0.00, 0.00},
181 {1, 1.51618, 3.55, 1.54, 0.00, 0.00},
182 {1, 1.51766, 3.69, 1.29, 0.00, 0.00},
183 {1, 1.51742, 3.62, 1.24, 0.00, 0.00},
184 {1, 1.51596, 3.61, 1.62, 0.00, 0.26},
185 {1, 1.51743, 3.60, 1.14, 0.00, 0.00},
186 {1, 1.51756, 3.61, 1.05, 0.00, 0.00},
187 {1, 1.51918, 3.58, 1.37, 0.00, 0.00},
188 {1, 1.51755, 3.60, 1.36, 0.00, 0.11},
189 {1, 1.51571, 3.46, 1.56, 0.00, 0.24},
190 {1, 1.51763, 3.66, 1.27, 0.00, 0.00},
191 {1, 1.51589, 3.43, 1.40, 0.00, 0.24},
192 {1, 1.51748, 3.56, 1.27, 0.00, 0.17},
193 {1, 1.51763, 3.59, 1.31, 0.00, 0.00},
194 {1, 1.51761, 3.54, 1.23, 0.00, 0.00},
195 {1, 1.51784, 3.67, 1.16, 0.00, 0.00},
196 {1, 1.52196, 3.85, 0.89, 0.00, 0.00},
197 {1, 1.51911, 3.73, 1.18, 0.00, 0.00},
198 {1, 1.51735, 3.54, 1.69, 0.00, 0.07},
199 {1, 1.51750, 3.55, 1.49, 0.00, 0.19},
200 {1, 1.51966, 3.75, 0.29, 0.00, 0.00},
201 {1, 1.51736, 3.62, 1.29, 0.00, 0.00},
202 {1, 1.51751, 3.57, 1.35, 0.00, 0.00},
203 {1, 1.51720, 3.50, 1.15, 0.00, 0.00},
204 {1, 1.51764, 3.54, 1.21, 0.00, 0.00},
205 {1, 1.51793, 3.48, 1.41, 0.00, 0.00},
206 {1, 1.51721, 3.48, 1.33, 0.00, 0.00},
207 {1, 1.51768, 3.52, 1.43, 0.00, 0.00},
208 {1, 1.51784, 3.49, 1.28, 0.00, 0.00},
209 {1, 1.51768, 3.56, 1.30, 0.00, 0.14},
210 {1, 1.51747, 3.50, 1.14, 0.00, 0.00},
211 {1, 1.51775, 3.48, 1.23, 0.09, 0.22},
212 {1, 1.51753, 3.47, 1.38, 0.00, 0.06},
213 {1, 1.51783, 3.54, 1.34, 0.00, 0.00},
214 {1, 1.51567, 3.45, 1.21, 0.00, 0.00},
215 {1, 1.51909, 3.53, 1.32, 0.11, 0.00},
216 {1, 1.51797, 3.48, 1.35, 0.00, 0.00},
217 {1, 1.52213, 3.82, 0.47, 0.00, 0.00},
218 {1, 1.52213, 3.82, 0.47, 0.00, 0.00},
219 {1, 1.51793, 3.50, 1.12, 0.00, 0.00},
220 {1, 1.51755, 3.42, 1.20, 0.00, 0.00},
221 {1, 1.51779, 3.39, 1.33, 0.00, 0.00},
222 {1, 1.52210, 3.84, 0.72, 0.00, 0.00},
223 {1, 1.51786, 3.43, 1.19, 0.00, 0.30},
224 {1, 1.51900, 3.48, 1.35, 0.00, 0.00},
225 {1, 1.51869, 3.37, 1.18, 0.00, 0.16},
226 {1, 1.52667, 3.70, 0.71, 0.00, 0.10},
227 {1, 1.52223, 3.77, 0.79, 0.00, 0.00},
228 {1, 1.51898, 3.35, 1.23, 0.00, 0.00},
229 {1, 1.52320, 3.72, 0.51, 0.00, 0.16},
230 {1, 1.51926, 3.33, 1.28, 0.00, 0.11},
231 {1, 1.51808, 2.87, 1.19, 0.00, 0.00},
232 {1, 1.51837, 2.84, 1.28, 0.00, 0.00},
233 {1, 1.51778, 2.81, 1.29, 0.00, 0.09},
234 {1, 1.51769, 2.71, 1.29, 0.00, 0.24},
235 {1, 1.51215, 3.47, 1.12, 0.00, 0.31},
236 {1, 1.51824, 3.48, 1.29, 0.00, 0.00},
237 {1, 1.51754, 3.74, 1.17, 0.00, 0.00},
238 {1, 1.51754, 3.66, 1.19, 0.00, 0.11},
239 {1, 1.51905, 3.62, 1.11, 0.00, 0.00},
240 {1, 1.51977, 3.58, 1.32, 0.69, 0.00},
241 {1, 1.52172, 3.86, 0.88, 0.00, 0.11},
242 {1, 1.52227, 3.81, 0.78, 0.00, 0.00},
243 {1, 1.52172, 3.74, 0.90, 0.00, 0.07},
244 {1, 1.52099, 3.59, 1.12, 0.00, 0.00},
245 {1, 1.52152, 3.65, 0.87, 0.00, 0.17},
246 {1, 1.52152, 3.65, 0.87, 0.00, 0.17},
247 {1, 1.52152, 3.58, 0.90, 0.00, 0.16},
248 {1, 1.52300, 3.58, 0.82, 0.00, 0.03},
249 {3, 1.51769, 3.66, 1.11, 0.00, 0.00},
250 {3, 1.51610, 3.53, 1.34, 0.00, 0.00},
251 {3, 1.51670, 3.57, 1.38, 0.00, 0.10},
252 {3, 1.51643, 3.52, 1.35, 0.00, 0.00},
253 {3, 1.51665, 3.45, 1.76, 0.00, 0.17},
254 {3, 1.52127, 3.90, 0.83, 0.00, 0.00},
255 {3, 1.51779, 3.65, 0.65, 0.00, 0.00},
256 {3, 1.51610, 3.40, 1.22, 0.00, 0.00},
257 {3, 1.51694, 3.58, 1.31, 0.00, 0.00},
258 {3, 1.51646, 3.40, 1.26, 0.00, 0.00},
259 {3, 1.51655, 3.39, 1.28, 0.00, 0.00},
260 {3, 1.52121, 3.76, 0.58, 0.00, 0.00},
261 {3, 1.51776, 3.41, 1.52, 0.00, 0.00},
262 {3, 1.51796, 3.36, 1.63, 0.00, 0.09},
263 {3, 1.51832, 3.34, 1.54, 0.00, 0.00},
264 {3, 1.51934, 3.54, 0.75, 0.15, 0.24},
265 {3, 1.52211, 3.78, 0.91, 0.00, 0.37},
266 {7, 1.51131, 3.20, 1.81, 1.19, 0.00},
267 {7, 1.51838, 3.26, 2.22, 1.63, 0.00},
268 {7, 1.52315, 3.34, 1.23, 0.00, 0.00},
269 {7, 1.52247, 2.20, 2.06, 0.00, 0.00},
270 {7, 1.52365, 1.83, 1.31, 1.68, 0.00},
271 {7, 1.51613, 1.78, 1.79, 0.76, 0.00},
272 {7, 1.51602, 0.00, 2.38, 0.64, 0.09},
273 {7, 1.51623, 0.00, 2.79, 0.40, 0.09},
274 {7, 1.51719, 0.00, 2.00, 1.59, 0.08},
275 {7, 1.51683, 0.00, 1.98, 1.57, 0.07},
276 {7, 1.51545, 0.00, 2.68, 0.61, 0.05},
277 {7, 1.51556, 0.00, 2.54, 0.81, 0.01},
278 {7, 1.51727, 0.00, 2.34, 0.66, 0.00},
279 {7, 1.51531, 0.00, 2.66, 0.64, 0.00},
280 {7, 1.51609, 0.00, 2.51, 0.53, 0.00},
281 {7, 1.51508, 0.00, 2.25, 0.63, 0.00},
282 {7, 1.51653, 0.00, 1.19, 0.00, 0.00},
283 {7, 1.51514, 0.00, 2.42, 0.56, 0.00},
284 {7, 1.51658, 0.00, 1.99, 1.71, 0.00},
285 {7, 1.51617, 0.00, 2.27, 0.67, 0.00},
286 {7, 1.51732, 0.00, 1.80, 1.55, 0.00},
287 {7, 1.51645, 0.00, 1.87, 1.38, 0.00},
288 {7, 1.51831, 0.00, 1.82, 2.88, 0.00},
289 {7, 1.51640, 0.00, 2.74, 0.54, 0.00},
290 {7, 1.51623, 0.00, 2.88, 1.06, 0.00},
291 {7, 1.51685, 0.00, 1.99, 1.59, 0.00},
292 {7, 1.52065, 0.00, 2.02, 1.64, 0.00},
293 {7, 1.51651, 0.00, 1.94, 1.57, 0.00},
294 {7, 1.51711, 0.00, 2.08, 1.67, 0.00},
295 };
296 }