IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / modules / ml / src / main / java / org / apache / ignite / ml / trainers / DatasetTrainer.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.ml.trainers;
19
20 import java.util.Map;
21 import org.apache.ignite.Ignite;
22 import org.apache.ignite.IgniteCache;
23 import org.apache.ignite.lang.IgniteBiPredicate;
24 import org.apache.ignite.ml.Model;
25 import org.apache.ignite.ml.dataset.DatasetBuilder;
26 import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
27 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
28 import org.apache.ignite.ml.math.Vector;
29 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
30
31 /**
32 * Interface for trainers. Trainer is just a function which produces model from the data.
33 *
34 * @param <M> Type of a produced model.
35 * @param <L> Type of a label.
36 */
37 public interface DatasetTrainer<M extends Model, L> {
38 /**
39 * Trains model based on the specified data.
40 *
41 * @param datasetBuilder Dataset builder.
42 * @param featureExtractor Feature extractor.
43 * @param lbExtractor Label extractor.
44 * @param <K> Type of a key in {@code upstream} data.
45 * @param <V> Type of a value in {@code upstream} data.
46 * @return Model.
47 */
48 public <K, V> M fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor,
49 IgniteBiFunction<K, V, L> lbExtractor);
50
51 /**
52 * Trains model based on the specified data.
53 *
54 * @param ignite Ignite instance.
55 * @param cache Ignite cache.
56 * @param featureExtractor Feature extractor.
57 * @param lbExtractor Label extractor.
58 * @param <K> Type of a key in {@code upstream} data.
59 * @param <V> Type of a value in {@code upstream} data.
60 * @return Model.
61 */
62 public default <K, V> M fit(Ignite ignite, IgniteCache<K, V> cache,
63 IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
64 return fit(
65 new CacheBasedDatasetBuilder<>(ignite, cache),
66 featureExtractor,
67 lbExtractor
68 );
69 }
70
71 /**
72 * Trains model based on the specified data.
73 *
74 * @param ignite Ignite instance.
75 * @param cache Ignite cache.
76 * @param filter Filter for {@code upstream} data.
77 * @param featureExtractor Feature extractor.
78 * @param lbExtractor Label extractor.
79 * @param <K> Type of a key in {@code upstream} data.
80 * @param <V> Type of a value in {@code upstream} data.
81 * @return Model.
82 */
83 public default <K, V> M fit(Ignite ignite, IgniteCache<K, V> cache, IgniteBiPredicate<K, V> filter,
84 IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
85 return fit(
86 new CacheBasedDatasetBuilder<>(ignite, cache, filter),
87 featureExtractor,
88 lbExtractor
89 );
90 }
91
92 /**
93 * Trains model based on the specified data.
94 *
95 * @param data Data.
96 * @param parts Number of partitions.
97 * @param featureExtractor Feature extractor.
98 * @param lbExtractor Label extractor.
99 * @param <K> Type of a key in {@code upstream} data.
100 * @param <V> Type of a value in {@code upstream} data.
101 * @return Model.
102 */
103 public default <K, V> M fit(Map<K, V> data, int parts, IgniteBiFunction<K, V, Vector> featureExtractor,
104 IgniteBiFunction<K, V, L> lbExtractor) {
105 return fit(
106 new LocalDatasetBuilder<>(data, parts),
107 featureExtractor,
108 lbExtractor
109 );
110 }
111
112 /**
113 * Trains model based on the specified data.
114 *
115 * @param data Data.
116 * @param filter Filter for {@code upstream} data.
117 * @param parts Number of partitions.
118 * @param featureExtractor Feature extractor.
119 * @param lbExtractor Label extractor.
120 * @param <K> Type of a key in {@code upstream} data.
121 * @param <V> Type of a value in {@code upstream} data.
122 * @return Model.
123 */
124 public default <K, V> M fit(Map<K, V> data, IgniteBiPredicate<K, V> filter, int parts,
125 IgniteBiFunction<K, V, Vector> featureExtractor,
126 IgniteBiFunction<K, V, L> lbExtractor) {
127 return fit(
128 new LocalDatasetBuilder<>(data, filter, parts),
129 featureExtractor,
130 lbExtractor
131 );
132 }
133 }