IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / modules / ml / src / test / java / org / apache / ignite / ml / preprocessing / imputing / ImputerTrainerTest.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.ml.preprocessing.imputing;
19
20 import java.util.Arrays;
21 import java.util.HashMap;
22 import java.util.Map;
23 import org.apache.ignite.ml.dataset.DatasetBuilder;
24 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
25 import org.apache.ignite.ml.math.Vector;
26 import org.apache.ignite.ml.math.VectorUtils;
27 import org.junit.Test;
28 import org.junit.runner.RunWith;
29 import org.junit.runners.Parameterized;
30
31 import static org.junit.Assert.assertArrayEquals;
32
33 /**
34 * Tests for {@link ImputerTrainer}.
35 */
36 @RunWith(Parameterized.class)
37 public class ImputerTrainerTest {
38 /** Parameters. */
39 @Parameterized.Parameters(name = "Data divided on {0} partitions")
40 public static Iterable<Integer[]> data() {
41 return Arrays.asList(
42 new Integer[] {1},
43 new Integer[] {2},
44 new Integer[] {3},
45 new Integer[] {5},
46 new Integer[] {7},
47 new Integer[] {100},
48 new Integer[] {1000}
49 );
50 }
51
52 /** Number of partitions. */
53 @Parameterized.Parameter
54 public int parts;
55
56 /** Tests {@code fit()} method. */
57 @Test
58 public void testFit() {
59 Map<Integer, Vector> data = new HashMap<>();
60 data.put(1, VectorUtils.of(1, 2, Double.NaN));
61 data.put(2, VectorUtils.of(1, Double.NaN, 22));
62 data.put(3, VectorUtils.of(Double.NaN, 10, 100));
63 data.put(4, VectorUtils.of(0, 2, 100));
64
65 DatasetBuilder<Integer, Vector> datasetBuilder = new LocalDatasetBuilder<>(data, parts);
66
67 ImputerTrainer<Integer, Vector> imputerTrainer = new ImputerTrainer<Integer, Vector>()
68 .withImputingStrategy(ImputingStrategy.MOST_FREQUENT);
69
70 ImputerPreprocessor<Integer, Vector> preprocessor = imputerTrainer.fit(
71 datasetBuilder,
72 (k, v) -> v
73 );
74
75 assertArrayEquals(new double[] {1, 0, 100}, preprocessor.apply(5, VectorUtils.of(Double.NaN, 0, Double.NaN)).asArray(), 1e-8);
76 }
77 }