IGNITE-10544: [ML] GMM with fixed components
authorAlexey Platonov <aplatonovv@gmail.com>
Tue, 12 Feb 2019 14:10:55 +0000 (17:10 +0300)
committerYury Babak <ybabak@gridgain.com>
Tue, 12 Feb 2019 14:10:55 +0000 (17:10 +0300)
This closes #6063

29 files changed:
modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/CovarianceMatricesAggregator.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmModel.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmPartitionData.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmTrainer.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/MeanWithClusterProbAggregator.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/package-info.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/SingularMatrixException.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/matrix/AbstractMatrix.java
modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/matrix/LUDecomposition.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/matrix/Matrix.java
modules/ml/src/main/java/org/apache/ignite/ml/math/stat/Distribution.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/math/stat/DistributionMixture.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/math/stat/MultivariateGaussianDistribution.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/math/stat/package-info.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/nn/ReplicatedVectorMatrix.java
modules/ml/src/test/java/org/apache/ignite/ml/clustering/ClusteringTestSuite.java
modules/ml/src/test/java/org/apache/ignite/ml/clustering/gmm/CovarianceMatricesAggregatorTest.java [new file with mode: 0644]
modules/ml/src/test/java/org/apache/ignite/ml/clustering/gmm/GmmModelTest.java [new file with mode: 0644]
modules/ml/src/test/java/org/apache/ignite/ml/clustering/gmm/GmmPartitionDataTest.java [new file with mode: 0644]
modules/ml/src/test/java/org/apache/ignite/ml/clustering/gmm/GmmTrainerIntegrationTest.java [new file with mode: 0644]
modules/ml/src/test/java/org/apache/ignite/ml/clustering/gmm/GmmTrainerTest.java [new file with mode: 0644]
modules/ml/src/test/java/org/apache/ignite/ml/clustering/gmm/MeanWithClusterProbAggregatorTest.java [new file with mode: 0644]
modules/ml/src/test/java/org/apache/ignite/ml/math/MathImplLocalTestSuite.java
modules/ml/src/test/java/org/apache/ignite/ml/math/MathImplMainTestSuite.java
modules/ml/src/test/java/org/apache/ignite/ml/math/primitives/matrix/LUDecompositionTest.java [new file with mode: 0644]
modules/ml/src/test/java/org/apache/ignite/ml/math/stat/DistributionMixtureTest.java [new file with mode: 0644]
modules/ml/src/test/java/org/apache/ignite/ml/math/stat/MultivariateGaussianDistributionTest.java [new file with mode: 0644]
modules/ml/src/test/java/org/apache/ignite/ml/math/stat/StatsTestSuite.java [new file with mode: 0644]
modules/yardstick/README.txt

diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/CovarianceMatricesAggregator.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/CovarianceMatricesAggregator.java
new file mode 100644 (file)
index 0000000..c36d030
--- /dev/null
@@ -0,0 +1,196 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.clustering.gmm;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import org.apache.ignite.internal.util.typedef.internal.A;
+import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.math.primitives.matrix.Matrix;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+
+/**
+ * This class encapsulates statistics aggregation logic for feature vector covariance matrix computation of one GMM
+ * component (cluster).
+ */
+public class CovarianceMatricesAggregator implements Serializable {
+    /** Serial version uid. */
+    private static final long serialVersionUID = 4163253784526780812L;
+
+    /** Mean vector. */
+    private final Vector mean;
+
+    /** Weighted by P(c|xi) sum of (xi - mean) * (xi - mean)^T values. */
+    private Matrix weightedSum;
+
+    /** Count of rows. */
+    private int rowCount;
+
+    /**
+     * Creates an instance of CovarianceMatricesAggregator.
+     *
+     * @param mean Mean vector.
+     */
+    CovarianceMatricesAggregator(Vector mean) {
+        this.mean = mean;
+    }
+
+    /**
+     * Creates an instance of CovarianceMatricesAggregator.
+     *
+     * @param mean Mean vector.
+     * @param weightedSum Weighted sums for covariace computation.
+     * @param rowCount Count of rows.
+     */
+    CovarianceMatricesAggregator(Vector mean, Matrix weightedSum, int rowCount) {
+        this.mean = mean;
+        this.weightedSum = weightedSum;
+        this.rowCount = rowCount;
+    }
+
+    /**
+     * Computes covatiation matrices for feature vector for each GMM component.
+     *
+     * @param dataset Dataset.
+     * @param clusterProbs Probabilities of each GMM component.
+     * @param means Means for each GMM component.
+     */
+    static List<Matrix> computeCovariances(Dataset<EmptyContext, GmmPartitionData> dataset,
+        Vector clusterProbs, Vector[] means) {
+
+        List<CovarianceMatricesAggregator> aggregators = dataset.compute(
+            data -> map(data, means),
+            CovarianceMatricesAggregator::reduce
+        );
+
+        if (aggregators == null)
+            return Collections.emptyList();
+
+        List<Matrix> res = new ArrayList<>();
+        for (int i = 0; i < aggregators.size(); i++)
+            res.add(aggregators.get(i).covariance(clusterProbs.get(i)));
+
+        return res;
+    }
+
+    /**
+     * @param x Feature vector (xi).
+     * @param pcxi P(c|xi) for GMM component "c" and vector xi.
+     */
+    void add(Vector x, double pcxi) {
+        Matrix deltaCol = x.minus(mean).toMatrix(false);
+        Matrix weightedCovComponent = deltaCol.times(deltaCol.transpose()).times(pcxi);
+        if (weightedSum == null)
+            weightedSum = weightedCovComponent;
+        else
+            weightedSum = weightedSum.plus(weightedCovComponent);
+        rowCount += 1;
+    }
+
+    /**
+     * @param other Other.
+     * @return sum of aggregators.
+     */
+    CovarianceMatricesAggregator plus(CovarianceMatricesAggregator other) {
+        A.ensure(this.mean.equals(other.mean), "this.mean == other.mean");
+
+        return new CovarianceMatricesAggregator(
+            mean,
+            this.weightedSum.plus(other.weightedSum),
+            this.rowCount + other.rowCount
+        );
+    }
+
+    /**
+     * Map stage for covariance computation over dataset.
+     *
+     * @param data Data partition.
+     * @param means Means vector.
+     * @return Covariance aggregators.
+     */
+    static List<CovarianceMatricesAggregator> map(GmmPartitionData data, Vector[] means) {
+        int countOfComponents = means.length;
+
+        List<CovarianceMatricesAggregator> aggregators = new ArrayList<>();
+        for (int i = 0; i < countOfComponents; i++)
+            aggregators.add(new CovarianceMatricesAggregator(means[i]));
+
+        for (int i = 0; i < data.size(); i++) {
+            for (int c = 0; c < countOfComponents; c++)
+                aggregators.get(c).add(data.getX(i), data.pcxi(c, i));
+        }
+
+        return aggregators;
+    }
+
+    /**
+     * @param clusterProb GMM component probability.
+     * @return computed covariance matrix.
+     */
+    private Matrix covariance(double clusterProb) {
+        return weightedSum.divide(rowCount * clusterProb);
+    }
+
+    /**
+     * Reduce stage for covariance computation over dataset.
+     *
+     * @param l first partition.
+     * @param r second partition.
+     */
+    static List<CovarianceMatricesAggregator> reduce(List<CovarianceMatricesAggregator> l,
+        List<CovarianceMatricesAggregator> r) {
+
+        A.ensure(l != null || r != null, "Both partitions cannot equal to null");
+
+        if (l == null || l.isEmpty())
+            return r;
+        if (r == null || r.isEmpty())
+            return l;
+
+        A.ensure(l.size() == r.size(), "l.size() == r.size()");
+        List<CovarianceMatricesAggregator> res = new ArrayList<>();
+        for (int i = 0; i < l.size(); i++)
+            res.add(l.get(i).plus(r.get(i)));
+
+        return res;
+    }
+
+    /**
+     * @return mean vector.
+     */
+    Vector mean() {
+        return mean.copy();
+    }
+
+    /**
+     * @return weighted sum.
+     */
+    Matrix weightedSum() {
+        return weightedSum.copy();
+    }
+
+    /**
+     * @return rows count.
+     */
+    public int rowCount() {
+        return rowCount;
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmModel.java
new file mode 100644 (file)
index 0000000..b2a526e
--- /dev/null
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.clustering.gmm;
+
+import java.util.List;
+import org.apache.ignite.ml.IgniteModel;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.stat.DistributionMixture;
+import org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution;
+
+/**
+ * Gaussian Mixture Model. This algorithm represents a soft clustering model where each cluster is gaussian distribution
+ * with own mean value and covariation matrix. Such model can predict cluster using maximum likelihood priciple (see
+ * {@link #predict(Vector)}). Also * this model can estimate probability of given vector (see {@link #prob(Vector)}) and
+ * compute likelihood vector where each component of it is a probability of cluster of mixture (see {@link
+ * #likelihood(Vector)}).
+ */
+public class GmmModel extends DistributionMixture<MultivariateGaussianDistribution> implements IgniteModel<Vector, Double> {
+    /** Serial version uid. */
+    private static final long serialVersionUID = -4484174539118240037L;
+
+    /**
+     * Creates an instance of GmmModel.
+     *
+     * @param componentProbs Probabilities of components.
+     * @param distributions Gaussian distributions for each component.
+     */
+    public GmmModel(Vector componentProbs, List<MultivariateGaussianDistribution> distributions) {
+        super(componentProbs, distributions);
+    }
+
+    /** {@inheritDoc} */
+    @Override public Double predict(Vector input) {
+        return (double)likelihood(input).maxElement().index();
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmPartitionData.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmPartitionData.java
new file mode 100644 (file)
index 0000000..942c511
--- /dev/null
@@ -0,0 +1,195 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.clustering.gmm;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+import org.apache.ignite.internal.util.typedef.internal.A;
+import org.apache.ignite.ml.dataset.PartitionDataBuilder;
+import org.apache.ignite.ml.dataset.UpstreamEntry;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.environment.LearningEnvironment;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution;
+import org.apache.ignite.ml.structures.LabeledVector;
+import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
+
+/**
+ * Partition data for GMM algorithm. Unlike partition data for other algorithms this class aggregate probabilities of
+ * each cluster of gaussians mixture (see {@link #pcxi}) for each vector in dataset.
+ */
+class GmmPartitionData implements AutoCloseable {
+    /** Dataset vectors. */
+    private List<LabeledVector<Double>> xs;
+
+    /** P(cluster|xi) where second idx is a cluster and first is a index of point. */
+    private double[][] pcxi;
+
+    /**
+     * Creates an instance of GmmPartitionData.
+     *
+     * @param xs Dataset.
+     * @param pcxi P(cluster|xi) per cluster.
+     */
+    GmmPartitionData(List<LabeledVector<Double>> xs, double[][] pcxi) {
+        A.ensure(xs.size() == pcxi.length, "xs.size() == pcxi.length");
+
+        this.xs = xs;
+        this.pcxi = pcxi;
+    }
+
+    /**
+     * @param i Index of vector in partition.
+     * @return Vector.
+     */
+    public Vector getX(int i) {
+        return xs.get(i).features();
+    }
+
+    /**
+     * @return all vectors from partition.
+     */
+    public List<LabeledVector<Double>> getAllXs() {
+        return Collections.unmodifiableList(xs);
+    }
+
+    /**
+     * @param cluster Cluster id.
+     * @param i Vector id.
+     * @return P(cluster | xi) value.
+     */
+    public double pcxi(int cluster, int i) {
+        return pcxi[i][cluster];
+    }
+
+    /**
+     * @param cluster Cluster id.
+     * @param i Vector id.
+     * @param value P(cluster|xi) value.
+     */
+    public void setPcxi(int cluster, int i, double value) {
+        pcxi[i][cluster] = value;
+    }
+
+    /**
+     * @return size of dataset partition.
+     */
+    public int size() {
+        return pcxi.length;
+    }
+
+    /**
+     * @return count of GMM components.
+     */
+    public int countOfComponents() {
+        return size() != 0 ? pcxi[0].length : 0;
+    }
+
+    /** {@inheritDoc} */
+    @Override public void close() throws Exception {
+        //NOP
+    }
+
+    /**
+     * Builder for GMM partition data.
+     */
+    public static class Builder<K, V> implements PartitionDataBuilder<K, V, EmptyContext, GmmPartitionData> {
+        /** Serial version uid. */
+        private static final long serialVersionUID = 1847063348042022561L;
+
+        /** Extractor. */
+        private final FeatureLabelExtractor<K, V, Double> extractor;
+
+        /** Count of components of mixture. */
+        private final int countOfComponents;
+
+        /**
+         * Creates an instance of Builder.
+         *
+         * @param extractor Extractor.
+         * @param countOfComponents Count of components.
+         */
+        public Builder(FeatureLabelExtractor<K, V, Double> extractor, int countOfComponents) {
+            this.extractor = extractor;
+            this.countOfComponents = countOfComponents;
+        }
+
+        /** {@inheritDoc} */
+        @Override public GmmPartitionData build(LearningEnvironment env, Iterator<UpstreamEntry<K, V>> upstreamData,
+            long upstreamDataSize, EmptyContext ctx) {
+
+            int rowsCount = Math.toIntExact(upstreamDataSize);
+            List<LabeledVector<Double>> xs = new ArrayList<>(rowsCount);
+            double[][] pcxi = new double[rowsCount][countOfComponents];
+
+            while (upstreamData.hasNext()) {
+                UpstreamEntry<K, V> entry = upstreamData.next();
+                LabeledVector<Double> x = extractor.extract(entry.getKey(), entry.getValue());
+                xs.add(x);
+            }
+
+            return new GmmPartitionData(xs, pcxi);
+        }
+    }
+
+    /**
+     * Sets P(c|xi) = 1 for closest cluster "c" for each vector in partition data using initial means as cluster centers
+     * (like in k-means).
+     *
+     * @param initMeans Initial means.
+     */
+    static void estimateLikelihoodClusters(GmmPartitionData data, Vector[] initMeans) {
+        for (int i = 0; i < data.size(); i++) {
+            int closestClusterId = -1;
+            double minSquaredDist = Double.MAX_VALUE;
+
+            Vector x = data.getX(i);
+            for (int c = 0; c < initMeans.length; c++) {
+                double distance = initMeans[c].getDistanceSquared(x);
+                if (distance < minSquaredDist) {
+                    closestClusterId = c;
+                    minSquaredDist = distance;
+                }
+            }
+
+            data.setPcxi(closestClusterId, i, 1.);
+        }
+    }
+
+    /**
+     * Updates P(c|xi) values in partitions given components probabilities and components of GMM.
+     *
+     * @param clusterProbs Component probabilities.
+     * @param components Components.
+     */
+    static void updatePcxi(GmmPartitionData data, Vector clusterProbs,
+        List<MultivariateGaussianDistribution> components) {
+
+        for (int i = 0; i < data.size(); i++) {
+            Vector x = data.getX(i);
+            double normalizer = 0.0;
+            for (int c = 0; c < clusterProbs.size(); c++)
+                normalizer += components.get(c).prob(x) * clusterProbs.get(c);
+
+            for (int c = 0; c < clusterProbs.size(); c++)
+                data.pcxi[i][c] = (components.get(c).prob(x) * clusterProbs.get(c)) / normalizer;
+        }
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmTrainer.java
new file mode 100644 (file)
index 0000000..883de24
--- /dev/null
@@ -0,0 +1,403 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.clustering.gmm;
+
+import java.util.ArrayList;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Optional;
+import java.util.stream.Collectors;
+import java.util.stream.DoubleStream;
+import java.util.stream.Stream;
+import org.apache.ignite.internal.util.typedef.internal.A;
+import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.environment.LearningEnvironment;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
+import org.apache.ignite.ml.environment.logging.MLLogger;
+import org.apache.ignite.ml.math.exceptions.SingularMatrixException;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.primitives.matrix.Matrix;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution;
+import org.apache.ignite.ml.structures.DatasetRow;
+import org.apache.ignite.ml.trainers.DatasetTrainer;
+import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
+import org.jetbrains.annotations.NotNull;
+
+/**
+ * Traner for GMM model.
+ */
+public class GmmTrainer extends DatasetTrainer<GmmModel, Double> {
+    /** Min divergence of mean vectors beween iterations. If divergence will less then trainer stops. */
+    private double eps = 1e-3;
+
+    /** Count of components. */
+    private int countOfComponents = 2;
+
+    /** Max count of iterations. */
+    private int maxCountOfIterations = 10;
+
+    /** Initial means. */
+    private Vector[] initialMeans;
+
+    /** Maximum initialization tries count. */
+    private int maxCountOfInitTries = 3;
+
+    /**
+     * Creates an instance of GmmTrainer.
+     */
+    public GmmTrainer() {
+    }
+
+    /**
+     * Creates an instance of GmmTrainer.
+     *
+     * @param countOfComponents Count of components.
+     * @param maxCountOfIterations Max count of iterations.
+     */
+    public GmmTrainer(int countOfComponents, int maxCountOfIterations) {
+        this.countOfComponents = countOfComponents;
+        this.maxCountOfIterations = maxCountOfIterations;
+    }
+
+    /** {@inheritDoc} */
+    @Override public <K, V> GmmModel fit(DatasetBuilder<K, V> datasetBuilder,
+        FeatureLabelExtractor<K, V, Double> extractor) {
+        return updateModel(null, datasetBuilder, extractor);
+    }
+
+    /**
+     * Sets numberOfComponents.
+     *
+     * @param numberOfComponents Number of components.
+     * @return trainer.
+     */
+    public GmmTrainer withCountOfComponents(int numberOfComponents) {
+        A.ensure(numberOfComponents > 0, "Number of components in GMM cannot equal 0");
+
+        this.countOfComponents = numberOfComponents;
+        initialMeans = null;
+        return this;
+    }
+
+    /**
+     * Sets initial means.
+     *
+     * @param means Initial means for clusters.
+     * @return trainer.
+     */
+    public GmmTrainer withInitialMeans(List<Vector> means) {
+        A.notEmpty(means, "GMM should start with non empty initial components list");
+
+        this.initialMeans = means.toArray(new Vector[means.size()]);
+        this.countOfComponents = means.size();
+        return this;
+    }
+
+    /**
+     * Sets max count of iterations
+     *
+     * @param maxCountOfIterations Max count of iterations.
+     * @return trainer.
+     */
+    public GmmTrainer withMaxCountIterations(int maxCountOfIterations) {
+        A.ensure(maxCountOfIterations > 0, "Max count iterations cannot be less or equal zero or negative");
+
+        this.maxCountOfIterations = maxCountOfIterations;
+        return this;
+    }
+
+    /**
+     * Sets min divergence beween iterations.
+     *
+     * @param eps Eps.
+     * @return trainer.
+     */
+    public GmmTrainer withEps(double eps) {
+        A.ensure(eps > 0 && eps < 1.0, "Min divergence beween iterations should be between 0.0 and 1.0");
+
+        this.eps = eps;
+        return this;
+    }
+
+    /**
+     * Sets MaxCountOfInitTries parameter. If means initialization were unsuccessfull then algorithm try to reinitialize
+     * means randomly MaxCountOfInitTries times.
+     *
+     * @param maxCountOfInitTries Max count of init tries.
+     * @return trainer.
+     */
+    public GmmTrainer withMaxCountOfInitTries(int maxCountOfInitTries) {
+        A.ensure(maxCountOfInitTries > 0, "Max initialization count should be great than zero.");
+
+        this.maxCountOfInitTries = maxCountOfInitTries;
+        return this;
+    }
+
+    /**
+     * Trains model based on the specified data.
+     *
+     * @param dataset Dataset.
+     * @return GMM model.
+     */
+    private Optional<GmmModel> fit(Dataset<EmptyContext, GmmPartitionData> dataset) {
+        return init(dataset).map(model -> updateModel(dataset, model));
+    }
+
+    /**
+     * Gets older model and returns updated model on given data.
+     *
+     * @param dataset Dataset.
+     * @param model Model.
+     * @return updated model.
+     */
+    @NotNull private GmmModel updateModel(Dataset<EmptyContext, GmmPartitionData> dataset, GmmModel model) {
+        boolean isConverged = false;
+        int countOfIterations = 0;
+        while (!isConverged) {
+            MeanWithClusterProbAggregator.AggregatedStats stats = MeanWithClusterProbAggregator.aggreateStats(dataset);
+            Vector clusterProbs = stats.clusterProbabilities();
+            Vector[] newMeans = stats.means().toArray(new Vector[countOfComponents]);
+
+            A.ensure(newMeans.length == model.countOfComponents(), "newMeans.size() == count of components");
+            A.ensure(newMeans[0].size() == initialMeans[0].size(), "newMeans[0].size() == initialMeans[0].size()");
+            List<Matrix> newCovs = CovarianceMatricesAggregator.computeCovariances(dataset, clusterProbs, newMeans);
+
+            try {
+                List<MultivariateGaussianDistribution> components = buildComponents(newMeans, newCovs);
+                GmmModel newModel = new GmmModel(clusterProbs, components);
+
+                countOfIterations += 1;
+                isConverged = isConverged(model, newModel) || countOfIterations > maxCountOfIterations;
+                model = newModel;
+
+                if (!isConverged)
+                    dataset.compute(data -> GmmPartitionData.updatePcxi(data, clusterProbs, components));
+            }
+            catch (SingularMatrixException | IllegalArgumentException e) {
+                String msg = "Cannot construct non-singular covariance matrix by data. " +
+                    "Try to select other initial means or other model trainer. Iterations will stop.";
+                environment.logger().log(MLLogger.VerboseLevel.HIGH, msg);
+                isConverged = true;
+            }
+        }
+
+        return model;
+    }
+
+    /**
+     * Init means and covariances.
+     *
+     * @param dataset Dataset.
+     * @return initial model.
+     */
+    private Optional<GmmModel> init(Dataset<EmptyContext, GmmPartitionData> dataset) {
+        int countOfTries = 0;
+
+        while (true) {
+            try {
+                if (initialMeans == null) {
+                    List<List<Vector>> randomMeansSets = Stream.of(dataset.compute(
+                        selectNRandomXsMapper(countOfComponents),
+                        GmmTrainer::selectNRandomXsReducer
+                    )).map(this::asList).collect(Collectors.toList());
+
+                    A.ensure(
+                        randomMeansSets.stream().mapToInt(List::size).sum() >= countOfComponents,
+                        "There is not enough data in dataset for select N random means"
+                    );
+
+                    initialMeans = new Vector[countOfComponents];
+                    int j = 0;
+                    for (int i = 0; i < countOfComponents; ) {
+                        List<Vector> randomMeansPart = randomMeansSets.get(j);
+                        if (!randomMeansPart.isEmpty()) {
+                            initialMeans[i] = randomMeansPart.remove(0);
+                            i++;
+                        }
+
+                        j = (j + 1) % randomMeansSets.size();
+                    }
+                }
+
+                dataset.compute(data -> GmmPartitionData.estimateLikelihoodClusters(data, initialMeans));
+
+                List<Matrix> initialCovs = CovarianceMatricesAggregator.computeCovariances(
+                    dataset,
+                    VectorUtils.fill(1. / countOfComponents, countOfComponents),
+                    initialMeans
+                );
+
+                if (initialCovs.isEmpty())
+                    return Optional.empty();
+
+                List<MultivariateGaussianDistribution> distributions = new ArrayList<>();
+                for (int i = 0; i < countOfComponents; i++)
+                    distributions.add(new MultivariateGaussianDistribution(initialMeans[i], initialCovs.get(i)));
+
+                return Optional.of(new GmmModel(
+                    VectorUtils.of(DoubleStream.generate(() -> 1. / countOfComponents).limit(countOfComponents).toArray()),
+                    distributions
+                ));
+            }
+            catch (SingularMatrixException | IllegalArgumentException e) {
+                String msg = "Cannot construct non-singular covariance matrix by data. " +
+                    "Try to select other initial means or other model trainer [number of tries = " + countOfTries + "]";
+                environment.logger().log(MLLogger.VerboseLevel.HIGH, msg);
+                countOfTries += 1;
+                initialMeans = null;
+                if (countOfTries >= maxCountOfInitTries)
+                    throw new RuntimeException(msg, e);
+            }
+        }
+    }
+
+    /**
+     * @param vectors Array of vectors.
+     * @return list of vectors.
+     */
+    private LinkedList<Vector> asList(Vector... vectors) {
+        LinkedList<Vector> res = new LinkedList<>();
+        for (Vector v : vectors)
+            res.addFirst(v);
+        return res;
+    }
+
+    /**
+     * Create new model components with provided means and covariances.
+     *
+     * @param means Means.
+     * @param covs Covariances.
+     * @return gmm components.
+     */
+    private List<MultivariateGaussianDistribution> buildComponents(Vector[] means, List<Matrix> covs) {
+        A.ensure(means.length == covs.size(), "means.size() == covs.size()");
+
+        List<MultivariateGaussianDistribution> res = new ArrayList<>();
+        for (int i = 0; i < means.length; i++)
+            res.add(new MultivariateGaussianDistribution(means[i], covs.get(i)));
+
+        return res;
+    }
+
+    /**
+     * Check algorithm covergency. If it's true then algorithm stops.
+     *
+     * @param oldModel Old model.
+     * @param newModel New model.
+     * @return true if algorithm gonverged.
+     */
+    private boolean isConverged(GmmModel oldModel, GmmModel newModel) {
+        A.ensure(oldModel.countOfComponents() == newModel.countOfComponents(),
+            "oldModel.countOfComponents() == newModel.countOfComponents()");
+
+        for (int i = 0; i < oldModel.countOfComponents(); i++) {
+            MultivariateGaussianDistribution d1 = oldModel.distributions().get(i);
+            MultivariateGaussianDistribution d2 = newModel.distributions().get(i);
+
+            if (Math.sqrt(d1.mean().getDistanceSquared(d2.mean())) >= eps)
+                return false;
+        }
+
+        return true;
+    }
+
+    /** {@inheritDoc} */
+    @Override public boolean isUpdateable(GmmModel mdl) {
+        return mdl.countOfComponents() == countOfComponents;
+    }
+
+    /** {@inheritDoc} */
+    @Override protected <K, V> GmmModel updateModel(GmmModel mdl, DatasetBuilder<K, V> datasetBuilder,
+        FeatureLabelExtractor<K, V, Double> extractor) {
+
+        try (Dataset<EmptyContext, GmmPartitionData> dataset = datasetBuilder.build(
+            LearningEnvironmentBuilder.defaultBuilder(),
+            new EmptyContextBuilder<>(),
+            new GmmPartitionData.Builder<>(extractor, countOfComponents)
+        )) {
+            if (mdl != null) {
+                if (initialMeans != null)
+                    environment.logger().log(MLLogger.VerboseLevel.HIGH, "Initial means will be replaced by model from update");
+                initialMeans = mdl.distributions().stream()
+                    .map(MultivariateGaussianDistribution::mean)
+                    .toArray(Vector[]::new);
+            }
+
+            Optional<GmmModel> model = fit(dataset);
+            if (model.isPresent())
+                return model.get();
+            else if (mdl != null)
+                return mdl;
+            else
+                throw new IllegalArgumentException("Cannot learn model on empty dataset.");
+        }
+        catch (Exception e) {
+            throw new RuntimeException(e);
+        }
+    }
+
+    /**
+     * Returns mapper for initial means selection.
+     *
+     * @param n Number of components.
+     * @return mapper.
+     */
+    private static IgniteBiFunction<GmmPartitionData, LearningEnvironment, Vector[][]> selectNRandomXsMapper(int n) {
+        return (data, env) -> {
+            Vector[] result;
+
+            if (data.size() <= n) {
+                result = data.getAllXs().stream()
+                    .map(DatasetRow::features)
+                    .toArray(Vector[]::new);
+            }
+            else {
+                result = env.randomNumbersGenerator().ints(0, data.size())
+                    .distinct().mapToObj(data::getX).limit(n)
+                    .toArray(Vector[]::new);
+            }
+
+            return new Vector[][] {result};
+        };
+    }
+
+    /**
+     * Reducer for means selection.
+     *
+     * @return reducer.
+     */
+    private static Vector[][] selectNRandomXsReducer(Vector[][] l, Vector[][] r) {
+        A.ensure(l != null || r != null, "l != null || r != null");
+
+        if (l == null)
+            return r;
+        if (r == null)
+            return l;
+
+        Vector[][] res = new Vector[l.length + r.length][];
+        System.arraycopy(l, 0, res, 0, l.length);
+        System.arraycopy(r, 0, res, l.length, r.length);
+
+        return res;
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/MeanWithClusterProbAggregator.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/MeanWithClusterProbAggregator.java
new file mode 100644 (file)
index 0000000..58044a7
--- /dev/null
@@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.clustering.gmm;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.stream.Collectors;
+import org.apache.ignite.internal.util.typedef.internal.A;
+import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+
+/**
+ * Statistics aggregator for mean values and cluster probabilities computing.
+ */
+class MeanWithClusterProbAggregator implements Serializable {
+    /** Serial version uid. */
+    private static final long serialVersionUID = 2700985110021774629L;
+
+    /** Weighted sum of vectors. */
+    private Vector weightedXsSum;
+
+    /** P(C|xi) sum. */
+    private double pcxiSum;
+
+    /** Aggregated partition data size. */
+    private int rowCount;
+
+    /**
+     * Create an instance of MeanWithClusterProbAggregator.
+     */
+    MeanWithClusterProbAggregator() {
+    }
+
+    /**
+     * Create an instance of MeanWithClusterProbAggregator.
+     *
+     * @param weightedXsSum Weighted sum of vectors.
+     * @param pcxiSum P(c|xi) sum.
+     * @param rowCount Count of rows.
+     */
+    MeanWithClusterProbAggregator(Vector weightedXsSum, double pcxiSum, int rowCount) {
+        this.weightedXsSum = weightedXsSum;
+        this.pcxiSum = pcxiSum;
+        this.rowCount = rowCount;
+    }
+
+    /**
+     * @return compute mean value by aggregated data.
+     */
+    public Vector mean() {
+        return weightedXsSum.divide(pcxiSum);
+    }
+
+    /**
+     * @return compute cluster probability by aggreated data.
+     */
+    public double clusterProb() {
+        return pcxiSum / rowCount;
+    }
+
+    /**
+     * Aggregates statistics for means and cluster probabilities computing given dataset.
+     *
+     * @param dataset Dataset.
+     */
+    public static AggregatedStats aggreateStats(Dataset<EmptyContext, GmmPartitionData> dataset) {
+        return new AggregatedStats(dataset.compute(
+            MeanWithClusterProbAggregator::map,
+            MeanWithClusterProbAggregator::reduce
+        ));
+    }
+
+    /**
+     * Add vector to statistics.
+     *
+     * @param x Vector.
+     * @param pcxi P(c|xi).
+     */
+    void add(Vector x, double pcxi) {
+        A.ensure(pcxi >= 0 && pcxi <= 1., "pcxi >= 0 && pcxi <= 1.");
+
+        Vector weightedVector = x.times(pcxi);
+        if (weightedXsSum == null)
+            weightedXsSum = weightedVector;
+        else
+            weightedXsSum = weightedXsSum.plus(weightedVector);
+
+        pcxiSum += pcxi;
+        rowCount += 1;
+    }
+
+    /**
+     * @param other Other.
+     * @return Sum of aggregators.
+     */
+    MeanWithClusterProbAggregator plus(MeanWithClusterProbAggregator other) {
+        return new MeanWithClusterProbAggregator(
+            weightedXsSum.plus(other.weightedXsSum),
+            pcxiSum + other.pcxiSum,
+            rowCount + other.rowCount
+        );
+    }
+
+    /**
+     * Map stage for statistics aggregation.
+     *
+     * @param data Partition data.
+     * @return Aggregated statistics.
+     */
+    static List<MeanWithClusterProbAggregator> map(GmmPartitionData data) {
+        List<MeanWithClusterProbAggregator> aggregators = new ArrayList<>();
+        for (int i = 0; i < data.countOfComponents(); i++)
+            aggregators.add(new MeanWithClusterProbAggregator());
+
+        for (int i = 0; i < data.size(); i++) {
+            for (int c = 0; c < data.countOfComponents(); c++)
+                aggregators.get(c).add(data.getX(i), data.pcxi(c, i));
+        }
+
+        return aggregators;
+    }
+
+    /**
+     * Reduce stage for statistics aggregation.
+     *
+     * @param l Reft part.
+     * @param r Right part.
+     * @return Sum of statistics for each cluster.
+     */
+    static List<MeanWithClusterProbAggregator> reduce(List<MeanWithClusterProbAggregator> l,
+        List<MeanWithClusterProbAggregator> r) {
+        A.ensure(l != null || r != null, "Both partitions cannot equal to null");
+
+        if (l == null || l.isEmpty())
+            return r;
+        if (r == null || r.isEmpty())
+            return l;
+
+        A.ensure(l.size() == r.size(), "l.size() == r.size()");
+        List<MeanWithClusterProbAggregator> res = new ArrayList<>();
+        for (int i = 0; i < l.size(); i++)
+            res.add(l.get(i).plus(r.get(i)));
+
+        return res;
+    }
+
+    /**
+     * Computed cluster probabilities and means.
+     */
+    public static class AggregatedStats {
+        /** Cluster probs. */
+        private final Vector clusterProbs;
+
+        /** Means. */
+        private final List<Vector> means;
+
+        /**
+         * Creates an instance of AggregatedStats.
+         *
+         * @param stats Statistics.
+         */
+        private AggregatedStats(List<MeanWithClusterProbAggregator> stats) {
+            clusterProbs = VectorUtils.of(stats.stream()
+                .mapToDouble(MeanWithClusterProbAggregator::clusterProb)
+                .toArray()
+            );
+
+            means = stats.stream()
+                .map(MeanWithClusterProbAggregator::mean)
+                .collect(Collectors.toList());
+        }
+
+        /**
+         * @return clusters probabilities.
+         */
+        public Vector clusterProbabilities() {
+            return clusterProbs;
+        }
+
+        /**
+         * @return means.
+         */
+        public List<Vector> means() {
+            return means;
+        }
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/package-info.java
new file mode 100644 (file)
index 0000000..ed0485e
--- /dev/null
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. --> Contains Gauss Mixture Model clustering algorithm (see {@link
+ * org.apache.ignite.ml.clustering.gmm.GmmModel}).
+ */
+package org.apache.ignite.ml.clustering.gmm;
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/SingularMatrixException.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/SingularMatrixException.java
new file mode 100644 (file)
index 0000000..c7acc80
--- /dev/null
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.exceptions;
+
+/**
+ * Exception to be thrown when a non-singular matrix is expected.
+ */
+public class SingularMatrixException extends MathIllegalArgumentException {
+    /** */
+    public SingularMatrixException() {
+        super("Regular (or non-singular) matrix expected.");
+    }
+
+    /** */
+    public SingularMatrixException(String format, Object... args) {
+        super(format, args);
+    }
+}
index 161b336..c9ab0f6 100644 (file)
@@ -918,4 +918,34 @@ public abstract class AbstractMatrix implements Matrix {
     @Override public String toString() {
         return "Matrix [rows=" + rowSize() + ", cols=" + columnSize() + "]";
     }
+
+    /** {@inheritDoc} */
+    @Override public double determinant() {
+        //TODO: IGNITE-11192, use nd4j
+        try (LUDecomposition dec = new LUDecomposition(this)) {
+            return dec.determinant();
+        }
+    }
+
+    /** {@inheritDoc} */
+    @Override public Matrix inverse() {
+        if (rowSize() != columnSize())
+            throw new CardinalityException(rowSize(), columnSize());
+
+        //TODO: IGNITE-11192, use nd4j
+        try (LUDecomposition dec = new LUDecomposition(this)) {
+            return dec.solve(likeIdentity());
+        }
+    }
+
+    /** */
+    protected Matrix likeIdentity() {
+        int n = rowSize();
+        Matrix res = like(n, n);
+
+        for (int i = 0; i < n; i++)
+            res.setX(i, i, 1.0);
+
+        return res;
+    }
 }
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/matrix/LUDecomposition.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/matrix/LUDecomposition.java
new file mode 100644 (file)
index 0000000..f418936
--- /dev/null
@@ -0,0 +1,380 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.primitives.matrix;
+
+import org.apache.ignite.ml.math.exceptions.CardinalityException;
+import org.apache.ignite.ml.math.exceptions.SingularMatrixException;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+
+import static org.apache.ignite.ml.math.util.MatrixUtil.copy;
+import static org.apache.ignite.ml.math.util.MatrixUtil.like;
+import static org.apache.ignite.ml.math.util.MatrixUtil.likeVector;
+
+/**
+ * Calculates the LU-decomposition of a square matrix.
+ * <p>
+ * This class is inspired by class from Apache Common Math with similar name.</p>
+ *
+ * @see <a href="http://mathworld.wolfram.com/LUDecomposition.html">MathWorld</a>
+ * @see <a href="http://en.wikipedia.org/wiki/LU_decomposition">Wikipedia</a>
+ *
+ * <p>TODO: IGNITE-11192, remove after resolve this issue.</p>
+ */
+public class LUDecomposition implements AutoCloseable {
+    /** Default bound to determine effective singularity in LU decomposition. */
+    private static final double DEFAULT_TOO_SMALL = 1e-11;
+
+    /** Pivot permutation associated with LU decomposition. */
+    private final Vector pivot;
+
+    /** Parity of the permutation associated with the LU decomposition. */
+    private boolean even;
+    /** Singularity indicator. */
+    private boolean singular;
+
+    /** Cached value of L. */
+    private Matrix cachedL;
+    /** Cached value of U. */
+    private Matrix cachedU;
+    /** Cached value of P. */
+    private Matrix cachedP;
+
+    /** Original matrix. */
+    private Matrix matrix;
+
+    /** Entries of LU decomposition. */
+    private Matrix lu;
+
+    /**
+     * Calculates the LU-decomposition of the given matrix. This constructor uses 1e-11 as default value for the
+     * singularity threshold.
+     *
+     * @param matrix Matrix to decompose.
+     * @throws CardinalityException if matrix is not square.
+     */
+    public LUDecomposition(Matrix matrix) {
+        this(matrix, DEFAULT_TOO_SMALL);
+    }
+
+    /**
+     * Calculates the LUP-decomposition of the given matrix.
+     *
+     * @param matrix Matrix to decompose.
+     * @param singularityThreshold threshold (based on partial row norm).
+     * @throws CardinalityException if matrix is not square.
+     */
+    public LUDecomposition(Matrix matrix, double singularityThreshold) {
+        assert matrix != null;
+
+        int rows = matrix.rowSize();
+        int cols = matrix.columnSize();
+
+        if (rows != cols)
+            throw new CardinalityException(rows, cols);
+
+        this.matrix = matrix;
+
+        lu = copy(matrix);
+
+        pivot = likeVector(matrix);
+
+        for (int i = 0; i < pivot.size(); i++)
+            pivot.setX(i, i);
+
+        even = true;
+        singular = false;
+
+        cachedL = null;
+        cachedU = null;
+        cachedP = null;
+
+        for (int col = 0; col < cols; col++) {
+
+            //upper
+            for (int row = 0; row < col; row++) {
+                Vector luRow = lu.viewRow(row);
+                double sum = luRow.get(col);
+
+                for (int i = 0; i < row; i++)
+                    sum -= luRow.getX(i) * lu.getX(i, col);
+
+                luRow.setX(col, sum);
+            }
+
+            // permutation row
+            int max = col;
+
+            double largest = Double.NEGATIVE_INFINITY;
+
+            // lower
+            for (int row = col; row < rows; row++) {
+                Vector luRow = lu.viewRow(row);
+                double sum = luRow.getX(col);
+
+                for (int i = 0; i < col; i++)
+                    sum -= luRow.getX(i) * lu.getX(i, col);
+
+                luRow.setX(col, sum);
+
+                if (Math.abs(sum) > largest) {
+                    largest = Math.abs(sum);
+                    max = row;
+                }
+            }
+
+            // Singularity check
+            if (Math.abs(lu.getX(max, col)) < singularityThreshold) {
+                singular = true;
+                return;
+            }
+
+            // Pivot if necessary
+            if (max != col) {
+                double tmp;
+                Vector luMax = lu.viewRow(max);
+                Vector luCol = lu.viewRow(col);
+
+                for (int i = 0; i < cols; i++) {
+                    tmp = luMax.getX(i);
+                    luMax.setX(i, luCol.getX(i));
+                    luCol.setX(i, tmp);
+                }
+
+                int temp = (int)pivot.getX(max);
+                pivot.setX(max, pivot.getX(col));
+                pivot.setX(col, temp);
+
+                even = !even;
+            }
+
+            // Divide the lower elements by the "winning" diagonal elt.
+            final double luDiag = lu.getX(col, col);
+
+            for (int row = col + 1; row < cols; row++) {
+                double val = lu.getX(row, col) / luDiag;
+                lu.setX(row, col, val);
+            }
+        }
+    }
+
+    /**
+     * Destroys decomposition components and other internal components of decomposition.
+     */
+    @Override public void close() {
+        if (cachedL != null)
+            cachedL.destroy();
+        if (cachedU != null)
+            cachedU.destroy();
+        if (cachedP != null)
+            cachedP.destroy();
+        lu.destroy();
+    }
+
+    /**
+     * Returns the matrix L of the decomposition.
+     * <p>L is a lower-triangular matrix</p>
+     *
+     * @return the L matrix (or null if decomposed matrix is singular).
+     */
+    public Matrix getL() {
+        if ((cachedL == null) && !singular) {
+            final int m = pivot.size();
+
+            cachedL = like(matrix);
+            cachedL.assign(0.0);
+
+            for (int i = 0; i < m; ++i) {
+                for (int j = 0; j < i; ++j)
+                    cachedL.setX(i, j, lu.getX(i, j));
+
+                cachedL.setX(i, i, 1.0);
+            }
+        }
+
+        return cachedL;
+    }
+
+    /**
+     * Returns the matrix U of the decomposition.
+     * <p>U is an upper-triangular matrix</p>
+     *
+     * @return the U matrix (or null if decomposed matrix is singular).
+     */
+    public Matrix getU() {
+        if ((cachedU == null) && !singular) {
+            final int m = pivot.size();
+
+            cachedU = like(matrix);
+            cachedU.assign(0.0);
+
+            for (int i = 0; i < m; ++i)
+                for (int j = i; j < m; ++j)
+                    cachedU.setX(i, j, lu.getX(i, j));
+        }
+
+        return cachedU;
+    }
+
+    /**
+     * Returns the P rows permutation matrix.
+     * <p>P is a sparse matrix with exactly one element set to 1.0 in
+     * each row and each column, all other elements being set to 0.0.</p>
+     * <p>The positions of the 1 elements are given by the {@link #getPivot()
+     * pivot permutation vector}.</p>
+     *
+     * @return the P rows permutation matrix (or null if decomposed matrix is singular).
+     * @see #getPivot()
+     */
+    public Matrix getP() {
+        if ((cachedP == null) && !singular) {
+            final int m = pivot.size();
+
+            cachedP = like(matrix);
+            cachedP.assign(0.0);
+
+            for (int i = 0; i < m; ++i)
+                cachedP.setX(i, (int)pivot.get(i), 1.0);
+        }
+
+        return cachedP;
+    }
+
+    /**
+     * Returns the pivot permutation vector.
+     *
+     * @return the pivot permutation vector.
+     * @see #getP()
+     */
+    public Vector getPivot() {
+        return pivot.copy();
+    }
+
+    /**
+     * Return the determinant of the matrix.
+     *
+     * @return determinant of the matrix.
+     */
+    public double determinant() {
+        if (singular)
+            return 0;
+
+        final int m = pivot.size();
+        double determinant = even ? 1 : -1;
+
+        for (int i = 0; i < m; i++)
+            determinant *= lu.getX(i, i);
+
+        return determinant;
+    }
+
+    /**
+     * @param b Vector to solve using this decomposition.
+     * @return Solution vector.
+     */
+    public Vector solve(Vector b) {
+        final int m = pivot.size();
+
+        if (b.size() != m)
+            throw new CardinalityException(b.size(), m);
+
+        if (singular)
+            throw new SingularMatrixException();
+
+        final double[] bp = new double[m];
+
+        // Apply permutations to b
+        for (int row = 0; row < m; row++)
+            bp[row] = b.get((int)pivot.get(row));
+
+        // Solve LY = b
+        for (int col = 0; col < m; col++) {
+            final double bpCol = bp[col];
+
+            for (int i = col + 1; i < m; i++)
+                bp[i] -= bpCol * lu.get(i, col);
+        }
+
+        // Solve UX = Y
+        for (int col = m - 1; col >= 0; col--) {
+            bp[col] /= lu.get(col, col);
+            final double bpCol = bp[col];
+
+            for (int i = 0; i < col; i++)
+                bp[i] -= bpCol * lu.get(i, col);
+        }
+
+        return b.like(m).assign(bp);
+    }
+
+    /**
+     * @param b Matrix to solve using this decomposition.
+     * @return Solution matrix.
+     */
+    public Matrix solve(Matrix b) {
+        final int m = pivot.size();
+
+        if (b.rowSize() != m)
+            throw new CardinalityException(b.rowSize(), m);
+
+        if (singular)
+            throw new SingularMatrixException();
+
+        final int nColB = b.columnSize();
+
+        // Apply permutations to b
+        final double[][] bp = new double[m][nColB];
+        for (int row = 0; row < m; row++) {
+            final double[] bpRow = bp[row];
+            final int pRow = (int)pivot.get(row);
+
+            for (int col = 0; col < nColB; col++)
+                bpRow[col] = b.get(pRow, col);
+        }
+
+        // Solve LY = b
+        for (int col = 0; col < m; col++) {
+            final double[] bpCol = bp[col];
+            for (int i = col + 1; i < m; i++) {
+                final double[] bpI = bp[i];
+                final double luICol = lu.get(i, col);
+
+                for (int j = 0; j < nColB; j++)
+                    bpI[j] -= bpCol[j] * luICol;
+            }
+        }
+
+        // Solve UX = Y
+        for (int col = m - 1; col >= 0; col--) {
+            final double[] bpCol = bp[col];
+            final double luDiag = lu.getX(col, col);
+
+            for (int j = 0; j < nColB; j++)
+                bpCol[j] /= luDiag;
+
+            for (int i = 0; i < col; i++) {
+                final double[] bpI = bp[i];
+                final double luICol = lu.get(i, col);
+
+                for (int j = 0; j < nColB; j++)
+                    bpI[j] -= bpCol[j] * luICol;
+            }
+        }
+
+        return b.like(b.rowSize(), b.columnSize()).assign(bp);
+    }
+}
index 52b2dc1..6d55483 100644 (file)
@@ -520,4 +520,19 @@ public interface Matrix extends MetaAttributes, Externalizable, StorageOpsMetric
      * @param f Function used for replacing.
      */
     public void compute(int row, int col, IgniteTriFunction<Integer, Integer, Double, Double> f);
+
+    /**
+     * Returns matrix determinant using Laplace theorem.
+     *
+     * @return A determinant for this matrix.
+     * @throws CardinalityException Thrown if matrix is not square.
+     */
+    public double determinant();
+
+    /**
+     * Returns the inverse matrix of this matrix
+     *
+     * @return Inverse of this matrix
+     */
+    public Matrix inverse();
 }
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/stat/Distribution.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/stat/Distribution.java
new file mode 100644 (file)
index 0000000..fe236c3
--- /dev/null
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.stat;
+
+import java.io.Serializable;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+
+/**
+ * Interface for distributions.
+ */
+public interface Distribution extends Serializable {
+    /**
+     * @param x Vector.
+     * @return probability of vector.
+     */
+    public double prob(Vector x);
+
+    /**
+     * @return dimension of vector space.
+     */
+    public int dimension();
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/stat/DistributionMixture.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/stat/DistributionMixture.java
new file mode 100644 (file)
index 0000000..29bb22f
--- /dev/null
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.stat;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.DoubleStream;
+import org.apache.ignite.internal.util.typedef.internal.A;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+
+/**
+ * Mixture of distributions class where each component has own probability and probability of input vector can be
+ * computed as a sum of likelihoods of each component.
+ *
+ * @param <C> distributions mixture component class.
+ */
+public abstract class DistributionMixture<C extends Distribution> implements Distribution {
+    /** */
+    private static final double EPS = 1e-5;
+
+    /** Component probabilities. */
+    private final Vector componentProbs;
+
+    /** Distributions. */
+    private final List<C> distributions;
+
+    /** Dimension. */
+    private final int dimension;
+
+    /**
+     * Creates an instance of DistributionMixture.
+     *
+     * @param componentProbs Component probabilities.
+     * @param distributions Distributions.
+     */
+    public DistributionMixture(Vector componentProbs, List<C> distributions) {
+        A.ensure(DoubleStream.of(componentProbs.asArray()).allMatch(v -> v > 0), "All distribution components should be greater than zero");
+        A.ensure(Math.abs(componentProbs.sum() - 1.) < EPS, "Components distribution should be nomalized");
+
+        A.ensure(!distributions.isEmpty(), "Distribution mixture should have at least one component");
+
+        final int dimension = distributions.get(0).dimension();
+        A.ensure(dimension > 0, "Dimension should be greater than zero");
+        A.ensure(distributions.stream().allMatch(d -> d.dimension() == dimension), "All distributions should have same dimension");
+
+        this.distributions = distributions;
+        this.componentProbs = componentProbs;
+        this.dimension = dimension;
+    }
+
+    /** {@inheritDoc} */
+    @Override public double prob(Vector x) {
+        return likelihood(x).sum();
+    }
+
+    /**
+     * @param x Vector.
+     * @return Vector consists of likelihoods of each mixture components.
+     */
+    public Vector likelihood(Vector x) {
+        return VectorUtils.of(distributions.stream().mapToDouble(f -> f.prob(x)).toArray())
+            .times(componentProbs);
+    }
+
+    /**
+     * @return an amount of components.
+     */
+    public int countOfComponents() {
+        return componentProbs.size();
+    }
+
+    /**
+     * @return component probabilities.
+     */
+    public Vector componentsProbs() {
+        return componentProbs.copy();
+    }
+
+    /**
+     * @return list of components.
+     */
+    public List<C> distributions() {
+        return Collections.unmodifiableList(distributions);
+    }
+
+    /** {@inheritDoc} */
+    @Override public int dimension() {
+        return dimension;
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/stat/MultivariateGaussianDistribution.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/stat/MultivariateGaussianDistribution.java
new file mode 100644 (file)
index 0000000..2051e56
--- /dev/null
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.stat;
+
+import org.apache.ignite.internal.util.typedef.internal.A;
+import org.apache.ignite.ml.math.primitives.matrix.Matrix;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+
+/**
+ * Distribution represents multidimentional gaussian distribution.
+ */
+public class MultivariateGaussianDistribution implements Distribution {
+    /** Mean. */
+    private Vector mean;
+
+    /** Covariance^-1. */
+    private Matrix invCovariance;
+
+    /** Normalizer. */
+    private double normalizer;
+
+    /**
+     * Constructs an instance of MultivariateGaussianDistribution.
+     *
+     * @param mean Mean.
+     * @param covariance Covariance.
+     */
+    public MultivariateGaussianDistribution(Vector mean, Matrix covariance) {
+        A.ensure(covariance.columnSize() == covariance.rowSize(), "Covariance matrix should be square");
+        A.ensure(mean.size() == covariance.rowSize(), "Covariance matrix should be built from same space as mean vector");
+
+        this.mean = mean;
+        invCovariance = covariance.inverse();
+
+        double determinant = covariance.determinant();
+        A.ensure(determinant > 0, "Covariance matrix should be positife definite");
+        normalizer = Math.pow(2 * Math.PI, ((double)invCovariance.rowSize()) / 2) * Math.sqrt(determinant);
+    }
+
+    /** {@inheritDoc} */
+    @Override public double prob(Vector x) {
+        Vector delta = x.minus(mean);
+        Matrix ePower = delta.toMatrix(true)
+            .times(invCovariance)
+            .times(delta.toMatrix(false))
+            .times(-0.5);
+        assert ePower.columnSize() == 1 && ePower.rowSize() == 1;
+
+        return Math.pow(Math.E, ePower.get(0, 0)) / normalizer;
+    }
+
+    /** {@inheritDoc} */
+    @Override public int dimension() {
+        return mean.size();
+    }
+
+    /**
+     * @return mean vector.
+     */
+    public Vector mean() {
+        return mean.copy();
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/stat/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/stat/package-info.java
new file mode 100644 (file)
index 0000000..69d1230
--- /dev/null
@@ -0,0 +1,21 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. --> Contains utility classes for distributions.
+ */
+package org.apache.ignite.ml.math.stat;
index 2aa0437..eaf98a6 100644 (file)
@@ -551,4 +551,18 @@ class ReplicatedVectorMatrix implements Matrix {
     public Vector replicant() {
         return vector;
     }
+
+    /** {@inheritDoc} */
+    @Override public double determinant() {
+        // If matrix is not square throw exception.
+        checkCardinality(vector.size(), replicationCnt);
+
+        // If matrix is 1x1 then determinant is its single element otherwise there are linear dependence and determinant is 0.
+        return vector.size() > 1 ? 0 : vector.get(1);
+    }
+
+    /** {@inheritDoc} */
+    @Override public Matrix inverse() {
+        throw new UnsupportedOperationException();
+    }
 }
index 80538a0..cae8bef 100644 (file)
 
 package org.apache.ignite.ml.clustering;
 
+import org.apache.ignite.ml.clustering.gmm.CovarianceMatricesAggregatorTest;
+import org.apache.ignite.ml.clustering.gmm.GmmModelTest;
+import org.apache.ignite.ml.clustering.gmm.GmmPartitionDataTest;
+import org.apache.ignite.ml.clustering.gmm.GmmTrainerIntegrationTest;
+import org.apache.ignite.ml.clustering.gmm.GmmTrainerTest;
+import org.apache.ignite.ml.clustering.gmm.MeanWithClusterProbAggregatorTest;
 import org.junit.runner.RunWith;
 import org.junit.runners.Suite;
 
@@ -25,8 +31,17 @@ import org.junit.runners.Suite;
  */
 @RunWith(Suite.class)
 @Suite.SuiteClasses({
+    //k-means tests
     KMeansTrainerTest.class,
-    KMeansModelTest.class
+    KMeansModelTest.class,
+
+    //GMM tests
+    CovarianceMatricesAggregatorTest.class,
+    GmmModelTest.class,
+    GmmPartitionDataTest.class,
+    MeanWithClusterProbAggregatorTest.class,
+    GmmTrainerTest.class,
+    GmmTrainerIntegrationTest.class
 })
 public class ClusteringTestSuite {
 }
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/gmm/CovarianceMatricesAggregatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/gmm/CovarianceMatricesAggregatorTest.java
new file mode 100644 (file)
index 0000000..b9753b5
--- /dev/null
@@ -0,0 +1,161 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.clustering.gmm;
+
+import java.util.Arrays;
+import java.util.List;
+import org.apache.ignite.ml.math.primitives.matrix.Matrix;
+import org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.structures.LabeledVector;
+import org.junit.Test;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * Tests for {@link CovarianceMatricesAggregator}.
+ */
+public class CovarianceMatricesAggregatorTest {
+    /** */
+    @Test
+    public void testAdd() {
+        CovarianceMatricesAggregator agg = new CovarianceMatricesAggregator(VectorUtils.of(1., 0.));
+        assertEquals(0, agg.rowCount());
+
+        agg.add(VectorUtils.of(1., 0.), 100.);
+        assertArrayEquals(VectorUtils.of(1., 0.).asArray(), agg.mean().asArray(), 1e-4);
+        assertArrayEquals(
+            agg.weightedSum().getStorage().data(),
+            fromArray(2, 0., 0., 0., 0.).getStorage().data(),
+            1e-4
+        );
+        assertEquals(1, agg.rowCount());
+
+        agg.add(VectorUtils.of(0., 1.), 10.);
+        assertArrayEquals(VectorUtils.of(1., 0.).asArray(), agg.mean().asArray(), 1e-4);
+        assertArrayEquals(
+            agg.weightedSum().getStorage().data(),
+            fromArray(2, 10., -10., -10., 10.).getStorage().data(),
+            1e-4
+        );
+        assertEquals(2, agg.rowCount());
+    }
+
+    /** */
+    @Test
+    public void testPlus() {
+        Vector mean = VectorUtils.of(1, 0);
+
+        CovarianceMatricesAggregator agg1 = new CovarianceMatricesAggregator(mean, identity(2), 1);
+        CovarianceMatricesAggregator agg2 = new CovarianceMatricesAggregator(mean, identity(2).times(2), 3);
+        CovarianceMatricesAggregator res = agg1.plus(agg2);
+
+        assertArrayEquals(mean.asArray(), res.mean().asArray(), 1e-4);
+        assertArrayEquals(identity(2).times(3).getStorage().data(), res.weightedSum().getStorage().data(), 1e-4);
+        assertEquals(4, res.rowCount());
+    }
+
+    /** */
+    @Test
+    public void testReduce() {
+        Vector mean1 = VectorUtils.of(1, 0);
+        Vector mean2 = VectorUtils.of(0, 1);
+
+        CovarianceMatricesAggregator agg11 = new CovarianceMatricesAggregator(mean1, identity(2), 1);
+        CovarianceMatricesAggregator agg12 = new CovarianceMatricesAggregator(mean1, identity(2), 1);
+
+        CovarianceMatricesAggregator agg21 = new CovarianceMatricesAggregator(mean2, identity(2), 2);
+        CovarianceMatricesAggregator agg22 = new CovarianceMatricesAggregator(mean2, identity(2), 2);
+
+        List<CovarianceMatricesAggregator> result = CovarianceMatricesAggregator.reduce(
+            Arrays.asList(agg11, agg21),
+            Arrays.asList(agg12, agg22)
+        );
+
+        assertEquals(2, result.size());
+        CovarianceMatricesAggregator res1 = result.get(0);
+        CovarianceMatricesAggregator res2 = result.get(1);
+
+        assertArrayEquals(mean1.asArray(), res1.mean().asArray(), 1e-4);
+        assertArrayEquals(identity(2).times(2).getStorage().data(), res1.weightedSum().getStorage().data(), 1e-4);
+        assertEquals(2, res1.rowCount());
+
+        assertArrayEquals(mean2.asArray(), res2.mean().asArray(), 1e-4);
+        assertArrayEquals(identity(2).times(2).getStorage().data(), res2.weightedSum().getStorage().data(), 1e-4);
+        assertEquals(4, res2.rowCount());
+    }
+
+    /** */
+    @Test
+    public void testMap() {
+        List<LabeledVector<Double>> xs = Arrays.asList(
+            new LabeledVector<>(VectorUtils.of(1, 0), 0.),
+            new LabeledVector<>(VectorUtils.of(0, 1), 0.),
+            new LabeledVector<>(VectorUtils.of(1, 1), 0.)
+        );
+
+        double[][] pcxi = new double[][] {
+            new double[] {0.1, 0.2},
+            new double[] {0.4, 0.3},
+            new double[] {0.5, 0.6}
+        };
+
+        GmmPartitionData data = new GmmPartitionData(xs, pcxi);
+        Vector mean1 = VectorUtils.of(1, 1);
+        Vector mean2 = VectorUtils.of(0, 1);
+        List<CovarianceMatricesAggregator> result = CovarianceMatricesAggregator.map(data, new Vector[] {mean1, mean2});
+
+        assertEquals(pcxi[0].length, result.size());
+
+        CovarianceMatricesAggregator res1 = result.get(0);
+        assertArrayEquals(mean1.asArray(), res1.mean().asArray(), 1e-4);
+        assertArrayEquals(
+            res1.weightedSum().getStorage().data(),
+            fromArray(2, 0.4, 0., 0., 0.1).getStorage().data(),
+            1e-4
+        );
+        assertEquals(3, res1.rowCount());
+
+        CovarianceMatricesAggregator res2 = result.get(1);
+        assertArrayEquals(mean2.asArray(), res2.mean().asArray(), 1e-4);
+        assertArrayEquals(
+            res2.weightedSum().getStorage().data(),
+            fromArray(2, 0.8, -0.2, -0.2, 0.2).getStorage().data(),
+            1e-4
+        );
+        assertEquals(3, res2.rowCount());
+    }
+
+    /** */
+    private Matrix identity(int n) {
+        DenseMatrix matrix = new DenseMatrix(n, n);
+        for (int i = 0; i < n; i++)
+            matrix.set(i, i, 1.);
+        return matrix;
+    }
+
+    /** */
+    private Matrix fromArray(int n, double... values) {
+        assertTrue(n == values.length / n);
+
+        return new DenseMatrix(values, n);
+    }
+}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/gmm/GmmModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/gmm/GmmModelTest.java
new file mode 100644 (file)
index 0000000..0c90738
--- /dev/null
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.clustering.gmm;
+
+import java.util.Arrays;
+import java.util.Collections;
+import org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution;
+import org.apache.ignite.ml.math.util.MatrixUtil;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Tests for {@link GmmModelTest}.
+ */
+public class GmmModelTest {
+    /** */
+    @Test
+    public void testTrivialCasesWithOneComponent() {
+        Vector mean = VectorUtils.of(1., 2.);
+        DenseMatrix covariance = MatrixUtil.fromList(Arrays.asList(
+            VectorUtils.of(1, -0.5),
+            VectorUtils.of(-0.5, 1)),
+            true
+        );
+
+        GmmModel gmm = new GmmModel(
+            VectorUtils.of(1.0),
+            Collections.singletonList(new MultivariateGaussianDistribution(mean, covariance))
+        );
+
+        Assert.assertEquals(2, gmm.dimension());
+        Assert.assertEquals(1, gmm.countOfComponents());
+        Assert.assertEquals(VectorUtils.of(1.), gmm.componentsProbs());
+        Assert.assertEquals(0., gmm.predict(mean), 0.01);
+        Assert.assertEquals(1, gmm.likelihood(mean).size());
+        Assert.assertEquals(0.183, gmm.likelihood(mean).get(0), 0.01);
+        Assert.assertEquals(0.183, gmm.prob(mean), 0.01);
+    }
+
+    /** */
+    @Test
+    public void testTwoComponents() {
+        Vector mean1 = VectorUtils.of(1., 2.);
+        DenseMatrix covariance1 = MatrixUtil.fromList(Arrays.asList(
+            VectorUtils.of(1, -0.25),
+            VectorUtils.of(-0.25, 1)),
+            true
+        );
+
+        Vector mean2 = VectorUtils.of(2., 1.);
+        DenseMatrix covariance2 = MatrixUtil.fromList(Arrays.asList(
+            VectorUtils.of(1, 0.5),
+            VectorUtils.of(0.5, 1)),
+            true
+        );
+
+        GmmModel gmm = new GmmModel(
+            VectorUtils.of(0.5, 0.5),
+            Arrays.asList(
+                new MultivariateGaussianDistribution(mean1, covariance1),
+                new MultivariateGaussianDistribution(mean2, covariance2)
+            )
+        );
+
+        Assert.assertEquals(0., gmm.predict(mean1), 0.01);
+        Assert.assertEquals(1., gmm.predict(mean2), 0.01);
+        Assert.assertEquals(0., gmm.predict(VectorUtils.of(1.5, 1.5)), 0.01);
+        Assert.assertEquals(1., gmm.predict(VectorUtils.of(3., 0.)), 0.01);
+    }
+}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/gmm/GmmPartitionDataTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/gmm/GmmPartitionDataTest.java
new file mode 100644 (file)
index 0000000..2aca0db
--- /dev/null
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.clustering.gmm;
+
+import java.util.Arrays;
+import org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution;
+import org.apache.ignite.ml.structures.LabeledVector;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Tests for {@link GmmPartitionDataTest}.
+ */
+public class GmmPartitionDataTest {
+    private GmmPartitionData data;
+
+    /** */
+    @Before
+    public void setUp() throws Exception {
+        data = new GmmPartitionData(
+            Arrays.asList(
+                new LabeledVector<>(VectorUtils.of(1, 0), 0.),
+                new LabeledVector<>(VectorUtils.of(0, 1), 0.),
+                new LabeledVector<>(VectorUtils.of(1, 1), 0.)
+            ),
+            new double[3][2]
+        );
+    }
+
+    /** */
+    @Test
+    public void testEstimateLikelihoodClusters() {
+        GmmPartitionData.estimateLikelihoodClusters(data, new Vector[] {
+            VectorUtils.of(1.0, 0.5),
+            VectorUtils.of(0.0, 0.5)
+        });
+
+        assertEquals(1.0, data.pcxi(0, 0), 1e-4);
+        assertEquals(0.0, data.pcxi(1, 0), 1e-4);
+
+        assertEquals(0.0, data.pcxi(0, 1), 1e-4);
+        assertEquals(1.0, data.pcxi(1, 1), 1e-4);
+
+        assertEquals(1.0, data.pcxi(0, 2), 1e-4);
+        assertEquals(0.0, data.pcxi(1, 2), 1e-4);
+    }
+
+    /** */
+    @Test
+    public void testUpdatePcxi() {
+        GmmPartitionData.updatePcxi(
+            data,
+            VectorUtils.of(0.3, 0.7),
+            Arrays.asList(
+                new MultivariateGaussianDistribution(VectorUtils.of(1.0, 0.5), new DenseMatrix(new double[] {0.5, 0., 0., 1.}, 2)),
+                new MultivariateGaussianDistribution(VectorUtils.of(0.0, 0.5), new DenseMatrix(new double[] {1.0, 0., 0., 1.}, 2))
+            )
+        );
+
+        assertEquals(0.49, data.pcxi(0, 0), 1e-2);
+        assertEquals(0.50, data.pcxi(1, 0), 1e-2);
+
+        assertEquals(0.18, data.pcxi(0, 1), 1e-2);
+        assertEquals(0.81, data.pcxi(1, 1), 1e-2);
+
+        assertEquals(0.49, data.pcxi(0, 2), 1e-2);
+        assertEquals(0.50, data.pcxi(1, 2), 1e-2);
+    }
+}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/gmm/GmmTrainerIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/gmm/GmmTrainerIntegrationTest.java
new file mode 100644 (file)
index 0000000..deb9948
--- /dev/null
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.clustering.gmm;
+
+import java.util.Arrays;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
+import org.apache.ignite.configuration.CacheConfiguration;
+import org.apache.ignite.internal.util.IgniteUtils;
+import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Integration test for GmmTrainer.
+ */
+public class GmmTrainerIntegrationTest extends GridCommonAbstractTest {
+    /** Number of nodes in grid */
+    private static final int NODE_COUNT = 3;
+
+    /** Ignite instance. */
+    private Ignite ignite;
+
+    /** {@inheritDoc} */
+    @Override protected void beforeTestsStarted() throws Exception {
+        for (int i = 1; i <= NODE_COUNT; i++)
+            startGrid(i);
+    }
+
+    /** {@inheritDoc} */
+    @Override protected void afterTestsStopped() {
+        stopAllGrids();
+    }
+
+    /**
+     * {@inheritDoc}
+     */
+    @Override protected void beforeTest() {
+        /* Grid instance. */
+        ignite = grid(NODE_COUNT);
+        ignite.configuration().setPeerClassLoadingEnabled(true);
+        IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+    }
+
+    /** */
+    @Test
+    public void testFit() {
+        CacheConfiguration<Integer, double[]> trainingSetCacheCfg = new CacheConfiguration<>();
+        trainingSetCacheCfg.setAffinity(new RendezvousAffinityFunction(false, 3));
+        trainingSetCacheCfg.setName("TRAINING_SET");
+
+        IgniteCache<Integer, double[]> data = ignite.createCache(trainingSetCacheCfg);
+        data.put(0, new double[] {1.0, 1.0, 1.0});
+        data.put(1, new double[] {1.0, 2.0, 1.0});
+        data.put(2, new double[] {2.0, 1.0, 1.0});
+        data.put(3, new double[] {-1.0, -1.0, 2.0});
+        data.put(4, new double[] {-1.0, -2.0, 2.0});
+        data.put(5, new double[] {-2.0, -1.0, 2.0});
+
+        GmmTrainer trainer = new GmmTrainer(2, 1)
+            .withInitialMeans(Arrays.asList(
+                VectorUtils.of(1.0, 2.0),
+                VectorUtils.of(-1.0, -2.0)));
+        GmmModel model = trainer.fit(
+            new CacheBasedDatasetBuilder<>(ignite, data),
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[2]
+        );
+
+        Assert.assertEquals(2, model.countOfComponents());
+        Assert.assertEquals(2, model.dimension());
+        Assert.assertArrayEquals(new double[] {1.33, 1.33}, model.distributions().get(0).mean().asArray(), 1e-2);
+        Assert.assertArrayEquals(new double[] {-1.33, -1.33}, model.distributions().get(1).mean().asArray(), 1e-2);
+    }
+}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/gmm/GmmTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/gmm/GmmTrainerTest.java
new file mode 100644 (file)
index 0000000..09d0abc
--- /dev/null
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.clustering.gmm;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.ignite.ml.common.TrainerTest;
+import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
+import org.apache.ignite.ml.structures.LabeledVector;
+import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Tests for GMM trainer.
+ */
+public class GmmTrainerTest extends TrainerTest {
+    /** Data. */
+    private static final Map<Integer, double[]> data = new HashMap<>();
+
+    static {
+        data.put(0, new double[] {1.0, 1.0, 1.0});
+        data.put(1, new double[] {1.0, 2.0, 1.0});
+        data.put(2, new double[] {2.0, 1.0, 1.0});
+        data.put(3, new double[] {-1.0, -1.0, 2.0});
+        data.put(4, new double[] {-1.0, -2.0, 2.0});
+        data.put(5, new double[] {-2.0, -1.0, 2.0});
+    }
+
+    /** */
+    @Test
+    public void testFit() {
+        GmmTrainer trainer = new GmmTrainer(2, 1)
+            .withInitialMeans(Arrays.asList(
+                VectorUtils.of(1.0, 2.0),
+                VectorUtils.of(-1.0, -2.0)));
+        GmmModel model = trainer.fit(
+            new LocalDatasetBuilder<>(data, parts),
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[2]
+        );
+
+        Assert.assertEquals(2, model.countOfComponents());
+        Assert.assertEquals(2, model.dimension());
+        Assert.assertArrayEquals(new double[] {1.33, 1.33}, model.distributions().get(0).mean().asArray(), 1e-2);
+        Assert.assertArrayEquals(new double[] {-1.33, -1.33}, model.distributions().get(1).mean().asArray(), 1e-2);
+    }
+
+    /** */
+    @Test(expected = IllegalArgumentException.class)
+    public void testOnEmptyPartition() throws Throwable {
+        GmmTrainer trainer = new GmmTrainer(2, 1)
+            .withInitialMeans(Arrays.asList(VectorUtils.of(1.0, 2.0), VectorUtils.of(-1.0, -2.0)));
+
+        try {
+            trainer.fit(
+                new LocalDatasetBuilder<>(new HashMap<>(), parts),
+                (k, v) -> new DenseVector(2),
+                (k, v) -> 1.0
+            );
+        }
+        catch (RuntimeException e) {
+            throw e.getCause();
+        }
+    }
+
+    /** */
+    @Test
+    public void testUpdateOnEmptyDataset() {
+        GmmTrainer trainer = new GmmTrainer(2, 1)
+            .withInitialMeans(Arrays.asList(
+                VectorUtils.of(1.0, 2.0),
+                VectorUtils.of(-1.0, -2.0)));
+        GmmModel model = trainer.fit(
+            new LocalDatasetBuilder<>(data, parts),
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[2]
+        );
+
+        model = trainer.updateModel(model,
+            new LocalDatasetBuilder<>(new HashMap<>(), parts),
+            new FeatureLabelExtractor<Double, Vector, Double>() {
+                private static final long serialVersionUID = -7245682432641745217L;
+
+                @Override public LabeledVector<Double> extract(Double aDouble, Vector vector) {
+                    return new LabeledVector<>(new DenseVector(2), 1.0);
+                }
+            }
+        );
+
+        Assert.assertEquals(2, model.countOfComponents());
+        Assert.assertEquals(2, model.dimension());
+        Assert.assertArrayEquals(new double[] {1.33, 1.33}, model.distributions().get(0).mean().asArray(), 1e-2);
+        Assert.assertArrayEquals(new double[] {-1.33, -1.33}, model.distributions().get(1).mean().asArray(), 1e-2);
+    }
+}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/gmm/MeanWithClusterProbAggregatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/gmm/MeanWithClusterProbAggregatorTest.java
new file mode 100644 (file)
index 0000000..e6307e1
--- /dev/null
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.clustering.gmm;
+
+import java.util.Arrays;
+import java.util.List;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.structures.LabeledVector;
+import org.junit.Test;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Tests for {@link MeanWithClusterProbAggregator}.
+ */
+public class MeanWithClusterProbAggregatorTest {
+    /** */
+    private MeanWithClusterProbAggregator aggregator1 = new MeanWithClusterProbAggregator();
+
+    /** */
+    private MeanWithClusterProbAggregator aggregator2 = new MeanWithClusterProbAggregator();
+
+    /**
+     * Default constructor.
+     */
+    public MeanWithClusterProbAggregatorTest() {
+        aggregator1.add(VectorUtils.of(1., 1.), 0.5);
+        aggregator1.add(VectorUtils.of(0., 1.), 0.25);
+        aggregator1.add(VectorUtils.of(1., 0.), 0.75);
+        aggregator1.add(VectorUtils.of(0., 0.), 0.10);
+
+        aggregator2.add(VectorUtils.of(1., 1.), 1.0);
+        aggregator2.add(VectorUtils.of(0., 1.), 1.0);
+        aggregator2.add(VectorUtils.of(1., 0.), 1.0);
+        aggregator2.add(VectorUtils.of(0., 0.), 1.0);
+    }
+
+    /** */
+    @Test
+    public void testAdd() {
+        assertArrayEquals(new double[] {0.781, 0.468}, aggregator1.mean().asArray(), 1e-2);
+        assertArrayEquals(new double[] {0.5, 0.5}, aggregator2.mean().asArray(), 1e-2);
+
+        assertEquals(0.4, aggregator1.clusterProb(), 1e-4);
+        assertEquals(1.0, aggregator2.clusterProb(), 1e-4);
+    }
+
+    /** */
+    @Test
+    public void testPlus() {
+        MeanWithClusterProbAggregator res = aggregator1.plus(aggregator2);
+
+        assertEquals(0.7, res.clusterProb(), 1e-4);
+        assertArrayEquals(new double[] {0.580, 0.491}, res.mean().asArray(), 1e-2);
+    }
+
+    /** */
+    @Test
+    public void testReduce() {
+        MeanWithClusterProbAggregator aggregator3 = new MeanWithClusterProbAggregator();
+        MeanWithClusterProbAggregator aggregator4 = new MeanWithClusterProbAggregator();
+
+        aggregator3.add(VectorUtils.of(1., 1.), 0.5);
+        aggregator3.add(VectorUtils.of(0., 1.), 0.25);
+        aggregator3.add(VectorUtils.of(1., 0.), 0.25);
+        aggregator3.add(VectorUtils.of(0., 0.), 0.5);
+
+        aggregator4.add(VectorUtils.of(1., 1.), 1.0);
+        aggregator4.add(VectorUtils.of(0., 1.), 1.0);
+        aggregator4.add(VectorUtils.of(1., 0.), 1.0);
+        aggregator4.add(VectorUtils.of(0., 0.), 1.0);
+
+        List<MeanWithClusterProbAggregator> res = MeanWithClusterProbAggregator.reduce(
+            Arrays.asList(aggregator1, aggregator3),
+            Arrays.asList(aggregator2, aggregator4)
+        );
+
+        MeanWithClusterProbAggregator res1 = res.get(0);
+        assertEquals(0.70, res1.clusterProb(), 1e-2);
+        assertArrayEquals(new double[] {0.580, 0.491}, res1.mean().asArray(), 1e-2);
+
+        MeanWithClusterProbAggregator res2 = res.get(1);
+        assertEquals(0.68, res2.clusterProb(), 1e-2);
+        assertArrayEquals(new double[] {0.50, 0.50}, res2.mean().asArray(), 1e-2);
+    }
+
+    /** */
+    @Test
+    public void testMap() {
+        GmmPartitionData data = new GmmPartitionData(
+            Arrays.asList(
+                new LabeledVector<>(VectorUtils.of(1, 0), 0.),
+                new LabeledVector<>(VectorUtils.of(0, 1), 0.),
+                new LabeledVector<>(VectorUtils.of(1, 1), 0.)
+            ),
+
+            new double[][] {
+                new double[] {0.5, 0.1},
+                new double[] {1.0, 0.4},
+                new double[] {0.3, 0.2}
+            }
+        );
+
+        List<MeanWithClusterProbAggregator> res = MeanWithClusterProbAggregator.map(data);
+        assertEquals(2, res.size());
+
+        MeanWithClusterProbAggregator agg1 = res.get(0);
+        assertEquals(0.6, agg1.clusterProb(), 1e-2);
+        assertArrayEquals(new double[] {0.44, 0.72}, agg1.mean().asArray(), 1e-2);
+
+        MeanWithClusterProbAggregator agg2 = res.get(1);
+        assertEquals(0.23, agg2.clusterProb(), 1e-2);
+        assertArrayEquals(new double[] {0.42, 0.85}, agg2.mean().asArray(), 1e-2);
+    }
+}
index b76a4c2..d6366d3 100644 (file)
@@ -20,6 +20,7 @@ package org.apache.ignite.ml.math;
 import org.apache.ignite.ml.math.distances.DistanceTest;
 import org.apache.ignite.ml.math.isolve.lsqr.LSQROnHeapTest;
 import org.apache.ignite.ml.math.primitives.matrix.DenseMatrixConstructorTest;
+import org.apache.ignite.ml.math.primitives.matrix.LUDecompositionTest;
 import org.apache.ignite.ml.math.primitives.matrix.MatrixArrayStorageTest;
 import org.apache.ignite.ml.math.primitives.matrix.MatrixAttributeTest;
 import org.apache.ignite.ml.math.primitives.matrix.MatrixStorageImplementationTest;
@@ -69,7 +70,8 @@ import org.junit.runners.Suite;
     // Matrix tests.
     MatrixAttributeTest.class,
     DistanceTest.class,
-    LSQROnHeapTest.class
+    LSQROnHeapTest.class,
+    LUDecompositionTest.class
 })
 public class MathImplLocalTestSuite {
     // No-op.
index cd6ae98..500900a 100644 (file)
@@ -17,6 +17,7 @@
 
 package org.apache.ignite.ml.math;
 
+import org.apache.ignite.ml.math.stat.StatsTestSuite;
 import org.junit.runner.RunWith;
 import org.junit.runners.Suite;
 
@@ -25,6 +26,7 @@ import org.junit.runners.Suite;
  */
 @RunWith(Suite.class)
 @Suite.SuiteClasses({
+    StatsTestSuite.class,
     MathImplLocalTestSuite.class,
     TracerTest.class,
     BlasTest.class
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/primitives/matrix/LUDecompositionTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/primitives/matrix/LUDecompositionTest.java
new file mode 100644 (file)
index 0000000..366f644
--- /dev/null
@@ -0,0 +1,250 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.primitives.matrix;
+
+import org.apache.ignite.ml.math.exceptions.CardinalityException;
+import org.apache.ignite.ml.math.exceptions.SingularMatrixException;
+import org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
+import org.apache.ignite.ml.math.util.MatrixUtil;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Tests for {@link LUDecomposition}.
+ */
+public class LUDecompositionTest {
+    /** */
+    private Matrix testL;
+    /** */
+    private Matrix testU;
+    /** */
+    private Matrix testP;
+    /** */
+    private Matrix testMatrix;
+    /** */
+    private int[] rawPivot;
+
+    /** */
+    @Before
+    public void setUp() {
+        double[][] rawMatrix = new double[][] {
+            {2.0d, 1.0d, 1.0d, 0.0d},
+            {4.0d, 3.0d, 3.0d, 1.0d},
+            {8.0d, 7.0d, 9.0d, 5.0d},
+            {6.0d, 7.0d, 9.0d, 8.0d}};
+        double[][] rawL = {
+            {1.0d, 0.0d, 0.0d, 0.0d},
+            {3.0d / 4.0d, 1.0d, 0.0d, 0.0d},
+            {1.0d / 2.0d, -2.0d / 7.0d, 1.0d, 0.0d},
+            {1.0d / 4.0d, -3.0d / 7.0d, 1.0d / 3.0d, 1.0d}};
+        double[][] rawU = {
+            {8.0d, 7.0d, 9.0d, 5.0d},
+            {0.0d, 7.0d / 4.0d, 9.0d / 4.0d, 17.0d / 4.0d},
+            {0.0d, 0.0d, -6.0d / 7.0d, -2.0d / 7.0d},
+            {0.0d, 0.0d, 0.0d, 2.0d / 3.0d}};
+        double[][] rawP = new double[][] {
+            {0, 0, 1.0d, 0},
+            {0, 0, 0, 1.0d},
+            {0, 1.0d, 0, 0},
+            {1.0d, 0, 0, 0}};
+
+        rawPivot = new int[] {3, 4, 2, 1};
+
+        testMatrix = new DenseMatrix(rawMatrix);
+        testL = new DenseMatrix(rawL);
+        testU = new DenseMatrix(rawU);
+        testP = new DenseMatrix(rawP);
+    }
+
+    /** */
+    @Test
+    public void getL() throws Exception {
+        Matrix luDecompositionL = new LUDecomposition(testMatrix).getL();
+
+        assertEquals("Unexpected row size.", testL.rowSize(), luDecompositionL.rowSize());
+        assertEquals("Unexpected column size.", testL.columnSize(), luDecompositionL.columnSize());
+
+        for (int i = 0; i < testL.rowSize(); i++)
+            for (int j = 0; j < testL.columnSize(); j++)
+                assertEquals("Unexpected value at (" + i + "," + j + ").",
+                    testL.getX(i, j), luDecompositionL.getX(i, j), 0.0000001d);
+
+        luDecompositionL.destroy();
+    }
+
+    /** */
+    @Test
+    public void getU() throws Exception {
+        Matrix luDecompositionU = new LUDecomposition(testMatrix).getU();
+
+        assertEquals("Unexpected row size.", testU.rowSize(), luDecompositionU.rowSize());
+        assertEquals("Unexpected column size.", testU.columnSize(), luDecompositionU.columnSize());
+
+        for (int i = 0; i < testU.rowSize(); i++)
+            for (int j = 0; j < testU.columnSize(); j++)
+                assertEquals("Unexpected value at (" + i + "," + j + ").",
+                    testU.getX(i, j), luDecompositionU.getX(i, j), 0.0000001d);
+
+        luDecompositionU.destroy();
+    }
+
+    /** */
+    @Test
+    public void getP() throws Exception {
+        Matrix luDecompositionP = new LUDecomposition(testMatrix).getP();
+
+        assertEquals("Unexpected row size.", testP.rowSize(), luDecompositionP.rowSize());
+        assertEquals("Unexpected column size.", testP.columnSize(), luDecompositionP.columnSize());
+
+        for (int i = 0; i < testP.rowSize(); i++)
+            for (int j = 0; j < testP.columnSize(); j++)
+                assertEquals("Unexpected value at (" + i + "," + j + ").",
+                    testP.getX(i, j), luDecompositionP.getX(i, j), 0.0000001d);
+
+        luDecompositionP.destroy();
+    }
+
+    /** */
+    @Test
+    public void getPivot() throws Exception {
+        Vector pivot = new LUDecomposition(testMatrix).getPivot();
+
+        assertEquals("Unexpected pivot size.", rawPivot.length, pivot.size());
+
+        for (int i = 0; i < testU.rowSize(); i++)
+            assertEquals("Unexpected value at " + i, rawPivot[i], (int)pivot.get(i) + 1);
+    }
+
+    /**
+     * Test for {@link MatrixUtil} features (more specifically, we test matrix which does not have a native like/copy
+     * methods support).
+     */
+    @Test
+    public void matrixUtilTest() {
+        LUDecomposition dec = new LUDecomposition(testMatrix);
+        Matrix luDecompositionL = dec.getL();
+
+        assertEquals("Unexpected L row size.", testL.rowSize(), luDecompositionL.rowSize());
+        assertEquals("Unexpected L column size.", testL.columnSize(), luDecompositionL.columnSize());
+
+        for (int i = 0; i < testL.rowSize(); i++)
+            for (int j = 0; j < testL.columnSize(); j++)
+                assertEquals("Unexpected L value at (" + i + "," + j + ").",
+                    testL.getX(i, j), luDecompositionL.getX(i, j), 0.0000001d);
+
+        Matrix luDecompositionU = dec.getU();
+
+        assertEquals("Unexpected U row size.", testU.rowSize(), luDecompositionU.rowSize());
+        assertEquals("Unexpected U column size.", testU.columnSize(), luDecompositionU.columnSize());
+
+        for (int i = 0; i < testU.rowSize(); i++)
+            for (int j = 0; j < testU.columnSize(); j++)
+                assertEquals("Unexpected U value at (" + i + "," + j + ").",
+                    testU.getX(i, j), luDecompositionU.getX(i, j), 0.0000001d);
+
+        Matrix luDecompositionP = dec.getP();
+
+        assertEquals("Unexpected P row size.", testP.rowSize(), luDecompositionP.rowSize());
+        assertEquals("Unexpected P column size.", testP.columnSize(), luDecompositionP.columnSize());
+
+        for (int i = 0; i < testP.rowSize(); i++)
+            for (int j = 0; j < testP.columnSize(); j++)
+                assertEquals("Unexpected P value at (" + i + "," + j + ").",
+                    testP.getX(i, j), luDecompositionP.getX(i, j), 0.0000001d);
+
+        dec.close();
+    }
+
+    /** */
+    @Test
+    public void singularDeterminant() throws Exception {
+        assertEquals("Unexpected determinant for singular matrix decomposition.",
+            0d, new LUDecomposition(new DenseMatrix(2, 2)).determinant(), 0d);
+    }
+
+    /** */
+    @Test(expected = CardinalityException.class)
+    public void solveVecWrongSize() throws Exception {
+        new LUDecomposition(testMatrix).solve(new DenseVector(testMatrix.rowSize() + 1));
+    }
+
+    /** */
+    @Test(expected = SingularMatrixException.class)
+    public void solveVecSingularMatrix() throws Exception {
+        new LUDecomposition(new DenseMatrix(testMatrix.rowSize(), testMatrix.rowSize()))
+            .solve(new DenseVector(testMatrix.rowSize()));
+    }
+
+    /** */
+    @Test
+    public void solveVec() throws Exception {
+        Vector sol = new LUDecomposition(testMatrix)
+            .solve(new DenseVector(testMatrix.rowSize()));
+
+        assertEquals("Wrong solution vector size.", testMatrix.rowSize(), sol.size());
+
+        for (int i = 0; i < sol.size(); i++)
+            assertEquals("Unexpected value at index " + i, 0d, sol.getX(i), 0.0000001d);
+    }
+
+    /** */
+    @Test(expected = CardinalityException.class)
+    public void solveMtxWrongSize() throws Exception {
+        new LUDecomposition(testMatrix).solve(
+            new DenseMatrix(testMatrix.rowSize() + 1, testMatrix.rowSize()));
+    }
+
+    /** */
+    @Test(expected = SingularMatrixException.class)
+    public void solveMtxSingularMatrix() throws Exception {
+        new LUDecomposition(new DenseMatrix(testMatrix.rowSize(), testMatrix.rowSize()))
+            .solve(new DenseMatrix(testMatrix.rowSize(), testMatrix.rowSize()));
+    }
+
+    /** */
+    @Test
+    public void solveMtx() throws Exception {
+        Matrix sol = new LUDecomposition(testMatrix)
+            .solve(new DenseMatrix(testMatrix.rowSize(), testMatrix.rowSize()));
+
+        assertEquals("Wrong solution matrix row size.", testMatrix.rowSize(), sol.rowSize());
+
+        assertEquals("Wrong solution matrix column size.", testMatrix.rowSize(), sol.columnSize());
+
+        for (int row = 0; row < sol.rowSize(); row++)
+            for (int col = 0; col < sol.columnSize(); col++)
+                assertEquals("Unexpected P value at (" + row + "," + col + ").",
+                    0d, sol.getX(row, col), 0.0000001d);
+    }
+
+    /** */
+    @Test(expected = AssertionError.class)
+    public void nullMatrixTest() {
+        new LUDecomposition(null);
+    }
+
+    /** */
+    @Test(expected = CardinalityException.class)
+    public void nonSquareMatrixTest() {
+        new LUDecomposition(new DenseMatrix(2, 3));
+    }
+}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/stat/DistributionMixtureTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/stat/DistributionMixtureTest.java
new file mode 100644 (file)
index 0000000..c46d212
--- /dev/null
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.stat;
+
+import java.util.Arrays;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+/**
+ *
+ */
+public class DistributionMixtureTest {
+    /** */
+    private DistributionMixture<Constant> mixture;
+
+    @Before
+    public void setUp() throws Exception {
+        mixture = new DistributionMixture<Constant>(
+            VectorUtils.of(0.3, 0.3, 0.4),
+            Arrays.asList(new Constant(0.5), new Constant(1.0), new Constant(0.))
+        ) {
+        };
+
+        assertEquals(1, mixture.dimension());
+        assertEquals(3, mixture.countOfComponents());
+    }
+
+    /** */
+    @Test
+    public void testLikelihood() {
+        assertArrayEquals(
+            new double[] {0.15, 0.3, 0.},
+            mixture.likelihood(VectorUtils.of(1.)).asArray(), 1e-4
+        );
+    }
+
+    /** */
+    @Test
+    public void testProb() {
+        assertEquals(0.45, mixture.prob(VectorUtils.of(1.)), 1e-4);
+    }
+
+    /** */
+    private static class Constant implements Distribution {
+        /** Value. */
+        private final double value;
+
+        /** */
+        public Constant(double value) {
+            this.value = value;
+        }
+
+        /** {@inheritDoc} */
+        @Override public double prob(Vector x) {
+            return value;
+        }
+
+        /** {@inheritDoc} */
+        @Override public int dimension() {
+            return 1;
+        }
+    }
+}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/stat/MultivariateGaussianDistributionTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/stat/MultivariateGaussianDistributionTest.java
new file mode 100644 (file)
index 0000000..415bc18
--- /dev/null
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.stat;
+
+import org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Tests for {@link MultivariateGaussianDistribution}.
+ */
+public class MultivariateGaussianDistributionTest {
+    /** */
+    @Test
+    public void testApply() {
+        MultivariateGaussianDistribution distribution = new MultivariateGaussianDistribution(
+            VectorUtils.of(1, 2),
+            new DenseMatrix(new double[][] {new double[] {1, -0.5}, new double[] {-0.5, 1}})
+        );
+
+        Assert.assertEquals(0.183, distribution.prob(VectorUtils.of(1, 2)), 0.01);
+        Assert.assertEquals(0.094, distribution.prob(VectorUtils.of(0, 2)), 0.01);
+    }
+}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/stat/StatsTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/stat/StatsTestSuite.java
new file mode 100644 (file)
index 0000000..5b4c80e
--- /dev/null
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.stat;
+
+import org.junit.runner.RunWith;
+import org.junit.runners.Suite;
+
+/**
+ * Test suite for stat package.
+ */
+@RunWith(Suite.class)
+@Suite.SuiteClasses({
+    DistributionMixtureTest.class,
+    MultivariateGaussianDistributionTest.class
+})
+public class StatsTestSuite {
+}
index 71ed387..768f783 100644 (file)
@@ -16,7 +16,9 @@ Running Ignite Benchmarks Locally
 The simplest way to start with benchmarking is to use one of the executable scripts available under `benchmarks\bin`
 directory:
 
+modules/yardstick/target
 ./bin/benchmark-run-all.sh config/benchmark-sample.properties
+modules/yardstick/target/assembly/bin/benchmark-run-all.sh modules/yardstick/target/assembly/config/benchmark-ml.properties
 
 The command above will benchmark the cache put operation for a distributed atomic cache. The results of the
 benchmark will be added to an auto-generated `output/results-{DATE-TIME}` directory.