IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / examples / src / main / java / org / apache / ignite / examples / ml / regression / logistic / binary / LogisticRegressionSGDTrainerSample.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.examples.ml.regression.logistic.binary;
19
20 import java.util.Arrays;
21 import java.util.UUID;
22 import javax.cache.Cache;
23 import org.apache.ignite.Ignite;
24 import org.apache.ignite.IgniteCache;
25 import org.apache.ignite.Ignition;
26 import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
27 import org.apache.ignite.cache.query.QueryCursor;
28 import org.apache.ignite.cache.query.ScanQuery;
29 import org.apache.ignite.configuration.CacheConfiguration;
30 import org.apache.ignite.ml.math.VectorUtils;
31 import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
32 import org.apache.ignite.ml.nn.UpdatesStrategy;
33 import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
34 import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
35 import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel;
36 import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer;
37 import org.apache.ignite.thread.IgniteThread;
38
39 /**
40 * Run logistic regression model over distributed cache.
41 *
42 * @see LogisticRegressionSGDTrainer
43 */
44 public class LogisticRegressionSGDTrainerSample {
45 /** Run example. */
46 public static void main(String[] args) throws InterruptedException {
47 System.out.println();
48 System.out.println(">>> Logistic regression model over partitioned dataset usage example started.");
49 // Start ignite grid.
50 try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
51 System.out.println(">>> Ignite grid started.");
52 IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
53 LogisticRegressionSGDTrainerSample.class.getSimpleName(), () -> {
54
55 IgniteCache<Integer, double[]> dataCache = getTestCache(ignite);
56
57 System.out.println(">>> Create new logistic regression trainer object.");
58 LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>(
59 new SimpleGDUpdateCalculator(0.2),
60 SimpleGDParameterUpdate::sumLocal,
61 SimpleGDParameterUpdate::avg
62 ), 100000, 10, 100, 123L);
63
64 System.out.println(">>> Perform the training to get the model.");
65 LogisticRegressionModel mdl = trainer.fit(
66 ignite,
67 dataCache,
68 (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
69 (k, v) -> v[0]
70 ).withRawLabels(true);
71
72 System.out.println(">>> Logistic regression model: " + mdl);
73
74 int amountOfErrors = 0;
75 int totalAmount = 0;
76
77 // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
78 int[][] confusionMtx = {{0, 0}, {0, 0}};
79
80 try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
81 for (Cache.Entry<Integer, double[]> observation : observations) {
82 double[] val = observation.getValue();
83 double[] inputs = Arrays.copyOfRange(val, 1, val.length);
84 double groundTruth = val[0];
85
86 double prediction = mdl.apply(new DenseLocalOnHeapVector(inputs));
87
88 totalAmount++;
89 if(groundTruth != prediction)
90 amountOfErrors++;
91
92 int idx1 = (int)prediction;
93 int idx2 = (int)groundTruth;
94
95 confusionMtx[idx1][idx2]++;
96
97 System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
98 }
99
100 System.out.println(">>> ---------------------------------");
101
102 System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
103 System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
104 }
105
106 System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
107 System.out.println(">>> ---------------------------------");
108 });
109
110 igniteThread.start();
111
112 igniteThread.join();
113 }
114 }
115 /**
116 * Fills cache with data and returns it.
117 *
118 * @param ignite Ignite instance.
119 * @return Filled Ignite Cache.
120 */
121 private static IgniteCache<Integer, double[]> getTestCache(Ignite ignite) {
122 CacheConfiguration<Integer, double[]> cacheConfiguration = new CacheConfiguration<>();
123 cacheConfiguration.setName("TEST_" + UUID.randomUUID());
124 cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10));
125
126 IgniteCache<Integer, double[]> cache = ignite.createCache(cacheConfiguration);
127
128 for (int i = 0; i < data.length; i++)
129 cache.put(i, data[i]);
130
131 return cache;
132 }
133
134
135 /** The 1st and 2nd classes from the Iris dataset. */
136 private static final double[][] data = {
137 {0, 5.1, 3.5, 1.4, 0.2},
138 {0, 4.9, 3, 1.4, 0.2},
139 {0, 4.7, 3.2, 1.3, 0.2},
140 {0, 4.6, 3.1, 1.5, 0.2},
141 {0, 5, 3.6, 1.4, 0.2},
142 {0, 5.4, 3.9, 1.7, 0.4},
143 {0, 4.6, 3.4, 1.4, 0.3},
144 {0, 5, 3.4, 1.5, 0.2},
145 {0, 4.4, 2.9, 1.4, 0.2},
146 {0, 4.9, 3.1, 1.5, 0.1},
147 {0, 5.4, 3.7, 1.5, 0.2},
148 {0, 4.8, 3.4, 1.6, 0.2},
149 {0, 4.8, 3, 1.4, 0.1},
150 {0, 4.3, 3, 1.1, 0.1},
151 {0, 5.8, 4, 1.2, 0.2},
152 {0, 5.7, 4.4, 1.5, 0.4},
153 {0, 5.4, 3.9, 1.3, 0.4},
154 {0, 5.1, 3.5, 1.4, 0.3},
155 {0, 5.7, 3.8, 1.7, 0.3},
156 {0, 5.1, 3.8, 1.5, 0.3},
157 {0, 5.4, 3.4, 1.7, 0.2},
158 {0, 5.1, 3.7, 1.5, 0.4},
159 {0, 4.6, 3.6, 1, 0.2},
160 {0, 5.1, 3.3, 1.7, 0.5},
161 {0, 4.8, 3.4, 1.9, 0.2},
162 {0, 5, 3, 1.6, 0.2},
163 {0, 5, 3.4, 1.6, 0.4},
164 {0, 5.2, 3.5, 1.5, 0.2},
165 {0, 5.2, 3.4, 1.4, 0.2},
166 {0, 4.7, 3.2, 1.6, 0.2},
167 {0, 4.8, 3.1, 1.6, 0.2},
168 {0, 5.4, 3.4, 1.5, 0.4},
169 {0, 5.2, 4.1, 1.5, 0.1},
170 {0, 5.5, 4.2, 1.4, 0.2},
171 {0, 4.9, 3.1, 1.5, 0.1},
172 {0, 5, 3.2, 1.2, 0.2},
173 {0, 5.5, 3.5, 1.3, 0.2},
174 {0, 4.9, 3.1, 1.5, 0.1},
175 {0, 4.4, 3, 1.3, 0.2},
176 {0, 5.1, 3.4, 1.5, 0.2},
177 {0, 5, 3.5, 1.3, 0.3},
178 {0, 4.5, 2.3, 1.3, 0.3},
179 {0, 4.4, 3.2, 1.3, 0.2},
180 {0, 5, 3.5, 1.6, 0.6},
181 {0, 5.1, 3.8, 1.9, 0.4},
182 {0, 4.8, 3, 1.4, 0.3},
183 {0, 5.1, 3.8, 1.6, 0.2},
184 {0, 4.6, 3.2, 1.4, 0.2},
185 {0, 5.3, 3.7, 1.5, 0.2},
186 {0, 5, 3.3, 1.4, 0.2},
187 {1, 7, 3.2, 4.7, 1.4},
188 {1, 6.4, 3.2, 4.5, 1.5},
189 {1, 6.9, 3.1, 4.9, 1.5},
190 {1, 5.5, 2.3, 4, 1.3},
191 {1, 6.5, 2.8, 4.6, 1.5},
192 {1, 5.7, 2.8, 4.5, 1.3},
193 {1, 6.3, 3.3, 4.7, 1.6},
194 {1, 4.9, 2.4, 3.3, 1},
195 {1, 6.6, 2.9, 4.6, 1.3},
196 {1, 5.2, 2.7, 3.9, 1.4},
197 {1, 5, 2, 3.5, 1},
198 {1, 5.9, 3, 4.2, 1.5},
199 {1, 6, 2.2, 4, 1},
200 {1, 6.1, 2.9, 4.7, 1.4},
201 {1, 5.6, 2.9, 3.6, 1.3},
202 {1, 6.7, 3.1, 4.4, 1.4},
203 {1, 5.6, 3, 4.5, 1.5},
204 {1, 5.8, 2.7, 4.1, 1},
205 {1, 6.2, 2.2, 4.5, 1.5},
206 {1, 5.6, 2.5, 3.9, 1.1},
207 {1, 5.9, 3.2, 4.8, 1.8},
208 {1, 6.1, 2.8, 4, 1.3},
209 {1, 6.3, 2.5, 4.9, 1.5},
210 {1, 6.1, 2.8, 4.7, 1.2},
211 {1, 6.4, 2.9, 4.3, 1.3},
212 {1, 6.6, 3, 4.4, 1.4},
213 {1, 6.8, 2.8, 4.8, 1.4},
214 {1, 6.7, 3, 5, 1.7},
215 {1, 6, 2.9, 4.5, 1.5},
216 {1, 5.7, 2.6, 3.5, 1},
217 {1, 5.5, 2.4, 3.8, 1.1},
218 {1, 5.5, 2.4, 3.7, 1},
219 {1, 5.8, 2.7, 3.9, 1.2},
220 {1, 6, 2.7, 5.1, 1.6},
221 {1, 5.4, 3, 4.5, 1.5},
222 {1, 6, 3.4, 4.5, 1.6},
223 {1, 6.7, 3.1, 4.7, 1.5},
224 {1, 6.3, 2.3, 4.4, 1.3},
225 {1, 5.6, 3, 4.1, 1.3},
226 {1, 5.5, 2.5, 4, 1.3},
227 {1, 5.5, 2.6, 4.4, 1.2},
228 {1, 6.1, 3, 4.6, 1.4},
229 {1, 5.8, 2.6, 4, 1.2},
230 {1, 5, 2.3, 3.3, 1},
231 {1, 5.6, 2.7, 4.2, 1.3},
232 {1, 5.7, 3, 4.2, 1.2},
233 {1, 5.7, 2.9, 4.2, 1.3},
234 {1, 6.2, 2.9, 4.3, 1.3},
235 {1, 5.1, 2.5, 3, 1.1},
236 {1, 5.7, 2.8, 4.1, 1.3},
237 };
238
239 }