IGNITE-9514 :[ML] Reduce time for the updating models on many partitions
authorzaleslaw <zaleslaw.sin@gmail.com>
Tue, 25 Sep 2018 12:51:54 +0000 (15:51 +0300)
committerYury Babak <ybabak@gridgain.com>
Tue, 25 Sep 2018 12:51:54 +0000 (15:51 +0300)
this closes #4788

65 files changed:
examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestClassificationExample.java
examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestRegressionExample.java
modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBBinaryClassifierTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java
modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/bootstrapping/BootstrappedVector.java
modules/ml/src/main/java/org/apache/ignite/ml/environment/LearningEnvironment.java
modules/ml/src/main/java/org/apache/ignite/ml/environment/logging/ConsoleLogger.java
modules/ml/src/main/java/org/apache/ignite/ml/environment/logging/CustomMLLogger.java
modules/ml/src/main/java/org/apache/ignite/ml/environment/logging/MLLogger.java
modules/ml/src/main/java/org/apache/ignite/ml/environment/parallelism/NoParallelismStrategy.java
modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java
modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/maxabsscaling/MaxAbsScalerTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/data/TreeDataIndex.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculator.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculator.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/NodeSplit.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/TreeNode.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogramsComputer.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/statistics/LeafValuesComputer.java
modules/ml/src/main/java/org/apache/ignite/ml/util/ModelTrace.java
modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/common/TrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/composition/predictionsaggregator/MeanValuePredictionsAggregatorTest.java
modules/ml/src/test/java/org/apache/ignite/ml/composition/predictionsaggregator/OnMajorityPredictionsAggregatorTest.java
modules/ml/src/test/java/org/apache/ignite/ml/dataset/feature/ObjectHistogramTest.java
modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java
modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java
modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java
modules/ml/src/test/java/org/apache/ignite/ml/math/VectorUtilsTest.java
modules/ml/src/test/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java
modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineMdlTest.java
modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/imputing/ImputerTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/maxabsscaling/MaxAbsScalerPreprocessorTest.java
modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/maxabsscaling/MaxAbsScalerTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/minmaxscaling/MinMaxScalerTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/tree/data/DecisionTreeDataTest.java
modules/ml/src/test/java/org/apache/ignite/ml/tree/data/TreeDataIndexTest.java
modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculatorTest.java
modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculatorTest.java
modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java
modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestTest.java
modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/data/impurity/GiniFeatureHistogramTest.java
modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogramTest.java
modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/data/impurity/MSEHistogramTest.java
modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/data/statistics/NormalDistributionStatisticsComputerTest.java

index 6194153..4693744 100644 (file)
@@ -70,7 +70,7 @@ public class RandomForestClassificationExample {
                 RandomForestClassifierTrainer classifier = new RandomForestClassifierTrainer(
                     IntStream.range(0, data[0].length - 1).mapToObj(
                         x -> new FeatureMeta("", idx.getAndIncrement(), false)).collect(Collectors.toList())
-                ).withCountOfTrees(101)
+                ).withAmountOfTrees(101)
                     .withFeaturesCountSelectionStrgy(FeaturesCountSelectionStrategies.ONE_THIRD)
                     .withMaxDepth(4)
                     .withMinImpurityDelta(0.)
index 5f010f2..ee0c1c2 100644 (file)
@@ -74,7 +74,7 @@ public class RandomForestRegressionExample {
                 RandomForestRegressionTrainer trainer = new RandomForestRegressionTrainer(
                     IntStream.range(0, data[0].length - 1).mapToObj(
                         x -> new FeatureMeta("", idx.getAndIncrement(), false)).collect(Collectors.toList())
-                ).withCountOfTrees(101)
+                ).withAmountOfTrees(101)
                     .withFeaturesCountSelectionStrgy(FeaturesCountSelectionStrategies.ONE_THIRD)
                     .withMaxDepth(4)
                     .withMinImpurityDelta(0.)
index f6ddfed..8682a46 100644 (file)
@@ -82,13 +82,13 @@ public abstract class GDBBinaryClassifierTrainer extends GDBTrainer {
             );
 
         if (uniqLabels != null && uniqLabels.size() == 2) {
-            ArrayList<Double> lblsArray = new ArrayList<>(uniqLabels);
-            externalFirstCls = lblsArray.get(0);
-            externalSecondCls = lblsArray.get(1);
+            ArrayList<Double> lblsArr = new ArrayList<>(uniqLabels);
+            externalFirstCls = lblsArr.get(0);
+            externalSecondCls = lblsArr.get(1);
             return true;
-        } else {
-            return false;
         }
+        else
+            return false;
     }
 
     /** {@inheritDoc} */
index 737495e..e689b91 100644 (file)
@@ -57,7 +57,7 @@ public class GDBLearningStrategy {
     protected IgniteSupplier<DatasetTrainer<? extends Model<Vector, Double>, Double>> baseMdlTrainerBuilder;
 
     /** Mean label value. */
-    protected double meanLabelValue;
+    protected double meanLbVal;
 
     /** Sample size. */
     protected long sampleSize;
@@ -111,7 +111,7 @@ public class GDBLearningStrategy {
         for (int i = 0; i < cntOfIterations; i++) {
             double[] weights = Arrays.copyOf(compositionWeights, models.size());
 
-            WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(weights, meanLabelValue);
+            WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(weights, meanLbVal);
             ModelsComposition currComposition = new ModelsComposition(models, aggregator);
             if (convCheck.isConverged(datasetBuilder, currComposition))
                 break;
@@ -142,13 +142,13 @@ public class GDBLearningStrategy {
         if(mdlToUpdate != null) {
             models.addAll(mdlToUpdate.getModels());
             WeightedPredictionsAggregator aggregator = (WeightedPredictionsAggregator) mdlToUpdate.getPredictionsAggregator();
-            meanLabelValue = aggregator.getBias();
+            meanLbVal = aggregator.getBias();
             compositionWeights = new double[models.size() + cntOfIterations];
             for(int i = 0; i < models.size(); i++)
                 compositionWeights[i] = aggregator.getWeights()[i];
-        } else {
-            compositionWeights = new double[cntOfIterations];
         }
+        else
+            compositionWeights = new double[cntOfIterations];
 
         Arrays.fill(compositionWeights, models.size(), compositionWeights.length, defaultGradStepSize);
         return models;
@@ -208,10 +208,10 @@ public class GDBLearningStrategy {
     /**
      * Sets mean label value.
      *
-     * @param meanLabelValue Mean label value.
+     * @param meanLbVal Mean label value.
      */
-    public GDBLearningStrategy withMeanLabelValue(double meanLabelValue) {
-        this.meanLabelValue = meanLabelValue;
+    public GDBLearningStrategy withMeanLabelValue(double meanLbVal) {
+        this.meanLbVal = meanLbVal;
         return this;
     }
 
@@ -262,6 +262,6 @@ public class GDBLearningStrategy {
 
     /** */
     public double getMeanValue() {
-        return meanLabelValue;
+        return meanLbVal;
     }
 }
index aedd0fd..573b256 100644 (file)
@@ -68,9 +68,9 @@ public class BootstrappedVector extends LabeledVector<Vector, Double> {
 
     /** {@inheritDoc} */
     @Override public int hashCode() {
-        int result = super.hashCode();
-        result = 31 * result + Arrays.hashCode(counters);
-        return result;
+        int res = super.hashCode();
+        res = 31 * res + Arrays.hashCode(counters);
+        return res;
     }
 
     /** {@inheritDoc} */
index 2b94a2f..f5fb693 100644 (file)
@@ -41,9 +41,9 @@ public interface LearningEnvironment {
     /**
      * Returns an instance of logger for specific class.
      *
-     * @param forClass Logging class context.
+     * @param forCls Logging class context.
      */
-    public <T> MLLogger logger(Class<T> forClass);
+    public <T> MLLogger logger(Class<T> forCls);
 
     /**
      * Creates an instance of LearningEnvironmentBuilder.
index 7efa29c..e064fc3 100644 (file)
@@ -28,7 +28,7 @@ public class ConsoleLogger implements MLLogger {
     /** Maximum Verbose level. */
     private final VerboseLevel maxVerboseLevel;
     /** Class name. */
-    private final String className;
+    private final String clsName;
 
     /**
      * Creates an instance of ConsoleLogger.
@@ -37,7 +37,7 @@ public class ConsoleLogger implements MLLogger {
      * @param clsName Class name.
      */
     private ConsoleLogger(VerboseLevel maxVerboseLevel, String clsName) {
-        this.className = clsName;
+        this.clsName = clsName;
         this.maxVerboseLevel = maxVerboseLevel;
     }
 
@@ -75,7 +75,7 @@ public class ConsoleLogger implements MLLogger {
      */
     private void print(VerboseLevel verboseLevel, String line) {
         if (this.maxVerboseLevel.compareTo(verboseLevel) >= 0)
-            System.out.println(String.format("%s [%s] %s", className, verboseLevel.name(), line));
+            System.out.println(String.format("%s [%s] %s", clsName, verboseLevel.name(), line));
     }
 
     /**
index 65bc4cb..90aed14 100644 (file)
@@ -27,29 +27,29 @@ import org.apache.ignite.ml.math.primitives.vector.Vector;
  */
 public class CustomMLLogger implements MLLogger {
     /** Ignite logger instance. */
-    private final IgniteLogger logger;
+    private final IgniteLogger log;
 
     /**
      * Creates an instance of CustomMLLogger.
      *
-     * @param logger Basic Logger.
+     * @param log Basic Logger.
      */
-    private CustomMLLogger(IgniteLogger logger) {
-        this.logger = logger;
+    private CustomMLLogger(IgniteLogger log) {
+        this.log = log;
     }
 
     /**
      * Returns factory for OnIgniteLogger instantiating.
      *
-     * @param rootLogger Root logger.
+     * @param rootLog Root logger.
      */
-    public static Factory factory(IgniteLogger rootLogger) {
-        return new Factory(rootLogger);
+    public static Factory factory(IgniteLogger rootLog) {
+        return new Factory(rootLog);
     }
 
     /** {@inheritDoc} */
     @Override public Vector log(Vector vector) {
-        Tracer.showAscii(vector, logger);
+        Tracer.showAscii(vector, log);
         return vector;
     }
 
@@ -73,10 +73,10 @@ public class CustomMLLogger implements MLLogger {
     private void log(VerboseLevel verboseLevel, String line) {
         switch (verboseLevel) {
             case LOW:
-                logger.info(line);
+                log.info(line);
                 break;
             case HIGH:
-                logger.debug(line);
+                log.debug(line);
                 break;
         }
     }
@@ -86,20 +86,20 @@ public class CustomMLLogger implements MLLogger {
      */
     private static class Factory implements MLLogger.Factory {
         /** Root logger. */
-        private IgniteLogger rootLogger;
+        private IgniteLogger rootLog;
 
         /**
          * Creates an instance of factory.
          *
-         * @param rootLogger Root logger.
+         * @param rootLog Root logger.
          */
-        public Factory(IgniteLogger rootLogger) {
-            this.rootLogger = rootLogger;
+        public Factory(IgniteLogger rootLog) {
+            this.rootLog = rootLog;
         }
 
         /** {@inheritDoc} */
         @Override public <T> MLLogger create(Class<T> targetCls) {
-            return new CustomMLLogger(rootLogger.getLogger(targetCls));
+            return new CustomMLLogger(rootLog.getLogger(targetCls));
         }
     }
 }
index 872b947..b2b4739 100644 (file)
@@ -28,7 +28,14 @@ public interface MLLogger {
      * Logging verbose level.
      */
     enum VerboseLevel {
-        OFF, LOW, HIGH
+        /** Disabled. */
+        OFF,
+
+        /** Low. */
+        LOW,
+
+        /** High. */
+        HIGH
     }
 
     /**
index 5f605a7..759e06a 100644 (file)
@@ -46,15 +46,17 @@ public class NoParallelismStrategy implements ParallelismStrategy {
      * @param <T> Type of result.
      */
     public static class Stub<T> implements Promise<T> {
-        private T result;
+
+        /** Result. */
+        private T res;
 
         /**
          * Create an instance of Stub
          *
-         * @param result Execution result.
+         * @param res Execution result.
          */
-        public Stub(T result) {
-            this.result = result;
+        public Stub(T res) {
+            this.res = res;
         }
 
         /** {@inheritDoc} */
@@ -74,14 +76,14 @@ public class NoParallelismStrategy implements ParallelismStrategy {
 
         /** {@inheritDoc} */
         @Override public T get() throws InterruptedException, ExecutionException {
-            return result;
+            return res;
         }
 
         /** {@inheritDoc} */
         @Override public T get(long timeout,
             @NotNull TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException {
 
-            return result;
+            return res;
         }
     }
 }
index 0d03ee5..3de73bd 100644 (file)
@@ -63,9 +63,9 @@ public class KNNClassificationModel extends NNClassificationModel implements Exp
             List<LabeledVector> neighbors = findKNearestNeighbors(v);
 
             return classify(neighbors, v, stgy);
-        } else {
-            throw new IllegalStateException("The train kNN dataset is null");
         }
+        else
+            throw new IllegalStateException("The train kNN dataset is null");
     }
 
     /** */
@@ -91,6 +91,7 @@ public class KNNClassificationModel extends NNClassificationModel implements Exp
         return Arrays.asList(getKClosestVectors(neighborsToFilter, getDistances(v, neighborsToFilter)));
     }
 
+    /** */
     private List<LabeledVector> findKNearestNeighborsInDataset(Vector v,
         Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset) {
         List<LabeledVector> neighborsFromPartitions = dataset.compute(data -> {
@@ -137,10 +138,10 @@ public class KNNClassificationModel extends NNClassificationModel implements Exp
     /**
      * Copy parameters from other model and save all datasets from it.
      *
-     * @param model Model.
+     * @param mdl Model.
      */
-    public void copyStateFrom(KNNClassificationModel model) {
-        this.copyParametersFrom(model);
-        datasets.addAll(model.datasets);
+    public void copyStateFrom(KNNClassificationModel mdl) {
+        this.copyParametersFrom(mdl);
+        datasets.addAll(mdl.datasets);
     }
 }
index 1cac909..c75c5bb 100644 (file)
@@ -115,7 +115,7 @@ public class MLPTrainer<P extends Serializable> extends MultiLabelDatasetTrainer
     }
 
     /** {@inheritDoc} */
-    @Override protected <K, V> MultilayerPerceptron updateModel(MultilayerPerceptron lastLearnedModel,
+    @Override protected <K, V> MultilayerPerceptron updateModel(MultilayerPerceptron lastLearnedMdl,
         DatasetBuilder<K, V> datasetBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, double[]> lbExtractor) {
 
@@ -128,8 +128,8 @@ public class MLPTrainer<P extends Serializable> extends MultiLabelDatasetTrainer
             new SimpleLabeledDatasetDataBuilder<>(featureExtractor, lbExtractor)
         )) {
             MultilayerPerceptron mdl;
-            if (lastLearnedModel != null)
-                mdl = lastLearnedModel;
+            if (lastLearnedMdl != null)
+                mdl = lastLearnedMdl;
             else {
                 MLPArchitecture arch = archSupplier.apply(dataset);
                 mdl = new MultilayerPerceptron(arch, new RandomInitializer(seed));
@@ -196,7 +196,7 @@ public class MLPTrainer<P extends Serializable> extends MultiLabelDatasetTrainer
                 );
 
                 if (totUp == null)
-                    return getLastTrainedModelOrThrowEmptyDatasetException(lastLearnedModel);
+                    return getLastTrainedModelOrThrowEmptyDatasetException(lastLearnedMdl);
 
                 P update = updatesStgy.allUpdatesReducer().apply(totUp);
                 mdl = updater.update(mdl, update);
index d3e5734..c8b1dca 100644 (file)
@@ -69,12 +69,12 @@ public class MaxAbsScalerTrainer<K, V> implements PreprocessingTrainer<K, V, Vec
                     if (b == null)
                         return a;
 
-                    double[] result = new double[a.length];
+                    double[] res = new double[a.length];
 
-                    for (int i = 0; i < result.length; i++) {
-                        result[i] = Math.max(Math.abs(a[i]), Math.abs(b[i]));
-                    }
-                    return result;
+                    for (int i = 0; i < res.length; i++)
+                        res[i] = Math.max(Math.abs(a[i]), Math.abs(b[i]));
+
+                    return res;
                 });
             return new MaxAbsScalerPreprocessor<>(maxAbs, basePreprocessor);
         }
index 7cbb1dc..ec60034 100644 (file)
@@ -96,12 +96,13 @@ public class SVMLinearMultiClassClassificationTrainer
                     return 0.0;
             };
 
-            SVMLinearBinaryClassificationModel model;
+            SVMLinearBinaryClassificationModel updatedMdl;
+
             if (mdl == null)
-                model = learnNewModel(trainer, datasetBuilder, featureExtractor, lbTransformer);
+                updatedMdl = learnNewModel(trainer, datasetBuilder, featureExtractor, lbTransformer);
             else
-                model = updateModel(mdl, clsLb, trainer, datasetBuilder, featureExtractor, lbTransformer);
-            multiClsMdl.add(clsLb, model);
+                updatedMdl = updateModel(mdl, clsLb, trainer, datasetBuilder, featureExtractor, lbTransformer);
+            multiClsMdl.add(clsLb, updatedMdl);
         });
 
         return multiClsMdl;
index 490c53d..5c3913e 100644 (file)
@@ -70,9 +70,9 @@ public abstract class DatasetTrainer<M extends Model, L> {
         IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
 
         if(mdl != null) {
-            if(checkState(mdl)) {
+            if (checkState(mdl))
                 return updateModel(mdl, datasetBuilder, featureExtractor, lbExtractor);
-            else {
+            else {
                 environment.logger(getClass()).log(
                     MLLogger.VerboseLevel.HIGH,
                     "Model cannot be updated because of initial state of " +
index 45774cb..b40ca93 100644 (file)
@@ -164,7 +164,7 @@ public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends Dataset
     private StepFunction<T>[] calculateImpurityForAllColumns(Dataset<EmptyContext, DecisionTreeData> dataset,
         TreeFilter filter, ImpurityMeasureCalculator<T> impurityCalc, int depth) {
 
-        StepFunction<T>[] result = dataset.compute(
+        return dataset.compute(
             part -> {
                 if (compressor != null)
                     return compressor.compress(impurityCalc.calculate(part, filter, depth));
@@ -172,8 +172,6 @@ public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends Dataset
                     return impurityCalc.calculate(part, filter, depth);
             }, this::reduce
         );
-
-        return result;
     }
 
     /**
@@ -314,16 +312,16 @@ public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends Dataset
                 .append(String.format("%.4f", leaf.getVal()));
         }
         else if (node instanceof DecisionTreeConditionalNode) {
-            DecisionTreeConditionalNode condition = (DecisionTreeConditionalNode)node;
+            DecisionTreeConditionalNode cond = (DecisionTreeConditionalNode)node;
             String prefix = depth == 0 ? "" : (isThen ? "then " : "else ");
             builder.append(String.format("%sif (x", prefix))
-                .append(condition.getCol())
+                .append(cond.getCol())
                 .append(" > ")
-                .append(String.format("%.4f", condition.getThreshold()))
+                .append(String.format("%.4f", cond.getThreshold()))
                 .append(pretty ? ")\n" : ") ");
-            printTree(condition.getThenNode(), depth + 1, builder, pretty, true);
+            printTree(cond.getThenNode(), depth + 1, builder, pretty, true);
             builder.append(pretty ? "\n" : " ");
-            printTree(condition.getElseNode(), depth + 1, builder, pretty, false);
+            printTree(cond.getElseNode(), depth + 1, builder, pretty, false);
         }
         else
             throw new IllegalArgumentException();
index 91ec8e1..58552f4 100644 (file)
@@ -87,11 +87,11 @@ public class DecisionTreeClassificationTrainer extends DecisionTree<GiniImpurity
     /**
      * Sets useIndex parameter and returns trainer instance.
      *
-     * @param useIndex Use index.
+     * @param useIdx Use index.
      * @return Decision tree trainer.
      */
-    public DecisionTreeClassificationTrainer withUseIndex(boolean useIndex) {
-        this.usingIdx = useIndex;
+    public DecisionTreeClassificationTrainer withUseIndex(boolean useIdx) {
+        this.usingIdx = useIdx;
         return this;
     }
 
index 6ebbda1..caac168 100644 (file)
@@ -43,15 +43,16 @@ import org.apache.ignite.ml.tree.data.DecisionTreeDataBuilder;
  * several learning iterations.
  */
 public class GDBOnTreesLearningStrategy  extends GDBLearningStrategy {
-    private boolean useIndex;
+    /** Use index. */
+    private boolean useIdx;
 
     /**
      * Create an instance of learning strategy.
      *
-     * @param useIndex Use index.
+     * @param useIdx Use index.
      */
-    public GDBOnTreesLearningStrategy(boolean useIndex) {
-        this.useIndex = useIndex;
+    public GDBOnTreesLearningStrategy(boolean useIdx) {
+        this.useIdx = useIdx;
     }
 
     /** {@inheritDoc} */
@@ -70,23 +71,23 @@ public class GDBOnTreesLearningStrategy  extends GDBLearningStrategy {
 
         try (Dataset<EmptyContext, DecisionTreeData> dataset = datasetBuilder.build(
             new EmptyContextBuilder<>(),
-            new DecisionTreeDataBuilder<>(featureExtractor, lbExtractor, useIndex)
+            new DecisionTreeDataBuilder<>(featureExtractor, lbExtractor, useIdx)
         )) {
             for (int i = 0; i < cntOfIterations; i++) {
                 double[] weights = Arrays.copyOf(compositionWeights, models.size());
-                WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(weights, meanLabelValue);
+                WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(weights, meanLbVal);
                 ModelsComposition currComposition = new ModelsComposition(models, aggregator);
 
                 if(convCheck.isConverged(dataset, currComposition))
                     break;
 
                 dataset.compute(part -> {
-                    if(part.getCopyOfOriginalLabels() == null)
-                        part.setCopyOfOriginalLabels(Arrays.copyOf(part.getLabels(), part.getLabels().length));
+                    if (part.getCopiedOriginalLabels() == null)
+                        part.setCopiedOriginalLabels(Arrays.copyOf(part.getLabels(), part.getLabels().length));
 
                     for(int j = 0; j < part.getLabels().length; j++) {
                         double mdlAnswer = currComposition.apply(VectorUtils.of(part.getFeatures()[j]));
-                        double originalLbVal = externalLbToInternalMapping.apply(part.getCopyOfOriginalLabels()[j]);
+                        double originalLbVal = externalLbToInternalMapping.apply(part.getCopiedOriginalLabels()[j]);
                         part.getLabels()[j] = -loss.gradient(sampleSize, originalLbVal, mdlAnswer);
                     }
                 });
index b8a16dc..335f751 100644 (file)
@@ -28,13 +28,13 @@ import org.apache.ignite.ml.tree.TreeFilter;
  */
 public class DecisionTreeData extends FeatureMatrixWithLabelsOnHeapData implements AutoCloseable {
     /** Copy of vector with original labels. Auxiliary for Gradient Boosting on Trees.*/
-    private double[] copyOfOriginalLabels;
+    private double[] copiedOriginalLabels;
 
     /** Indexes cache. */
     private final List<TreeDataIndex> indexesCache;
 
     /** Build index. */
-    private final boolean buildIndex;
+    private final boolean buildIdx;
 
     /**
      * Constructs a new instance of decision tree data.
@@ -45,7 +45,7 @@ public class DecisionTreeData extends FeatureMatrixWithLabelsOnHeapData implemen
      */
     public DecisionTreeData(double[][] features, double[] labels, boolean buildIdx) {
         super(features, labels);
-        this.buildIndex = buildIdx;
+        this.buildIdx = buildIdx;
 
         indexesCache = new ArrayList<>();
         if (buildIdx)
@@ -81,7 +81,7 @@ public class DecisionTreeData extends FeatureMatrixWithLabelsOnHeapData implemen
             }
         }
 
-        return new DecisionTreeData(newFeatures, newLabels, buildIndex);
+        return new DecisionTreeData(newFeatures, newLabels, buildIdx);
     }
 
     /**
@@ -129,13 +129,13 @@ public class DecisionTreeData extends FeatureMatrixWithLabelsOnHeapData implemen
     }
 
     /** */
-    public double[] getCopyOfOriginalLabels() {
-        return copyOfOriginalLabels;
+    public double[] getCopiedOriginalLabels() {
+        return copiedOriginalLabels;
     }
 
     /** */
-    public void setCopyOfOriginalLabels(double[] copyOfOriginalLabels) {
-        this.copyOfOriginalLabels = copyOfOriginalLabels;
+    public void setCopiedOriginalLabels(double[] copiedOriginalLabels) {
+        this.copiedOriginalLabels = copiedOriginalLabels;
     }
 
     /** {@inheritDoc} */
index 6678218..4436b07 100644 (file)
@@ -43,7 +43,7 @@ public class DecisionTreeDataBuilder<K, V, C extends Serializable>
     private final IgniteBiFunction<K, V, Double> lbExtractor;
 
     /** Build index. */
-    private final boolean buildIndex;
+    private final boolean buildIdx;
 
     /**
      * Constructs a new instance of decision tree data builder.
@@ -56,7 +56,7 @@ public class DecisionTreeDataBuilder<K, V, C extends Serializable>
         IgniteBiFunction<K, V, Double> lbExtractor, boolean buildIdx) {
         this.featureExtractor = featureExtractor;
         this.lbExtractor = lbExtractor;
-        this.buildIndex = buildIdx;
+        this.buildIdx = buildIdx;
     }
 
     /** {@inheritDoc} */
@@ -75,6 +75,6 @@ public class DecisionTreeDataBuilder<K, V, C extends Serializable>
             ptr++;
         }
 
-        return new DecisionTreeData(features, labels, buildIndex);
+        return new DecisionTreeData(features, labels, buildIdx);
     }
 }
index 88ce190..a86f78d 100644 (file)
@@ -26,7 +26,7 @@ import org.apache.ignite.ml.tree.TreeFilter;
  */
 public class TreeDataIndex {
     /** Index containing IDs of rows as if they is sorted by feature values. */
-    private final int[][] index;
+    private final int[][] idx;
 
     /** Original features table. */
     private final double[][] features;
@@ -48,9 +48,9 @@ public class TreeDataIndex {
         int cols = features.length == 0 ? 0 : features[0].length;
 
         double[][] featuresCp = new double[rows][cols];
-        index = new int[rows][cols];
+        idx = new int[rows][cols];
         for (int row = 0; row < rows; row++) {
-            Arrays.fill(index[row], row);
+            Arrays.fill(idx[row], row);
             featuresCp[row] = Arrays.copyOf(features[row], cols);
         }
 
@@ -61,12 +61,12 @@ public class TreeDataIndex {
     /**
      * Constructs an instance of TreeDataIndex
      *
-     * @param indexProj Index projection.
+     * @param idxProj Index projection.
      * @param features Features.
      * @param labels Labels.
      */
-    private TreeDataIndex(int[][] indexProj, double[][] features, double[] labels) {
-        this.index = indexProj;
+    private TreeDataIndex(int[][] idxProj, double[][] features, double[] labels) {
+        this.idx = idxProj;
         this.features = features;
         this.labels = labels;
     }
@@ -79,7 +79,7 @@ public class TreeDataIndex {
      * @return Label value.
      */
     public double labelInSortedOrder(int k, int featureId) {
-        return labels[index[k][featureId]];
+        return labels[idx[k][featureId]];
     }
 
     /**
@@ -90,7 +90,7 @@ public class TreeDataIndex {
      * @return Features vector.
      */
     public double[] featuresInSortedOrder(int k, int featureId) {
-        return features[index[k][featureId]];
+        return features[idx[k][featureId]];
     }
 
     /**
@@ -117,30 +117,30 @@ public class TreeDataIndex {
                 projSize++;
         }
 
-        int[][] projection = new int[projSize][columnsCount()];
+        int[][] prj = new int[projSize][columnsCount()];
         for(int feature = 0; feature < columnsCount(); feature++) {
             int ptr = 0;
             for(int row = 0; row < rowsCount(); row++) {
                 if(filter.test(featuresInSortedOrder(row, feature)))
-                    projection[ptr++][feature] = index[row][feature];
+                    prj[ptr++][feature] = idx[row][feature];
             }
         }
 
-        return new TreeDataIndex(projection, features, labels);
+        return new TreeDataIndex(prj, features, labels);
     }
 
     /**
      * @return count of rows in current index.
      */
     public int rowsCount() {
-        return index.length;
+        return idx.length;
     }
 
     /**
      * @return count of columns in current index.
      */
     public int columnsCount() {
-        return rowsCount() == 0 ? 0 : index[0].length ;
+        return rowsCount() == 0 ? 0 : idx[0].length;
     }
 
     /**
@@ -168,9 +168,9 @@ public class TreeDataIndex {
                     features[i][col] = features[j][col];
                     features[j][col] = tmpFeature;
 
-                    int tmpLb = index[i][col];
-                    index[i][col] = index[j][col];
-                    index[j][col] = tmpLb;
+                    int tmpLb = idx[i][col];
+                    idx[i][col] = idx[j][col];
+                    idx[j][col] = tmpLb;
 
                     i++;
                     j--;
index 0c67535..b97e297 100644 (file)
@@ -32,15 +32,15 @@ import org.apache.ignite.ml.tree.impurity.util.StepFunction;
  */
 public abstract class ImpurityMeasureCalculator<T extends ImpurityMeasure<T>> implements Serializable {
     /** Use index structure instead of using sorting while learning. */
-    protected final boolean useIndex;
+    protected final boolean useIdx;
 
     /**
      * Constructs an instance of ImpurityMeasureCalculator.
      *
-     * @param useIndex Use index.
+     * @param useIdx Use index.
      */
-    public ImpurityMeasureCalculator(boolean useIndex) {
-        this.useIndex = useIndex;
+    public ImpurityMeasureCalculator(boolean useIdx) {
+        this.useIdx = useIdx;
     }
 
     /**
@@ -61,7 +61,7 @@ public abstract class ImpurityMeasureCalculator<T extends ImpurityMeasure<T>> im
      * @return Columns count in current dataset.
      */
     protected int columnsCount(DecisionTreeData data, TreeDataIndex idx) {
-        return useIndex ? idx.columnsCount() : data.getFeatures()[0].length;
+        return useIdx ? idx.columnsCount() : data.getFeatures()[0].length;
     }
 
     /**
@@ -72,7 +72,7 @@ public abstract class ImpurityMeasureCalculator<T extends ImpurityMeasure<T>> im
      * @return rows count in current dataset
      */
     protected int rowsCount(DecisionTreeData data, TreeDataIndex idx) {
-        return useIndex ? idx.rowsCount() : data.getFeatures().length;
+        return useIdx ? idx.rowsCount() : data.getFeatures().length;
     }
 
     /**
@@ -85,7 +85,7 @@ public abstract class ImpurityMeasureCalculator<T extends ImpurityMeasure<T>> im
      * @return label value in according to kth order statistic
      */
     protected double getLabelValue(DecisionTreeData data, TreeDataIndex idx, int featureId, int k) {
-        return useIndex ? idx.labelInSortedOrder(k, featureId) : data.getLabels()[k];
+        return useIdx ? idx.labelInSortedOrder(k, featureId) : data.getLabels()[k];
     }
 
     /**
@@ -98,10 +98,10 @@ public abstract class ImpurityMeasureCalculator<T extends ImpurityMeasure<T>> im
      * @return feature value in according to kth order statistic.
      */
     protected double getFeatureValue(DecisionTreeData data, TreeDataIndex idx, int featureId, int k) {
-        return useIndex ? idx.featureInSortedOrder(k, featureId) : data.getFeatures()[k][featureId];
+        return useIdx ? idx.featureInSortedOrder(k, featureId) : data.getFeatures()[k][featureId];
     }
 
     protected Vector getFeatureValues(DecisionTreeData data, TreeDataIndex idx, int featureId, int k) {
-        return VectorUtils.of(useIndex ? idx.featuresInSortedOrder(k, featureId) : data.getFeatures()[k]);
+        return VectorUtils.of(useIdx ? idx.featuresInSortedOrder(k, featureId) : data.getFeatures()[k]);
     }
 }
index 38b3097..6a1eb0c 100644 (file)
@@ -39,22 +39,22 @@ public class GiniImpurityMeasureCalculator extends ImpurityMeasureCalculator<Gin
      * Constructs a new instance of Gini impurity measure calculator.
      *
      * @param lbEncoder Label encoder which defines integer value for every label class.
-     * @param useIndex Use index while calculate.
+     * @param useIdx Use index while calculate.
      */
-    public GiniImpurityMeasureCalculator(Map<Double, Integer> lbEncoder, boolean useIndex) {
-        super(useIndex);
+    public GiniImpurityMeasureCalculator(Map<Double, Integer> lbEncoder, boolean useIdx) {
+        super(useIdx);
         this.lbEncoder = lbEncoder;
     }
 
     /** {@inheritDoc} */
     @SuppressWarnings("unchecked")
     @Override public StepFunction<GiniImpurityMeasure>[] calculate(DecisionTreeData data, TreeFilter filter, int depth) {
-        TreeDataIndex index = null;
+        TreeDataIndex idx = null;
         boolean canCalculate = false;
 
-        if (useIndex) {
-            index = data.createIndexByFilter(depth, filter);
-            canCalculate = index.rowsCount() > 0;
+        if (useIdx) {
+            idx = data.createIndexByFilter(depth, filter);
+            canCalculate = idx.rowsCount() > 0;
         }
         else {
             data = data.filter(filter);
@@ -62,47 +62,47 @@ public class GiniImpurityMeasureCalculator extends ImpurityMeasureCalculator<Gin
         }
 
         if (canCalculate) {
-            int rowsCnt = rowsCount(data, index);
-            int colsCnt = columnsCount(data, index);
+            int rowsCnt = rowsCount(data, idx);
+            int colsCnt = columnsCount(data, idx);
 
             StepFunction<GiniImpurityMeasure>[] res = new StepFunction[colsCnt];
 
             long right[] = new long[lbEncoder.size()];
             for (int i = 0; i < rowsCnt; i++) {
-                double lb = getLabelValue(data, index, 0, i);
+                double lb = getLabelValue(data, idx, 0, i);
                 right[getLabelCode(lb)]++;
             }
 
             for (int col = 0; col < res.length; col++) {
-                if(!useIndex)
+                if (!useIdx)
                     data.sort(col);
 
                 double[] x = new double[rowsCnt + 1];
                 GiniImpurityMeasure[] y = new GiniImpurityMeasure[rowsCnt + 1];
 
                 long[] left = new long[lbEncoder.size()];
-                long[] rightCopy = Arrays.copyOf(right, right.length);
+                long[] rightCp = Arrays.copyOf(right, right.length);
 
                 int xPtr = 0, yPtr = 0;
                 x[xPtr++] = Double.NEGATIVE_INFINITY;
                 y[yPtr++] = new GiniImpurityMeasure(
                     Arrays.copyOf(left, left.length),
-                    Arrays.copyOf(rightCopy, rightCopy.length)
+                    Arrays.copyOf(rightCp, rightCp.length)
                 );
 
                 for (int i = 0; i < rowsCnt; i++) {
-                    double lb = getLabelValue(data, index, col, i);
+                    double lb = getLabelValue(data, idx, col, i);
                     left[getLabelCode(lb)]++;
-                    rightCopy[getLabelCode(lb)]--;
+                    rightCp[getLabelCode(lb)]--;
 
-                    double featureVal = getFeatureValue(data, index, col, i);
-                    if (i < (rowsCnt - 1) && getFeatureValue(data, index, col, i + 1) == featureVal)
+                    double featureVal = getFeatureValue(data, idx, col, i);
+                    if (i < (rowsCnt - 1) && getFeatureValue(data, idx, col, i + 1) == featureVal)
                         continue;
 
                     x[xPtr++] = featureVal;
                     y[yPtr++] = new GiniImpurityMeasure(
                         Arrays.copyOf(left, left.length),
-                        Arrays.copyOf(rightCopy, rightCopy.length)
+                        Arrays.copyOf(rightCp, rightCp.length)
                     );
                 }
 
index 1788737..3629768 100644 (file)
@@ -33,20 +33,20 @@ public class MSEImpurityMeasureCalculator extends ImpurityMeasureCalculator<MSEI
     /**
      * Constructs an instance of MSEImpurityMeasureCalculator.
      *
-     * @param useIndex Use index while calculate.
+     * @param useIdx Use index while calculate.
      */
-    public MSEImpurityMeasureCalculator(boolean useIndex) {
-        super(useIndex);
+    public MSEImpurityMeasureCalculator(boolean useIdx) {
+        super(useIdx);
     }
 
     /** {@inheritDoc} */
     @Override public StepFunction<MSEImpurityMeasure>[] calculate(DecisionTreeData data, TreeFilter filter, int depth) {
-        TreeDataIndex index = null;
-        boolean canCalculate = false;
+        TreeDataIndex idx = null;
+        boolean canCalculate;
 
-        if (useIndex) {
-            index = data.createIndexByFilter(depth, filter);
-            canCalculate = index.rowsCount() > 0;
+        if (useIdx) {
+            idx = data.createIndexByFilter(depth, filter);
+            canCalculate = idx.rowsCount() > 0;
         }
         else {
             data = data.filter(filter);
@@ -54,8 +54,8 @@ public class MSEImpurityMeasureCalculator extends ImpurityMeasureCalculator<MSEI
         }
 
         if (canCalculate) {
-            int rowsCnt = rowsCount(data, index);
-            int colsCnt = columnsCount(data, index);
+            int rowsCnt = rowsCount(data, idx);
+            int colsCnt = columnsCount(data, idx);
 
             @SuppressWarnings("unchecked")
             StepFunction<MSEImpurityMeasure>[] res = new StepFunction[colsCnt];
@@ -63,14 +63,14 @@ public class MSEImpurityMeasureCalculator extends ImpurityMeasureCalculator<MSEI
             double rightYOriginal = 0;
             double rightY2Original = 0;
             for (int i = 0; i < rowsCnt; i++) {
-                double lbVal = getLabelValue(data, index, 0, i);
+                double lbVal = getLabelValue(data, idx, 0, i);
 
                 rightYOriginal += lbVal;
                 rightY2Original += Math.pow(lbVal, 2);
             }
 
             for (int col = 0; col < res.length; col++) {
-                if (!useIndex)
+                if (!useIdx)
                     data.sort(col);
 
                 double[] x = new double[rowsCnt + 1];
@@ -86,7 +86,7 @@ public class MSEImpurityMeasureCalculator extends ImpurityMeasureCalculator<MSEI
                 int leftSize = 0;
                 for (int i = 0; i <= rowsCnt; i++) {
                     if (leftSize > 0) {
-                        double lblVal = getLabelValue(data, index, col, i - 1);
+                        double lblVal = getLabelValue(data, idx, col, i - 1);
 
                         leftY += lblVal;
                         leftY2 += Math.pow(lblVal, 2);
@@ -96,7 +96,7 @@ public class MSEImpurityMeasureCalculator extends ImpurityMeasureCalculator<MSEI
                     }
 
                     if (leftSize < rowsCnt)
-                        x[leftSize + 1] = getFeatureValue(data, index, col, i);
+                        x[leftSize + 1] = getFeatureValue(data, idx, col, i);
 
                     y[leftSize] = new MSEImpurityMeasure(
                         leftY, leftY2, leftSize, rightY, rightY2, rowsCnt - leftSize
index c617d8d..4a83eb2 100644 (file)
@@ -73,7 +73,7 @@ public abstract class RandomForestTrainer<L, S extends ImpurityComputer<Bootstra
     private static final double BUCKET_SIZE_FACTOR = (1 / 10.0);
 
     /** Count of trees. */
-    private int cntOfTrees = 1;
+    private int amountOfTrees = 1;
 
     /** Subsample size. */
     private double subSampleSize = 1.0;
@@ -115,7 +115,7 @@ public abstract class RandomForestTrainer<L, S extends ImpurityComputer<Bootstra
         List<TreeRoot> models = null;
         try (Dataset<EmptyContext, BootstrappedDatasetPartition> dataset = datasetBuilder.build(
             new EmptyContextBuilder<>(),
-            new BootstrappedDatasetBuilder<>(featureExtractor, lbExtractor, cntOfTrees, subSampleSize))) {
+            new BootstrappedDatasetBuilder<>(featureExtractor, lbExtractor, amountOfTrees, subSampleSize))) {
 
             if(!init(dataset))
                 return buildComposition(Collections.emptyList());
@@ -138,8 +138,8 @@ public abstract class RandomForestTrainer<L, S extends ImpurityComputer<Bootstra
      * @param cntOfTrees Count of trees.
      * @return an instance of current object with valid type in according to inheritance.
      */
-    public T withCountOfTrees(int cntOfTrees) {
-        this.cntOfTrees = cntOfTrees;
+    public T withAmountOfTrees(int amountOfTrees) {
+        this.amountOfTrees = amountOfTrees;
         return instance();
     }
 
@@ -348,7 +348,7 @@ public abstract class RandomForestTrainer<L, S extends ImpurityComputer<Bootstra
      */
     private Queue<TreeNode> createRootsQueue() {
         Queue<TreeNode> roots = new LinkedList<>();
-        for (int i = 0; i < cntOfTrees; i++)
+        for (int i = 0; i < amountOfTrees; i++)
             roots.add(new TreeNode(1, i));
         return roots;
     }
index 52d0b74..3ccb568 100644 (file)
@@ -27,7 +27,7 @@ public class NodeSplit {
     private final int featureId;
 
     /** Feature split value. */
-    private final double value;
+    private final double val;
 
     /** Impurity at this split point. */
     private final double impurity;
@@ -36,12 +36,12 @@ public class NodeSplit {
      * Creates an instance of NodeSplit.
      *
      * @param featureId Feature id.
-     * @param value Feature split value.
+     * @param val Feature split value.
      * @param impurity Impurity value.
      */
-    public NodeSplit(int featureId, double value, double impurity) {
+    public NodeSplit(int featureId, double val, double impurity) {
         this.featureId = featureId;
-        this.value = value;
+        this.val = val;
         this.impurity = impurity;
     }
 
@@ -52,7 +52,7 @@ public class NodeSplit {
      * @return list of children.
      */
     public List<TreeNode> split(TreeNode node) {
-        List<TreeNode> children = node.toConditional(featureId, value);
+        List<TreeNode> children = node.toConditional(featureId, val);
         node.setImpurity(impurity);
         return children;
     }
@@ -73,7 +73,7 @@ public class NodeSplit {
     }
 
     /** */
-    public double getValue() {
-        return value;
+    public double getVal() {
+        return val;
     }
 }
index eb06143..528e31d 100644 (file)
@@ -51,7 +51,7 @@ public class TreeNode implements Model<Vector, Double>, Serializable {
     private int featureId;
 
     /** Value. */
-    private double value;
+    private double val;
 
     /** Type. */
     private Type type;
@@ -76,7 +76,7 @@ public class TreeNode implements Model<Vector, Double>, Serializable {
      */
     public TreeNode(long id, int treeId) {
         this.id = new NodeId(treeId, id);
-        this.value = -1;
+        this.val = -1;
         this.type = Type.UNKNOWN;
         this.impurity = Double.POSITIVE_INFINITY;
         this.depth = 1;
@@ -87,9 +87,9 @@ public class TreeNode implements Model<Vector, Double>, Serializable {
         assert type != Type.UNKNOWN;
 
         if (type == Type.LEAF)
-            return value;
+            return val;
         else {
-            if (features.get(featureId) <= value)
+            if (features.get(featureId) <= val)
                 return left.apply(features);
             else
                 return right.apply(features);
@@ -109,7 +109,7 @@ public class TreeNode implements Model<Vector, Double>, Serializable {
             case LEAF:
                 return id;
             default:
-                if (features.get(featureId) <= value)
+                if (features.get(featureId) <= val)
                     return left.predictNextNodeKey(features);
                 else
                     return right.predictNextNodeKey(features);
@@ -120,12 +120,12 @@ public class TreeNode implements Model<Vector, Double>, Serializable {
      * Convert node to conditional node.
      *
      * @param featureId Feature id.
-     * @param value Value.
+     * @param val Value.
      */
-    public List<TreeNode> toConditional(int featureId, double value) {
+    public List<TreeNode> toConditional(int featureId, double val) {
         assert type == Type.UNKNOWN;
 
-        toLeaf(value);
+        toLeaf(val);
         left = new TreeNode(2 * id.nodeId(), id.treeId());
         right = new TreeNode(2 * id.nodeId() + 1, id.treeId());
         this.type = Type.CONDITIONAL;
@@ -138,12 +138,12 @@ public class TreeNode implements Model<Vector, Double>, Serializable {
     /**
      * Convert node to leaf.
      *
-     * @param value Value.
+     * @param val Value.
      */
-    public void toLeaf(double value) {
+    public void toLeaf(double val) {
         assert type == Type.UNKNOWN;
 
-        this.value = value;
+        this.val = val;
         this.type = Type.LEAF;
 
         this.left = null;
@@ -156,8 +156,8 @@ public class TreeNode implements Model<Vector, Double>, Serializable {
     }
 
     /** */
-    public void setValue(double value) {
-        this.value = value;
+    public void setVal(double val) {
+        this.val = val;
     }
 
     /** */
index d1ed87f..8320461 100644 (file)
@@ -183,9 +183,9 @@ public abstract class ImpurityHistogramsComputer<S extends ImpurityComputer<Boot
          */
         private void addTo(Map<Integer, S> from, Map<Integer, S> to) {
             from.forEach((key, hist) -> {
-                if(!to.containsKey(key)) {
+                if (!to.containsKey(key))
                     to.put(key, hist);
-                else {
+                else {
                     S sumOfHists = to.get(key).plus(hist);
                     to.put(key, sumOfHists);
                 }
index 056eece..cd343ef 100644 (file)
@@ -65,7 +65,7 @@ public abstract class LeafValuesComputer<T> implements Serializable {
             T stat = stats.get(id);
             if(stat != null) {
                 double leafVal = computeLeafValue(stat);
-                leaf.setValue(leafVal);
+                leaf.setVal(leafVal);
             }
         });
     }
index e6539d2..d34ab62 100644 (file)
@@ -68,10 +68,10 @@ public class ModelTrace {
      * Add field.
      *
      * @param name Name.
-     * @param value Value.
+     * @param val Value.
      */
-    public ModelTrace addField(String name, String value) {
-        mdlFields.add(new IgniteBiTuple<>(name, value));
+    public ModelTrace addField(String name, String val) {
+        mdlFields.add(new IgniteBiTuple<>(name, val));
         return this;
     }
 
index 74ff8f1..205f0ff 100644 (file)
@@ -22,6 +22,7 @@ import java.util.HashMap;
 import java.util.Map;
 import org.apache.ignite.ml.clustering.kmeans.KMeansModel;
 import org.apache.ignite.ml.clustering.kmeans.KMeansTrainer;
+import org.apache.ignite.ml.common.TrainerTest;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
 import org.apache.ignite.ml.math.distances.EuclideanDistance;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
@@ -36,7 +37,7 @@ import static org.junit.Assert.assertTrue;
 /**
  * Tests for {@link KMeansTrainer}.
  */
-public class KMeansTrainerTest {
+public class KMeansTrainerTest extends TrainerTest {
     /** Precision in test checks. */
     private static final double PRECISION = 1e-2;
 
@@ -59,7 +60,7 @@ public class KMeansTrainerTest {
     public void findOneClusters() {
         KMeansTrainer trainer = createAndCheckTrainer();
         KMeansModel knnMdl = trainer.withAmountOfClusters(1).fit(
-            new LocalDatasetBuilder<>(data, 2),
+            new LocalDatasetBuilder<>(data, parts),
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
             (k, v) -> v[2]
         );
@@ -77,19 +78,19 @@ public class KMeansTrainerTest {
     public void testUpdateMdl() {
         KMeansTrainer trainer = createAndCheckTrainer();
         KMeansModel originalMdl = trainer.withAmountOfClusters(1).fit(
-            new LocalDatasetBuilder<>(data, 2),
+            new LocalDatasetBuilder<>(data, parts),
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
             (k, v) -> v[2]
         );
         KMeansModel updatedMdlOnSameDataset = trainer.update(
             originalMdl,
-            new LocalDatasetBuilder<>(data, 2),
+            new LocalDatasetBuilder<>(data, parts),
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
             (k, v) -> v[2]
         );
         KMeansModel updatedMdlOnEmptyDataset = trainer.update(
             originalMdl,
-            new LocalDatasetBuilder<>(new HashMap<Integer, double[]>(), 2),
+            new LocalDatasetBuilder<>(new HashMap<Integer, double[]>(), parts),
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
             (k, v) -> v[2]
         );
index 678ed44..5d3bb5f 100644 (file)
@@ -28,7 +28,7 @@ import org.junit.runners.Parameterized;
 @RunWith(Parameterized.class)
 public class TrainerTest {
     /** Number of parts to be tested. */
-    private static final int[] partsToBeTested = new int[]{1, 2, 3, 4, 5, 7, 100};
+    private static final int[] partsToBeTested = new int[] {1, 2, 3, 4, 13};
 
     /** Parameters. */
     @Parameterized.Parameters(name = "Data divided on {0} partitions, training with batch size {1}")
@@ -36,7 +36,7 @@ public class TrainerTest {
         List<Integer[]> res = new ArrayList<>();
 
         for (int part : partsToBeTested)
-            res.add(new Integer[]{part});
+            res.add(new Integer[] {part});
 
         return res;
     }
index 4c3655b..4958b4b 100644 (file)
@@ -21,6 +21,7 @@ import java.util.HashMap;
 import java.util.Map;
 import java.util.function.BiFunction;
 import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.common.TrainerTest;
 import org.apache.ignite.ml.composition.ModelsComposition;
 import org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory;
 import org.apache.ignite.ml.composition.boosting.convergence.simple.ConvergenceCheckerStubFactory;
@@ -38,7 +39,7 @@ import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
 
 /** */
-public class GDBTrainerTest {
+public class GDBTrainerTest extends TrainerTest {
     /** */
     @Test
     public void testFitRegression() {
index e738716..0d46361 100644 (file)
@@ -21,7 +21,9 @@ import org.junit.Test;
 
 import static org.junit.Assert.assertEquals;
 
+/** */
 public class MeanValuePredictionsAggregatorTest {
+    /** Aggregator. */
     private PredictionsAggregator aggregator = new MeanValuePredictionsAggregator();
 
     /** */
index 8649b72..4d25a86 100644 (file)
@@ -22,6 +22,7 @@ import org.junit.Test;
 import static org.junit.Assert.assertEquals;
 
 public class OnMajorityPredictionsAggregatorTest {
+    /** Aggregator. */
     private PredictionsAggregator aggregator = new OnMajorityPredictionsAggregator();
 
     /** */
index 131b69b..9efb939 100644 (file)
@@ -74,10 +74,10 @@ public class ObjectHistogramTest {
 
     /**
      * @param hist History.
-     * @param expectedBuckets Expected buckets.
-     * @param expectedCounters Expected counters.
+     * @param expBuckets Expected buckets.
+     * @param expCounters Expected counters.
      */
-    private void testBuckets(ObjectHistogram<Double> hist, int[] expectedBuckets, int[] expectedCounters) {
+    private void testBuckets(ObjectHistogram<Double> hist, int[] expBuckets, int[] expCounters) {
         int size = hist.buckets().size();
         int[] buckets = new int[size];
         int[] counters = new int[size];
@@ -87,8 +87,8 @@ public class ObjectHistogramTest {
             buckets[ptr++] = bucket;
         }
 
-        assertArrayEquals(expectedBuckets, buckets);
-        assertArrayEquals(expectedCounters, counters);
+        assertArrayEquals(expBuckets, buckets);
+        assertArrayEquals(expCounters, counters);
     }
 
     /**
@@ -96,12 +96,12 @@ public class ObjectHistogramTest {
      */
     @Test
     public void testAdd() {
-        double value = 100.;
-        hist1.addElement(value);
-        Optional<Double> counter = hist1.getValue(computeBucket(value));
+        double val = 100.0;
+        hist1.addElement(val);
+        Optional<Double> cntr = hist1.getValue(computeBucket(val));
 
-        assertTrue(counter.isPresent());
-        assertEquals(1, counter.get().intValue());
+        assertTrue(cntr.isPresent());
+        assertEquals(1, cntr.get().intValue());
     }
 
     /**
@@ -109,8 +109,8 @@ public class ObjectHistogramTest {
      */
     @Test
     public void testAddHist() {
-        ObjectHistogram<Double> result = hist1.plus(hist2);
-        testBuckets(result, new int[] {0, 1, 2, 3, 4, 5, 6}, new int[] {10, 8, 2, 1, 1, 2, 1});
+        ObjectHistogram<Double> res = hist1.plus(hist2);
+        testBuckets(res, new int[] {0, 1, 2, 3, 4, 5, 6}, new int[] {10, 8, 2, 1, 1, 2, 1});
     }
 
     /**
@@ -133,18 +133,19 @@ public class ObjectHistogramTest {
         assertArrayEquals(new double[] {4., 7., 9., 10., 11., 12.}, sums, 0.01);
     }
 
+    /** */
     @Test
     public void testOfSum() {
         IgniteFunction<Double, Integer> bucketMap = x -> (int) (Math.ceil(x * 100) % 100);
-        IgniteFunction<Double, Double> counterMap = x -> Math.pow(x, 2);
+        IgniteFunction<Double, Double> cntrMap = x -> Math.pow(x, 2);
 
-        ObjectHistogram<Double> forAllHistogram = new ObjectHistogram<>(bucketMap, counterMap);
+        ObjectHistogram<Double> forAllHistogram = new ObjectHistogram<>(bucketMap, cntrMap);
         Random rnd = new Random();
         List<ObjectHistogram<Double>> partitions = new ArrayList<>();
         int cntOfPartitions = rnd.nextInt(100);
         int sizeOfDataset = rnd.nextInt(10000);
         for(int i = 0; i < cntOfPartitions; i++)
-            partitions.add(new ObjectHistogram<>(bucketMap, counterMap));
+            partitions.add(new ObjectHistogram<>(bucketMap, cntrMap));
 
         for(int i = 0; i < sizeOfDataset; i++) {
             double objVal = rnd.nextDouble();
@@ -152,7 +153,7 @@ public class ObjectHistogramTest {
             partitions.get(rnd.nextInt(partitions.size())).addElement(objVal);
         }
 
-        Optional<ObjectHistogram<Double>> leftSum = partitions.stream().reduce((x,y) -> x.plus(y));
+        Optional<ObjectHistogram<Double>> leftSum = partitions.stream().reduce(ObjectHistogram::plus);
         Optional<ObjectHistogram<Double>> rightSum = partitions.stream().reduce((x,y) -> y.plus(x));
         assertTrue(leftSum.isPresent());
         assertTrue(rightSum.isPresent());
@@ -162,9 +163,9 @@ public class ObjectHistogramTest {
     }
 
     /**
-     * @param value Value.
+     * @param val Value.
      */
-    private int computeBucket(Double value) {
-        return (int)Math.rint(value);
+    private int computeBucket(Double val) {
+        return (int)Math.rint(val);
     }
 }
index 7e5a079..73192f0 100644 (file)
@@ -41,7 +41,7 @@ public class LearningEnvironmentTest {
         RandomForestRegressionTrainer trainer = new RandomForestRegressionTrainer(
             IntStream.range(0, 0).mapToObj(
                 x -> new FeatureMeta("", 0, false)).collect(Collectors.toList())
-        ).withCountOfTrees(101)
+        ).withAmountOfTrees(101)
             .withFeaturesCountSelectionStrgy(FeaturesCountSelectionStrategies.ONE_THIRD)
             .withMaxDepth(4)
             .withMinImpurityDelta(0.)
index 199644b..9c75824 100644 (file)
@@ -26,7 +26,6 @@ import org.apache.ignite.ml.knn.ann.ANNClassificationModel;
 import org.apache.ignite.ml.knn.ann.ANNClassificationTrainer;
 import org.apache.ignite.ml.knn.classification.NNStrategy;
 import org.apache.ignite.ml.math.distances.EuclideanDistance;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.junit.Assert;
 import org.junit.Test;
@@ -109,11 +108,16 @@ public class ANNClassificationTest extends TrainerTest {
             .withDistanceMeasure(new EuclideanDistance())
             .withStrategy(NNStrategy.SIMPLE);
 
-        Vector v1 = VectorUtils.of(550, 550);
-        Vector v2 = VectorUtils.of(-550, -550);
-        TestUtils.assertEquals(originalMdl.apply(v1), updatedOnSameDataset.apply(v1), PRECISION);
-        TestUtils.assertEquals(originalMdl.apply(v2), updatedOnSameDataset.apply(v2), PRECISION);
-        TestUtils.assertEquals(originalMdl.apply(v1), updatedOnEmptyDataset.apply(v1), PRECISION);
-        TestUtils.assertEquals(originalMdl.apply(v2), updatedOnEmptyDataset.apply(v2), PRECISION);
+        Assert.assertNotNull(updatedOnSameDataset.getCandidates());
+
+        Assert.assertTrue(updatedOnSameDataset.toString().contains(NNStrategy.SIMPLE.name()));
+        Assert.assertTrue(updatedOnSameDataset.toString(true).contains(NNStrategy.SIMPLE.name()));
+        Assert.assertTrue(updatedOnSameDataset.toString(false).contains(NNStrategy.SIMPLE.name()));
+
+        Assert.assertNotNull(updatedOnEmptyDataset.getCandidates());
+
+        Assert.assertTrue(updatedOnEmptyDataset.toString().contains(NNStrategy.SIMPLE.name()));
+        Assert.assertTrue(updatedOnEmptyDataset.toString(true).contains(NNStrategy.SIMPLE.name()));
+        Assert.assertTrue(updatedOnEmptyDataset.toString(false).contains(NNStrategy.SIMPLE.name()));
     }
 }
index 52ff1ec..9ff0bc2 100644 (file)
 
 package org.apache.ignite.ml.knn;
 
-import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.HashMap;
-import java.util.List;
 import java.util.Map;
+import org.apache.ignite.ml.common.TrainerTest;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
 import org.apache.ignite.ml.knn.classification.NNStrategy;
 import org.apache.ignite.ml.knn.regression.KNNRegressionModel;
@@ -32,34 +31,13 @@ import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
 import org.junit.Assert;
 import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.Parameterized;
 
 import static junit.framework.TestCase.assertEquals;
 
 /**
  * Tests for {@link KNNRegressionTrainer}.
  */
-@RunWith(Parameterized.class)
-public class KNNRegressionTest {
-    /** Number of parts to be tested. */
-    private static final int[] partsToBeTested = new int[] {1, 2, 3, 4, 5, 7, 100};
-
-    /** Number of partitions. */
-    @Parameterized.Parameter
-    public int parts;
-
-    /** Parameters. */
-    @Parameterized.Parameters(name = "Data divided on {0} partitions, training with batch size {1}")
-    public static Iterable<Integer[]> data() {
-        List<Integer[]> res = new ArrayList<>();
-
-        for (int part : partsToBeTested)
-            res.add(new Integer[] {part});
-
-        return res;
-    }
-
+public class KNNRegressionTest extends TrainerTest {
     /** */
     @Test
     public void testSimpleRegressionWithOneNeighbour() {
index f8dc078..42d7efd 100644 (file)
@@ -23,6 +23,9 @@ import org.junit.Test;
 
 import static org.junit.Assert.assertEquals;
 
+/**
+ * Tests for {@link VectorUtils }
+ */
 public class VectorUtilsTest {
     /** */
     @Test
@@ -55,14 +58,12 @@ public class VectorUtilsTest {
     /** */
     @Test(expected = NullPointerException.class)
     public void testFails1() {
-        double[] values = null;
-        VectorUtils.of(values);
+        VectorUtils.of((double[])null);
     }
 
     /** */
     @Test(expected = NullPointerException.class)
     public void testFails2() {
-        Double[] values = null;
-        VectorUtils.of(values);
+        VectorUtils.of((Double[])null);
     }
 }
index 6af03df..b720695 100644 (file)
@@ -20,13 +20,12 @@ package org.apache.ignite.ml.math.isolve.lsqr;
 import java.util.Arrays;
 import java.util.HashMap;
 import java.util.Map;
+import org.apache.ignite.ml.common.TrainerTest;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
 import org.apache.ignite.ml.dataset.primitive.builder.data.SimpleLabeledDatasetDataBuilder;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.Parameterized;
 
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
@@ -35,26 +34,7 @@ import static org.junit.Assert.assertTrue;
 /**
  * Tests for {@link LSQROnHeap}.
  */
-@RunWith(Parameterized.class)
-public class LSQROnHeapTest {
-    /** Parameters. */
-    @Parameterized.Parameters(name = "Data divided on {0} partitions")
-    public static Iterable<Integer[]> data() {
-        return Arrays.asList(
-            new Integer[] {1},
-            new Integer[] {2},
-            new Integer[] {3},
-            new Integer[] {5},
-            new Integer[] {7},
-            new Integer[] {100},
-            new Integer[] {1000}
-        );
-    }
-
-    /** Number of partitions. */
-    @Parameterized.Parameter
-    public int parts;
-
+public class LSQROnHeapTest extends TrainerTest {
     /** Tests solving simple linear system. */
     @Test
     public void testSolveLinearSystem() {
index d740577..e59d515 100644 (file)
@@ -38,6 +38,11 @@ public class PipelineMdlTest {
         verifyPredict(getMdl(new LogisticRegressionModel(weights, 1.0).withRawLabels(true)));
     }
 
+    /**
+     * Get the empty internal model.
+     *
+     * @param internalMdl Internal model.
+     */
     private PipelineMdl<Integer, double[]> getMdl(LogisticRegressionModel internalMdl) {
         return new PipelineMdl<Integer, double[]>()
             .withFeatureExtractor(null)
index d465e82..4b7fa33 100644 (file)
 
 package org.apache.ignite.ml.preprocessing.binarization;
 
-import java.util.Arrays;
 import java.util.HashMap;
 import java.util.Map;
+import org.apache.ignite.ml.common.TrainerTest;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.Parameterized;
 
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
@@ -35,26 +33,7 @@ import static org.junit.Assert.assertEquals;
 /**
  * Tests for {@link BinarizationTrainer}.
  */
-@RunWith(Parameterized.class)
-public class BinarizationTrainerTest {
-    /** Parameters. */
-    @Parameterized.Parameters(name = "Data divided on {0} partitions")
-    public static Iterable<Integer[]> data() {
-        return Arrays.asList(
-            new Integer[] {1},
-            new Integer[] {2},
-            new Integer[] {3},
-            new Integer[] {5},
-            new Integer[] {7},
-            new Integer[] {100},
-            new Integer[] {1000}
-        );
-    }
-
-    /** Number of partitions. */
-    @Parameterized.Parameter
-    public int parts;
-
+public class BinarizationTrainerTest extends TrainerTest {
     /** Tests {@code fit()} method. */
     @Test
     public void testFit() {
index 6d01901..23afd30 100644 (file)
 
 package org.apache.ignite.ml.preprocessing.encoding;
 
-import java.util.Arrays;
 import java.util.HashMap;
 import java.util.Map;
+import org.apache.ignite.ml.common.TrainerTest;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
 import org.apache.ignite.ml.math.exceptions.preprocessing.UnknownCategorialFeatureValue;
 import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.Parameterized;
 
 import static junit.framework.TestCase.fail;
 import static org.junit.Assert.assertArrayEquals;
@@ -33,26 +31,7 @@ import static org.junit.Assert.assertArrayEquals;
 /**
  * Tests for {@link EncoderTrainer}.
  */
-@RunWith(Parameterized.class)
-public class EncoderTrainerTest {
-    /** Parameters. */
-    @Parameterized.Parameters(name = "Data divided on {0} partitions")
-    public static Iterable<Integer[]> data() {
-        return Arrays.asList(
-            new Integer[]{1},
-            new Integer[]{2},
-            new Integer[]{3},
-            new Integer[]{5},
-            new Integer[]{7},
-            new Integer[]{100},
-            new Integer[]{1000}
-        );
-    }
-
-    /** Number of partitions. */
-    @Parameterized.Parameter
-    public int parts;
-
+public class EncoderTrainerTest extends TrainerTest {
     /** Tests {@code fit()} method. */
     @Test
     public void testFitOnStringCategorialFeatures() {
index 006ac29..9c11d13 100644 (file)
 
 package org.apache.ignite.ml.preprocessing.imputing;
 
-import java.util.Arrays;
 import java.util.HashMap;
 import java.util.Map;
+import org.apache.ignite.ml.common.TrainerTest;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.Parameterized;
 
 import static org.junit.Assert.assertArrayEquals;
 
 /**
  * Tests for {@link ImputerTrainer}.
  */
-@RunWith(Parameterized.class)
-public class ImputerTrainerTest {
-    /** Parameters. */
-    @Parameterized.Parameters(name = "Data divided on {0} partitions")
-    public static Iterable<Integer[]> data() {
-        return Arrays.asList(
-            new Integer[] {1},
-            new Integer[] {2},
-            new Integer[] {3},
-            new Integer[] {5},
-            new Integer[] {7},
-            new Integer[] {100},
-            new Integer[] {1000}
-        );
-    }
-
-    /** Number of partitions. */
-    @Parameterized.Parameter
-    public int parts;
-
+public class ImputerTrainerTest extends TrainerTest {
     /** Tests {@code fit()} method. */
     @Test
     public void testFit() {
index 3c30f3e..91562da 100644 (file)
@@ -42,7 +42,7 @@ public class MaxAbsScalerPreprocessorTest {
             (k, v) -> v
         );
 
-        double[][] expectedData = new double[][] {
+        double[][] expData = new double[][] {
             {.5, 4. / 22, 1. / 300},
             {.25, 8. / 22, 22. / 300},
             {-1., 10. / 22, 100. / 300},
@@ -50,6 +50,6 @@ public class MaxAbsScalerPreprocessorTest {
         };
 
         for (int i = 0; i < data.length; i++)
-            assertArrayEquals(expectedData[i], preprocessor.apply(i, VectorUtils.of(data[i])).asArray(), 1e-8);
+            assertArrayEquals(expData[i], preprocessor.apply(i, VectorUtils.of(data[i])).asArray(), 1e-8);
     }
 }
index 5711660..844468e 100644 (file)
 
 package org.apache.ignite.ml.preprocessing.maxabsscaling;
 
-import java.util.Arrays;
 import java.util.HashMap;
 import java.util.Map;
+import org.apache.ignite.ml.common.TrainerTest;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.Parameterized;
 
 import static org.junit.Assert.assertArrayEquals;
 
 /**
  * Tests for {@link MaxAbsScalerTrainer}.
  */
-@RunWith(Parameterized.class)
-public class MaxAbsScalerTrainerTest {
-    /** Parameters. */
-    @Parameterized.Parameters(name = "Data divided on {0} partitions")
-    public static Iterable<Integer[]> data() {
-        return Arrays.asList(
-            new Integer[] {1},
-            new Integer[] {2},
-            new Integer[] {3},
-            new Integer[] {5},
-            new Integer[] {7},
-            new Integer[] {100},
-            new Integer[] {1000}
-        );
-    }
-
-    /** Number of partitions. */
-    @Parameterized.Parameter
-    public int parts;
-
+public class MaxAbsScalerTrainerTest extends TrainerTest {
     /** Tests {@code fit()} method. */
     @Test
     public void testFit() {
index 451f5e9..4c0a99f 100644 (file)
 
 package org.apache.ignite.ml.preprocessing.minmaxscaling;
 
-import java.util.Arrays;
 import java.util.HashMap;
 import java.util.Map;
+import org.apache.ignite.ml.common.TrainerTest;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.Parameterized;
 
 import static org.junit.Assert.assertArrayEquals;
 
 /**
  * Tests for {@link MinMaxScalerTrainer}.
  */
-@RunWith(Parameterized.class)
-public class MinMaxScalerTrainerTest {
-    /** Parameters. */
-    @Parameterized.Parameters(name = "Data divided on {0} partitions")
-    public static Iterable<Integer[]> data() {
-        return Arrays.asList(
-            new Integer[] {1},
-            new Integer[] {2},
-            new Integer[] {3},
-            new Integer[] {5},
-            new Integer[] {7},
-            new Integer[] {100},
-            new Integer[] {1000}
-        );
-    }
-
-    /** Number of partitions. */
-    @Parameterized.Parameter
-    public int parts;
-
+public class MinMaxScalerTrainerTest extends TrainerTest {
     /** Tests {@code fit()} method. */
     @Test
     public void testFit() {
index 7b02f20..9d39354 100644 (file)
 
 package org.apache.ignite.ml.preprocessing.normalization;
 
-import java.util.Arrays;
 import java.util.HashMap;
 import java.util.Map;
+import org.apache.ignite.ml.common.TrainerTest;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.apache.ignite.ml.preprocessing.binarization.BinarizationTrainer;
 import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.Parameterized;
 
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
@@ -34,26 +32,7 @@ import static org.junit.Assert.assertEquals;
 /**
  * Tests for {@link BinarizationTrainer}.
  */
-@RunWith(Parameterized.class)
-public class NormalizationTrainerTest {
-    /** Parameters. */
-    @Parameterized.Parameters(name = "Data divided on {0} partitions")
-    public static Iterable<Integer[]> data() {
-        return Arrays.asList(
-            new Integer[] {1},
-            new Integer[] {2},
-            new Integer[] {3},
-            new Integer[] {5},
-            new Integer[] {7},
-            new Integer[] {100},
-            new Integer[] {1000}
-        );
-    }
-
-    /** Number of partitions. */
-    @Parameterized.Parameter
-    public int parts;
-
+public class NormalizationTrainerTest extends TrainerTest {
     /** Tests {@code fit()} method. */
     @Test
     public void testFit() {
index 9c35ac7..3ca1a07 100644 (file)
@@ -123,7 +123,7 @@ public class LinearRegressionLSQRTrainerTest extends TrainerTest {
 
         LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
 
-        LinearRegressionModel originalModel = trainer.fit(
+        LinearRegressionModel originalMdl = trainer.fit(
             data,
             parts,
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
@@ -131,7 +131,7 @@ public class LinearRegressionLSQRTrainerTest extends TrainerTest {
         );
 
         LinearRegressionModel updatedOnSameDS = trainer.update(
-            originalModel,
+            originalMdl,
             data,
             parts,
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
@@ -139,17 +139,17 @@ public class LinearRegressionLSQRTrainerTest extends TrainerTest {
         );
 
         LinearRegressionModel updatedOnEmpyDS = trainer.update(
-            originalModel,
+            originalMdl,
             new HashMap<Integer, double[]>(),
             parts,
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
             (k, v) -> v[coef.length]
         );
 
-        assertArrayEquals(originalModel.getWeights().getStorage().data(), updatedOnSameDS.getWeights().getStorage().data(), 1e-6);
-        assertEquals(originalModel.getIntercept(), updatedOnSameDS.getIntercept(), 1e-6);
+        assertArrayEquals(originalMdl.getWeights().getStorage().data(), updatedOnSameDS.getWeights().getStorage().data(), 1e-6);
+        assertEquals(originalMdl.getIntercept(), updatedOnSameDS.getIntercept(), 1e-6);
 
-        assertArrayEquals(originalModel.getWeights().getStorage().data(), updatedOnEmpyDS.getWeights().getStorage().data(), 1e-6);
-        assertEquals(originalModel.getIntercept(), updatedOnEmpyDS.getIntercept(), 1e-6);
+        assertArrayEquals(originalMdl.getWeights().getStorage().data(), updatedOnEmpyDS.getWeights().getStorage().data(), 1e-6);
+        assertEquals(originalMdl.getIntercept(), updatedOnEmpyDS.getIntercept(), 1e-6);
     }
 }
index 86b0f27..1af9109 100644 (file)
@@ -94,7 +94,7 @@ public class LinearRegressionSGDTrainerTest extends TrainerTest {
             RPropParameterUpdate::avg
         ), 100000, 10, 100, 0L);
 
-        LinearRegressionModel originalModel = trainer.withSeed(0).fit(
+        LinearRegressionModel originalMdl = trainer.withSeed(0).fit(
             data,
             parts,
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
@@ -103,7 +103,7 @@ public class LinearRegressionSGDTrainerTest extends TrainerTest {
 
 
         LinearRegressionModel updatedOnSameDS = trainer.withSeed(0).update(
-            originalModel,
+            originalMdl,
             data,
             parts,
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
@@ -111,7 +111,7 @@ public class LinearRegressionSGDTrainerTest extends TrainerTest {
         );
 
         LinearRegressionModel updatedOnEmptyDS = trainer.withSeed(0).update(
-            originalModel,
+            originalMdl,
             new HashMap<Integer, double[]>(),
             parts,
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
@@ -119,19 +119,19 @@ public class LinearRegressionSGDTrainerTest extends TrainerTest {
         );
 
         assertArrayEquals(
-            originalModel.getWeights().getStorage().data(),
+            originalMdl.getWeights().getStorage().data(),
             updatedOnSameDS.getWeights().getStorage().data(),
             1.0
         );
 
-        assertEquals(originalModel.getIntercept(), updatedOnSameDS.getIntercept(), 1.0);
+        assertEquals(originalMdl.getIntercept(), updatedOnSameDS.getIntercept(), 1.0);
 
         assertArrayEquals(
-            originalModel.getWeights().getStorage().data(),
+            originalMdl.getWeights().getStorage().data(),
             updatedOnEmptyDS.getWeights().getStorage().data(),
             1e-1
         );
 
-        assertEquals(originalModel.getIntercept(), updatedOnEmptyDS.getIntercept(), 1e-1);
+        assertEquals(originalMdl.getIntercept(), updatedOnEmptyDS.getIntercept(), 1e-1);
     }
 }
index 73c8842..78cd08d 100644 (file)
@@ -103,7 +103,7 @@ public class LogRegMultiClassTrainerTest extends TrainerTest {
             .withBatchSize(100)
             .withSeed(123L);
 
-        LogRegressionMultiClassModel originalModel = trainer.fit(
+        LogRegressionMultiClassModel originalMdl = trainer.fit(
             cacheMock,
             parts,
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
@@ -111,7 +111,7 @@ public class LogRegMultiClassTrainerTest extends TrainerTest {
         );
 
         LogRegressionMultiClassModel updatedOnSameDS = trainer.update(
-            originalModel,
+            originalMdl,
             cacheMock,
             parts,
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
@@ -119,7 +119,7 @@ public class LogRegMultiClassTrainerTest extends TrainerTest {
         );
 
         LogRegressionMultiClassModel updatedOnEmptyDS = trainer.update(
-            originalModel,
+            originalMdl,
             new HashMap<Integer, double[]>(),
             parts,
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
@@ -135,8 +135,8 @@ public class LogRegMultiClassTrainerTest extends TrainerTest {
 
 
         for (Vector vec : vectors) {
-            TestUtils.assertEquals(originalModel.apply(vec), updatedOnSameDS.apply(vec), PRECISION);
-            TestUtils.assertEquals(originalModel.apply(vec), updatedOnEmptyDS.apply(vec), PRECISION);
+            TestUtils.assertEquals(originalMdl.apply(vec), updatedOnSameDS.apply(vec), PRECISION);
+            TestUtils.assertEquals(originalMdl.apply(vec), updatedOnEmptyDS.apply(vec), PRECISION);
         }
     }
 }
index 1da0d1a..723677c 100644 (file)
@@ -76,7 +76,7 @@ public class LogisticRegressionSGDTrainerTest extends TrainerTest {
             SimpleGDParameterUpdate::avg
         ), 100000, 10, 100, 123L);
 
-        LogisticRegressionModel originalModel = trainer.fit(
+        LogisticRegressionModel originalMdl = trainer.fit(
             cacheMock,
             parts,
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
@@ -84,7 +84,7 @@ public class LogisticRegressionSGDTrainerTest extends TrainerTest {
         );
 
         LogisticRegressionModel updatedOnSameDS = trainer.update(
-            originalModel,
+            originalMdl,
             cacheMock,
             parts,
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
@@ -92,7 +92,7 @@ public class LogisticRegressionSGDTrainerTest extends TrainerTest {
         );
 
         LogisticRegressionModel updatedOnEmptyDS = trainer.update(
-            originalModel,
+            originalMdl,
             new HashMap<Integer, double[]>(),
             parts,
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
@@ -101,9 +101,9 @@ public class LogisticRegressionSGDTrainerTest extends TrainerTest {
 
         Vector v1 = VectorUtils.of(100, 10);
         Vector v2 = VectorUtils.of(10, 100);
-        TestUtils.assertEquals(originalModel.apply(v1), updatedOnSameDS.apply(v1), PRECISION);
-        TestUtils.assertEquals(originalModel.apply(v2), updatedOnSameDS.apply(v2), PRECISION);
-        TestUtils.assertEquals(originalModel.apply(v2), updatedOnEmptyDS.apply(v2), PRECISION);
-        TestUtils.assertEquals(originalModel.apply(v1), updatedOnEmptyDS.apply(v1), PRECISION);
+        TestUtils.assertEquals(originalMdl.apply(v1), updatedOnSameDS.apply(v1), PRECISION);
+        TestUtils.assertEquals(originalMdl.apply(v2), updatedOnSameDS.apply(v2), PRECISION);
+        TestUtils.assertEquals(originalMdl.apply(v2), updatedOnEmptyDS.apply(v2), PRECISION);
+        TestUtils.assertEquals(originalMdl.apply(v1), updatedOnEmptyDS.apply(v1), PRECISION);
     }
 }
index 84975a8..d89b9bf 100644 (file)
@@ -45,7 +45,7 @@ public class DecisionTreeRegressionTrainerTest {
 
     /** Use index [= 1 if true]. */
     @Parameterized.Parameter(1)
-    public int useIndex;
+    public int useIdx;
 
     /** Test parameters. */
     @Parameterized.Parameters(name = "Data divided on {0} partitions. Use index = {1}.")
@@ -73,7 +73,7 @@ public class DecisionTreeRegressionTrainerTest {
         }
 
         DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(1, 0)
-            .withUsingIdx(useIndex == 1);
+            .withUsingIdx(useIdx == 1);
 
         DecisionTreeNode tree = trainer.fit(
             data,
index 4ee717a..7405c16 100644 (file)
@@ -40,7 +40,7 @@ public class DecisionTreeDataTest {
 
     /** Use index. */
     @Parameterized.Parameter
-    public boolean useIndex;
+    public boolean useIdx;
 
     /** */
     @Test
@@ -48,7 +48,7 @@ public class DecisionTreeDataTest {
         double[][] features = new double[][]{{0}, {1}, {2}, {3}, {4}, {5}};
         double[] labels = new double[]{0, 1, 2, 3, 4, 5};
 
-        DecisionTreeData data = new DecisionTreeData(features, labels, useIndex);
+        DecisionTreeData data = new DecisionTreeData(features, labels, useIdx);
         DecisionTreeData filteredData = data.filter(obj -> obj[0] > 2);
 
         assertArrayEquals(new double[][]{{3}, {4}, {5}}, filteredData.getFeatures());
@@ -61,7 +61,7 @@ public class DecisionTreeDataTest {
         double[][] features = new double[][]{{4, 1}, {3, 3}, {2, 0}, {1, 4}, {0, 2}};
         double[] labels = new double[]{0, 1, 2, 3, 4};
 
-        DecisionTreeData data = new DecisionTreeData(features, labels, useIndex);
+        DecisionTreeData data = new DecisionTreeData(features, labels, useIdx);
 
         data.sort(0);
 
index 78bdfdf..b8ad49a 100644 (file)
@@ -75,41 +75,41 @@ public class TreeDataIndexTest {
     };
 
     /** */
-    private TreeDataIndex index = new TreeDataIndex(features, labels);
+    private TreeDataIndex idx = new TreeDataIndex(features, labels);
 
     /** */
     @Test
     public void labelInSortedOrderTest() {
-        assertEquals(features.length, index.rowsCount());
-        assertEquals(features[0].length, index.columnsCount());
+        assertEquals(features.length, idx.rowsCount());
+        assertEquals(features[0].length, idx.columnsCount());
 
-        for (int k = 0; k < index.rowsCount(); k++) {
-            for (int featureId = 0; featureId < index.columnsCount(); featureId++)
-                assertEquals(labelsInSortedOrder[k][featureId], index.labelInSortedOrder(k, featureId), 0.01);
+        for (int k = 0; k < idx.rowsCount(); k++) {
+            for (int featureId = 0; featureId < idx.columnsCount(); featureId++)
+                assertEquals(labelsInSortedOrder[k][featureId], idx.labelInSortedOrder(k, featureId), 0.01);
         }
     }
 
     /** */
     @Test
     public void featuresInSortedOrderTest() {
-        assertEquals(features.length, index.rowsCount());
-        assertEquals(features[0].length, index.columnsCount());
+        assertEquals(features.length, idx.rowsCount());
+        assertEquals(features[0].length, idx.columnsCount());
 
-        for (int k = 0; k < index.rowsCount(); k++) {
-            for (int featureId = 0; featureId < index.columnsCount(); featureId++)
-                assertArrayEquals(featuresInSortedOrder[k][featureId], index.featuresInSortedOrder(k, featureId), 0.01);
+        for (int k = 0; k < idx.rowsCount(); k++) {
+            for (int featureId = 0; featureId < idx.columnsCount(); featureId++)
+                assertArrayEquals(featuresInSortedOrder[k][featureId], idx.featuresInSortedOrder(k, featureId), 0.01);
         }
     }
 
     /** */
     @Test
     public void featureInSortedOrderTest() {
-        assertEquals(features.length, index.rowsCount());
-        assertEquals(features[0].length, index.columnsCount());
+        assertEquals(features.length, idx.rowsCount());
+        assertEquals(features[0].length, idx.columnsCount());
 
-        for (int k = 0; k < index.rowsCount(); k++) {
-            for (int featureId = 0; featureId < index.columnsCount(); featureId++)
-                assertEquals((double)k + 1, index.featureInSortedOrder(k, featureId), 0.01);
+        for (int k = 0; k < idx.rowsCount(); k++) {
+            for (int featureId = 0; featureId < idx.columnsCount(); featureId++)
+                assertEquals((double)k + 1, idx.featureInSortedOrder(k, featureId), 0.01);
         }
     }
 
@@ -120,9 +120,9 @@ public class TreeDataIndexTest {
         TreeFilter filter2 = features -> features[1] > 2;
         TreeFilter filterAnd = filter1.and(features -> features[1] > 2);
 
-        TreeDataIndex filtered1 = index.filter(filter1);
+        TreeDataIndex filtered1 = idx.filter(filter1);
         TreeDataIndex filtered2 = filtered1.filter(filter2);
-        TreeDataIndex filtered3 = index.filter(filterAnd);
+        TreeDataIndex filtered3 = idx.filter(filterAnd);
 
         assertEquals(2, filtered1.rowsCount());
         assertEquals(4, filtered1.columnsCount());
index a328bd7..0c77a2c 100644 (file)
@@ -45,7 +45,7 @@ public class GiniImpurityMeasureCalculatorTest {
 
     /** Use index. */
     @Parameterized.Parameter
-    public boolean useIndex;
+    public boolean useIdx;
 
     /** */
     @Test
@@ -56,9 +56,9 @@ public class GiniImpurityMeasureCalculatorTest {
         Map<Double, Integer> encoder = new HashMap<>();
         encoder.put(0.0, 0);
         encoder.put(1.0, 1);
-        GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder, useIndex);
+        GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder, useIdx);
 
-        StepFunction<GiniImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels, useIndex), fs -> true, 0);
+        StepFunction<GiniImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels, useIdx), fs -> true, 0);
 
         assertEquals(2, impurity.length);
 
@@ -88,9 +88,9 @@ public class GiniImpurityMeasureCalculatorTest {
         Map<Double, Integer> encoder = new HashMap<>();
         encoder.put(0.0, 0);
         encoder.put(1.0, 1);
-        GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder, useIndex);
+        GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder, useIdx);
 
-        StepFunction<GiniImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels, useIndex), fs -> true, 0);
+        StepFunction<GiniImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels, useIdx), fs -> true, 0);
 
         assertEquals(1, impurity.length);
 
@@ -111,7 +111,7 @@ public class GiniImpurityMeasureCalculatorTest {
         encoder.put(1.0, 1);
         encoder.put(2.0, 2);
 
-        GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder, useIndex);
+        GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder, useIdx);
 
         assertEquals(0, calculator.getLabelCode(0.0));
         assertEquals(1, calculator.getLabelCode(1.0));
index 82b3805..ed1fce0 100644 (file)
@@ -43,7 +43,7 @@ public class MSEImpurityMeasureCalculatorTest {
 
     /** Use index. */
     @Parameterized.Parameter
-    public boolean useIndex;
+    public boolean useIdx;
 
     /** */
     @Test
@@ -51,9 +51,9 @@ public class MSEImpurityMeasureCalculatorTest {
         double[][] data = new double[][]{{0, 2}, {1, 1}, {2, 0}, {3, 3}};
         double[] labels = new double[]{1, 2, 2, 1};
 
-        MSEImpurityMeasureCalculator calculator = new MSEImpurityMeasureCalculator(useIndex);
+        MSEImpurityMeasureCalculator calculator = new MSEImpurityMeasureCalculator(useIdx);
 
-        StepFunction<MSEImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels, useIndex), fs -> true, 0);
+        StepFunction<MSEImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels, useIdx), fs -> true, 0);
 
         assertEquals(2, impurity.length);
 
index 087f4e8..3a038ff 100644 (file)
@@ -19,16 +19,14 @@ package org.apache.ignite.ml.tree.randomforest;
 
 import java.util.ArrayList;
 import java.util.HashMap;
-import java.util.List;
 import java.util.Map;
+import org.apache.ignite.ml.common.TrainerTest;
 import org.apache.ignite.ml.composition.ModelsComposition;
 import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator;
 import org.apache.ignite.ml.dataset.feature.FeatureMeta;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.Parameterized;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
@@ -36,31 +34,7 @@ import static org.junit.Assert.assertTrue;
 /**
  * Tests for {@link RandomForestClassifierTrainer}.
  */
-@RunWith(Parameterized.class)
-public class RandomForestClassifierTrainerTest {
-    /**
-     * Number of parts to be tested.
-     */
-    private static final int[] partsToBeTested = new int[] {1, 2, 3, 4, 5, 7};
-
-    /**
-     * Number of partitions.
-     */
-    @Parameterized.Parameter
-    public int parts;
-
-    /**
-     * Data iterator.
-     */
-    @Parameterized.Parameters(name = "Data divided on {0} partitions")
-    public static Iterable<Integer[]> data() {
-        List<Integer[]> res = new ArrayList<>();
-        for (int part : partsToBeTested)
-            res.add(new Integer[] {part});
-
-        return res;
-    }
-
+public class RandomForestClassifierTrainerTest extends TrainerTest {
     /** */
     @Test
     public void testFit() {
@@ -79,7 +53,7 @@ public class RandomForestClassifierTrainerTest {
         for (int i = 0; i < 4; i++)
             meta.add(new FeatureMeta("", i, false));
         RandomForestClassifierTrainer trainer = new RandomForestClassifierTrainer(meta)
-            .withCountOfTrees(5)
+            .withAmountOfTrees(5)
             .withFeaturesCountSelectionStrgy(x -> 2);
 
         ModelsComposition mdl = trainer.fit(sample, parts, (k, v) -> VectorUtils.of(k), (k, v) -> v);
@@ -106,15 +80,15 @@ public class RandomForestClassifierTrainerTest {
         for (int i = 0; i < 4; i++)
             meta.add(new FeatureMeta("", i, false));
         RandomForestClassifierTrainer trainer = new RandomForestClassifierTrainer(meta)
-            .withCountOfTrees(100)
+            .withAmountOfTrees(100)
             .withFeaturesCountSelectionStrgy(x -> 2);
 
-        ModelsComposition originalModel = trainer.fit(sample, parts, (k, v) -> VectorUtils.of(k), (k, v) -> v);
-        ModelsComposition updatedOnSameDS = trainer.update(originalModel, sample, parts, (k, v) -> VectorUtils.of(k), (k, v) -> v);
-        ModelsComposition updatedOnEmptyDS = trainer.update(originalModel, new HashMap<double[], Double>(), parts, (k, v) -> VectorUtils.of(k), (k, v) -> v);
+        ModelsComposition originalMdl = trainer.fit(sample, parts, (k, v) -> VectorUtils.of(k), (k, v) -> v);
+        ModelsComposition updatedOnSameDS = trainer.update(originalMdl, sample, parts, (k, v) -> VectorUtils.of(k), (k, v) -> v);
+        ModelsComposition updatedOnEmptyDS = trainer.update(originalMdl, new HashMap<double[], Double>(), parts, (k, v) -> VectorUtils.of(k), (k, v) -> v);
 
         Vector v = VectorUtils.of(5, 0.5, 0.05, 0.005);
-        assertEquals(originalModel.apply(v), updatedOnSameDS.apply(v), 0.01);
-        assertEquals(originalModel.apply(v), updatedOnEmptyDS.apply(v), 0.01);
+        assertEquals(originalMdl.apply(v), updatedOnSameDS.apply(v), 0.01);
+        assertEquals(originalMdl.apply(v), updatedOnEmptyDS.apply(v), 0.01);
     }
 }
index fcc20bd..08ff95d 100644 (file)
@@ -19,16 +19,14 @@ package org.apache.ignite.ml.tree.randomforest;
 
 import java.util.ArrayList;
 import java.util.HashMap;
-import java.util.List;
 import java.util.Map;
+import org.apache.ignite.ml.common.TrainerTest;
 import org.apache.ignite.ml.composition.ModelsComposition;
 import org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictionsAggregator;
 import org.apache.ignite.ml.dataset.feature.FeatureMeta;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.Parameterized;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
@@ -36,28 +34,7 @@ import static org.junit.Assert.assertTrue;
 /**
  * Tests for {@link RandomForestRegressionTrainer}.
  */
-@RunWith(Parameterized.class)
-public class RandomForestRegressionTrainerTest {
-    /**
-     * Number of parts to be tested.
-     */
-    private static final int[] partsToBeTested = new int[] {1, 2, 3, 4, 5, 7};
-
-    /**
-     * Number of partitions.
-     */
-    @Parameterized.Parameter
-    public int parts;
-
-    @Parameterized.Parameters(name = "Data divided on {0} partitions")
-    public static Iterable<Integer[]> data() {
-        List<Integer[]> res = new ArrayList<>();
-        for (int part : partsToBeTested)
-            res.add(new Integer[] {part});
-
-        return res;
-    }
-
+public class RandomForestRegressionTrainerTest extends TrainerTest {
     /** */
     @Test
     public void testFit() {
@@ -76,7 +53,7 @@ public class RandomForestRegressionTrainerTest {
         for(int i = 0; i < 4; i++)
             meta.add(new FeatureMeta("", i, false));
         RandomForestRegressionTrainer trainer = new RandomForestRegressionTrainer(meta)
-            .withCountOfTrees(5)
+            .withAmountOfTrees(5)
             .withFeaturesCountSelectionStrgy(x -> 2);
 
         ModelsComposition mdl = trainer.fit(sample, parts, (k, v) -> VectorUtils.of(v), (k, v) -> k);
@@ -102,15 +79,15 @@ public class RandomForestRegressionTrainerTest {
         for (int i = 0; i < 4; i++)
             meta.add(new FeatureMeta("", i, false));
         RandomForestRegressionTrainer trainer = new RandomForestRegressionTrainer(meta)
-            .withCountOfTrees(100)
+            .withAmountOfTrees(100)
             .withFeaturesCountSelectionStrgy(x -> 2);
 
-        ModelsComposition originalModel = trainer.fit(sample, parts, (k, v) -> VectorUtils.of(k), (k, v) -> v);
-        ModelsComposition updatedOnSameDS = trainer.update(originalModel, sample, parts, (k, v) -> VectorUtils.of(k), (k, v) -> v);
-        ModelsComposition updatedOnEmptyDS = trainer.update(originalModel, new HashMap<double[], Double>(), parts, (k, v) -> VectorUtils.of(k), (k, v) -> v);
+        ModelsComposition originalMdl = trainer.fit(sample, parts, (k, v) -> VectorUtils.of(k), (k, v) -> v);
+        ModelsComposition updatedOnSameDS = trainer.update(originalMdl, sample, parts, (k, v) -> VectorUtils.of(k), (k, v) -> v);
+        ModelsComposition updatedOnEmptyDS = trainer.update(originalMdl, new HashMap<double[], Double>(), parts, (k, v) -> VectorUtils.of(k), (k, v) -> v);
 
         Vector v = VectorUtils.of(5, 0.5, 0.05, 0.005);
-        assertEquals(originalModel.apply(v), updatedOnSameDS.apply(v), 0.1);
-        assertEquals(originalModel.apply(v), updatedOnEmptyDS.apply(v), 0.1);
+        assertEquals(originalMdl.apply(v), updatedOnSameDS.apply(v), 0.1);
+        assertEquals(originalMdl.apply(v), updatedOnEmptyDS.apply(v), 0.1);
     }
 }
index 9fa7f0e..eb81b36 100644 (file)
@@ -34,7 +34,7 @@ public class RandomForestTest {
     private final long seed = 0;
 
     /** Count of trees. */
-    private final int countOfTrees = 10;
+    private final int cntOfTrees = 10;
 
     /** Min imp delta. */
     private final double minImpDelta = 1.0;
@@ -55,7 +55,7 @@ public class RandomForestTest {
 
     /** Rf. */
     private RandomForestClassifierTrainer rf = new RandomForestClassifierTrainer(meta)
-        .withCountOfTrees(countOfTrees)
+        .withAmountOfTrees(cntOfTrees)
         .withSeed(seed)
         .withFeaturesCountSelectionStrgy(x -> 4)
         .withMaxDepth(maxDepth)
index 7ca6411..a82bb95 100644 (file)
@@ -44,7 +44,7 @@ public class GiniFeatureHistogramTest extends ImpurityHistogramTest {
 
     /** */
     @Before
-    public void setUp() throws Exception {
+    public void setUp() {
         feature2Meta.setMinVal(-5);
         feature2Meta.setBucketSize(1);
     }
@@ -129,12 +129,13 @@ public class GiniFeatureHistogramTest extends ImpurityHistogramTest {
 
         NodeSplit catSplit = catFeatureSmpl1.findBestSplit().get();
         NodeSplit contSplit = contFeatureSmpl1.findBestSplit().get();
-        assertEquals(1.0, catSplit.getValue(), 0.01);
-        assertEquals(-0.5, contSplit.getValue(), 0.01);
+        assertEquals(1.0, catSplit.getVal(), 0.01);
+        assertEquals(-0.5, contSplit.getVal(), 0.01);
         assertFalse(emptyHist.findBestSplit().isPresent());
         assertFalse(catFeatureSmpl2.findBestSplit().isPresent());
     }
 
+    /** */
     @Test
     public void testOfSums() {
         int sampleId = 0;
@@ -148,22 +149,22 @@ public class GiniFeatureHistogramTest extends ImpurityHistogramTest {
 
         List<GiniHistogram> partitions1 = new ArrayList<>();
         List<GiniHistogram> partitions2 = new ArrayList<>();
-        int countOfPartitions = rnd.nextInt(1000);
-        for(int i = 0; i < countOfPartitions; i++) {
+        int cntOfPartitions = rnd.nextInt(1000);
+        for (int i = 0; i < cntOfPartitions; i++) {
             partitions1.add(new GiniHistogram(sampleId,lblMapping, bucketMeta1));
             partitions2.add(new GiniHistogram(sampleId,lblMapping, bucketMeta2));
         }
 
         int datasetSize = rnd.nextInt(10000);
         for(int i = 0; i < datasetSize; i++) {
-            BootstrappedVector vec = randomVector(2, 1, true);
+            BootstrappedVector vec = randomVector(true);
             vec.features().set(1, (vec.features().get(1) * 100) % 100);
 
             forAllHist1.addElement(vec);
             forAllHist2.addElement(vec);
-            int partitionId = rnd.nextInt(countOfPartitions);
-            partitions1.get(partitionId).addElement(vec);
-            partitions2.get(partitionId).addElement(vec);
+            int partId = rnd.nextInt(cntOfPartitions);
+            partitions1.get(partId).addElement(vec);
+            partitions2.get(partId).addElement(vec);
         }
 
         checkSums(forAllHist1, partitions1);
index df4c154..54bd0df 100644 (file)
@@ -32,9 +32,17 @@ import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertTrue;
 
+/**
+ * Tests for {@link ImpurityHistogram}.
+ */
 public class ImpurityHistogramTest {
-    protected static final int COUNT_OF_CLASSES = 3;
-    protected static final Map<Double, Integer> lblMapping = new HashMap<>();
+    /** Count of classes. */
+    private static final int COUNT_OF_CLASSES = 3;
+
+    /** Lbl mapping. */
+    static final Map<Double, Integer> lblMapping = new HashMap<>();
+
+    /** Random generator. */
     protected Random rnd = new Random();
 
     static {
@@ -42,28 +50,41 @@ public class ImpurityHistogramTest {
             lblMapping.put((double)i, i);
     }
 
-    protected void checkBucketIds(Set<Integer> bucketIdsSet, Integer[] expected) {
+    /** */
+    void checkBucketIds(Set<Integer> bucketIdsSet, Integer[] exp) {
         Integer[] bucketIds = new Integer[bucketIdsSet.size()];
         bucketIdsSet.toArray(bucketIds);
-        assertArrayEquals(expected, bucketIds);
+        assertArrayEquals(exp, bucketIds);
     }
 
-    protected void checkCounters(ObjectHistogram<BootstrappedVector> hist, double[] expected) {
+    /** */
+    void checkCounters(ObjectHistogram<BootstrappedVector> hist, double[] exp) {
         double[] counters = hist.buckets().stream().mapToDouble(x -> hist.getValue(x).get()).toArray();
-        assertArrayEquals(expected, counters, 0.01);
+        assertArrayEquals(exp, counters, 0.01);
     }
 
-    protected BootstrappedVector randomVector(int countOfFeatures, int countOfSampes, boolean isClassification) {
-        double[] features = DoubleStream.generate(() -> rnd.nextDouble()).limit(countOfFeatures).toArray();
-        int[] counters = IntStream.generate(() -> rnd.nextInt(10)).limit(countOfSampes).toArray();
+    /**
+     * Generates random vector.
+     *
+     * @param isClassification Is classification.
+     */
+    BootstrappedVector randomVector(boolean isClassification) {
+        double[] features = DoubleStream.generate(() -> rnd.nextDouble()).limit(2).toArray();
+        int[] counters = IntStream.generate(() -> rnd.nextInt(10)).limit(1).toArray();
         double lbl = isClassification ? Math.abs(rnd.nextInt() % COUNT_OF_CLASSES) : rnd.nextDouble();
         return new BootstrappedVector(VectorUtils.of(features), lbl, counters);
     }
 
-    protected <T extends Histogram<BootstrappedVector, T>> void checkSums(T expected, List<T> partitions) {
+    /**
+     * Check sums.
+     *
+     * @param exp Expected value.
+     * @param partitions Partitions.
+     */
+    <T extends Histogram<BootstrappedVector, T>> void checkSums(T exp, List<T> partitions) {
         T leftSum = partitions.stream().reduce((x,y) -> x.plus(y)).get();
         T rightSum = partitions.stream().reduce((x,y) -> y.plus(x)).get();
-        assertTrue(expected.isEqualTo(leftSum));
-        assertTrue(expected.isEqualTo(rightSum));
+        assertTrue(exp.isEqualTo(leftSum));
+        assertTrue(exp.isEqualTo(rightSum));
     }
 }
index 41bd5ff..872ecec 100644 (file)
@@ -82,6 +82,7 @@ public class MSEHistogramTest extends ImpurityHistogramTest {
         checkCounters(contHist2.getSumOfSquaredLabels(), new double[]{ 2 * 5 * 5, 2 * 1 * 1, 1 * 4 * 4, 1 * 2 * 2, 0 * 3 * 3 });
     }
 
+    /** */
     @Test
     public void testOfSums() {
         int sampleId = 0;
@@ -95,22 +96,24 @@ public class MSEHistogramTest extends ImpurityHistogramTest {
 
         List<MSEHistogram> partitions1 = new ArrayList<>();
         List<MSEHistogram> partitions2 = new ArrayList<>();
-        int countOfPartitions = rnd.nextInt(100);
-        for(int i = 0; i < countOfPartitions; i++) {
+
+        int cntOfPartitions = rnd.nextInt(100);
+
+        for (int i = 0; i < cntOfPartitions; i++) {
             partitions1.add(new MSEHistogram(sampleId, bucketMeta1));
             partitions2.add(new MSEHistogram(sampleId, bucketMeta2));
         }
 
         int datasetSize = rnd.nextInt(1000);
         for(int i = 0; i < datasetSize; i++) {
-            BootstrappedVector vec = randomVector(2, 1, false);
+            BootstrappedVector vec = randomVector(false);
             vec.features().set(1, (vec.features().get(1) * 100) % 100);
 
             forAllHist1.addElement(vec);
             forAllHist2.addElement(vec);
-            int partitionId = rnd.nextInt(countOfPartitions);
-            partitions1.get(partitionId).addElement(vec);
-            partitions2.get(partitionId).addElement(vec);
+            int partId = rnd.nextInt(cntOfPartitions);
+            partitions1.get(partId).addElement(vec);
+            partitions2.get(partId).addElement(vec);
         }
 
         checkSums(forAllHist1, partitions1);
index 79ee3b6..c65a9ac 100644 (file)
@@ -54,13 +54,14 @@ public class NormalDistributionStatisticsComputerTest {
         new BootstrappedVector(VectorUtils.of(9, 0, 11, 2, 13, 3, 15), 0., null),
     });
 
+    /** Normal Distribution Statistics Computer. */
     private NormalDistributionStatisticsComputer computer = new NormalDistributionStatisticsComputer();
 
     /** */
     @Test
     public void computeStatsOnPartitionTest() {
-        List<NormalDistributionStatistics> result = computer.computeStatsOnPartition(partition, meta);
-        NormalDistributionStatistics[] expected = new NormalDistributionStatistics[] {
+        List<NormalDistributionStatistics> res = computer.computeStatsOnPartition(partition, meta);
+        NormalDistributionStatistics[] exp = new NormalDistributionStatistics[] {
             new NormalDistributionStatistics(0, 9, 285, 45, 10),
             new NormalDistributionStatistics(Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY, 0, 0, 10),
             new NormalDistributionStatistics(2, 11, 505, 65, 10),
@@ -70,15 +71,15 @@ public class NormalDistributionStatisticsComputerTest {
             new NormalDistributionStatistics(6, 15, 1185, 105, 10),
         };
 
-        assertEquals(expected.length, result.size());
-        for (int i = 0; i < expected.length; i++) {
-            NormalDistributionStatistics expectedStat = expected[i];
-            NormalDistributionStatistics resultStat = result.get(i);
-            assertEquals(expectedStat.mean(), resultStat.mean(), 0.01);
-            assertEquals(expectedStat.variance(), resultStat.variance(), 0.01);
-            assertEquals(expectedStat.std(), resultStat.std(), 0.01);
-            assertEquals(expectedStat.min(), resultStat.min(), 0.01);
-            assertEquals(expectedStat.max(), resultStat.max(), 0.01);
+        assertEquals(exp.length, res.size());
+        for (int i = 0; i < exp.length; i++) {
+            NormalDistributionStatistics expStat = exp[i];
+            NormalDistributionStatistics resStat = res.get(i);
+            assertEquals(expStat.mean(), resStat.mean(), 0.01);
+            assertEquals(expStat.variance(), resStat.variance(), 0.01);
+            assertEquals(expStat.std(), resStat.std(), 0.01);
+            assertEquals(expStat.min(), resStat.min(), 0.01);
+            assertEquals(expStat.max(), resStat.max(), 0.01);
         }
     }
 
@@ -105,8 +106,8 @@ public class NormalDistributionStatisticsComputerTest {
             new NormalDistributionStatistics(0, 9, 285, 45, 10)
         );
 
-        List<NormalDistributionStatistics> result = computer.reduceStats(left, right, meta);
-        NormalDistributionStatistics[] expected = new NormalDistributionStatistics[] {
+        List<NormalDistributionStatistics> res = computer.reduceStats(left, right, meta);
+        NormalDistributionStatistics[] exp = new NormalDistributionStatistics[] {
             new NormalDistributionStatistics(0, 15, 1470, 150, 20),
             new NormalDistributionStatistics(Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY, 0, 0, 10),
             new NormalDistributionStatistics(2, 13, 1310, 150, 20),
@@ -116,15 +117,15 @@ public class NormalDistributionStatisticsComputerTest {
             new NormalDistributionStatistics(0, 15, 1470, 150, 20)
         };
 
-        assertEquals(expected.length, result.size());
-        for (int i = 0; i < expected.length; i++) {
-            NormalDistributionStatistics expectedStat = expected[i];
-            NormalDistributionStatistics resultStat = result.get(i);
-            assertEquals(expectedStat.mean(), resultStat.mean(), 0.01);
-            assertEquals(expectedStat.variance(), resultStat.variance(), 0.01);
-            assertEquals(expectedStat.std(), resultStat.std(), 0.01);
-            assertEquals(expectedStat.min(), resultStat.min(), 0.01);
-            assertEquals(expectedStat.max(), resultStat.max(), 0.01);
+        assertEquals(exp.length, res.size());
+        for (int i = 0; i < exp.length; i++) {
+            NormalDistributionStatistics expStat = exp[i];
+            NormalDistributionStatistics resStat = res.get(i);
+            assertEquals(expStat.mean(), resStat.mean(), 0.01);
+            assertEquals(expStat.variance(), resStat.variance(), 0.01);
+            assertEquals(expStat.std(), resStat.std(), 0.01);
+            assertEquals(expStat.min(), resStat.min(), 0.01);
+            assertEquals(expStat.max(), resStat.max(), 0.01);
         }
     }
 }