IGNITE-9387: [ML] Model updating
authorAlexey Platonov <aplatonovv@gmail.com>
Tue, 4 Sep 2018 15:11:48 +0000 (18:11 +0300)
committerYury Babak <ybabak@gridgain.com>
Tue, 4 Sep 2018 15:11:48 +0000 (18:11 +0300)
this closes #4659

46 files changed:
examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesClassificationTrainerExample.java
examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java
modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/composition/BaggingModelTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsComposition.java
modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsCompositionFormat.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/knn/NNClassificationModel.java
modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationModel.java
modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNModelFormat.java
modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java
modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/AbstractLSQR.java
modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java
modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/PreprocessingTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassModel.java
modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationModel.java
modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java
modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java
modules/ml/src/test/java/org/apache/ignite/ml/common/LocalModelsTest.java
modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java
modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java
modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java

index 130b91a..075eab2 100644 (file)
@@ -23,6 +23,7 @@ import org.apache.ignite.Ignition;
 import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
 import org.apache.ignite.configuration.CacheConfiguration;
 import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.composition.ModelsComposition;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.apache.ignite.ml.trainers.DatasetTrainer;
@@ -58,7 +59,7 @@ public class GDBOnTreesClassificationTrainerExample {
                 IgniteCache<Integer, double[]> trainingSet = fillTrainingData(ignite, trainingSetCfg);
 
                 // Create regression trainer.
-                DatasetTrainer<Model<Vector, Double>, Double> trainer = new GDBBinaryClassifierOnTreesTrainer(1.0, 300, 2, 0.);
+                DatasetTrainer<ModelsComposition, Double> trainer = new GDBBinaryClassifierOnTreesTrainer(1.0, 300, 2, 0.);
 
                 // Train decision tree model.
                 Model<Vector, Double> mdl = trainer.fit(
index 31dd2b0..b2b08d0 100644 (file)
@@ -23,6 +23,7 @@ import org.apache.ignite.Ignition;
 import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
 import org.apache.ignite.configuration.CacheConfiguration;
 import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.composition.ModelsComposition;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.apache.ignite.ml.trainers.DatasetTrainer;
@@ -58,7 +59,7 @@ public class GDBOnTreesRegressionTrainerExample {
                 IgniteCache<Integer, double[]> trainingSet = fillTrainingData(ignite, trainingSetCfg);
 
                 // Create regression trainer.
-                DatasetTrainer<Model<Vector, Double>, Double> trainer = new GDBRegressionOnTreesTrainer(1.0, 2000, 1, 0.);
+                DatasetTrainer<ModelsComposition, Double> trainer = new GDBRegressionOnTreesTrainer(1.0, 2000, 1, 0.);
 
                 // Train decision tree model.
                 Model<Vector, Double> mdl = trainer.fit(
index 5b880fc..2596dbc 100644 (file)
@@ -21,6 +21,7 @@ import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
+import java.util.Optional;
 import java.util.Random;
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
@@ -72,6 +73,14 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> {
      */
     @Override public <K, V> KMeansModel fit(DatasetBuilder<K, V> datasetBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
+
+        return updateModel(null, datasetBuilder, featureExtractor, lbExtractor);
+    }
+
+    /** {@inheritDoc} */
+    @Override protected <K, V> KMeansModel updateModel(KMeansModel mdl, DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
+
         assert datasetBuilder != null;
 
         PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<Double, LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(
@@ -85,7 +94,7 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> {
             (upstream, upstreamSize) -> new EmptyContext(),
             partDataBuilder
         )) {
-            final int cols = dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> {
+            final Integer cols = dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> {
                 if (a == null)
                     return b == null ? 0 : b;
                 if (b == null)
@@ -93,7 +102,12 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> {
                 return b;
             });
 
-            centers = initClusterCentersRandomly(dataset, k);
+            if (cols == null)
+                return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
+
+            centers = Optional.ofNullable(mdl)
+                .map(KMeansModel::centers)
+                .orElseGet(() -> initClusterCentersRandomly(dataset, k));
 
             boolean converged = false;
             int iteration = 0;
@@ -127,6 +141,11 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> {
         return new KMeansModel(centers, distance);
     }
 
+    /** {@inheritDoc} */
+    @Override protected boolean checkState(KMeansModel mdl) {
+        return mdl.centers().length == k && mdl.distanceMeasure().equals(distance);
+    }
+
     /**
      * Prepares the data to define new centroids on current iteration.
      *
@@ -281,10 +300,12 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> {
             return this;
         }
 
+        /**
+         * @return centroid statistics.
+         */
         public ConcurrentHashMap<Integer, ConcurrentHashMap<Double, Integer>> getCentroidStat() {
             return centroidStat;
         }
-
     }
 
     /**
index f439789..493c1da 100644 (file)
@@ -177,4 +177,24 @@ public abstract class BaggingModelTrainer extends DatasetTrainer<ModelsCompositi
             return VectorUtils.of(newFeaturesValues);
         });
     }
+
+    /**
+     * Learn new models on dataset and create new Compositions over them and already learned models.
+     *
+     * @param mdl Learned model.
+     * @param datasetBuilder Dataset builder.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @param <K> Type of a key in {@code upstream} data.
+     * @param <V> Type of a value in {@code upstream} data.
+     * @return New models composition.
+     */
+    @Override public <K, V> ModelsComposition updateModel(ModelsComposition mdl, DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
+
+        ArrayList<Model<Vector, Double>> newModels = new ArrayList<>(mdl.getModels());
+        newModels.addAll(fit(datasetBuilder, featureExtractor, lbExtractor).getModels());
+
+        return new ModelsComposition(newModels, predictionsAggregator);
+    }
 }
index e14fa6d..36ee626 100644 (file)
@@ -19,6 +19,8 @@ package org.apache.ignite.ml.composition;
 
 import java.util.Collections;
 import java.util.List;
+import org.apache.ignite.ml.Exportable;
+import org.apache.ignite.ml.Exporter;
 import org.apache.ignite.ml.Model;
 import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
@@ -27,7 +29,7 @@ import org.apache.ignite.ml.util.ModelTrace;
 /**
  * Model consisting of several models and prediction aggregation strategy.
  */
-public class ModelsComposition implements Model<Vector, Double> {
+public class ModelsComposition implements Model<Vector, Double>, Exportable<ModelsCompositionFormat> {
     /**
      * Predictions aggregator.
      */
@@ -78,6 +80,12 @@ public class ModelsComposition implements Model<Vector, Double> {
     }
 
     /** {@inheritDoc} */
+    @Override public <P> void saveModel(Exporter<ModelsCompositionFormat, P> exporter, P path) {
+        ModelsCompositionFormat format = new ModelsCompositionFormat(models, predictionsAggregator);
+        exporter.save(format, path);
+    }
+
+    /** {@inheritDoc} */
     @Override public String toString() {
         return toString(false);
     }
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsCompositionFormat.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsCompositionFormat.java
new file mode 100644 (file)
index 0000000..68af0a9
--- /dev/null
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition;
+
+import java.io.Serializable;
+import java.util.List;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+
+/**
+ * ModelsComposition representation.
+ *
+ * @see ModelsComposition
+ */
+public class ModelsCompositionFormat implements Serializable {
+    /** Serial version uid. */
+    private static final long serialVersionUID = 9115341364082681837L;
+
+    /** Models. */
+    private List<Model<Vector, Double>> models;
+
+    /** Predictions aggregator. */
+    private PredictionsAggregator predictionsAggregator;
+
+    /**
+     * Creates an instance of ModelsCompositionFormat.
+     *
+     * @param models Models.
+     * @param predictionsAggregator Predictions aggregator.
+     */
+    public ModelsCompositionFormat(List<Model<Vector, Double>> models,PredictionsAggregator predictionsAggregator) {
+        this.models = models;
+        this.predictionsAggregator = predictionsAggregator;
+    }
+
+    /** */
+    public List<Model<Vector, Double>> models() {
+        return models;
+    }
+
+    /** */
+    public PredictionsAggregator predictionsAggregator() {
+        return predictionsAggregator;
+    }
+}
index 5a0f52a..c7f21dd 100644 (file)
@@ -52,7 +52,7 @@ import org.jetbrains.annotations.NotNull;
  *
  * But in practice Decision Trees is most used regressors (see: {@link DecisionTreeRegressionTrainer}).
  */
-public abstract class GDBTrainer extends DatasetTrainer<Model<Vector, Double>, Double> {
+public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Double> {
     /** Gradient step. */
     private final double gradientStep;
 
@@ -81,7 +81,7 @@ public abstract class GDBTrainer extends DatasetTrainer<Model<Vector, Double>, D
     }
 
     /** {@inheritDoc} */
-    @Override public <K, V> Model<Vector, Double> fit(DatasetBuilder<K, V> datasetBuilder,
+    @Override public <K, V> ModelsComposition fit(DatasetBuilder<K, V> datasetBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor,
         IgniteBiFunction<K, V, Double> lbExtractor) {
 
@@ -119,6 +119,21 @@ public abstract class GDBTrainer extends DatasetTrainer<Model<Vector, Double>, D
         };
     }
 
+
+    //TODO: This method will be implemented in IGNITE-9412
+    /** {@inheritDoc} */
+    @Override public <K, V> ModelsComposition updateModel(ModelsComposition mdl, DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
+
+        throw new UnsupportedOperationException();
+    }
+
+    //TODO: This method will be implemented in IGNITE-9412
+    /** {@inheritDoc} */
+    @Override protected boolean checkState(ModelsComposition mdl) {
+        throw new UnsupportedOperationException();
+    }
+
     /**
      * Defines unique labels in dataset if need (useful in case of classification).
      *
index b7a57f5..d435f91 100644 (file)
@@ -174,6 +174,11 @@ public abstract class NNClassificationModel implements Model<Vector, Double>, Ex
             return 1.0; // strategy.SIMPLE
     }
 
+    /** */
+    public DistanceMeasure getDistanceMeasure() {
+        return distanceMeasure;
+    }
+
     /** {@inheritDoc} */
     @Override public int hashCode() {
         int res = 1;
@@ -212,6 +217,17 @@ public abstract class NNClassificationModel implements Model<Vector, Double>, Ex
             .toString();
     }
 
+    /**
+     * Sets parameters from other model to this model.
+     *
+     * @param mdl Model.
+     */
+    protected void copyParametersFrom(NNClassificationModel mdl) {
+        this.k = mdl.k;
+        this.distanceMeasure = mdl.distanceMeasure;
+        this.stgy = mdl.stgy;
+    }
+
     /** */
     public abstract <P> void saveModel(Exporter<KNNModelFormat, P> exporter, P path);
 }
index e8c0b4a..bec82a9 100644 (file)
@@ -44,12 +44,18 @@ public class ANNClassificationModel extends NNClassificationModel  {
     /** The labeled set of candidates. */
     private final LabeledVectorSet<ProbableLabel, LabeledVector> candidates;
 
+    /** Centroid statistics. */
+    private final ANNClassificationTrainer.CentroidStat centroindsStat;
+
     /**
      * Build the model based on a candidates set.
      * @param centers The candidates set.
+     * @param centroindsStat
      */
-    public ANNClassificationModel(LabeledVectorSet<ProbableLabel, LabeledVector> centers) {
+    public ANNClassificationModel(LabeledVectorSet<ProbableLabel, LabeledVector> centers,
+        ANNClassificationTrainer.CentroidStat centroindsStat) {
        this.candidates = centers;
+       this.centroindsStat = centroindsStat;
     }
 
     /** */
@@ -57,6 +63,11 @@ public class ANNClassificationModel extends NNClassificationModel  {
         return candidates;
     }
 
+    /** */
+    public ANNClassificationTrainer.CentroidStat getCentroindsStat() {
+        return centroindsStat;
+    }
+
     /** {@inheritDoc} */
     @Override public Double apply(Vector v) {
             List<LabeledVector> neighbors = findKNearestNeighbors(v);
@@ -65,7 +76,7 @@ public class ANNClassificationModel extends NNClassificationModel  {
 
     /** */
     @Override public <P> void saveModel(Exporter<KNNModelFormat, P> exporter, P path) {
-        ANNModelFormat mdlData = new ANNModelFormat(k, distanceMeasure, stgy, candidates);
+        ANNModelFormat mdlData = new ANNModelFormat(k, distanceMeasure, stgy, candidates, centroindsStat);
         exporter.save(mdlData, path);
     }
 
index 1c45812..3e32b67 100644 (file)
 
 package org.apache.ignite.ml.knn.ann;
 
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
 import java.util.TreeMap;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentSkipListSet;
+import java.util.stream.Collectors;
 import org.apache.ignite.lang.IgniteBiTuple;
 import org.apache.ignite.ml.clustering.kmeans.KMeansModel;
 import org.apache.ignite.ml.clustering.kmeans.KMeansTrainer;
@@ -39,8 +43,8 @@ import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
 import org.jetbrains.annotations.NotNull;
 
 /**
- * ANN algorithm trainer to solve multi-class classification task.
- * This trainer is based on ACD strategy and KMeans clustering algorithm to find centroids.
+ * ANN algorithm trainer to solve multi-class classification task. This trainer is based on ACD strategy and KMeans
+ * clustering algorithm to find centroids.
  */
 public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClassificationModel> {
     /** Amount of clusters. */
@@ -61,29 +65,55 @@ public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClass
     /**
      * Trains model based on the specified data.
      *
-     * @param datasetBuilder   Dataset builder.
+     * @param datasetBuilder Dataset builder.
      * @param featureExtractor Feature extractor.
-     * @param lbExtractor      Label extractor.
+     * @param lbExtractor Label extractor.
      * @return Model.
      */
-    @Override public <K, V> ANNClassificationModel fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
-        final Vector[] centers = getCentroids(featureExtractor, lbExtractor, datasetBuilder);
+    @Override public <K, V> ANNClassificationModel fit(DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
 
-        final CentroidStat centroidStat = getCentroidStat(datasetBuilder, featureExtractor, lbExtractor, centers);
+        return updateModel(null, datasetBuilder, featureExtractor, lbExtractor);
+    }
+
+    /** {@inheritDoc} */
+    @Override protected <K, V> ANNClassificationModel updateModel(ANNClassificationModel mdl,
+        DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
+
+        List<Vector> centers;
+        CentroidStat centroidStat;
+        if (mdl != null) {
+            centers = Arrays.stream(mdl.getCandidates().data()).map(x -> x.features()).collect(Collectors.toList());
+            CentroidStat newStat = getCentroidStat(datasetBuilder, featureExtractor, lbExtractor, centers);
+            if(newStat == null)
+                return mdl;
+            CentroidStat oldStat = mdl.getCentroindsStat();
+            centroidStat = newStat.merge(oldStat);
+        } else {
+            centers = getCentroids(featureExtractor, lbExtractor, datasetBuilder);
+            centroidStat = getCentroidStat(datasetBuilder, featureExtractor, lbExtractor, centers);
+        }
 
         final LabeledVectorSet<ProbableLabel, LabeledVector> dataset = buildLabelsForCandidates(centers, centroidStat);
 
-        return new ANNClassificationModel(dataset);
+        return new ANNClassificationModel(dataset, centroidStat);
+    }
+
+    /** {@inheritDoc} */
+    @Override protected boolean checkState(ANNClassificationModel mdl) {
+        return mdl.getDistanceMeasure().equals(distance) && mdl.getCandidates().rowSize() == k;
     }
 
     /** */
-    @NotNull private LabeledVectorSet<ProbableLabel, LabeledVector> buildLabelsForCandidates(Vector[] centers, CentroidStat centroidStat) {
+    @NotNull private LabeledVectorSet<ProbableLabel, LabeledVector> buildLabelsForCandidates(List<Vector> centers,
+        CentroidStat centroidStat) {
         // init
-        final LabeledVector<Vector, ProbableLabel>[] arr = new LabeledVector[centers.length];
+        final LabeledVector<Vector, ProbableLabel>[] arr = new LabeledVector[centers.size()];
 
         // fill label for each centroid
-        for (int i = 0; i < centers.length; i++)
-            arr[i] = new LabeledVector<>(centers[i], fillProbableLabel(i, centroidStat));
+        for (int i = 0; i < centers.size(); i++)
+            arr[i] = new LabeledVector<>(centers.get(i), fillProbableLabel(i, centroidStat));
 
         return new LabeledVectorSet<>(arr);
     }
@@ -92,13 +122,14 @@ public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClass
      * Perform KMeans clusterization algorithm to find centroids.
      *
      * @param featureExtractor Feature extractor.
-     * @param lbExtractor      Label extractor.
-     * @param datasetBuilder   The dataset builder.
-     * @param <K>              Type of a key in {@code upstream} data.
-     * @param <V>              Type of a value in {@code upstream} data.
+     * @param lbExtractor Label extractor.
+     * @param datasetBuilder The dataset builder.
+     * @param <K> Type of a key in {@code upstream} data.
+     * @param <V> Type of a value in {@code upstream} data.
      * @return The arrays of vectors.
      */
-    private <K, V> Vector[] getCentroids(IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor, DatasetBuilder<K, V> datasetBuilder) {
+    private <K, V> List<Vector> getCentroids(IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor, DatasetBuilder<K, V> datasetBuilder) {
         KMeansTrainer trainer = new KMeansTrainer()
             .withK(k)
             .withMaxIterations(maxIterations)
@@ -112,7 +143,7 @@ public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClass
             lbExtractor
         );
 
-        return mdl.centers();
+        return Arrays.asList(mdl.centers());
     }
 
     /** */
@@ -125,21 +156,24 @@ public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClass
         ConcurrentHashMap<Double, Integer> centroidLbDistribution
             = centroidStat.centroidStat().get(centroidIdx);
 
-        if(centroidStat.counts.containsKey(centroidIdx)){
+        if (centroidStat.counts.containsKey(centroidIdx)) {
 
             int clusterSize = centroidStat
                 .counts
                 .get(centroidIdx);
 
             clsLbls.keySet().forEach(
-                (label) -> clsLbls.put(label, centroidLbDistribution.containsKey(label) ? ((double) (centroidLbDistribution.get(label)) / clusterSize) : 0.0)
+                (label) -> clsLbls.put(label, centroidLbDistribution.containsKey(label) ? ((double)(centroidLbDistribution.get(label)) / clusterSize) : 0.0)
             );
         }
         return new ProbableLabel(clsLbls);
     }
 
     /** */
-    private <K, V> CentroidStat getCentroidStat(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor, Vector[] centers) {
+    private <K, V> CentroidStat getCentroidStat(DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor, List<Vector> centers) {
+
         PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<Double, LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(
             featureExtractor,
             lbExtractor
@@ -174,7 +208,7 @@ public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClass
                     }
 
                     res.counts.merge(centroidIdx, 1,
-                        (IgniteBiFunction<Integer, Integer, Integer>) (i1, i2) -> i1 + i2);
+                        (IgniteBiFunction<Integer, Integer, Integer>)(i1, i2) -> i1 + i2);
                 }
                 return res;
             }, (a, b) -> {
@@ -194,15 +228,15 @@ public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClass
      * Find the closest cluster center index and distance to it from a given point.
      *
      * @param centers Centers to look in.
-     * @param pnt     Point.
+     * @param pnt Point.
      */
-    private IgniteBiTuple<Integer, Double> findClosestCentroid(Vector[] centers, LabeledVector pnt) {
+    private IgniteBiTuple<Integer, Double> findClosestCentroid(List<Vector> centers, LabeledVector pnt) {
         double bestDistance = Double.POSITIVE_INFINITY;
         int bestInd = 0;
 
-        for (int i = 0; i < centers.length; i++) {
-            if (centers[i] != null) {
-                double dist = distance.compute(centers[i], pnt.features());
+        for (int i = 0; i < centers.size(); i++) {
+            if (centers.get(i) != null) {
+                double dist = distance.compute(centers.get(i), pnt.features());
                 if (dist < bestDistance) {
                     bestDistance = dist;
                     bestInd = i;
@@ -212,7 +246,6 @@ public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClass
         return new IgniteBiTuple<>(bestInd, bestDistance);
     }
 
-
     /**
      * Gets the amount of clusters.
      *
@@ -314,7 +347,9 @@ public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClass
     }
 
     /** Service class used for statistics. */
-    public static class CentroidStat {
+    public static class CentroidStat implements Serializable {
+        /** Serial version uid. */
+        private static final long serialVersionUID = 7624883170532045144L;
 
         /** Count of points closest to the center with a given index. */
         ConcurrentHashMap<Integer, ConcurrentHashMap<Double, Integer>> centroidStat = new ConcurrentHashMap<>();
index e10f3b2..be09828 100644 (file)
@@ -30,6 +30,9 @@ import org.apache.ignite.ml.structures.LabeledVectorSet;
  * @see ANNClassificationModel
  */
 public class ANNModelFormat extends KNNModelFormat implements Serializable {
+    /** Centroid statistics. */
+    private final ANNClassificationTrainer.CentroidStat candidatesStat;
+
     /** The labeled set of candidates. */
     private LabeledVectorSet<ProbableLabel, LabeledVector> candidates;
 
@@ -38,15 +41,18 @@ public class ANNModelFormat extends KNNModelFormat implements Serializable {
      * @param k Amount of nearest neighbors.
      * @param measure Distance measure.
      * @param stgy kNN strategy.
+     * @param candidatesStat
      */
     public ANNModelFormat(int k,
-                          DistanceMeasure measure,
-                          NNStrategy stgy,
-                          LabeledVectorSet<ProbableLabel, LabeledVector> candidates) {
+        DistanceMeasure measure,
+        NNStrategy stgy,
+        LabeledVectorSet<ProbableLabel, LabeledVector> candidates,
+        ANNClassificationTrainer.CentroidStat candidatesStat) {
         this.k = k;
         this.distanceMeasure = measure;
         this.stgy = stgy;
         this.candidates = candidates;
+        this.candidatesStat = candidatesStat;
     }
 
     /** */
index 0b88f81..0d03ee5 100644 (file)
@@ -19,6 +19,7 @@ package org.apache.ignite.ml.knn.classification;
 
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -42,25 +43,29 @@ public class KNNClassificationModel extends NNClassificationModel implements Exp
     /** */
     private static final long serialVersionUID = -127386523291350345L;
 
-    /** Dataset. */
-    private Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset;
+    /** Datasets. */
+    private List<Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>>> datasets;
 
     /**
      * Builds the model via prepared dataset.
+     *
      * @param dataset Specially prepared object to run algorithm over it.
      */
     public KNNClassificationModel(Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset) {
-        this.dataset = dataset;
+        this.datasets = new ArrayList<>();
+        if (dataset != null)
+            datasets.add(dataset);
     }
 
     /** {@inheritDoc} */
     @Override public Double apply(Vector v) {
-        if(dataset != null) {
+        if (!datasets.isEmpty()) {
             List<LabeledVector> neighbors = findKNearestNeighbors(v);
 
             return classify(neighbors, v, stgy);
-        } else
+        } else {
             throw new IllegalStateException("The train kNN dataset is null");
+        }
     }
 
     /** */
@@ -77,6 +82,17 @@ public class KNNClassificationModel extends NNClassificationModel implements Exp
      * @return K-nearest neighbors.
      */
     protected List<LabeledVector> findKNearestNeighbors(Vector v) {
+        List<LabeledVector> neighborsFromPartitions = datasets.stream()
+            .flatMap(dataset -> findKNearestNeighborsInDataset(v, dataset).stream())
+            .collect(Collectors.toList());
+
+        LabeledVectorSet<Double, LabeledVector> neighborsToFilter = buildLabeledDatasetOnListOfVectors(neighborsFromPartitions);
+
+        return Arrays.asList(getKClosestVectors(neighborsToFilter, getDistances(v, neighborsToFilter)));
+    }
+
+    private List<LabeledVector> findKNearestNeighborsInDataset(Vector v,
+        Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset) {
         List<LabeledVector> neighborsFromPartitions = dataset.compute(data -> {
             TreeMap<Double, Set<Integer>> distanceIdxPairs = getDistances(v, data);
             return Arrays.asList(getKClosestVectors(data, distanceIdxPairs));
@@ -88,12 +104,14 @@ public class KNNClassificationModel extends NNClassificationModel implements Exp
             return Stream.concat(a.stream(), b.stream()).collect(Collectors.toList());
         });
 
+        if(neighborsFromPartitions == null)
+            return Collections.emptyList();
+
         LabeledVectorSet<Double, LabeledVector> neighborsToFilter = buildLabeledDatasetOnListOfVectors(neighborsFromPartitions);
 
         return Arrays.asList(getKClosestVectors(neighborsToFilter, getDistances(v, neighborsToFilter)));
     }
 
-
     /** */
     private double classify(List<LabeledVector> neighbors, Vector v, NNStrategy stgy) {
         Map<Double, Double> clsVotes = new HashMap<>();
@@ -116,5 +134,13 @@ public class KNNClassificationModel extends NNClassificationModel implements Exp
         return getClassWithMaxVotes(clsVotes);
     }
 
-
+    /**
+     * Copy parameters from other model and save all datasets from it.
+     *
+     * @param model Model.
+     */
+    public void copyStateFrom(KNNClassificationModel model) {
+        this.copyParametersFrom(model);
+        datasets.addAll(model.datasets);
+    }
 }
index e0a81f9..1a3ff73 100644 (file)
@@ -37,6 +37,24 @@ public class KNNClassificationTrainer extends SingleLabelDatasetTrainer<KNNClass
      */
     @Override public <K, V> KNNClassificationModel fit(DatasetBuilder<K, V> datasetBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
-        return new KNNClassificationModel(KNNUtils.buildDataset(datasetBuilder, featureExtractor, lbExtractor));
+
+        return updateModel(null, datasetBuilder, featureExtractor, lbExtractor);
+    }
+
+    /** {@inheritDoc} */
+    @Override public <K, V> KNNClassificationModel updateModel(KNNClassificationModel mdl,
+        DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
+
+        KNNClassificationModel res = new KNNClassificationModel(KNNUtils.buildDataset(datasetBuilder,
+            featureExtractor, lbExtractor));
+        if (mdl != null)
+            res.copyStateFrom(mdl);
+        return res;
+    }
+
+    /** {@inheritDoc} */
+    @Override protected boolean checkState(KNNClassificationModel mdl) {
+        return true;
     }
 }
index 395ce61..7a42dc8 100644 (file)
@@ -37,6 +37,23 @@ public class KNNRegressionTrainer extends SingleLabelDatasetTrainer<KNNRegressio
      */
     public <K, V> KNNRegressionModel fit(DatasetBuilder<K, V> datasetBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
-        return new KNNRegressionModel(KNNUtils.buildDataset(datasetBuilder, featureExtractor, lbExtractor));
+
+        return updateModel(null, datasetBuilder, featureExtractor, lbExtractor);
+    }
+
+    /** {@inheritDoc} */
+    @Override public <K, V> KNNRegressionModel updateModel(KNNRegressionModel mdl, DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
+
+        KNNRegressionModel res = new KNNRegressionModel(KNNUtils.buildDataset(datasetBuilder,
+            featureExtractor, lbExtractor));
+        if (mdl != null)
+            res.copyStateFrom(mdl);
+        return res;
+    }
+
+    /** {@inheritDoc} */
+    @Override protected boolean checkState(KNNRegressionModel mdl) {
+        return true;
     }
 }
index 7a362f7..c9281c0 100644 (file)
@@ -78,7 +78,9 @@ public abstract class AbstractLSQR {
      */
     public LSQRResult solve(double damp, double atol, double btol, double conlim, double iterLim, boolean calcVar,
         double[] x0) {
-        int n = getColumns();
+        Integer n = getColumns();
+        if(n == null)
+            return null;
 
         if (iterLim < 0)
             iterLim = 2 * n;
@@ -313,7 +315,7 @@ public abstract class AbstractLSQR {
     protected abstract double[] iter(double bnorm, double[] target);
 
     /** */
-    protected abstract int getColumns();
+    protected abstract Integer getColumns();
 
     /** */
     private static double[] symOrtho(double a, double b) {
index f75caef..14356e1 100644 (file)
@@ -100,7 +100,7 @@ public class LSQROnHeap<K, V> extends AbstractLSQR implements AutoCloseable {
      *
      * @return number of columns
      */
-    @Override protected int getColumns() {
+    @Override protected Integer getColumns() {
         return dataset.compute(
             data -> data.getFeatures() == null ? null : data.getFeatures().length / data.getRows(),
             (a, b) -> {
index 6727ba9..8f1a4cb 100644 (file)
@@ -111,12 +111,25 @@ public class MLPTrainer<P extends Serializable> extends MultiLabelDatasetTrainer
     public <K, V> MultilayerPerceptron fit(DatasetBuilder<K, V> datasetBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, double[]> lbExtractor) {
 
+        return updateModel(null, datasetBuilder, featureExtractor, lbExtractor);
+    }
+
+    /** {@inheritDoc} */
+    @Override protected <K, V> MultilayerPerceptron updateModel(MultilayerPerceptron lastLearnedModel,
+        DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, double[]> lbExtractor) {
+
         try (Dataset<EmptyContext, SimpleLabeledDatasetData> dataset = datasetBuilder.build(
             new EmptyContextBuilder<>(),
             new SimpleLabeledDatasetDataBuilder<>(featureExtractor, lbExtractor)
         )) {
-            MLPArchitecture arch = archSupplier.apply(dataset);
-            MultilayerPerceptron mdl = new MultilayerPerceptron(arch, new RandomInitializer(seed));
+            MultilayerPerceptron mdl;
+            if (lastLearnedModel != null) {
+                mdl = lastLearnedModel;
+            } else {
+                MLPArchitecture arch = archSupplier.apply(dataset);
+                mdl = new MultilayerPerceptron(arch, new RandomInitializer(seed));
+            }
             ParameterUpdateCalculator<? super MultilayerPerceptron, P> updater = updatesStgy.getUpdatesCalculator();
 
             for (int i = 0; i < maxIterations; i += locIterations) {
@@ -178,6 +191,9 @@ public class MLPTrainer<P extends Serializable> extends MultiLabelDatasetTrainer
                     }
                 );
 
+                if (totUp == null)
+                    return getLastTrainedModelOrThrowEmptyDatasetException(lastLearnedModel);
+
                 P update = updatesStgy.allUpdatesReducer().apply(totUp);
                 mdl = updater.update(mdl, update);
             }
@@ -189,6 +205,11 @@ public class MLPTrainer<P extends Serializable> extends MultiLabelDatasetTrainer
         }
     }
 
+    /** {@inheritDoc} */
+    @Override protected boolean checkState(MultilayerPerceptron mdl) {
+        return true;
+    }
+
     /**
      * Builds a batch of the data by fetching specified rows.
      *
index 1886ee5..b977864 100644 (file)
@@ -17,6 +17,7 @@
 
 package org.apache.ignite.ml.preprocessing;
 
+import java.util.Map;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
@@ -24,8 +25,6 @@ import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 
-import java.util.Map;
-
 /**
  * Trainer for preprocessor.
  *
index 8197779..5497177 100644 (file)
@@ -38,16 +38,34 @@ public class LinearRegressionLSQRTrainer extends SingleLabelDatasetTrainer<Linea
     @Override public <K, V> LinearRegressionModel fit(DatasetBuilder<K, V> datasetBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
 
+        return updateModel(null, datasetBuilder, featureExtractor, lbExtractor);
+    }
+
+    /** {@inheritDoc} */
+    @Override protected <K, V> LinearRegressionModel updateModel(LinearRegressionModel mdl,
+        DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
+
         LSQRResult res;
 
         try (LSQROnHeap<K, V> lsqr = new LSQROnHeap<>(
             datasetBuilder,
             new SimpleLabeledDatasetDataBuilder<>(
                 new FeatureExtractorWrapper<>(featureExtractor),
-                lbExtractor.andThen(e -> new double[]{e})
+                lbExtractor.andThen(e -> new double[] {e})
             )
         )) {
-            res = lsqr.solve(0, 1e-12, 1e-12, 1e8, -1, false, null);
+            double[] x0 = null;
+            if (mdl != null) {
+                int x0Size = mdl.getWeights().size() + 1;
+                Vector weights = mdl.getWeights().like(x0Size);
+                mdl.getWeights().nonZeroes().forEach(ith -> weights.set(ith.index(), ith.get()));
+                weights.set(weights.size() - 1, mdl.getIntercept());
+                x0 = weights.asArray();
+            }
+            res = lsqr.solve(0, 1e-12, 1e-12, 1e8, -1, false, x0);
+            if (res == null)
+                return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
         }
         catch (Exception e) {
             throw new RuntimeException(e);
@@ -58,4 +76,9 @@ public class LinearRegressionLSQRTrainer extends SingleLabelDatasetTrainer<Linea
 
         return new LinearRegressionModel(weights, x[x.length - 1]);
     }
+
+    /** {@inheritDoc} */
+    @Override protected boolean checkState(LinearRegressionModel mdl) {
+        return true;
+    }
 }
index 44f60d1..125ed24 100644 (file)
@@ -19,6 +19,7 @@ package org.apache.ignite.ml.regressions.linear;
 
 import java.io.Serializable;
 import java.util.Arrays;
+import java.util.Optional;
 import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
@@ -34,6 +35,7 @@ import org.apache.ignite.ml.nn.UpdatesStrategy;
 import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
 import org.apache.ignite.ml.optimization.LossFunctions;
 import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
+import org.jetbrains.annotations.NotNull;
 
 /**
  * Trainer of the linear regression model based on stochastic gradient descent algorithm.
@@ -43,16 +45,16 @@ public class LinearRegressionSGDTrainer<P extends Serializable> extends SingleLa
     private final UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy;
 
     /** Max number of iteration. */
-    private final int maxIterations;
+    private int maxIterations = 1000;
 
     /** Batch size. */
-    private final int batchSize;
+    private int batchSize = 10;
 
     /** Number of local iterations. */
-    private final int locIterations;
+    private int locIterations = 100;
 
     /** Seed for random generator. */
-    private final long seed;
+    private long seed = System.currentTimeMillis();
 
     /**
      * Constructs a new instance of linear regression SGD trainer.
@@ -72,10 +74,24 @@ public class LinearRegressionSGDTrainer<P extends Serializable> extends SingleLa
         this.seed = seed;
     }
 
+    /**
+     * Constructs a new instance of linear regression SGD trainer.
+     */
+    public LinearRegressionSGDTrainer(UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy) {
+        this.updatesStgy = updatesStgy;
+    }
+
     /** {@inheritDoc} */
     @Override public <K, V> LinearRegressionModel fit(DatasetBuilder<K, V> datasetBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
 
+        return updateModel(null, datasetBuilder, featureExtractor, lbExtractor);
+    }
+
+    /** {@inheritDoc} */
+    @Override protected <K, V> LinearRegressionModel updateModel(LinearRegressionModel mdl, DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
+
         IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier = dataset -> {
 
             int cols = dataset.compute(data -> {
@@ -108,7 +124,10 @@ public class LinearRegressionSGDTrainer<P extends Serializable> extends SingleLa
 
         IgniteBiFunction<K, V, double[]> lbE = (IgniteBiFunction<K, V, double[]>)(k, v) -> new double[] {lbExtractor.apply(k, v)};
 
-        MultilayerPerceptron mlp = trainer.fit(datasetBuilder, featureExtractor, lbE);
+        MultilayerPerceptron mlp = Optional.ofNullable(mdl)
+            .map(this::restoreMLPState)
+            .map(m -> trainer.update(m, datasetBuilder, featureExtractor, lbE))
+            .orElseGet(() -> trainer.fit(datasetBuilder, featureExtractor, lbE));
 
         double[] p = mlp.parameters().getStorage().data();
 
@@ -117,4 +136,72 @@ public class LinearRegressionSGDTrainer<P extends Serializable> extends SingleLa
             p[p.length - 1]
         );
     }
+
+    /**
+     * @param mdl Model.
+     * @return state of MLP from last learning.
+     */
+    @NotNull private MultilayerPerceptron restoreMLPState(LinearRegressionModel mdl) {
+        Vector weights = mdl.getWeights();
+        double intercept = mdl.getIntercept();
+        MLPArchitecture architecture1 = new MLPArchitecture(weights.size());
+        architecture1 = architecture1.withAddedLayer(1, true, Activators.LINEAR);
+        MLPArchitecture architecture = architecture1;
+        MultilayerPerceptron perceptron = new MultilayerPerceptron(architecture);
+
+        Vector mlpState = weights.like(weights.size() + 1);
+        weights.nonZeroes().forEach(ith -> mlpState.set(ith.index(), ith.get()));
+        mlpState.set(mlpState.size() - 1, intercept);
+        perceptron.setParameters(mlpState);
+        return perceptron;
+    }
+
+    /** {@inheritDoc} */
+    @Override protected boolean checkState(LinearRegressionModel mdl) {
+        return true;
+    }
+
+    /**
+     * Set up the max number of iterations before convergence.
+     *
+     * @param maxIterations The parameter value.
+     * @return Model with new max number of iterations before convergence parameter value.
+     */
+    public LinearRegressionSGDTrainer<P> withMaxIterations(int maxIterations) {
+        this.maxIterations = maxIterations;
+        return this;
+    }
+
+    /**
+     * Set up the batchSize parameter.
+     *
+     * @param batchSize The size of learning batch.
+     * @return Trainer with new batch size parameter value.
+     */
+    public LinearRegressionSGDTrainer<P> withBatchSize(int batchSize) {
+        this.batchSize = batchSize;
+        return this;
+    }
+
+    /**
+     * Set up the amount of local iterations of SGD algorithm.
+     *
+     * @param amountOfLocIterations The parameter value.
+     * @return Trainer with new locIterations parameter value.
+     */
+    public LinearRegressionSGDTrainer<P> withLocIterations(int amountOfLocIterations) {
+        this.locIterations = amountOfLocIterations;
+        return this;
+    }
+
+    /**
+     * Set up the random seed parameter.
+     *
+     * @param seed Seed for random generator.
+     * @return Trainer with new seed parameter value.
+     */
+    public LinearRegressionSGDTrainer<P> withSeed(long seed) {
+        this.seed = seed;
+        return this;
+    }
 }
index 6396279..839dab5 100644 (file)
@@ -34,6 +34,7 @@ import org.apache.ignite.ml.nn.UpdatesStrategy;
 import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
 import org.apache.ignite.ml.optimization.LossFunctions;
 import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
+import org.jetbrains.annotations.NotNull;
 
 /**
  * Trainer of the logistic regression model based on stochastic gradient descent algorithm.
@@ -76,8 +77,15 @@ public class LogisticRegressionSGDTrainer<P extends Serializable> extends Single
     @Override public <K, V> LogisticRegressionModel fit(DatasetBuilder<K, V> datasetBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
 
-        IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier = dataset -> {
+        return updateModel(null, datasetBuilder, featureExtractor, lbExtractor);
+    }
+
+    /** {@inheritDoc} */
+    @Override protected <K, V> LogisticRegressionModel updateModel(LogisticRegressionModel mdl,
+        DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
 
+        IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier = dataset -> {
             int cols = dataset.compute(data -> {
                 if (data.getFeatures() == null)
                     return null;
@@ -106,7 +114,13 @@ public class LogisticRegressionSGDTrainer<P extends Serializable> extends Single
             seed
         );
 
-        MultilayerPerceptron mlp = trainer.fit(datasetBuilder, featureExtractor, (k, v) -> new double[] {lbExtractor.apply(k, v)});
+        IgniteBiFunction<K, V, double[]> lbExtractorWrapper = (k, v) -> new double[] {lbExtractor.apply(k, v)};
+        MultilayerPerceptron mlp;
+        if(mdl != null) {
+            mlp = restoreMLPState(mdl);
+            mlp = trainer.update(mlp, datasetBuilder, featureExtractor, lbExtractorWrapper);
+        } else
+            mlp = trainer.fit(datasetBuilder, featureExtractor, lbExtractorWrapper);
 
         double[] params = mlp.parameters().getStorage().data();
 
@@ -114,4 +128,28 @@ public class LogisticRegressionSGDTrainer<P extends Serializable> extends Single
             params[params.length - 1]
         );
     }
+
+    /**
+     * @param mdl Model.
+     * @return state of MLP from last learning.
+     */
+    @NotNull private MultilayerPerceptron restoreMLPState(LogisticRegressionModel mdl) {
+        Vector weights = mdl.weights();
+        double intercept = mdl.intercept();
+        MLPArchitecture architecture1 = new MLPArchitecture(weights.size());
+        architecture1 = architecture1.withAddedLayer(1, true, Activators.SIGMOID);
+        MLPArchitecture architecture = architecture1;
+        MultilayerPerceptron perceptron = new MultilayerPerceptron(architecture);
+
+        Vector mlpState = weights.like(weights.size() + 1);
+        weights.nonZeroes().forEach(ith -> mlpState.set(ith.index(), ith.get()));
+        mlpState.set(mlpState.size() - 1, intercept);
+        perceptron.setParameters(mlpState);
+        return perceptron;
+    }
+
+    /** {@inheritDoc} */
+    @Override protected boolean checkState(LogisticRegressionModel mdl) {
+        return true;
+    }
 }
index 56d2d29..a7c9118 100644 (file)
@@ -21,6 +21,7 @@ import java.io.Serializable;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Objects;
+import java.util.Optional;
 import java.util.TreeMap;
 import org.apache.ignite.ml.Exportable;
 import org.apache.ignite.ml.Exporter;
@@ -103,4 +104,12 @@ public class LogRegressionMultiClassModel implements Model<Vector, Double>, Expo
     public void add(double clsLb, LogisticRegressionModel mdl) {
         models.put(clsLb, mdl);
     }
+
+    /**
+     * @param clsLb Class label.
+     * @return model for class label if it exists.
+     */
+    public Optional<LogisticRegressionModel> getModel(Double clsLb) {
+        return Optional.ofNullable(models.get(clsLb));
+    }
 }
index 4885373..eb44301 100644 (file)
@@ -22,6 +22,7 @@ import java.util.ArrayList;
 import java.util.Collection;
 import java.util.HashSet;
 import java.util.List;
+import java.util.Optional;
 import java.util.Set;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
@@ -33,6 +34,7 @@ import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.nn.MultilayerPerceptron;
 import org.apache.ignite.ml.nn.UpdatesStrategy;
+import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel;
 import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer;
 import org.apache.ignite.ml.structures.partition.LabelPartitionDataBuilderOnHeap;
 import org.apache.ignite.ml.structures.partition.LabelPartitionDataOnHeap;
@@ -71,6 +73,19 @@ public class LogRegressionMultiClassTrainer<P extends Serializable>
         IgniteBiFunction<K, V, Double> lbExtractor) {
         List<Double> classes = extractClassLabels(datasetBuilder, lbExtractor);
 
+        return updateModel(null, datasetBuilder, featureExtractor, lbExtractor);
+    }
+
+    /** {@inheritDoc} */
+    @Override public <K, V> LogRegressionMultiClassModel updateModel(LogRegressionMultiClassModel mdl,
+        DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
+
+        List<Double> classes = extractClassLabels(datasetBuilder, lbExtractor);
+
+        if(classes.isEmpty())
+            return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
+
         LogRegressionMultiClassModel multiClsMdl = new LogRegressionMultiClassModel();
 
         classes.forEach(clsLb -> {
@@ -85,12 +100,23 @@ public class LogRegressionMultiClassTrainer<P extends Serializable>
                 else
                     return 0.0;
             };
-            multiClsMdl.add(clsLb, trainer.fit(datasetBuilder, featureExtractor, lbTransformer));
+
+            LogisticRegressionModel model = Optional.ofNullable(mdl)
+                .flatMap(multiClassModel -> multiClassModel.getModel(clsLb))
+                .map(learnedModel -> trainer.update(learnedModel, datasetBuilder, featureExtractor, lbTransformer))
+                .orElseGet(() -> trainer.fit(datasetBuilder, featureExtractor, lbTransformer));
+
+            multiClsMdl.add(clsLb, model);
         });
 
         return multiClsMdl;
     }
 
+    /** {@inheritDoc} */
+    @Override protected boolean checkState(LogRegressionMultiClassModel mdl) {
+        return true;
+    }
+
     /** Iterates among dataset and collects class labels. */
     private <K, V> List<Double> extractClassLabels(DatasetBuilder<K, V> datasetBuilder,
         IgniteBiFunction<K, V, Double> lbExtractor) {
@@ -121,7 +147,8 @@ public class LogRegressionMultiClassTrainer<P extends Serializable>
                 return Stream.of(a, b).flatMap(Collection::stream).collect(Collectors.toSet());
             });
 
-            res.addAll(clsLabels);
+            if (clsLabels != null)
+                res.addAll(clsLabels);
 
         }
         catch (Exception e) {
@@ -191,7 +218,7 @@ public class LogRegressionMultiClassTrainer<P extends Serializable>
     }
 
     /**
-     * Set up the regularization parameter.
+     * Set up the random seed parameter.
      *
      * @param seed Seed for random generator.
      * @return Trainer with new seed parameter value.
index 933a712..573df1a 100644 (file)
@@ -22,9 +22,11 @@ import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.PartitionDataBuilder;
 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.math.StorageConstants;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
+import org.apache.ignite.ml.math.primitives.vector.impl.SparseVector;
 import org.apache.ignite.ml.structures.LabeledVector;
 import org.apache.ignite.ml.structures.LabeledVectorSet;
 import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap;
@@ -61,6 +63,14 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai
     @Override public <K, V> SVMLinearBinaryClassificationModel fit(DatasetBuilder<K, V> datasetBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
 
+        return updateModel(null, datasetBuilder, featureExtractor, lbExtractor);
+    }
+
+    /** {@inheritDoc} */
+    @Override protected <K, V> SVMLinearBinaryClassificationModel updateModel(SVMLinearBinaryClassificationModel mdl,
+        DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
+
         assert datasetBuilder != null;
 
         PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<Double, LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(
@@ -74,29 +84,57 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai
             (upstream, upstreamSize) -> new EmptyContext(),
             partDataBuilder
         )) {
-            final int cols = dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> {
-                if (a == null)
-                    return b == null ? 0 : b;
-                if (b == null)
-                    return a;
-                return b;
-            });
-
-            final int weightVectorSizeWithIntercept = cols + 1;
-
-            weights = initializeWeightsWithZeros(weightVectorSizeWithIntercept);
+            if (mdl == null) {
+                final int cols = dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> {
+                    if (a == null)
+                        return b == null ? 0 : b;
+                    if (b == null)
+                        return a;
+                    return b;
+                });
+
+                final int weightVectorSizeWithIntercept = cols + 1;
+                weights = initializeWeightsWithZeros(weightVectorSizeWithIntercept);
+            } else {
+                weights = getStateVector(mdl);
+            }
 
             for (int i = 0; i < this.getAmountOfIterations(); i++) {
                 Vector deltaWeights = calculateUpdates(weights, dataset);
+                if (deltaWeights == null)
+                    return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
+
                 weights = weights.plus(deltaWeights); // creates new vector
             }
-        }
-        catch (Exception e) {
+        } catch (Exception e) {
             throw new RuntimeException(e);
         }
         return new SVMLinearBinaryClassificationModel(weights.viewPart(1, weights.size() - 1), weights.get(0));
     }
 
+    /** {@inheritDoc} */
+    @Override protected boolean checkState(SVMLinearBinaryClassificationModel mdl) {
+        return true;
+    }
+
+    /**
+     * @param mdl Model.
+     * @return vector of model weights with intercept.
+     */
+    private Vector getStateVector(SVMLinearBinaryClassificationModel mdl) {
+        double intercept = mdl.intercept();
+        Vector weights = mdl.weights();
+
+        int stateVectorSize = weights.size() + 1;
+        Vector result = weights.isDense() ?
+            new DenseVector(stateVectorSize) :
+            new SparseVector(stateVectorSize, StorageConstants.RANDOM_ACCESS_MODE);
+
+        result.set(0, intercept);
+        weights.nonZeroes().forEach(ith -> result.set(ith.index(), ith.get()));
+        return result;
+    }
+
     /** */
     @NotNull private Vector initializeWeightsWithZeros(int vectorSize) {
         return new DenseVector(vectorSize);
index 4b04824..46bf4b2 100644 (file)
@@ -21,6 +21,7 @@ import java.io.Serializable;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Objects;
+import java.util.Optional;
 import java.util.TreeMap;
 import org.apache.ignite.ml.Exportable;
 import org.apache.ignite.ml.Exporter;
@@ -102,4 +103,12 @@ public class SVMLinearMultiClassClassificationModel implements Model<Vector, Dou
     public void add(double clsLb, SVMLinearBinaryClassificationModel mdl) {
         models.put(clsLb, mdl);
     }
+
+    /**
+     * @param clsLb Class label.
+     * @return model trained for target class if it exists.
+     */
+    public Optional<SVMLinearBinaryClassificationModel> getModelForClass(double clsLb) {
+        return Optional.of(models.get(clsLb));
+    }
 }
index 4b7cc95..b77baa2 100644 (file)
@@ -57,15 +57,26 @@ public class SVMLinearMultiClassClassificationTrainer
     /**
      * Trains model based on the specified data.
      *
-     * @param datasetBuilder   Dataset builder.
+     * @param datasetBuilder Dataset builder.
      * @param featureExtractor Feature extractor.
-     * @param lbExtractor      Label extractor.
+     * @param lbExtractor Label extractor.
      * @return Model.
      */
     @Override public <K, V> SVMLinearMultiClassClassificationModel fit(DatasetBuilder<K, V> datasetBuilder,
-                                                                IgniteBiFunction<K, V, Vector> featureExtractor,
-                                                                IgniteBiFunction<K, V, Double> lbExtractor) {
+        IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
+        return updateModel(null, datasetBuilder, featureExtractor, lbExtractor);
+    }
+
+    /** {@inheritDoc} */
+    @Override public <K, V> SVMLinearMultiClassClassificationModel updateModel(
+        SVMLinearMultiClassClassificationModel mdl,
+        DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
+
         List<Double> classes = extractClassLabels(datasetBuilder, lbExtractor);
+        if (classes.isEmpty())
+            return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
 
         SVMLinearMultiClassClassificationModel multiClsMdl = new SVMLinearMultiClassClassificationModel();
 
@@ -84,14 +95,60 @@ public class SVMLinearMultiClassClassificationTrainer
                 else
                     return -1.0;
             };
-            multiClsMdl.add(clsLb, trainer.fit(datasetBuilder, featureExtractor, lbTransformer));
+
+            SVMLinearBinaryClassificationModel model;
+            if (mdl == null)
+                model = learnNewModel(trainer, datasetBuilder, featureExtractor, lbTransformer);
+            else
+                model = updateModel(mdl, clsLb, trainer, datasetBuilder, featureExtractor, lbTransformer);
+            multiClsMdl.add(clsLb, model);
         });
 
         return multiClsMdl;
     }
 
+    /** {@inheritDoc} */
+    @Override protected boolean checkState(SVMLinearMultiClassClassificationModel mdl) {
+        return true;
+    }
+
+    /**
+     * Trains model based on the specified data.
+     *
+     * @param svmTrainer Prepared SVM trainer.
+     * @param datasetBuilder Dataset builder.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     */
+    private <K, V> SVMLinearBinaryClassificationModel learnNewModel(SVMLinearBinaryClassificationTrainer svmTrainer,
+        DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
+
+        return svmTrainer.fit(datasetBuilder, featureExtractor, lbExtractor);
+    }
+
+    /**
+     * Updates already learned model or fit new model if there is no model for current class label.
+     *
+     * @param multiClsMdl Learning multi-class model.
+     * @param clsLb Current class label.
+     * @param svmTrainer Prepared SVM trainer.
+     * @param datasetBuilder Dataset builder.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     */
+    private <K, V> SVMLinearBinaryClassificationModel updateModel(SVMLinearMultiClassClassificationModel multiClsMdl,
+        Double clsLb, SVMLinearBinaryClassificationTrainer svmTrainer, DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
+
+        return multiClsMdl.getModelForClass(clsLb)
+            .map(learnedModel -> svmTrainer.update(learnedModel, datasetBuilder, featureExtractor, lbExtractor))
+            .orElseGet(() -> svmTrainer.fit(datasetBuilder, featureExtractor, lbExtractor));
+    }
+
     /** Iterates among dataset and collects class labels. */
-    private <K, V> List<Double> extractClassLabels(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Double> lbExtractor) {
+    private <K, V> List<Double> extractClassLabels(DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
         assert datasetBuilder != null;
 
         PartitionDataBuilder<K, V, EmptyContext, LabelPartitionDataOnHeap> partDataBuilder = new LabelPartitionDataBuilderOnHeap<>(lbExtractor);
@@ -107,7 +164,8 @@ public class SVMLinearMultiClassClassificationTrainer
 
                 final double[] lbs = data.getY();
 
-                for (double lb : lbs) locClsLabels.add(lb);
+                for (double lb : lbs)
+                    locClsLabels.add(lb);
 
                 return locClsLabels;
             }, (a, b) -> {
@@ -118,8 +176,8 @@ public class SVMLinearMultiClassClassificationTrainer
                 return Stream.of(a, b).flatMap(Collection::stream).collect(Collectors.toSet());
             });
 
-            res.addAll(clsLabels);
-
+            if (clsLabels != null)
+                res.addAll(clsLabels);
         } catch (Exception e) {
             throw new RuntimeException(e);
         }
@@ -132,7 +190,7 @@ public class SVMLinearMultiClassClassificationTrainer
      * @param lambda The regularization parameter. Should be more than 0.0.
      * @return Trainer with new lambda parameter value.
      */
-    public SVMLinearMultiClassClassificationTrainer  withLambda(double lambda) {
+    public SVMLinearMultiClassClassificationTrainer withLambda(double lambda) {
         assert lambda > 0.0;
         this.lambda = lambda;
         return this;
@@ -162,7 +220,7 @@ public class SVMLinearMultiClassClassificationTrainer
      * @param amountOfIterations The parameter value.
      * @return Trainer with new amountOfIterations parameter value.
      */
-    public SVMLinearMultiClassClassificationTrainer  withAmountOfIterations(int amountOfIterations) {
+    public SVMLinearMultiClassClassificationTrainer withAmountOfIterations(int amountOfIterations) {
         this.amountOfIterations = amountOfIterations;
         return this;
     }
@@ -182,7 +240,7 @@ public class SVMLinearMultiClassClassificationTrainer
      * @param amountOfLocIterations The parameter value.
      * @return Trainer with new amountOfLocIterations parameter value.
      */
-    public SVMLinearMultiClassClassificationTrainer  withAmountOfLocIterations(int amountOfLocIterations) {
+    public SVMLinearMultiClassClassificationTrainer withAmountOfLocIterations(int amountOfLocIterations) {
         this.amountOfLocIterations = amountOfLocIterations;
         return this;
     }
index 2f5d5d6..fb34c93 100644 (file)
@@ -26,8 +26,10 @@ import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
 import org.apache.ignite.ml.environment.LearningEnvironment;
+import org.apache.ignite.ml.environment.logging.MLLogger;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.jetbrains.annotations.NotNull;
 
 /**
  * Interface for trainers. Trainer is just a function which produces model from the data.
@@ -53,6 +55,71 @@ public abstract class DatasetTrainer<M extends Model, L> {
         IgniteBiFunction<K, V, L> lbExtractor);
 
     /**
+     * Gets state of model in arguments, compare it with training parameters of trainer and if they are fit then
+     * trainer updates model in according to new data and return new model. In other case trains new model.
+     *
+     * @param mdl Learned model.
+     * @param datasetBuilder Dataset builder.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @param <K> Type of a key in {@code upstream} data.
+     * @param <V> Type of a value in {@code upstream} data.
+     * @return Updated model.
+     */
+    public <K,V> M update(M mdl, DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+
+        if(mdl != null) {
+            if(checkState(mdl)) {
+                return updateModel(mdl, datasetBuilder, featureExtractor, lbExtractor);
+            } else {
+                environment.logger(getClass()).log(
+                    MLLogger.VerboseLevel.HIGH,
+                    "Model cannot be updated because of initial state of " +
+                        "it doesn't corresponds to trainer parameters"
+                );
+            }
+        }
+
+        return fit(datasetBuilder, featureExtractor, lbExtractor);
+    }
+
+    /**
+     * @param mdl Model.
+     * @return true if current critical for training parameters correspond to parameters from last training.
+     */
+    protected abstract boolean checkState(M mdl);
+
+    /**
+     * Used on update phase when given dataset is empty.
+     * If last trained model exist then method returns it. In other case throws IllegalArgumentException.
+     *
+     * @param lastTrainedMdl Model.
+     */
+    @NotNull protected M getLastTrainedModelOrThrowEmptyDatasetException(M lastTrainedMdl) {
+        String msg = "Cannot train model on empty dataset";
+        if (lastTrainedMdl != null) {
+            environment.logger(getClass()).log(MLLogger.VerboseLevel.HIGH, msg);
+            return lastTrainedMdl;
+        } else
+            throw new EmptyDatasetException();
+    }
+
+    /**
+     * Gets state of model in arguments, update in according to new data and return new model.
+     *
+     * @param mdl Learned model.
+     * @param datasetBuilder Dataset builder.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @param <K> Type of a key in {@code upstream} data.
+     * @param <V> Type of a value in {@code upstream} data.
+     * @return Updated model.
+     */
+    protected abstract <K, V> M updateModel(M mdl, DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor);
+
+    /**
      * Trains model based on the specified data.
      *
      * @param ignite Ignite instance.
@@ -73,6 +140,27 @@ public abstract class DatasetTrainer<M extends Model, L> {
     }
 
     /**
+     * Gets state of model in arguments, update in according to new data and return new model.
+     *
+     * @param mdl Learned model.
+     * @param ignite Ignite instance.
+     * @param cache Ignite cache.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @param <K> Type of a key in {@code upstream} data.
+     * @param <V> Type of a value in {@code upstream} data.
+     * @return Updated model.
+     */
+    public <K, V> M update(M mdl, Ignite ignite, IgniteCache<K, V> cache,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+        return update(
+            mdl, new CacheBasedDatasetBuilder<>(ignite, cache),
+            featureExtractor,
+            lbExtractor
+        );
+    }
+
+    /**
      * Trains model based on the specified data.
      *
      * @param ignite Ignite instance.
@@ -94,6 +182,28 @@ public abstract class DatasetTrainer<M extends Model, L> {
     }
 
     /**
+     * Gets state of model in arguments, update in according to new data and return new model.
+     *
+     * @param mdl Learned model.
+     * @param ignite Ignite instance.
+     * @param cache Ignite cache.
+     * @param filter Filter for {@code upstream} data.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @param <K> Type of a key in {@code upstream} data.
+     * @param <V> Type of a value in {@code upstream} data.
+     * @return Updated model.
+     */
+    public <K, V> M update(M mdl, Ignite ignite, IgniteCache<K, V> cache, IgniteBiPredicate<K, V> filter,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+        return update(
+            mdl, new CacheBasedDatasetBuilder<>(ignite, cache, filter),
+            featureExtractor,
+            lbExtractor
+        );
+    }
+
+    /**
      * Trains model based on the specified data.
      *
      * @param data Data.
@@ -114,6 +224,27 @@ public abstract class DatasetTrainer<M extends Model, L> {
     }
 
     /**
+     * Gets state of model in arguments, update in according to new data and return new model.
+     *
+     * @param mdl Learned model.
+     * @param data Data.
+     * @param parts Number of partitions.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @param <K> Type of a key in {@code upstream} data.
+     * @param <V> Type of a value in {@code upstream} data.
+     * @return Updated model.
+     */
+    public <K, V> M update(M mdl, Map<K, V> data, int parts, IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, L> lbExtractor) {
+        return update(
+            mdl, new LocalDatasetBuilder<>(data, parts),
+            featureExtractor,
+            lbExtractor
+        );
+    }
+
+    /**
      * Trains model based on the specified data.
      *
      * @param data Data.
@@ -136,10 +267,45 @@ public abstract class DatasetTrainer<M extends Model, L> {
     }
 
     /**
+     * Gets state of model in arguments, update in according to new data and return new model.
+     *
+     * @param data Data.
+     * @param filter Filter for {@code upstream} data.
+     * @param parts Number of partitions.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @param <K> Type of a key in {@code upstream} data.
+     * @param <V> Type of a value in {@code upstream} data.
+     * @return Updated model.
+     */
+    public <K, V> M update(M mdl, Map<K, V> data, IgniteBiPredicate<K, V> filter, int parts,
+        IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, L> lbExtractor) {
+        return update(
+            mdl, new LocalDatasetBuilder<>(data, filter, parts),
+            featureExtractor,
+            lbExtractor
+        );
+    }
+
+    /**
      * Sets learning Environment
      * @param environment Environment.
      */
     public void setEnvironment(LearningEnvironment environment) {
         this.environment = environment;
     }
+
+    /** */
+    public static class EmptyDatasetException extends IllegalArgumentException {
+        /** Serial version uid. */
+        private static final long serialVersionUID = 6914650522523293521L;
+
+        /**
+         * Constructs an instance of EmptyDatasetException.
+         */
+        public EmptyDatasetException() {
+            super("Cannot train model on empty dataset");
+        }
+    }
 }
index de8994a..355048a 100644 (file)
@@ -86,6 +86,29 @@ public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends Dataset
         }
     }
 
+    /**
+     * Trains new model based on dataset because there is no valid approach to update decision trees.
+     *
+     * @param mdl Learned model.
+     * @param datasetBuilder Dataset builder.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @param <K> Type of a key in {@code upstream} data.
+     * @param <V> Type of a value in {@code upstream} data.
+     * @return New model based on new dataset.
+     */
+    @Override public <K, V> DecisionTreeNode updateModel(DecisionTreeNode mdl, DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
+
+        return fit(datasetBuilder, featureExtractor, lbExtractor);
+    }
+
+    /** {@inheritDoc} */
+    @Override protected boolean checkState(DecisionTreeNode mdl) {
+        return true;
+    }
+
+    /** */
     public <K,V> DecisionTreeNode fit(Dataset<EmptyContext, DecisionTreeData> dataset) {
         return split(dataset, e -> true, 0, getImpurityMeasureCalculator(dataset));
     }
index 559dfff..7832584 100644 (file)
@@ -64,8 +64,9 @@ public class RandomForestClassifierTrainer
      * This id can be used as index in arrays or lists.
      *
      * @param dataset Dataset.
+     * @return true if initialization was done.
      */
-    @Override protected void init(Dataset<EmptyContext, BootstrappedDatasetPartition> dataset) {
+    @Override protected boolean init(Dataset<EmptyContext, BootstrappedDatasetPartition> dataset) {
         Set<Double> uniqLabels = dataset.compute(
             x -> {
                 Set<Double> labels = new HashSet<>();
@@ -85,11 +86,14 @@ public class RandomForestClassifierTrainer
             }
         );
 
+        if(uniqLabels == null)
+            return false;
+
         int i = 0;
         for (Double label : uniqLabels)
             lblMapping.put(label, i++);
 
-        super.init(dataset);
+        return super.init(dataset);
     }
 
     /** {@inheritDoc} */
index cb25aa3..91fcf0a 100644 (file)
@@ -30,6 +30,7 @@ import java.util.Set;
 import java.util.function.Function;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
+import org.apache.ignite.ml.Model;
 import org.apache.ignite.ml.composition.ModelsComposition;
 import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
@@ -116,7 +117,8 @@ public abstract class RandomForestTrainer<L, S extends ImpurityComputer<Bootstra
             new EmptyContextBuilder<>(),
             new BootstrappedDatasetBuilder<>(featureExtractor, lbExtractor, cntOfTrees, subsampleSize))) {
 
-            init(dataset);
+            if(!init(dataset))
+                return buildComposition(Collections.emptyList());
             models = fit(dataset);
         }
         catch (Exception e) {
@@ -202,7 +204,8 @@ public abstract class RandomForestTrainer<L, S extends ImpurityComputer<Bootstra
      *
      * @param dataset Dataset.
      */
-    protected void init(Dataset<EmptyContext, BootstrappedDatasetPartition> dataset) {
+    protected boolean init(Dataset<EmptyContext, BootstrappedDatasetPartition> dataset) {
+        return true;
     }
 
     /**
@@ -215,6 +218,8 @@ public abstract class RandomForestTrainer<L, S extends ImpurityComputer<Bootstra
         Queue<TreeNode> treesQueue = createRootsQueue();
         ArrayList<TreeRoot> roots = initTrees(treesQueue);
         Map<Integer, BucketMeta> histMeta = computeHistogramMeta(meta, dataset);
+        if(histMeta.isEmpty())
+            return Collections.emptyList();
 
         ImpurityHistogramsComputer<S> histogramsComputer = createImpurityHistogramsComputer();
         while (!treesQueue.isEmpty()) {
@@ -232,6 +237,23 @@ public abstract class RandomForestTrainer<L, S extends ImpurityComputer<Bootstra
         return roots;
     }
 
+    /** {@inheritDoc} */
+    @Override protected boolean checkState(ModelsComposition mdl) {
+        ModelsComposition fakeComposition = buildComposition(Collections.emptyList());
+        return mdl.getPredictionsAggregator().getClass() == fakeComposition.getPredictionsAggregator().getClass();
+    }
+
+    /** {@inheritDoc} */
+    @Override protected <K, V> ModelsComposition updateModel(ModelsComposition mdl, DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
+
+        ArrayList<Model<Vector, Double>> oldModels = new ArrayList<>(mdl.getModels());
+        ModelsComposition newModels = fit(datasetBuilder, featureExtractor, lbExtractor);
+        oldModels.addAll(newModels.getModels());
+
+        return new ModelsComposition(oldModels, mdl.getPredictionsAggregator());
+    }
+
     /**
      * Split node with NodeId if need.
      *
@@ -302,6 +324,8 @@ public abstract class RandomForestTrainer<L, S extends ImpurityComputer<Bootstra
 
         List<NormalDistributionStatistics> stats = new NormalDistributionStatisticsComputer()
             .computeStatistics(meta, dataset);
+        if(stats == null)
+            return Collections.emptyMap();
 
         Map<Integer, BucketMeta> bucketsMeta = new HashMap<>();
         for (int i = 0; i < stats.size(); i++) {
index aae5af1..03f044a 100644 (file)
@@ -27,6 +27,7 @@ import org.apache.ignite.ml.math.distances.EuclideanDistance;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
+import org.jetbrains.annotations.NotNull;
 import org.junit.Test;
 
 import static org.junit.Assert.assertEquals;
@@ -39,19 +40,70 @@ public class KMeansTrainerTest {
     /** Precision in test checks. */
     private static final double PRECISION = 1e-2;
 
+    /** Data. */
+    private static final Map<Integer, double[]> data = new HashMap<>();
+
+    static {
+        data.put(0, new double[] {1.0, 1.0, 1.0});
+        data.put(1, new double[] {1.0, 2.0, 1.0});
+        data.put(2, new double[] {2.0, 1.0, 1.0});
+        data.put(3, new double[] {-1.0, -1.0, 2.0});
+        data.put(4, new double[] {-1.0, -2.0, 2.0});
+        data.put(5, new double[] {-2.0, -1.0, 2.0});
+    }
+
     /**
      * A few points, one cluster, one iteration
      */
     @Test
     public void findOneClusters() {
-        Map<Integer, double[]> data = new HashMap<>();
-        data.put(0, new double[]{1.0, 1.0, 1.0});
-        data.put(1, new double[]{1.0, 2.0, 1.0});
-        data.put(2, new double[]{2.0, 1.0, 1.0});
-        data.put(3, new double[]{-1.0, -1.0, 2.0});
-        data.put(4, new double[]{-1.0, -2.0, 2.0});
-        data.put(5, new double[]{-2.0, -1.0, 2.0});
+        KMeansTrainer trainer = createAndCheckTrainer();
+        KMeansModel knnMdl = trainer.withK(1).fit(
+            new LocalDatasetBuilder<>(data, 2),
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[2]
+        );
+
+        Vector firstVector = new DenseVector(new double[] {2.0, 2.0});
+        assertEquals(knnMdl.apply(firstVector), 0.0, PRECISION);
+        Vector secondVector = new DenseVector(new double[] {-2.0, -2.0});
+        assertEquals(knnMdl.apply(secondVector), 0.0, PRECISION);
+        assertEquals(trainer.getMaxIterations(), 1);
+        assertEquals(trainer.getEpsilon(), PRECISION, PRECISION);
+    }
 
+    /** */
+    @Test
+    public void testUpdateMdl() {
+        KMeansTrainer trainer = createAndCheckTrainer();
+        KMeansModel originalMdl = trainer.withK(1).fit(
+            new LocalDatasetBuilder<>(data, 2),
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[2]
+        );
+        KMeansModel updatedMdlOnSameDataset = trainer.update(
+            originalMdl,
+            new LocalDatasetBuilder<>(data, 2),
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[2]
+        );
+        KMeansModel updatedMdlOnEmptyDataset = trainer.update(
+            originalMdl,
+            new LocalDatasetBuilder<>(new HashMap<Integer, double[]>(), 2),
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[2]
+        );
+
+        Vector firstVector = new DenseVector(new double[] {2.0, 2.0});
+        Vector secondVector = new DenseVector(new double[] {-2.0, -2.0});
+        assertEquals(originalMdl.apply(firstVector), updatedMdlOnSameDataset.apply(firstVector), PRECISION);
+        assertEquals(originalMdl.apply(secondVector), updatedMdlOnSameDataset.apply(secondVector), PRECISION);
+        assertEquals(originalMdl.apply(firstVector), updatedMdlOnEmptyDataset.apply(firstVector), PRECISION);
+        assertEquals(originalMdl.apply(secondVector), updatedMdlOnEmptyDataset.apply(secondVector), PRECISION);
+    }
+
+    /** */
+    @NotNull private KMeansTrainer createAndCheckTrainer() {
         KMeansTrainer trainer = new KMeansTrainer()
             .withDistance(new EuclideanDistance())
             .withK(10)
@@ -61,20 +113,6 @@ public class KMeansTrainerTest {
         assertEquals(10, trainer.getK());
         assertEquals(2, trainer.getSeed());
         assertTrue(trainer.getDistance() instanceof EuclideanDistance);
-
-        KMeansModel knnMdl = trainer
-            .withK(1)
-            .fit(
-                new LocalDatasetBuilder<>(data, 2),
-                (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
-                (k, v) -> v[2]
-            );
-
-        Vector firstVector = new DenseVector(new double[]{2.0, 2.0});
-        assertEquals(knnMdl.apply(firstVector), 0.0, PRECISION);
-        Vector secondVector = new DenseVector(new double[]{-2.0, -2.0});
-        assertEquals(knnMdl.apply(secondVector), 0.0, PRECISION);
-        assertEquals(trainer.getMaxIterations(), 1);
-        assertEquals(trainer.getEpsilon(), PRECISION, PRECISION);
+        return trainer;
     }
 }
index acf28e9..745eac9 100644 (file)
@@ -22,6 +22,7 @@ import java.util.Set;
 import org.apache.ignite.ml.clustering.kmeans.KMeansModel;
 import org.apache.ignite.ml.clustering.kmeans.KMeansModelFormat;
 import org.apache.ignite.ml.knn.ann.ANNClassificationModel;
+import org.apache.ignite.ml.knn.ann.ANNClassificationTrainer;
 import org.apache.ignite.ml.knn.ann.ANNModelFormat;
 import org.apache.ignite.ml.knn.classification.KNNClassificationModel;
 import org.apache.ignite.ml.knn.classification.KNNModelFormat;
@@ -103,11 +104,11 @@ public class CollectionsTest {
 
         test(new SVMLinearBinaryClassificationModel(null, 1.0), new SVMLinearBinaryClassificationModel(null, 0.5));
 
-        test(new ANNClassificationModel(new LabeledVectorSet<>()),
-            new ANNClassificationModel(new LabeledVectorSet<>(1, 1, true)));
+        test(new ANNClassificationModel(new LabeledVectorSet<>(), new ANNClassificationTrainer.CentroidStat()),
+            new ANNClassificationModel(new LabeledVectorSet<>(1, 1, true), new ANNClassificationTrainer.CentroidStat()));
 
-        test(new ANNModelFormat(1, new ManhattanDistance(), NNStrategy.SIMPLE, new LabeledVectorSet<>()),
-            new ANNModelFormat(2, new ManhattanDistance(), NNStrategy.SIMPLE, new LabeledVectorSet<>()));
+        test(new ANNModelFormat(1, new ManhattanDistance(), NNStrategy.SIMPLE, new LabeledVectorSet<>(), new ANNClassificationTrainer.CentroidStat()),
+            new ANNModelFormat(2, new ManhattanDistance(), NNStrategy.SIMPLE, new LabeledVectorSet<>(), new ANNClassificationTrainer.CentroidStat()));
     }
 
     /** Test classes that have all instances equal (eg, metrics). */
index 17d9c1a..9315850 100644 (file)
@@ -32,6 +32,7 @@ import org.apache.ignite.ml.clustering.kmeans.KMeansTrainer;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
 import org.apache.ignite.ml.knn.NNClassificationModel;
 import org.apache.ignite.ml.knn.ann.ANNClassificationModel;
+import org.apache.ignite.ml.knn.ann.ANNClassificationTrainer;
 import org.apache.ignite.ml.knn.ann.ANNModelFormat;
 import org.apache.ignite.ml.knn.ann.ProbableLabel;
 import org.apache.ignite.ml.knn.classification.KNNClassificationModel;
@@ -237,7 +238,7 @@ public class LocalModelsTest {
         executeModelTest(mdlFilePath -> {
             final LabeledVectorSet<ProbableLabel, LabeledVector> centers = new LabeledVectorSet<>();
 
-            NNClassificationModel mdl = new ANNClassificationModel(centers)
+            NNClassificationModel mdl = new ANNClassificationModel(centers, new ANNClassificationTrainer.CentroidStat())
                 .withK(4)
                 .withDistanceMeasure(new ManhattanDistance())
                 .withStrategy(NNStrategy.WEIGHTED);
@@ -250,7 +251,7 @@ public class LocalModelsTest {
             Assert.assertNotNull(load);
 
 
-            NNClassificationModel importedMdl = new ANNClassificationModel(load.getCandidates())
+            NNClassificationModel importedMdl = new ANNClassificationModel(load.getCandidates(), new ANNClassificationTrainer.CentroidStat())
                 .withK(load.getK())
                 .withDistanceMeasure(load.getDistanceMeasure())
                 .withStrategy(load.getStgy());
index 4452668..3e340f6 100644 (file)
@@ -54,7 +54,7 @@ public class GDBTrainerTest {
             learningSample.put(i, new double[] {xs[i], ys[i]});
         }
 
-        DatasetTrainer<Model<Vector, Double>, Double> trainer
+        DatasetTrainer<ModelsComposition, Double> trainer
             = new GDBRegressionOnTreesTrainer(1.0, 2000, 3, 0.0).withUseIndex(true);
 
         Model<Vector, Double> mdl = trainer.fit(
index 7289b1d..d8fb620 100644 (file)
@@ -26,6 +26,7 @@ import org.apache.ignite.ml.knn.ann.ANNClassificationModel;
 import org.apache.ignite.ml.knn.ann.ANNClassificationTrainer;
 import org.apache.ignite.ml.knn.classification.NNStrategy;
 import org.apache.ignite.ml.math.distances.EuclideanDistance;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.junit.Assert;
 import org.junit.Test;
@@ -68,4 +69,47 @@ public class ANNClassificationTest extends TrainerTest {
         Assert.assertTrue(mdl.toString(true).contains(NNStrategy.SIMPLE.name()));
         Assert.assertTrue(mdl.toString(false).contains(NNStrategy.SIMPLE.name()));
     }
+
+    /** */
+    @Test
+    public void testUpdate() {
+        Map<Integer, double[]> cacheMock = new HashMap<>();
+
+        for (int i = 0; i < twoClusters.length; i++)
+            cacheMock.put(i, twoClusters[i]);
+
+        ANNClassificationTrainer trainer = new ANNClassificationTrainer()
+            .withK(10)
+            .withMaxIterations(10)
+            .withEpsilon(1e-4)
+            .withDistance(new EuclideanDistance());
+
+        ANNClassificationModel originalMdl = (ANNClassificationModel) trainer.withSeed(1234L).fit(
+            cacheMock,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        ).withK(3)
+            .withDistanceMeasure(new EuclideanDistance())
+            .withStrategy(NNStrategy.SIMPLE);
+
+        ANNClassificationModel updatedOnSameDataset = trainer.withSeed(1234L).update(originalMdl,
+            cacheMock, parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[2]
+        );
+
+        ANNClassificationModel updatedOnEmptyDataset = trainer.withSeed(1234L).update(originalMdl,
+            new HashMap<Integer, double[]>(), parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[2]
+        );
+
+        Vector v1 = VectorUtils.of(550, 550);
+        Vector v2 = VectorUtils.of(-550, -550);
+        TestUtils.assertEquals(originalMdl.apply(v1), updatedOnSameDataset.apply(v1), PRECISION);
+        TestUtils.assertEquals(originalMdl.apply(v2), updatedOnSameDataset.apply(v2), PRECISION);
+        TestUtils.assertEquals(originalMdl.apply(v1), updatedOnEmptyDataset.apply(v1), PRECISION);
+        TestUtils.assertEquals(originalMdl.apply(v2), updatedOnEmptyDataset.apply(v2), PRECISION);
+    }
 }
index c5a5c1c..748123a 100644 (file)
@@ -174,4 +174,43 @@ public class KNNClassificationTest {
         Vector vector = new DenseVector(new double[] {-1.01, -1.01});
         assertEquals(knnMdl.apply(vector), 1.0);
     }
+
+    /** */
+    @Test
+    public void testUpdate() {
+        Map<Integer, double[]> data = new HashMap<>();
+        data.put(0, new double[] {10.0, 10.0, 1.0});
+        data.put(1, new double[] {10.0, 20.0, 1.0});
+        data.put(2, new double[] {-1, -1, 1.0});
+        data.put(3, new double[] {-2, -2, 2.0});
+        data.put(4, new double[] {-1.0, -2.0, 2.0});
+        data.put(5, new double[] {-2.0, -1.0, 2.0});
+
+        KNNClassificationTrainer trainer = new KNNClassificationTrainer();
+
+        KNNClassificationModel originalMdl = (KNNClassificationModel)trainer.fit(
+            data,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[2]
+        ).withK(3)
+            .withDistanceMeasure(new EuclideanDistance())
+            .withStrategy(NNStrategy.WEIGHTED);
+
+        KNNClassificationModel updatedOnSameDataset = trainer.update(originalMdl,
+            data, parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[2]
+        );
+
+        KNNClassificationModel updatedOnEmptyDataset = trainer.update(originalMdl,
+            new HashMap<Integer, double[]>(), parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[2]
+        );
+
+        Vector vector = new DenseVector(new double[] {-1.01, -1.01});
+        assertEquals(originalMdl.apply(vector), updatedOnSameDataset.apply(vector));
+        assertEquals(originalMdl.apply(vector), updatedOnEmptyDataset.apply(vector));
+    }
 }
index 5504e1a..52ff1ec 100644 (file)
@@ -35,6 +35,8 @@ import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
 
+import static junit.framework.TestCase.assertEquals;
+
 /**
  * Tests for {@link KNNRegressionTrainer}.
  */
@@ -135,4 +137,42 @@ public class KNNRegressionTest {
         Assert.assertTrue(knnMdl.toString(true).contains(stgy.name()));
         Assert.assertTrue(knnMdl.toString(false).contains(stgy.name()));
     }
+
+    /** */
+    @Test
+    public void testUpdate() {
+        Map<Integer, double[]> data = new HashMap<>();
+        data.put(0, new double[] {11.0, 0, 0, 0, 0, 0});
+        data.put(1, new double[] {12.0, 2.0, 0, 0, 0, 0});
+        data.put(2, new double[] {13.0, 0, 3.0, 0, 0, 0});
+        data.put(3, new double[] {14.0, 0, 0, 4.0, 0, 0});
+        data.put(4, new double[] {15.0, 0, 0, 0, 5.0, 0});
+        data.put(5, new double[] {16.0, 0, 0, 0, 0, 6.0});
+
+        KNNRegressionTrainer trainer = new KNNRegressionTrainer();
+
+        KNNRegressionModel originalMdl = (KNNRegressionModel) trainer.fit(
+            new LocalDatasetBuilder<>(data, parts),
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        ).withK(1)
+            .withDistanceMeasure(new EuclideanDistance())
+            .withStrategy(NNStrategy.SIMPLE);
+
+        KNNRegressionModel updatedOnSameDataset = trainer.update(originalMdl,
+            data, parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[2]
+        );
+
+        KNNRegressionModel updatedOnEmptyDataset = trainer.update(originalMdl,
+            new HashMap<Integer, double[]>(), parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[2]
+        );
+
+        Vector vector = new DenseVector(new double[] {0, 0, 0, 5.0, 0.0});
+        assertEquals(originalMdl.apply(vector), updatedOnSameDataset.apply(vector));
+        assertEquals(originalMdl.apply(vector), updatedOnEmptyDataset.apply(vector));
+    }
 }
index a1d601c..6a6555e 100644 (file)
@@ -29,6 +29,7 @@ import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
 import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
 import org.apache.ignite.ml.optimization.LossFunctions;
+import org.apache.ignite.ml.optimization.SmoothParametrized;
 import org.apache.ignite.ml.optimization.updatecalculators.NesterovParameterUpdate;
 import org.apache.ignite.ml.optimization.updatecalculators.NesterovUpdateCalculator;
 import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate;
@@ -154,6 +155,69 @@ public class MLPTrainerTest {
 
             TestUtils.checkIsInEpsilonNeighbourhood(new DenseVector(new double[]{0.0}), predict.getRow(0), 1E-1);
         }
+
+        /** */
+        @Test
+        public void testUpdate() {
+            UpdatesStrategy<SmoothParametrized, SimpleGDParameterUpdate> updatesStgy = new UpdatesStrategy<>(
+                new SimpleGDUpdateCalculator(0.2),
+                SimpleGDParameterUpdate::sumLocal,
+                SimpleGDParameterUpdate::avg
+            );
+
+            Map<Integer, double[][]> xorData = new HashMap<>();
+            xorData.put(0, new double[][]{{0.0, 0.0}, {0.0}});
+            xorData.put(1, new double[][]{{0.0, 1.0}, {1.0}});
+            xorData.put(2, new double[][]{{1.0, 0.0}, {1.0}});
+            xorData.put(3, new double[][]{{1.0, 1.0}, {0.0}});
+
+            MLPArchitecture arch = new MLPArchitecture(2).
+                withAddedLayer(10, true, Activators.RELU).
+                withAddedLayer(1, false, Activators.SIGMOID);
+
+            MLPTrainer<SimpleGDParameterUpdate> trainer = new MLPTrainer<>(
+                arch,
+                LossFunctions.MSE,
+                updatesStgy,
+                3000,
+                batchSize,
+                50,
+                123L
+            );
+
+            MultilayerPerceptron originalMdl = trainer.fit(
+                xorData,
+                parts,
+                (k, v) -> VectorUtils.of(v[0]),
+                (k, v) -> v[1]
+            );
+
+            MultilayerPerceptron updatedOnSameDS = trainer.update(
+                originalMdl,
+                xorData,
+                parts,
+                (k, v) -> VectorUtils.of(v[0]),
+                (k, v) -> v[1]
+            );
+
+            MultilayerPerceptron updatedOnEmptyDS = trainer.update(
+                originalMdl,
+                new HashMap<Integer, double[][]>(),
+                parts,
+                (k, v) -> VectorUtils.of(v[0]),
+                (k, v) -> v[1]
+            );
+
+            DenseMatrix matrix = new DenseMatrix(new double[][] {
+                {0.0, 0.0},
+                {0.0, 1.0},
+                {1.0, 0.0},
+                {1.0, 1.0}
+            });
+
+            TestUtils.checkIsInEpsilonNeighbourhood(originalMdl.apply(matrix).getRow(0), updatedOnSameDS.apply(matrix).getRow(0), 1E-1);
+            TestUtils.checkIsInEpsilonNeighbourhood(originalMdl.apply(matrix).getRow(0), updatedOnEmptyDS.apply(matrix).getRow(0), 1E-1);
+        }
     }
 
     /**
index d16ae72..9c35ac7 100644 (file)
@@ -101,4 +101,55 @@ public class LinearRegressionLSQRTrainerTest extends TrainerTest {
 
         assertEquals(intercept, mdl.getIntercept(), 1e-6);
     }
+
+    /** */
+    @Test
+    public void testUpdate() {
+        Random rnd = new Random(0);
+        Map<Integer, double[]> data = new HashMap<>();
+        double[] coef = new double[100];
+        double intercept = rnd.nextDouble() * 10;
+
+        for (int i = 0; i < 100000; i++) {
+            double[] x = new double[coef.length + 1];
+
+            for (int j = 0; j < coef.length; j++)
+                x[j] = rnd.nextDouble() * 10;
+
+            x[coef.length] = intercept;
+
+            data.put(i, x);
+        }
+
+        LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
+
+        LinearRegressionModel originalModel = trainer.fit(
+            data,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[coef.length]
+        );
+
+        LinearRegressionModel updatedOnSameDS = trainer.update(
+            originalModel,
+            data,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[coef.length]
+        );
+
+        LinearRegressionModel updatedOnEmpyDS = trainer.update(
+            originalModel,
+            new HashMap<Integer, double[]>(),
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[coef.length]
+        );
+
+        assertArrayEquals(originalModel.getWeights().getStorage().data(), updatedOnSameDS.getWeights().getStorage().data(), 1e-6);
+        assertEquals(originalModel.getIntercept(), updatedOnSameDS.getIntercept(), 1e-6);
+
+        assertArrayEquals(originalModel.getWeights().getStorage().data(), updatedOnEmpyDS.getWeights().getStorage().data(), 1e-6);
+        assertEquals(originalModel.getIntercept(), updatedOnEmpyDS.getIntercept(), 1e-6);
+    }
 }
index 349e712..86b0f27 100644 (file)
@@ -72,4 +72,66 @@ public class LinearRegressionSGDTrainerTest extends TrainerTest {
 
         assertEquals(2.8421709430404007e-14, mdl.getIntercept(), 1e-1);
     }
+
+    /** */
+    @Test
+    public void testUpdate() {
+        Map<Integer, double[]> data = new HashMap<>();
+        data.put(0, new double[]{-1.0915526, 1.81983527, -0.91409478, 0.70890712, -24.55724107});
+        data.put(1, new double[]{-0.61072904, 0.37545517, 0.21705352, 0.09516495, -26.57226867});
+        data.put(2, new double[]{0.05485406, 0.88219898, -0.80584547, 0.94668307, 61.80919728});
+        data.put(3, new double[]{-0.24835094, -0.34000053, -1.69984651, -1.45902635, -161.65525991});
+        data.put(4, new double[]{0.63675392, 0.31675535, 0.38837437, -1.1221971, -14.46432611});
+        data.put(5, new double[]{0.14194017, 2.18158997, -0.28397346, -0.62090588, -3.2122197});
+        data.put(6, new double[]{-0.53487507, 1.4454797, 0.21570443, -0.54161422, -46.5469012});
+        data.put(7, new double[]{-1.58812173, -0.73216803, -2.15670676, -1.03195988, -247.23559889});
+        data.put(8, new double[]{0.20702671, 0.92864654, 0.32721202, -0.09047503, 31.61484949});
+        data.put(9, new double[]{-0.37890345, -0.04846179, -0.84122753, -1.14667474, -124.92598583});
+
+        LinearRegressionSGDTrainer<?> trainer = new LinearRegressionSGDTrainer<>(new UpdatesStrategy<>(
+            new RPropUpdateCalculator(),
+            RPropParameterUpdate::sumLocal,
+            RPropParameterUpdate::avg
+        ), 100000, 10, 100, 0L);
+
+        LinearRegressionModel originalModel = trainer.withSeed(0).fit(
+            data,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[4]
+        );
+
+
+        LinearRegressionModel updatedOnSameDS = trainer.withSeed(0).update(
+            originalModel,
+            data,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[4]
+        );
+
+        LinearRegressionModel updatedOnEmptyDS = trainer.withSeed(0).update(
+            originalModel,
+            new HashMap<Integer, double[]>(),
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[4]
+        );
+
+        assertArrayEquals(
+            originalModel.getWeights().getStorage().data(),
+            updatedOnSameDS.getWeights().getStorage().data(),
+            1.0
+        );
+
+        assertEquals(originalModel.getIntercept(), updatedOnSameDS.getIntercept(), 1.0);
+
+        assertArrayEquals(
+            originalModel.getWeights().getStorage().data(),
+            updatedOnEmptyDS.getWeights().getStorage().data(),
+            1e-1
+        );
+
+        assertEquals(originalModel.getIntercept(), updatedOnEmptyDS.getIntercept(), 1e-1);
+    }
 }
index 1f8c5d1..f08501c 100644 (file)
@@ -19,9 +19,11 @@ package org.apache.ignite.ml.regressions.logistic;
 
 import java.util.Arrays;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 import org.apache.ignite.ml.TestUtils;
 import org.apache.ignite.ml.common.TrainerTest;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.apache.ignite.ml.nn.UpdatesStrategy;
 import org.apache.ignite.ml.optimization.SmoothParametrized;
@@ -81,4 +83,60 @@ public class LogRegMultiClassTrainerTest extends TrainerTest {
         TestUtils.assertEquals(2, mdl.apply(VectorUtils.of(-10, -10)), PRECISION);
         TestUtils.assertEquals(3, mdl.apply(VectorUtils.of(10, -10)), PRECISION);
     }
+
+    /** */
+    @Test
+    public void testUpdate() {
+        Map<Integer, double[]> cacheMock = new HashMap<>();
+
+        for (int i = 0; i < fourSetsInSquareVertices.length; i++)
+            cacheMock.put(i, fourSetsInSquareVertices[i]);
+
+        LogRegressionMultiClassTrainer<?> trainer = new LogRegressionMultiClassTrainer<>()
+            .withUpdatesStgy(new UpdatesStrategy<>(
+                new SimpleGDUpdateCalculator(0.2),
+                SimpleGDParameterUpdate::sumLocal,
+                SimpleGDParameterUpdate::avg
+            ))
+            .withAmountOfIterations(1000)
+            .withAmountOfLocIterations(10)
+            .withBatchSize(100)
+            .withSeed(123L);
+
+        LogRegressionMultiClassModel originalModel = trainer.fit(
+            cacheMock,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        );
+
+        LogRegressionMultiClassModel updatedOnSameDS = trainer.update(
+            originalModel,
+            cacheMock,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        );
+
+        LogRegressionMultiClassModel updatedOnEmptyDS = trainer.update(
+            originalModel,
+            new HashMap<Integer, double[]>(),
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        );
+
+        List<Vector> vectors = Arrays.asList(
+            VectorUtils.of(10, 10),
+            VectorUtils.of(-10, 10),
+            VectorUtils.of(-10, -10),
+            VectorUtils.of(10, -10)
+        );
+
+
+        for (Vector vec : vectors) {
+            TestUtils.assertEquals(originalModel.apply(vec), updatedOnSameDS.apply(vec), PRECISION);
+            TestUtils.assertEquals(originalModel.apply(vec), updatedOnEmptyDS.apply(vec), PRECISION);
+        }
+    }
 }
index 5bd2dbd..1da0d1a 100644 (file)
@@ -22,6 +22,7 @@ import java.util.HashMap;
 import java.util.Map;
 import org.apache.ignite.ml.TestUtils;
 import org.apache.ignite.ml.common.TrainerTest;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.apache.ignite.ml.nn.UpdatesStrategy;
 import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
@@ -60,4 +61,49 @@ public class LogisticRegressionSGDTrainerTest extends TrainerTest {
         TestUtils.assertEquals(0, mdl.apply(VectorUtils.of(100, 10)), PRECISION);
         TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(10, 100)), PRECISION);
     }
+
+    /** */
+    @Test
+    public void testUpdate() {
+        Map<Integer, double[]> cacheMock = new HashMap<>();
+
+        for (int i = 0; i < twoLinearlySeparableClasses.length; i++)
+            cacheMock.put(i, twoLinearlySeparableClasses[i]);
+
+        LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>(
+            new SimpleGDUpdateCalculator().withLearningRate(0.2),
+            SimpleGDParameterUpdate::sumLocal,
+            SimpleGDParameterUpdate::avg
+        ), 100000, 10, 100, 123L);
+
+        LogisticRegressionModel originalModel = trainer.fit(
+            cacheMock,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        );
+
+        LogisticRegressionModel updatedOnSameDS = trainer.update(
+            originalModel,
+            cacheMock,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        );
+
+        LogisticRegressionModel updatedOnEmptyDS = trainer.update(
+            originalModel,
+            new HashMap<Integer, double[]>(),
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        );
+
+        Vector v1 = VectorUtils.of(100, 10);
+        Vector v2 = VectorUtils.of(10, 100);
+        TestUtils.assertEquals(originalModel.apply(v1), updatedOnSameDS.apply(v1), PRECISION);
+        TestUtils.assertEquals(originalModel.apply(v2), updatedOnSameDS.apply(v2), PRECISION);
+        TestUtils.assertEquals(originalModel.apply(v2), updatedOnEmptyDS.apply(v2), PRECISION);
+        TestUtils.assertEquals(originalModel.apply(v1), updatedOnEmptyDS.apply(v1), PRECISION);
+    }
 }
index 5630bee..263bb6d 100644 (file)
@@ -22,6 +22,7 @@ import java.util.HashMap;
 import java.util.Map;
 import org.apache.ignite.ml.TestUtils;
 import org.apache.ignite.ml.common.TrainerTest;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.junit.Test;
 
@@ -52,4 +53,44 @@ public class SVMBinaryTrainerTest extends TrainerTest {
         TestUtils.assertEquals(-1, mdl.apply(VectorUtils.of(100, 10)), PRECISION);
         TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(10, 100)), PRECISION);
     }
+
+    /** */
+    @Test
+    public void testUpdate() {
+        Map<Integer, double[]> cacheMock = new HashMap<>();
+
+        for (int i = 0; i < twoLinearlySeparableClasses.length; i++)
+            cacheMock.put(i, twoLinearlySeparableClasses[i]);
+
+        SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer()
+            .withAmountOfIterations(1000)
+            .withSeed(1234L);
+
+        SVMLinearBinaryClassificationModel originalModel = trainer.fit(
+            cacheMock,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        );
+
+        SVMLinearBinaryClassificationModel updatedOnSameDS = trainer.update(
+            originalModel,
+            cacheMock,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        );
+
+        SVMLinearBinaryClassificationModel updatedOnEmptyDS = trainer.update(
+            originalModel,
+            new HashMap<Integer, double[]>(),
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        );
+
+        Vector v = VectorUtils.of(100, 10);
+        TestUtils.assertEquals(originalModel.apply(v), updatedOnSameDS.apply(v), PRECISION);
+        TestUtils.assertEquals(originalModel.apply(v), updatedOnEmptyDS.apply(v), PRECISION);
+    }
 }
index 7ea28c2..e0c62af 100644 (file)
@@ -22,6 +22,7 @@ import java.util.HashMap;
 import java.util.Map;
 import org.apache.ignite.ml.TestUtils;
 import org.apache.ignite.ml.common.TrainerTest;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.junit.Test;
 
@@ -54,4 +55,46 @@ public class SVMMultiClassTrainerTest extends TrainerTest {
         TestUtils.assertEquals(0, mdl.apply(VectorUtils.of(100, 10)), PRECISION);
         TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(10, 100)), PRECISION);
     }
+
+    /** */
+    @Test
+    public void testUpdate() {
+        Map<Integer, double[]> cacheMock = new HashMap<>();
+
+        for (int i = 0; i < twoLinearlySeparableClasses.length; i++)
+            cacheMock.put(i, twoLinearlySeparableClasses[i]);
+
+        SVMLinearMultiClassClassificationTrainer trainer = new SVMLinearMultiClassClassificationTrainer()
+            .withLambda(0.3)
+            .withAmountOfLocIterations(10)
+            .withAmountOfIterations(100)
+            .withSeed(1234L);
+
+        SVMLinearMultiClassClassificationModel originalModel = trainer.fit(
+            cacheMock,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        );
+
+        SVMLinearMultiClassClassificationModel updatedOnSameDS = trainer.update(
+            originalModel,
+            cacheMock,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        );
+
+        SVMLinearMultiClassClassificationModel updatedOnEmptyDS = trainer.update(
+            originalModel,
+            new HashMap<Integer, double[]>(),
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        );
+
+        Vector v = VectorUtils.of(100, 10);
+        TestUtils.assertEquals(originalModel.apply(v), updatedOnSameDS.apply(v), PRECISION);
+        TestUtils.assertEquals(originalModel.apply(v), updatedOnEmptyDS.apply(v), PRECISION);
+    }
 }
index 4abf508..087f4e8 100644 (file)
@@ -24,6 +24,7 @@ import java.util.Map;
 import org.apache.ignite.ml.composition.ModelsComposition;
 import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator;
 import org.apache.ignite.ml.dataset.feature.FeatureMeta;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -75,7 +76,7 @@ public class RandomForestClassifierTrainerTest {
         }
 
         ArrayList<FeatureMeta> meta = new ArrayList<>();
-        for(int i = 0; i < 4; i++)
+        for (int i = 0; i < 4; i++)
             meta.add(new FeatureMeta("", i, false));
         RandomForestClassifierTrainer trainer = new RandomForestClassifierTrainer(meta)
             .withCountOfTrees(5)
@@ -86,4 +87,34 @@ public class RandomForestClassifierTrainerTest {
         assertTrue(mdl.getPredictionsAggregator() instanceof OnMajorityPredictionsAggregator);
         assertEquals(5, mdl.getModels().size());
     }
+
+    /** */
+    @Test
+    public void testUpdate() {
+        int sampleSize = 1000;
+        Map<double[], Double> sample = new HashMap<>();
+        for (int i = 0; i < sampleSize; i++) {
+            double x1 = i;
+            double x2 = x1 / 10.0;
+            double x3 = x2 / 10.0;
+            double x4 = x3 / 10.0;
+
+            sample.put(new double[] {x1, x2, x3, x4}, (double)(i % 2));
+        }
+
+        ArrayList<FeatureMeta> meta = new ArrayList<>();
+        for (int i = 0; i < 4; i++)
+            meta.add(new FeatureMeta("", i, false));
+        RandomForestClassifierTrainer trainer = new RandomForestClassifierTrainer(meta)
+            .withCountOfTrees(100)
+            .withFeaturesCountSelectionStrgy(x -> 2);
+
+        ModelsComposition originalModel = trainer.fit(sample, parts, (k, v) -> VectorUtils.of(k), (k, v) -> v);
+        ModelsComposition updatedOnSameDS = trainer.update(originalModel, sample, parts, (k, v) -> VectorUtils.of(k), (k, v) -> v);
+        ModelsComposition updatedOnEmptyDS = trainer.update(originalModel, new HashMap<double[], Double>(), parts, (k, v) -> VectorUtils.of(k), (k, v) -> v);
+
+        Vector v = VectorUtils.of(5, 0.5, 0.05, 0.005);
+        assertEquals(originalModel.apply(v), updatedOnSameDS.apply(v), 0.01);
+        assertEquals(originalModel.apply(v), updatedOnEmptyDS.apply(v), 0.01);
+    }
 }
index c4a4a75..fcc20bd 100644 (file)
@@ -24,6 +24,7 @@ import java.util.Map;
 import org.apache.ignite.ml.composition.ModelsComposition;
 import org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictionsAggregator;
 import org.apache.ignite.ml.dataset.feature.FeatureMeta;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -82,4 +83,34 @@ public class RandomForestRegressionTrainerTest {
         assertTrue(mdl.getPredictionsAggregator() instanceof MeanValuePredictionsAggregator);
         assertEquals(5, mdl.getModels().size());
     }
+
+    /** */
+    @Test
+    public void testUpdate() {
+        int sampleSize = 1000;
+        Map<double[], Double> sample = new HashMap<>();
+        for (int i = 0; i < sampleSize; i++) {
+            double x1 = i;
+            double x2 = x1 / 10.0;
+            double x3 = x2 / 10.0;
+            double x4 = x3 / 10.0;
+
+            sample.put(new double[] {x1, x2, x3, x4}, (double)(i % 2));
+        }
+
+        ArrayList<FeatureMeta> meta = new ArrayList<>();
+        for (int i = 0; i < 4; i++)
+            meta.add(new FeatureMeta("", i, false));
+        RandomForestRegressionTrainer trainer = new RandomForestRegressionTrainer(meta)
+            .withCountOfTrees(100)
+            .withFeaturesCountSelectionStrgy(x -> 2);
+
+        ModelsComposition originalModel = trainer.fit(sample, parts, (k, v) -> VectorUtils.of(k), (k, v) -> v);
+        ModelsComposition updatedOnSameDS = trainer.update(originalModel, sample, parts, (k, v) -> VectorUtils.of(k), (k, v) -> v);
+        ModelsComposition updatedOnEmptyDS = trainer.update(originalModel, new HashMap<double[], Double>(), parts, (k, v) -> VectorUtils.of(k), (k, v) -> v);
+
+        Vector v = VectorUtils.of(5, 0.5, 0.05, 0.005);
+        assertEquals(originalModel.apply(v), updatedOnSameDS.apply(v), 0.1);
+        assertEquals(originalModel.apply(v), updatedOnEmptyDS.apply(v), 0.1);
+    }
 }