IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / modules / ml / src / test / java / org / apache / ignite / ml / selection / cv / CrossValidationTest.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.ml.selection.cv;
19
20 import java.util.HashMap;
21 import java.util.Map;
22 import org.apache.ignite.ml.math.VectorUtils;
23 import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
24 import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
25 import org.apache.ignite.ml.tree.DecisionTreeNode;
26 import org.junit.Test;
27
28 import static junit.framework.TestCase.assertTrue;
29 import static org.junit.Assert.assertEquals;
30
31 /**
32 * Tests for {@link CrossValidation}.
33 */
34 public class CrossValidationTest {
35 /** */
36 @Test
37 public void testScoreWithGoodDataset() {
38 Map<Integer, Double> data = new HashMap<>();
39
40 for (int i = 0; i < 1000; i++)
41 data.put(i, i > 500 ? 1.0 : 0.0);
42
43 DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(1, 0);
44
45 CrossValidation<DecisionTreeNode, Double, Integer, Double> scoreCalculator =
46 new CrossValidation<>();
47
48 int folds = 4;
49
50 double[] scores = scoreCalculator.score(
51 trainer,
52 new Accuracy<>(),
53 data,
54 1,
55 (k, v) -> VectorUtils.of(k),
56 (k, v) -> v,
57 folds
58 );
59
60 assertEquals(folds, scores.length);
61
62 for (int i = 0; i < folds; i++)
63 assertEquals(1, scores[i], 1e-1);
64 }
65
66 /** */
67 @Test
68 public void testScoreWithBadDataset() {
69 Map<Integer, Double> data = new HashMap<>();
70
71 for (int i = 0; i < 1000; i++)
72 data.put(i, i % 2 == 0 ? 1.0 : 0.0);
73
74 DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(1, 0);
75
76 CrossValidation<DecisionTreeNode, Double, Integer, Double> scoreCalculator =
77 new CrossValidation<>();
78
79 int folds = 4;
80
81 double[] scores = scoreCalculator.score(
82 trainer,
83 new Accuracy<>(),
84 data,
85 1,
86 (k, v) -> VectorUtils.of(k),
87 (k, v) -> v,
88 folds
89 );
90
91 assertEquals(folds, scores.length);
92
93 for (int i = 0; i < folds; i++)
94 assertTrue(scores[i] < 0.6);
95 }
96 }