IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / examples / src / main / java / org / apache / ignite / examples / ml / tutorial / Step_5_Scaling.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.examples.ml.tutorial;
19
20 import java.io.FileNotFoundException;
21 import org.apache.ignite.Ignite;
22 import org.apache.ignite.IgniteCache;
23 import org.apache.ignite.Ignition;
24 import org.apache.ignite.ml.math.Vector;
25 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
26 import org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderTrainer;
27 import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer;
28 import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer;
29 import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer;
30 import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
31 import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
32 import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
33 import org.apache.ignite.ml.tree.DecisionTreeNode;
34 import org.apache.ignite.thread.IgniteThread;
35
36 /**
37 * MinMaxScalerTrainer and NormalizationTrainer are used in this example due to different values distribution in columns and rows.
38 */
39 public class Step_5_Scaling {
40 /** Run example. */
41 public static void main(String[] args) throws InterruptedException {
42 try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
43 IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
44 Step_5_Scaling.class.getSimpleName(), () -> {
45 try {
46 IgniteCache<Integer, Object[]> dataCache = TitanicUtils.readPassengers(ignite);
47
48 // Defines first preprocessor that extracts features from an upstream data.
49 // Extracts "pclass", "sibsp", "parch", "sex", "embarked", "age", "fare"
50 IgniteBiFunction<Integer, Object[], Object[]> featureExtractor
51 = (k, v) -> new Object[]{v[0], v[3], v[4], v[5], v[6], v[8], v[10]};
52
53 IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double) v[1];
54
55 IgniteBiFunction<Integer, Object[], Vector> strEncoderPreprocessor = new StringEncoderTrainer<Integer, Object[]>()
56 .encodeFeature(1)
57 .encodeFeature(6) // <--- Changed index here
58 .fit(ignite,
59 dataCache,
60 featureExtractor
61 );
62
63 IgniteBiFunction<Integer, Object[], Vector> imputingPreprocessor = new ImputerTrainer<Integer, Object[]>()
64 .fit(ignite,
65 dataCache,
66 strEncoderPreprocessor
67 );
68
69
70 IgniteBiFunction<Integer, Object[], Vector> minMaxScalerPreprocessor = new MinMaxScalerTrainer<Integer, Object[]>()
71 .fit(
72 ignite,
73 dataCache,
74 imputingPreprocessor
75 );
76
77 IgniteBiFunction<Integer, Object[], Vector> normalizationPreprocessor = new NormalizationTrainer<Integer, Object[]>()
78 .withP(1)
79 .fit(
80 ignite,
81 dataCache,
82 minMaxScalerPreprocessor
83 );
84
85 DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(5, 0);
86
87 // Train decision tree model.
88 DecisionTreeNode mdl = trainer.fit(
89 ignite,
90 dataCache,
91 normalizationPreprocessor,
92 lbExtractor
93 );
94
95 double accuracy = Evaluator.evaluate(
96 dataCache,
97 mdl,
98 normalizationPreprocessor,
99 lbExtractor,
100 new Accuracy<>()
101 );
102
103 System.out.println("\n>>> Accuracy " + accuracy);
104 System.out.println("\n>>> Test Error " + (1 - accuracy));
105 }
106 catch (FileNotFoundException e) {
107 e.printStackTrace();
108 }
109 });
110
111 igniteThread.start();
112 igniteThread.join();
113 }
114 }
115 }