2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 package org
.apache
.ignite
.examples
.ml
.tutorial
;
20 import java
.io
.FileNotFoundException
;
21 import java
.util
.Arrays
;
22 import org
.apache
.ignite
.Ignite
;
23 import org
.apache
.ignite
.IgniteCache
;
24 import org
.apache
.ignite
.Ignition
;
25 import org
.apache
.ignite
.ml
.math
.Vector
;
26 import org
.apache
.ignite
.ml
.math
.functions
.IgniteBiFunction
;
27 import org
.apache
.ignite
.ml
.preprocessing
.encoding
.stringencoder
.StringEncoderTrainer
;
28 import org
.apache
.ignite
.ml
.preprocessing
.imputing
.ImputerTrainer
;
29 import org
.apache
.ignite
.ml
.preprocessing
.minmaxscaling
.MinMaxScalerTrainer
;
30 import org
.apache
.ignite
.ml
.preprocessing
.normalization
.NormalizationTrainer
;
31 import org
.apache
.ignite
.ml
.selection
.cv
.CrossValidation
;
32 import org
.apache
.ignite
.ml
.selection
.scoring
.evaluator
.Evaluator
;
33 import org
.apache
.ignite
.ml
.selection
.scoring
.metric
.Accuracy
;
34 import org
.apache
.ignite
.ml
.selection
.split
.TrainTestDatasetSplitter
;
35 import org
.apache
.ignite
.ml
.selection
.split
.TrainTestSplit
;
36 import org
.apache
.ignite
.ml
.tree
.DecisionTreeClassificationTrainer
;
37 import org
.apache
.ignite
.ml
.tree
.DecisionTreeNode
;
38 import org
.apache
.ignite
.thread
.IgniteThread
;
41 * To choose the best hyperparameters the cross-validation will be used in this example.
43 * The purpose of cross-validation is model checking, not model building.
45 * We train k different models.
47 * They differ in that 1/(k-1)th of the training data is exchanged against other cases.
49 * These models are sometimes called surrogate models because the (average) performance measured for these models
50 * is taken as a surrogate of the performance of the model trained on all cases.
52 * All scenarios are described there: https://sebastianraschka.com/faq/docs/evaluate-a-model.html
54 public class Step_8_CV
{
56 public static void main(String
[] args
) throws InterruptedException
{
57 try (Ignite ignite
= Ignition
.start("examples/config/example-ignite.xml")) {
58 IgniteThread igniteThread
= new IgniteThread(ignite
.configuration().getIgniteInstanceName(),
59 Step_8_CV
.class.getSimpleName(), () -> {
61 IgniteCache
<Integer
, Object
[]> dataCache
= TitanicUtils
.readPassengers(ignite
);
63 // Defines first preprocessor that extracts features from an upstream data.
64 // Extracts "pclass", "sibsp", "parch", "sex", "embarked", "age", "fare"
65 IgniteBiFunction
<Integer
, Object
[], Object
[]> featureExtractor
66 = (k
, v
) -> new Object
[]{v
[0], v
[3], v
[4], v
[5], v
[6], v
[8], v
[10]};
68 IgniteBiFunction
<Integer
, Object
[], Double
> lbExtractor
= (k
, v
) -> (double) v
[1];
70 TrainTestSplit
<Integer
, Object
[]> split
= new TrainTestDatasetSplitter
<Integer
, Object
[]>()
73 IgniteBiFunction
<Integer
, Object
[], Vector
> strEncoderPreprocessor
= new StringEncoderTrainer
<Integer
, Object
[]>()
75 .encodeFeature(6) // <--- Changed index here
81 IgniteBiFunction
<Integer
, Object
[], Vector
> imputingPreprocessor
= new ImputerTrainer
<Integer
, Object
[]>()
84 strEncoderPreprocessor
87 IgniteBiFunction
<Integer
, Object
[], Vector
> minMaxScalerPreprocessor
= new MinMaxScalerTrainer
<Integer
, Object
[]>()
94 // Tune hyperparams with K-fold Cross-Validation on the splitted training set.
95 int[] pSet
= new int[]{1, 2};
96 int[] maxDeepSet
= new int[]{1, 2, 3, 4, 5, 10, 20};
99 double avg
= Double
.MIN_VALUE
;
102 for(int maxDeep
: maxDeepSet
){
103 IgniteBiFunction
<Integer
, Object
[], Vector
> normalizationPreprocessor
= new NormalizationTrainer
<Integer
, Object
[]>()
108 minMaxScalerPreprocessor
111 DecisionTreeClassificationTrainer trainer
= new DecisionTreeClassificationTrainer(maxDeep
, 0);
113 CrossValidation
<DecisionTreeNode
, Double
, Integer
, Object
[]> scoreCalculator
114 = new CrossValidation
<>();
116 double[] scores
= scoreCalculator
.score(
121 split
.getTrainFilter(),
122 normalizationPreprocessor
,
127 System
.out
.println("Scores are: " + Arrays
.toString(scores
));
129 final double currAvg
= Arrays
.stream(scores
).average().orElse(Double
.MIN_VALUE
);
134 bestMaxDeep
= maxDeep
;
137 System
.out
.println("Avg is: " + currAvg
+ " with p: " + p
+ " with maxDeep: " + maxDeep
);
141 System
.out
.println("Train with p: " + bestP
+ " and maxDeep: " + bestMaxDeep
);
143 IgniteBiFunction
<Integer
, Object
[], Vector
> normalizationPreprocessor
= new NormalizationTrainer
<Integer
, Object
[]>()
148 minMaxScalerPreprocessor
151 DecisionTreeClassificationTrainer trainer
= new DecisionTreeClassificationTrainer(bestMaxDeep
, 0);
153 // Train decision tree model.
154 DecisionTreeNode bestMdl
= trainer
.fit(
157 split
.getTrainFilter(),
158 normalizationPreprocessor
,
162 double accuracy
= Evaluator
.evaluate(
164 split
.getTestFilter(),
166 normalizationPreprocessor
,
171 System
.out
.println("\n>>> Accuracy " + accuracy
);
172 System
.out
.println("\n>>> Test Error " + (1 - accuracy
));
174 catch (FileNotFoundException e
) {
179 igniteThread
.start();