IGNITE-11137: [ML] IgniteModelStorage improvements for IgniteModel and
authorAnton Dmitriev <dmitrievanthony@gmail.com>
Mon, 4 Feb 2019 09:20:00 +0000 (12:20 +0300)
committerYury Babak <ybabak@gridgain.com>
Mon, 4 Feb 2019 09:20:00 +0000 (12:20 +0300)
SQL functionality

This closes #6001

16 files changed:
examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLInferenceExample.java
modules/ml/src/main/java/org/apache/ignite/ml/inference/IgniteModelStorageUtil.java
modules/ml/src/main/java/org/apache/ignite/ml/inference/storage/descriptor/IgniteModelDescriptorStorage.java
modules/ml/src/main/java/org/apache/ignite/ml/inference/storage/descriptor/LocalModelDescriptorStorage.java
modules/ml/src/main/java/org/apache/ignite/ml/inference/storage/descriptor/ModelDescriptorStorage.java
modules/ml/src/main/java/org/apache/ignite/ml/inference/storage/descriptor/ModelDescriptorStorageFactory.java
modules/ml/src/main/java/org/apache/ignite/ml/inference/storage/model/ModelStorageFactory.java
modules/ml/src/main/java/org/apache/ignite/ml/sql/SQLFeatureLabelExtractor.java
modules/ml/src/main/java/org/apache/ignite/ml/sql/SQLFunctions.java
modules/ml/src/main/java/org/apache/ignite/ml/util/LRUCache.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/util/LRUCacheExpirationListener.java [new file with mode: 0644]
modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
modules/ml/src/test/java/org/apache/ignite/ml/inference/IgniteModelStorageUtilTest.java [new file with mode: 0644]
modules/ml/src/test/java/org/apache/ignite/ml/inference/InferenceTestSuite.java
modules/ml/src/test/java/org/apache/ignite/ml/util/LRUCacheTest.java [new file with mode: 0644]
modules/ml/src/test/java/org/apache/ignite/ml/util/UtilTestSuite.java [new file with mode: 0644]

index c8f0596..185c2c2 100644 (file)
@@ -116,7 +116,7 @@ public class DecisionTreeClassificationTrainerSQLInferenceExample {
 
             // Model storage is used to store raw serialized model.
             System.out.println("Saving model into model storage...");
-            IgniteModelStorageUtil.saveModel(mdl, "titanik_model_tree");
+            IgniteModelStorageUtil.saveModel(ignite, mdl, "titanik_model_tree");
 
             // Making inference using saved model.
             System.out.println("Inference...");
@@ -130,6 +130,8 @@ public class DecisionTreeClassificationTrainerSQLInferenceExample {
                 for (List<?> row : cursor)
                     System.out.println("|     " + row.get(0) + " |        " + row.get(1) + " |");
             }
+
+            IgniteModelStorageUtil.removeModel(ignite, "titanik_model_tree");
         }
     }
 }
index af0a1a5..9bc86b1 100644 (file)
@@ -24,7 +24,6 @@ import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
 import org.apache.ignite.Ignite;
-import org.apache.ignite.Ignition;
 import org.apache.ignite.ml.IgniteModel;
 import org.apache.ignite.ml.inference.builder.AsyncModelBuilder;
 import org.apache.ignite.ml.inference.builder.SingleModelBuilder;
@@ -49,37 +48,65 @@ public class IgniteModelStorageUtil {
     /**
      * Saved specified model with specified name.
      *
+     * @param ignite Ignite instance.
      * @param mdl Model to be saved.
      * @param name Model name to be used.
      */
-    public static void saveModel(IgniteModel<Vector, Double> mdl, String name) {
+    public static void saveModel(Ignite ignite, IgniteModel<Vector, Double> mdl, String name) {
         IgniteModel<byte[], byte[]> mdlWrapper = wrapIgniteModel(mdl);
         byte[] serializedMdl = Utils.serialize(mdlWrapper);
         UUID mdlId = UUID.randomUUID();
 
-        saveModelStorage(serializedMdl, mdlId);
-        saveModelDescriptorStorage(name, mdlId);
+        saveModelDescriptor(ignite, name, mdlId);
+
+        try {
+            saveModelEntity(ignite, serializedMdl, mdlId);
+        }
+        catch (Exception e) {
+            // Here we need to do a rollback and remove descriptor from correspondent storage.
+            removeModelEntity(ignite, mdlId);
+            throw e;
+        }
+    }
+
+    /**
+     * Removes model with specified name.
+     *
+     * @param ignite Ignite instance.
+     * @param name Mode name to be removed.
+     */
+    public static void removeModel(Ignite ignite, String name) {
+        ModelDescriptor desc = getModelDescriptor(ignite, name);
+        if (desc == null)
+            return;
+
+        UUID mdlId = UUID.fromString(desc.getName());
+        removeModel(ignite, IGNITE_MDL_FOLDER + "/" + mdlId);
+        removeModelDescriptor(ignite, name);
     }
 
     /**
      * Retrieves Ignite model by name using {@link SingleModelBuilder}.
      *
+     * @param ignite Ignite instance.
      * @param name Model name.
      * @return Synchronous model built using {@link SingleModelBuilder}.
      */
-    public static Model<Vector, Double> getModel(String name) {
-        return getSyncModel(name, new SingleModelBuilder());
+    public static Model<Vector, Double> getModel(Ignite ignite, String name) {
+        return getSyncModel(ignite, name, new SingleModelBuilder());
     }
 
     /**
      * Retrieves Ignite model by name using synchronous model builder.
      *
+     * @param ignite Ignite instance.
      * @param name Model name.
      * @param mdlBldr Synchronous model builder.
      * @return Synchronous model built using specified model builder.
      */
-    public static Model<Vector, Double> getSyncModel(String name, SyncModelBuilder mdlBldr) {
-        ModelDescriptor desc = Objects.requireNonNull(getModelDescriptor(name), "Model not found [name=" + name + "]");
+    public static Model<Vector, Double> getSyncModel(Ignite ignite, String name, SyncModelBuilder mdlBldr) {
+        ModelDescriptor desc = Objects.requireNonNull(getModelDescriptor(ignite, name),
+            "Model not found [name=" + name + "]");
 
         Model<byte[], byte[]> infMdl = mdlBldr.build(desc.getReader(), desc.getParser());
 
@@ -89,12 +116,14 @@ public class IgniteModelStorageUtil {
     /**
      * Retrieves Ignite model by name using asynchronous model builder.
      *
+     * @param ignite Ignite instance.
      * @param name Model name.
      * @param mdlBldr Asynchronous model builder.
      * @return Asynchronous model built using specified model builder.
      */
-    public static Model<Vector, Future<Double>> getAsyncModel(String name, AsyncModelBuilder mdlBldr) {
-        ModelDescriptor desc = Objects.requireNonNull(getModelDescriptor(name), "Model not found [name=" + name + "]");
+    public static Model<Vector, Future<Double>> getAsyncModel(Ignite ignite, String name, AsyncModelBuilder mdlBldr) {
+        ModelDescriptor desc = Objects.requireNonNull(getModelDescriptor(ignite, name),
+            "Model not found [name=" + name + "]");
 
         Model<byte[], Future<byte[]>> infMdl = mdlBldr.build(desc.getReader(), desc.getParser());
 
@@ -104,45 +133,71 @@ public class IgniteModelStorageUtil {
     /**
      * Saves specified serialized model into storage as a file.
      *
+     * @param ignite Ignite instance.
      * @param serializedMdl Serialized model represented as a byte array.
      * @param mdlId Model identifier.
      */
-    private static void saveModelStorage(byte[] serializedMdl, UUID mdlId) {
-        Ignite ignite = Ignition.ignite();
-
+    private static void saveModelEntity(Ignite ignite, byte[] serializedMdl, UUID mdlId) {
         ModelStorage storage = new ModelStorageFactory().getModelStorage(ignite);
         storage.mkdirs(IGNITE_MDL_FOLDER);
-        storage.putFile(IGNITE_MDL_FOLDER + "/" + mdlId, serializedMdl);
+        storage.putFile(IGNITE_MDL_FOLDER + "/" + mdlId, serializedMdl, true);
+    }
+
+    /**
+     * Removes model with specified identifier from model storage.
+     *
+     * @param ignite Ignite instance.
+     * @param mdlId Model identifier.
+     */
+    private static void removeModelEntity(Ignite ignite, UUID mdlId) {
+        ModelStorage storage = new ModelStorageFactory().getModelStorage(ignite);
+        storage.remove(IGNITE_MDL_FOLDER + "/" + mdlId);
     }
 
     /**
-     * Saves model descriptor into descriptor storage.
+     * Saves model descriptor into descriptor storage if a model with given name is not saved yet, otherwise throws
+     * exception. To save model with the same name remove old model first.
      *
+     * @param ignite Ignite instance.
      * @param name Model name.
      * @param mdlId Model identifier used to find model in model storage (only with {@link ModelStorageModelReader}).
+     * @throws IllegalArgumentException If model with given name was already saved.
      */
-    private static void saveModelDescriptorStorage(String name, UUID mdlId) {
-        Ignite ignite = Ignition.ignite();
-
+    private static void saveModelDescriptor(Ignite ignite, String name, UUID mdlId) {
         ModelDescriptorStorage descStorage = new ModelDescriptorStorageFactory().getModelDescriptorStorage(ignite);
-        descStorage.put(name, new ModelDescriptor(
-            name,
+
+        boolean saved = descStorage.putIfAbsent(name, new ModelDescriptor(
+            mdlId.toString(),
             null,
             new ModelSignature(null, null, null),
             new ModelStorageModelReader(IGNITE_MDL_FOLDER + "/" + mdlId),
             new IgniteModelParser<>()
         ));
+
+        if (!saved)
+            throw new IllegalArgumentException("Model descriptor with given name already exists [name=" + name + "]");
+    }
+
+    /**
+     * Removes model descriptor from descriptor storage.
+     *
+     * @param ignite Ignite instance.
+     * @param name Model name.
+     */
+    private static void removeModelDescriptor(Ignite ignite, String name) {
+        ModelDescriptorStorage descStorage = new ModelDescriptorStorageFactory().getModelDescriptorStorage(ignite);
+
+        descStorage.remove(name);
     }
 
     /**
      * Retirieves model descriptor.
      *
+     * @param ignite Ignite instance.
      * @param name Model name.
      * @return Model descriptor.
      */
-    private static ModelDescriptor getModelDescriptor(String name) {
-        Ignite ignite = Ignition.ignite();
-
+    private static ModelDescriptor getModelDescriptor(Ignite ignite, String name) {
         ModelDescriptorStorage descStorage = new ModelDescriptorStorageFactory().getModelDescriptorStorage(ignite);
 
         return descStorage.get(name);
index 07503c5..a240885 100644 (file)
@@ -47,6 +47,11 @@ public class IgniteModelDescriptorStorage implements ModelDescriptorStorage {
     }
 
     /** {@inheritDoc} */
+    @Override public boolean putIfAbsent(String mdlId, ModelDescriptor mdl) {
+        return models.putIfAbsent(mdlId, mdl);
+    }
+
+    /** {@inheritDoc} */
     @Override public ModelDescriptor get(String mdlId) {
         return models.get(mdlId);
     }
index df54ab7..c8add1d 100644 (file)
@@ -36,6 +36,11 @@ public class LocalModelDescriptorStorage implements ModelDescriptorStorage {
     }
 
     /** {@inheritDoc} */
+    @Override public boolean putIfAbsent(String mdlId, ModelDescriptor mdl) {
+        return models.putIfAbsent(mdlId, mdl) == null;
+    }
+
+    /** {@inheritDoc} */
     @Override public ModelDescriptor get(String name) {
         return models.get(name);
     }
index 4bb8cfd..92b351e 100644 (file)
@@ -34,6 +34,15 @@ public interface ModelDescriptorStorage extends Iterable<IgniteBiTuple<String, M
     public void put(String mdlId, ModelDescriptor mdl);
 
     /**
+     * Saves the specified model descriptor with the specified model identifier if it's not saved yet.
+     *
+     * @param mdlId Model identifier.
+     * @param mdl Model descriptor.
+     * @return {@code true} if model descriptor has been successfully saved, otherwise {@code false}.
+     */
+    public boolean putIfAbsent(String mdlId, ModelDescriptor mdl);
+
+    /**
      * Returns model descriptor saved for the specified model identifier.
      *
      * @param mdlId Model identifier.
index 7f5daf4..7333c55 100644 (file)
@@ -37,6 +37,9 @@ public class ModelDescriptorStorageFactory {
     public ModelDescriptorStorage getModelDescriptorStorage(Ignite ignite) {
         IgniteCache<String, ModelDescriptor> cache = ignite.cache(MODEL_DESCRIPTOR_STORAGE_CACHE_NAME);
 
+        if (cache == null)
+            throw new IllegalStateException("Model descriptor storage doesn't exists. Enable ML plugin to create it.");
+
         return new IgniteModelDescriptorStorage(cache);
     }
 
index 96246ff..c1074f0 100644 (file)
@@ -35,6 +35,10 @@ public class ModelStorageFactory {
      */
     public ModelStorage getModelStorage(Ignite ignite) {
         IgniteCache<String, FileOrDirectory> cache = ignite.cache(MODEL_STORAGE_CACHE_NAME);
+
+        if (cache == null)
+            throw new IllegalStateException("Model storage doesn't exists. Enable ML plugin to create it.");
+
         ModelStorageProvider storageProvider = new IgniteModelStorageProvider(cache);
 
         return new DefaultModelStorage(storageProvider);
index 4ed3a6a..312335a 100644 (file)
@@ -30,6 +30,9 @@ import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
  * SQL feature label extractor that should be used to extract features and label from binary objects in SQL table.
  */
 public class SQLFeatureLabelExtractor implements FeatureLabelExtractor<Object, BinaryObject, Double> {
+    /** */
+    private static final long serialVersionUID = 9040557299449762021L;
+
     /** Feature extractors for each needed fields as a list of functions. */
     private final List<Function<BinaryObject, Number>> featureExtractors = new ArrayList<>();
 
index a12d69b..a3e55cf 100644 (file)
 
 package org.apache.ignite.ml.sql;
 
-import java.util.Arrays;
+import java.util.Map;
+import java.util.concurrent.locks.LockSupport;
+import org.apache.ignite.Ignition;
 import org.apache.ignite.cache.query.annotations.QuerySqlFunction;
 import org.apache.ignite.ml.inference.IgniteModelStorageUtil;
 import org.apache.ignite.ml.inference.Model;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.util.LRUCache;
 
 /**
  * SQL functions that should be defined and passed into cache configuration to extend list of functions available
  * in SQL interface.
  */
 public class SQLFunctions {
+    /** Default LRU model cache size. */
+    private static final int LRU_CACHE_SIZE = 10;
+
+    /** Cache clear interval in seconds. */
+    private static final long CACHE_CLEAR_INTERVAL_SEC = 60;
+
+    /** Default LRU model cache. */
+    // TODO: IGNITE-11163: Add hart beat tracker to DistributedInfModel.
+    private static final Map<String, Model<Vector, Double>> cache = new LRUCache<>(LRU_CACHE_SIZE, Model::close);
+
+    static {
+        Thread invalidationThread = new Thread(() -> {
+            while (Thread.currentThread().isInterrupted())
+                LockSupport.parkNanos(CACHE_CLEAR_INTERVAL_SEC * 1_000_000_000L);
+
+            synchronized (cache) {
+                for (Model<Vector, Double> mdl : cache.values())
+                    mdl.close();
+
+                cache.clear();
+            }
+        });
+
+        invalidationThread.setDaemon(true);
+        invalidationThread.start();
+    }
+
     /**
      * Makes prediction using specified model name to extract model from model storage and specified input values
      * as input object for prediction.
@@ -39,10 +69,15 @@ public class SQLFunctions {
      */
     @QuerySqlFunction
     public static double predict(String mdl, Double... x) {
-        System.out.println("Prediction for " + Arrays.toString(x));
+        Model<Vector, Double> infMdl;
 
-        try (Model<Vector, Double> infMdl = IgniteModelStorageUtil.getModel(mdl)) {
-            return infMdl.predict(VectorUtils.of(x));
+        synchronized (cache) {
+            infMdl = cache.computeIfAbsent(
+                mdl,
+                key -> IgniteModelStorageUtil.getModel(Ignition.ignite(), mdl)
+            );
         }
+
+        return infMdl.predict(VectorUtils.of(x));
     }
 }
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/util/LRUCache.java b/modules/ml/src/main/java/org/apache/ignite/ml/util/LRUCache.java
new file mode 100644 (file)
index 0000000..6e480c0
--- /dev/null
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.util;
+
+import java.util.LinkedHashMap;
+import java.util.Map;
+
+/**
+ * LRU cache with fixed size and expiration listener.
+ *
+ * @param <K> Type of a key.
+ * @param <V> Type of a value.
+ */
+public class LRUCache<K, V> extends LinkedHashMap<K, V> {
+    /** */
+    private static final long serialVersionUID = 4266640700294024306L;
+
+    /** Cache size. */
+    private final int cacheSize;
+
+    /** Removal listeners. */
+    private final LRUCacheExpirationListener<V> expirationLsnr;
+
+    /**
+     * Constructs a new instance of LRU cache.
+     *
+     * @param cacheSize Cache size.
+     */
+    public LRUCache(int cacheSize) {
+        this(cacheSize, e -> {});
+    }
+
+    /**
+     * Constructs a new instance of LRU cache.
+     *
+     * @param cacheSize Cache size.
+     * @param expirationLsnr Expiration listener.
+     */
+    public LRUCache(int cacheSize, LRUCacheExpirationListener<V> expirationLsnr) {
+        super(10, 0.75f, true);
+        this.cacheSize = cacheSize;
+        this.expirationLsnr = expirationLsnr;
+    }
+
+    /** {@inheritDoc} */
+    @Override protected boolean removeEldestEntry(Map.Entry<K, V> eldest) {
+        if(size() > cacheSize) {
+            expirationLsnr.entryExpired(eldest.getValue());
+            return true;
+        }
+
+        return false;
+    }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/util/LRUCacheExpirationListener.java b/modules/ml/src/main/java/org/apache/ignite/ml/util/LRUCacheExpirationListener.java
new file mode 100644 (file)
index 0000000..9385cfc
--- /dev/null
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.util;
+
+/**
+ * LRU cache expiration listener.
+ *
+ * @param <V> Type of a value.
+ */
+@FunctionalInterface
+public interface LRUCacheExpirationListener<V> {
+    /**
+     * Handles entry expiration, is called before value is moved from cache.
+     *
+     * @param val Value to be expired and removed.
+     */
+    public void entryExpired(V val);
+}
index 076a81d..f33c9ff 100644 (file)
@@ -35,6 +35,7 @@ import org.apache.ignite.ml.selection.SelectionTestSuite;
 import org.apache.ignite.ml.structures.StructuresTestSuite;
 import org.apache.ignite.ml.svm.SVMTestSuite;
 import org.apache.ignite.ml.tree.DecisionTreeTestSuite;
+import org.apache.ignite.ml.util.UtilTestSuite;
 import org.apache.ignite.ml.util.generators.DataStreamGeneratorTestSuite;
 import org.junit.runner.RunWith;
 import org.junit.runners.Suite;
@@ -59,6 +60,7 @@ import org.junit.runners.Suite;
     CommonTestSuite.class,
     MultiClassTestSuite.class,
     DataStreamGeneratorTestSuite.class,
+    UtilTestSuite.class,
 
     /** JUnit 3 tests. */
     DecisionTreeTestSuite.class,
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/inference/IgniteModelStorageUtilTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/inference/IgniteModelStorageUtilTest.java
new file mode 100644 (file)
index 0000000..2feca69
--- /dev/null
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.inference;
+
+import org.apache.ignite.Ignite;
+import org.apache.ignite.configuration.IgniteConfiguration;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.util.plugin.MLPluginConfiguration;
+import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
+import org.junit.Test;
+
+/**
+ * Tests for {@link IgniteModelStorageUtil}.
+ */
+public class IgniteModelStorageUtilTest extends GridCommonAbstractTest {
+    /** Ignite configuration. */
+    private final IgniteConfiguration cfg;
+
+    /**
+     * Constructs a new instance of Ignite model storage util test.
+     */
+    public IgniteModelStorageUtilTest() {
+        cfg = new IgniteConfiguration();
+
+        MLPluginConfiguration mlCfg = new MLPluginConfiguration();
+        mlCfg.setWithMdlDescStorage(true);
+        mlCfg.setWithMdlStorage(true);
+
+        cfg.setPluginConfigurations(mlCfg);
+    }
+
+    /** */
+    @Test
+    public void testSaveAndGet() throws Exception {
+        try (Ignite ignite = startGrid(cfg)) {
+            IgniteModelStorageUtil.saveModel(ignite, i -> 0.42, "mdl");
+            Model<Vector, Double> infMdl = IgniteModelStorageUtil.getModel(ignite, "mdl");
+
+            assertEquals(0.42, infMdl.predict(VectorUtils.of()));
+        }
+    }
+
+    /** */
+    @Test(expected = IllegalArgumentException.class)
+    public void testSaveModelWithTheSameName() throws Exception {
+        try (Ignite ignite = startGrid(cfg)) {
+            IgniteModelStorageUtil.saveModel(ignite, i -> 0.42, "mdl");
+            IgniteModelStorageUtil.saveModel(ignite, i -> 0.42, "mdl");
+        }
+    }
+
+    /** */
+    @Test
+    public void testSaveRemoveSaveModel() throws Exception {
+        try (Ignite ignite = startGrid(cfg)) {
+            IgniteModelStorageUtil.saveModel(ignite, i -> 0.42, "mdl");
+            IgniteModelStorageUtil.removeModel(ignite, "mdl");
+            IgniteModelStorageUtil.saveModel(ignite, i -> 0.43, "mdl");
+
+            Model<Vector, Double> infMdl = IgniteModelStorageUtil.getModel(ignite, "mdl");
+
+            assertEquals(0.43, infMdl.predict(VectorUtils.of()));
+        }
+    }
+}
index b4f10db..27179de 100644 (file)
@@ -34,7 +34,8 @@ import org.junit.runners.Suite;
     ThreadedModelBuilderTest.class,
     DirectorySerializerTest.class,
     DefaultModelStorageTest.class,
-    IgniteDistributedModelBuilderTest.class
+    IgniteDistributedModelBuilderTest.class,
+    IgniteModelStorageUtilTest.class
 })
 public class InferenceTestSuite {
 }
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/util/LRUCacheTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/util/LRUCacheTest.java
new file mode 100644 (file)
index 0000000..2b8d01d
--- /dev/null
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.util;
+
+import java.util.ArrayList;
+import java.util.List;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * Tests for {@link LRUCache}.
+ */
+public class LRUCacheTest {
+    /** */
+    @Test
+    public void testSize() {
+        LRUCache<Integer, Integer> cache = new LRUCache<>(10);
+        for (int i = 0; i < 100; i++)
+            cache.put(i, i);
+
+        assertEquals(10, cache.size());
+    }
+
+    /** */
+    @Test
+    public void testValues() {
+        LRUCache<Integer, Integer> cache = new LRUCache<>(10);
+        for (int i = 0; i < 100; i++) {
+            cache.get(0);
+            cache.put(i, i);
+        }
+
+        assertTrue(cache.containsKey(0));
+
+        for (int i = 91; i < 100; i++)
+            assertTrue(cache.containsKey(i));
+    }
+
+    /** */
+    @Test
+    public void testExpirationListener() {
+        List<Integer> expired = new ArrayList<>();
+
+        LRUCache<Integer, Integer> cache = new LRUCache<>(10, expired::add);
+        for (int i = 0; i < 100; i++)
+            cache.put(i, i);
+
+        for (int i = 0; i < 90; i++)
+            assertEquals(i, expired.get(i).longValue());
+    }
+}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/util/UtilTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/util/UtilTestSuite.java
new file mode 100644 (file)
index 0000000..201ab5d
--- /dev/null
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.util;
+
+import org.junit.runner.RunWith;
+import org.junit.runners.Suite;
+
+/**
+ * Test suite for all tests located in {@link org.apache.ignite.ml.util} package.
+ */
+@RunWith(Suite.class)
+@Suite.SuiteClasses({
+    LRUCacheTest.class
+})
+public class UtilTestSuite {
+}