IGNITE-11232: [ML] Random Forest, NodeId exception
authorAlexey Platonov <aplatonovv@gmail.com>
Fri, 8 Feb 2019 10:48:04 +0000 (13:48 +0300)
committerYury Babak <ybabak@gridgain.com>
Fri, 8 Feb 2019 10:48:04 +0000 (13:48 +0300)
(Failed to serialize object)

This closes #6043

17 files changed:
modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/BucketMeta.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/ObjectHistogram.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/NodeId.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/NodeSplit.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/TreeNode.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/GiniHistogram.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogram.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogramsComputer.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/MSEHistogram.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/basic/BootstrappedVectorsHistogram.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/basic/CountersHistogram.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/basic/package-info.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/statistics/ClassifierLeafValuesComputer.java
modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
modules/ml/src/test/java/org/apache/ignite/ml/dataset/feature/ObjectHistogramTest.java
modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestIntegrationTest.java [new file with mode: 0644]
modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestTreeTestSuite.java

index 4ac9adb..5dab662 100644 (file)
@@ -23,6 +23,9 @@ import java.io.Serializable;
  * Bucket meta-information for feature histogram.
  */
 public class BucketMeta implements Serializable {
+    /** Serial version uid. */
+    private static final long serialVersionUID = 7827158624437006995L;
+
     /** Feature meta. */
     private final FeatureMeta featureMeta;
 
index d894c3f..697e79d 100644 (file)
@@ -22,44 +22,23 @@ import java.util.Map;
 import java.util.Optional;
 import java.util.Set;
 import java.util.TreeMap;
-import org.apache.ignite.ml.math.functions.IgniteFunction;
 
 /**
  * Basic implementation of {@link Histogram} that implements also {@link DistributionComputer}.
  *
  * @param <T> Type of object for histogram.
  */
-public class ObjectHistogram<T> implements Histogram<T, ObjectHistogram<T>>, DistributionComputer {
+public abstract class ObjectHistogram<T> implements Histogram<T, ObjectHistogram<T>>, DistributionComputer {
     /** Serial version uid. */
     private static final long serialVersionUID = -2708731174031404487L;
 
-    /** Bucket mapping. */
-    private final IgniteFunction<T, Integer> bucketMapping;
-
-    /** Mapping to counter. */
-    private final IgniteFunction<T, Double> mappingToCntr;
-
     /** Histogram. */
-    private final Map<Integer, Double> hist;
-
-    /**
-     * Create an instance of ObjectHistogram.
-     *
-     * @param bucketMapping Bucket mapping.
-     * @param mappingToCntr Mapping to counter.
-     */
-    public ObjectHistogram(IgniteFunction<T, Integer> bucketMapping,
-        IgniteFunction<T, Double> mappingToCntr) {
-
-        this.bucketMapping = bucketMapping;
-        this.mappingToCntr = mappingToCntr;
-        this.hist = new TreeMap<>(Integer::compareTo);
-    }
+    private final TreeMap<Integer, Double> hist = new TreeMap<>();
 
     /** {@inheritDoc} */
     @Override public void addElement(T val) {
-        Integer bucket = bucketMapping.apply(val);
-        Double cntrVal = mappingToCntr.apply(val);
+        Integer bucket = mapToBucket(val);
+        Double cntrVal = mapToCounter(val);
 
         assert cntrVal >= 0;
         Double bucketVal = hist.getOrDefault(bucket, 0.0);
@@ -91,7 +70,7 @@ public class ObjectHistogram<T> implements Histogram<T, ObjectHistogram<T>>, Dis
 
     /** {@inheritDoc} */
     @Override public ObjectHistogram<T> plus(ObjectHistogram<T> other) {
-        ObjectHistogram<T> res = new ObjectHistogram<>(bucketMapping, mappingToCntr);
+        ObjectHistogram<T> res = newInstance();
         addTo(this.hist, res.hist);
         addTo(other.hist, res.hist);
         return res;
@@ -111,7 +90,7 @@ public class ObjectHistogram<T> implements Histogram<T, ObjectHistogram<T>>, Dis
     }
 
     /** {@inheritDoc} */
-    public boolean isEqualTo(ObjectHistogram<T> other) {
+    @Override public boolean isEqualTo(ObjectHistogram<T> other) {
         Set<Integer> totalBuckets = new HashSet<>(buckets());
         totalBuckets.addAll(other.buckets());
         if(totalBuckets.size() != buckets().size())
@@ -126,4 +105,27 @@ public class ObjectHistogram<T> implements Histogram<T, ObjectHistogram<T>>, Dis
 
         return true;
     }
+
+    /**
+     * Bucket mapping.
+     *
+     * @param obj Object.
+     * @return BucketId.
+     */
+    public abstract Integer mapToBucket(T obj);
+
+    /**
+     * Counter mapping.
+     *
+     * @param obj Object.
+     * @return counter.
+     */
+    public abstract Double mapToCounter(T obj);
+
+    /**
+     * Creates an instance of ObjectHistogram from child class.
+     *
+     * @return object histogram.
+     */
+    public abstract ObjectHistogram<T> newInstance();
 }
index 2f40af3..f0ecd62 100644 (file)
 
 package org.apache.ignite.ml.tree.randomforest.data;
 
-import org.apache.ignite.lang.IgniteBiTuple;
+import java.io.Serializable;
+import java.util.Objects;
 
 /**
  * Class represents Node id in Random Forest consisting of tree id and node id in tree in according to
  * breadth-first search in tree.
  */
-public class NodeId extends IgniteBiTuple<Integer, Long> {
+public class NodeId implements Serializable {
     /** Serial version uid. */
     private static final long serialVersionUID = 4400852013136423333L;
 
+    /** Tree id. */
+    private final int treeId;
+
+    /** Node id. */
+    private final long nodeId;
+
     /**
      * Create an instance of NodeId.
      *
      * @param treeId Tree id.
      * @param nodeId Node id.
      */
-    public NodeId(Integer treeId, Long nodeId) {
-        super(treeId, nodeId);
+    public NodeId(int treeId, long nodeId) {
+        this.treeId = treeId;
+        this.nodeId = nodeId;
     }
 
     /**
@@ -42,7 +50,7 @@ public class NodeId extends IgniteBiTuple<Integer, Long> {
      * @return Tree id.
      */
     public int treeId() {
-        return get1();
+        return treeId;
     }
 
     /**
@@ -50,6 +58,22 @@ public class NodeId extends IgniteBiTuple<Integer, Long> {
      * @return Node id.
      */
     public long nodeId() {
-        return get2();
+        return nodeId;
+    }
+
+    /** {@inheritDoc} */
+    @Override public boolean equals(Object o) {
+        if (this == o)
+            return true;
+        if (o == null || getClass() != o.getClass())
+            return false;
+        NodeId id = (NodeId)o;
+        return Objects.equals(treeId, id.treeId) &&
+            Objects.equals(nodeId, id.nodeId);
+    }
+
+    /** {@inheritDoc} */
+    @Override public int hashCode() {
+        return Objects.hash(treeId, nodeId);
     }
 }
index 3ccb568..0bae0b3 100644 (file)
 
 package org.apache.ignite.ml.tree.randomforest.data;
 
+import java.io.Serializable;
 import java.util.List;
 
 /**
  * Class represents a split point for decision tree.
  */
-public class NodeSplit {
+public class NodeSplit implements Serializable {
+    /** Serial version uid. */
+    private static final long serialVersionUID = 1331311529596106124L;
+
     /** Feature id in feature vector. */
     private final int featureId;
 
index 8c11bdb..96cdced 100644 (file)
@@ -17,7 +17,6 @@
 
 package org.apache.ignite.ml.tree.randomforest.data;
 
-import java.io.Serializable;
 import java.util.Arrays;
 import java.util.List;
 import org.apache.ignite.ml.IgniteModel;
@@ -26,7 +25,7 @@ import org.apache.ignite.ml.math.primitives.vector.Vector;
 /**
  * Decision tree node class.
  */
-public class TreeNode implements IgniteModel<Vector, Double>, Serializable {
+public class TreeNode implements IgniteModel<Vector, Double> {
     /** Serial version uid. */
     private static final long serialVersionUID = -8546263332508653661L;
 
@@ -83,7 +82,7 @@ public class TreeNode implements IgniteModel<Vector, Double>, Serializable {
     }
 
     /** {@inheritDoc} */
-    public Double predict(Vector features) {
+    @Override public Double predict(Vector features) {
         assert type != Type.UNKNOWN;
 
         if (type == Type.LEAF)
index 3ca2a93..97f830c 100644 (file)
@@ -31,6 +31,7 @@ import org.apache.ignite.ml.dataset.feature.BucketMeta;
 import org.apache.ignite.ml.dataset.feature.ObjectHistogram;
 import org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector;
 import org.apache.ignite.ml.tree.randomforest.data.NodeSplit;
+import org.apache.ignite.ml.tree.randomforest.data.impurity.basic.CountersHistogram;
 
 /**
  * Class contains implementation of splitting point finding algorithm based on Gini metric (see
@@ -68,11 +69,10 @@ public class GiniHistogram extends ImpurityHistogram implements ImpurityComputer
         this.sampleId = sampleId;
         this.bucketMeta = bucketMeta;
         this.lblMapping = lblMapping;
+        this.bucketIds = new TreeSet<>();
 
         for (int i = 0; i < lblMapping.size(); i++)
-            hists.add(new ObjectHistogram<>(this::bucketMap, this::counterMap));
-
-        this.bucketIds = new TreeSet<>();
+            hists.add(new CountersHistogram(bucketIds, bucketMeta, featureId, sampleId));
     }
 
     /** {@inheritDoc} */
@@ -175,28 +175,6 @@ public class GiniHistogram extends ImpurityHistogram implements ImpurityComputer
         return hists.get(lblMapping.get(lbl));
     }
 
-    /**
-     * Maps vector to counter value.
-     *
-     * @param vec Vector.
-     * @return Counter value.
-     */
-    private Double counterMap(BootstrappedVector vec) {
-        return (double)vec.counters()[sampleId];
-    }
-
-    /**
-     * Maps vector to bucket id.
-     *
-     * @param vec Vector.
-     * @return Bucket id.
-     */
-    private Integer bucketMap(BootstrappedVector vec) {
-        int bucketId = bucketMeta.getBucketId(vec.features().get(featureId));
-        this.bucketIds.add(bucketId);
-        return bucketId;
-    }
-
     /** {@inheritDoc} */
     @Override public boolean isEqualTo(GiniHistogram other) {
         HashSet<Integer> unionBuckets = new HashSet<>(buckets());
@@ -221,4 +199,5 @@ public class GiniHistogram extends ImpurityHistogram implements ImpurityComputer
 
         return true;
     }
+
 }
index 296d862..b81699f 100644 (file)
@@ -17,6 +17,7 @@
 
 package org.apache.ignite.ml.tree.randomforest.data.impurity;
 
+import java.io.Serializable;
 import java.util.Optional;
 import java.util.Set;
 import java.util.TreeSet;
@@ -25,7 +26,10 @@ import org.apache.ignite.ml.tree.randomforest.data.NodeSplit;
 /**
  * Helper class for ImpurityHistograms.
  */
-public abstract class ImpurityHistogram {
+public abstract class ImpurityHistogram implements Serializable {
+    /** Serial version uid. */
+    private static final long serialVersionUID = -8982240673834216798L;
+
     /** Bucket ids. */
     protected final Set<Integer> bucketIds = new TreeSet<>();
 
index d202441..1d9df47 100644 (file)
@@ -168,7 +168,7 @@ public abstract class ImpurityHistogramsComputer<S extends ImpurityComputer<Boot
          * @param other Other instance.
          */
         public NodeImpurityHistograms<S> plus(NodeImpurityHistograms<S> other) {
-            assert nodeId == other.nodeId;
+            assert nodeId.equals(other.nodeId);
             NodeImpurityHistograms<S> res = new NodeImpurityHistograms<>(nodeId);
             addTo(this.perFeatureStatistics, res.perFeatureStatistics);
             addTo(other.perFeatureStatistics, res.perFeatureStatistics);
index c00b1c1..5b847ed 100644 (file)
@@ -25,10 +25,12 @@ import org.apache.ignite.ml.dataset.feature.BucketMeta;
 import org.apache.ignite.ml.dataset.feature.ObjectHistogram;
 import org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector;
 import org.apache.ignite.ml.tree.randomforest.data.NodeSplit;
+import org.apache.ignite.ml.tree.randomforest.data.impurity.basic.BootstrappedVectorsHistogram;
+import org.apache.ignite.ml.tree.randomforest.data.impurity.basic.CountersHistogram;
 
 /**
- * Class contains implementation of splitting point finding algorithm based on MSE metric (see https://en.wikipedia.org/wiki/Mean_squared_error)
- * and represents a set of histograms in according to this metric.
+ * Class contains implementation of splitting point finding algorithm based on MSE metric (see
+ * https://en.wikipedia.org/wiki/Mean_squared_error) and represents a set of histograms in according to this metric.
  */
 public class MSEHistogram extends ImpurityHistogram implements ImpurityComputer<BootstrappedVector, MSEHistogram> {
     /** Serial version uid. */
@@ -60,9 +62,9 @@ public class MSEHistogram extends ImpurityHistogram implements ImpurityComputer<
         this.bucketMeta = bucketMeta;
         this.sampleId = sampleId;
 
-        counters = new ObjectHistogram<>(this::bucketMap, this::counterMap);
-        sumOfLabels = new ObjectHistogram<>(this::bucketMap, this::ysMap);
-        sumOfSquaredLabels = new ObjectHistogram<>(this::bucketMap, this::y2sMap);
+        counters = new CountersHistogram(bucketIds, bucketMeta, featureId, sampleId);
+        sumOfLabels = new SumOfLabelsHistogram(bucketIds, bucketMeta, featureId, sampleId, 1);
+        sumOfSquaredLabels = new SumOfLabelsHistogram(bucketIds, bucketMeta, featureId, sampleId, 2);
     }
 
     /** {@inheritDoc} */
@@ -221,14 +223,62 @@ public class MSEHistogram extends ImpurityHistogram implements ImpurityComputer<
     @Override public boolean isEqualTo(MSEHistogram other) {
         HashSet<Integer> unionBuckets = new HashSet<>(buckets());
         unionBuckets.addAll(other.bucketIds);
-        if(unionBuckets.size() != bucketIds.size())
+        if (unionBuckets.size() != bucketIds.size())
             return false;
 
-        if(!this.counters.isEqualTo(other.counters))
+        if (!this.counters.isEqualTo(other.counters))
             return false;
-        if(!this.sumOfLabels.isEqualTo(other.sumOfLabels))
+        if (!this.sumOfLabels.isEqualTo(other.sumOfLabels))
             return false;
 
         return this.sumOfSquaredLabels.isEqualTo(other.sumOfSquaredLabels);
     }
+
+    /**
+     * Class for label summurizing in histograms.
+     */
+    private static class SumOfLabelsHistogram extends BootstrappedVectorsHistogram {
+        /** Serial version uid. */
+        private static final long serialVersionUID = -3846156279667677800L;
+
+        /** Sample id. */
+        private final int sampleId;
+
+        /** Label power. */
+        private final double labelPower;
+
+        /**
+         * Create an instance of SumOfLabelsHistogram.
+         *
+         * @param bucketIds Bucket ids.
+         * @param bucketMeta Bucket meta.
+         * @param featureId Feature id.
+         * @param sampleId Sample id.
+         * @param labelPower Label power.
+         */
+        public SumOfLabelsHistogram(Set<Integer> bucketIds, BucketMeta bucketMeta, int featureId, int sampleId,
+            double labelPower) {
+
+            super(bucketIds, bucketMeta, featureId);
+            this.sampleId = sampleId;
+            this.labelPower = labelPower;
+        }
+
+        /** {@inheritDoc} */
+        @Override public Integer mapToBucket(BootstrappedVector vec) {
+            int bucketId = bucketMeta.getBucketId(vec.features().get(featureId));
+            this.bucketIds.add(bucketId);
+            return bucketId;
+        }
+
+        /** {@inheritDoc} */
+        @Override public Double mapToCounter(BootstrappedVector vec) {
+            return vec.counters()[sampleId] * Math.pow(vec.label(), labelPower);
+        }
+
+        /** {@inheritDoc} */
+        @Override public ObjectHistogram<BootstrappedVector> newInstance() {
+            return new SumOfLabelsHistogram(bucketIds, bucketMeta, featureId, sampleId, labelPower);
+        }
+    }
 }
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/basic/BootstrappedVectorsHistogram.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/basic/BootstrappedVectorsHistogram.java
new file mode 100644 (file)
index 0000000..8806606
--- /dev/null
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.tree.randomforest.data.impurity.basic;
+
+import java.util.Set;
+import org.apache.ignite.ml.dataset.feature.BucketMeta;
+import org.apache.ignite.ml.dataset.feature.ObjectHistogram;
+import org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector;
+
+/**
+ * Histogram for bootstrapped vectors with predefined bucket mapping logic for feature id == featureId.
+ */
+public abstract class BootstrappedVectorsHistogram extends ObjectHistogram<BootstrappedVector> {
+    /** Serial version uid. */
+    private static final long serialVersionUID = 6369546706769440897L;
+
+    /** Bucket ids. */
+    protected final Set<Integer> bucketIds;
+
+    /** Bucket meta. */
+    protected final BucketMeta bucketMeta;
+
+    /** Feature id. */
+    protected final int featureId;
+
+    /**
+     * Creates an instance of BootstrappedVectorsHistogram.
+     *
+     * @param bucketIds Bucket ids.
+     * @param bucketMeta Bucket meta.
+     * @param featureId Feature Id.
+     */
+    public BootstrappedVectorsHistogram(Set<Integer> bucketIds, BucketMeta bucketMeta, int featureId) {
+        this.bucketIds = bucketIds;
+        this.bucketMeta = bucketMeta;
+        this.featureId = featureId;
+    }
+
+    /** {@inheritDoc} */
+    @Override public Integer mapToBucket(BootstrappedVector vec) {
+        int bucketId = bucketMeta.getBucketId(vec.features().get(featureId));
+        this.bucketIds.add(bucketId);
+        return bucketId;
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/basic/CountersHistogram.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/basic/CountersHistogram.java
new file mode 100644 (file)
index 0000000..bd39e35
--- /dev/null
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.tree.randomforest.data.impurity.basic;
+
+import java.util.Set;
+import org.apache.ignite.ml.dataset.feature.BucketMeta;
+import org.apache.ignite.ml.dataset.feature.ObjectHistogram;
+import org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector;
+
+/**
+ * Represents a historam of element counts per bucket.
+ */
+public class CountersHistogram extends BootstrappedVectorsHistogram {
+    /** Serial version uid. */
+    private static final long serialVersionUID = 7744564790854918891L;
+
+    /** Sample id. */
+    private final int sampleId;
+
+    /**
+     * Creates an instance of CountersHistogram.
+     *
+     * @param bucketIds Bucket ids.
+     * @param bucketMeta Bucket meta.
+     * @param featureId Feature id.
+     * @param sampleId Sample Id.
+     */
+    public CountersHistogram(Set<Integer> bucketIds, BucketMeta bucketMeta, int featureId, int sampleId) {
+        super(bucketIds, bucketMeta, featureId);
+        this.sampleId = sampleId;
+    }
+
+    /** {@inheritDoc} */
+    @Override public Double mapToCounter(BootstrappedVector vec) {
+        return (double)vec.counters()[sampleId];
+    }
+
+    /** {@inheritDoc} */
+    @Override public ObjectHistogram<BootstrappedVector> newInstance() {
+        return new CountersHistogram(bucketIds, bucketMeta, featureId, sampleId);
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/basic/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/basic/package-info.java
new file mode 100644 (file)
index 0000000..d56d1a8
--- /dev/null
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains implementation of basic classes for impurity computers.
+ */
+package org.apache.ignite.ml.tree.randomforest.data.impurity.basic;
index 64297ff..eeab0ed 100644 (file)
@@ -42,12 +42,14 @@ public class ClassifierLeafValuesComputer extends LeafValuesComputer<ObjectHisto
     }
 
     /** {@inheritDoc} */
-    @Override protected void addElementToLeafStatistic(ObjectHistogram<BootstrappedVector> leafStatAggr, BootstrappedVector vec, int sampleId) {
+    @Override protected void addElementToLeafStatistic(ObjectHistogram<BootstrappedVector> leafStatAggr,
+        BootstrappedVector vec, int sampleId) {
         leafStatAggr.addElement(vec);
     }
 
     /** {@inheritDoc} */
-    @Override protected ObjectHistogram<BootstrappedVector> mergeLeafStats(ObjectHistogram<BootstrappedVector> leftStats,
+    @Override protected ObjectHistogram<BootstrappedVector> mergeLeafStats(
+        ObjectHistogram<BootstrappedVector> leftStats,
         ObjectHistogram<BootstrappedVector> rightStats) {
 
         return leftStats.plus(rightStats);
@@ -55,10 +57,7 @@ public class ClassifierLeafValuesComputer extends LeafValuesComputer<ObjectHisto
 
     /** {@inheritDoc} */
     @Override protected ObjectHistogram<BootstrappedVector> createLeafStatsAggregator(int sampleId) {
-        return new ObjectHistogram<>(
-            x -> lblMapping.get(x.label()),
-            x -> (double)x.counters()[sampleId]
-        );
+        return new LeafStatsHistogram(lblMapping, sampleId);
     }
 
     /**
@@ -71,7 +70,7 @@ public class ClassifierLeafValuesComputer extends LeafValuesComputer<ObjectHisto
             .max(Comparator.comparing(b -> stat.getValue(b).orElse(0.0)))
             .orElse(-1);
 
-        if(bucketId == -1)
+        if (bucketId == -1)
             return Double.NaN;
 
         return lblMapping.entrySet().stream()
@@ -79,4 +78,40 @@ public class ClassifierLeafValuesComputer extends LeafValuesComputer<ObjectHisto
             .findFirst()
             .get().getKey();
     }
+
+    /** */
+    private static class LeafStatsHistogram extends ObjectHistogram<BootstrappedVector> {
+        /** Serial version uid. */
+        private static final long serialVersionUID = 4017587488421003308L;
+
+        /** Label mapping. */
+        private final Map<Double, Integer> lblMapping;
+
+        /** Sample id. */
+        private final int sampleId;
+
+        /**
+         * @param lblMapping Lbl mapping.
+         * @param sampleId Sample id.
+         */
+        public LeafStatsHistogram(Map<Double, Integer> lblMapping, int sampleId) {
+            this.lblMapping = lblMapping;
+            this.sampleId = sampleId;
+        }
+
+        /** {@inheritDoc} */
+        @Override public Integer mapToBucket(BootstrappedVector x) {
+            return lblMapping.get(x.label());
+        }
+
+        /** {@inheritDoc} */
+        @Override public Double mapToCounter(BootstrappedVector x) {
+            return (double)x.counters()[sampleId];
+        }
+
+        /** {@inheritDoc} */
+        @Override public ObjectHistogram<BootstrappedVector> newInstance() {
+            return new LeafStatsHistogram(lblMapping, sampleId);
+        }
+    }
 }
index f33c9ff..24939e5 100644 (file)
@@ -35,6 +35,7 @@ import org.apache.ignite.ml.selection.SelectionTestSuite;
 import org.apache.ignite.ml.structures.StructuresTestSuite;
 import org.apache.ignite.ml.svm.SVMTestSuite;
 import org.apache.ignite.ml.tree.DecisionTreeTestSuite;
+import org.apache.ignite.ml.tree.randomforest.RandomForestTreeTestSuite;
 import org.apache.ignite.ml.util.UtilTestSuite;
 import org.apache.ignite.ml.util.generators.DataStreamGeneratorTestSuite;
 import org.junit.runner.RunWith;
@@ -61,6 +62,7 @@ import org.junit.runners.Suite;
     MultiClassTestSuite.class,
     DataStreamGeneratorTestSuite.class,
     UtilTestSuite.class,
+    RandomForestTreeTestSuite.class,
 
     /** JUnit 3 tests. */
     DecisionTreeTestSuite.class,
index 9efb939..bec8cb0 100644 (file)
@@ -47,8 +47,8 @@ public class ObjectHistogramTest {
      */
     @Before
     public void setUp() throws Exception {
-        hist1 = new ObjectHistogram<>(this::computeBucket, x -> 1.);
-        hist2 = new ObjectHistogram<>(this::computeBucket, x -> 1.);
+        hist1 = new TestHist1();
+        hist2 = new TestHist1();
 
         fillHist(hist1, dataFirstPart);
         fillHist(hist2, dataSecondPart);
@@ -124,7 +124,7 @@ public class ObjectHistogramTest {
         double[] sums = new double[distribution.size()];
 
         int ptr = 0;
-        for(int bucket : distribution.keySet()) {
+        for (int bucket : distribution.keySet()) {
             sums[ptr] = distribution.get(bucket);
             buckets[ptr++] = bucket;
         }
@@ -136,25 +136,25 @@ public class ObjectHistogramTest {
     /** */
     @Test
     public void testOfSum() {
-        IgniteFunction<Double, Integer> bucketMap = x -> (int) (Math.ceil(x * 100) % 100);
+        IgniteFunction<Double, Integer> bucketMap = x -> (int)(Math.ceil(x * 100) % 100);
         IgniteFunction<Double, Double> cntrMap = x -> Math.pow(x, 2);
 
-        ObjectHistogram<Double> forAllHistogram = new ObjectHistogram<>(bucketMap, cntrMap);
+        ObjectHistogram<Double> forAllHistogram = new TestHist2();
         Random rnd = new Random();
         List<ObjectHistogram<Double>> partitions = new ArrayList<>();
         int cntOfPartitions = rnd.nextInt(100);
         int sizeOfDataset = rnd.nextInt(10000);
-        for(int i = 0; i < cntOfPartitions; i++)
-            partitions.add(new ObjectHistogram<>(bucketMap, cntrMap));
+        for (int i = 0; i < cntOfPartitions; i++)
+            partitions.add(new TestHist2());
 
-        for(int i = 0; i < sizeOfDataset; i++) {
+        for (int i = 0; i < sizeOfDataset; i++) {
             double objVal = rnd.nextDouble();
             forAllHistogram.addElement(objVal);
             partitions.get(rnd.nextInt(partitions.size())).addElement(objVal);
         }
 
         Optional<ObjectHistogram<Double>> leftSum = partitions.stream().reduce(ObjectHistogram::plus);
-        Optional<ObjectHistogram<Double>> rightSum = partitions.stream().reduce((x,y) -> y.plus(x));
+        Optional<ObjectHistogram<Double>> rightSum = partitions.stream().reduce((x, y) -> y.plus(x));
         assertTrue(leftSum.isPresent());
         assertTrue(rightSum.isPresent());
         assertTrue(forAllHistogram.isEqualTo(leftSum.get()));
@@ -168,4 +168,46 @@ public class ObjectHistogramTest {
     private int computeBucket(Double val) {
         return (int)Math.rint(val);
     }
+
+    /** */
+    private static class TestHist1 extends ObjectHistogram<Double> {
+        /** Serial version uid. */
+        private static final long serialVersionUID = 2397005559193012602L;
+
+        /** {@inheritDoc} */
+        @Override public Integer mapToBucket(Double obj) {
+            return (int)Math.rint(obj);
+        }
+
+        /** {@inheritDoc} */
+        @Override public Double mapToCounter(Double obj) {
+            return 1.;
+        }
+
+        /** {@inheritDoc} */
+        @Override public ObjectHistogram<Double> newInstance() {
+            return new TestHist1();
+        }
+    }
+
+    /** */
+    private static class TestHist2 extends ObjectHistogram<Double> {
+        /** Serial version uid. */
+        private static final long serialVersionUID = -2080037140817825107L;
+
+        /** {@inheritDoc} */
+        @Override public Integer mapToBucket(Double x) {
+            return (int)(Math.ceil(x * 100) % 100);
+        }
+
+        /** {@inheritDoc} */
+        @Override public Double mapToCounter(Double x) {
+            return Math.pow(x, 2);
+        }
+
+        /** {@inheritDoc} */
+        @Override public ObjectHistogram<Double> newInstance() {
+            return new TestHist2();
+        }
+    }
 }
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestIntegrationTest.java
new file mode 100644 (file)
index 0000000..4dff495
--- /dev/null
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.tree.randomforest;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Random;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
+import org.apache.ignite.configuration.CacheConfiguration;
+import org.apache.ignite.internal.util.IgniteUtils;
+import org.apache.ignite.ml.composition.ModelsComposition;
+import org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictionsAggregator;
+import org.apache.ignite.ml.dataset.feature.FeatureMeta;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
+import org.junit.Test;
+
+/**
+ * Tests for {@link RandomForestTrainer}.
+ */
+public class RandomForestIntegrationTest extends GridCommonAbstractTest {
+    /** Number of nodes in grid */
+    private static final int NODE_COUNT = 3;
+
+    /** Ignite instance. */
+    private Ignite ignite;
+
+    /** {@inheritDoc} */
+    @Override protected void beforeTestsStarted() throws Exception {
+        for (int i = 1; i <= NODE_COUNT; i++)
+            startGrid(i);
+    }
+
+    /** {@inheritDoc} */
+    @Override protected void afterTestsStopped() {
+        stopAllGrids();
+    }
+
+    /**
+     * {@inheritDoc}
+     */
+    @Override protected void beforeTest() {
+        /* Grid instance. */
+        ignite = grid(NODE_COUNT);
+        ignite.configuration().setPeerClassLoadingEnabled(true);
+        IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+    }
+
+    /** */
+    @Test
+    public void testFit() {
+        int size = 100;
+
+        CacheConfiguration<Integer, double[]> trainingSetCacheCfg = new CacheConfiguration<>();
+        trainingSetCacheCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
+        trainingSetCacheCfg.setName("TRAINING_SET");
+
+        IgniteCache<Integer, double[]> data = ignite.createCache(trainingSetCacheCfg);
+
+        Random rnd = new Random(0);
+        for (int i = 0; i < size; i++) {
+            double x = rnd.nextDouble() - 0.5;
+            data.put(i, new double[] {x, x > 0 ? 1 : 0});
+        }
+
+        ArrayList<FeatureMeta> meta = new ArrayList<>();
+        meta.add(new FeatureMeta("", 0, false));
+        RandomForestRegressionTrainer trainer = new RandomForestRegressionTrainer(meta)
+            .withAmountOfTrees(5)
+            .withFeaturesCountSelectionStrgy(x -> 2);
+
+        ModelsComposition mdl = trainer.fit(
+            ignite, data,
+            (k, v) -> VectorUtils.of(Arrays.copyOf(v, v.length - 1)),
+            (k, v) -> v[v.length - 1]
+        );
+
+        assertTrue(mdl.getPredictionsAggregator() instanceof MeanValuePredictionsAggregator);
+        assertEquals(5, mdl.getModels().size());
+    }
+}
index cc51352..6752652 100644 (file)
@@ -33,7 +33,8 @@ import org.junit.runners.Suite;
     GiniFeatureHistogramTest.class,
     MSEHistogramTest.class,
     NormalDistributionStatisticsComputerTest.class,
-    RandomForestTest.class
+    RandomForestTest.class,
+    RandomForestIntegrationTest.class
 })
 public class RandomForestTreeTestSuite {
 }