IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / examples / src / main / java / org / apache / ignite / examples / ml / knn / KNNRegressionExample.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.examples.ml.knn;
19
20 import java.util.Arrays;
21 import java.util.UUID;
22 import javax.cache.Cache;
23 import org.apache.ignite.Ignite;
24 import org.apache.ignite.IgniteCache;
25 import org.apache.ignite.Ignition;
26 import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
27 import org.apache.ignite.cache.query.QueryCursor;
28 import org.apache.ignite.cache.query.ScanQuery;
29 import org.apache.ignite.configuration.CacheConfiguration;
30 import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer;
31 import org.apache.ignite.ml.knn.classification.KNNStrategy;
32 import org.apache.ignite.ml.knn.regression.KNNRegressionModel;
33 import org.apache.ignite.ml.knn.regression.KNNRegressionTrainer;
34 import org.apache.ignite.ml.math.VectorUtils;
35 import org.apache.ignite.ml.math.distances.ManhattanDistance;
36 import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
37 import org.apache.ignite.thread.IgniteThread;
38
39 /**
40 * Run kNN regression trainer over distributed dataset.
41 *
42 * @see KNNClassificationTrainer
43 */
44 public class KNNRegressionExample {
45 /** Run example. */
46 public static void main(String[] args) throws InterruptedException {
47 System.out.println();
48 System.out.println(">>> kNN regression over cached dataset usage example started.");
49 // Start ignite grid.
50 try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
51 System.out.println(">>> Ignite grid started.");
52
53 IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
54 KNNRegressionExample.class.getSimpleName(), () -> {
55 IgniteCache<Integer, double[]> dataCache = getTestCache(ignite);
56
57 KNNRegressionTrainer trainer = new KNNRegressionTrainer();
58
59 KNNRegressionModel knnMdl = (KNNRegressionModel) trainer.fit(
60 ignite,
61 dataCache,
62 (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
63 (k, v) -> v[0]
64 ).withK(5)
65 .withDistanceMeasure(new ManhattanDistance())
66 .withStrategy(KNNStrategy.WEIGHTED);
67
68 int totalAmount = 0;
69 // Calculate mean squared error (MSE)
70 double mse = 0.0;
71 // Calculate mean absolute error (MAE)
72 double mae = 0.0;
73
74 try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
75 for (Cache.Entry<Integer, double[]> observation : observations) {
76 double[] val = observation.getValue();
77 double[] inputs = Arrays.copyOfRange(val, 1, val.length);
78 double groundTruth = val[0];
79
80 double prediction = knnMdl.apply(new DenseLocalOnHeapVector(inputs));
81
82 mse += Math.pow(prediction - groundTruth, 2.0);
83 mae += Math.abs(prediction - groundTruth);
84
85 totalAmount++;
86 }
87
88 mse = mse / totalAmount;
89 System.out.println("\n>>> Mean squared error (MSE) " + mse);
90
91 mae = mae / totalAmount;
92 System.out.println("\n>>> Mean absolute error (MAE) " + mae);
93 }
94 });
95
96 igniteThread.start();
97 igniteThread.join();
98 }
99 }
100
101 /**
102 * Fills cache with data and returns it.
103 *
104 * @param ignite Ignite instance.
105 * @return Filled Ignite Cache.
106 */
107 private static IgniteCache<Integer, double[]> getTestCache(Ignite ignite) {
108 CacheConfiguration<Integer, double[]> cacheConfiguration = new CacheConfiguration<>();
109 cacheConfiguration.setName("TEST_" + UUID.randomUUID());
110 cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10));
111
112 IgniteCache<Integer, double[]> cache = ignite.createCache(cacheConfiguration);
113
114 for (int i = 0; i < data.length; i++)
115 cache.put(i, data[i]);
116
117 return cache;
118 }
119
120 /** The Iris dataset. */
121 private static final double[][] data = {
122 {199, 125, 256, 6000, 256, 16, 128},
123 {253, 29, 8000, 32000, 32, 8, 32},
124 {132, 29, 8000, 16000, 32, 8, 16},
125 {290, 26, 8000, 32000, 64, 8, 32},
126 {381, 23, 16000, 32000, 64, 16, 32},
127 {749, 23, 16000, 64000, 64, 16, 32},
128 {1238, 23, 32000, 64000, 128, 32, 64},
129 {23, 400, 1000, 3000, 0, 1, 2},
130 {24, 400, 512, 3500, 4, 1, 6},
131 {70, 60, 2000, 8000, 65, 1, 8},
132 {117, 50, 4000, 16000, 65, 1, 8},
133 {15, 350, 64, 64, 0, 1, 4},
134 {64, 200, 512, 16000, 0, 4, 32},
135 {23, 167, 524, 2000, 8, 4, 15},
136 {29, 143, 512, 5000, 0, 7, 32},
137 {22, 143, 1000, 2000, 0, 5, 16},
138 {124, 110, 5000, 5000, 142, 8, 64},
139 {35, 143, 1500, 6300, 0, 5, 32},
140 {39, 143, 3100, 6200, 0, 5, 20},
141 {40, 143, 2300, 6200, 0, 6, 64},
142 {45, 110, 3100, 6200, 0, 6, 64},
143 {28, 320, 128, 6000, 0, 1, 12},
144 {21, 320, 512, 2000, 4, 1, 3},
145 {28, 320, 256, 6000, 0, 1, 6},
146 {22, 320, 256, 3000, 4, 1, 3},
147 {28, 320, 512, 5000, 4, 1, 5},
148 {27, 320, 256, 5000, 4, 1, 6},
149 {102, 25, 1310, 2620, 131, 12, 24},
150 {74, 50, 2620, 10480, 30, 12, 24},
151 {138, 56, 5240, 20970, 30, 12, 24},
152 {136, 64, 5240, 20970, 30, 12, 24},
153 {23, 50, 500, 2000, 8, 1, 4},
154 {29, 50, 1000, 4000, 8, 1, 5},
155 {44, 50, 2000, 8000, 8, 1, 5},
156 {30, 50, 1000, 4000, 8, 3, 5},
157 {41, 50, 1000, 8000, 8, 3, 5},
158 {74, 50, 2000, 16000, 8, 3, 5},
159 {54, 133, 1000, 12000, 9, 3, 12},
160 {41, 133, 1000, 8000, 9, 3, 12},
161 {18, 810, 512, 512, 8, 1, 1},
162 {28, 810, 1000, 5000, 0, 1, 1},
163 {36, 320, 512, 8000, 4, 1, 5},
164 {38, 200, 512, 8000, 8, 1, 8},
165 {34, 700, 384, 8000, 0, 1, 1},
166 {19, 700, 256, 2000, 0, 1, 1},
167 {72, 140, 1000, 16000, 16, 1, 3},
168 {36, 200, 1000, 8000, 0, 1, 2},
169 {30, 110, 1000, 4000, 16, 1, 2},
170 {56, 110, 1000, 12000, 16, 1, 2},
171 {42, 220, 1000, 8000, 16, 1, 2},
172 {34, 800, 256, 8000, 0, 1, 4},
173 {19, 125, 512, 1000, 0, 8, 20},
174 {75, 75, 2000, 8000, 64, 1, 38},
175 {113, 75, 2000, 16000, 64, 1, 38},
176 {157, 75, 2000, 16000, 128, 1, 38},
177 {18, 90, 256, 1000, 0, 3, 10},
178 {20, 105, 256, 2000, 0, 3, 10},
179 {28, 105, 1000, 4000, 0, 3, 24},
180 {33, 105, 2000, 4000, 8, 3, 19},
181 {47, 75, 2000, 8000, 8, 3, 24},
182 {54, 75, 3000, 8000, 8, 3, 48},
183 {20, 175, 256, 2000, 0, 3, 24},
184 {23, 300, 768, 3000, 0, 6, 24},
185 {25, 300, 768, 3000, 6, 6, 24},
186 {52, 300, 768, 12000, 6, 6, 24},
187 {27, 300, 768, 4500, 0, 1, 24},
188 {50, 300, 384, 12000, 6, 1, 24},
189 {18, 300, 192, 768, 6, 6, 24},
190 {53, 180, 768, 12000, 6, 1, 31},
191 {23, 330, 1000, 3000, 0, 2, 4},
192 {30, 300, 1000, 4000, 8, 3, 64},
193 {73, 300, 1000, 16000, 8, 2, 112},
194 {20, 330, 1000, 2000, 0, 1, 2},
195 {25, 330, 1000, 4000, 0, 3, 6},
196 {28, 140, 2000, 4000, 0, 3, 6},
197 {29, 140, 2000, 4000, 0, 4, 8},
198 {32, 140, 2000, 4000, 8, 1, 20},
199 {175, 140, 2000, 32000, 32, 1, 20},
200 {57, 140, 2000, 8000, 32, 1, 54},
201 {181, 140, 2000, 32000, 32, 1, 54},
202 {32, 140, 2000, 4000, 8, 1, 20},
203 {82, 57, 4000, 16000, 1, 6, 12},
204 {171, 57, 4000, 24000, 64, 12, 16},
205 {361, 26, 16000, 32000, 64, 16, 24},
206 {350, 26, 16000, 32000, 64, 8, 24},
207 {220, 26, 8000, 32000, 0, 8, 24},
208 {113, 26, 8000, 16000, 0, 8, 16},
209 {15, 480, 96, 512, 0, 1, 1},
210 {21, 203, 1000, 2000, 0, 1, 5},
211 {35, 115, 512, 6000, 16, 1, 6},
212 {18, 1100, 512, 1500, 0, 1, 1},
213 {20, 1100, 768, 2000, 0, 1, 1},
214 {20, 600, 768, 2000, 0, 1, 1},
215 {28, 400, 2000, 4000, 0, 1, 1},
216 {45, 400, 4000, 8000, 0, 1, 1},
217 {18, 900, 1000, 1000, 0, 1, 2},
218 {17, 900, 512, 1000, 0, 1, 2},
219 {26, 900, 1000, 4000, 4, 1, 2},
220 {28, 900, 1000, 4000, 8, 1, 2},
221 {28, 900, 2000, 4000, 0, 3, 6},
222 {31, 225, 2000, 4000, 8, 3, 6},
223 {42, 180, 2000, 8000, 8, 1, 6},
224 {76, 185, 2000, 16000, 16, 1, 6},
225 {76, 180, 2000, 16000, 16, 1, 6},
226 {26, 225, 1000, 4000, 2, 3, 6},
227 {59, 25, 2000, 12000, 8, 1, 4},
228 {65, 25, 2000, 12000, 16, 3, 5},
229 {101, 17, 4000, 16000, 8, 6, 12},
230 {116, 17, 4000, 16000, 32, 6, 12},
231 {18, 1500, 768, 1000, 0, 0, 0},
232 {20, 1500, 768, 2000, 0, 0, 0},
233 {20, 800, 768, 2000, 0, 0, 0},
234 {30, 50, 2000, 4000, 0, 3, 6},
235 {44, 50, 2000, 8000, 8, 3, 6},
236 {82, 50, 2000, 16000, 24, 1, 6},
237 {128, 50, 8000, 16000, 48, 1, 10},
238 {37, 100, 1000, 8000, 0, 2, 6},
239 {46, 100, 1000, 8000, 24, 2, 6},
240 {46, 100, 1000, 8000, 24, 3, 6},
241 {80, 50, 2000, 16000, 12, 3, 16},
242 {88, 50, 2000, 16000, 24, 6, 16},
243 {33, 150, 512, 4000, 0, 8, 128},
244 {46, 115, 2000, 8000, 16, 1, 3},
245 {29, 115, 2000, 4000, 2, 1, 5},
246 {53, 92, 2000, 8000, 32, 1, 6},
247 {41, 92, 2000, 8000, 4, 1, 6},
248 {86, 75, 4000, 16000, 16, 1, 6},
249 {95, 60, 4000, 16000, 32, 1, 6},
250 {107, 60, 2000, 16000, 64, 5, 8},
251 {117, 60, 4000, 16000, 64, 5, 8},
252 {119, 50, 4000, 16000, 64, 5, 10},
253 {120, 72, 4000, 16000, 64, 8, 16},
254 {48, 72, 2000, 8000, 16, 6, 8},
255 {126, 40, 8000, 16000, 32, 8, 16},
256 {266, 40, 8000, 32000, 64, 8, 24},
257 {270, 35, 8000, 32000, 64, 8, 24},
258 {426, 38, 16000, 32000, 128, 16, 32},
259 {151, 48, 4000, 24000, 32, 8, 24},
260 {267, 38, 8000, 32000, 64, 8, 24},
261 {603, 30, 16000, 32000, 256, 16, 24},
262 {19, 112, 1000, 1000, 0, 1, 4},
263 {21, 84, 1000, 2000, 0, 1, 6},
264 {26, 56, 1000, 4000, 0, 1, 6},
265 {35, 56, 2000, 6000, 0, 1, 8},
266 {41, 56, 2000, 8000, 0, 1, 8},
267 {47, 56, 4000, 8000, 0, 1, 8},
268 {62, 56, 4000, 12000, 0, 1, 8},
269 {78, 56, 4000, 16000, 0, 1, 8},
270 {80, 38, 4000, 8000, 32, 16, 32},
271 {142, 38, 8000, 16000, 64, 4, 8},
272 {281, 38, 8000, 24000, 160, 4, 8},
273 {190, 38, 4000, 16000, 128, 16, 32},
274 {21, 200, 1000, 2000, 0, 1, 2},
275 {25, 200, 1000, 4000, 0, 1, 4},
276 {67, 200, 2000, 8000, 64, 1, 5},
277 {24, 250, 512, 4000, 0, 1, 7},
278 {24, 250, 512, 4000, 0, 4, 7},
279 {64, 250, 1000, 16000, 1, 1, 8},
280 {25, 160, 512, 4000, 2, 1, 5},
281 {20, 160, 512, 2000, 2, 3, 8},
282 {29, 160, 1000, 4000, 8, 1, 14},
283 {43, 160, 1000, 8000, 16, 1, 14},
284 {53, 160, 2000, 8000, 32, 1, 13},
285 {19, 240, 512, 1000, 8, 1, 3},
286 {22, 240, 512, 2000, 8, 1, 5},
287 {31, 105, 2000, 4000, 8, 3, 8},
288 {41, 105, 2000, 6000, 16, 6, 16},
289 {47, 105, 2000, 8000, 16, 4, 14},
290 {99, 52, 4000, 16000, 32, 4, 12},
291 {67, 70, 4000, 12000, 8, 6, 8},
292 {81, 59, 4000, 12000, 32, 6, 12},
293 {149, 59, 8000, 16000, 64, 12, 24},
294 {183, 26, 8000, 24000, 32, 8, 16},
295 {275, 26, 8000, 32000, 64, 12, 16},
296 {382, 26, 8000, 32000, 128, 24, 32},
297 {56, 116, 2000, 8000, 32, 5, 28},
298 {182, 50, 2000, 32000, 24, 6, 26},
299 {227, 50, 2000, 32000, 48, 26, 52},
300 {341, 50, 2000, 32000, 112, 52, 104},
301 {360, 50, 4000, 32000, 112, 52, 104},
302 {919, 30, 8000, 64000, 96, 12, 176},
303 {978, 30, 8000, 64000, 128, 12, 176},
304 {24, 180, 262, 4000, 0, 1, 3},
305 {37, 124, 1000, 8000, 0, 1, 8},
306 {50, 98, 1000, 8000, 32, 2, 8},
307 {41, 125, 2000, 8000, 0, 2, 14},
308 {47, 480, 512, 8000, 32, 0, 0},
309 {25, 480, 1000, 4000, 0, 0, 0}
310 };
311 }