IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / examples / src / main / java / org / apache / ignite / examples / ml / tree / DecisionTreeRegressionTrainerExample.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.examples.ml.tree;
19
20 import org.apache.ignite.Ignite;
21 import org.apache.ignite.IgniteCache;
22 import org.apache.ignite.Ignition;
23 import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
24 import org.apache.ignite.configuration.CacheConfiguration;
25 import org.apache.ignite.ml.math.VectorUtils;
26 import org.apache.ignite.ml.tree.DecisionTreeNode;
27 import org.apache.ignite.ml.tree.DecisionTreeRegressionTrainer;
28 import org.apache.ignite.thread.IgniteThread;
29
30 /**
31 * Example of using distributed {@link DecisionTreeRegressionTrainer}.
32 */
33 public class DecisionTreeRegressionTrainerExample {
34 /**
35 * Executes example.
36 *
37 * @param args Command line arguments, none required.
38 */
39 public static void main(String... args) throws InterruptedException {
40 System.out.println(">>> Decision tree regression trainer example started.");
41
42 // Start ignite grid.
43 try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
44 System.out.println(">>> Ignite grid started.");
45
46 IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
47 DecisionTreeRegressionTrainerExample.class.getSimpleName(), () -> {
48
49 // Create cache with training data.
50 CacheConfiguration<Integer, Point> trainingSetCfg = new CacheConfiguration<>();
51 trainingSetCfg.setName("TRAINING_SET");
52 trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
53
54 IgniteCache<Integer, Point> trainingSet = ignite.createCache(trainingSetCfg);
55
56 // Fill training data.
57 generatePoints(trainingSet);
58
59 // Create regression trainer.
60 DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(10, 0);
61
62 // Train decision tree model.
63 DecisionTreeNode mdl = trainer.fit(
64 ignite,
65 trainingSet,
66 (k, v) -> VectorUtils.of(v.x),
67 (k, v) -> v.y
68 );
69
70 System.out.println(">>> Decision tree regression model: " + mdl);
71
72 System.out.println(">>> ---------------------------------");
73 System.out.println(">>> | Prediction\t| Ground Truth\t|");
74 System.out.println(">>> ---------------------------------");
75
76 // Calculate score.
77 for (int x = 0; x < 10; x++) {
78 double predicted = mdl.apply(VectorUtils.of(x));
79
80 System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", predicted, Math.sin(x));
81 }
82
83 System.out.println(">>> ---------------------------------");
84
85 System.out.println(">>> Decision tree regression trainer example completed.");
86 });
87
88 igniteThread.start();
89
90 igniteThread.join();
91 }
92 }
93
94 /**
95 * Generates {@code sin(x)} on interval [0, 10) and loads into the specified cache.
96 */
97 private static void generatePoints(IgniteCache<Integer, Point> trainingSet) {
98 for (int i = 0; i < 1000; i++) {
99 double x = i / 100.0;
100 double y = Math.sin(x);
101
102 trainingSet.put(i, new Point(x, y));
103 }
104 }
105
106 /** Point data class. */
107 private static class Point {
108 /** X coordinate. */
109 final double x;
110
111 /** Y coordinate. */
112 final double y;
113
114 /**
115 * Constructs a new instance of point.
116 *
117 * @param x X coordinate.
118 * @param y Y coordinate.
119 */
120 Point(double x, double y) {
121 this.x = x;
122 this.y = y;
123 }
124 }
125 }