IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / modules / ml / src / main / java / org / apache / ignite / ml / selection / scoring / evaluator / Evaluator.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.ml.selection.scoring.evaluator;
19
20 import org.apache.ignite.IgniteCache;
21 import org.apache.ignite.lang.IgniteBiPredicate;
22 import org.apache.ignite.ml.Model;
23 import org.apache.ignite.ml.math.Vector;
24 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
25 import org.apache.ignite.ml.selection.scoring.cursor.CacheBasedLabelPairCursor;
26 import org.apache.ignite.ml.selection.scoring.cursor.LabelPairCursor;
27 import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
28
29 /**
30 * Binary classification evaluator that compute metrics from predictions.
31 */
32 public class Evaluator {
33 /**
34 * Computes the given metric on the given cache.
35 *
36 * @param dataCache The given cache.
37 * @param mdl The model.
38 * @param featureExtractor The feature extractor.
39 * @param lbExtractor The label extractor.
40 * @param metric The binary classification metric.
41 * @param <L> The type of label.
42 * @param <K> The type of cache entry key.
43 * @param <V> The type of cache entry value.
44 * @return Computed metric.
45 */
46 public static <L, K, V> double evaluate(IgniteCache<K, V> dataCache,
47 Model<Vector, L> mdl,
48 IgniteBiFunction<K, V, Vector> featureExtractor,
49 IgniteBiFunction<K, V, L> lbExtractor,
50 Accuracy<L> metric) {
51 double metricRes;
52
53 try (LabelPairCursor<L> cursor = new CacheBasedLabelPairCursor<L, K, V>(
54 dataCache,
55 featureExtractor,
56 lbExtractor,
57 mdl
58 )) {
59 metricRes = metric.score(cursor.iterator());
60 }
61 catch (Exception e) {
62 throw new RuntimeException(e);
63 }
64
65 return metricRes;
66 }
67
68 /**
69 * Computes the given metric on the given cache.
70 *
71 * @param dataCache The given cache.
72 * @param filter The given filter.
73 * @param mdl The model.
74 * @param featureExtractor The feature extractor.
75 * @param lbExtractor The label extractor.
76 * @param metric The binary classification metric.
77 * @param <L> The type of label.
78 * @param <K> The type of cache entry key.
79 * @param <V> The type of cache entry value.
80 * @return Computed metric.
81 */
82 public static <L, K, V> double evaluate(IgniteCache<K, V> dataCache, IgniteBiPredicate<K, V> filter,
83 Model<Vector, L> mdl,
84 IgniteBiFunction<K, V, Vector> featureExtractor,
85 IgniteBiFunction<K, V, L> lbExtractor,
86 Accuracy<L> metric) {
87 double metricRes;
88
89 try (LabelPairCursor<L> cursor = new CacheBasedLabelPairCursor<L, K, V>(
90 dataCache,
91 filter,
92 featureExtractor,
93 lbExtractor,
94 mdl
95 )) {
96 metricRes = metric.score(cursor.iterator());
97 }
98 catch (Exception e) {
99 throw new RuntimeException(e);
100 }
101
102 return metricRes;
103 }
104 }