IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / modules / ml / src / test / java / org / apache / ignite / ml / math / isolve / lsqr / LSQROnHeapTest.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.ml.math.isolve.lsqr;
19
20 import java.util.Arrays;
21 import java.util.HashMap;
22 import java.util.Map;
23 import org.apache.ignite.ml.dataset.DatasetBuilder;
24 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
25 import org.apache.ignite.ml.dataset.primitive.builder.data.SimpleLabeledDatasetDataBuilder;
26 import org.apache.ignite.ml.math.VectorUtils;
27 import org.junit.Test;
28 import org.junit.runner.RunWith;
29 import org.junit.runners.Parameterized;
30
31 import static org.junit.Assert.assertArrayEquals;
32
33 /**
34 * Tests for {@link LSQROnHeap}.
35 */
36 @RunWith(Parameterized.class)
37 public class LSQROnHeapTest {
38 /** Parameters. */
39 @Parameterized.Parameters(name = "Data divided on {0} partitions")
40 public static Iterable<Integer[]> data() {
41 return Arrays.asList(
42 new Integer[] {1},
43 new Integer[] {2},
44 new Integer[] {3},
45 new Integer[] {5},
46 new Integer[] {7},
47 new Integer[] {100},
48 new Integer[] {1000}
49 );
50 }
51
52 /** Number of partitions. */
53 @Parameterized.Parameter
54 public int parts;
55
56 /** Tests solving simple linear system. */
57 @Test
58 public void testSolveLinearSystem() {
59 Map<Integer, double[]> data = new HashMap<>();
60 data.put(0, new double[]{3, 2, -1, 1});
61 data.put(1, new double[]{2, -2, 4, -2});
62 data.put(2, new double[]{-1, 0.5, -1, 0});
63
64 DatasetBuilder<Integer, double[]> datasetBuilder = new LocalDatasetBuilder<>(data, parts);
65
66 LSQROnHeap<Integer, double[]> lsqr = new LSQROnHeap<>(
67 datasetBuilder,
68 new SimpleLabeledDatasetDataBuilder<>(
69 (k, v) -> VectorUtils.of(Arrays.copyOf(v, v.length - 1)),
70 (k, v) -> new double[]{v[3]}
71 )
72 );
73
74 LSQRResult res = lsqr.solve(0, 1e-12, 1e-12, 1e8, -1, false, null);
75
76 assertArrayEquals(new double[]{1, -2, -2}, res.getX(), 1e-6);
77 }
78
79 /** Tests solving simple linear system with specified x0. */
80 @Test
81 public void testSolveLinearSystemWithX0() {
82 Map<Integer, double[]> data = new HashMap<>();
83 data.put(0, new double[]{3, 2, -1, 1});
84 data.put(1, new double[]{2, -2, 4, -2});
85 data.put(2, new double[]{-1, 0.5, -1, 0});
86
87 DatasetBuilder<Integer, double[]> datasetBuilder = new LocalDatasetBuilder<>(data, parts);
88
89 LSQROnHeap<Integer, double[]> lsqr = new LSQROnHeap<>(
90 datasetBuilder,
91 new SimpleLabeledDatasetDataBuilder<>(
92 (k, v) -> VectorUtils.of(Arrays.copyOf(v, v.length - 1)),
93 (k, v) -> new double[]{v[3]}
94 )
95 );
96
97 LSQRResult res = lsqr.solve(0, 1e-12, 1e-12, 1e8, -1, false,
98 new double[] {999, 999, 999});
99
100 assertArrayEquals(new double[]{1, -2, -2}, res.getX(), 1e-6);
101 }
102
103 /** Tests solving least squares problem. */
104 @Test
105 public void testSolveLeastSquares() throws Exception {
106 Map<Integer, double[]> data = new HashMap<>();
107 data.put(0, new double[] {-1.0915526, 1.81983527, -0.91409478, 0.70890712, -24.55724107});
108 data.put(1, new double[] {-0.61072904, 0.37545517, 0.21705352, 0.09516495, -26.57226867});
109 data.put(2, new double[] {0.05485406, 0.88219898, -0.80584547, 0.94668307, 61.80919728});
110 data.put(3, new double[] {-0.24835094, -0.34000053, -1.69984651, -1.45902635, -161.65525991});
111 data.put(4, new double[] {0.63675392, 0.31675535, 0.38837437, -1.1221971, -14.46432611});
112 data.put(5, new double[] {0.14194017, 2.18158997, -0.28397346, -0.62090588, -3.2122197});
113 data.put(6, new double[] {-0.53487507, 1.4454797, 0.21570443, -0.54161422, -46.5469012});
114 data.put(7, new double[] {-1.58812173, -0.73216803, -2.15670676, -1.03195988, -247.23559889});
115 data.put(8, new double[] {0.20702671, 0.92864654, 0.32721202, -0.09047503, 31.61484949});
116 data.put(9, new double[] {-0.37890345, -0.04846179, -0.84122753, -1.14667474, -124.92598583});
117
118 DatasetBuilder<Integer, double[]> datasetBuilder = new LocalDatasetBuilder<>(data, 1);
119
120 try (LSQROnHeap<Integer, double[]> lsqr = new LSQROnHeap<>(
121 datasetBuilder,
122 new SimpleLabeledDatasetDataBuilder<>(
123 (k, v) -> VectorUtils.of(Arrays.copyOf(v, v.length - 1)),
124 (k, v) -> new double[]{v[4]}
125 )
126 )) {
127 LSQRResult res = lsqr.solve(0, 1e-12, 1e-12, 1e8, -1, false, null);
128
129 assertArrayEquals(new double[]{72.26948107, 15.95144674, 24.07403921, 66.73038781}, res.getX(), 1e-6);
130 }
131 }
132 }