IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / examples / src / main / java / org / apache / ignite / examples / ml / selection / cv / CrossValidationExample.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.examples.ml.selection.cv;
19
20 import java.util.Arrays;
21 import java.util.Random;
22 import org.apache.ignite.Ignite;
23 import org.apache.ignite.IgniteCache;
24 import org.apache.ignite.Ignition;
25 import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
26 import org.apache.ignite.configuration.CacheConfiguration;
27 import org.apache.ignite.examples.ml.tree.DecisionTreeClassificationTrainerExample;
28 import org.apache.ignite.ml.math.VectorUtils;
29 import org.apache.ignite.ml.selection.cv.CrossValidation;
30 import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
31 import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
32 import org.apache.ignite.ml.tree.DecisionTreeNode;
33 import org.apache.ignite.thread.IgniteThread;
34
35 /**
36 * Run decision tree classification with cross validation.
37 *
38 * @see CrossValidation
39 */
40 public class CrossValidationExample {
41 /**
42 * Executes example.
43 *
44 * @param args Command line arguments, none required.
45 */
46 public static void main(String... args) throws InterruptedException {
47 System.out.println(">>> Cross validation score calculator example started.");
48
49 // Start ignite grid.
50 try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
51 System.out.println(">>> Ignite grid started.");
52
53 IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
54 DecisionTreeClassificationTrainerExample.class.getSimpleName(), () -> {
55
56 // Create cache with training data.
57 CacheConfiguration<Integer, LabeledPoint> trainingSetCfg = new CacheConfiguration<>();
58 trainingSetCfg.setName("TRAINING_SET");
59 trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
60
61 IgniteCache<Integer, LabeledPoint> trainingSet = ignite.createCache(trainingSetCfg);
62
63 Random rnd = new Random(0);
64
65 // Fill training data.
66 for (int i = 0; i < 1000; i++)
67 trainingSet.put(i, generatePoint(rnd));
68
69 // Create classification trainer.
70 DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
71
72 CrossValidation<DecisionTreeNode, Double, Integer, LabeledPoint> scoreCalculator
73 = new CrossValidation<>();
74
75 double[] scores = scoreCalculator.score(
76 trainer,
77 new Accuracy<>(),
78 ignite,
79 trainingSet,
80 (k, v) -> VectorUtils.of(v.x, v.y),
81 (k, v) -> v.lb,
82 4
83 );
84
85 System.out.println(">>> Accuracy: " + Arrays.toString(scores));
86
87 System.out.println(">>> Cross validation score calculator example completed.");
88 });
89
90 igniteThread.start();
91
92 igniteThread.join();
93 }
94 }
95
96 /**
97 * Generate point with {@code x} in (-0.5, 0.5) and {@code y} in the same interval. If {@code x * y > 0} then label
98 * is 1, otherwise 0.
99 *
100 * @param rnd Random.
101 * @return Point with label.
102 */
103 private static LabeledPoint generatePoint(Random rnd) {
104
105 double x = rnd.nextDouble() - 0.5;
106 double y = rnd.nextDouble() - 0.5;
107
108 return new LabeledPoint(x, y, x * y > 0 ? 1 : 0);
109 }
110
111 /** Point data class. */
112 private static class Point {
113 /** X coordinate. */
114 final double x;
115
116 /** Y coordinate. */
117 final double y;
118
119 /**
120 * Constructs a new instance of point.
121 *
122 * @param x X coordinate.
123 * @param y Y coordinate.
124 */
125 Point(double x, double y) {
126 this.x = x;
127 this.y = y;
128 }
129 }
130
131 /** Labeled point data class. */
132 private static class LabeledPoint extends Point {
133 /** Point label. */
134 final double lb;
135
136 /**
137 * Constructs a new instance of labeled point data.
138 *
139 * @param x X coordinate.
140 * @param y Y coordinate.
141 * @param lb Point label.
142 */
143 LabeledPoint(double x, double y, double lb) {
144 super(x, y);
145 this.lb = lb;
146 }
147 }
148 }