IGNITE-8907: [ML] Using vectors in featureExtractor
[ignite.git] / modules / ml / src / main / java / org / apache / ignite / ml / preprocessing / encoding / stringencoder / StringEncoderPreprocessor.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.ignite.ml.preprocessing.encoding.stringencoder;
19
20 import java.util.Map;
21 import java.util.Set;
22 import org.apache.ignite.ml.math.Vector;
23 import org.apache.ignite.ml.math.VectorUtils;
24 import org.apache.ignite.ml.math.exceptions.preprocessing.UnknownStringValue;
25 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
26
27 /**
28 * Preprocessing function that makes String encoding.
29 *
30 * @param <K> Type of a key in {@code upstream} data.
31 * @param <V> Type of a value in {@code upstream} data.
32 */
33 public class StringEncoderPreprocessor<K, V> implements IgniteBiFunction<K, V, Vector> {
34 /** */
35 private static final long serialVersionUID = 6237812226382623469L;
36 /** */
37 private static final String KEY_FOR_NULL_VALUES = "";
38
39 /** Filling values. */
40 private final Map<String, Integer>[] encodingValues;
41
42 /** Base preprocessor. */
43 private final IgniteBiFunction<K, V, Object[]> basePreprocessor;
44
45 /** Feature indices to apply encoder.*/
46 private final Set<Integer> handledIndices;
47
48 /**
49 * Constructs a new instance of String Encoder preprocessor.
50 *
51 * @param basePreprocessor Base preprocessor.
52 * @param handledIndices Handled indices.
53 */
54 public StringEncoderPreprocessor(Map<String, Integer>[] encodingValues,
55 IgniteBiFunction<K, V, Object[]> basePreprocessor, Set<Integer> handledIndices) {
56 this.handledIndices = handledIndices;
57 this.encodingValues = encodingValues;
58 this.basePreprocessor = basePreprocessor;
59 }
60
61 /**
62 * Applies this preprocessor.
63 *
64 * @param k Key.
65 * @param v Value.
66 * @return Preprocessed row.
67 */
68 @Override public Vector apply(K k, V v) {
69 Object[] tmp = basePreprocessor.apply(k, v);
70 double[] res = new double[tmp.length];
71
72 for (int i = 0; i < res.length; i++) {
73 Object tmpObj = tmp[i];
74 if(handledIndices.contains(i)){
75 if(tmpObj.equals(Double.NaN) && encodingValues[i].containsKey(KEY_FOR_NULL_VALUES))
76 res[i] = encodingValues[i].get(KEY_FOR_NULL_VALUES);
77 else if (encodingValues[i].containsKey(tmpObj))
78 res[i] = encodingValues[i].get(tmpObj);
79 else
80 throw new UnknownStringValue(tmpObj.toString());
81 } else
82 res[i] = (double)tmpObj;
83 }
84 return VectorUtils.of(res);
85 }
86 }