IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / examples / src / main / java / org / apache / ignite / examples / ml / selection / split / TrainTestDatasetSplitterExample.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.examples.ml.selection.split;
19
20 import java.util.Arrays;
21 import java.util.UUID;
22 import javax.cache.Cache;
23 import org.apache.ignite.Ignite;
24 import org.apache.ignite.IgniteCache;
25 import org.apache.ignite.Ignition;
26 import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
27 import org.apache.ignite.cache.query.QueryCursor;
28 import org.apache.ignite.cache.query.ScanQuery;
29 import org.apache.ignite.configuration.CacheConfiguration;
30 import org.apache.ignite.ml.math.VectorUtils;
31 import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
32 import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer;
33 import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
34 import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter;
35 import org.apache.ignite.ml.selection.split.TrainTestSplit;
36 import org.apache.ignite.thread.IgniteThread;
37
38 /**
39 * Run linear regression model over dataset splitted on train and test subsets.
40 *
41 * @see TrainTestDatasetSplitter
42 */
43 public class TrainTestDatasetSplitterExample {
44 /** */
45 private static final double[][] data = {
46 {8, 78, 284, 9.100000381, 109},
47 {9.300000191, 68, 433, 8.699999809, 144},
48 {7.5, 70, 739, 7.199999809, 113},
49 {8.899999619, 96, 1792, 8.899999619, 97},
50 {10.19999981, 74, 477, 8.300000191, 206},
51 {8.300000191, 111, 362, 10.89999962, 124},
52 {8.800000191, 77, 671, 10, 152},
53 {8.800000191, 168, 636, 9.100000381, 162},
54 {10.69999981, 82, 329, 8.699999809, 150},
55 {11.69999981, 89, 634, 7.599999905, 134},
56 {8.5, 149, 631, 10.80000019, 292},
57 {8.300000191, 60, 257, 9.5, 108},
58 {8.199999809, 96, 284, 8.800000191, 111},
59 {7.900000095, 83, 603, 9.5, 182},
60 {10.30000019, 130, 686, 8.699999809, 129},
61 {7.400000095, 145, 345, 11.19999981, 158},
62 {9.600000381, 112, 1357, 9.699999809, 186},
63 {9.300000191, 131, 544, 9.600000381, 177},
64 {10.60000038, 80, 205, 9.100000381, 127},
65 {9.699999809, 130, 1264, 9.199999809, 179},
66 {11.60000038, 140, 688, 8.300000191, 80},
67 {8.100000381, 154, 354, 8.399999619, 103},
68 {9.800000191, 118, 1632, 9.399999619, 101},
69 {7.400000095, 94, 348, 9.800000191, 117},
70 {9.399999619, 119, 370, 10.39999962, 88},
71 {11.19999981, 153, 648, 9.899999619, 78},
72 {9.100000381, 116, 366, 9.199999809, 102},
73 {10.5, 97, 540, 10.30000019, 95},
74 {11.89999962, 176, 680, 8.899999619, 80},
75 {8.399999619, 75, 345, 9.600000381, 92},
76 {5, 134, 525, 10.30000019, 126},
77 {9.800000191, 161, 870, 10.39999962, 108},
78 {9.800000191, 111, 669, 9.699999809, 77},
79 {10.80000019, 114, 452, 9.600000381, 60},
80 {10.10000038, 142, 430, 10.69999981, 71},
81 {10.89999962, 238, 822, 10.30000019, 86},
82 {9.199999809, 78, 190, 10.69999981, 93},
83 {8.300000191, 196, 867, 9.600000381, 106},
84 {7.300000191, 125, 969, 10.5, 162},
85 {9.399999619, 82, 499, 7.699999809, 95},
86 {9.399999619, 125, 925, 10.19999981, 91},
87 {9.800000191, 129, 353, 9.899999619, 52},
88 {3.599999905, 84, 288, 8.399999619, 110},
89 {8.399999619, 183, 718, 10.39999962, 69},
90 {10.80000019, 119, 540, 9.199999809, 57},
91 {10.10000038, 180, 668, 13, 106},
92 {9, 82, 347, 8.800000191, 40},
93 {10, 71, 345, 9.199999809, 50},
94 {11.30000019, 118, 463, 7.800000191, 35},
95 {11.30000019, 121, 728, 8.199999809, 86},
96 {12.80000019, 68, 383, 7.400000095, 57},
97 {10, 112, 316, 10.39999962, 57},
98 {6.699999809, 109, 388, 8.899999619, 94}
99 };
100
101 /** Run example. */
102 public static void main(String[] args) throws InterruptedException {
103 System.out.println();
104 System.out.println(">>> Linear regression model over cache based dataset usage example started.");
105 // Start ignite grid.
106 try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
107 System.out.println(">>> Ignite grid started.");
108
109 IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
110 TrainTestDatasetSplitterExample.class.getSimpleName(), () -> {
111 IgniteCache<Integer, double[]> dataCache = getTestCache(ignite);
112
113 System.out.println(">>> Create new linear regression trainer object.");
114 LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
115
116 TrainTestSplit<Integer, double[]> split = new TrainTestDatasetSplitter<Integer, double[]>()
117 .split(0.75);
118
119 System.out.println(">>> Perform the training to get the model.");
120 LinearRegressionModel mdl = trainer.fit(
121 ignite,
122 dataCache,
123 split.getTrainFilter(),
124 (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
125 (k, v) -> v[0]
126 );
127
128 System.out.println(">>> Linear regression model: " + mdl);
129
130 System.out.println(">>> ---------------------------------");
131 System.out.println(">>> | Prediction\t| Ground Truth\t|");
132 System.out.println(">>> ---------------------------------");
133
134 ScanQuery<Integer, double[]> qry = new ScanQuery<>();
135 qry.setFilter(split.getTestFilter());
136
137 try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(qry)) {
138 for (Cache.Entry<Integer, double[]> observation : observations) {
139 double[] val = observation.getValue();
140 double[] inputs = Arrays.copyOfRange(val, 1, val.length);
141 double groundTruth = val[0];
142
143 double prediction = mdl.apply(new DenseLocalOnHeapVector(inputs));
144
145 System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
146 }
147 }
148
149 System.out.println(">>> ---------------------------------");
150 });
151
152 igniteThread.start();
153
154 igniteThread.join();
155 }
156 }
157
158 /**
159 * Fills cache with data and returns it.
160 *
161 * @param ignite Ignite instance.
162 * @return Filled Ignite Cache.
163 */
164 private static IgniteCache<Integer, double[]> getTestCache(Ignite ignite) {
165 CacheConfiguration<Integer, double[]> cacheConfiguration = new CacheConfiguration<>();
166 cacheConfiguration.setName("TEST_" + UUID.randomUUID());
167 cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 3));
168
169 IgniteCache<Integer, double[]> cache = ignite.createCache(cacheConfiguration);
170
171 for (int i = 0; i < data.length; i++)
172 cache.put(i, data[i]);
173
174 return cache;
175 }
176 }