IGNITE-10718: [ML] Merge XGBoost and Ignite ML trees together
authordmitrievanthony <dmitrievanthony@gmail.com>
Tue, 25 Dec 2018 15:34:48 +0000 (18:34 +0300)
committerYury Babak <ybabak@gridgain.com>
Tue, 25 Dec 2018 15:34:48 +0000 (18:34 +0300)
This closes #5691

38 files changed:
examples/src/main/java/org/apache/ignite/examples/ml/inference/IgniteFunctionDistributedInferenceExample.java
examples/src/main/java/org/apache/ignite/examples/ml/inference/ModelStorageExample.java
examples/src/main/java/org/apache/ignite/examples/ml/inference/TensorFlowDistributedInferenceExample.java
examples/src/main/java/org/apache/ignite/examples/ml/inference/TensorFlowLocalInferenceExample.java
examples/src/main/java/org/apache/ignite/examples/ml/inference/TensorFlowThreadedInferenceExample.java
examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java
examples/src/main/java/org/apache/ignite/examples/ml/xgboost/XGBoostModelParserExample.java
modules/ml/src/main/java/org/apache/ignite/ml/Model.java
modules/ml/src/main/java/org/apache/ignite/ml/inference/InfModel.java
modules/ml/src/main/java/org/apache/ignite/ml/inference/ModelDescriptor.java
modules/ml/src/main/java/org/apache/ignite/ml/inference/builder/AsyncInfModelBuilder.java
modules/ml/src/main/java/org/apache/ignite/ml/inference/builder/IgniteDistributedInfModelBuilder.java
modules/ml/src/main/java/org/apache/ignite/ml/inference/builder/SingleInfModelBuilder.java
modules/ml/src/main/java/org/apache/ignite/ml/inference/builder/SyncInfModelBuilder.java
modules/ml/src/main/java/org/apache/ignite/ml/inference/builder/ThreadedInfModelBuilder.java
modules/ml/src/main/java/org/apache/ignite/ml/inference/parser/IgniteFunctionInfModelParser.java
modules/ml/src/main/java/org/apache/ignite/ml/inference/parser/InfModelParser.java
modules/ml/src/main/java/org/apache/ignite/ml/inference/parser/TensorFlowBaseInfModelParser.java
modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/minmaxscaling/MinMaxScalerPreprocessor.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java
modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeConditionalNode.java
modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/InfModelBuilderTestUtil.java
modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/SingleInfModelBuilderTest.java
modules/ml/src/test/java/org/apache/ignite/ml/inference/builder/ThreadedInfModelBuilderTest.java
modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/minmaxscaling/MinMaxScalerPreprocessorTest.java
modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/EvaluatorTest.java
modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java
modules/ml/xgboost-model-parser/src/main/java/org/apache/ignite/ml/xgboost/MapBasedXGObject.java [deleted file]
modules/ml/xgboost-model-parser/src/main/java/org/apache/ignite/ml/xgboost/XGLeafNode.java [deleted file]
modules/ml/xgboost-model-parser/src/main/java/org/apache/ignite/ml/xgboost/XGModelComposition.java [new file with mode: 0644]
modules/ml/xgboost-model-parser/src/main/java/org/apache/ignite/ml/xgboost/XGNode.java [deleted file]
modules/ml/xgboost-model-parser/src/main/java/org/apache/ignite/ml/xgboost/XGObject.java [deleted file]
modules/ml/xgboost-model-parser/src/main/java/org/apache/ignite/ml/xgboost/XGSplitNode.java [deleted file]
modules/ml/xgboost-model-parser/src/main/java/org/apache/ignite/ml/xgboost/parser/XGModelParser.java
modules/ml/xgboost-model-parser/src/main/java/org/apache/ignite/ml/xgboost/parser/visitor/XGModelVisitor.java
modules/ml/xgboost-model-parser/src/main/java/org/apache/ignite/ml/xgboost/parser/visitor/XGTreeDictionaryVisitor.java [moved from modules/ml/xgboost-model-parser/src/main/java/org/apache/ignite/ml/xgboost/XGModel.java with 53% similarity]
modules/ml/xgboost-model-parser/src/main/java/org/apache/ignite/ml/xgboost/parser/visitor/XGTreeVisitor.java
modules/ml/xgboost-model-parser/src/test/java/org/apache/ignite/ml/xgboost/parser/XGBoostModelParserTest.java

index 380b8e6..58ddde7 100644 (file)
@@ -70,7 +70,7 @@ public class IgniteFunctionDistributedInferenceExample {
 
             System.out.println(">>> Preparing model reader and model parser.");
             InfModelReader reader = new InMemoryInfModelReader(mdl);
-            InfModelParser<Vector, Double> parser = new IgniteFunctionInfModelParser<>();
+            InfModelParser<Vector, Double, ?> parser = new IgniteFunctionInfModelParser<>();
             try (InfModel<Vector, Future<Double>> infMdl = new IgniteDistributedInfModelBuilder(ignite, 4, 4)
                 .build(reader, parser)) {
                 System.out.println(">>> Inference model is ready.");
@@ -85,7 +85,7 @@ public class IgniteFunctionDistributedInferenceExample {
                         Vector inputs = val.copyOfRange(1, val.size());
                         double groundTruth = val.get(0);
 
-                        double prediction = infMdl.predict(inputs).get();
+                        double prediction = infMdl.apply(inputs).get();
 
                         System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
                     }
index 5b704f3..a32d137 100644 (file)
@@ -78,7 +78,7 @@ public class ModelStorageExample {
 
                 System.out.println("Make inference...");
                 for (int i = 0; i < 10; i++) {
-                    Integer res = deserialize(infMdl.predict(serialize(i)));
+                    Integer res = deserialize(infMdl.apply(serialize(i)));
                     System.out.println(i + " -> " + res);
                 }
             }
index 48e8df1..a1e3b21 100644 (file)
@@ -61,7 +61,7 @@ public class TensorFlowDistributedInferenceExample {
 
             InfModelReader reader = new FileSystemInfModelReader(mdlRsrc.getPath());
 
-            InfModelParser<double[], Long> parser = new TensorFlowSavedModelInfModelParser<double[], Long>("serve")
+            InfModelParser<double[], Long, ?> parser = new TensorFlowSavedModelInfModelParser<double[], Long>("serve")
 
                 .withInput("Placeholder", doubles -> {
                     float[][][] reshaped = new float[1][28][28];
@@ -86,7 +86,7 @@ public class TensorFlowDistributedInferenceExample {
                 .build(reader, parser)) {
                 List<Future<?>> futures = new ArrayList<>(images.size());
                 for (MnistUtils.MnistLabeledImage image : images)
-                    futures.add(threadedMdl.predict(image.getPixels()));
+                    futures.add(threadedMdl.apply(image.getPixels()));
                 for (Future<?> f : futures)
                     f.get();
             }
index c907778..d5ccbd7 100644 (file)
@@ -54,7 +54,7 @@ public class TensorFlowLocalInferenceExample {
 
         InfModelReader reader = new FileSystemInfModelReader(mdlRsrc.getPath());
 
-        InfModelParser<double[], Long> parser = new TensorFlowSavedModelInfModelParser<double[], Long>("serve")
+        InfModelParser<double[], Long, ?> parser = new TensorFlowSavedModelInfModelParser<double[], Long>("serve")
             .withInput("Placeholder", doubles -> {
                 float[][][] reshaped = new float[1][28][28];
                 for (int i = 0; i < doubles.length; i++)
@@ -75,7 +75,7 @@ public class TensorFlowLocalInferenceExample {
 
         try (InfModel<double[], Long> locMdl = new SingleInfModelBuilder().build(reader, parser)) {
             for (MnistUtils.MnistLabeledImage image : images)
-                locMdl.predict(image.getPixels());
+                locMdl.apply(image.getPixels());
         }
 
         long t1 = System.currentTimeMillis();
index 93dadea..14051f4 100644 (file)
@@ -58,7 +58,7 @@ public class TensorFlowThreadedInferenceExample {
 
         InfModelReader reader = new FileSystemInfModelReader(mdlRsrc.getPath());
 
-        InfModelParser<double[], Long> parser = new TensorFlowSavedModelInfModelParser<double[], Long>("serve")
+        InfModelParser<double[], Long, ?> parser = new TensorFlowSavedModelInfModelParser<double[], Long>("serve")
 
             .withInput("Placeholder", doubles -> {
                 float[][][] reshaped = new float[1][28][28];
@@ -83,7 +83,7 @@ public class TensorFlowThreadedInferenceExample {
             .build(reader, parser)) {
             List<Future<?>> futures = new ArrayList<>(images.size());
             for (MnistUtils.MnistLabeledImage image : images)
-                futures.add(threadedMdl.predict(image.getPixels()));
+                futures.add(threadedMdl.apply(image.getPixels()));
             for (Future<?> f : futures)
                 f.get();
         }
index 0cbde9c..adb7e44 100644 (file)
@@ -50,8 +50,14 @@ public class Step_1_Read_and_Learn {
             try {
                 IgniteCache<Integer, Object[]> dataCache = TitanicUtils.readPassengers(ignite);
 
-                IgniteBiFunction<Integer, Object[], Vector> featureExtractor
-                    = (k, v) -> VectorUtils.of((double) v[0], (double) v[5], (double) v[6]);
+                IgniteBiFunction<Integer, Object[], Vector> featureExtractor = (k, v) -> {
+                    double[] data = new double[]{(double) v[0], (double) v[5], (double) v[6]};
+                    data[0] = Double.isNaN(data[0]) ? 0 : data[0];
+                    data[1] = Double.isNaN(data[1]) ? 0 : data[1];
+                    data[2] = Double.isNaN(data[2]) ? 0 : data[2];
+
+                    return VectorUtils.of(data);
+                };
 
                 IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double) v[1];
 
index 40f10d8..68f27c4 100644 (file)
@@ -19,6 +19,7 @@ package org.apache.ignite.examples.ml.xgboost;
 
 import java.io.File;
 import java.io.FileNotFoundException;
+import java.util.HashMap;
 import java.util.Scanner;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Future;
@@ -30,8 +31,6 @@ import org.apache.ignite.ml.inference.builder.AsyncInfModelBuilder;
 import org.apache.ignite.ml.inference.builder.IgniteDistributedInfModelBuilder;
 import org.apache.ignite.ml.inference.reader.FileSystemInfModelReader;
 import org.apache.ignite.ml.inference.reader.InfModelReader;
-import org.apache.ignite.ml.xgboost.MapBasedXGObject;
-import org.apache.ignite.ml.xgboost.XGObject;
 import org.apache.ignite.ml.xgboost.parser.XGModelParser;
 
 /**
@@ -70,7 +69,7 @@ public class XGBoostModelParserExample {
             if (testExpRes == null)
                 throw new IllegalArgumentException("File not found [resource_path=" + TEST_ER_RES + "]");
 
-            try (InfModel<XGObject, Future<Double>> mdl = mdlBuilder.build(reader, parser);
+            try (InfModel<HashMap<String, Double>, Future<Double>> mdl = mdlBuilder.build(reader, parser);
                  Scanner testDataScanner = new Scanner(testData);
                  Scanner testExpResultsScanner = new Scanner(testExpRes)) {
 
@@ -78,7 +77,7 @@ public class XGBoostModelParserExample {
                     String testDataStr = testDataScanner.nextLine();
                     String testExpResultsStr = testExpResultsScanner.nextLine();
 
-                    MapBasedXGObject testObj = new MapBasedXGObject();
+                    HashMap<String, Double> testObj = new HashMap<>();
 
                     for (String keyValueString : testDataStr.split(" ")) {
                         String[] keyVal = keyValueString.split(":");
@@ -87,7 +86,7 @@ public class XGBoostModelParserExample {
                             testObj.put("f" + keyVal[0], Double.parseDouble(keyVal[1]));
                     }
 
-                    double prediction = mdl.predict(testObj).get();
+                    double prediction = mdl.apply(testObj).get();
 
                     double expPrediction = Double.parseDouble(testExpResultsStr);
 
index 87458a1..6453108 100644 (file)
 package org.apache.ignite.ml;
 
 import java.util.function.BiFunction;
+import org.apache.ignite.ml.inference.InfModel;
 import org.apache.ignite.ml.math.functions.IgniteFunction;
 
 /** Basic interface for all models. */
-public interface Model<T, V> extends IgniteFunction<T, V> {
+public interface Model<T, V> extends InfModel<T, V>, IgniteFunction<T, V> {
     /**
      * Combines this model with other model via specified combiner
      *
@@ -50,4 +51,9 @@ public interface Model<T, V> extends IgniteFunction<T, V> {
     public default String toString(boolean pretty) {
         return getClass().getSimpleName();
     }
+
+    /** {@inheritDoc} */
+    @Override public default void close() {
+        // Do nothing.
+    }
 }
index ea96407..c5b95e4 100644 (file)
 
 package org.apache.ignite.ml.inference;
 
+import java.util.function.Function;
+
 /**
  * Inference model that can be used to make predictions.
  *
  * @param <I> Type of model input.
  * @param <O> Type of model output.
  */
-public interface InfModel<I, O> extends AutoCloseable {
+public interface InfModel<I, O> extends Function<I, O>, AutoCloseable {
     /**
      * Make a prediction for the specified input arguments.
      *
      * @param input Input arguments.
      * @return Prediction result.
      */
-    public O predict(I input);
+    public O apply(I input);
 
     /** {@inheritDoc} */
-    @Override public void close();
+    public void close();
 }
\ No newline at end of file
index 49a2593..8c8980c 100644 (file)
@@ -39,7 +39,7 @@ public class ModelDescriptor implements Serializable {
     private final InfModelReader reader;
 
     /** Model parser. */
-    private final InfModelParser<byte[], byte[]> parser;
+    private final InfModelParser<byte[], byte[], ?> parser;
 
     /**
      * Constructs a new instance of model descriptor.
@@ -51,7 +51,7 @@ public class ModelDescriptor implements Serializable {
      * @param parser Model parser.
      */
     public ModelDescriptor(String name, String desc, ModelSignature signature, InfModelReader reader,
-        InfModelParser<byte[], byte[]> parser) {
+        InfModelParser<byte[], byte[], ?> parser) {
         this.name = name;
         this.desc = desc;
         this.signature = signature;
@@ -80,7 +80,7 @@ public class ModelDescriptor implements Serializable {
     }
 
     /** */
-    public InfModelParser<byte[], byte[]> getParser() {
+    public InfModelParser<byte[], byte[], ?> getParser() {
         return parser;
     }
 
index adf4659..e8b7e86 100644 (file)
@@ -39,5 +39,5 @@ public interface AsyncInfModelBuilder {
      * @return Inference model.
      */
     public <I extends Serializable, O extends Serializable> InfModel<I, Future<O>> build(InfModelReader reader,
-        InfModelParser<I, O> parser);
+        InfModelParser<I, O, ?> parser);
 }
index 2c6d917..8347b7c 100644 (file)
@@ -44,7 +44,7 @@ import org.apache.ignite.services.ServiceContext;
  * When the {@link #build(InfModelReader, InfModelParser)} method is called Apache Ignite starts the specified number of
  * service instances and request/response queues. Each service instance reads request queue, processes inbound requests
  * and writes responses to response queue. The facade returned by the {@link #build(InfModelReader, InfModelParser)}
- * method operates with request/response queues. When the {@link InfModel#predict(Object)} method is called the argument
+ * method operates with request/response queues. When the {@link InfModel#apply(Object)} method is called the argument
  * is sent as a request to the request queue. When the response is appeared in the response queue the {@link Future}
  * correspondent to the previously sent request is completed and the processing finishes.
  *
@@ -93,7 +93,7 @@ public class IgniteDistributedInfModelBuilder implements AsyncInfModelBuilder {
      * Starts the specified in constructor number of service instances and request/response queues. Each service
      * instance reads request queue, processes inbound requests and writes responses to response queue. The returned
      * facade is represented by the {@link InfModel} operates with request/response queues, but hides these details
-     * behind {@link InfModel#predict(Object)} method of {@link InfModel}.
+     * behind {@link InfModel#apply(Object)} method of {@link InfModel}.
      *
      * Be aware that {@link InfModel#close()} method must be called to clear allocated resources, stop services and
      * remove queues.
@@ -105,13 +105,13 @@ public class IgniteDistributedInfModelBuilder implements AsyncInfModelBuilder {
      * @return Facade represented by {@link InfModel}.
      */
     @Override public <I extends Serializable, O extends Serializable> InfModel<I, Future<O>> build(
-        InfModelReader reader, InfModelParser<I, O> parser) {
+        InfModelReader reader, InfModelParser<I, O, ?> parser) {
         return new DistributedInfModel<>(ignite, UUID.randomUUID().toString(), reader, parser, instances, maxPerNode);
     }
 
     /**
      * Facade that operates with request/response queues to make distributed inference, but hides these details
-     * behind {@link InfModel#predict(Object)} method of {@link InfModel}.
+     * behind {@link InfModel#apply(Object)} method of {@link InfModel}.
      *
      * Be aware that {@link InfModel#close()} method must be called to clear allocated resources, stop services and
      * remove queues.
@@ -155,7 +155,7 @@ public class IgniteDistributedInfModelBuilder implements AsyncInfModelBuilder {
          * @param instances Number of service instances maintaining to make distributed inference.
          * @param maxPerNode Max per node number of instances.
          */
-        DistributedInfModel(Ignite ignite, String suffix, InfModelReader reader, InfModelParser<I, O> parser,
+        DistributedInfModel(Ignite ignite, String suffix, InfModelReader reader, InfModelParser<I, O, ?> parser,
             int instances, int maxPerNode) {
             this.ignite = ignite;
             this.suffix = suffix;
@@ -172,7 +172,7 @@ public class IgniteDistributedInfModelBuilder implements AsyncInfModelBuilder {
         }
 
         /** {@inheritDoc} */
-        @Override public Future<O> predict(I input) {
+        @Override public Future<O> apply(I input) {
             if (!running.get())
                 throw new IllegalStateException("Inference model is not running");
 
@@ -198,7 +198,7 @@ public class IgniteDistributedInfModelBuilder implements AsyncInfModelBuilder {
          * @param instances Number of service instances maintaining to make distributed inference.
          * @param maxPerNode Max per node number of instances.
          */
-        private void startService(InfModelReader reader, InfModelParser<I, O> parser, int instances, int maxPerNode) {
+        private void startService(InfModelReader reader, InfModelParser<I, O, ?> parser, int instances, int maxPerNode) {
             ignite.services().deployMultiple(
                 String.format(INFERENCE_SERVICE_NAME_PATTERN, suffix),
                 new IgniteDistributedInfModelService<>(reader, parser, suffix),
@@ -294,7 +294,7 @@ public class IgniteDistributedInfModelBuilder implements AsyncInfModelBuilder {
         private final InfModelReader reader;
 
         /** Inference model parser. */
-        private final InfModelParser<I, O> parser;
+        private final InfModelParser<I, O, ?> parser;
 
         /** Suffix that with correspondent templates formats service and queue names. */
         private final String suffix;
@@ -315,7 +315,7 @@ public class IgniteDistributedInfModelBuilder implements AsyncInfModelBuilder {
          * @param parser Inference model parser.
          * @param suffix Suffix that with correspondent templates formats service and queue names.
          */
-        IgniteDistributedInfModelService(InfModelReader reader, InfModelParser<I, O> parser, String suffix) {
+        IgniteDistributedInfModelService(InfModelReader reader, InfModelParser<I, O, ?> parser, String suffix) {
             this.reader = reader;
             this.parser = parser;
             this.suffix = suffix;
@@ -347,7 +347,7 @@ public class IgniteDistributedInfModelBuilder implements AsyncInfModelBuilder {
                     continue;
                 }
 
-                O res = mdl.predict(req);
+                O res = mdl.apply(req);
 
                 try {
                     resQueue.put(res);
index f756f45..032ebab 100644 (file)
@@ -27,8 +27,8 @@ import org.apache.ignite.ml.inference.reader.InfModelReader;
  */
 public class SingleInfModelBuilder implements SyncInfModelBuilder {
     /** {@inheritDoc} */
-    @Override public <I extends Serializable, O extends Serializable> InfModel<I, O> build(InfModelReader reader,
-        InfModelParser<I, O> parser) {
+    @Override public <I extends Serializable, O extends Serializable, M extends InfModel<I, O>> M build(InfModelReader reader,
+        InfModelParser<I, O, M> parser) {
         return parser.parse(reader.read());
     }
 }
index 7aed8b8..f9883fc 100644 (file)
@@ -37,6 +37,6 @@ public interface SyncInfModelBuilder {
      * @param <O> Type of model output.
      * @return Inference model.
      */
-    public <I extends Serializable, O extends Serializable> InfModel<I, O> build(InfModelReader reader,
-        InfModelParser<I, O> parser);
+    public <I extends Serializable, O extends Serializable, M extends InfModel<I, O>> M build(InfModelReader reader,
+        InfModelParser<I, O, M> parser);
 }
index ff538de..b39cb8d 100644 (file)
@@ -44,7 +44,7 @@ public class ThreadedInfModelBuilder implements AsyncInfModelBuilder {
 
     /** {@inheritDoc} */
     @Override public <I extends Serializable, O extends Serializable> InfModel<I, Future<O>> build(
-        InfModelReader reader, InfModelParser<I, O> parser) {
+        InfModelReader reader, InfModelParser<I, O, ?> parser) {
         return new ThreadedInfModel<>(parser.parse(reader.read()), threads);
     }
 
@@ -74,8 +74,8 @@ public class ThreadedInfModelBuilder implements AsyncInfModelBuilder {
         }
 
         /** {@inheritDoc} */
-        @Override public Future<O> predict(I input) {
-            return threadPool.submit(() -> mdl.predict(input));
+        @Override public Future<O> apply(I input) {
+            return threadPool.submit(() -> mdl.apply(input));
         }
 
         /** {@inheritDoc} */
index a4f1377..9c8a862 100644 (file)
@@ -29,7 +29,7 @@ import org.apache.ignite.ml.math.functions.IgniteFunction;
  * @param <I> Type of model input.
  * @param <O> Type of model output.
  */
-public class IgniteFunctionInfModelParser<I, O> implements InfModelParser<I, O> {
+public class IgniteFunctionInfModelParser<I, O> implements InfModelParser<I, O, InfModel<I, O>> {
     /** */
     private static final long serialVersionUID = -4624683614990816434L;
 
@@ -64,7 +64,7 @@ public class IgniteFunctionInfModelParser<I, O> implements InfModelParser<I, O>
         }
 
         /** {@inheritDoc} */
-        @Override public O predict(I input) {
+        @Override public O apply(I input) {
             return function.apply(input);
         }
 
index fa62558..df5659c 100644 (file)
@@ -27,12 +27,12 @@ import org.apache.ignite.ml.inference.InfModel;
  * @param <O> Type of model output.
  */
 @FunctionalInterface
-public interface InfModelParser<I, O> extends Serializable {
+public interface InfModelParser<I, O, M extends InfModel<I, O>> extends Serializable {
     /**
      * Accepts serialized model represented by byte array, parses it and returns {@link InfModel}.
      *
      * @param mdl Serialized model represented by byte array.
      * @return Inference model.
      */
-    public InfModel<I, O> parse(byte[] mdl);
+    public M parse(byte[] mdl);
 }
index acc521f..d79b5de 100644 (file)
@@ -33,7 +33,7 @@ import org.tensorflow.Tensor;
  * @param <I> Type of model input.
  * @param <O> Type of model output.
  */
-public abstract class TensorFlowBaseInfModelParser<I, O> implements InfModelParser<I, O> {
+public abstract class TensorFlowBaseInfModelParser<I, O> implements InfModelParser<I, O, InfModel<I, O>> {
     /** */
     private static final long serialVersionUID = 5574259553625871456L;
 
@@ -143,7 +143,7 @@ public abstract class TensorFlowBaseInfModelParser<I, O> implements InfModelPars
         }
 
         /** {@inheritDoc} */
-        @Override public O predict(I input) {
+        @Override public O apply(I input) {
             Session.Runner runner = ses.runner();
 
             runner = feedAll(runner, input);
index 87e369f..8d94e15 100644 (file)
@@ -71,8 +71,16 @@ public class MinMaxScalerPreprocessor<K, V> implements IgniteBiFunction<K, V, Ve
         assert res.size() == min.length;
         assert res.size() == max.length;
 
-        for (int i = 0; i < res.size(); i++)
-            res.set(i, (res.get(i) - min[i]) / (max[i] - min[i]));
+        for (int i = 0; i < res.size(); i++) {
+            double num = res.get(i) - min[i];
+            double denom = max[i] - min[i];
+            double scaled = num / denom;
+
+            if (Double.isNaN(scaled))
+                res.set(i, num);
+            else
+                res.set(i, scaled);
+        }
 
         return res;
     }
index 573759e..35d1ea4 100644 (file)
@@ -156,7 +156,8 @@ public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends Dataset
             splitPnt.col,
             splitPnt.threshold,
             split(dataset, updatePredicateForThenNode(filter, splitPnt), deep + 1, impurityCalc),
-            split(dataset, updatePredicateForElseNode(filter, splitPnt), deep + 1, impurityCalc)
+            split(dataset, updatePredicateForElseNode(filter, splitPnt), deep + 1, impurityCalc),
+            null
         );
     }
 
index f598165..ef4d115 100644 (file)
@@ -33,10 +33,13 @@ public class DecisionTreeConditionalNode implements DecisionTreeNode {
     private final double threshold;
 
     /** Node that will be used in case tested value is greater then threshold. */
-    private final DecisionTreeNode thenNode;
+    private DecisionTreeNode thenNode;
 
     /** Node that will be used in case tested value is not greater then threshold. */
-    private final DecisionTreeNode elseNode;
+    private DecisionTreeNode elseNode;
+
+    /** Node that will be used in case tested value is not presented. */
+    private DecisionTreeNode missingNode;
 
     /**
      * Constructs a new instance of decision tree conditional node.
@@ -45,17 +48,29 @@ public class DecisionTreeConditionalNode implements DecisionTreeNode {
      * @param threshold Threshold.
      * @param thenNode Node that will be used in case tested value is greater then threshold.
      * @param elseNode Node that will be used in case tested value is not greater then threshold.
+     * @param missingNode Node that will be used in case tested value is not presented.
      */
-    DecisionTreeConditionalNode(int col, double threshold, DecisionTreeNode thenNode, DecisionTreeNode elseNode) {
+    public DecisionTreeConditionalNode(int col, double threshold, DecisionTreeNode thenNode, DecisionTreeNode elseNode,
+        DecisionTreeNode missingNode) {
         this.col = col;
         this.threshold = threshold;
         this.thenNode = thenNode;
         this.elseNode = elseNode;
+        this.missingNode = missingNode;
     }
 
     /** {@inheritDoc} */
     @Override public Double apply(Vector features) {
-        return features.get(col) > threshold ? thenNode.apply(features) : elseNode.apply(features);
+        double val = features.get(col);
+
+        if (Double.isNaN(val)) {
+            if (missingNode == null)
+                throw new IllegalArgumentException("Feature must not be null or missing node should be specified");
+
+            return missingNode.apply(features);
+        }
+
+        return val > threshold ? thenNode.apply(features) : elseNode.apply(features);
     }
 
     /** */
@@ -74,10 +89,30 @@ public class DecisionTreeConditionalNode implements DecisionTreeNode {
     }
 
     /** */
+    public void setThenNode(DecisionTreeNode thenNode) {
+        this.thenNode = thenNode;
+    }
+
+    /** */
     public DecisionTreeNode getElseNode() {
         return elseNode;
     }
 
+    /** */
+    public void setElseNode(DecisionTreeNode elseNode) {
+        this.elseNode = elseNode;
+    }
+
+    /** */
+    public DecisionTreeNode getMissingNode() {
+        return missingNode;
+    }
+
+    /** */
+    public void setMissingNode(DecisionTreeNode missingNode) {
+        this.missingNode = missingNode;
+    }
+
     /** {@inheritDoc} */
     @Override public String toString() {
         return toString(false);
index b95e759..6b20fc1 100644 (file)
@@ -39,9 +39,9 @@ class InfModelBuilderTestUtil {
      *
      * @return Dummy model parser used in tests.
      */
-    static InfModelParser<Integer, Integer> getParser() {
+    static InfModelParser<Integer, Integer, InfModel<Integer, Integer>> getParser() {
         return m -> new InfModel<Integer, Integer>() {
-            @Override public Integer predict(Integer input) {
+            @Override public Integer apply(Integer input) {
                 return input;
             }
 
index 22596f2..b0bae25 100644 (file)
@@ -37,6 +37,6 @@ public class SingleInfModelBuilderTest {
         );
 
         for (int i = 0; i < 100; i++)
-            assertEquals(Integer.valueOf(i), infMdl.predict(i));
+            assertEquals(Integer.valueOf(i), infMdl.apply(i));
     }
 }
index 6d2f344..b4207f5 100644 (file)
@@ -39,6 +39,6 @@ public class ThreadedInfModelBuilderTest {
         );
 
         for (int i = 0; i < 100; i++)
-            assertEquals(Integer.valueOf(i), infMdl.predict(i).get());
+            assertEquals(Integer.valueOf(i), infMdl.apply(i).get());
     }
 }
index ce59112..f97120c 100644 (file)
@@ -53,4 +53,21 @@ public class MinMaxScalerPreprocessorTest {
        for (int i = 0; i < data.length; i++)
            assertArrayEquals(standardData[i], preprocessor.apply(i, VectorUtils.of(data[i])).asArray(), 1e-8);
     }
+
+    /** Test {@code apply()} method with division by zero. */
+    @Test
+    public void testApplyDivisionByZero() {
+        double[][] data = new double[][]{{1.}, {1.}, {1.}, {1.}};
+
+        MinMaxScalerPreprocessor<Integer, Vector> preprocessor = new MinMaxScalerPreprocessor<>(
+            new double[] {1.},
+            new double[] {1.},
+            (k, v) -> v
+        );
+
+        double[][] standardData = new double[][]{{0.}, {0.}, {0.}, {0.}};
+
+        for (int i = 0; i < data.length; i++)
+            assertArrayEquals(standardData[i], preprocessor.apply(i, VectorUtils.of(data[i])).asArray(), 1e-8);
+    }
 }
index 16750d3..5025460 100644 (file)
@@ -206,9 +206,9 @@ public class EvaluatorTest extends GridCommonAbstractTest {
         assertTrue(res.toString().length() > 0);
         assertEquals("Best maxDeep", 1.0, res.getBest("maxDeep"));
         assertEquals("Best minImpurityDecrease", 0.0, res.getBest("minImpurityDecrease"));
-        assertArrayEquals("Best score", new double[] {0.6666666666666666, 0.4, 0}, res.getBestScore(), 0);
+        assertArrayEquals("Best score", new double[] {0.6666666666666666, 0.6, 0}, res.getBestScore(), 0);
         assertEquals("Best hyper params size", 2, res.getBestHyperParams().size());
-        assertEquals("Best average score", 0.35555555555555557, res.getBestAvgScore());
+        assertEquals("Best average score", 0.4222222222222222, res.getBestAvgScore());
 
         assertEquals("Scores amount", 18, scores.size());
 
index 7d282df..e05139f 100644 (file)
@@ -27,6 +27,7 @@ import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredicti
 import org.apache.ignite.ml.dataset.feature.FeatureMeta;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.trainers.DatasetTrainer;
 import org.junit.Test;
 
 import static org.junit.Assert.assertEquals;
@@ -53,9 +54,10 @@ public class RandomForestClassifierTrainerTest extends TrainerTest {
         ArrayList<FeatureMeta> meta = new ArrayList<>();
         for (int i = 0; i < 4; i++)
             meta.add(new FeatureMeta("", i, false));
-        RandomForestClassifierTrainer trainer = new RandomForestClassifierTrainer(meta)
+        DatasetTrainer<ModelsComposition, Double> trainer = new RandomForestClassifierTrainer(meta)
             .withAmountOfTrees(5)
-            .withFeaturesCountSelectionStrgy(x -> 2);
+            .withFeaturesCountSelectionStrgy(x -> 2)
+            .withEnvironmentBuilder(TestUtils.testEnvBuilder());
 
         ModelsComposition mdl = trainer.fit(sample, parts, (k, v) -> VectorUtils.of(k), (k, v) -> v);
 
@@ -80,7 +82,7 @@ public class RandomForestClassifierTrainerTest extends TrainerTest {
         ArrayList<FeatureMeta> meta = new ArrayList<>();
         for (int i = 0; i < 4; i++)
             meta.add(new FeatureMeta("", i, false));
-        RandomForestClassifierTrainer trainer = new RandomForestClassifierTrainer(meta)
+        DatasetTrainer<ModelsComposition, Double> trainer = new RandomForestClassifierTrainer(meta)
             .withAmountOfTrees(100)
             .withFeaturesCountSelectionStrgy(x -> 2)
             .withEnvironmentBuilder(TestUtils.testEnvBuilder());
diff --git a/modules/ml/xgboost-model-parser/src/main/java/org/apache/ignite/ml/xgboost/MapBasedXGObject.java b/modules/ml/xgboost-model-parser/src/main/java/org/apache/ignite/ml/xgboost/MapBasedXGObject.java
deleted file mode 100644 (file)
index 20bdf1f..0000000
+++ /dev/null
@@ -1,61 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.xgboost;
-
-import java.util.HashMap;
-import java.util.Map;
-
-/** Map based implementation of {@link XGObject}. */
-public class MapBasedXGObject implements XGObject {
-    /** */
-    private static final long serialVersionUID = 4378979710350902592L;
-
-    /** Key-value map. */
-    private final Map<String, Double> map;
-
-    /**
-     * Constructs a new instance of map based {@link XGObject} with empty map.
-     */
-    public MapBasedXGObject() {
-        this(new HashMap<>());
-    }
-
-    /**
-     * Constructs a new instance of map based {@link XGObject} with the specified map.
-     *
-     * @param map Map.
-     */
-    public MapBasedXGObject(Map<String, Double> map) {
-        this.map = map;
-    }
-
-    /** {@inheritDoc} */
-    @Override public Double getFeature(String featureName) {
-        return map.get(featureName);
-    }
-
-    /**
-     * Puts feature value with the specified feature name.
-     *
-     * @param featureName Feature name.
-     * @param val Feature value.
-     */
-    public void put(String featureName, Double val) {
-        map.put(featureName, val);
-    }
-}
diff --git a/modules/ml/xgboost-model-parser/src/main/java/org/apache/ignite/ml/xgboost/XGLeafNode.java b/modules/ml/xgboost-model-parser/src/main/java/org/apache/ignite/ml/xgboost/XGLeafNode.java
deleted file mode 100644 (file)
index da572db..0000000
+++ /dev/null
@@ -1,38 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.xgboost;
-
-/** XGBoost model leaf node. */
-public class XGLeafNode implements XGNode {
-    /** Value. */
-    private final double val;
-
-    /**
-     * Constructs a new instance of leaf node.
-     *
-     * @param val Value.
-     */
-    public XGLeafNode(double val) {
-        this.val = val;
-    }
-
-    /** {@inheritDoc} */
-    @Override public double predict(XGObject obj) {
-        return val;
-    }
-}
diff --git a/modules/ml/xgboost-model-parser/src/main/java/org/apache/ignite/ml/xgboost/XGModelComposition.java b/modules/ml/xgboost-model-parser/src/main/java/org/apache/ignite/ml/xgboost/XGModelComposition.java
new file mode 100644 (file)
index 0000000..b8ccff7
--- /dev/null
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.xgboost;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.composition.ModelsComposition;
+import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.impl.SparseVector;
+import org.apache.ignite.ml.tree.DecisionTreeNode;
+
+import static org.apache.ignite.ml.math.StorageConstants.RANDOM_ACCESS_MODE;
+
+/**
+ * XGBoost model composition.
+ */
+public class XGModelComposition implements Model<HashMap<String, Double>, Double> {
+    /** Dictionary used for matching feature names and indexes. */
+    private final Map<String, Integer> dict;
+
+    /** Composition of decision trees. */
+    private ModelsComposition modelsComposition;
+
+    /**
+     * Constructs a new instance of composition of models.
+     *
+     * @param models Basic models.
+     */
+    public XGModelComposition(Map<String, Integer> dict, List<DecisionTreeNode> models) {
+        this.dict = dict;
+        this.modelsComposition = new ModelsComposition(models, new XGModelPredictionsAggregator());
+    }
+
+    /** {@inheritDoc} */
+    @Override public Double apply(HashMap<String, Double> map) {
+        return modelsComposition.apply(toVector(map));
+    }
+
+    /** */
+    public Map<String, Integer> getDict() {
+        return dict;
+    }
+
+    /** */
+    public ModelsComposition getModelsComposition() {
+        return modelsComposition;
+    }
+
+    /** */
+    public void setModelsComposition(ModelsComposition modelsComposition) {
+        this.modelsComposition = modelsComposition;
+    }
+
+    /**
+     * Converts hash map into sparse vector using dictionary.
+     *
+     * @param input Hash map with pairs of feature name and feature value.
+     * @return Sparse vector.
+     */
+    private Vector toVector(Map<String, Double> input) {
+        Vector inputVector = new SparseVector(dict.size(), RANDOM_ACCESS_MODE);
+        for (int i = 0; i < dict.size(); i++)
+            inputVector.set(i, Double.NaN);
+
+        for (Map.Entry<String, Double> feature : input.entrySet()) {
+            Integer idx = dict.get(feature.getKey());
+
+            if (idx != null)
+                inputVector.set(idx, feature.getValue());
+
+        }
+
+        return inputVector;
+    }
+
+    /**
+     * XG model predictions aggregator.
+     */
+    private static class XGModelPredictionsAggregator implements PredictionsAggregator {
+        /** {@inheritDoc} */
+        @Override public Double apply(double[] predictions) {
+            double res = 0;
+
+            for (double prediction : predictions)
+                res += prediction;
+
+            return (1.0 / (1.0 + Math.exp(-res)));
+        }
+    }
+}
\ No newline at end of file
diff --git a/modules/ml/xgboost-model-parser/src/main/java/org/apache/ignite/ml/xgboost/XGNode.java b/modules/ml/xgboost-model-parser/src/main/java/org/apache/ignite/ml/xgboost/XGNode.java
deleted file mode 100644 (file)
index f3401a7..0000000
+++ /dev/null
@@ -1,29 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.xgboost;
-
-/** XGBoost model node. */
-public interface XGNode {
-    /**
-     * Predicts label for the specified object.
-     *
-     * @param obj Object.
-     * @return Label.
-     */
-    public double predict(XGObject obj);
-}
\ No newline at end of file
diff --git a/modules/ml/xgboost-model-parser/src/main/java/org/apache/ignite/ml/xgboost/XGObject.java b/modules/ml/xgboost-model-parser/src/main/java/org/apache/ignite/ml/xgboost/XGObject.java
deleted file mode 100644 (file)
index 408c170..0000000
+++ /dev/null
@@ -1,33 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.xgboost;
-
-import java.io.Serializable;
-
-/**
- * Base interface for objects processed by XGBoost model.
- */
-public interface XGObject extends Serializable {
-    /**
-     * Returns feature value by the specified name.
-     *
-     * @param featureName Feature name.
-     * @return Feature value.
-     */
-    public Double getFeature(String featureName);
-}
diff --git a/modules/ml/xgboost-model-parser/src/main/java/org/apache/ignite/ml/xgboost/XGSplitNode.java b/modules/ml/xgboost-model-parser/src/main/java/org/apache/ignite/ml/xgboost/XGSplitNode.java
deleted file mode 100644 (file)
index 16c0ed4..0000000
+++ /dev/null
@@ -1,74 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.xgboost;
-
-/** XGBoost model split node. */
-public class XGSplitNode implements XGNode {
-    /** Feature name. */
-    private final String featureName;
-
-    /** Threshold. */
-    private final double threshold;
-
-    /** "Yes" child node. */
-    private XGNode yesNode;
-
-    /** "No" child node. */
-    private XGNode noNode;
-
-    /** "Missing" child node. */
-    private XGNode missingNode;
-
-    /**
-     * Constructs a new instance of XGBoost model split node.
-     *
-     * @param featureName Feature name.
-     * @param threshold Threshold.
-     */
-    public XGSplitNode(String featureName, double threshold) {
-        this.featureName = featureName;
-        this.threshold = threshold;
-    }
-
-    /** {@inheritDoc} */
-    @Override public double predict(XGObject obj) {
-        Double featureVal = obj.getFeature(featureName);
-
-        if (featureVal == null)
-            return missingNode.predict(obj);
-        else if (featureVal < threshold)
-            return yesNode.predict(obj);
-        else
-            return noNode.predict(obj);
-    }
-
-    /** */
-    public void setYesNode(XGNode yesNode) {
-        this.yesNode = yesNode;
-    }
-
-    /** */
-    public void setNoNode(XGNode noNode) {
-        this.noNode = noNode;
-    }
-
-    /** */
-    public void setMissingNode(XGNode missingNode) {
-        this.missingNode = missingNode;
-    }
-}
\ No newline at end of file
index a0b124f..8d40a7e 100644 (file)
@@ -19,12 +19,12 @@ package org.apache.ignite.ml.xgboost.parser;
 
 import java.io.ByteArrayInputStream;
 import java.io.IOException;
+import java.util.HashMap;
 import org.antlr.v4.runtime.CharStream;
 import org.antlr.v4.runtime.CharStreams;
 import org.antlr.v4.runtime.CommonTokenStream;
 import org.apache.ignite.ml.inference.parser.InfModelParser;
-import org.apache.ignite.ml.xgboost.XGModel;
-import org.apache.ignite.ml.xgboost.XGObject;
+import org.apache.ignite.ml.xgboost.XGModelComposition;
 import org.apache.ignite.ml.xgboost.parser.visitor.XGModelVisitor;
 
 /** XGBoost model parser. Uses the following ANTLR grammar to parse file:
@@ -63,13 +63,12 @@ import org.apache.ignite.ml.xgboost.parser.visitor.XGModelVisitor;
  * xgModel : xgTree+ ;
  * </pre>
  */
-// TODO: IGNITE-10718 Merge XGBoost and Ignite ML trees together.
-public class XGModelParser implements InfModelParser<XGObject, Double> {
+public class XGModelParser implements InfModelParser<HashMap<String, Double>, Double, XGModelComposition> {
     /** */
     private static final long serialVersionUID = -5819843559270294718L;
 
     /** {@inheritDoc} */
-    @Override public XGModel parse(byte[] mdl) {
+    @Override public XGModelComposition parse(byte[] mdl) {
         try (ByteArrayInputStream bais = new ByteArrayInputStream(mdl)) {
             CharStream cStream = CharStreams.fromStream(bais);
             XGBoostModelLexer lexer = new XGBoostModelLexer(cStream);
index 70433c7..76de3af 100644 (file)
 package org.apache.ignite.ml.xgboost.parser.visitor;
 
 import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
-import org.apache.ignite.ml.xgboost.XGModel;
-import org.apache.ignite.ml.xgboost.XGNode;
+import java.util.Map;
+import java.util.Set;
+import org.apache.ignite.ml.tree.DecisionTreeNode;
+import org.apache.ignite.ml.xgboost.XGModelComposition;
 import org.apache.ignite.ml.xgboost.parser.XGBoostModelBaseVisitor;
 import org.apache.ignite.ml.xgboost.parser.XGBoostModelParser;
 
 /**
  * XGBoost model visitor that parses model.
  */
-public class XGModelVisitor extends XGBoostModelBaseVisitor<XGModel> {
-    /** Tree visitor. */
-    private final XGTreeVisitor treeVisitor = new XGTreeVisitor();
+public class XGModelVisitor extends XGBoostModelBaseVisitor<XGModelComposition> {
+    /** Tree dictionary visitor. */
+    private final XGTreeDictionaryVisitor treeDictionaryVisitor = new XGTreeDictionaryVisitor();
 
     /** {@inheritDoc} */
-    @Override public XGModel visitXgModel(XGBoostModelParser.XgModelContext ctx) {
-        List<XGNode> trees = new ArrayList<>();
+    @Override public XGModelComposition visitXgModel(XGBoostModelParser.XgModelContext ctx) {
+        List<DecisionTreeNode> trees = new ArrayList<>();
 
+        Set<String> featureNames = new HashSet<>();
         for (XGBoostModelParser.XgTreeContext treeCtx : ctx.xgTree())
-            trees.add(treeVisitor.visitXgTree(treeCtx));
+            featureNames.addAll(treeDictionaryVisitor.visitXgTree(treeCtx));
 
-        return new XGModel(trees);
+        Map<String, Integer> dict = buildDictionary(featureNames);
+
+        XGTreeVisitor treeVisitor = new XGTreeVisitor(dict);
+
+        for (XGBoostModelParser.XgTreeContext treeCtx : ctx.xgTree()) {
+            DecisionTreeNode treeNode = treeVisitor.visitXgTree(treeCtx);
+            trees.add(treeNode);
+        }
+
+        return new XGModelComposition(dict, trees);
+    }
+
+    /**
+     * Build dictionary using specified feature names.
+     *
+     * @param featureNames Feature names.
+     * @return Dictionary.
+     */
+    private Map<String, Integer> buildDictionary(Set<String> featureNames) {
+        List<String> orderedFeatureNames = new ArrayList<>(featureNames);
+        Collections.sort(orderedFeatureNames);
+
+        Map<String, Integer> dict = new HashMap<>();
+        for (int i = 0; i < orderedFeatureNames.size(); i++)
+            dict.put(orderedFeatureNames.get(i), i);
+
+        return dict;
     }
 }
\ No newline at end of file
  * limitations under the License.
  */
 
-package org.apache.ignite.ml.xgboost;
+package org.apache.ignite.ml.xgboost.parser.visitor;
 
-import java.util.List;
-import org.apache.ignite.ml.inference.InfModel;
+import java.util.HashSet;
+import java.util.Set;
+import org.apache.ignite.ml.xgboost.parser.XGBoostModelBaseVisitor;
+import org.apache.ignite.ml.xgboost.parser.XGBoostModelParser;
 
 /**
- * XGBoost model.
+ * Tree dictionary visitor that collects all feature names.
  */
-public class XGModel implements InfModel<XGObject, Double> {
-    /** List of decision trees. */
-    private final List<XGNode> trees;
-
-    /**
-     * Constructs a new XGBoost model.
-     *
-     * @param trees List of XGBoost trees.
-     */
-    public XGModel(List<XGNode> trees) {
-        this.trees = trees;
-    }
-
+public class XGTreeDictionaryVisitor extends XGBoostModelBaseVisitor<Set<String>> {
     /** {@inheritDoc} */
-    @Override public Double predict(XGObject obj) {
-        double res = 0;
+    @Override public Set<String> visitXgTree(XGBoostModelParser.XgTreeContext ctx) {
+        Set<String> featureNames = new HashSet<>();
 
-        for (XGNode tree : trees)
-            res += tree.predict(obj);
+        for (XGBoostModelParser.XgNodeContext nodeCtx : ctx.xgNode()) {
+            String featureName = nodeCtx.STRING().getText();
+            featureNames.add(featureName);
+        }
 
-        return (1.0 / (1.0 + Math.exp(-res)));
-    }
-
-    /** {@inheritDoc} */
-    @Override public void close() {
-        // Do nothing.
+        return featureNames;
     }
-}
\ No newline at end of file
+}
index 60268ff..f031ed7 100644 (file)
@@ -20,37 +20,49 @@ package org.apache.ignite.ml.xgboost.parser.visitor;
 import java.util.HashMap;
 import java.util.Map;
 import org.antlr.v4.runtime.tree.TerminalNode;
-import org.apache.ignite.ml.xgboost.XGLeafNode;
-import org.apache.ignite.ml.xgboost.XGNode;
-import org.apache.ignite.ml.xgboost.XGSplitNode;
+import org.apache.ignite.ml.tree.DecisionTreeConditionalNode;
+import org.apache.ignite.ml.tree.DecisionTreeLeafNode;
+import org.apache.ignite.ml.tree.DecisionTreeNode;
 import org.apache.ignite.ml.xgboost.parser.XGBoostModelBaseVisitor;
 import org.apache.ignite.ml.xgboost.parser.XGBoostModelParser;
 
 /**
  * XGBoost tree visitor that parses tree.
  */
-public class XGTreeVisitor extends XGBoostModelBaseVisitor<XGNode> {
+public class XGTreeVisitor extends XGBoostModelBaseVisitor<DecisionTreeNode> {
     /** Index of the root node. */
     private static final int ROOT_NODE_IDX = 0;
 
+    /** Dictionary for matching column name and index. */
+    private final Map<String, Integer> dict;
+
+    /**
+     * Constructs a new instance of tree visitor.
+     *
+     * @param dict Dictionary for matching column name and index.
+     */
+    public XGTreeVisitor(Map<String, Integer> dict) {
+        this.dict = dict;
+    }
+
     /** {@inheritDoc} */
-    @Override public XGNode visitXgTree(XGBoostModelParser.XgTreeContext ctx) {
-        Map<Integer, XGSplitNode> splitNodes = new HashMap<>();
-        Map<Integer, XGLeafNode> leafNodes = new HashMap<>();
+    @Override public DecisionTreeNode visitXgTree(XGBoostModelParser.XgTreeContext ctx) {
+        Map<Integer, DecisionTreeConditionalNode> splitNodes = new HashMap<>();
+        Map<Integer, DecisionTreeLeafNode> leafNodes = new HashMap<>();
 
         for (XGBoostModelParser.XgNodeContext nodeCtx : ctx.xgNode()) {
             int idx = Integer.valueOf(nodeCtx.INT(0).getText());
             String featureName = nodeCtx.STRING().getText();
             double threshold = parseXgValue(nodeCtx.xgValue());
 
-            splitNodes.put(idx, new XGSplitNode(featureName, threshold));
+            splitNodes.put(idx, new DecisionTreeConditionalNode(dict.get(featureName), threshold, null, null, null));
         }
 
         for (XGBoostModelParser.XgLeafContext leafCtx : ctx.xgLeaf()) {
             int idx = Integer.valueOf(leafCtx.INT().getText());
             double val = parseXgValue(leafCtx.xgValue());
 
-            leafNodes.put(idx, new XGLeafNode(val));
+            leafNodes.put(idx, new DecisionTreeLeafNode(val));
         }
 
         for (XGBoostModelParser.XgNodeContext nodeCtx : ctx.xgNode()) {
@@ -59,9 +71,10 @@ public class XGTreeVisitor extends XGBoostModelBaseVisitor<XGNode> {
             int noIdx = Integer.valueOf(nodeCtx.INT(2).getText());
             int missIdx = Integer.valueOf(nodeCtx.INT(3).getText());
 
-            XGSplitNode node = splitNodes.get(idx);
-            node.setYesNode(splitNodes.containsKey(yesIdx) ? splitNodes.get(yesIdx) : leafNodes.get(yesIdx));
-            node.setNoNode(splitNodes.containsKey(noIdx) ? splitNodes.get(noIdx) : leafNodes.get(noIdx));
+            DecisionTreeConditionalNode node = splitNodes.get(idx);
+
+            node.setElseNode(splitNodes.containsKey(yesIdx) ? splitNodes.get(yesIdx) : leafNodes.get(yesIdx));
+            node.setThenNode(splitNodes.containsKey(noIdx) ? splitNodes.get(noIdx) : leafNodes.get(noIdx));
             node.setMissingNode(splitNodes.containsKey(missIdx) ? splitNodes.get(missIdx) : leafNodes.get(missIdx));
         }
 
index abee78f..1439a5a 100644 (file)
 package org.apache.ignite.ml.xgboost.parser;
 
 import java.net.URL;
+import java.util.HashMap;
 import java.util.Scanner;
-import org.apache.ignite.ml.inference.InfModel;
 import org.apache.ignite.ml.inference.builder.SingleInfModelBuilder;
 import org.apache.ignite.ml.inference.builder.SyncInfModelBuilder;
 import org.apache.ignite.ml.inference.reader.FileSystemInfModelReader;
 import org.apache.ignite.ml.inference.reader.InfModelReader;
-import org.apache.ignite.ml.xgboost.MapBasedXGObject;
-import org.apache.ignite.ml.xgboost.XGObject;
+import org.apache.ignite.ml.xgboost.XGModelComposition;
 import org.junit.Test;
 
 import static org.junit.Assert.assertEquals;
@@ -53,7 +52,7 @@ public class XGBoostModelParserTest {
 
         InfModelReader reader = new FileSystemInfModelReader(url.getPath());
 
-        try (InfModel<XGObject, Double> mdl = mdlBuilder.build(reader, parser);
+        try (XGModelComposition mdl = mdlBuilder.build(reader, parser);
              Scanner testDataScanner = new Scanner(XGBoostModelParserTest.class.getClassLoader()
                  .getResourceAsStream("datasets/agaricus-test-data.txt"));
              Scanner testExpResultsScanner = new Scanner(XGBoostModelParserTest.class.getClassLoader()
@@ -65,7 +64,7 @@ public class XGBoostModelParserTest {
                 String testDataStr = testDataScanner.nextLine();
                 String testExpResultsStr = testExpResultsScanner.nextLine();
 
-                MapBasedXGObject testObj = new MapBasedXGObject();
+                HashMap<String, Double> testObj = new HashMap<>();
 
                 for (String keyValueString : testDataStr.split(" ")) {
                     String[] keyVal = keyValueString.split(":");
@@ -74,7 +73,7 @@ public class XGBoostModelParserTest {
                         testObj.put("f" + keyVal[0], Double.parseDouble(keyVal[1]));
                 }
 
-                double prediction = mdl.predict(testObj);
+                double prediction = mdl.apply(testObj);
 
                 double expPrediction = Double.parseDouble(testExpResultsStr);