IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / modules / ml / src / main / java / org / apache / ignite / ml / selection / scoring / cursor / CacheBasedLabelPairCursor.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.ml.selection.scoring.cursor;
19
20 import java.util.Iterator;
21 import javax.cache.Cache;
22 import org.apache.ignite.IgniteCache;
23 import org.apache.ignite.cache.query.QueryCursor;
24 import org.apache.ignite.cache.query.ScanQuery;
25 import org.apache.ignite.lang.IgniteBiPredicate;
26 import org.apache.ignite.ml.Model;
27 import org.apache.ignite.ml.math.Vector;
28 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
29 import org.apache.ignite.ml.selection.scoring.LabelPair;
30 import org.jetbrains.annotations.NotNull;
31
32 /**
33 * Truth with prediction cursor based on a data stored in Ignite cache.
34 *
35 * @param <L> Type of a label (truth or prediction).
36 * @param <K> Type of a key in {@code upstream} data.
37 * @param <V> Type of a value in {@code upstream} data.
38 */
39 public class CacheBasedLabelPairCursor<L, K, V> implements LabelPairCursor<L> {
40 /** Query cursor. */
41 private final QueryCursor<Cache.Entry<K, V>> cursor;
42
43 /** Feature extractor. */
44 private final IgniteBiFunction<K, V, Vector> featureExtractor;
45
46 /** Label extractor. */
47 private final IgniteBiFunction<K, V, L> lbExtractor;
48
49 /** Model for inference. */
50 private final Model<Vector, L> mdl;
51
52 /**
53 * Constructs a new instance of cache based truth with prediction cursor.
54 *
55 * @param upstreamCache Ignite cache with {@code upstream} data.
56 * @param filter Filter for {@code upstream} data.
57 * @param featureExtractor Feature extractor.
58 * @param lbExtractor Label extractor.
59 * @param mdl Model for inference.
60 */
61 public CacheBasedLabelPairCursor(IgniteCache<K, V> upstreamCache, IgniteBiPredicate<K, V> filter,
62 IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor,
63 Model<Vector, L> mdl) {
64 this.cursor = query(upstreamCache, filter);
65 this.featureExtractor = featureExtractor;
66 this.lbExtractor = lbExtractor;
67 this.mdl = mdl;
68 }
69
70 /**
71 * Constructs a new instance of cache based truth with prediction cursor.
72 *
73 * @param upstreamCache Ignite cache with {@code upstream} data.
74 * @param featureExtractor Feature extractor.
75 * @param lbExtractor Label extractor.
76 * @param mdl Model for inference.
77 */
78 public CacheBasedLabelPairCursor(IgniteCache<K, V> upstreamCache,
79 IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor,
80 Model<Vector, L> mdl) {
81 this.cursor = query(upstreamCache);
82 this.featureExtractor = featureExtractor;
83 this.lbExtractor = lbExtractor;
84 this.mdl = mdl;
85 }
86
87 /** {@inheritDoc} */
88 @Override public void close() {
89 cursor.close();
90 }
91
92 /** {@inheritDoc} */
93 @NotNull @Override public Iterator<LabelPair<L>> iterator() {
94 return new TruthWithPredictionIterator(cursor.iterator());
95 }
96
97 /**
98 * Queries the specified cache using the specified filter.
99 *
100 * @param upstreamCache Ignite cache with {@code upstream} data.
101 * @param filter Filter for {@code upstream} data.
102 * @return Query cursor.
103 */
104 private QueryCursor<Cache.Entry<K, V>> query(IgniteCache<K, V> upstreamCache, IgniteBiPredicate<K, V> filter) {
105 ScanQuery<K, V> qry = new ScanQuery<>();
106 qry.setFilter(filter);
107
108 return upstreamCache.query(qry);
109 }
110
111 /**
112 * Queries the specified cache using the specified filter.
113 *
114 * @param upstreamCache Ignite cache with {@code upstream} data.
115 * @return Query cursor.
116 */
117 private QueryCursor<Cache.Entry<K, V>> query(IgniteCache<K, V> upstreamCache) {
118 ScanQuery<K, V> qry = new ScanQuery<>();
119
120 return upstreamCache.query(qry);
121 }
122
123 /**
124 * Util iterator that makes predictions using the model.
125 */
126 private class TruthWithPredictionIterator implements Iterator<LabelPair<L>> {
127 /** Base iterator. */
128 private final Iterator<Cache.Entry<K, V>> iter;
129
130 /**
131 * Constructs a new instance of truth with prediction iterator.
132 *
133 * @param iter Base iterator.
134 */
135 public TruthWithPredictionIterator(Iterator<Cache.Entry<K, V>> iter) {
136 this.iter = iter;
137 }
138
139 /** {@inheritDoc} */
140 @Override public boolean hasNext() {
141 return iter.hasNext();
142 }
143
144 /** {@inheritDoc} */
145 @Override public LabelPair<L> next() {
146 Cache.Entry<K, V> entry = iter.next();
147
148 Vector features = featureExtractor.apply(entry.getKey(), entry.getValue());
149 L lb = lbExtractor.apply(entry.getKey(), entry.getValue());
150
151 return new LabelPair<>(lb, mdl.apply(features));
152 }
153 }
154 }