IGNITE-11138: [ML] Predict from SQL
authordmitrievanthony <dmitrievanthony@gmail.com>
Fri, 1 Feb 2019 10:04:44 +0000 (13:04 +0300)
committerYury Babak <ybabak@gridgain.com>
Fri, 1 Feb 2019 10:04:44 +0000 (13:04 +0300)
This closes #5977

examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLInferenceExample.java
examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLTableExample.java
modules/ml/src/main/java/org/apache/ignite/ml/inference/IgniteModelStorageUtil.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/sql/SQLFeatureLabelExtractor.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/sql/SQLFunctions.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/sql/SqlDatasetBuilder.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/sql/package-info.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java

index b7ae1de..c8f0596 100644 (file)
 
 package org.apache.ignite.examples.ml.sql;
 
-import java.io.ByteArrayInputStream;
-import java.io.ByteArrayOutputStream;
-import java.io.IOException;
-import java.io.ObjectInputStream;
-import java.io.ObjectOutputStream;
-import java.io.Serializable;
 import java.util.List;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.Ignition;
-import org.apache.ignite.binary.BinaryObject;
 import org.apache.ignite.cache.query.QueryCursor;
 import org.apache.ignite.cache.query.SqlFieldsQuery;
-import org.apache.ignite.cache.query.annotations.QuerySqlFunction;
 import org.apache.ignite.configuration.CacheConfiguration;
 import org.apache.ignite.internal.util.IgniteUtils;
-import org.apache.ignite.ml.IgniteModel;
-import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
-import org.apache.ignite.ml.inference.Model;
-import org.apache.ignite.ml.inference.ModelDescriptor;
-import org.apache.ignite.ml.inference.ModelSignature;
-import org.apache.ignite.ml.inference.builder.SingleModelBuilder;
-import org.apache.ignite.ml.inference.parser.IgniteModelParser;
-import org.apache.ignite.ml.inference.reader.ModelStorageModelReader;
-import org.apache.ignite.ml.inference.storage.descriptor.ModelDescriptorStorage;
-import org.apache.ignite.ml.inference.storage.descriptor.ModelDescriptorStorageFactory;
-import org.apache.ignite.ml.inference.storage.model.ModelStorage;
-import org.apache.ignite.ml.inference.storage.model.ModelStorageFactory;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.inference.IgniteModelStorageUtil;
+import org.apache.ignite.ml.sql.SQLFeatureLabelExtractor;
+import org.apache.ignite.ml.sql.SQLFunctions;
+import org.apache.ignite.ml.sql.SqlDatasetBuilder;
 import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
 import org.apache.ignite.ml.tree.DecisionTreeNode;
 
@@ -65,7 +47,7 @@ public class DecisionTreeClassificationTrainerSQLInferenceExample {
     private static final String TEST_DATA_RES = "examples/src/main/resources/datasets/titanik_test.csv";
 
     /** Run example. */
-    public static void main(String[] args) throws IOException {
+    public static void main(String[] args) {
         System.out.println(">>> Decision tree classification trainer example started.");
 
         // Start ignite grid.
@@ -122,59 +104,25 @@ public class DecisionTreeClassificationTrainerSQLInferenceExample {
             DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
 
             System.out.println(">>> Perform training...");
-            IgniteCache<Integer, BinaryObject> titanicTrainCache = ignite.cache("SQL_PUBLIC_TITANIK_TRAIN");
             DecisionTreeNode mdl = trainer.fit(
-                // We have to specify ".withKeepBinary(true)" because SQL caches contains only binary objects and this
-                // information has to be passed into the trainer.
-                new CacheBasedDatasetBuilder<>(ignite, titanicTrainCache).withKeepBinary(true),
-                (k, v) -> VectorUtils.of(
-                    // We have to handle null values here to avoid NpE during unboxing.
-                    replaceNull(v.<Integer>field("pclass")),
-                    "male".equals(v.<String>field("sex")) ? 1 : 0,
-                    replaceNull(v.<Double>field("age")),
-                    replaceNull(v.<Integer>field("sibsp")),
-                    replaceNull(v.<Integer>field("parch")),
-                    replaceNull(v.<Double>field("fare"))
-                ),
-                (k, v) -> replaceNull(v.<Integer>field("survived"))
+                new SqlDatasetBuilder(ignite, "SQL_PUBLIC_TITANIK_TRAIN"),
+                new SQLFeatureLabelExtractor()
+                    .withFeatureFields("pclass", "age", "sibsp", "parch", "fare")
+                    .withFeatureField("sex", e -> "male".equals(e) ? 1 : 0)
+                    .withLabelField("survived")
             );
 
             System.out.println(">>> Saving model...");
 
             // Model storage is used to store raw serialized model.
             System.out.println("Saving model into model storage...");
-            byte[] serializedMdl = serialize((IgniteModel<byte[], byte[]>)i -> {
-                // Here we need to wrap model so that it accepts and returns byte array.
-                try {
-                    Vector input = deserialize(i);
-                    return serialize(mdl.predict(input));
-                }
-                catch (IOException | ClassNotFoundException e) {
-                    throw new RuntimeException(e);
-                }
-            });
-
-            ModelStorage storage = new ModelStorageFactory().getModelStorage(ignite);
-            storage.mkdirs("/");
-            storage.putFile("/my_model", serializedMdl);
-
-            // Model descriptor storage that is used to store model metadata.
-            System.out.println("Saving model descriptor into model descriptor storage...");
-            ModelDescriptor desc = new ModelDescriptor(
-                "MyModel",
-                "My Cool Model",
-                new ModelSignature("", "", ""),
-                new ModelStorageModelReader("/my_model"),
-                new IgniteModelParser<>()
-            );
-            ModelDescriptorStorage descStorage = new ModelDescriptorStorageFactory().getModelDescriptorStorage(ignite);
-            descStorage.put("my_model", desc);
+            IgniteModelStorageUtil.saveModel(mdl, "titanik_model_tree");
 
             // Making inference using saved model.
             System.out.println("Inference...");
             try (QueryCursor<List<?>> cursor = cache.query(new SqlFieldsQuery("select " +
                 "survived as truth, " +
-                "predict('my_model', pclass, case sex when 'male' then 1 else 0 end, age, sibsp, parch, fare) as prediction " +
+                "predict('titanik_model_tree', pclass, age, sibsp, parch, fare, case sex when 'male' then 1 else 0 end) as prediction " +
                 "from titanik_train"))) {
                 // Print inference result.
                 System.out.println("| Truth | Prediction |");
@@ -184,91 +132,4 @@ public class DecisionTreeClassificationTrainerSQLInferenceExample {
             }
         }
     }
-
-    /**
-     * Replaces NULL values by 0.
-     *
-     * @param obj Input value.
-     * @param <T> Type of value.
-     * @return Input value of 0 if value is null.
-     */
-    private static <T extends Number> double replaceNull(T obj) {
-        if (obj == null)
-            return 0;
-
-        return obj.doubleValue();
-    }
-
-    /**
-     * SQL functions that should be defined and passed into cache configuration to extend list of functions available
-     * in SQL interface.
-     */
-    public static class SQLFunctions {
-        /**
-         * Makes prediction using specified model name to extract model from model storage and specified input values
-         * as input object for prediction.
-         *
-         * @param mdl Pretrained model.
-         * @param x Input values.
-         * @return Prediction.
-         */
-        @QuerySqlFunction
-        public static double predict(String mdl, Double... x) {
-            // Pretrained models work with vector of doubles so we need to replace null by 0 (or any other double).
-            for (int i = 0; i < x.length; i++)
-                if (x[i] == null)
-                    x[i] = 0.0;
-
-            Ignite ignite = Ignition.ignite();
-
-            ModelDescriptorStorage descStorage = new ModelDescriptorStorageFactory().getModelDescriptorStorage(ignite);
-            ModelDescriptor desc = descStorage.get(mdl);
-
-            Model<byte[], byte[]> infMdl = new SingleModelBuilder().build(desc.getReader(), desc.getParser());
-
-            Vector input = VectorUtils.of(x);
-
-            try {
-                return deserialize(infMdl.predict(serialize(input)));
-            }
-            catch (IOException | ClassNotFoundException e) {
-                throw new RuntimeException(e);
-            }
-        }
-    }
-
-    /**
-     * Serialized the specified object.
-     *
-     * @param o Object to be serialized.
-     * @return Serialized object as byte array.
-     * @throws IOException In case of exception.
-     */
-    private static <T extends Serializable> byte[] serialize(T o) throws IOException {
-        try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
-             ObjectOutputStream oos = new ObjectOutputStream(baos)) {
-            oos.writeObject(o);
-            oos.flush();
-
-            return baos.toByteArray();
-        }
-    }
-
-    /**
-     * Deserialized object represented as a byte array.
-     *
-     * @param o Serialized object.
-     * @param <T> Type of serialized object.
-     * @return Deserialized object.
-     * @throws IOException In case of exception.
-     * @throws ClassNotFoundException In case of exception.
-     */
-    @SuppressWarnings("unchecked")
-    private static <T extends Serializable> T deserialize(byte[] o) throws IOException, ClassNotFoundException {
-        try (ByteArrayInputStream bais = new ByteArrayInputStream(o);
-             ObjectInputStream ois = new ObjectInputStream(bais)) {
-
-            return (T)ois.readObject();
-        }
-    }
 }
index d7ff059..a4f9a2d 100644 (file)
@@ -21,14 +21,14 @@ import java.util.List;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.Ignition;
-import org.apache.ignite.binary.BinaryObject;
 import org.apache.ignite.cache.query.QueryCursor;
 import org.apache.ignite.cache.query.SqlFieldsQuery;
 import org.apache.ignite.configuration.CacheConfiguration;
 import org.apache.ignite.internal.util.IgniteUtils;
-import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.sql.SQLFeatureLabelExtractor;
+import org.apache.ignite.ml.sql.SqlDatasetBuilder;
 import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
 import org.apache.ignite.ml.tree.DecisionTreeNode;
 
@@ -102,21 +102,12 @@ public class DecisionTreeClassificationTrainerSQLTableExample {
             DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
 
             System.out.println(">>> Perform training...");
-            IgniteCache<Integer, BinaryObject> titanicTrainCache = ignite.cache("SQL_PUBLIC_TITANIK_TRAIN");
             DecisionTreeNode mdl = trainer.fit(
-                // We have to specify ".withKeepBinary(true)" because SQL caches contains only binary objects and this
-                // information has to be passed into the trainer.
-                new CacheBasedDatasetBuilder<>(ignite, titanicTrainCache).withKeepBinary(true),
-                (k, v) -> VectorUtils.of(
-                    // We have to handle null values here to avoid NpE during unboxing.
-                    replaceNull(v.<Integer>field("pclass")),
-                    "male".equals(v.<String>field("sex")) ? 1 : 0,
-                    replaceNull(v.<Double>field("age")),
-                    replaceNull(v.<Integer>field("sibsp")),
-                    replaceNull(v.<Integer>field("parch")),
-                    replaceNull(v.<Double>field("fare"))
-                ),
-                (k, v) -> replaceNull(v.<Integer>field("survived"))
+                new SqlDatasetBuilder(ignite, "SQL_PUBLIC_TITANIK_TRAIN"),
+                new SQLFeatureLabelExtractor()
+                    .withFeatureFields("pclass", "age", "sibsp", "parch", "fare")
+                    .withFeatureField("sex", e -> "male".equals(e) ? 1 : 0)
+                    .withLabelField("survived")
             );
 
             System.out.println(">>> Perform inference...");
@@ -128,15 +119,14 @@ public class DecisionTreeClassificationTrainerSQLTableExample {
                 "parch, " +
                 "fare from titanik_test"))) {
                 for (List<?> passenger : cursor) {
-                    Vector input = VectorUtils.of(
-                        // We have to handle null values here to avoid NpE during unboxing.
-                        replaceNull((Integer)passenger.get(0)),
-                        "male".equals(passenger.get(1)) ? 1 : 0,
-                        replaceNull((Double)passenger.get(2)),
-                        replaceNull((Integer)passenger.get(3)),
-                        replaceNull((Integer)passenger.get(4)),
-                        replaceNull((Double)passenger.get(5))
-                    );
+                    Vector input = VectorUtils.of(new Double[]{
+                        asDouble(passenger.get(0)),
+                        "male".equals(passenger.get(1)) ? 1.0 : 0.0,
+                        asDouble(passenger.get(2)),
+                        asDouble(passenger.get(3)),
+                        asDouble(passenger.get(4)),
+                        asDouble(passenger.get(5))
+                    });
 
                     double prediction = mdl.predict(input);
 
@@ -149,16 +139,22 @@ public class DecisionTreeClassificationTrainerSQLTableExample {
     }
 
     /**
-     * Replaces NULL values by 0.
+     * Converts specified number into double.
      *
-     * @param obj Input value.
-     * @param <T> Type of value.
-     * @return Input value of 0 if value is null.
+     * @param obj Number.
+     * @param <T> Type of number.
+     * @return Double.
      */
-    private static <T extends Number> double replaceNull(T obj) {
+    private static <T extends Number> Double asDouble(Object obj) {
         if (obj == null)
-            return 0;
+            return null;
 
-        return obj.doubleValue();
+        if (obj instanceof Number) {
+            Number num = (Number) obj;
+
+            return num.doubleValue();
+        }
+
+        throw new IllegalArgumentException("Object is expected to be a number [obj=" + obj + "]");
     }
 }
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/inference/IgniteModelStorageUtil.java b/modules/ml/src/main/java/org/apache/ignite/ml/inference/IgniteModelStorageUtil.java
new file mode 100644 (file)
index 0000000..af0a1a5
--- /dev/null
@@ -0,0 +1,257 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.inference;
+
+import java.util.Objects;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.IgniteModel;
+import org.apache.ignite.ml.inference.builder.AsyncModelBuilder;
+import org.apache.ignite.ml.inference.builder.SingleModelBuilder;
+import org.apache.ignite.ml.inference.builder.SyncModelBuilder;
+import org.apache.ignite.ml.inference.parser.IgniteModelParser;
+import org.apache.ignite.ml.inference.reader.ModelStorageModelReader;
+import org.apache.ignite.ml.inference.storage.descriptor.ModelDescriptorStorage;
+import org.apache.ignite.ml.inference.storage.descriptor.ModelDescriptorStorageFactory;
+import org.apache.ignite.ml.inference.storage.model.ModelStorage;
+import org.apache.ignite.ml.inference.storage.model.ModelStorageFactory;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.util.Utils;
+import org.jetbrains.annotations.NotNull;
+
+/**
+ * Utils class that helps to operate with model storage and Ignite models.
+ */
+public class IgniteModelStorageUtil {
+    /** Folder to be used to store Ignite models. */
+    private static final String IGNITE_MDL_FOLDER = "/ignite_models";
+
+    /**
+     * Saved specified model with specified name.
+     *
+     * @param mdl Model to be saved.
+     * @param name Model name to be used.
+     */
+    public static void saveModel(IgniteModel<Vector, Double> mdl, String name) {
+        IgniteModel<byte[], byte[]> mdlWrapper = wrapIgniteModel(mdl);
+        byte[] serializedMdl = Utils.serialize(mdlWrapper);
+        UUID mdlId = UUID.randomUUID();
+
+        saveModelStorage(serializedMdl, mdlId);
+        saveModelDescriptorStorage(name, mdlId);
+    }
+
+    /**
+     * Retrieves Ignite model by name using {@link SingleModelBuilder}.
+     *
+     * @param name Model name.
+     * @return Synchronous model built using {@link SingleModelBuilder}.
+     */
+    public static Model<Vector, Double> getModel(String name) {
+        return getSyncModel(name, new SingleModelBuilder());
+    }
+
+    /**
+     * Retrieves Ignite model by name using synchronous model builder.
+     *
+     * @param name Model name.
+     * @param mdlBldr Synchronous model builder.
+     * @return Synchronous model built using specified model builder.
+     */
+    public static Model<Vector, Double> getSyncModel(String name, SyncModelBuilder mdlBldr) {
+        ModelDescriptor desc = Objects.requireNonNull(getModelDescriptor(name), "Model not found [name=" + name + "]");
+
+        Model<byte[], byte[]> infMdl = mdlBldr.build(desc.getReader(), desc.getParser());
+
+        return unwrapIgniteSyncModel(infMdl);
+    }
+
+    /**
+     * Retrieves Ignite model by name using asynchronous model builder.
+     *
+     * @param name Model name.
+     * @param mdlBldr Asynchronous model builder.
+     * @return Asynchronous model built using specified model builder.
+     */
+    public static Model<Vector, Future<Double>> getAsyncModel(String name, AsyncModelBuilder mdlBldr) {
+        ModelDescriptor desc = Objects.requireNonNull(getModelDescriptor(name), "Model not found [name=" + name + "]");
+
+        Model<byte[], Future<byte[]>> infMdl = mdlBldr.build(desc.getReader(), desc.getParser());
+
+        return unwrapIgniteAsyncModel(infMdl);
+    }
+
+    /**
+     * Saves specified serialized model into storage as a file.
+     *
+     * @param serializedMdl Serialized model represented as a byte array.
+     * @param mdlId Model identifier.
+     */
+    private static void saveModelStorage(byte[] serializedMdl, UUID mdlId) {
+        Ignite ignite = Ignition.ignite();
+
+        ModelStorage storage = new ModelStorageFactory().getModelStorage(ignite);
+        storage.mkdirs(IGNITE_MDL_FOLDER);
+        storage.putFile(IGNITE_MDL_FOLDER + "/" + mdlId, serializedMdl);
+    }
+
+    /**
+     * Saves model descriptor into descriptor storage.
+     *
+     * @param name Model name.
+     * @param mdlId Model identifier used to find model in model storage (only with {@link ModelStorageModelReader}).
+     */
+    private static void saveModelDescriptorStorage(String name, UUID mdlId) {
+        Ignite ignite = Ignition.ignite();
+
+        ModelDescriptorStorage descStorage = new ModelDescriptorStorageFactory().getModelDescriptorStorage(ignite);
+        descStorage.put(name, new ModelDescriptor(
+            name,
+            null,
+            new ModelSignature(null, null, null),
+            new ModelStorageModelReader(IGNITE_MDL_FOLDER + "/" + mdlId),
+            new IgniteModelParser<>()
+        ));
+    }
+
+    /**
+     * Retirieves model descriptor.
+     *
+     * @param name Model name.
+     * @return Model descriptor.
+     */
+    private static ModelDescriptor getModelDescriptor(String name) {
+        Ignite ignite = Ignition.ignite();
+
+        ModelDescriptorStorage descStorage = new ModelDescriptorStorageFactory().getModelDescriptorStorage(ignite);
+
+        return descStorage.get(name);
+    }
+
+    /**
+     * Wraps Ignite model so that model accepts and returns serialized objects (byte arrays).
+     *
+     * @param mdl Ignite model.
+     * @return Ignite model that accepts and returns serialized objects (byte arrays).
+     */
+    private static IgniteModel<byte[], byte[]> wrapIgniteModel(IgniteModel<Vector, Double> mdl) {
+        return input -> {
+            Vector deserializedInput = Utils.deserialize(input);
+            Double output = mdl.predict(deserializedInput);
+
+            return Utils.serialize(output);
+        };
+    }
+
+    /**
+     * Unwraps Ignite model so that model accepts and returns deserialized objects ({@link Vector} and {@link Double}).
+     *
+     * @param mdl Ignite model.
+     * @return Ignite model that accepts and returns deserialized objects ({@link Vector} and {@link Double}).
+     */
+    private static Model<Vector, Double> unwrapIgniteSyncModel(Model<byte[], byte[]> mdl) {
+        return new Model<Vector, Double>() {
+            /** {@inheritDoc} */
+            @Override public Double predict(Vector input) {
+                byte[] serializedInput = Utils.serialize(input);
+                byte[] serializedOutput = mdl.predict(serializedInput);
+
+                return Utils.deserialize(serializedOutput);
+            }
+
+            /** {@inheritDoc} */
+            @Override public void close() {
+                mdl.close();
+            }
+        };
+    }
+
+    /**
+     * Unwraps Ignite model so that model accepts and returns deserialized objects ({@link Vector} and {@link Double}).
+     *
+     * @param mdl Ignite model.
+     * @return Ignite model that accepts and returns deserialized objects ({@link Vector} and {@link Double}).
+     */
+    private static Model<Vector, Future<Double>> unwrapIgniteAsyncModel(Model<byte[], Future<byte[]>> mdl) {
+        return new Model<Vector, Future<Double>>() {
+            /** {@inheritDoc} */
+            @Override public Future<Double> predict(Vector input) {
+                byte[] serializedInput = Utils.serialize(input);
+                Future<byte[]> serializedOutput = mdl.predict(serializedInput);
+
+                return new FutureDeserializationWrapper<>(serializedOutput);
+            }
+
+            /** {@inheritDoc} */
+            @Override public void close() {
+                mdl.close();
+            }
+        };
+    }
+
+    /**
+     * Future deserialization wrapper that accepts future that returns serialized object and turns it into future that
+     * returns deserialized object.
+     *
+     * @param <T> Type of return value.
+     */
+    private static class FutureDeserializationWrapper<T> implements Future<T> {
+        /** Delegate. */
+        private final Future<byte[]> delegate;
+
+        /**
+         * Constructs a new instance of future deserialization wrapper.
+         *
+         * @param delegate Delegate.
+         */
+        public FutureDeserializationWrapper(Future<byte[]> delegate) {
+            this.delegate = delegate;
+        }
+
+        /** {@inheritDoc} */
+        @Override public boolean cancel(boolean mayInterruptIfRunning) {
+            return delegate.cancel(mayInterruptIfRunning);
+        }
+
+        /** {@inheritDoc} */
+        @Override public boolean isCancelled() {
+            return delegate.isCancelled();
+        }
+
+        /** {@inheritDoc} */
+        @Override public boolean isDone() {
+            return delegate.isDone();
+        }
+
+        /** {@inheritDoc} */
+        @Override public T get() throws InterruptedException, ExecutionException {
+            return (T)Utils.deserialize(delegate.get());
+        }
+
+        /** {@inheritDoc} */
+        @Override public T get(long timeout, @NotNull TimeUnit unit) throws InterruptedException, ExecutionException,
+            TimeoutException {
+            return (T)Utils.deserialize(delegate.get(timeout, unit));
+        }
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/sql/SQLFeatureLabelExtractor.java b/modules/ml/src/main/java/org/apache/ignite/ml/sql/SQLFeatureLabelExtractor.java
new file mode 100644 (file)
index 0000000..4ed3a6a
--- /dev/null
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.sql;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.function.Function;
+import org.apache.ignite.binary.BinaryObject;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
+import org.apache.ignite.ml.structures.LabeledVector;
+import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
+
+/**
+ * SQL feature label extractor that should be used to extract features and label from binary objects in SQL table.
+ */
+public class SQLFeatureLabelExtractor implements FeatureLabelExtractor<Object, BinaryObject, Double> {
+    /** Feature extractors for each needed fields as a list of functions. */
+    private final List<Function<BinaryObject, Number>> featureExtractors = new ArrayList<>();
+
+    /** Label extractor as a function. */
+    private Function<BinaryObject, Number> lbExtractor;
+
+    /** {@inheritDoc} */
+    @Override public LabeledVector<Double> extract(Object o, BinaryObject obj) {
+        Vector features = new DenseVector(featureExtractors.size());
+
+        int i = 0;
+        for (Function<BinaryObject, Number> featureExtractor : featureExtractors) {
+            Number val = featureExtractor.apply(obj);
+
+            if (val != null)
+                features.set(i, val.doubleValue());
+
+            i++;
+        }
+
+        Number lb = lbExtractor.apply(obj);
+
+        return new LabeledVector<>(features, lb == null ? null : lb.doubleValue());
+    }
+
+    /**
+     * Adds feature extractor for the field with specified name and value transformer.
+     *
+     * @param name Field name.
+     * @param transformer Field value transformer.
+     * @param <T> Field type.
+     * @return This SQL feature label extractor.
+     */
+    public <T> SQLFeatureLabelExtractor withFeatureField(String name, Function<T, Number> transformer) {
+        featureExtractors.add(obj -> transformer.apply(obj.<T>field(name)));
+
+        return this;
+    }
+
+    /**
+     * Adds feature extractor for the field with specified name. Field should be numeric (subclass of {@link Number}).
+     *
+     * @param name Field name.
+     * @return This SQL feature label extractor.
+     */
+    public SQLFeatureLabelExtractor withFeatureField(String name) {
+        featureExtractors.add(obj -> obj.field(name));
+
+        return this;
+    }
+
+    /**
+     * Adds feature extractor for the field with specified name. Field should be numeric (subclass of {@link Number}).
+     *
+     * @param names Field names.
+     * @return This SQL feature label extractor.
+     */
+    public SQLFeatureLabelExtractor withFeatureFields(String... names) {
+        for (String name : names)
+            withFeatureField(name);
+
+        return this;
+    }
+
+    /**
+     * Adds label extractor.
+     *
+     * @param name Field name.
+     * @return This SQL feature label extractor.
+     */
+    public SQLFeatureLabelExtractor withLabelField(String name) {
+        lbExtractor = obj -> obj.field(name);
+
+        return this;
+    }
+
+    /**
+     * Adds label extractor.
+     *
+     * @param name Field name.
+     * @param transformer Field value transformer.
+     * @param <T> Type of field.
+     * @return This SQL feature label extractor.
+     */
+    public <T> SQLFeatureLabelExtractor withLabelField(String name, Function<T, Number> transformer) {
+        lbExtractor = obj -> transformer.apply(obj.<T>field(name));
+
+        return this;
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/sql/SQLFunctions.java b/modules/ml/src/main/java/org/apache/ignite/ml/sql/SQLFunctions.java
new file mode 100644 (file)
index 0000000..a12d69b
--- /dev/null
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.sql;
+
+import java.util.Arrays;
+import org.apache.ignite.cache.query.annotations.QuerySqlFunction;
+import org.apache.ignite.ml.inference.IgniteModelStorageUtil;
+import org.apache.ignite.ml.inference.Model;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+
+/**
+ * SQL functions that should be defined and passed into cache configuration to extend list of functions available
+ * in SQL interface.
+ */
+public class SQLFunctions {
+    /**
+     * Makes prediction using specified model name to extract model from model storage and specified input values
+     * as input object for prediction.
+     *
+     * @param mdl Pretrained model.
+     * @param x Input values.
+     * @return Prediction.
+     */
+    @QuerySqlFunction
+    public static double predict(String mdl, Double... x) {
+        System.out.println("Prediction for " + Arrays.toString(x));
+
+        try (Model<Vector, Double> infMdl = IgniteModelStorageUtil.getModel(mdl)) {
+            return infMdl.predict(VectorUtils.of(x));
+        }
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/sql/SqlDatasetBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/sql/SqlDatasetBuilder.java
new file mode 100644 (file)
index 0000000..4d31ca3
--- /dev/null
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.sql;
+
+import org.apache.ignite.Ignite;
+import org.apache.ignite.binary.BinaryObject;
+import org.apache.ignite.lang.IgniteBiPredicate;
+import org.apache.ignite.ml.dataset.UpstreamTransformerBuilder;
+import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDataset;
+import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
+
+/**
+ * Subclass of {@link CacheBasedDatasetBuilder} that should be used to work with SQL tables.
+ */
+public class SqlDatasetBuilder extends CacheBasedDatasetBuilder<Object, BinaryObject> {
+    /**
+     * Constructs a new instance of cache based dataset builder that makes {@link CacheBasedDataset} with default
+     * predicate that passes all upstream entries to dataset.
+     *
+     * @param ignite Ignite instance.
+     * @param upstreamCache Name of Ignite Cache with {@code upstream} data.
+     */
+    public SqlDatasetBuilder(Ignite ignite, String upstreamCache) {
+        this(ignite, upstreamCache, (a, b) -> true);
+    }
+
+    /**
+     * Constructs a new instance of cache based dataset builder that makes {@link CacheBasedDataset}.
+     *
+     * @param ignite Ignite instance.
+     * @param upstreamCache Name of Ignite Cache with {@code upstream} data.
+     * @param filter Filter for {@code upstream} data.
+     */
+    public SqlDatasetBuilder(Ignite ignite, String upstreamCache, IgniteBiPredicate<Object, BinaryObject> filter) {
+        this(ignite, upstreamCache, filter, UpstreamTransformerBuilder.identity());
+    }
+
+    /**
+     * Constructs a new instance of cache based dataset builder that makes {@link CacheBasedDataset}.
+     *
+     * @param ignite Ignite instance.
+     * @param upstreamCache Name of Ignite Cache with {@code upstream} data.
+     * @param filter Filter for {@code upstream} data.
+     */
+    public SqlDatasetBuilder(Ignite ignite, String upstreamCache, IgniteBiPredicate<Object, BinaryObject> filter,
+        UpstreamTransformerBuilder transformerBuilder) {
+        super(ignite, ignite.cache(upstreamCache), filter, transformerBuilder, true);
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/sql/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/sql/package-info.java
new file mode 100644 (file)
index 0000000..9bc590d
--- /dev/null
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains util classes that help to work with machine learning models in SQL and train models on data from SQL tables.
+ */
+package org.apache.ignite.ml.sql;
\ No newline at end of file
index 63a9f3c..016c468 100644 (file)
@@ -22,6 +22,7 @@ import java.io.ByteArrayOutputStream;
 import java.io.IOException;
 import java.io.ObjectInputStream;
 import java.io.ObjectOutputStream;
+import java.io.Serializable;
 import java.util.Iterator;
 import java.util.Random;
 import java.util.Spliterator;
@@ -130,4 +131,42 @@ public class Utils {
                 Spliterators.spliteratorUnknownSize(iter, Spliterator.ORDERED),
                 false);
     }
+
+    /**
+     * Serialized the specified object.
+     *
+     * @param o Object to be serialized.
+     * @return Serialized object as byte array.
+     */
+    public static <T extends Serializable> byte[] serialize(T o) {
+        try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
+             ObjectOutputStream oos = new ObjectOutputStream(baos)) {
+            oos.writeObject(o);
+            oos.flush();
+
+            return baos.toByteArray();
+        }
+        catch (IOException e) {
+            throw new RuntimeException(e);
+        }
+    }
+
+    /**
+     * Deserialized object represented as a byte array.
+     *
+     * @param o Serialized object.
+     * @param <T> Type of serialized object.
+     * @return Deserialized object.
+     */
+    @SuppressWarnings("unchecked")
+    public static <T extends Serializable> T deserialize(byte[] o) {
+        try (ByteArrayInputStream bais = new ByteArrayInputStream(o);
+             ObjectInputStream ois = new ObjectInputStream(bais)) {
+
+            return (T)ois.readObject();
+        }
+        catch (IOException | ClassNotFoundException e) {
+            throw new RuntimeException(e);
+        }
+    }
 }