IGNITE-10700: [ML] Working with Binary Objects
authorYuriBabak <y.chief@gmail.com>
Mon, 21 Jan 2019 15:50:51 +0000 (18:50 +0300)
committerYury Babak <ybabak@gridgain.com>
Mon, 21 Jan 2019 15:50:51 +0000 (18:50 +0300)
This closes #5871

examples/src/main/java/org/apache/ignite/examples/ml/TrainingWithBinaryObjectExample.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilder.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java
modules/ml/src/test/java/org/apache/ignite/ml/common/CommonTestSuite.java
modules/ml/src/test/java/org/apache/ignite/ml/common/KeepBinaryTest.java [new file with mode: 0644]
modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtilsTest.java
modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java

diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/TrainingWithBinaryObjectExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/TrainingWithBinaryObjectExample.java
new file mode 100644 (file)
index 0000000..f8df0a8
--- /dev/null
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml;
+
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.binary.BinaryObject;
+import org.apache.ignite.binary.BinaryObjectBuilder;
+import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
+import org.apache.ignite.configuration.CacheConfiguration;
+import org.apache.ignite.ml.clustering.kmeans.KMeansModel;
+import org.apache.ignite.ml.clustering.kmeans.KMeansTrainer;
+import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+
+/**
+ * Example of support model training with binary objects.
+ */
+public class TrainingWithBinaryObjectExample {
+    /** Run example. */
+    public static void main(String[] args) throws Exception {
+        System.out.println();
+        System.out.println(">>> Model training over cached dataset with binary objects usage example started.");
+        // Start ignite grid.
+        try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
+            System.out.println(">>> Ignite grid started.");
+
+            IgniteCache<Integer, BinaryObject> dataCache = populateCache(ignite);
+
+            // Create dataset builder with enabled support of keeping binary for upstream cache.
+            CacheBasedDatasetBuilder<Integer, BinaryObject> datasetBuilder =
+                new CacheBasedDatasetBuilder<>(ignite, dataCache).withKeepBinary(true);
+
+            //
+            IgniteBiFunction<Integer, BinaryObject, Vector> featureExtractor
+                = (k, v) -> VectorUtils.of(new double[] {v.field("feature1")});
+
+            IgniteBiFunction<Integer, BinaryObject, Double> lbExtractor = (k, v) -> (double)v.field("label");
+
+            KMeansTrainer trainer = new KMeansTrainer();
+
+            KMeansModel kmdl = trainer.fit(datasetBuilder, featureExtractor, lbExtractor);
+
+            System.out.println(">>> Model trained over binary objects. Model " + kmdl);
+        }
+    }
+
+    /** Populate cache with some binary objects. */
+    private static IgniteCache<Integer, BinaryObject> populateCache(Ignite ignite) {
+        CacheConfiguration<Integer, BinaryObject> cacheConfiguration = new CacheConfiguration<>();
+
+        cacheConfiguration.setName("PERSONS");
+        cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 2));
+
+        IgniteCache<Integer, BinaryObject> cache = ignite.createCache(cacheConfiguration).withKeepBinary();
+
+        BinaryObjectBuilder builder = ignite.binary().builder("testType");
+
+        for (int i = 0; i < 100; i++) {
+            if (i > 50)
+                cache.put(i, builder.setField("feature1", 0.0).setField("label", 0.0).build());
+            else
+                cache.put(i, builder.setField("feature1", 1.0).setField("label", 1.0).build());
+        }
+        return cache;
+    }
+}
index b2aa00b..f6d5976 100644 (file)
@@ -78,6 +78,9 @@ public class CacheBasedDataset<K, V, C extends Serializable, D extends AutoClose
     /** Learning environment builder. */
     private final LearningEnvironmentBuilder envBuilder;
 
+    /** Upstream keep binary. */
+    private final boolean upstreamKeepBinary;
+
     /**
      * Constructs a new instance of dataset based on Ignite Cache, which is used as {@code upstream} and as reliable storage for
      * partition {@code context} as well.
@@ -98,7 +101,8 @@ public class CacheBasedDataset<K, V, C extends Serializable, D extends AutoClose
         IgniteCache<Integer, C> datasetCache,
         LearningEnvironmentBuilder envBuilder,
         PartitionDataBuilder<K, V, C, D> partDataBuilder,
-        UUID datasetId) {
+        UUID datasetId,
+        boolean upstreamKeepBinary) {
         this.ignite = ignite;
         this.upstreamCache = upstreamCache;
         this.filter = filter;
@@ -107,6 +111,7 @@ public class CacheBasedDataset<K, V, C extends Serializable, D extends AutoClose
         this.partDataBuilder = partDataBuilder;
         this.envBuilder = envBuilder;
         this.datasetId = datasetId;
+        this.upstreamKeepBinary = upstreamKeepBinary;
     }
 
     /** {@inheritDoc} */
@@ -127,7 +132,8 @@ public class CacheBasedDataset<K, V, C extends Serializable, D extends AutoClose
                 datasetCacheName,
                 datasetId,
                 partDataBuilder,
-                env
+                env,
+                upstreamKeepBinary
             );
 
 
@@ -160,7 +166,8 @@ public class CacheBasedDataset<K, V, C extends Serializable, D extends AutoClose
                 datasetCacheName,
                 datasetId,
                 partDataBuilder,
-                env
+                env,
+                upstreamKeepBinary
             );
             return data != null ? map.apply(data, env) : null;
         }, reduce, identity);
index b85bfc2..f452904 100644 (file)
@@ -61,6 +61,8 @@ public class CacheBasedDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
     /** Upstream transformer builder. */
     private final UpstreamTransformerBuilder transformerBuilder;
 
+    /** Upstream keep binary. */
+    private final boolean upstreamKeepBinary;
     /**
      * Constructs a new instance of cache based dataset builder that makes {@link CacheBasedDataset} with default
      * predicate that passes all upstream entries to dataset.
@@ -94,10 +96,28 @@ public class CacheBasedDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
         IgniteCache<K, V> upstreamCache,
         IgniteBiPredicate<K, V> filter,
         UpstreamTransformerBuilder transformerBuilder) {
+        this(ignite, upstreamCache, filter, transformerBuilder, false);
+    }
+
+    /**
+     * Constructs a new instance of cache based dataset builder that makes {@link CacheBasedDataset}.
+     *
+     * @param ignite Ignite.
+     * @param upstreamCache Upstream cache.
+     * @param filter Filter.
+     * @param transformerBuilder Transformer builder.
+     * @param isKeepBinary Is keep binary for upstream cache.
+     */
+    public CacheBasedDatasetBuilder(Ignite ignite,
+        IgniteCache<K, V> upstreamCache,
+        IgniteBiPredicate<K, V> filter,
+        UpstreamTransformerBuilder transformerBuilder,
+        Boolean isKeepBinary){
         this.ignite = ignite;
         this.upstreamCache = upstreamCache;
         this.filter = filter;
         this.transformerBuilder = transformerBuilder;
+        this.upstreamKeepBinary = isKeepBinary;
     }
 
     /** {@inheritDoc} */
@@ -129,10 +149,11 @@ public class CacheBasedDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
             partCtxBuilder,
             envBuilder,
             RETRIES,
-            RETRY_INTERVAL
+            RETRY_INTERVAL,
+            upstreamKeepBinary
         );
 
-        return new CacheBasedDataset<>(ignite, upstreamCache, filter, transformerBuilder, datasetCache, envBuilder, partDataBuilder, datasetId);
+        return new CacheBasedDataset<>(ignite, upstreamCache, filter, transformerBuilder, datasetCache, envBuilder, partDataBuilder, datasetId, upstreamKeepBinary);
     }
 
     /** {@inheritDoc} */
@@ -145,4 +166,13 @@ public class CacheBasedDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
         return new CacheBasedDatasetBuilder<>(ignite, upstreamCache,
             (e1, e2) -> filter.apply(e1, e2) && filterToAdd.apply(e1, e2));
     }
+
+    /**
+     * Add keepBinary policy. False by default.
+     *
+     * @param isKeepBinary Is keep binary.
+     */
+    public CacheBasedDatasetBuilder<K, V> withKeepBinary(boolean isKeepBinary){
+        return new CacheBasedDatasetBuilder<K, V>(ignite, upstreamCache, filter, transformerBuilder, isKeepBinary);
+    }
 }
index f12977c..868245d 100644 (file)
@@ -188,7 +188,8 @@ public class ComputeUtils {
         UpstreamTransformerBuilder transformerBuilder,
         String datasetCacheName, UUID datasetId,
         PartitionDataBuilder<K, V, C, D> partDataBuilder,
-        LearningEnvironment env) {
+        LearningEnvironment env,
+        boolean isKeepBinary) {
 
         PartitionDataStorage dataStorage = (PartitionDataStorage)ignite
             .cluster()
@@ -203,6 +204,9 @@ public class ComputeUtils {
 
             IgniteCache<K, V> upstreamCache = ignite.cache(upstreamCacheName);
 
+            if (isKeepBinary)
+                upstreamCache = upstreamCache.withKeepBinary();
+
             ScanQuery<K, V> qry = new ScanQuery<>();
             qry.setLocal(true);
             qry.setPartition(part);
@@ -260,6 +264,7 @@ public class ComputeUtils {
      * @param transformerBuilder Upstream transformer builder.
      * @param ctxBuilder Partition {@code context} builder.
      * @param envBuilder Environment builder.
+     * @param isKeepBinary Support of binary objects.
      * @param <K> Type of a key in {@code upstream} data.
      * @param <V> Type of a value in {@code upstream} data.
      * @param <C> Type of a partition {@code context}.
@@ -273,13 +278,17 @@ public class ComputeUtils {
         PartitionContextBuilder<K, V, C> ctxBuilder,
         LearningEnvironmentBuilder envBuilder,
         int retries,
-        int interval) {
+        int interval,
+        boolean isKeepBinary) {
         affinityCallWithRetries(ignite, Arrays.asList(datasetCacheName, upstreamCacheName), part -> {
             Ignite locIgnite = Ignition.localIgnite();
             LearningEnvironment env = envBuilder.buildForWorker(part);
 
             IgniteCache<K, V> locUpstreamCache = locIgnite.cache(upstreamCacheName);
 
+            if (isKeepBinary)
+                locUpstreamCache = locUpstreamCache.withKeepBinary();
+
             ScanQuery<K, V> qry = new ScanQuery<>();
             qry.setLocal(true);
             qry.setPartition(part);
@@ -315,33 +324,6 @@ public class ComputeUtils {
     }
 
     /**
-     * Initializes partition {@code context} by loading it from a partition {@code upstream}.
-     *
-     * @param ignite Ignite instance.
-     * @param upstreamCacheName Name of an {@code upstream} cache.
-     * @param filter Filter for {@code upstream} data.
-     * @param transformerBuilder Builder of transformer of upstream data.
-     * @param datasetCacheName Name of a partition {@code context} cache.
-     * @param ctxBuilder Partition {@code context} builder.
-     * @param envBuilder Environment builder.
-     * @param retries Number of retries for the case when one of partitions not found on the node.
-     * @param <K> Type of a key in {@code upstream} data.
-     * @param <V> Type of a value in {@code upstream} data.
-     * @param <C> Type of a partition {@code context}.
-     */
-    public static <K, V, C extends Serializable> void initContext(
-        Ignite ignite,
-        String upstreamCacheName,
-        IgniteBiPredicate<K, V> filter,
-        UpstreamTransformerBuilder transformerBuilder,
-        String datasetCacheName,
-        PartitionContextBuilder<K, V, C> ctxBuilder,
-        LearningEnvironmentBuilder envBuilder,
-        int retries) {
-        initContext(ignite, upstreamCacheName, transformerBuilder, filter, datasetCacheName, ctxBuilder, envBuilder, retries, 0);
-    }
-
-    /**
      * Extracts partition {@code context} from the Ignite Cache.
      *
      * @param ignite Ignite instance.
index e3e1d2b..2f42dd5 100644 (file)
@@ -28,7 +28,8 @@ import org.junit.runners.Suite;
     LocalModelsTest.class,
     CollectionsTest.class,
     ExternalizeTest.class,
-    ModelTest.class
+    ModelTest.class,
+    KeepBinaryTest.class
 })
 public class CommonTestSuite {
 }
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/common/KeepBinaryTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/common/KeepBinaryTest.java
new file mode 100644 (file)
index 0000000..3af62bd
--- /dev/null
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.common;
+
+import java.util.UUID;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.binary.BinaryObject;
+import org.apache.ignite.binary.BinaryObjectBuilder;
+import org.apache.ignite.configuration.CacheConfiguration;
+import org.apache.ignite.internal.util.IgniteUtils;
+import org.apache.ignite.ml.clustering.kmeans.KMeansModel;
+import org.apache.ignite.ml.clustering.kmeans.KMeansTrainer;
+import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Test for IGNITE-10700.
+ */
+@RunWith(JUnit4.class)
+public class KeepBinaryTest extends GridCommonAbstractTest {
+    /** Number of nodes in grid. */
+    private static final int NODE_COUNT = 2;
+    /** Number of samples. */
+    public static final int NUMBER_OF_SAMPLES = 1000;
+    /** Half of samples. */
+    public static final int HALF = NUMBER_OF_SAMPLES / 2;
+
+    /** Ignite instance. */
+    private Ignite ignite;
+
+    /** {@inheritDoc} */
+    @Override protected void beforeTestsStarted() throws Exception {
+        for (int i = 1; i <= NODE_COUNT; i++)
+            startGrid(i);
+    }
+
+    /** {@inheritDoc} */
+    @Override protected void afterTestsStopped() {
+        stopAllGrids();
+    }
+
+    /** {@inheritDoc} */
+    @Override protected void beforeTest() {
+        /* Grid instance. */
+        ignite = grid(NODE_COUNT);
+        ignite.configuration().setPeerClassLoadingEnabled(true);
+        IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+    }
+
+    /**
+     * Startup Ignite, populate cache and train some model.
+     */
+    @Test
+    public void test() {
+        IgniteCache<Integer, BinaryObject> dataCache = populateCache(ignite);
+
+        IgniteBiFunction<Integer, BinaryObject, Vector> featureExtractor
+            = (k, v) -> VectorUtils.of(new double[]{v.field("feature1")});
+
+        IgniteBiFunction<Integer, BinaryObject, Double> lbExtractor = (k, v) -> (double) v.field("label");
+
+        KMeansTrainer trainer = new KMeansTrainer().withSeed(123L);
+
+        CacheBasedDatasetBuilder<Integer, BinaryObject> datasetBuilder =
+            new CacheBasedDatasetBuilder<>(ignite, dataCache).withKeepBinary(true);
+
+        KMeansModel kmdl = trainer.fit(datasetBuilder, featureExtractor, lbExtractor);
+
+        assertTrue(kmdl.predict(VectorUtils.num2Vec(0.0)) == 0);
+        assertTrue(kmdl.predict(VectorUtils.num2Vec(10.0)) == 1);
+    }
+
+    /**
+     * Populate cache with binary objects.
+     */
+    private IgniteCache<Integer, BinaryObject> populateCache(Ignite ignite) {
+        CacheConfiguration<Integer, BinaryObject> cacheConfiguration = new CacheConfiguration<>();
+        cacheConfiguration.setName("TEST_" + UUID.randomUUID());
+
+        IgniteCache<Integer, BinaryObject> cache = ignite.createCache(cacheConfiguration).withKeepBinary();
+
+        BinaryObjectBuilder builder = ignite.binary().builder("testType");
+
+        for (int i = 0; i < NUMBER_OF_SAMPLES; i++) {
+            if (i < HALF)
+                cache.put(i, builder.setField("feature1", 0.0).setField("label", 0.0).build());
+            else
+                cache.put(i, builder.setField("feature1", 10.0).setField("label", 1.0).build());
+        }
+
+        return cache;
+    }
+}
index bb8570d..1205d53 100644 (file)
@@ -195,7 +195,8 @@ public class ComputeUtilsTest extends GridCommonAbstractTest {
                         UpstreamEntry<Integer, Integer> e = upstream.next();
                         return new TestPartitionData(e.getKey() + e.getValue());
                     },
-                    TestUtils.testEnvBuilder().buildForWorker(part)
+                    TestUtils.testEnvBuilder().buildForWorker(part),
+                    false
                 ),
                 0
             );
@@ -234,8 +235,8 @@ public class ComputeUtilsTest extends GridCommonAbstractTest {
         ComputeUtils.<Integer, Integer, Integer>initContext(
             ignite,
             upstreamCacheName,
-            (k, v) -> true,
             UpstreamTransformerBuilder.identity(),
+            (k, v) -> true,
             datasetCacheName,
             (env, upstream, upstreamSize) -> {
 
@@ -245,7 +246,9 @@ public class ComputeUtilsTest extends GridCommonAbstractTest {
                 return e.getKey() + e.getValue();
             },
             TestUtils.testEnvBuilder(),
-            0
+            0,
+            0,
+            false
         );
 
         assertEquals(1, datasetCache.size());
index 3ffd5cf..7e20adf 100644 (file)
@@ -28,7 +28,6 @@ import org.junit.Test;
 
 /** Test for {@link DiscreteNaiveBayesTrainer} */
 public class DiscreteNaiveBayesTrainerTest extends TrainerTest {
-
     /** Precision in test checks. */
     private static final double PRECISION = 1e-2;
     /** */
index e934e96..9a266a6 100644 (file)
@@ -43,9 +43,9 @@ public class LogisticRegressionModelTest {
 
         assertEquals(0.1, new LogisticRegressionModel(weights, 1.0).withThreshold(0.1).threshold(), 0);
 
-        assertTrue(new LogisticRegressionModel(weights, 1.0).toString().length() > 0);
-        assertTrue(new LogisticRegressionModel(weights, 1.0).toString(true).length() > 0);
-        assertTrue(new LogisticRegressionModel(weights, 1.0).toString(false).length() > 0);
+        assertTrue(!new LogisticRegressionModel(weights, 1.0).toString().isEmpty());
+        assertTrue(!new LogisticRegressionModel(weights, 1.0).toString(true).isEmpty());
+        assertTrue(!new LogisticRegressionModel(weights, 1.0).toString(false).isEmpty());
 
         verifyPredict(new LogisticRegressionModel(weights, 1.0).withRawLabels(true));
         verifyPredict(new LogisticRegressionModel(null, 1.0).withRawLabels(true).withWeights(weights));