package org.apache.ignite.examples.ml.sql;
-import java.io.ByteArrayInputStream;
-import java.io.ByteArrayOutputStream;
-import java.io.IOException;
-import java.io.ObjectInputStream;
-import java.io.ObjectOutputStream;
-import java.io.Serializable;
import java.util.List;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
-import org.apache.ignite.binary.BinaryObject;
import org.apache.ignite.cache.query.QueryCursor;
import org.apache.ignite.cache.query.SqlFieldsQuery;
-import org.apache.ignite.cache.query.annotations.QuerySqlFunction;
import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.internal.util.IgniteUtils;
-import org.apache.ignite.ml.IgniteModel;
-import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
-import org.apache.ignite.ml.inference.Model;
-import org.apache.ignite.ml.inference.ModelDescriptor;
-import org.apache.ignite.ml.inference.ModelSignature;
-import org.apache.ignite.ml.inference.builder.SingleModelBuilder;
-import org.apache.ignite.ml.inference.parser.IgniteModelParser;
-import org.apache.ignite.ml.inference.reader.ModelStorageModelReader;
-import org.apache.ignite.ml.inference.storage.descriptor.ModelDescriptorStorage;
-import org.apache.ignite.ml.inference.storage.descriptor.ModelDescriptorStorageFactory;
-import org.apache.ignite.ml.inference.storage.model.ModelStorage;
-import org.apache.ignite.ml.inference.storage.model.ModelStorageFactory;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.inference.IgniteModelStorageUtil;
+import org.apache.ignite.ml.sql.SQLFeatureLabelExtractor;
+import org.apache.ignite.ml.sql.SQLFunctions;
+import org.apache.ignite.ml.sql.SqlDatasetBuilder;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
import org.apache.ignite.ml.tree.DecisionTreeNode;
private static final String TEST_DATA_RES = "examples/src/main/resources/datasets/titanik_test.csv";
/** Run example. */
- public static void main(String[] args) throws IOException {
+ public static void main(String[] args) {
System.out.println(">>> Decision tree classification trainer example started.");
// Start ignite grid.
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
System.out.println(">>> Perform training...");
- IgniteCache<Integer, BinaryObject> titanicTrainCache = ignite.cache("SQL_PUBLIC_TITANIK_TRAIN");
DecisionTreeNode mdl = trainer.fit(
- // We have to specify ".withKeepBinary(true)" because SQL caches contains only binary objects and this
- // information has to be passed into the trainer.
- new CacheBasedDatasetBuilder<>(ignite, titanicTrainCache).withKeepBinary(true),
- (k, v) -> VectorUtils.of(
- // We have to handle null values here to avoid NpE during unboxing.
- replaceNull(v.<Integer>field("pclass")),
- "male".equals(v.<String>field("sex")) ? 1 : 0,
- replaceNull(v.<Double>field("age")),
- replaceNull(v.<Integer>field("sibsp")),
- replaceNull(v.<Integer>field("parch")),
- replaceNull(v.<Double>field("fare"))
- ),
- (k, v) -> replaceNull(v.<Integer>field("survived"))
+ new SqlDatasetBuilder(ignite, "SQL_PUBLIC_TITANIK_TRAIN"),
+ new SQLFeatureLabelExtractor()
+ .withFeatureFields("pclass", "age", "sibsp", "parch", "fare")
+ .withFeatureField("sex", e -> "male".equals(e) ? 1 : 0)
+ .withLabelField("survived")
);
System.out.println(">>> Saving model...");
// Model storage is used to store raw serialized model.
System.out.println("Saving model into model storage...");
- byte[] serializedMdl = serialize((IgniteModel<byte[], byte[]>)i -> {
- // Here we need to wrap model so that it accepts and returns byte array.
- try {
- Vector input = deserialize(i);
- return serialize(mdl.predict(input));
- }
- catch (IOException | ClassNotFoundException e) {
- throw new RuntimeException(e);
- }
- });
-
- ModelStorage storage = new ModelStorageFactory().getModelStorage(ignite);
- storage.mkdirs("/");
- storage.putFile("/my_model", serializedMdl);
-
- // Model descriptor storage that is used to store model metadata.
- System.out.println("Saving model descriptor into model descriptor storage...");
- ModelDescriptor desc = new ModelDescriptor(
- "MyModel",
- "My Cool Model",
- new ModelSignature("", "", ""),
- new ModelStorageModelReader("/my_model"),
- new IgniteModelParser<>()
- );
- ModelDescriptorStorage descStorage = new ModelDescriptorStorageFactory().getModelDescriptorStorage(ignite);
- descStorage.put("my_model", desc);
+ IgniteModelStorageUtil.saveModel(mdl, "titanik_model_tree");
// Making inference using saved model.
System.out.println("Inference...");
try (QueryCursor<List<?>> cursor = cache.query(new SqlFieldsQuery("select " +
"survived as truth, " +
- "predict('my_model', pclass, case sex when 'male' then 1 else 0 end, age, sibsp, parch, fare) as prediction " +
+ "predict('titanik_model_tree', pclass, age, sibsp, parch, fare, case sex when 'male' then 1 else 0 end) as prediction " +
"from titanik_train"))) {
// Print inference result.
System.out.println("| Truth | Prediction |");
}
}
}
-
- /**
- * Replaces NULL values by 0.
- *
- * @param obj Input value.
- * @param <T> Type of value.
- * @return Input value of 0 if value is null.
- */
- private static <T extends Number> double replaceNull(T obj) {
- if (obj == null)
- return 0;
-
- return obj.doubleValue();
- }
-
- /**
- * SQL functions that should be defined and passed into cache configuration to extend list of functions available
- * in SQL interface.
- */
- public static class SQLFunctions {
- /**
- * Makes prediction using specified model name to extract model from model storage and specified input values
- * as input object for prediction.
- *
- * @param mdl Pretrained model.
- * @param x Input values.
- * @return Prediction.
- */
- @QuerySqlFunction
- public static double predict(String mdl, Double... x) {
- // Pretrained models work with vector of doubles so we need to replace null by 0 (or any other double).
- for (int i = 0; i < x.length; i++)
- if (x[i] == null)
- x[i] = 0.0;
-
- Ignite ignite = Ignition.ignite();
-
- ModelDescriptorStorage descStorage = new ModelDescriptorStorageFactory().getModelDescriptorStorage(ignite);
- ModelDescriptor desc = descStorage.get(mdl);
-
- Model<byte[], byte[]> infMdl = new SingleModelBuilder().build(desc.getReader(), desc.getParser());
-
- Vector input = VectorUtils.of(x);
-
- try {
- return deserialize(infMdl.predict(serialize(input)));
- }
- catch (IOException | ClassNotFoundException e) {
- throw new RuntimeException(e);
- }
- }
- }
-
- /**
- * Serialized the specified object.
- *
- * @param o Object to be serialized.
- * @return Serialized object as byte array.
- * @throws IOException In case of exception.
- */
- private static <T extends Serializable> byte[] serialize(T o) throws IOException {
- try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
- ObjectOutputStream oos = new ObjectOutputStream(baos)) {
- oos.writeObject(o);
- oos.flush();
-
- return baos.toByteArray();
- }
- }
-
- /**
- * Deserialized object represented as a byte array.
- *
- * @param o Serialized object.
- * @param <T> Type of serialized object.
- * @return Deserialized object.
- * @throws IOException In case of exception.
- * @throws ClassNotFoundException In case of exception.
- */
- @SuppressWarnings("unchecked")
- private static <T extends Serializable> T deserialize(byte[] o) throws IOException, ClassNotFoundException {
- try (ByteArrayInputStream bais = new ByteArrayInputStream(o);
- ObjectInputStream ois = new ObjectInputStream(bais)) {
-
- return (T)ois.readObject();
- }
- }
}
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
-import org.apache.ignite.binary.BinaryObject;
import org.apache.ignite.cache.query.QueryCursor;
import org.apache.ignite.cache.query.SqlFieldsQuery;
import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.internal.util.IgniteUtils;
-import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.sql.SQLFeatureLabelExtractor;
+import org.apache.ignite.ml.sql.SqlDatasetBuilder;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
import org.apache.ignite.ml.tree.DecisionTreeNode;
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
System.out.println(">>> Perform training...");
- IgniteCache<Integer, BinaryObject> titanicTrainCache = ignite.cache("SQL_PUBLIC_TITANIK_TRAIN");
DecisionTreeNode mdl = trainer.fit(
- // We have to specify ".withKeepBinary(true)" because SQL caches contains only binary objects and this
- // information has to be passed into the trainer.
- new CacheBasedDatasetBuilder<>(ignite, titanicTrainCache).withKeepBinary(true),
- (k, v) -> VectorUtils.of(
- // We have to handle null values here to avoid NpE during unboxing.
- replaceNull(v.<Integer>field("pclass")),
- "male".equals(v.<String>field("sex")) ? 1 : 0,
- replaceNull(v.<Double>field("age")),
- replaceNull(v.<Integer>field("sibsp")),
- replaceNull(v.<Integer>field("parch")),
- replaceNull(v.<Double>field("fare"))
- ),
- (k, v) -> replaceNull(v.<Integer>field("survived"))
+ new SqlDatasetBuilder(ignite, "SQL_PUBLIC_TITANIK_TRAIN"),
+ new SQLFeatureLabelExtractor()
+ .withFeatureFields("pclass", "age", "sibsp", "parch", "fare")
+ .withFeatureField("sex", e -> "male".equals(e) ? 1 : 0)
+ .withLabelField("survived")
);
System.out.println(">>> Perform inference...");
"parch, " +
"fare from titanik_test"))) {
for (List<?> passenger : cursor) {
- Vector input = VectorUtils.of(
- // We have to handle null values here to avoid NpE during unboxing.
- replaceNull((Integer)passenger.get(0)),
- "male".equals(passenger.get(1)) ? 1 : 0,
- replaceNull((Double)passenger.get(2)),
- replaceNull((Integer)passenger.get(3)),
- replaceNull((Integer)passenger.get(4)),
- replaceNull((Double)passenger.get(5))
- );
+ Vector input = VectorUtils.of(new Double[]{
+ asDouble(passenger.get(0)),
+ "male".equals(passenger.get(1)) ? 1.0 : 0.0,
+ asDouble(passenger.get(2)),
+ asDouble(passenger.get(3)),
+ asDouble(passenger.get(4)),
+ asDouble(passenger.get(5))
+ });
double prediction = mdl.predict(input);
}
/**
- * Replaces NULL values by 0.
+ * Converts specified number into double.
*
- * @param obj Input value.
- * @param <T> Type of value.
- * @return Input value of 0 if value is null.
+ * @param obj Number.
+ * @param <T> Type of number.
+ * @return Double.
*/
- private static <T extends Number> double replaceNull(T obj) {
+ private static <T extends Number> Double asDouble(Object obj) {
if (obj == null)
- return 0;
+ return null;
- return obj.doubleValue();
+ if (obj instanceof Number) {
+ Number num = (Number) obj;
+
+ return num.doubleValue();
+ }
+
+ throw new IllegalArgumentException("Object is expected to be a number [obj=" + obj + "]");
}
}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.inference;
+
+import java.util.Objects;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.IgniteModel;
+import org.apache.ignite.ml.inference.builder.AsyncModelBuilder;
+import org.apache.ignite.ml.inference.builder.SingleModelBuilder;
+import org.apache.ignite.ml.inference.builder.SyncModelBuilder;
+import org.apache.ignite.ml.inference.parser.IgniteModelParser;
+import org.apache.ignite.ml.inference.reader.ModelStorageModelReader;
+import org.apache.ignite.ml.inference.storage.descriptor.ModelDescriptorStorage;
+import org.apache.ignite.ml.inference.storage.descriptor.ModelDescriptorStorageFactory;
+import org.apache.ignite.ml.inference.storage.model.ModelStorage;
+import org.apache.ignite.ml.inference.storage.model.ModelStorageFactory;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.util.Utils;
+import org.jetbrains.annotations.NotNull;
+
+/**
+ * Utils class that helps to operate with model storage and Ignite models.
+ */
+public class IgniteModelStorageUtil {
+ /** Folder to be used to store Ignite models. */
+ private static final String IGNITE_MDL_FOLDER = "/ignite_models";
+
+ /**
+ * Saved specified model with specified name.
+ *
+ * @param mdl Model to be saved.
+ * @param name Model name to be used.
+ */
+ public static void saveModel(IgniteModel<Vector, Double> mdl, String name) {
+ IgniteModel<byte[], byte[]> mdlWrapper = wrapIgniteModel(mdl);
+ byte[] serializedMdl = Utils.serialize(mdlWrapper);
+ UUID mdlId = UUID.randomUUID();
+
+ saveModelStorage(serializedMdl, mdlId);
+ saveModelDescriptorStorage(name, mdlId);
+ }
+
+ /**
+ * Retrieves Ignite model by name using {@link SingleModelBuilder}.
+ *
+ * @param name Model name.
+ * @return Synchronous model built using {@link SingleModelBuilder}.
+ */
+ public static Model<Vector, Double> getModel(String name) {
+ return getSyncModel(name, new SingleModelBuilder());
+ }
+
+ /**
+ * Retrieves Ignite model by name using synchronous model builder.
+ *
+ * @param name Model name.
+ * @param mdlBldr Synchronous model builder.
+ * @return Synchronous model built using specified model builder.
+ */
+ public static Model<Vector, Double> getSyncModel(String name, SyncModelBuilder mdlBldr) {
+ ModelDescriptor desc = Objects.requireNonNull(getModelDescriptor(name), "Model not found [name=" + name + "]");
+
+ Model<byte[], byte[]> infMdl = mdlBldr.build(desc.getReader(), desc.getParser());
+
+ return unwrapIgniteSyncModel(infMdl);
+ }
+
+ /**
+ * Retrieves Ignite model by name using asynchronous model builder.
+ *
+ * @param name Model name.
+ * @param mdlBldr Asynchronous model builder.
+ * @return Asynchronous model built using specified model builder.
+ */
+ public static Model<Vector, Future<Double>> getAsyncModel(String name, AsyncModelBuilder mdlBldr) {
+ ModelDescriptor desc = Objects.requireNonNull(getModelDescriptor(name), "Model not found [name=" + name + "]");
+
+ Model<byte[], Future<byte[]>> infMdl = mdlBldr.build(desc.getReader(), desc.getParser());
+
+ return unwrapIgniteAsyncModel(infMdl);
+ }
+
+ /**
+ * Saves specified serialized model into storage as a file.
+ *
+ * @param serializedMdl Serialized model represented as a byte array.
+ * @param mdlId Model identifier.
+ */
+ private static void saveModelStorage(byte[] serializedMdl, UUID mdlId) {
+ Ignite ignite = Ignition.ignite();
+
+ ModelStorage storage = new ModelStorageFactory().getModelStorage(ignite);
+ storage.mkdirs(IGNITE_MDL_FOLDER);
+ storage.putFile(IGNITE_MDL_FOLDER + "/" + mdlId, serializedMdl);
+ }
+
+ /**
+ * Saves model descriptor into descriptor storage.
+ *
+ * @param name Model name.
+ * @param mdlId Model identifier used to find model in model storage (only with {@link ModelStorageModelReader}).
+ */
+ private static void saveModelDescriptorStorage(String name, UUID mdlId) {
+ Ignite ignite = Ignition.ignite();
+
+ ModelDescriptorStorage descStorage = new ModelDescriptorStorageFactory().getModelDescriptorStorage(ignite);
+ descStorage.put(name, new ModelDescriptor(
+ name,
+ null,
+ new ModelSignature(null, null, null),
+ new ModelStorageModelReader(IGNITE_MDL_FOLDER + "/" + mdlId),
+ new IgniteModelParser<>()
+ ));
+ }
+
+ /**
+ * Retirieves model descriptor.
+ *
+ * @param name Model name.
+ * @return Model descriptor.
+ */
+ private static ModelDescriptor getModelDescriptor(String name) {
+ Ignite ignite = Ignition.ignite();
+
+ ModelDescriptorStorage descStorage = new ModelDescriptorStorageFactory().getModelDescriptorStorage(ignite);
+
+ return descStorage.get(name);
+ }
+
+ /**
+ * Wraps Ignite model so that model accepts and returns serialized objects (byte arrays).
+ *
+ * @param mdl Ignite model.
+ * @return Ignite model that accepts and returns serialized objects (byte arrays).
+ */
+ private static IgniteModel<byte[], byte[]> wrapIgniteModel(IgniteModel<Vector, Double> mdl) {
+ return input -> {
+ Vector deserializedInput = Utils.deserialize(input);
+ Double output = mdl.predict(deserializedInput);
+
+ return Utils.serialize(output);
+ };
+ }
+
+ /**
+ * Unwraps Ignite model so that model accepts and returns deserialized objects ({@link Vector} and {@link Double}).
+ *
+ * @param mdl Ignite model.
+ * @return Ignite model that accepts and returns deserialized objects ({@link Vector} and {@link Double}).
+ */
+ private static Model<Vector, Double> unwrapIgniteSyncModel(Model<byte[], byte[]> mdl) {
+ return new Model<Vector, Double>() {
+ /** {@inheritDoc} */
+ @Override public Double predict(Vector input) {
+ byte[] serializedInput = Utils.serialize(input);
+ byte[] serializedOutput = mdl.predict(serializedInput);
+
+ return Utils.deserialize(serializedOutput);
+ }
+
+ /** {@inheritDoc} */
+ @Override public void close() {
+ mdl.close();
+ }
+ };
+ }
+
+ /**
+ * Unwraps Ignite model so that model accepts and returns deserialized objects ({@link Vector} and {@link Double}).
+ *
+ * @param mdl Ignite model.
+ * @return Ignite model that accepts and returns deserialized objects ({@link Vector} and {@link Double}).
+ */
+ private static Model<Vector, Future<Double>> unwrapIgniteAsyncModel(Model<byte[], Future<byte[]>> mdl) {
+ return new Model<Vector, Future<Double>>() {
+ /** {@inheritDoc} */
+ @Override public Future<Double> predict(Vector input) {
+ byte[] serializedInput = Utils.serialize(input);
+ Future<byte[]> serializedOutput = mdl.predict(serializedInput);
+
+ return new FutureDeserializationWrapper<>(serializedOutput);
+ }
+
+ /** {@inheritDoc} */
+ @Override public void close() {
+ mdl.close();
+ }
+ };
+ }
+
+ /**
+ * Future deserialization wrapper that accepts future that returns serialized object and turns it into future that
+ * returns deserialized object.
+ *
+ * @param <T> Type of return value.
+ */
+ private static class FutureDeserializationWrapper<T> implements Future<T> {
+ /** Delegate. */
+ private final Future<byte[]> delegate;
+
+ /**
+ * Constructs a new instance of future deserialization wrapper.
+ *
+ * @param delegate Delegate.
+ */
+ public FutureDeserializationWrapper(Future<byte[]> delegate) {
+ this.delegate = delegate;
+ }
+
+ /** {@inheritDoc} */
+ @Override public boolean cancel(boolean mayInterruptIfRunning) {
+ return delegate.cancel(mayInterruptIfRunning);
+ }
+
+ /** {@inheritDoc} */
+ @Override public boolean isCancelled() {
+ return delegate.isCancelled();
+ }
+
+ /** {@inheritDoc} */
+ @Override public boolean isDone() {
+ return delegate.isDone();
+ }
+
+ /** {@inheritDoc} */
+ @Override public T get() throws InterruptedException, ExecutionException {
+ return (T)Utils.deserialize(delegate.get());
+ }
+
+ /** {@inheritDoc} */
+ @Override public T get(long timeout, @NotNull TimeUnit unit) throws InterruptedException, ExecutionException,
+ TimeoutException {
+ return (T)Utils.deserialize(delegate.get(timeout, unit));
+ }
+ }
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.sql;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.function.Function;
+import org.apache.ignite.binary.BinaryObject;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
+import org.apache.ignite.ml.structures.LabeledVector;
+import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
+
+/**
+ * SQL feature label extractor that should be used to extract features and label from binary objects in SQL table.
+ */
+public class SQLFeatureLabelExtractor implements FeatureLabelExtractor<Object, BinaryObject, Double> {
+ /** Feature extractors for each needed fields as a list of functions. */
+ private final List<Function<BinaryObject, Number>> featureExtractors = new ArrayList<>();
+
+ /** Label extractor as a function. */
+ private Function<BinaryObject, Number> lbExtractor;
+
+ /** {@inheritDoc} */
+ @Override public LabeledVector<Double> extract(Object o, BinaryObject obj) {
+ Vector features = new DenseVector(featureExtractors.size());
+
+ int i = 0;
+ for (Function<BinaryObject, Number> featureExtractor : featureExtractors) {
+ Number val = featureExtractor.apply(obj);
+
+ if (val != null)
+ features.set(i, val.doubleValue());
+
+ i++;
+ }
+
+ Number lb = lbExtractor.apply(obj);
+
+ return new LabeledVector<>(features, lb == null ? null : lb.doubleValue());
+ }
+
+ /**
+ * Adds feature extractor for the field with specified name and value transformer.
+ *
+ * @param name Field name.
+ * @param transformer Field value transformer.
+ * @param <T> Field type.
+ * @return This SQL feature label extractor.
+ */
+ public <T> SQLFeatureLabelExtractor withFeatureField(String name, Function<T, Number> transformer) {
+ featureExtractors.add(obj -> transformer.apply(obj.<T>field(name)));
+
+ return this;
+ }
+
+ /**
+ * Adds feature extractor for the field with specified name. Field should be numeric (subclass of {@link Number}).
+ *
+ * @param name Field name.
+ * @return This SQL feature label extractor.
+ */
+ public SQLFeatureLabelExtractor withFeatureField(String name) {
+ featureExtractors.add(obj -> obj.field(name));
+
+ return this;
+ }
+
+ /**
+ * Adds feature extractor for the field with specified name. Field should be numeric (subclass of {@link Number}).
+ *
+ * @param names Field names.
+ * @return This SQL feature label extractor.
+ */
+ public SQLFeatureLabelExtractor withFeatureFields(String... names) {
+ for (String name : names)
+ withFeatureField(name);
+
+ return this;
+ }
+
+ /**
+ * Adds label extractor.
+ *
+ * @param name Field name.
+ * @return This SQL feature label extractor.
+ */
+ public SQLFeatureLabelExtractor withLabelField(String name) {
+ lbExtractor = obj -> obj.field(name);
+
+ return this;
+ }
+
+ /**
+ * Adds label extractor.
+ *
+ * @param name Field name.
+ * @param transformer Field value transformer.
+ * @param <T> Type of field.
+ * @return This SQL feature label extractor.
+ */
+ public <T> SQLFeatureLabelExtractor withLabelField(String name, Function<T, Number> transformer) {
+ lbExtractor = obj -> transformer.apply(obj.<T>field(name));
+
+ return this;
+ }
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.sql;
+
+import java.util.Arrays;
+import org.apache.ignite.cache.query.annotations.QuerySqlFunction;
+import org.apache.ignite.ml.inference.IgniteModelStorageUtil;
+import org.apache.ignite.ml.inference.Model;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+
+/**
+ * SQL functions that should be defined and passed into cache configuration to extend list of functions available
+ * in SQL interface.
+ */
+public class SQLFunctions {
+ /**
+ * Makes prediction using specified model name to extract model from model storage and specified input values
+ * as input object for prediction.
+ *
+ * @param mdl Pretrained model.
+ * @param x Input values.
+ * @return Prediction.
+ */
+ @QuerySqlFunction
+ public static double predict(String mdl, Double... x) {
+ System.out.println("Prediction for " + Arrays.toString(x));
+
+ try (Model<Vector, Double> infMdl = IgniteModelStorageUtil.getModel(mdl)) {
+ return infMdl.predict(VectorUtils.of(x));
+ }
+ }
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.sql;
+
+import org.apache.ignite.Ignite;
+import org.apache.ignite.binary.BinaryObject;
+import org.apache.ignite.lang.IgniteBiPredicate;
+import org.apache.ignite.ml.dataset.UpstreamTransformerBuilder;
+import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDataset;
+import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
+
+/**
+ * Subclass of {@link CacheBasedDatasetBuilder} that should be used to work with SQL tables.
+ */
+public class SqlDatasetBuilder extends CacheBasedDatasetBuilder<Object, BinaryObject> {
+ /**
+ * Constructs a new instance of cache based dataset builder that makes {@link CacheBasedDataset} with default
+ * predicate that passes all upstream entries to dataset.
+ *
+ * @param ignite Ignite instance.
+ * @param upstreamCache Name of Ignite Cache with {@code upstream} data.
+ */
+ public SqlDatasetBuilder(Ignite ignite, String upstreamCache) {
+ this(ignite, upstreamCache, (a, b) -> true);
+ }
+
+ /**
+ * Constructs a new instance of cache based dataset builder that makes {@link CacheBasedDataset}.
+ *
+ * @param ignite Ignite instance.
+ * @param upstreamCache Name of Ignite Cache with {@code upstream} data.
+ * @param filter Filter for {@code upstream} data.
+ */
+ public SqlDatasetBuilder(Ignite ignite, String upstreamCache, IgniteBiPredicate<Object, BinaryObject> filter) {
+ this(ignite, upstreamCache, filter, UpstreamTransformerBuilder.identity());
+ }
+
+ /**
+ * Constructs a new instance of cache based dataset builder that makes {@link CacheBasedDataset}.
+ *
+ * @param ignite Ignite instance.
+ * @param upstreamCache Name of Ignite Cache with {@code upstream} data.
+ * @param filter Filter for {@code upstream} data.
+ */
+ public SqlDatasetBuilder(Ignite ignite, String upstreamCache, IgniteBiPredicate<Object, BinaryObject> filter,
+ UpstreamTransformerBuilder transformerBuilder) {
+ super(ignite, ignite.cache(upstreamCache), filter, transformerBuilder, true);
+ }
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains util classes that help to work with machine learning models in SQL and train models on data from SQL tables.
+ */
+package org.apache.ignite.ml.sql;
\ No newline at end of file
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
+import java.io.Serializable;
import java.util.Iterator;
import java.util.Random;
import java.util.Spliterator;
Spliterators.spliteratorUnknownSize(iter, Spliterator.ORDERED),
false);
}
+
+ /**
+ * Serialized the specified object.
+ *
+ * @param o Object to be serialized.
+ * @return Serialized object as byte array.
+ */
+ public static <T extends Serializable> byte[] serialize(T o) {
+ try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ ObjectOutputStream oos = new ObjectOutputStream(baos)) {
+ oos.writeObject(o);
+ oos.flush();
+
+ return baos.toByteArray();
+ }
+ catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ /**
+ * Deserialized object represented as a byte array.
+ *
+ * @param o Serialized object.
+ * @param <T> Type of serialized object.
+ * @return Deserialized object.
+ */
+ @SuppressWarnings("unchecked")
+ public static <T extends Serializable> T deserialize(byte[] o) {
+ try (ByteArrayInputStream bais = new ByteArrayInputStream(o);
+ ObjectInputStream ois = new ObjectInputStream(bais)) {
+
+ return (T)ois.readObject();
+ }
+ catch (IOException | ClassNotFoundException e) {
+ throw new RuntimeException(e);
+ }
+ }
}