IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / examples / src / main / java / org / apache / ignite / examples / ml / tutorial / Step_8_CV.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.examples.ml.tutorial;
19
20 import java.io.FileNotFoundException;
21 import java.util.Arrays;
22 import org.apache.ignite.Ignite;
23 import org.apache.ignite.IgniteCache;
24 import org.apache.ignite.Ignition;
25 import org.apache.ignite.ml.math.Vector;
26 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
27 import org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderTrainer;
28 import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer;
29 import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer;
30 import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer;
31 import org.apache.ignite.ml.selection.cv.CrossValidation;
32 import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
33 import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
34 import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter;
35 import org.apache.ignite.ml.selection.split.TrainTestSplit;
36 import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
37 import org.apache.ignite.ml.tree.DecisionTreeNode;
38 import org.apache.ignite.thread.IgniteThread;
39
40 /**
41 * To choose the best hyperparameters the cross-validation will be used in this example.
42 *
43 * The purpose of cross-validation is model checking, not model building.
44 *
45 * We train k different models.
46 *
47 * They differ in that 1/(k-1)th of the training data is exchanged against other cases.
48 *
49 * These models are sometimes called surrogate models because the (average) performance measured for these models
50 * is taken as a surrogate of the performance of the model trained on all cases.
51 *
52 * All scenarios are described there: https://sebastianraschka.com/faq/docs/evaluate-a-model.html
53 */
54 public class Step_8_CV {
55 /** Run example. */
56 public static void main(String[] args) throws InterruptedException {
57 try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
58 IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
59 Step_8_CV.class.getSimpleName(), () -> {
60 try {
61 IgniteCache<Integer, Object[]> dataCache = TitanicUtils.readPassengers(ignite);
62
63 // Defines first preprocessor that extracts features from an upstream data.
64 // Extracts "pclass", "sibsp", "parch", "sex", "embarked", "age", "fare"
65 IgniteBiFunction<Integer, Object[], Object[]> featureExtractor
66 = (k, v) -> new Object[]{v[0], v[3], v[4], v[5], v[6], v[8], v[10]};
67
68 IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double) v[1];
69
70 TrainTestSplit<Integer, Object[]> split = new TrainTestDatasetSplitter<Integer, Object[]>()
71 .split(0.75);
72
73 IgniteBiFunction<Integer, Object[], Vector> strEncoderPreprocessor = new StringEncoderTrainer<Integer, Object[]>()
74 .encodeFeature(1)
75 .encodeFeature(6) // <--- Changed index here
76 .fit(ignite,
77 dataCache,
78 featureExtractor
79 );
80
81 IgniteBiFunction<Integer, Object[], Vector> imputingPreprocessor = new ImputerTrainer<Integer, Object[]>()
82 .fit(ignite,
83 dataCache,
84 strEncoderPreprocessor
85 );
86
87 IgniteBiFunction<Integer, Object[], Vector> minMaxScalerPreprocessor = new MinMaxScalerTrainer<Integer, Object[]>()
88 .fit(
89 ignite,
90 dataCache,
91 imputingPreprocessor
92 );
93
94 // Tune hyperparams with K-fold Cross-Validation on the splitted training set.
95 int[] pSet = new int[]{1, 2};
96 int[] maxDeepSet = new int[]{1, 2, 3, 4, 5, 10, 20};
97 int bestP = 1;
98 int bestMaxDeep = 1;
99 double avg = Double.MIN_VALUE;
100
101 for(int p: pSet){
102 for(int maxDeep: maxDeepSet){
103 IgniteBiFunction<Integer, Object[], Vector> normalizationPreprocessor = new NormalizationTrainer<Integer, Object[]>()
104 .withP(p)
105 .fit(
106 ignite,
107 dataCache,
108 minMaxScalerPreprocessor
109 );
110
111 DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(maxDeep, 0);
112
113 CrossValidation<DecisionTreeNode, Double, Integer, Object[]> scoreCalculator
114 = new CrossValidation<>();
115
116 double[] scores = scoreCalculator.score(
117 trainer,
118 new Accuracy<>(),
119 ignite,
120 dataCache,
121 split.getTrainFilter(),
122 normalizationPreprocessor,
123 lbExtractor,
124 3
125 );
126
127 System.out.println("Scores are: " + Arrays.toString(scores));
128
129 final double currAvg = Arrays.stream(scores).average().orElse(Double.MIN_VALUE);
130
131 if(currAvg > avg) {
132 avg = currAvg;
133 bestP = p;
134 bestMaxDeep = maxDeep;
135 }
136
137 System.out.println("Avg is: " + currAvg + " with p: " + p + " with maxDeep: " + maxDeep);
138 }
139 }
140
141 System.out.println("Train with p: " + bestP + " and maxDeep: " + bestMaxDeep);
142
143 IgniteBiFunction<Integer, Object[], Vector> normalizationPreprocessor = new NormalizationTrainer<Integer, Object[]>()
144 .withP(bestP)
145 .fit(
146 ignite,
147 dataCache,
148 minMaxScalerPreprocessor
149 );
150
151 DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(bestMaxDeep, 0);
152
153 // Train decision tree model.
154 DecisionTreeNode bestMdl = trainer.fit(
155 ignite,
156 dataCache,
157 split.getTrainFilter(),
158 normalizationPreprocessor,
159 lbExtractor
160 );
161
162 double accuracy = Evaluator.evaluate(
163 dataCache,
164 split.getTestFilter(),
165 bestMdl,
166 normalizationPreprocessor,
167 lbExtractor,
168 new Accuracy<>()
169 );
170
171 System.out.println("\n>>> Accuracy " + accuracy);
172 System.out.println("\n>>> Test Error " + (1 - accuracy));
173 }
174 catch (FileNotFoundException e) {
175 e.printStackTrace();
176 }
177 });
178
179 igniteThread.start();
180
181 igniteThread.join();
182 }
183 }
184 }