IGNITE-7438: LSQR solver for Linear Regression
authordmitrievanthony <dmitrievanthony@gmail.com>
Fri, 9 Feb 2018 11:17:27 +0000 (14:17 +0300)
committerYury Babak <ybabak@gridgain.com>
Fri, 9 Feb 2018 11:17:27 +0000 (14:17 +0300)
this closes #3494

20 files changed:
examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerExample.java [new file with mode: 0644]
examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithQRTrainerExample.java [moved from examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionExampleWithQRTrainer.java with 99% similarity]
examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithSGDTrainerExample.java [moved from examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionExampleWithSGDTrainer.java with 99% similarity]
modules/ml/src/main/java/org/apache/ignite/ml/DatasetTrainer.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/Trainer.java
modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/IterativeSolverResult.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/LinSysPartitionDataBuilderOnHeap.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/LinSysPartitionDataOnHeap.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/AbstractLSQR.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQRPartitionContext.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQRResult.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/package-info.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/package-info.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/trainers/Trainer.java
modules/ml/src/test/java/org/apache/ignite/ml/math/MathImplLocalTestSuite.java
modules/ml/src/test/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java [new file with mode: 0644]
modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java
modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java [new file with mode: 0644]

diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerExample.java
new file mode 100644 (file)
index 0000000..20e0653
--- /dev/null
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.regression.linear;
+
+import java.util.Arrays;
+import java.util.UUID;
+import javax.cache.Cache;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
+import org.apache.ignite.cache.query.QueryCursor;
+import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.configuration.CacheConfiguration;
+import org.apache.ignite.examples.ml.math.matrix.SparseDistributedMatrixExample;
+import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer;
+import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
+import org.apache.ignite.thread.IgniteThread;
+
+/**
+ * Run linear regression model over distributed matrix.
+ *
+ * @see LinearRegressionLSQRTrainer
+ */
+public class DistributedLinearRegressionWithLSQRTrainerExample {
+    /** */
+    private static final double[][] data = {
+        {8, 78, 284, 9.100000381, 109},
+        {9.300000191, 68, 433, 8.699999809, 144},
+        {7.5, 70, 739, 7.199999809, 113},
+        {8.899999619, 96, 1792, 8.899999619, 97},
+        {10.19999981, 74, 477, 8.300000191, 206},
+        {8.300000191, 111, 362, 10.89999962, 124},
+        {8.800000191, 77, 671, 10, 152},
+        {8.800000191, 168, 636, 9.100000381, 162},
+        {10.69999981, 82, 329, 8.699999809, 150},
+        {11.69999981, 89, 634, 7.599999905, 134},
+        {8.5, 149, 631, 10.80000019, 292},
+        {8.300000191, 60, 257, 9.5, 108},
+        {8.199999809, 96, 284, 8.800000191, 111},
+        {7.900000095, 83, 603, 9.5, 182},
+        {10.30000019, 130, 686, 8.699999809, 129},
+        {7.400000095, 145, 345, 11.19999981, 158},
+        {9.600000381, 112, 1357, 9.699999809, 186},
+        {9.300000191, 131, 544, 9.600000381, 177},
+        {10.60000038, 80, 205, 9.100000381, 127},
+        {9.699999809, 130, 1264, 9.199999809, 179},
+        {11.60000038, 140, 688, 8.300000191, 80},
+        {8.100000381, 154, 354, 8.399999619, 103},
+        {9.800000191, 118, 1632, 9.399999619, 101},
+        {7.400000095, 94, 348, 9.800000191, 117},
+        {9.399999619, 119, 370, 10.39999962, 88},
+        {11.19999981, 153, 648, 9.899999619, 78},
+        {9.100000381, 116, 366, 9.199999809, 102},
+        {10.5, 97, 540, 10.30000019, 95},
+        {11.89999962, 176, 680, 8.899999619, 80},
+        {8.399999619, 75, 345, 9.600000381, 92},
+        {5, 134, 525, 10.30000019, 126},
+        {9.800000191, 161, 870, 10.39999962, 108},
+        {9.800000191, 111, 669, 9.699999809, 77},
+        {10.80000019, 114, 452, 9.600000381, 60},
+        {10.10000038, 142, 430, 10.69999981, 71},
+        {10.89999962, 238, 822, 10.30000019, 86},
+        {9.199999809, 78, 190, 10.69999981, 93},
+        {8.300000191, 196, 867, 9.600000381, 106},
+        {7.300000191, 125, 969, 10.5, 162},
+        {9.399999619, 82, 499, 7.699999809, 95},
+        {9.399999619, 125, 925, 10.19999981, 91},
+        {9.800000191, 129, 353, 9.899999619, 52},
+        {3.599999905, 84, 288, 8.399999619, 110},
+        {8.399999619, 183, 718, 10.39999962, 69},
+        {10.80000019, 119, 540, 9.199999809, 57},
+        {10.10000038, 180, 668, 13, 106},
+        {9, 82, 347, 8.800000191, 40},
+        {10, 71, 345, 9.199999809, 50},
+        {11.30000019, 118, 463, 7.800000191, 35},
+        {11.30000019, 121, 728, 8.199999809, 86},
+        {12.80000019, 68, 383, 7.400000095, 57},
+        {10, 112, 316, 10.39999962, 57},
+        {6.699999809, 109, 388, 8.899999619, 94}
+    };
+
+    /** Run example. */
+    public static void main(String[] args) throws InterruptedException {
+        System.out.println();
+        System.out.println(">>> Linear regression model over sparse distributed matrix API usage example started.");
+        // Start ignite grid.
+        try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
+            System.out.println(">>> Ignite grid started.");
+
+            // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread
+            // because we create ignite cache internally.
+            IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
+                SparseDistributedMatrixExample.class.getSimpleName(), () -> {
+                IgniteCache<Integer, double[]> dataCache = getTestCache(ignite);
+
+                System.out.println(">>> Create new linear regression trainer object.");
+                LinearRegressionLSQRTrainer<Integer, double[]> trainer = new LinearRegressionLSQRTrainer<>();
+
+                System.out.println(">>> Perform the training to get the model.");
+                LinearRegressionModel mdl = trainer.fit(
+                    new CacheBasedDatasetBuilder<>(ignite, dataCache),
+                    (k, v) -> Arrays.copyOfRange(v, 1, v.length),
+                    (k, v) -> v[0],
+                    4
+                );
+
+                System.out.println(">>> Linear regression model: " + mdl);
+
+                System.out.println(">>> ---------------------------------");
+                System.out.println(">>> | Prediction\t| Ground Truth\t|");
+                System.out.println(">>> ---------------------------------");
+
+                try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
+                    for (Cache.Entry<Integer, double[]> observation : observations) {
+                        double[] val = observation.getValue();
+                        double[] inputs = Arrays.copyOfRange(val, 1, val.length);
+                        double groundTruth = val[0];
+
+                        double prediction = mdl.apply(new DenseLocalOnHeapVector(inputs));
+
+                        System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
+                    }
+                }
+
+                System.out.println(">>> ---------------------------------");
+            });
+
+            igniteThread.start();
+
+            igniteThread.join();
+        }
+    }
+
+    /**
+     * Fills cache with data and returns it.
+     *
+     * @param ignite Ignite instance.
+     * @return Filled Ignite Cache.
+     */
+    private static IgniteCache<Integer, double[]> getTestCache(Ignite ignite) {
+        CacheConfiguration<Integer, double[]> cacheConfiguration = new CacheConfiguration<>();
+        cacheConfiguration.setName("TEST_" + UUID.randomUUID());
+        cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10));
+
+        IgniteCache<Integer, double[]> cache = ignite.createCache(cacheConfiguration);
+
+        for (int i = 0; i < data.length; i++)
+            cache.put(i, data[i]);
+
+        return cache;
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/DatasetTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/DatasetTrainer.java
new file mode 100644 (file)
index 0000000..aa04d8e
--- /dev/null
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml;
+
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+
+/**
+ * Interface for trainers. Trainer is just a function which produces model from the data.
+ *
+ * @param <K> Type of a key in {@code upstream} data.
+ * @param <V> Type of a value in {@code upstream} data.
+ * @param <M> Type of a produced model.
+ */
+public interface DatasetTrainer<K, V, M extends Model> {
+    /**
+     * Trains model based on the specified data.
+     *
+     * @param datasetBuilder Dataset builder.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @param cols Number of columns.
+     * @return Model.
+     */
+    public M fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, double[]> featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor, int cols);
+}
index a95a1cc..4e0a570 100644 (file)
@@ -26,6 +26,8 @@ import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer
  * @param <M> Type of produced model.
  * @param <T> Type of data needed for model producing.
  */
+// TODO: IGNITE-7659: Reduce multiple Trainer interfaces to one
+@Deprecated
 public interface Trainer<M extends Model, T> {
     /**
      * Returns model based on data
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/IterativeSolverResult.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/IterativeSolverResult.java
new file mode 100644 (file)
index 0000000..fe39ad7
--- /dev/null
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.isolve;
+
+import java.io.Serializable;
+import java.util.Arrays;
+
+/**
+ * Base class for iterative solver results.
+ */
+public class IterativeSolverResult implements Serializable {
+    /** */
+    private static final long serialVersionUID = 8084061028708491097L;
+
+    /** The final solution. */
+    private final double[] x;
+
+    /** Iteration number upon termination. */
+    private final int iterations;
+
+    /**
+     * Constructs a new instance of iterative solver result.
+     *
+     * @param x The final solution.
+     * @param iterations Iteration number upon termination.
+     */
+    public IterativeSolverResult(double[] x, int iterations) {
+        this.x = x;
+        this.iterations = iterations;
+    }
+
+    /** */
+    public double[] getX() {
+        return x;
+    }
+
+    /** */
+    public int getIterations() {
+        return iterations;
+    }
+
+    /** */
+    @Override public String toString() {
+        return "IterativeSolverResult{" +
+            "x=" + Arrays.toString(x) +
+            ", iterations=" + iterations +
+            '}';
+    }
+}
\ No newline at end of file
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/LinSysPartitionDataBuilderOnHeap.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/LinSysPartitionDataBuilderOnHeap.java
new file mode 100644 (file)
index 0000000..1c2e2cf
--- /dev/null
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.isolve;
+
+import java.io.Serializable;
+import java.util.Iterator;
+import org.apache.ignite.ml.dataset.PartitionDataBuilder;
+import org.apache.ignite.ml.dataset.UpstreamEntry;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+
+/**
+ * Linear system partition data builder that builds {@link LinSysPartitionDataOnHeap}.
+ *
+ * @param <K> Type of a key in <tt>upstream</tt> data.
+ * @param <V> Type of a value in <tt>upstream</tt> data.
+ * @param <C> Type of a partition <tt>context</tt>.
+ */
+public class LinSysPartitionDataBuilderOnHeap<K, V, C extends Serializable>
+    implements PartitionDataBuilder<K, V, C, LinSysPartitionDataOnHeap> {
+    /** */
+    private static final long serialVersionUID = -7820760153954269227L;
+
+    /** Extractor of X matrix row. */
+    private final IgniteBiFunction<K, V, double[]> xExtractor;
+
+    /** Extractor of Y vector value. */
+    private final IgniteBiFunction<K, V, Double> yExtractor;
+
+    /** Number of columns. */
+    private final int cols;
+
+    /**
+     * Constructs a new instance of linear system partition data builder.
+     *
+     * @param xExtractor Extractor of X matrix row.
+     * @param yExtractor Extractor of Y vector value.
+     * @param cols Number of columns.
+     */
+    public LinSysPartitionDataBuilderOnHeap(IgniteBiFunction<K, V, double[]> xExtractor,
+        IgniteBiFunction<K, V, Double> yExtractor, int cols) {
+        this.xExtractor = xExtractor;
+        this.yExtractor = yExtractor;
+        this.cols = cols;
+    }
+
+    /** {@inheritDoc} */
+    @Override public LinSysPartitionDataOnHeap build(Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize,
+        C ctx) {
+        // Prepares the matrix of features in flat column-major format.
+        double[] x = new double[Math.toIntExact(upstreamDataSize * cols)];
+        double[] y = new double[Math.toIntExact(upstreamDataSize)];
+
+        int ptr = 0;
+        while (upstreamData.hasNext()) {
+            UpstreamEntry<K, V> entry = upstreamData.next();
+            double[] row = xExtractor.apply(entry.getKey(), entry.getValue());
+
+            assert row.length == cols : "X extractor must return exactly " + cols + " columns";
+
+            for (int i = 0; i < cols; i++)
+                x[Math.toIntExact(i * upstreamDataSize) + ptr] = row[i];
+
+            y[ptr] = yExtractor.apply(entry.getKey(), entry.getValue());
+
+            ptr++;
+        }
+
+        return new LinSysPartitionDataOnHeap(x, Math.toIntExact(upstreamDataSize), cols, y);
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/LinSysPartitionDataOnHeap.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/LinSysPartitionDataOnHeap.java
new file mode 100644 (file)
index 0000000..e0b8f46
--- /dev/null
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.isolve;
+
+/**
+ * On Heap partition data that keeps part of a linear system.
+ */
+public class LinSysPartitionDataOnHeap implements AutoCloseable {
+    /** Part of X matrix. */
+    private final double[] x;
+
+    /** Number of rows. */
+    private final int rows;
+
+    /** Number of columns. */
+    private final int cols;
+
+    /** Part of Y vector. */
+    private final double[] y;
+
+    /**
+     * Constructs a new instance of linear system partition data.
+     *
+     * @param x Part of X matrix.
+     * @param rows Number of rows.
+     * @param cols Number of columns.
+     * @param y Part of Y vector.
+     */
+    public LinSysPartitionDataOnHeap(double[] x, int rows, int cols, double[] y) {
+        this.x = x;
+        this.rows = rows;
+        this.cols = cols;
+        this.y = y;
+    }
+
+    /** */
+    public double[] getX() {
+        return x;
+    }
+
+    /** */
+    public int getRows() {
+        return rows;
+    }
+
+    /** */
+    public int getCols() {
+        return cols;
+    }
+
+    /** */
+    public double[] getY() {
+        return y;
+    }
+
+    /** {@inheritDoc} */
+    @Override public void close() {
+        // Do nothing, GC will clean up.
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/AbstractLSQR.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/AbstractLSQR.java
new file mode 100644 (file)
index 0000000..8d190cd
--- /dev/null
@@ -0,0 +1,333 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.isolve.lsqr;
+
+import com.github.fommil.netlib.BLAS;
+import java.util.Arrays;
+
+/**
+ * Basic implementation of the LSQR algorithm without assumptions about dataset storage format or data processing
+ * device.
+ *
+ * This implementation is based on SciPy implementation.
+ * SciPy implementation: https://github.com/scipy/scipy/blob/master/scipy/sparse/linalg/isolve/lsqr.py#L98.
+ */
+// TODO: IGNITE-7660: Refactor LSQR algorithm
+public abstract class AbstractLSQR {
+    /** The smallest representable positive number such that 1.0 + eps != 1.0. */
+    private static final double eps = Double.longBitsToDouble(Double.doubleToLongBits(1.0) | 1) - 1.0;
+
+    /** BLAS (Basic Linear Algebra Subprograms) instance. */
+    private static BLAS blas = BLAS.getInstance();
+
+    /**
+     * Solves given Sparse Linear Systems.
+     *
+     * @param damp Damping coefficient.
+     * @param atol Stopping tolerances, if both (atol and btol) are 1.0e-9 (say), the final residual norm should be
+     * accurate to about 9 digits.
+     * @param btol Stopping tolerances, if both (atol and btol) are 1.0e-9 (say), the final residual norm should be
+     * accurate to about 9 digits.
+     * @param conlim Another stopping tolerance, LSQR terminates if an estimate of cond(A) exceeds conlim.
+     * @param iterLim Explicit limitation on number of iterations (for safety).
+     * @param calcVar Whether to estimate diagonals of (A'A + damp^2*I)^{-1}.
+     * @param x0 Initial value of x.
+     * @return Solver result.
+     */
+    public LSQRResult solve(double damp, double atol, double btol, double conlim, double iterLim, boolean calcVar,
+        double[] x0) {
+        int n = getColumns();
+
+        if (iterLim < 0)
+            iterLim = 2 * n;
+
+        double[] var = new double[n];
+        int itn = 0;
+        int istop = 0;
+        double ctol = 0;
+
+        if (conlim > 0)
+            ctol = 1 / conlim;
+
+        double anorm = 0;
+        double acond = 0;
+        double dampsq = Math.pow(damp, 2.0);
+        double ddnorm = 0;
+        double res2 = 0;
+        double xnorm = 0;
+        double xxnorm = 0;
+        double z = 0;
+        double cs2 = -1;
+        double sn2 = 0;
+
+        // Set up the first vectors u and v for the bidiagonalization.
+        // These satisfy  beta*u = b - A*x,  alfa*v = A'*u.
+        double bnorm = bnorm();
+        double[] x;
+        double beta;
+
+        if (x0 == null) {
+            x = new double[n];
+            beta = bnorm;
+        }
+        else {
+            x = x0;
+            beta = beta(x, -1.0, 1.0);
+        }
+
+        double[] v = new double[n];
+        double alfa;
+
+        if (beta > 0) {
+            v = iter(beta, v);
+            alfa = blas.dnrm2(v.length, v, 1);
+        }
+        else {
+            System.arraycopy(x, 0, v, 0, v.length);
+            alfa = 0;
+        }
+
+        if (alfa > 0)
+            blas.dscal(v.length, 1 / alfa, v, 1);
+
+        double[] w = Arrays.copyOf(v, v.length);
+
+        double rhobar = alfa;
+        double phibar = beta;
+        double rnorm = beta;
+        double r1norm = rnorm;
+        double r2norm = rnorm;
+        double arnorm = alfa * beta;
+        double[] dk = new double[w.length];
+
+        if (arnorm == 0)
+            return new LSQRResult(x, itn, istop, r1norm, r2norm, anorm, acond, arnorm, xnorm, var);
+
+        // Main iteration loop.
+        while (itn < iterLim) {
+            itn = itn + 1;
+
+            // Perform the next step of the bidiagonalization to obtain the
+            // next  beta, u, alfa, v.  These satisfy the relations
+            //            beta*u  =  A*v   -  alfa*u,
+            //            alfa*v  =  A'*u  -  beta*v.
+            beta = beta(v, 1.0, -alfa);
+            if (beta > 0) {
+                anorm = Math.sqrt(Math.pow(anorm, 2) + Math.pow(alfa, 2) + Math.pow(beta, 2) + Math.pow(damp, 2));
+
+                blas.dscal(v.length, -beta, v, 1);
+
+                iter(beta, v);
+
+                //v = dataset.iter(beta, n);
+                alfa = blas.dnrm2(v.length, v, 1);
+
+                if (alfa > 0)
+                    blas.dscal(v.length, 1 / alfa, v, 1);
+            }
+
+            // Use a plane rotation to eliminate the damping parameter.
+            // This alters the diagonal (rhobar) of the lower-bidiagonal matrix.
+            double rhobar1 = Math.sqrt(Math.pow(rhobar, 2) + Math.pow(damp, 2));
+            double cs1 = rhobar / rhobar1;
+            double sn1 = damp / rhobar1;
+            double psi = sn1 * phibar;
+            phibar = cs1 * phibar;
+
+            // Use a plane rotation to eliminate the subdiagonal element (beta)
+            // of the lower-bidiagonal matrix, giving an upper-bidiagonal matrix.
+            double[] symOrtho = symOrtho(rhobar1, beta);
+            double cs = symOrtho[0];
+            double sn = symOrtho[1];
+            double rho = symOrtho[2];
+
+            double theta = sn * alfa;
+            rhobar = -cs * alfa;
+            double phi = cs * phibar;
+            phibar = sn * phibar;
+            double tau = sn * phi;
+
+            double t1 = phi / rho;
+            double t2 = -theta / rho;
+
+            blas.dcopy(w.length, w, 1, dk, 1);
+            blas.dscal(dk.length, 1 / rho, dk, 1);
+
+            // x = x + t1*w
+            blas.daxpy(w.length, t1, w, 1, x, 1);
+            // w = v + t2*w
+            blas.dscal(w.length, t2, w, 1);
+            blas.daxpy(w.length, 1, v, 1, w, 1);
+
+            ddnorm = ddnorm + Math.pow(blas.dnrm2(dk.length, dk, 1), 2);
+
+            if (calcVar)
+                blas.daxpy(var.length, 1.0, pow2(dk), 1, var, 1);
+
+            // Use a plane rotation on the right to eliminate the
+            // super-diagonal element (theta) of the upper-bidiagonal matrix.
+            // Then use the result to estimate norm(x).
+            double delta = sn2 * rho;
+            double gambar = -cs2 * rho;
+            double rhs = phi - delta * z;
+            double zbar = rhs / gambar;
+            xnorm = Math.sqrt(xxnorm + Math.pow(zbar, 2));
+            double gamma = Math.sqrt(Math.pow(gambar, 2) + Math.pow(theta, 2));
+            cs2 = gambar / gamma;
+            sn2 = theta / gamma;
+            z = rhs / gamma;
+            xxnorm = xxnorm + Math.pow(z, 2);
+
+            // Test for convergence.
+            // First, estimate the condition of the matrix  Abar,
+            // and the norms of  rbar  and  Abar'rbar.
+            acond = anorm * Math.sqrt(ddnorm);
+            double res1 = Math.pow(phibar, 2);
+            res2 = res2 + Math.pow(psi, 2);
+            rnorm = Math.sqrt(res1 + res2);
+            arnorm = alfa * Math.abs(tau);
+
+            // Distinguish between
+            //    r1norm = ||b - Ax|| and
+            //    r2norm = rnorm in current code
+            //           = sqrt(r1norm^2 + damp^2*||x||^2).
+            //    Estimate r1norm from
+            //    r1norm = sqrt(r2norm^2 - damp^2*||x||^2).
+            // Although there is cancellation, it might be accurate enough.
+            double r1sq = Math.pow(rnorm, 2) - dampsq * xxnorm;
+            r1norm = Math.sqrt(Math.abs(r1sq));
+
+            if (r1sq < 0)
+                r1norm = -r1norm;
+
+            r2norm = rnorm;
+
+            // Now use these norms to estimate certain other quantities,
+            // some of which will be small near a solution.
+            double test1 = rnorm / bnorm;
+            double test2 = arnorm / (anorm * rnorm + eps);
+            double test3 = 1 / (acond + eps);
+            t1 = test1 / (1 + anorm * xnorm / bnorm);
+            double rtol = btol + atol * anorm * xnorm / bnorm;
+
+            // The following tests guard against extremely small values of
+            // atol, btol  or  ctol.  (The user may have set any or all of
+            // the parameters  atol, btol, conlim  to 0.)
+            // The effect is equivalent to the normal tests using
+            // atol = eps,  btol = eps,  conlim = 1/eps.
+            if (itn >= iterLim)
+                istop = 7;
+
+            if (1 + test3 <= 1)
+                istop = 6;
+
+            if (1 + test2 <= 1)
+                istop = 5;
+
+            if (1 + t1 <= 1)
+                istop = 4;
+
+            // Allow for tolerances set by the user.
+            if (test3 <= ctol)
+                istop = 3;
+
+            if (test2 <= atol)
+                istop = 2;
+
+            if (test1 <= rtol)
+                istop = 1;
+
+            if (istop != 0)
+                break;
+        }
+
+        return new LSQRResult(x, itn, istop, r1norm, r2norm, anorm, acond, arnorm, xnorm, var);
+    }
+
+    /**
+     * Calculates bnorm.
+     *
+     * @return bnorm
+     */
+    protected abstract double bnorm();
+
+    /**
+     * Calculates beta.
+     *
+     * @param x X value.
+     * @param alfa Alfa value.
+     * @param beta Beta value.
+     * @return Beta.
+     */
+    protected abstract double beta(double[] x, double alfa, double beta);
+
+    /**
+     * Perform LSQR iteration.
+     *
+     * @param bnorm Bnorm value.
+     * @param target Target value.
+     * @return Iteration result.
+     */
+    protected abstract double[] iter(double bnorm, double[] target);
+
+    /** */
+    protected abstract int getColumns();
+
+    /** */
+    private static double[] symOrtho(double a, double b) {
+        if (b == 0)
+            return new double[] {Math.signum(a), 0, Math.abs(a)};
+        else if (a == 0)
+            return new double[] {0, Math.signum(b), Math.abs(b)};
+        else {
+            double c, s, r;
+
+            if (Math.abs(b) > Math.abs(a)) {
+                double tau = a / b;
+                s = Math.signum(b) / Math.sqrt(1 + tau * tau);
+                c = s * tau;
+                r = b / s;
+            }
+            else {
+                double tau = b / a;
+                c = Math.signum(a) / Math.sqrt(1 + tau * tau);
+                s = c * tau;
+                r = a / c;
+            }
+
+            return new double[] {c, s, r};
+        }
+    }
+
+    /**
+     * Raises all elements of the specified vector {@code a} to the power of the specified {@code pow}. Be aware that
+     * it's "in place" operation.
+     *
+     * @param a Vector or matrix of doubles.
+     * @return Matrix with elements raised to the specified power.
+     */
+    private static double[] pow2(double[] a) {
+        double[] res = new double[a.length];
+
+        for (int i = 0; i < res.length; i++)
+            res[i] = Math.pow(a[i], 2);
+
+        return res;
+    }
+
+}
\ No newline at end of file
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java
new file mode 100644 (file)
index 0000000..fa8e713
--- /dev/null
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.isolve.lsqr;
+
+import com.github.fommil.netlib.BLAS;
+import java.util.Arrays;
+import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.PartitionDataBuilder;
+import org.apache.ignite.ml.math.isolve.LinSysPartitionDataOnHeap;
+
+/**
+ * Distributed implementation of LSQR algorithm based on {@link AbstractLSQR} and {@link Dataset}.
+ */
+public class LSQROnHeap<K, V> extends AbstractLSQR implements AutoCloseable {
+    /** Dataset. */
+    private final Dataset<LSQRPartitionContext, LinSysPartitionDataOnHeap> dataset;
+
+    /**
+     * Constructs a new instance of OnHeap LSQR algorithm implementation.
+     *
+     * @param datasetBuilder Dataset builder.
+     * @param partDataBuilder Partition data builder.
+     */
+    public LSQROnHeap(DatasetBuilder<K, V> datasetBuilder,
+        PartitionDataBuilder<K, V, LSQRPartitionContext, LinSysPartitionDataOnHeap> partDataBuilder) {
+        this.dataset = datasetBuilder.build(
+            (upstream, upstreamSize) -> new LSQRPartitionContext(),
+            partDataBuilder
+        );
+    }
+
+    /** {@inheritDoc} */
+    @Override protected double bnorm() {
+        return dataset.computeWithCtx((ctx, data) -> {
+            ctx.setU(Arrays.copyOf(data.getY(), data.getY().length));
+
+            return BLAS.getInstance().dnrm2(data.getY().length, data.getY(), 1);
+        }, (a, b) -> a == null ? b : Math.sqrt(a * a + b * b));
+    }
+
+    /** {@inheritDoc} */
+    @Override protected double beta(double[] x, double alfa, double beta) {
+        return dataset.computeWithCtx((ctx, data) -> {
+            BLAS.getInstance().dgemv("N", data.getRows(), data.getCols(), alfa, data.getX(),
+                Math.max(1, data.getRows()), x, 1, beta, ctx.getU(), 1);
+
+            return BLAS.getInstance().dnrm2(ctx.getU().length, ctx.getU(), 1);
+        }, (a, b) -> a == null ? b : Math.sqrt(a * a + b * b));
+    }
+
+    /** {@inheritDoc} */
+    @Override protected double[] iter(double bnorm, double[] target) {
+        double[] res = dataset.computeWithCtx((ctx, data) -> {
+            BLAS.getInstance().dscal(ctx.getU().length, 1 / bnorm, ctx.getU(), 1);
+            double[] v = new double[data.getCols()];
+            BLAS.getInstance().dgemv("T", data.getRows(), data.getCols(), 1.0, data.getX(),
+                Math.max(1, data.getRows()), ctx.getU(), 1, 0, v, 1);
+
+            return v;
+        }, (a, b) -> {
+            if (a == null)
+                return b;
+            else {
+                BLAS.getInstance().daxpy(a.length, 1.0, a, 1, b, 1);
+
+                return b;
+            }
+        });
+        BLAS.getInstance().daxpy(res.length, 1.0, res, 1, target, 1);
+        return target;
+    }
+
+    /**
+     * Returns number of columns in dataset.
+     *
+     * @return number of columns
+     */
+    @Override protected int getColumns() {
+        return dataset.compute(LinSysPartitionDataOnHeap::getCols, (a, b) -> a == null ? b : a);
+    }
+
+    /** {@inheritDoc} */
+    @Override public void close() throws Exception {
+        dataset.close();
+    }
+}
\ No newline at end of file
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQRPartitionContext.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQRPartitionContext.java
new file mode 100644 (file)
index 0000000..0ec9805
--- /dev/null
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.isolve.lsqr;
+
+import java.io.Serializable;
+
+/**
+ * Partition context of the LSQR algorithm.
+ */
+public class LSQRPartitionContext implements Serializable {
+    /** */
+    private static final long serialVersionUID = -8159608186899430315L;
+
+    /** Part of U vector. */
+    private double[] u;
+
+    /** */
+    public double[] getU() {
+        return u;
+    }
+
+    /** */
+    public void setU(double[] u) {
+        this.u = u;
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQRResult.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQRResult.java
new file mode 100644 (file)
index 0000000..47beddb
--- /dev/null
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.isolve.lsqr;
+
+import java.util.Arrays;
+import org.apache.ignite.ml.math.isolve.IterativeSolverResult;
+
+/**
+ * LSQR iterative solver result.
+ */
+public class LSQRResult extends IterativeSolverResult {
+    /** */
+    private static final long serialVersionUID = -8866269808589635947L;
+
+    /**
+     * Gives the reason for termination. 1 means x is an approximate solution to Ax = b. 2 means x approximately solves
+     * the least-squares problem.
+     */
+    private final int isstop;
+
+    /** Represents norm(r), where r = b - Ax. */
+    private final double r1norn;
+
+    /**Represents sqrt( norm(r)^2  +  damp^2 * norm(x)^2 ). Equal to r1norm if damp == 0. */
+    private final double r2norm;
+
+    /** Estimate of Frobenius norm of Abar = [[A]; [damp*I]]. */
+    private final double anorm;
+
+    /** Estimate of cond(Abar). */
+    private final double acond;
+
+    /** Estimate of norm(A'*r - damp^2*x). */
+    private final double arnorm;
+
+    /** Represents norm(x). */
+    private final double xnorm;
+
+    /**
+     * If calc_var is True, estimates all diagonals of (A'A)^{-1} (if damp == 0) or more generally
+     * (A'A + damp^2*I)^{-1}. This is well defined if A has full column rank or damp > 0.
+     */
+    private final double[] var;
+
+    /**
+     * Constructs a new instance of LSQR result.
+     *
+     * @param x X value.
+     * @param iterations Number of performed iterations.
+     * @param isstop Stop reason.
+     * @param r1norn R1 norm value.
+     * @param r2norm R2 norm value.
+     * @param anorm A norm value.
+     * @param acond A cond value.
+     * @param arnorm AR norm value.
+     * @param xnorm X norm value.
+     * @param var Var value.
+     */
+    public LSQRResult(double[] x, int iterations, int isstop, double r1norn, double r2norm, double anorm, double acond,
+        double arnorm, double xnorm, double[] var) {
+        super(x, iterations);
+        this.isstop = isstop;
+        this.r1norn = r1norn;
+        this.r2norm = r2norm;
+        this.anorm = anorm;
+        this.acond = acond;
+        this.arnorm = arnorm;
+        this.xnorm = xnorm;
+        this.var = var;
+    }
+
+    /** */
+    public int getIsstop() {
+        return isstop;
+    }
+
+    /** */
+    public double getR1norn() {
+        return r1norn;
+    }
+
+    /** */
+    public double getR2norm() {
+        return r2norm;
+    }
+
+    /** */
+    public double getAnorm() {
+        return anorm;
+    }
+
+    /** */
+    public double getAcond() {
+        return acond;
+    }
+
+    /** */
+    public double getArnorm() {
+        return arnorm;
+    }
+
+    /** */
+    public double getXnorm() {
+        return xnorm;
+    }
+
+    /** */
+    public double[] getVar() {
+        return var;
+    }
+
+    /** */
+    @Override public String toString() {
+        return "LSQRResult{" +
+            "isstop=" + isstop +
+            ", r1norn=" + r1norn +
+            ", r2norm=" + r2norm +
+            ", anorm=" + anorm +
+            ", acond=" + acond +
+            ", arnorm=" + arnorm +
+            ", xnorm=" + xnorm +
+            ", var=" + Arrays.toString(var) +
+            '}';
+    }
+}
\ No newline at end of file
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/package-info.java
new file mode 100644 (file)
index 0000000..a667eb7
--- /dev/null
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains LSQR algorithm implementation.
+ */
+package org.apache.ignite.ml.math.isolve.lsqr;
\ No newline at end of file
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/package-info.java
new file mode 100644 (file)
index 0000000..5e0155f
--- /dev/null
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains iterative algorithms for solving linear systems.
+ */
+package org.apache.ignite.ml.math.isolve;
\ No newline at end of file
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java
new file mode 100644 (file)
index 0000000..d7d587e
--- /dev/null
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.regressions.linear;
+
+import java.util.Arrays;
+import org.apache.ignite.ml.DatasetTrainer;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+import org.apache.ignite.ml.math.isolve.LinSysPartitionDataBuilderOnHeap;
+import org.apache.ignite.ml.math.isolve.lsqr.AbstractLSQR;
+import org.apache.ignite.ml.math.isolve.lsqr.LSQROnHeap;
+import org.apache.ignite.ml.math.isolve.lsqr.LSQRResult;
+
+/**
+ * Trainer of the linear regression model based on LSQR algorithm.
+ *
+ * @param <K> Type of a key in {@code upstream} data.
+ * @param <V> Type of a value in {@code upstream} data.
+ *
+ * @see AbstractLSQR
+ */
+public class LinearRegressionLSQRTrainer<K, V> implements DatasetTrainer<K, V, LinearRegressionModel> {
+    /** {@inheritDoc} */
+    @Override public LinearRegressionModel fit(DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor, int cols) {
+
+        LSQRResult res;
+
+        try (LSQROnHeap<K, V> lsqr = new LSQROnHeap<>(
+            datasetBuilder,
+            new LinSysPartitionDataBuilderOnHeap<>(
+                (k, v) -> {
+                    double[] row = Arrays.copyOf(featureExtractor.apply(k, v), cols + 1);
+
+                    row[cols] = 1.0;
+
+                    return row;
+                },
+                lbExtractor,
+                cols + 1
+            )
+        )) {
+            res = lsqr.solve(0, 1e-12, 1e-12, 1e8, -1, false, null);
+        }
+        catch (Exception e) {
+            throw new RuntimeException(e);
+        }
+
+        Vector weights = new DenseLocalOnHeapVector(Arrays.copyOfRange(res.getX(), 0, cols));
+
+        return new LinearRegressionModel(weights, res.getX()[cols]);
+    }
+}
index 5efdf57..b4f83d9 100644 (file)
@@ -20,6 +20,8 @@ package org.apache.ignite.ml.trainers;
 import org.apache.ignite.ml.Model;
 
 /** Trainer interface. */
+@Deprecated
+// TODO: IGNITE-7659: Reduce multiple Trainer interfaces to one
 public interface Trainer<M extends Model, T> {
     /**
      * Train the model based on provided data.
index bb41239..926d872 100644 (file)
@@ -61,6 +61,7 @@ import org.apache.ignite.ml.math.impls.vector.VectorIterableTest;
 import org.apache.ignite.ml.math.impls.vector.VectorNormTest;
 import org.apache.ignite.ml.math.impls.vector.VectorToMatrixTest;
 import org.apache.ignite.ml.math.impls.vector.VectorViewTest;
+import org.apache.ignite.ml.math.isolve.lsqr.LSQROnHeapTest;
 import org.junit.runner.RunWith;
 import org.junit.runners.Suite;
 
@@ -120,7 +121,8 @@ import org.junit.runners.Suite;
     QRDecompositionTest.class,
     SingularValueDecompositionTest.class,
     QRDSolverTest.class,
-    DistanceTest.class
+    DistanceTest.class,
+    LSQROnHeapTest.class
 })
 public class MathImplLocalTestSuite {
     // No-op.
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java
new file mode 100644 (file)
index 0000000..4892ff8
--- /dev/null
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.isolve.lsqr;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import org.apache.ignite.ml.math.isolve.LinSysPartitionDataBuilderOnHeap;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import static org.junit.Assert.assertArrayEquals;
+
+/**
+ * Tests for {@link LSQROnHeap}.
+ */
+@RunWith(Parameterized.class)
+public class LSQROnHeapTest {
+    /** Parameters. */
+    @Parameterized.Parameters(name = "Data divided on {0} partitions")
+    public static Iterable<Integer[]> data() {
+        return Arrays.asList(
+            new Integer[] {1},
+            new Integer[] {2},
+            new Integer[] {3},
+            new Integer[] {5},
+            new Integer[] {7},
+            new Integer[] {100},
+            new Integer[] {1000}
+        );
+    }
+
+    /** Number of partitions. */
+    @Parameterized.Parameter
+    public int parts;
+
+    /** Tests solving simple linear system. */
+    @Test
+    public void testSolveLinearSystem() {
+        Map<Integer, double[]> data = new HashMap<>();
+        data.put(0, new double[]{3, 2, -1, 1});
+        data.put(1, new double[]{2, -2, 4, -2});
+        data.put(2, new double[]{-1, 0.5, -1, 0});
+
+        DatasetBuilder<Integer, double[]> datasetBuilder = new LocalDatasetBuilder<>(data, parts);
+
+        LSQROnHeap<Integer, double[]> lsqr = new LSQROnHeap<>(
+            datasetBuilder,
+            new LinSysPartitionDataBuilderOnHeap<>(
+                (k, v) -> Arrays.copyOf(v, v.length - 1),
+                (k, v) -> v[3],
+                3
+            )
+        );
+
+        LSQRResult res = lsqr.solve(0, 1e-12, 1e-12, 1e8, -1, false, null);
+
+        assertArrayEquals(new double[]{1, -2, -2}, res.getX(), 1e-6);
+    }
+
+    /** Tests solving simple linear system with specified x0. */
+    @Test
+    public void testSolveLinearSystemWithX0() {
+        Map<Integer, double[]> data = new HashMap<>();
+        data.put(0, new double[]{3, 2, -1, 1});
+        data.put(1, new double[]{2, -2, 4, -2});
+        data.put(2, new double[]{-1, 0.5, -1, 0});
+
+        DatasetBuilder<Integer, double[]> datasetBuilder = new LocalDatasetBuilder<>(data, parts);
+
+        LSQROnHeap<Integer, double[]> lsqr = new LSQROnHeap<>(
+            datasetBuilder,
+            new LinSysPartitionDataBuilderOnHeap<>(
+                (k, v) -> Arrays.copyOf(v, v.length - 1),
+                (k, v) -> v[3],
+                3
+            )
+        );
+
+        LSQRResult res = lsqr.solve(0, 1e-12, 1e-12, 1e8, -1, false,
+            new double[] {999, 999, 999});
+
+        assertArrayEquals(new double[]{1, -2, -2}, res.getX(), 1e-6);
+    }
+
+    /** Tests solving least squares problem. */
+    @Test
+    public void testSolveLeastSquares() throws Exception {
+        Map<Integer, double[]> data = new HashMap<>();
+        data.put(0, new double[] {-1.0915526, 1.81983527, -0.91409478, 0.70890712, -24.55724107});
+        data.put(1, new double[] {-0.61072904, 0.37545517, 0.21705352, 0.09516495, -26.57226867});
+        data.put(2, new double[] {0.05485406, 0.88219898, -0.80584547, 0.94668307, 61.80919728});
+        data.put(3, new double[] {-0.24835094, -0.34000053, -1.69984651, -1.45902635, -161.65525991});
+        data.put(4, new double[] {0.63675392, 0.31675535, 0.38837437, -1.1221971, -14.46432611});
+        data.put(5, new double[] {0.14194017, 2.18158997, -0.28397346, -0.62090588, -3.2122197});
+        data.put(6, new double[] {-0.53487507, 1.4454797, 0.21570443, -0.54161422, -46.5469012});
+        data.put(7, new double[] {-1.58812173, -0.73216803, -2.15670676, -1.03195988, -247.23559889});
+        data.put(8, new double[] {0.20702671, 0.92864654, 0.32721202, -0.09047503, 31.61484949});
+        data.put(9, new double[] {-0.37890345, -0.04846179, -0.84122753, -1.14667474, -124.92598583});
+
+        DatasetBuilder<Integer, double[]> datasetBuilder = new LocalDatasetBuilder<>(data, 1);
+
+        try (LSQROnHeap<Integer, double[]> lsqr = new LSQROnHeap<>(
+            datasetBuilder,
+            new LinSysPartitionDataBuilderOnHeap<>(
+                (k, v) -> Arrays.copyOf(v, v.length - 1),
+                (k, v) -> v[4],
+                4
+            )
+        )) {
+            LSQRResult res = lsqr.solve(0, 1e-12, 1e-12, 1e8, -1, false, null);
+
+            assertArrayEquals(new double[]{72.26948107,  15.95144674,  24.07403921,  66.73038781}, res.getX(), 1e-6);
+        }
+    }
+}
index 5c79c8f..82b3a1b 100644 (file)
@@ -21,6 +21,7 @@ import org.apache.ignite.ml.regressions.linear.BlockDistributedLinearRegressionQ
 import org.apache.ignite.ml.regressions.linear.BlockDistributedLinearRegressionSGDTrainerTest;
 import org.apache.ignite.ml.regressions.linear.DistributedLinearRegressionQRTrainerTest;
 import org.apache.ignite.ml.regressions.linear.DistributedLinearRegressionSGDTrainerTest;
+import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainerTest;
 import org.apache.ignite.ml.regressions.linear.LinearRegressionModelTest;
 import org.apache.ignite.ml.regressions.linear.LocalLinearRegressionQRTrainerTest;
 import org.apache.ignite.ml.regressions.linear.LocalLinearRegressionSGDTrainerTest;
@@ -38,7 +39,8 @@ import org.junit.runners.Suite;
     DistributedLinearRegressionQRTrainerTest.class,
     DistributedLinearRegressionSGDTrainerTest.class,
     BlockDistributedLinearRegressionQRTrainerTest.class,
-    BlockDistributedLinearRegressionSGDTrainerTest.class
+    BlockDistributedLinearRegressionSGDTrainerTest.class,
+    LinearRegressionLSQRTrainerTest.class
 })
 public class RegressionsTestSuite {
     // No-op.
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java
new file mode 100644 (file)
index 0000000..3bb3ee7
--- /dev/null
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.regressions.linear;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Random;
+import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Tests for {@link LinearRegressionLSQRTrainer}.
+ */
+@RunWith(Parameterized.class)
+public class LinearRegressionLSQRTrainerTest {
+    /** Parameters. */
+    @Parameterized.Parameters(name = "Data divided on {0} partitions")
+    public static Iterable<Integer[]> data() {
+        return Arrays.asList(
+            new Integer[] {1},
+            new Integer[] {2},
+            new Integer[] {3},
+            new Integer[] {5},
+            new Integer[] {7},
+            new Integer[] {100},
+            new Integer[] {1000}
+        );
+    }
+
+    /** Number of partitions. */
+    @Parameterized.Parameter
+    public int parts;
+
+    /**
+     * Tests {@code fit()} method on a simple small dataset.
+     */
+    @Test
+    public void testSmallDataFit() {
+        Map<Integer, double[]> data = new HashMap<>();
+        data.put(0, new double[] {-1.0915526, 1.81983527, -0.91409478, 0.70890712, -24.55724107});
+        data.put(1, new double[] {-0.61072904, 0.37545517, 0.21705352, 0.09516495, -26.57226867});
+        data.put(2, new double[] {0.05485406, 0.88219898, -0.80584547, 0.94668307, 61.80919728});
+        data.put(3, new double[] {-0.24835094, -0.34000053, -1.69984651, -1.45902635, -161.65525991});
+        data.put(4, new double[] {0.63675392, 0.31675535, 0.38837437, -1.1221971, -14.46432611});
+        data.put(5, new double[] {0.14194017, 2.18158997, -0.28397346, -0.62090588, -3.2122197});
+        data.put(6, new double[] {-0.53487507, 1.4454797, 0.21570443, -0.54161422, -46.5469012});
+        data.put(7, new double[] {-1.58812173, -0.73216803, -2.15670676, -1.03195988, -247.23559889});
+        data.put(8, new double[] {0.20702671, 0.92864654, 0.32721202, -0.09047503, 31.61484949});
+        data.put(9, new double[] {-0.37890345, -0.04846179, -0.84122753, -1.14667474, -124.92598583});
+
+        LinearRegressionLSQRTrainer<Integer, double[]> trainer = new LinearRegressionLSQRTrainer<>();
+
+        LinearRegressionModel mdl = trainer.fit(
+            new LocalDatasetBuilder<>(data, parts),
+            (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1),
+            (k, v) -> v[4],
+            4
+        );
+
+        assertArrayEquals(
+            new double[]{72.26948107,  15.95144674,  24.07403921,  66.73038781},
+            mdl.getWeights().getStorage().data(),
+            1e-6
+        );
+
+        assertEquals(2.8421709430404007e-14, mdl.getIntercept(), 1e-6);
+    }
+
+    /**
+     * Tests {@code fit()} method on a big (100000 x 100) dataset.
+     */
+    @Test
+    public void testBigDataFit() {
+        Random rnd = new Random(0);
+        Map<Integer, double[]> data = new HashMap<>();
+        double[] coef = new double[100];
+        double intercept = rnd.nextDouble() * 10;
+
+        for (int i = 0; i < 100000; i++) {
+            double[] x = new double[coef.length + 1];
+
+            for (int j = 0; j < coef.length; j++)
+                x[j] = rnd.nextDouble() * 10;
+
+            x[coef.length] = intercept;
+
+            data.put(i, x);
+        }
+
+        LinearRegressionLSQRTrainer<Integer, double[]> trainer = new LinearRegressionLSQRTrainer<>();
+
+        LinearRegressionModel mdl = trainer.fit(
+            new LocalDatasetBuilder<>(data, parts),
+            (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1),
+            (k, v) -> v[coef.length],
+            coef.length
+        );
+
+        assertArrayEquals(coef, mdl.getWeights().getStorage().data(), 1e-6);
+
+        assertEquals(intercept, mdl.getIntercept(), 1e-6);
+    }
+}