IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / examples / src / main / java / org / apache / ignite / examples / ml / tree / DecisionTreeClassificationTrainerExample.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.examples.ml.tree;
19
20 import java.util.Random;
21 import org.apache.ignite.Ignite;
22 import org.apache.ignite.IgniteCache;
23 import org.apache.ignite.Ignition;
24 import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
25 import org.apache.ignite.configuration.CacheConfiguration;
26 import org.apache.ignite.ml.math.VectorUtils;
27 import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
28 import org.apache.ignite.ml.tree.DecisionTreeNode;
29 import org.apache.ignite.thread.IgniteThread;
30
31 /**
32 * Example of using distributed {@link DecisionTreeClassificationTrainer}.
33 */
34 public class DecisionTreeClassificationTrainerExample {
35 /**
36 * Executes example.
37 *
38 * @param args Command line arguments, none required.
39 */
40 public static void main(String... args) throws InterruptedException {
41 System.out.println(">>> Decision tree classification trainer example started.");
42
43 // Start ignite grid.
44 try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
45 System.out.println(">>> Ignite grid started.");
46
47 IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
48 DecisionTreeClassificationTrainerExample.class.getSimpleName(), () -> {
49
50 // Create cache with training data.
51 CacheConfiguration<Integer, LabeledPoint> trainingSetCfg = new CacheConfiguration<>();
52 trainingSetCfg.setName("TRAINING_SET");
53 trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
54
55 IgniteCache<Integer, LabeledPoint> trainingSet = ignite.createCache(trainingSetCfg);
56
57 Random rnd = new Random(0);
58
59 // Fill training data.
60 for (int i = 0; i < 1000; i++)
61 trainingSet.put(i, generatePoint(rnd));
62
63 // Create classification trainer.
64 DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
65
66 // Train decision tree model.
67 DecisionTreeNode mdl = trainer.fit(
68 ignite,
69 trainingSet,
70 (k, v) -> VectorUtils.of(v.x, v.y),
71 (k, v) -> v.lb
72 );
73
74 // Calculate score.
75 int correctPredictions = 0;
76 for (int i = 0; i < 1000; i++) {
77 LabeledPoint pnt = generatePoint(rnd);
78
79 double prediction = mdl.apply(VectorUtils.of(pnt.x, pnt.y));
80
81 if (prediction == pnt.lb)
82 correctPredictions++;
83 }
84
85 System.out.println(">>> Accuracy: " + correctPredictions / 10.0 + "%");
86
87 System.out.println(">>> Decision tree classification trainer example completed.");
88 });
89
90 igniteThread.start();
91
92 igniteThread.join();
93 }
94 }
95
96 /**
97 * Generate point with {@code x} in (-0.5, 0.5) and {@code y} in the same interval. If {@code x * y > 0} then label
98 * is 1, otherwise 0.
99 *
100 * @param rnd Random.
101 * @return Point with label.
102 */
103 private static LabeledPoint generatePoint(Random rnd) {
104
105 double x = rnd.nextDouble() - 0.5;
106 double y = rnd.nextDouble() - 0.5;
107
108 return new LabeledPoint(x, y, x * y > 0 ? 1 : 0);
109 }
110
111 /** Point data class. */
112 private static class Point {
113 /** X coordinate. */
114 final double x;
115
116 /** Y coordinate. */
117 final double y;
118
119 /**
120 * Constructs a new instance of point.
121 *
122 * @param x X coordinate.
123 * @param y Y coordinate.
124 */
125 Point(double x, double y) {
126 this.x = x;
127 this.y = y;
128 }
129 }
130
131 /** Labeled point data class. */
132 private static class LabeledPoint extends Point {
133 /** Point label. */
134 final double lb;
135
136 /**
137 * Constructs a new instance of labeled point data.
138 *
139 * @param x X coordinate.
140 * @param y Y coordinate.
141 * @param lb Point label.
142 */
143 LabeledPoint(double x, double y, double lb) {
144 super(x, y);
145 this.lb = lb;
146 }
147 }
148 }