IGNITE-11010: [ML] Use seed from learningEnviroment for KMeans trainer
authorYuriBabak <y.chief@gmail.com>
Wed, 23 Jan 2019 08:23:33 +0000 (11:23 +0300)
committerYury Babak <ybabak@gridgain.com>
Wed, 23 Jan 2019 08:23:33 +0000 (11:23 +0300)
This closes #5884

examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java
examples/src/main/java/org/apache/ignite/examples/ml/knn/ANNClassificationExample.java
modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java
modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/common/KeepBinaryTest.java
modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java

index 46550f3..e748f4d 100644 (file)
@@ -57,8 +57,7 @@ public class KMeansClusterizationExample {
             IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
                 .fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
 
-            KMeansTrainer trainer = new KMeansTrainer()
-                .withSeed(7867L);
+            KMeansTrainer trainer = new KMeansTrainer();
 
             KMeansModel mdl = trainer.fit(
                 ignite,
index 71546e9..a5d15d1 100644 (file)
@@ -65,7 +65,6 @@ public class ANNClassificationExample {
                 .withDistance(new ManhattanDistance())
                 .withK(50)
                 .withMaxIterations(1000)
-                .withSeed(1234L)
                 .withEpsilon(1e-2);
 
             long startTrainingTime = System.currentTimeMillis();
index 3206b5f..4bd017f 100644 (file)
@@ -61,9 +61,6 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> {
     /** Distance measure. */
     private DistanceMeasure distance = new EuclideanDistance();
 
-    /** KMeans initializer. */
-    private long seed;
-
     /**
      * Trains model based on the specified data.
      *
@@ -235,7 +232,7 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> {
 
             if (data.rowSize() != 0) {
                 if (data.rowSize() > k) { // If it's enough rows in partition to pick k vectors.
-                    final Random random = new Random(seed);
+                    final Random random = environment.randomNumbersGenerator();
 
                     for (int i = 0; i < k; i++) {
                         Set<Integer> uniqueIndices = new HashSet<>();
@@ -272,7 +269,7 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> {
         // Pick k vectors randomly.
         if (rndPnts.size() >= k) {
             for (int i = 0; i < k; i++) {
-                final LabeledVector rndPnt = rndPnts.get(new Random(seed).nextInt(rndPnts.size()));
+                final LabeledVector rndPnt = rndPnts.get(environment.randomNumbersGenerator().nextInt(rndPnts.size()));
                 rndPnts.remove(rndPnt);
                 initCenters[i] = rndPnt.features();
             }
@@ -394,24 +391,4 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> {
         this.distance = distance;
         return this;
     }
-
-    /**
-     * Gets the seed number.
-     *
-     * @return The parameter value.
-     */
-    public long getSeed() {
-        return seed;
-    }
-
-    /**
-     * Set up the seed.
-     *
-     * @param seed The parameter value.
-     * @return Model with new seed parameter value.
-     */
-    public KMeansTrainer withSeed(long seed) {
-        this.seed = seed;
-        return this;
-    }
 }
index 0cdfc52..2da09db 100644 (file)
@@ -60,9 +60,6 @@ public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClass
     /** Distance measure. */
     private DistanceMeasure distance = new EuclideanDistance();
 
-    /** KMeans initializer. */
-    private long seed;
-
     /**
      * Trains model based on the specified data.
      *
@@ -140,7 +137,6 @@ public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClass
         KMeansTrainer trainer = new KMeansTrainer()
             .withAmountOfClusters(k)
             .withMaxIterations(maxIterations)
-            .withSeed(seed)
             .withDistance(distance)
             .withEpsilon(epsilon);
 
@@ -334,26 +330,6 @@ public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClass
         return this;
     }
 
-    /**
-     * Gets the seed number.
-     *
-     * @return The parameter value.
-     */
-    public long getSeed() {
-        return seed;
-    }
-
-    /**
-     * Set up the seed.
-     *
-     * @param seed The parameter value.
-     * @return Model with new seed parameter value.
-     */
-    public ANNClassificationTrainer withSeed(long seed) {
-        this.seed = seed;
-        return this;
-    }
-
     /** Service class used for statistics. */
     public static class CentroidStat implements Serializable {
         /** Serial version uid. */
index e621801..111f1bb 100644 (file)
@@ -35,7 +35,7 @@ public class KNNRegressionTrainer extends SingleLabelDatasetTrainer<KNNRegressio
      * @param lbExtractor Label extractor.
      * @return Model.
      */
-    public <K, V> KNNRegressionModel fit(DatasetBuilder<K, V> datasetBuilder,
+    @Override public <K, V> KNNRegressionModel fit(DatasetBuilder<K, V> datasetBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
 
         return updateModel(null, datasetBuilder, featureExtractor, lbExtractor);
index cf511ec..43df304 100644 (file)
@@ -108,7 +108,7 @@ public class MLPTrainer<P extends Serializable> extends MultiLabelDatasetTrainer
     }
 
     /** {@inheritDoc} */
-    public <K, V> MultilayerPerceptron fit(DatasetBuilder<K, V> datasetBuilder,
+    @Override public <K, V> MultilayerPerceptron fit(DatasetBuilder<K, V> datasetBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, double[]> lbExtractor) {
 
         return updateModel(null, datasetBuilder, featureExtractor, lbExtractor);
index e33ad08..fe4c74d 100644 (file)
@@ -109,10 +109,8 @@ public class KMeansTrainerTest extends TrainerTest {
             .withDistance(new EuclideanDistance())
             .withAmountOfClusters(10)
             .withMaxIterations(1)
-            .withEpsilon(PRECISION)
-            .withSeed(2);
+            .withEpsilon(PRECISION);
         assertEquals(10, trainer.getAmountOfClusters());
-        assertEquals(2, trainer.getSeed());
         assertTrue(trainer.getDistance() instanceof EuclideanDistance);
         return trainer;
     }
index 1d1103f..bc2e3d5 100644 (file)
@@ -81,7 +81,7 @@ public class KeepBinaryTest extends GridCommonAbstractTest {
 
         IgniteBiFunction<Integer, BinaryObject, Double> lbExtractor = (k, v) -> (double) v.field("label");
 
-        KMeansTrainer trainer = new KMeansTrainer().withSeed(123L);
+        KMeansTrainer trainer = new KMeansTrainer();
 
         CacheBasedDatasetBuilder<Integer, BinaryObject> datasetBuilder =
             new CacheBasedDatasetBuilder<>(ignite, dataCache).withKeepBinary(true);
index 9c75824..2f779a2 100644 (file)
@@ -44,14 +44,12 @@ public class ANNClassificationTest extends TrainerTest {
             .withK(10)
             .withMaxIterations(10)
             .withEpsilon(1e-4)
-            .withDistance(new EuclideanDistance())
-            .withSeed(1234L);
+            .withDistance(new EuclideanDistance());
 
         Assert.assertEquals(10, trainer.getK());
         Assert.assertEquals(10, trainer.getMaxIterations());
         TestUtils.assertEquals(1e-4, trainer.getEpsilon(), PRECISION);
         Assert.assertEquals(new EuclideanDistance(), trainer.getDistance());
-        Assert.assertEquals(1234L, trainer.getSeed());
 
         NNClassificationModel mdl = trainer.fit(
             cacheMock,
@@ -83,7 +81,7 @@ public class ANNClassificationTest extends TrainerTest {
             .withEpsilon(1e-4)
             .withDistance(new EuclideanDistance());
 
-        ANNClassificationModel originalMdl = (ANNClassificationModel) trainer.withSeed(1234L).fit(
+        ANNClassificationModel originalMdl = (ANNClassificationModel) trainer.fit(
             cacheMock,
             parts,
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
@@ -92,7 +90,7 @@ public class ANNClassificationTest extends TrainerTest {
             .withDistanceMeasure(new EuclideanDistance())
             .withStrategy(NNStrategy.SIMPLE);
 
-        ANNClassificationModel updatedOnSameDataset = (ANNClassificationModel) trainer.withSeed(1234L).update(originalMdl,
+        ANNClassificationModel updatedOnSameDataset = (ANNClassificationModel) trainer.update(originalMdl,
             cacheMock, parts,
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
             (k, v) -> v[2]
@@ -100,7 +98,7 @@ public class ANNClassificationTest extends TrainerTest {
             .withDistanceMeasure(new EuclideanDistance())
             .withStrategy(NNStrategy.SIMPLE);
 
-        ANNClassificationModel updatedOnEmptyDataset = (ANNClassificationModel) trainer.withSeed(1234L).update(originalMdl,
+        ANNClassificationModel updatedOnEmptyDataset = (ANNClassificationModel) trainer.update(originalMdl,
             new HashMap<Integer, double[]>(), parts,
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
             (k, v) -> v[2]