2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 package org
.apache
.ignite
.ml
.preprocessing
.minmaxscaling
;
20 import org
.apache
.ignite
.ml
.dataset
.Dataset
;
21 import org
.apache
.ignite
.ml
.dataset
.DatasetBuilder
;
22 import org
.apache
.ignite
.ml
.dataset
.UpstreamEntry
;
23 import org
.apache
.ignite
.ml
.dataset
.primitive
.context
.EmptyContext
;
24 import org
.apache
.ignite
.ml
.math
.Vector
;
25 import org
.apache
.ignite
.ml
.math
.functions
.IgniteBiFunction
;
26 import org
.apache
.ignite
.ml
.preprocessing
.PreprocessingTrainer
;
29 * Trainer of the minmaxscaling preprocessor.
31 * @param <K> Type of a key in {@code upstream} data.
32 * @param <V> Type of a value in {@code upstream} data.
34 public class MinMaxScalerTrainer
<K
, V
> implements PreprocessingTrainer
<K
, V
, Vector
, Vector
> {
36 @Override public MinMaxScalerPreprocessor
<K
, V
> fit(DatasetBuilder
<K
, V
> datasetBuilder
,
37 IgniteBiFunction
<K
, V
, Vector
> basePreprocessor
) {
38 try (Dataset
<EmptyContext
, MinMaxScalerPartitionData
> dataset
= datasetBuilder
.build(
39 (upstream
, upstreamSize
) -> new EmptyContext(),
40 (upstream
, upstreamSize
, ctx
) -> {
44 while (upstream
.hasNext()) {
45 UpstreamEntry
<K
, V
> entity
= upstream
.next();
46 Vector row
= basePreprocessor
.apply(entity
.getKey(), entity
.getValue());
49 min
= new double[row
.size()];
50 for (int i
= 0; i
< min
.length
; i
++)
51 min
[i
] = Double
.MAX_VALUE
;
54 assert min
.length
== row
.size() : "Base preprocessor must return exactly " + min
.length
58 max
= new double[row
.size()];
59 for (int i
= 0; i
< max
.length
; i
++)
60 max
[i
] = -Double
.MAX_VALUE
;
63 assert max
.length
== row
.size() : "Base preprocessor must return exactly " + min
.length
66 for (int i
= 0; i
< row
.size(); i
++) {
67 if (row
.get(i
) < min
[i
])
69 if (row
.get(i
) > max
[i
])
74 return new MinMaxScalerPartitionData(min
, max
);
77 double[][] minMax
= dataset
.compute(
78 data
-> data
.getMin() != null ?
new double[][]{ data
.getMin(), data
.getMax() } : null
,
86 double[][] res
= new double[2][];
88 res
[0] = new double[a
[0].length
];
89 for (int i
= 0; i
< res
[0].length
; i
++)
90 res
[0][i
] = Math
.min(a
[0][i
], b
[0][i
]);
92 res
[1] = new double[a
[1].length
];
93 for (int i
= 0; i
< res
[1].length
; i
++)
94 res
[1][i
] = Math
.max(a
[1][i
], b
[1][i
]);
100 return new MinMaxScalerPreprocessor
<>(minMax
[0], minMax
[1], basePreprocessor
);
102 catch (Exception e
) {
103 throw new RuntimeException(e
);