IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / examples / src / main / java / org / apache / ignite / examples / ml / tutorial / Step_6_KNN.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.examples.ml.tutorial;
19
20 import java.io.FileNotFoundException;
21 import org.apache.ignite.Ignite;
22 import org.apache.ignite.IgniteCache;
23 import org.apache.ignite.Ignition;
24 import org.apache.ignite.ml.knn.classification.KNNClassificationModel;
25 import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer;
26 import org.apache.ignite.ml.knn.classification.KNNStrategy;
27 import org.apache.ignite.ml.math.Vector;
28 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
29 import org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderTrainer;
30 import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer;
31 import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer;
32 import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer;
33 import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
34 import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
35 import org.apache.ignite.thread.IgniteThread;
36
37 /**
38 * Sometimes is better to change algorithm, let's say on kNN.
39 */
40 public class Step_6_KNN {
41 /** Run example. */
42 public static void main(String[] args) throws InterruptedException {
43 try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
44 IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
45 Step_6_KNN.class.getSimpleName(), () -> {
46 try {
47 IgniteCache<Integer, Object[]> dataCache = TitanicUtils.readPassengers(ignite);
48
49 // Defines first preprocessor that extracts features from an upstream data.
50 // Extracts "pclass", "sibsp", "parch", "sex", "embarked", "age", "fare"
51 IgniteBiFunction<Integer, Object[], Object[]> featureExtractor
52 = (k, v) -> new Object[]{v[0], v[3], v[4], v[5], v[6], v[8], v[10]};
53
54 IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double) v[1];
55
56 IgniteBiFunction<Integer, Object[], Vector> strEncoderPreprocessor = new StringEncoderTrainer<Integer, Object[]>()
57 .encodeFeature(1)
58 .encodeFeature(6) // <--- Changed index here
59 .fit(ignite,
60 dataCache,
61 featureExtractor
62 );
63
64 IgniteBiFunction<Integer, Object[], Vector> imputingPreprocessor = new ImputerTrainer<Integer, Object[]>()
65 .fit(ignite,
66 dataCache,
67 strEncoderPreprocessor
68 );
69
70
71 IgniteBiFunction<Integer, Object[], Vector> minMaxScalerPreprocessor = new MinMaxScalerTrainer<Integer, Object[]>()
72 .fit(
73 ignite,
74 dataCache,
75 imputingPreprocessor
76 );
77
78 IgniteBiFunction<Integer, Object[], Vector> normalizationPreprocessor = new NormalizationTrainer<Integer, Object[]>()
79 .withP(1)
80 .fit(
81 ignite,
82 dataCache,
83 minMaxScalerPreprocessor
84 );
85
86 KNNClassificationTrainer trainer = new KNNClassificationTrainer();
87
88 // Train decision tree model.
89 KNNClassificationModel mdl = trainer.fit(
90 ignite,
91 dataCache,
92 normalizationPreprocessor,
93 lbExtractor
94 ).withK(1).withStrategy(KNNStrategy.WEIGHTED);
95
96 double accuracy = Evaluator.evaluate(
97 dataCache,
98 mdl,
99 normalizationPreprocessor,
100 lbExtractor,
101 new Accuracy<>()
102 );
103
104 System.out.println("\n>>> Accuracy " + accuracy);
105 System.out.println("\n>>> Test Error " + (1 - accuracy));
106 }
107 catch (FileNotFoundException e) {
108 e.printStackTrace();
109 }
110 });
111
112 igniteThread.start();
113 igniteThread.join();
114 }
115 }
116 }