IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / examples / src / main / java / org / apache / ignite / examples / ml / tree / boosting / GRBOnTreesRegressionTrainerExample.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.examples.ml.tree.boosting;
19
20 import org.apache.ignite.Ignite;
21 import org.apache.ignite.IgniteCache;
22 import org.apache.ignite.Ignition;
23 import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
24 import org.apache.ignite.configuration.CacheConfiguration;
25 import org.apache.ignite.ml.Model;
26 import org.apache.ignite.ml.math.Vector;
27 import org.apache.ignite.ml.math.VectorUtils;
28 import org.apache.ignite.ml.trainers.DatasetTrainer;
29 import org.apache.ignite.ml.tree.boosting.GDBRegressionOnTreesTrainer;
30 import org.apache.ignite.thread.IgniteThread;
31 import org.jetbrains.annotations.NotNull;
32
33 /**
34 * Example represents a solution for the task of regression learning based on
35 * Gradient Boosting on trees implementation. It shows an initialization of {@link org.apache.ignite.ml.tree.boosting.GDBRegressionOnTreesTrainer},
36 * initialization of Ignite Cache, learning step and comparing of predicted and real values.
37 *
38 * In this example dataset is creating automatically by parabolic function f(x) = x^2.
39 */
40 public class GRBOnTreesRegressionTrainerExample {
41 /**
42 * Run example.
43 *
44 * @param args Command line arguments, none required.
45 */
46 public static void main(String... args) throws InterruptedException {
47 // Start ignite grid.
48 try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
49 System.out.println(">>> Ignite grid started.");
50
51 IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
52 GRBOnTreesRegressionTrainerExample.class.getSimpleName(), () -> {
53
54 // Create cache with training data.
55 CacheConfiguration<Integer, double[]> trainingSetCfg = createCacheConfiguration();
56 IgniteCache<Integer, double[]> trainingSet = fillTrainingData(ignite, trainingSetCfg);
57
58 // Create regression trainer.
59 DatasetTrainer<Model<Vector, Double>, Double> trainer = new GDBRegressionOnTreesTrainer(1.0, 2000, 1, 0.);
60
61 // Train decision tree model.
62 Model<Vector, Double> mdl = trainer.fit(
63 ignite,
64 trainingSet,
65 (k, v) -> VectorUtils.of(v[0]),
66 (k, v) -> v[1]
67 );
68
69 System.out.println(">>> ---------------------------------");
70 System.out.println(">>> | Prediction\t| Valid answer \t|");
71 System.out.println(">>> ---------------------------------");
72
73 // Calculate score.
74 for (int x = -5; x < 5; x++) {
75 double predicted = mdl.apply(VectorUtils.of(x));
76
77 System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", predicted, Math.pow(x, 2));
78 }
79
80 System.out.println(">>> ---------------------------------");
81
82 System.out.println(">>> GDB Regression trainer example completed.");
83 });
84
85 igniteThread.start();
86 igniteThread.join();
87 }
88 }
89
90 /**
91 * Create cache configuration.
92 */
93 @NotNull private static CacheConfiguration<Integer, double[]> createCacheConfiguration() {
94 CacheConfiguration<Integer, double[]> trainingSetCfg = new CacheConfiguration<>();
95 trainingSetCfg.setName("TRAINING_SET");
96 trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
97 return trainingSetCfg;
98 }
99
100 /**
101 * Fill parabola training data.
102 *
103 * @param ignite Ignite.
104 * @param trainingSetCfg Training set config.
105 */
106 @NotNull private static IgniteCache<Integer, double[]> fillTrainingData(Ignite ignite,
107 CacheConfiguration<Integer, double[]> trainingSetCfg) {
108 IgniteCache<Integer, double[]> trainingSet = ignite.createCache(trainingSetCfg);
109 for(int i = -50; i <= 50; i++) {
110 double x = ((double)i) / 10.0;
111 double y = Math.pow(x, 2);
112 trainingSet.put(i, new double[] {x, y});
113 }
114 return trainingSet;
115 }
116 }