IGNITE-11144: [ML] Create example for FeatureLabelExtractor
authorArtem Malykh <amalykhgh@gmail.com>
Mon, 4 Feb 2019 14:23:25 +0000 (17:23 +0300)
committerYury Babak <ybabak@gridgain.com>
Mon, 4 Feb 2019 14:23:25 +0000 (17:23 +0300)
This closes #5993

examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java
modules/ml/src/main/java/org/apache/ignite/ml/composition/bagging/BaggedTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/VectorUtils.java
modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/trainers/FeatureLabelExtractor.java
modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java
modules/ml/src/test/java/org/apache/ignite/ml/composition/StackingTest.java
modules/ml/src/test/java/org/apache/ignite/ml/trainers/StackingTest.java [deleted file]

index 1bb4146..772a35b 100644 (file)
@@ -27,6 +27,8 @@ import org.apache.ignite.cache.query.ScanQuery;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer;
 import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
+import org.apache.ignite.ml.structures.LabeledVector;
+import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
 import org.apache.ignite.ml.util.MLSandboxDatasets;
 import org.apache.ignite.ml.util.SandboxMLCache;
 
@@ -59,11 +61,22 @@ public class LinearRegressionLSQRTrainerExample {
             LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
 
             System.out.println(">>> Perform the training to get the model.");
+
+             // This object is used to extract features and vectors from upstream entities which are
+             // essentialy tuples of the form (key, value) (in our case (Integer, Vector)).
+             // Key part of tuple in our example is ignored.
+             // Label is extracted from 0th entry of the value (which is a Vector)
+             // and features are all remaining vector part. Alternatively we could use
+             // DatasetTrainer#fit(Ignite, IgniteCache, IgniteBiFunction, IgniteBiFunction) method call
+             // where there is a separate lambda for extracting label from (key, value) and a separate labmda for
+             // extracting features.
+            FeatureLabelExtractor<Integer, Vector, Double> extractor =
+                (k, v) -> new LabeledVector<>(v.copyOfRange(1, v.size()), v.get(0));
+
             LinearRegressionModel mdl = trainer.fit(
                 ignite,
                 dataCache,
-                (k, v) -> v.copyOfRange(1, v.size()),
-                (k, v) -> v.get(0)
+                extractor
             );
 
             System.out.println(">>> Linear regression model: " + mdl);
index a63ef62..b588b25 100644 (file)
 
 package org.apache.ignite.ml.composition.bagging;
 
+import java.util.Collections;
+import java.util.List;
+import java.util.Random;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
 import org.apache.ignite.ml.IgniteModel;
 import org.apache.ignite.ml.composition.CompositionUtils;
 import org.apache.ignite.ml.composition.combinators.parallel.TrainersParallelComposition;
@@ -31,12 +36,6 @@ import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
 import org.apache.ignite.ml.trainers.transformers.BaggingUpstreamTransformer;
 import org.apache.ignite.ml.util.Utils;
 
-import java.util.Collections;
-import java.util.List;
-import java.util.Random;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
-
 /**
  * Trainer encapsulating logic of bootstrap aggregating (bagging).
  * This trainer accepts some other trainer and returns bagged version of it.
index 0c12672..72f95af 100644 (file)
 
 package org.apache.ignite.ml.math.primitives.vector;
 
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
 import org.apache.ignite.internal.util.typedef.internal.A;
 import org.apache.ignite.ml.math.StorageConstants;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
@@ -25,11 +29,6 @@ import org.apache.ignite.ml.math.primitives.vector.impl.DelegatingNamedVector;
 import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
 import org.apache.ignite.ml.math.primitives.vector.impl.SparseVector;
 
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Objects;
-
 /**
  * Some utils for {@link Vector}.
  */
index 7455ff1..a78396d 100644 (file)
@@ -198,6 +198,24 @@ public abstract class DatasetTrainer<M extends IgniteModel, L> {
     }
 
     /**
+     * Trains model based on the specified data.
+     *
+     * @param ignite Ignite instance.
+     * @param cache Ignite cache.
+     * @param extractor Features and labels extractor.
+     * @param <K> Type of a key in {@code upstream} data.
+     * @param <V> Type of a value in {@code upstream} data.
+     * @return Model.
+     */
+    public <K, V> M fit(Ignite ignite, IgniteCache<K, V> cache,
+        FeatureLabelExtractor<K, V, L> extractor) {
+        return fit(
+            new CacheBasedDatasetBuilder<>(ignite, cache),
+            extractor
+        );
+    }
+
+    /**
      * Gets state of model in arguments, update in according to new data and return new model.
      *
      * @param mdl Learned model.
index cd8a0ae..37a2e57 100644 (file)
 
 package org.apache.ignite.ml.trainers;
 
+import java.io.Serializable;
+import java.util.Objects;
 import org.apache.ignite.ml.math.functions.IgniteFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.structures.LabeledVector;
 
-import java.io.Serializable;
-import java.util.Objects;
-
 /**
  * Class fro extracting features and vectors from upstream.
  *
index 0cba06c..8661d4b 100644 (file)
 
 package org.apache.ignite.ml.trainers;
 
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
 import org.apache.ignite.ml.IgniteModel;
 import org.apache.ignite.ml.composition.ModelsComposition;
 import org.apache.ignite.ml.composition.bagging.BaggedTrainer;
@@ -34,12 +39,6 @@ import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.apache.ignite.ml.trainers.transformers.BaggingUpstreamTransformer;
 import org.apache.ignite.ml.util.Utils;
 
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Random;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
-
 /**
  * Class containing various trainer transformers.
  */
index 1203cfb..3267790 100644 (file)
@@ -18,6 +18,7 @@
 package org.apache.ignite.ml.composition;
 
 import java.util.Arrays;
+import org.apache.ignite.IgniteCache;
 import org.apache.ignite.ml.IgniteModel;
 import org.apache.ignite.ml.TestUtils;
 import org.apache.ignite.ml.common.TrainerTest;
@@ -167,6 +168,6 @@ public class StackingTest extends TrainerTest {
         StackedDatasetTrainer<Void, Void, Void, IgniteModel<Void, Void>, Void> trainer =
             new StackedDatasetTrainer<>();
         thrown.expect(IllegalStateException.class);
-        trainer.fit(null, null, null);
+        trainer.fit(null, (IgniteCache<Object, Object>)null, null);
     }
 }
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/StackingTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/StackingTest.java
deleted file mode 100644 (file)
index 9c089ce..0000000
+++ /dev/null
@@ -1,169 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.trainers;
-
-import java.util.Arrays;
-import org.apache.ignite.ml.IgniteModel;
-import org.apache.ignite.ml.TestUtils;
-import org.apache.ignite.ml.common.TrainerTest;
-import org.apache.ignite.ml.composition.stacking.StackedDatasetTrainer;
-import org.apache.ignite.ml.composition.stacking.StackedModel;
-import org.apache.ignite.ml.composition.stacking.StackedVectorDatasetTrainer;
-import org.apache.ignite.ml.math.functions.IgniteFunction;
-import org.apache.ignite.ml.math.primitives.matrix.Matrix;
-import org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
-import org.apache.ignite.ml.nn.Activators;
-import org.apache.ignite.ml.nn.MLPTrainer;
-import org.apache.ignite.ml.nn.MultilayerPerceptron;
-import org.apache.ignite.ml.nn.UpdatesStrategy;
-import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
-import org.apache.ignite.ml.optimization.LossFunctions;
-import org.apache.ignite.ml.optimization.SmoothParametrized;
-import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
-import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
-import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer;
-import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
-import org.junit.Rule;
-import org.junit.Test;
-import org.junit.rules.ExpectedException;
-
-import static junit.framework.TestCase.assertEquals;
-
-/**
- * Tests stacked trainers.
- */
-public class StackingTest extends TrainerTest {
-    /** Rule to check exceptions. */
-    @Rule
-    public ExpectedException thrown = ExpectedException.none();
-
-    /**
-     * Tests simple stack training.
-     */
-    @Test
-    public void testSimpleStack() {
-        StackedDatasetTrainer<Vector, Vector, Double, LinearRegressionModel, Double> trainer =
-            new StackedDatasetTrainer<>();
-
-        UpdatesStrategy<SmoothParametrized, SimpleGDParameterUpdate> updatesStgy = new UpdatesStrategy<>(
-            new SimpleGDUpdateCalculator(0.2),
-            SimpleGDParameterUpdate::sumLocal,
-            SimpleGDParameterUpdate::avg
-        );
-
-        MLPArchitecture arch = new MLPArchitecture(2).
-            withAddedLayer(10, true, Activators.RELU).
-            withAddedLayer(1, false, Activators.SIGMOID);
-
-        MLPTrainer<SimpleGDParameterUpdate> trainer1 = new MLPTrainer<>(
-            arch,
-            LossFunctions.MSE,
-            updatesStgy,
-            3000,
-            10,
-            50,
-            123L
-        );
-
-        // Convert model trainer to produce Vector -> Vector model
-        DatasetTrainer<AdaptableDatasetModel<Vector, Vector, Matrix, Matrix, MultilayerPerceptron>, Double> mlpTrainer =
-            AdaptableDatasetTrainer.of(trainer1)
-                .beforeTrainedModel((Vector v) -> new DenseMatrix(v.asArray(), 1))
-                .afterTrainedModel((Matrix mtx) -> mtx.getRow(0))
-                .withConvertedLabels(VectorUtils::num2Arr);
-
-        final double factor = 3;
-
-        StackedModel<Vector, Vector, Double, LinearRegressionModel> mdl = trainer
-            .withAggregatorTrainer(new LinearRegressionLSQRTrainer().withConvertedLabels(x -> x * factor))
-            .addTrainer(mlpTrainer)
-            .withAggregatorInputMerger(VectorUtils::concat)
-            .withSubmodelOutput2VectorConverter(IgniteFunction.identity())
-            .withVector2SubmodelInputConverter(IgniteFunction.identity())
-            .withOriginalFeaturesKept(IgniteFunction.identity())
-            .withEnvironmentBuilder(TestUtils.testEnvBuilder())
-            .fit(getCacheMock(xor),
-                parts,
-                (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
-                (k, v) -> v[v.length - 1]);
-
-        assertEquals(0.0 * factor, mdl.predict(VectorUtils.of(0.0, 0.0)), 0.3);
-        assertEquals(1.0 * factor, mdl.predict(VectorUtils.of(0.0, 1.0)), 0.3);
-        assertEquals(1.0 * factor, mdl.predict(VectorUtils.of(1.0, 0.0)), 0.3);
-        assertEquals(0.0 * factor, mdl.predict(VectorUtils.of(1.0, 1.0)), 0.3);
-    }
-
-    /**
-     * Tests simple stack training.
-     */
-    @Test
-    public void testSimpleVectorStack() {
-        StackedVectorDatasetTrainer<Double, LinearRegressionModel, Double> trainer =
-            new StackedVectorDatasetTrainer<>();
-
-        UpdatesStrategy<SmoothParametrized, SimpleGDParameterUpdate> updatesStgy = new UpdatesStrategy<>(
-            new SimpleGDUpdateCalculator(0.2),
-            SimpleGDParameterUpdate::sumLocal,
-            SimpleGDParameterUpdate::avg
-        );
-
-        MLPArchitecture arch = new MLPArchitecture(2).
-            withAddedLayer(10, true, Activators.RELU).
-            withAddedLayer(1, false, Activators.SIGMOID);
-
-        DatasetTrainer<MultilayerPerceptron, Double> mlpTrainer = new MLPTrainer<>(
-            arch,
-            LossFunctions.MSE,
-            updatesStgy,
-            3000,
-            10,
-            50,
-            123L
-        ).withConvertedLabels(VectorUtils::num2Arr);
-
-        final double factor = 3;
-
-        StackedModel<Vector, Vector, Double, LinearRegressionModel> mdl = trainer
-            .withAggregatorTrainer(new LinearRegressionLSQRTrainer().withConvertedLabels(x -> x * factor))
-            .addMatrix2MatrixTrainer(mlpTrainer)
-            .withEnvironmentBuilder(TestUtils.testEnvBuilder())
-            .fit(getCacheMock(xor),
-                parts,
-                (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
-                (k, v) -> v[v.length - 1]);
-
-        assertEquals(0.0 * factor, mdl.predict(VectorUtils.of(0.0, 0.0)), 0.3);
-        assertEquals(1.0 * factor, mdl.predict(VectorUtils.of(0.0, 1.0)), 0.3);
-        assertEquals(1.0 * factor, mdl.predict(VectorUtils.of(1.0, 0.0)), 0.3);
-        assertEquals(0.0 * factor, mdl.predict(VectorUtils.of(1.0, 1.0)), 0.3);
-    }
-
-    /**
-     * Tests that if there is no any way for input of first layer to propagate to second layer,
-     * exception will be thrown.
-     */
-    @Test
-    public void testINoWaysOfPropagation() {
-        StackedDatasetTrainer<Void, Void, Void, IgniteModel<Void, Void>, Void> trainer =
-            new StackedDatasetTrainer<>();
-        thrown.expect(IllegalStateException.class);
-        trainer.fit(null, null, null);
-    }
-}