IGNITE-10803: [ML] Add prototype LogReg loading from PMML format
authorzaleslaw <zaleslaw.sin@gmail.com>
Thu, 27 Dec 2018 13:17:00 +0000 (16:17 +0300)
committerYury Babak <ybabak@gridgain.com>
Thu, 27 Dec 2018 13:17:00 +0000 (16:17 +0300)
This closes #5744

examples/pom.xml
examples/src/main/java/org/apache/ignite/examples/ml/inference/LogRegFromSparkThroughPMMLExample.java [new file with mode: 0644]
examples/src/main/resources/models/spark/iris.pmml [new file with mode: 0644]

index 429ec79..6320a0f 100644 (file)
             <version>${javassist.version}</version>
             <scope>test</scope>
         </dependency>
+        <!-- https://mvnrepository.com/artifact/org.jpmml/pmml-model -->
+        <dependency>
+            <groupId>org.jpmml</groupId>
+            <artifactId>pmml-model</artifactId>
+            <version>1.4.7</version>
+        </dependency>
+
+        <dependency>
+            <groupId>com.fasterxml.jackson.core</groupId>
+            <artifactId>jackson-core</artifactId>
+            <version>2.7.3</version>
+        </dependency>
+
+        <dependency>
+            <groupId>com.fasterxml.jackson.core</groupId>
+            <artifactId>jackson-databind</artifactId>
+            <version>2.7.3</version>
+        </dependency>
+
+        <dependency>
+            <groupId>com.fasterxml.jackson.core</groupId>
+            <artifactId>jackson-annotations</artifactId>
+            <version>2.7.3</version>
+        </dependency>
+
     </dependencies>
 
     <properties>
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/LogRegFromSparkThroughPMMLExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/LogRegFromSparkThroughPMMLExample.java
new file mode 100644 (file)
index 0000000..30a4498
--- /dev/null
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.inference;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.io.InputStream;
+import javax.xml.bind.JAXBException;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
+import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel;
+import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
+import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import org.apache.ignite.ml.util.MLSandboxDatasets;
+import org.apache.ignite.ml.util.SandboxMLCache;
+import org.dmg.pmml.PMML;
+import org.dmg.pmml.regression.RegressionModel;
+import org.dmg.pmml.regression.RegressionTable;
+import org.jpmml.model.PMMLUtil;
+import org.xml.sax.SAXException;
+
+/**
+ * Run logistic regression model loaded from PMML file. The PMML file was generated by Spark MLLib toPMML operator.
+ * <p>
+ * Code in this example launches Ignite grid and fills the cache with test data points (based on the
+ * <a href="https://en.wikipedia.org/wiki/Iris_flower_data_set"></a>Iris dataset</a>).</p>
+ * <p>
+ * You can change the test data used in this example and re-run it to explore this algorithm further.</p>
+ */
+public class LogRegFromSparkThroughPMMLExample {
+    /** Run example. */
+    public static void main(String[] args) throws FileNotFoundException {
+        System.out.println();
+        System.out.println(">>> Logistic regression model loaded from PMML over partitioned dataset usage example started.");
+        // Start ignite grid.
+        try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
+            System.out.println(">>> Ignite grid started.");
+
+            IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
+                .fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
+
+            LogisticRegressionModel mdl = PMMLParser.load("examples/src/main/resources/models/spark/iris.pmml");
+
+            System.out.println(">>> Logistic regression model: " + mdl);
+
+            double accuracy = BinaryClassificationEvaluator.evaluate(
+                dataCache,
+                mdl,
+                (k, v) -> v.copyOfRange(1, v.size()),
+                (k, v) -> v.get(0),
+                new Accuracy<>()
+            );
+
+            System.out.println("\n>>> Accuracy " + accuracy);
+            System.out.println("\n>>> Test Error " + (1 - accuracy));
+        }
+    }
+
+    /** Util class to build the LogReg model. */
+    private static class PMMLParser {
+        /**
+         * @param path Path.
+         */
+        public static LogisticRegressionModel load(String path) {
+            try (InputStream is = new FileInputStream(new File(path))) {
+                PMML pmml = PMMLUtil.unmarshal(is);
+
+                RegressionModel logRegMdl = (RegressionModel)pmml.getModels().get(0);
+
+                RegressionTable regTbl = logRegMdl.getRegressionTables().get(0);
+
+                Vector coefficients = new DenseVector(regTbl.getNumericPredictors().size());
+
+                for (int i = 0; i < regTbl.getNumericPredictors().size(); i++)
+                    coefficients.set(i, regTbl.getNumericPredictors().get(i).getCoefficient());
+
+                double interceptor = regTbl.getIntercept();
+
+                return new LogisticRegressionModel(coefficients, interceptor);
+            }
+            catch (IOException | JAXBException | SAXException e) {
+                e.printStackTrace();
+            }
+
+            return null;
+        }
+    }
+}
diff --git a/examples/src/main/resources/models/spark/iris.pmml b/examples/src/main/resources/models/spark/iris.pmml
new file mode 100644 (file)
index 0000000..78f310d
--- /dev/null
@@ -0,0 +1,30 @@
+<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
+<PMML xmlns="http://www.dmg.org/PMML-4_2" version="4.2">
+    <Header description="logistic regression">
+        <Application name="Apache Spark MLlib" version="2.2.0"/>
+        <Timestamp>2018-12-25T15:09:09</Timestamp>
+    </Header>
+    <DataDictionary numberOfFields="5">
+        <DataField name="field_0" optype="continuous" dataType="double"/>
+        <DataField name="field_1" optype="continuous" dataType="double"/>
+        <DataField name="field_2" optype="continuous" dataType="double"/>
+        <DataField name="field_3" optype="continuous" dataType="double"/>
+        <DataField name="target" optype="categorical" dataType="string"/>
+    </DataDictionary>
+    <RegressionModel modelName="logistic regression" functionName="classification" normalizationMethod="logit">
+        <MiningSchema>
+            <MiningField name="field_0" usageType="active"/>
+            <MiningField name="field_1" usageType="active"/>
+            <MiningField name="field_2" usageType="active"/>
+            <MiningField name="field_3" usageType="active"/>
+            <MiningField name="target" usageType="target"/>
+        </MiningSchema>
+        <RegressionTable intercept="0.0" targetCategory="1">
+            <NumericPredictor name="field_0" coefficient="5.84520630732407"/>
+            <NumericPredictor name="field_1" coefficient="-19.36222130270906"/>
+            <NumericPredictor name="field_2" coefficient="5.66074235971065"/>
+            <NumericPredictor name="field_3" coefficient="16.110585062151788"/>
+        </RegressionTable>
+        <RegressionTable intercept="-0.0" targetCategory="0"/>
+    </RegressionModel>
+</PMML>