IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / examples / src / main / java / org / apache / ignite / examples / ml / clustering / KMeansClusterizationExample.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.examples.ml.clustering;
19
20 import java.util.Arrays;
21 import java.util.UUID;
22 import javax.cache.Cache;
23 import org.apache.ignite.Ignite;
24 import org.apache.ignite.IgniteCache;
25 import org.apache.ignite.Ignition;
26 import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
27 import org.apache.ignite.cache.query.QueryCursor;
28 import org.apache.ignite.cache.query.ScanQuery;
29 import org.apache.ignite.configuration.CacheConfiguration;
30 import org.apache.ignite.ml.clustering.kmeans.KMeansModel;
31 import org.apache.ignite.ml.clustering.kmeans.KMeansTrainer;
32 import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer;
33 import org.apache.ignite.ml.math.Tracer;
34 import org.apache.ignite.ml.math.VectorUtils;
35 import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
36 import org.apache.ignite.thread.IgniteThread;
37
38 /**
39 * Run kNN multi-class classification trainer over distributed dataset.
40 *
41 * @see KNNClassificationTrainer
42 */
43 public class KMeansClusterizationExample {
44 /** Run example. */
45 public static void main(String[] args) throws InterruptedException {
46 System.out.println();
47 System.out.println(">>> KMeans clustering algorithm over cached dataset usage example started.");
48 // Start ignite grid.
49 try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
50 System.out.println(">>> Ignite grid started.");
51
52 IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
53 KMeansClusterizationExample.class.getSimpleName(), () -> {
54 IgniteCache<Integer, double[]> dataCache = getTestCache(ignite);
55
56 KMeansTrainer trainer = new KMeansTrainer()
57 .withSeed(7867L);
58
59 KMeansModel mdl = trainer.fit(
60 ignite,
61 dataCache,
62 (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
63 (k, v) -> v[0]
64 );
65
66 System.out.println(">>> KMeans centroids");
67 Tracer.showAscii(mdl.centers()[0]);
68 Tracer.showAscii(mdl.centers()[1]);
69 System.out.println(">>>");
70
71 System.out.println(">>> -----------------------------------");
72 System.out.println(">>> | Predicted cluster\t| Real Label\t|");
73 System.out.println(">>> -----------------------------------");
74
75 int amountOfErrors = 0;
76 int totalAmount = 0;
77
78 try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
79 for (Cache.Entry<Integer, double[]> observation : observations) {
80 double[] val = observation.getValue();
81 double[] inputs = Arrays.copyOfRange(val, 1, val.length);
82 double groundTruth = val[0];
83
84 double prediction = mdl.apply(new DenseLocalOnHeapVector(inputs));
85
86 totalAmount++;
87 if (groundTruth != prediction)
88 amountOfErrors++;
89
90 System.out.printf(">>> | %.4f\t\t\t| %.4f\t\t|\n", prediction, groundTruth);
91 }
92
93 System.out.println(">>> ---------------------------------");
94
95 System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
96 System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
97 }
98 });
99
100 igniteThread.start();
101 igniteThread.join();
102 }
103 }
104
105 /**
106 * Fills cache with data and returns it.
107 *
108 * @param ignite Ignite instance.
109 * @return Filled Ignite Cache.
110 */
111 private static IgniteCache<Integer, double[]> getTestCache(Ignite ignite) {
112 CacheConfiguration<Integer, double[]> cacheConfiguration = new CacheConfiguration<>();
113 cacheConfiguration.setName("TEST_" + UUID.randomUUID());
114 cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10));
115
116 IgniteCache<Integer, double[]> cache = ignite.createCache(cacheConfiguration);
117
118 for (int i = 0; i < data.length; i++)
119 cache.put(i, data[i]);
120
121 return cache;
122 }
123
124 /** The Iris dataset. */
125 private static final double[][] data = {
126 {0, 5.1, 3.5, 1.4, 0.2},
127 {0, 4.9, 3, 1.4, 0.2},
128 {0, 4.7, 3.2, 1.3, 0.2},
129 {0, 4.6, 3.1, 1.5, 0.2},
130 {0, 5, 3.6, 1.4, 0.2},
131 {0, 5.4, 3.9, 1.7, 0.4},
132 {0, 4.6, 3.4, 1.4, 0.3},
133 {0, 5, 3.4, 1.5, 0.2},
134 {0, 4.4, 2.9, 1.4, 0.2},
135 {0, 4.9, 3.1, 1.5, 0.1},
136 {0, 5.4, 3.7, 1.5, 0.2},
137 {0, 4.8, 3.4, 1.6, 0.2},
138 {0, 4.8, 3, 1.4, 0.1},
139 {0, 4.3, 3, 1.1, 0.1},
140 {0, 5.8, 4, 1.2, 0.2},
141 {0, 5.7, 4.4, 1.5, 0.4},
142 {0, 5.4, 3.9, 1.3, 0.4},
143 {0, 5.1, 3.5, 1.4, 0.3},
144 {0, 5.7, 3.8, 1.7, 0.3},
145 {0, 5.1, 3.8, 1.5, 0.3},
146 {0, 5.4, 3.4, 1.7, 0.2},
147 {0, 5.1, 3.7, 1.5, 0.4},
148 {0, 4.6, 3.6, 1, 0.2},
149 {0, 5.1, 3.3, 1.7, 0.5},
150 {0, 4.8, 3.4, 1.9, 0.2},
151 {0, 5, 3, 1.6, 0.2},
152 {0, 5, 3.4, 1.6, 0.4},
153 {0, 5.2, 3.5, 1.5, 0.2},
154 {0, 5.2, 3.4, 1.4, 0.2},
155 {0, 4.7, 3.2, 1.6, 0.2},
156 {0, 4.8, 3.1, 1.6, 0.2},
157 {0, 5.4, 3.4, 1.5, 0.4},
158 {0, 5.2, 4.1, 1.5, 0.1},
159 {0, 5.5, 4.2, 1.4, 0.2},
160 {0, 4.9, 3.1, 1.5, 0.1},
161 {0, 5, 3.2, 1.2, 0.2},
162 {0, 5.5, 3.5, 1.3, 0.2},
163 {0, 4.9, 3.1, 1.5, 0.1},
164 {0, 4.4, 3, 1.3, 0.2},
165 {0, 5.1, 3.4, 1.5, 0.2},
166 {0, 5, 3.5, 1.3, 0.3},
167 {0, 4.5, 2.3, 1.3, 0.3},
168 {0, 4.4, 3.2, 1.3, 0.2},
169 {0, 5, 3.5, 1.6, 0.6},
170 {0, 5.1, 3.8, 1.9, 0.4},
171 {0, 4.8, 3, 1.4, 0.3},
172 {0, 5.1, 3.8, 1.6, 0.2},
173 {0, 4.6, 3.2, 1.4, 0.2},
174 {0, 5.3, 3.7, 1.5, 0.2},
175 {0, 5, 3.3, 1.4, 0.2},
176 {1, 7, 3.2, 4.7, 1.4},
177 {1, 6.4, 3.2, 4.5, 1.5},
178 {1, 6.9, 3.1, 4.9, 1.5},
179 {1, 5.5, 2.3, 4, 1.3},
180 {1, 6.5, 2.8, 4.6, 1.5},
181 {1, 5.7, 2.8, 4.5, 1.3},
182 {1, 6.3, 3.3, 4.7, 1.6},
183 {1, 4.9, 2.4, 3.3, 1},
184 {1, 6.6, 2.9, 4.6, 1.3},
185 {1, 5.2, 2.7, 3.9, 1.4},
186 {1, 5, 2, 3.5, 1},
187 {1, 5.9, 3, 4.2, 1.5},
188 {1, 6, 2.2, 4, 1},
189 {1, 6.1, 2.9, 4.7, 1.4},
190 {1, 5.6, 2.9, 3.6, 1.3},
191 {1, 6.7, 3.1, 4.4, 1.4},
192 {1, 5.6, 3, 4.5, 1.5},
193 {1, 5.8, 2.7, 4.1, 1},
194 {1, 6.2, 2.2, 4.5, 1.5},
195 {1, 5.6, 2.5, 3.9, 1.1},
196 {1, 5.9, 3.2, 4.8, 1.8},
197 {1, 6.1, 2.8, 4, 1.3},
198 {1, 6.3, 2.5, 4.9, 1.5},
199 {1, 6.1, 2.8, 4.7, 1.2},
200 {1, 6.4, 2.9, 4.3, 1.3},
201 {1, 6.6, 3, 4.4, 1.4},
202 {1, 6.8, 2.8, 4.8, 1.4},
203 {1, 6.7, 3, 5, 1.7},
204 {1, 6, 2.9, 4.5, 1.5},
205 {1, 5.7, 2.6, 3.5, 1},
206 {1, 5.5, 2.4, 3.8, 1.1},
207 {1, 5.5, 2.4, 3.7, 1},
208 {1, 5.8, 2.7, 3.9, 1.2},
209 {1, 6, 2.7, 5.1, 1.6},
210 {1, 5.4, 3, 4.5, 1.5},
211 {1, 6, 3.4, 4.5, 1.6},
212 {1, 6.7, 3.1, 4.7, 1.5},
213 {1, 6.3, 2.3, 4.4, 1.3},
214 {1, 5.6, 3, 4.1, 1.3},
215 {1, 5.5, 2.5, 4, 1.3},
216 {1, 5.5, 2.6, 4.4, 1.2},
217 {1, 6.1, 3, 4.6, 1.4},
218 {1, 5.8, 2.6, 4, 1.2},
219 {1, 5, 2.3, 3.3, 1},
220 {1, 5.6, 2.7, 4.2, 1.3},
221 {1, 5.7, 3, 4.2, 1.2},
222 {1, 5.7, 2.9, 4.2, 1.3},
223 {1, 6.2, 2.9, 4.3, 1.3},
224 {1, 5.1, 2.5, 3, 1.1},
225 {1, 5.7, 2.8, 4.1, 1.3},
226 };
227 }