2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 package org
.apache
.ignite
.examples
.ml
.regression
.linear
;
20 import java
.util
.Arrays
;
21 import java
.util
.UUID
;
22 import javax
.cache
.Cache
;
23 import org
.apache
.ignite
.Ignite
;
24 import org
.apache
.ignite
.IgniteCache
;
25 import org
.apache
.ignite
.Ignition
;
26 import org
.apache
.ignite
.cache
.affinity
.rendezvous
.RendezvousAffinityFunction
;
27 import org
.apache
.ignite
.cache
.query
.QueryCursor
;
28 import org
.apache
.ignite
.cache
.query
.ScanQuery
;
29 import org
.apache
.ignite
.configuration
.CacheConfiguration
;
30 import org
.apache
.ignite
.ml
.math
.Vector
;
31 import org
.apache
.ignite
.ml
.math
.VectorUtils
;
32 import org
.apache
.ignite
.ml
.math
.functions
.IgniteBiFunction
;
33 import org
.apache
.ignite
.ml
.preprocessing
.minmaxscaling
.MinMaxScalerPreprocessor
;
34 import org
.apache
.ignite
.ml
.preprocessing
.minmaxscaling
.MinMaxScalerTrainer
;
35 import org
.apache
.ignite
.ml
.regressions
.linear
.LinearRegressionLSQRTrainer
;
36 import org
.apache
.ignite
.ml
.regressions
.linear
.LinearRegressionModel
;
37 import org
.apache
.ignite
.thread
.IgniteThread
;
40 * Run linear regression model over cached dataset.
42 * @see LinearRegressionLSQRTrainer
43 * @see MinMaxScalerTrainer
44 * @see MinMaxScalerPreprocessor
46 public class LinearRegressionLSQRTrainerWithMinMaxScalerExample
{
48 private static final double[][] data
= {
49 {8, 78, 284, 9.100000381, 109},
50 {9.300000191, 68, 433, 8.699999809, 144},
51 {7.5, 70, 739, 7.199999809, 113},
52 {8.899999619, 96, 1792, 8.899999619, 97},
53 {10.19999981, 74, 477, 8.300000191, 206},
54 {8.300000191, 111, 362, 10.89999962, 124},
55 {8.800000191, 77, 671, 10, 152},
56 {8.800000191, 168, 636, 9.100000381, 162},
57 {10.69999981, 82, 329, 8.699999809, 150},
58 {11.69999981, 89, 634, 7.599999905, 134},
59 {8.5, 149, 631, 10.80000019, 292},
60 {8.300000191, 60, 257, 9.5, 108},
61 {8.199999809, 96, 284, 8.800000191, 111},
62 {7.900000095, 83, 603, 9.5, 182},
63 {10.30000019, 130, 686, 8.699999809, 129},
64 {7.400000095, 145, 345, 11.19999981, 158},
65 {9.600000381, 112, 1357, 9.699999809, 186},
66 {9.300000191, 131, 544, 9.600000381, 177},
67 {10.60000038, 80, 205, 9.100000381, 127},
68 {9.699999809, 130, 1264, 9.199999809, 179},
69 {11.60000038, 140, 688, 8.300000191, 80},
70 {8.100000381, 154, 354, 8.399999619, 103},
71 {9.800000191, 118, 1632, 9.399999619, 101},
72 {7.400000095, 94, 348, 9.800000191, 117},
73 {9.399999619, 119, 370, 10.39999962, 88},
74 {11.19999981, 153, 648, 9.899999619, 78},
75 {9.100000381, 116, 366, 9.199999809, 102},
76 {10.5, 97, 540, 10.30000019, 95},
77 {11.89999962, 176, 680, 8.899999619, 80},
78 {8.399999619, 75, 345, 9.600000381, 92},
79 {5, 134, 525, 10.30000019, 126},
80 {9.800000191, 161, 870, 10.39999962, 108},
81 {9.800000191, 111, 669, 9.699999809, 77},
82 {10.80000019, 114, 452, 9.600000381, 60},
83 {10.10000038, 142, 430, 10.69999981, 71},
84 {10.89999962, 238, 822, 10.30000019, 86},
85 {9.199999809, 78, 190, 10.69999981, 93},
86 {8.300000191, 196, 867, 9.600000381, 106},
87 {7.300000191, 125, 969, 10.5, 162},
88 {9.399999619, 82, 499, 7.699999809, 95},
89 {9.399999619, 125, 925, 10.19999981, 91},
90 {9.800000191, 129, 353, 9.899999619, 52},
91 {3.599999905, 84, 288, 8.399999619, 110},
92 {8.399999619, 183, 718, 10.39999962, 69},
93 {10.80000019, 119, 540, 9.199999809, 57},
94 {10.10000038, 180, 668, 13, 106},
95 {9, 82, 347, 8.800000191, 40},
96 {10, 71, 345, 9.199999809, 50},
97 {11.30000019, 118, 463, 7.800000191, 35},
98 {11.30000019, 121, 728, 8.199999809, 86},
99 {12.80000019, 68, 383, 7.400000095, 57},
100 {10, 112, 316, 10.39999962, 57},
101 {6.699999809, 109, 388, 8.899999619, 94}
105 public static void main(String
[] args
) throws InterruptedException
{
106 System
.out
.println();
107 System
.out
.println(">>> Linear regression model over cached dataset usage example started.");
108 // Start ignite grid.
109 try (Ignite ignite
= Ignition
.start("examples/config/example-ignite.xml")) {
110 System
.out
.println(">>> Ignite grid started.");
112 IgniteThread igniteThread
= new IgniteThread(ignite
.configuration().getIgniteInstanceName(),
113 LinearRegressionLSQRTrainerWithMinMaxScalerExample
.class.getSimpleName(), () -> {
114 IgniteCache
<Integer
, Vector
> dataCache
= getTestCache(ignite
);
116 System
.out
.println(">>> Create new minmaxscaling trainer object.");
117 MinMaxScalerTrainer
<Integer
, Vector
> normalizationTrainer
= new MinMaxScalerTrainer
<>();
119 System
.out
.println(">>> Perform the training to get the minmaxscaling preprocessor.");
120 IgniteBiFunction
<Integer
, Vector
, Vector
> preprocessor
= normalizationTrainer
.fit(
124 double[] arr
= v
.asArray();
125 return VectorUtils
.of(Arrays
.copyOfRange(arr
, 1, arr
.length
));
129 System
.out
.println(">>> Create new linear regression trainer object.");
130 LinearRegressionLSQRTrainer trainer
= new LinearRegressionLSQRTrainer();
132 System
.out
.println(">>> Perform the training to get the model.");
133 LinearRegressionModel mdl
= trainer
.fit(ignite
, dataCache
, preprocessor
, (k
, v
) -> v
.get(0));
135 System
.out
.println(">>> Linear regression model: " + mdl
);
137 System
.out
.println(">>> ---------------------------------");
138 System
.out
.println(">>> | Prediction\t| Ground Truth\t|");
139 System
.out
.println(">>> ---------------------------------");
141 try (QueryCursor
<Cache
.Entry
<Integer
, Vector
>> observations
= dataCache
.query(new ScanQuery
<>())) {
142 for (Cache
.Entry
<Integer
, Vector
> observation
: observations
) {
143 Integer key
= observation
.getKey();
144 Vector val
= observation
.getValue();
145 double groundTruth
= val
.get(0);
147 double prediction
= mdl
.apply(preprocessor
.apply(key
, val
));
149 System
.out
.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction
, groundTruth
);
153 System
.out
.println(">>> ---------------------------------");
156 igniteThread
.start();
163 * Fills cache with data and returns it.
165 * @param ignite Ignite instance.
166 * @return Filled Ignite Cache.
168 private static IgniteCache
<Integer
, Vector
> getTestCache(Ignite ignite
) {
169 CacheConfiguration
<Integer
, Vector
> cacheConfiguration
= new CacheConfiguration
<>();
170 cacheConfiguration
.setName("TEST_" + UUID
.randomUUID());
171 cacheConfiguration
.setAffinity(new RendezvousAffinityFunction(false
, 10));
173 IgniteCache
<Integer
, Vector
> cache
= ignite
.createCache(cacheConfiguration
);
175 for (int i
= 0; i
< data
.length
; i
++)
176 cache
.put(i
, VectorUtils
.of(data
[i
]));