2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 package org
.apache
.ignite
.ml
.regressions
.linear
;
20 import java
.util
.Arrays
;
21 import java
.util
.HashMap
;
23 import org
.apache
.ignite
.ml
.math
.VectorUtils
;
24 import org
.apache
.ignite
.ml
.nn
.UpdatesStrategy
;
25 import org
.apache
.ignite
.ml
.optimization
.updatecalculators
.RPropParameterUpdate
;
26 import org
.apache
.ignite
.ml
.optimization
.updatecalculators
.RPropUpdateCalculator
;
27 import org
.junit
.Test
;
28 import org
.junit
.runner
.RunWith
;
29 import org
.junit
.runners
.Parameterized
;
31 import static org
.junit
.Assert
.assertArrayEquals
;
32 import static org
.junit
.Assert
.assertEquals
;
35 * Tests for {@link LinearRegressionSGDTrainer}.
37 @RunWith(Parameterized
.class)
38 public class LinearRegressionSGDTrainerTest
{
40 @Parameterized.Parameters(name
= "Data divided on {0} partitions")
41 public static Iterable
<Integer
[]> data() {
52 /** Number of partitions. */
53 @Parameterized.Parameter
57 * Tests {@code fit()} method on a simple small dataset.
60 public void testSmallDataFit() {
61 Map
<Integer
, double[]> data
= new HashMap
<>();
62 data
.put(0, new double[] {-1.0915526, 1.81983527, -0.91409478, 0.70890712, -24.55724107});
63 data
.put(1, new double[] {-0.61072904, 0.37545517, 0.21705352, 0.09516495, -26.57226867});
64 data
.put(2, new double[] {0.05485406, 0.88219898, -0.80584547, 0.94668307, 61.80919728});
65 data
.put(3, new double[] {-0.24835094, -0.34000053, -1.69984651, -1.45902635, -161.65525991});
66 data
.put(4, new double[] {0.63675392, 0.31675535, 0.38837437, -1.1221971, -14.46432611});
67 data
.put(5, new double[] {0.14194017, 2.18158997, -0.28397346, -0.62090588, -3.2122197});
68 data
.put(6, new double[] {-0.53487507, 1.4454797, 0.21570443, -0.54161422, -46.5469012});
69 data
.put(7, new double[] {-1.58812173, -0.73216803, -2.15670676, -1.03195988, -247.23559889});
70 data
.put(8, new double[] {0.20702671, 0.92864654, 0.32721202, -0.09047503, 31.61484949});
71 data
.put(9, new double[] {-0.37890345, -0.04846179, -0.84122753, -1.14667474, -124.92598583});
73 LinearRegressionSGDTrainer
<?
> trainer
= new LinearRegressionSGDTrainer
<>(new UpdatesStrategy
<>(
74 new RPropUpdateCalculator(),
75 RPropParameterUpdate
::sumLocal
,
76 RPropParameterUpdate
::avg
77 ), 100000, 10, 100, 123L);
79 LinearRegressionModel mdl
= trainer
.fit(
82 (k
, v
) -> VectorUtils
.of(Arrays
.copyOfRange(v
, 0, v
.length
- 1)),
87 new double[] {72.26948107, 15.95144674, 24.07403921, 66.73038781},
88 mdl
.getWeights().getStorage().data(),
92 assertEquals(2.8421709430404007e-14, mdl
.getIntercept(), 1e-1);