IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / modules / ml / src / test / java / org / apache / ignite / ml / knn / KNNRegressionTest.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.ml.knn;
19
20 import java.util.ArrayList;
21 import java.util.Arrays;
22 import java.util.HashMap;
23 import java.util.List;
24 import java.util.Map;
25 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
26 import org.apache.ignite.ml.knn.classification.KNNStrategy;
27 import org.apache.ignite.ml.knn.regression.KNNRegressionModel;
28 import org.apache.ignite.ml.knn.regression.KNNRegressionTrainer;
29 import org.apache.ignite.ml.math.Vector;
30 import org.apache.ignite.ml.math.VectorUtils;
31 import org.apache.ignite.ml.math.distances.EuclideanDistance;
32 import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
33 import org.junit.Assert;
34 import org.junit.Test;
35 import org.junit.runner.RunWith;
36 import org.junit.runners.Parameterized;
37
38 /**
39 * Tests for {@link KNNRegressionTrainer}.
40 */
41 @RunWith(Parameterized.class)
42 public class KNNRegressionTest {
43 /** Number of parts to be tested. */
44 private static final int[] partsToBeTested = new int[] {1, 2, 3, 4, 5, 7, 100};
45
46 /** Number of partitions. */
47 @Parameterized.Parameter
48 public int parts;
49
50 /** Parameters. */
51 @Parameterized.Parameters(name = "Data divided on {0} partitions, training with batch size {1}")
52 public static Iterable<Integer[]> data() {
53 List<Integer[]> res = new ArrayList<>();
54
55 for (int part : partsToBeTested)
56 res.add(new Integer[] {part});
57
58 return res;
59 }
60
61 /** */
62 @Test
63 public void testSimpleRegressionWithOneNeighbour() {
64 Map<Integer, double[]> data = new HashMap<>();
65 data.put(0, new double[] {11.0, 0, 0, 0, 0, 0});
66 data.put(1, new double[] {12.0, 2.0, 0, 0, 0, 0});
67 data.put(2, new double[] {13.0, 0, 3.0, 0, 0, 0});
68 data.put(3, new double[] {14.0, 0, 0, 4.0, 0, 0});
69 data.put(4, new double[] {15.0, 0, 0, 0, 5.0, 0});
70 data.put(5, new double[] {16.0, 0, 0, 0, 0, 6.0});
71
72 KNNRegressionTrainer trainer = new KNNRegressionTrainer();
73
74 KNNRegressionModel knnMdl = (KNNRegressionModel) trainer.fit(
75 new LocalDatasetBuilder<>(data, parts),
76 (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
77 (k, v) -> v[0]
78 ).withK(1)
79 .withDistanceMeasure(new EuclideanDistance())
80 .withStrategy(KNNStrategy.SIMPLE);
81
82 Vector vector = new DenseLocalOnHeapVector(new double[] {0, 0, 0, 5.0, 0.0});
83 System.out.println(knnMdl.apply(vector));
84 Assert.assertEquals(15, knnMdl.apply(vector), 1E-12);
85 }
86
87 /** */
88 @Test
89 public void testLongly() {
90 Map<Integer, double[]> data = new HashMap<>();
91 data.put(0, new double[] {60323, 83.0, 234289, 2356, 1590, 107608, 1947});
92 data.put(1, new double[] {61122, 88.5, 259426, 2325, 1456, 108632, 1948});
93 data.put(2, new double[] {60171, 88.2, 258054, 3682, 1616, 109773, 1949});
94 data.put(3, new double[] {61187, 89.5, 284599, 3351, 1650, 110929, 1950});
95 data.put(4, new double[] {63221, 96.2, 328975, 2099, 3099, 112075, 1951});
96 data.put(5, new double[] {63639, 98.1, 346999, 1932, 3594, 113270, 1952});
97 data.put(6, new double[] {64989, 99.0, 365385, 1870, 3547, 115094, 1953});
98 data.put(7, new double[] {63761, 100.0, 363112, 3578, 3350, 116219, 1954});
99 data.put(8, new double[] {66019, 101.2, 397469, 2904, 3048, 117388, 1955});
100 data.put(9, new double[] {68169, 108.4, 442769, 2936, 2798, 120445, 1957});
101 data.put(10, new double[] {66513, 110.8, 444546, 4681, 2637, 121950, 1958});
102 data.put(11, new double[] {68655, 112.6, 482704, 3813, 2552, 123366, 1959});
103 data.put(12, new double[] {69564, 114.2, 502601, 3931, 2514, 125368, 1960});
104 data.put(13, new double[] {69331, 115.7, 518173, 4806, 2572, 127852, 1961});
105 data.put(14, new double[] {70551, 116.9, 554894, 4007, 2827, 130081, 1962});
106
107 KNNRegressionTrainer trainer = new KNNRegressionTrainer();
108
109 KNNRegressionModel knnMdl = (KNNRegressionModel) trainer.fit(
110 new LocalDatasetBuilder<>(data, parts),
111 (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
112 (k, v) -> v[0]
113 ).withK(3)
114 .withDistanceMeasure(new EuclideanDistance())
115 .withStrategy(KNNStrategy.SIMPLE);
116
117 Vector vector = new DenseLocalOnHeapVector(new double[] {104.6, 419180, 2822, 2857, 118734, 1956});
118 System.out.println(knnMdl.apply(vector));
119 Assert.assertEquals(67857, knnMdl.apply(vector), 2000);
120 }
121
122 /** */
123 @Test
124 public void testLonglyWithWeightedStrategy() {
125 Map<Integer, double[]> data = new HashMap<>();
126 data.put(0, new double[] {60323, 83.0, 234289, 2356, 1590, 107608, 1947});
127 data.put(1, new double[] {61122, 88.5, 259426, 2325, 1456, 108632, 1948});
128 data.put(2, new double[] {60171, 88.2, 258054, 3682, 1616, 109773, 1949});
129 data.put(3, new double[] {61187, 89.5, 284599, 3351, 1650, 110929, 1950});
130 data.put(4, new double[] {63221, 96.2, 328975, 2099, 3099, 112075, 1951});
131 data.put(5, new double[] {63639, 98.1, 346999, 1932, 3594, 113270, 1952});
132 data.put(6, new double[] {64989, 99.0, 365385, 1870, 3547, 115094, 1953});
133 data.put(7, new double[] {63761, 100.0, 363112, 3578, 3350, 116219, 1954});
134 data.put(8, new double[] {66019, 101.2, 397469, 2904, 3048, 117388, 1955});
135 data.put(9, new double[] {68169, 108.4, 442769, 2936, 2798, 120445, 1957});
136 data.put(10, new double[] {66513, 110.8, 444546, 4681, 2637, 121950, 1958});
137 data.put(11, new double[] {68655, 112.6, 482704, 3813, 2552, 123366, 1959});
138 data.put(12, new double[] {69564, 114.2, 502601, 3931, 2514, 125368, 1960});
139 data.put(13, new double[] {69331, 115.7, 518173, 4806, 2572, 127852, 1961});
140 data.put(14, new double[] {70551, 116.9, 554894, 4007, 2827, 130081, 1962});
141
142 KNNRegressionTrainer trainer = new KNNRegressionTrainer();
143
144 KNNRegressionModel knnMdl = (KNNRegressionModel) trainer.fit(
145 new LocalDatasetBuilder<>(data, parts),
146 (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
147 (k, v) -> v[0]
148 ).withK(3)
149 .withDistanceMeasure(new EuclideanDistance())
150 .withStrategy(KNNStrategy.SIMPLE);
151
152 Vector vector = new DenseLocalOnHeapVector(new double[] {104.6, 419180, 2822, 2857, 118734, 1956});
153 System.out.println(knnMdl.apply(vector));
154 Assert.assertEquals(67857, knnMdl.apply(vector), 2000);
155 }
156 }