IGNITE-8532: [ML] GA Grid: Implement Roulette Wheel Selection
authorTurik Campbell <admin@techbysample.com>
Fri, 18 Jan 2019 13:51:24 +0000 (16:51 +0300)
committerYury Babak <ybabak@gridgain.com>
Fri, 18 Jan 2019 13:51:24 +0000 (16:51 +0300)
This closes #5842

examples/src/main/java/org/apache/ignite/examples/ml/genetic/helloworld/HelloWorldGAExample.java
modules/ml/src/main/java/org/apache/ignite/ml/genetic/GAGrid.java
modules/ml/src/main/java/org/apache/ignite/ml/genetic/RouletteWheelSelectionJob.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/genetic/RouletteWheelSelectionTask.java [new file with mode: 0644]
modules/ml/src/main/java/org/apache/ignite/ml/genetic/parameter/GAGridConstants.java

index 585cbb5..7e8bb8a 100644 (file)
@@ -26,6 +26,7 @@ import org.apache.ignite.ml.genetic.Chromosome;
 import org.apache.ignite.ml.genetic.GAGrid;
 import org.apache.ignite.ml.genetic.Gene;
 import org.apache.ignite.ml.genetic.parameter.GAConfiguration;
+import org.apache.ignite.ml.genetic.parameter.GAGridConstants;
 
 /**
  * This example demonstrates how to use the {@link GAGrid} framework. In this example, we want to evolve a string
@@ -37,6 +38,14 @@ import org.apache.ignite.ml.genetic.parameter.GAConfiguration;
  * <p>
  * You can change the test data and parameters of GA grid used in this example and re-run it to explore
  * this functionality further.</p>
+ * 
+ * For example, you may change the some basic genetic parameters on the GAConfiguration object:
+ * 
+ *  Mutation Rate
+ *  Crossover Rate
+ *  Population Size
+ *  Selection Method
+ *  
  * <p>
  * How to run from command line:</p>
  * <p>
@@ -72,7 +81,19 @@ public class HelloWorldGAExample {
 
             // Initialize gene pool.
             gaCfg.setGenePool(genes);
-
+             
+            // Set CrossOver Rate.
+            gaCfg.setCrossOverRate(.05);
+            
+            // Set Mutation Rate.
+            gaCfg.setMutationRate(.05);
+           
+            // Set Selection Method.
+            gaCfg.setSelectionMtd(GAGridConstants.SELECTION_METHOD.SELECTION_METHOD_ROULETTE_WHEEL);
+            
+            // Set Population Size.
+            gaCfg.setPopulationSize(2000);
+            
             // Create and set Fitness function.
             HelloWorldFitnessFunction function = new HelloWorldFitnessFunction();
             gaCfg.setFitnessFunction(function);
index 92eab5e..5531241 100644 (file)
 package org.apache.ignite.ml.genetic;
 
 import java.util.ArrayList;
+import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Random;
+import java.util.stream.Collectors;
 import javax.cache.Cache.Entry;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
@@ -48,8 +50,7 @@ public class GAGrid {
     private IgniteCache<Long, Chromosome> populationCache;
     /** Gene cache */
     private IgniteCache<Long, Gene> geneCache;
-    /** population keys */
-    private List<Long> populationKeys = new ArrayList<Long>();
+   
 
     /**
      * @param cfg GAConfiguration
@@ -76,12 +77,11 @@ public class GAGrid {
      * @return Average fitness score
      */
     private Double calculateAverageFitness() {
-
         double avgFitnessScore = 0;
 
         IgniteCache<Long, Gene> cache = ignite.cache(GAGridConstants.POPULATION_CACHE);
 
-        // Execute query to get names of all employees.
+        // Execute query calculate average fitness
         SqlFieldsQuery sql = new SqlFieldsQuery("select AVG(FITNESSSCORE) from Chromosome");
 
         // Iterate over the result set.
@@ -110,7 +110,7 @@ public class GAGrid {
     private Boolean copyFitterChromosomesToPopulation(List<Long> fittestKeys, List<Long> selectedKeys) {
         double truncatePercentage = this.cfg.getTruncateRate();
 
-        int totalSize = this.populationKeys.size();
+        int totalSize = this.cfg.getPopulationSize();
 
         int truncateCnt = (int)(truncatePercentage * totalSize);
 
@@ -118,7 +118,6 @@ public class GAGrid {
 
         return this.ignite.compute()
             .execute(new TruncateSelectionTask(fittestKeys, numOfCopies), selectedKeys);
-
     }
 
     /**
@@ -137,7 +136,7 @@ public class GAGrid {
             if (!(keys.contains(key))) {
                 genes[k] = key;
                 keys.add(key);
-                k = k + 1;
+                k += 1;
             }
         }
         return new Chromosome(genes);
@@ -165,27 +164,28 @@ public class GAGrid {
 
         initializeGenePopulation();
 
-        intializePopulation();
+        initializePopulation();
 
         // Calculate Fitness
-        calculateFitness(this.populationKeys);
+        calculateFitness(getPopulationKeys());
 
         // Retrieve chromosomes in order by fitness value
-        List<Long> keys = getChromosomesByFittest();
+        LinkedHashMap<Long, Double> map = getChromosomesByFittest();
 
         // Calculate average fitness value of population
         double averageFitnessScore = calculateAverageFitness();
 
-        fittestChomosome = populationCache.get(keys.get(0));
+        Long key = map.keySet().iterator().next();
+               
+        fittestChomosome = populationCache.get(key);
 
         // while NOT terminateCondition met
         while (!(cfg.getTerminateCriteria().isTerminationConditionMet(fittestChomosome, averageFitnessScore,
             generationCnt))) {
-            generationCnt = generationCnt + 1;
+            generationCnt += 1;
 
             // We will crossover/mutate over chromosomes based on selection method
-
-            List<Long> selectedKeysforCrossMutaton = selection(keys);
+            List<Long> selectedKeysforCrossMutaton = selection(map);
 
             // Cross Over
             crossover(selectedKeysforCrossMutaton);
@@ -197,10 +197,12 @@ public class GAGrid {
             calculateFitness(selectedKeysforCrossMutaton);
 
             // Retrieve chromosomes in order by fitness value
-            keys = getChromosomesByFittest();
+            map = getChromosomesByFittest();
 
+            key = map.keySet().iterator().next();
+            
             // Retreive the first chromosome from the list
-            fittestChomosome = populationCache.get(keys.get(0));
+            fittestChomosome = populationCache.get(key);
 
             // Calculate average fitness value of population
             averageFitnessScore = calculateAverageFitness();
@@ -214,27 +216,29 @@ public class GAGrid {
     /**
      * helper routine to retrieve Chromosome keys in order of fittest
      *
-     * @return List of primary keys for chromosomes.
+     * @return Map of primary key/fitness score pairs for chromosomes.
      */
-    private List<Long> getChromosomesByFittest() {
-        List<Long> orderChromKeysByFittest = new ArrayList<Long>();
+    private LinkedHashMap<Long,Double> getChromosomesByFittest() {
+       LinkedHashMap<Long, Double> orderChromKeysByFittest  = new LinkedHashMap<>();
+       
         String orderDirection = "desc";
 
         if (!cfg.isHigherFitnessValFitter())
             orderDirection = "asc";
 
-        String fittestSQL = "select _key from Chromosome order by fitnessScore " + orderDirection;
+        String fittestSQL = "select _key, fitnessScore from Chromosome order by fitnessScore " + orderDirection;
 
         // Execute query to retrieve keys for ALL Chromosomes by fittnessScore
         QueryCursor<List<?>> cursor = populationCache.query(new SqlFieldsQuery(fittestSQL));
-
+    
         List<List<?>> res = cursor.getAll();
-
+                       
         for (List row : res) {
-            Long key = (Long)row.get(0);
-            orderChromKeysByFittest.add(key);
+               Long key = (Long)row.get(0);
+               Double fitnessScore= (Double)row.get(1);
+               orderChromKeysByFittest.put(key, fitnessScore);
         }
-
+        
         return orderChromKeysByFittest;
     }
 
@@ -272,25 +276,11 @@ public class GAGrid {
         for (int j = 0; j < populationSize; j++) {
             Chromosome chromosome = createChromosome(cfg.getChromosomeLen());
             populationCache.put(chromosome.id(), chromosome);
-            populationKeys.add(chromosome.id());
         }
 
     }
 
-    /**
-     * initialize the population of Chromosomes based on GAConfiguration
-     */
-    void intializePopulation() {
-        int populationSize = cfg.getPopulationSize();
-        populationCache.clear();
-
-        for (int j = 0; j < populationSize; j++) {
-            Chromosome chromosome = createChromosome(cfg.getChromosomeLen());
-            populationCache.put(chromosome.id(), chromosome);
-            populationKeys.add(chromosome.id());
-        }
-
-    }
+  
 
     /**
      * Perform mutation
@@ -330,7 +320,7 @@ public class GAGrid {
      * Truncation selection simply retains the fittest x% of the population. These fittest individuals are duplicated so
      * that the population size is maintained.
      *
-     * @param keys
+     * @param keys Keys.
      * @return List of keys
      */
     private List<Long> selectByTruncation(List<Long> keys) {
@@ -340,6 +330,18 @@ public class GAGrid {
 
         return keys.subList(truncateCnt, keys.size());
     }
+    
+    /**
+     * Roulette Wheel selection 
+     *
+     * @param map Map of keys/fitness scores
+     * @return List of primary Keys for respective chromosomes that will breed
+     */
+    private List<Long> selectByRouletteWheel(LinkedHashMap map) {
+       List<Long> populationKeys = this.ignite.compute().execute(new RouletteWheelSelectionTask(this.cfg), map);
+       
+        return populationKeys;
+    }
 
     /**
      * @param k Gene index in Chromosome.
@@ -359,7 +361,7 @@ public class GAGrid {
      * @return Primary key of respective Gene
      */
     private long selectGeneByChromsomeCriteria(int k) {
-        List<Gene> genes = new ArrayList();
+        List<Gene> genes = new ArrayList<>();
 
         StringBuffer sbSqlClause = new StringBuffer("_val like '");
         sbSqlClause.append("%");
@@ -393,11 +395,14 @@ public class GAGrid {
     /**
      * Select chromosomes
      *
-     * @param chromosomeKeys List of population primary keys for respective Chromsomes
+     * @param map Map of keys/fitness scores for respective Chromsomes
      * @return List of primary keys for respective Chromsomes
      */
-    private List<Long> selection(List<Long> chromosomeKeys) {
-        List<Long> selectedKeys = new ArrayList();
+    private List<Long> selection(LinkedHashMap map) {
+        List<Long> selectedKeys = new ArrayList<>();
+
+        // We will crossover/mutate over chromosomes based on selection method
+        List<Long> chromosomeKeys = new ArrayList<>(map.keySet());
 
         GAGridConstants.SELECTION_METHOD selectionMtd = cfg.getSelectionMtd();
 
@@ -413,8 +418,10 @@ public class GAGrid {
                 copyFitterChromosomesToPopulation(fittestKeys, selectedKeys);
 
                 // copy more fit keys to rest of population
-                break;
-
+                break; 
+            case SELECTION_METHOD_ROULETTE_WHEEL:
+              selectedKeys = this.selectByRouletteWheel(map);
+               
             default:
                 break;
         }
@@ -428,6 +435,14 @@ public class GAGrid {
      * @return List of Chromosome primary keys
      */
     List<Long> getPopulationKeys() {
-        return populationKeys;
+        String fittestSQL = "select _key from Chromosome";
+
+         // Execute query to retrieve keys for ALL Chromosomes
+         QueryCursor<List<?>> cursor = populationCache.query(new SqlFieldsQuery(fittestSQL));
+
+         List<List<?>> res = cursor.getAll();
+
+        return (List<Long>) res.stream().map(x -> x.get(0)).collect(Collectors.toList());
     }
+
 }
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/genetic/RouletteWheelSelectionJob.java b/modules/ml/src/main/java/org/apache/ignite/ml/genetic/RouletteWheelSelectionJob.java
new file mode 100644 (file)
index 0000000..5b288af
--- /dev/null
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.genetic;
+
+import java.util.Iterator;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Random;
+import java.util.stream.Collectors;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.IgniteException;
+import org.apache.ignite.IgniteLogger;
+import org.apache.ignite.compute.ComputeJobAdapter;
+import org.apache.ignite.ml.genetic.parameter.GAGridConstants;
+import org.apache.ignite.resources.IgniteInstanceResource;
+import org.apache.ignite.resources.LoggerResource;
+
+/**
+ * Responsible for performing Roulette Wheel selection
+ */
+public class RouletteWheelSelectionJob extends ComputeJobAdapter {
+    /** Ignite instance */
+    @IgniteInstanceResource
+    private Ignite ignite = null;
+
+    /** Ignite logger */
+    @LoggerResource
+    private IgniteLogger log = null;
+
+    /** Total Fitness score */
+    Double totalFitnessScore = null;
+
+    /** Chromosome key/fitness score pair */
+    LinkedHashMap<Long, Double> map = null;
+
+    /**
+     * @param totalFitnessScore Total fitness score
+     * @param map Chromosome key / fitness score map
+     */
+    public RouletteWheelSelectionJob(Double totalFitnessScore, LinkedHashMap<Long, Double> map) {
+        this.totalFitnessScore = totalFitnessScore;
+        this.map = map;
+    }
+
+    /**
+     * Perform Roulette Wheel selection
+     *
+     * @return Chromosome parent chosen after 'spinning' the wheel.
+     */
+    @Override public Chromosome execute() throws IgniteException {
+
+        IgniteCache<Long, Chromosome> populationCache = ignite.cache(GAGridConstants.POPULATION_CACHE);
+
+        int value = spintheWheel(this.totalFitnessScore);
+
+        double partialSum = 0;
+        boolean notFound = true;
+
+        //sort map in ascending order by fitness score
+        Map<Long, Double> sortedAscendingMap = map.entrySet().stream()
+            .sorted(Map.Entry.comparingByValue())
+            .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (e1, e2) -> e1, LinkedHashMap::new));
+
+        Iterator<Entry<Long, Double>> entries = sortedAscendingMap.entrySet().iterator();
+
+        Long chromosomeKey = (long)-1;
+
+        while (entries.hasNext() && notFound) {
+            Entry<Long, Double> entry = entries.next();
+            Long key = entry.getKey();
+            Double fitnessScore = entry.getValue();
+            partialSum = partialSum + fitnessScore;
+
+            if (partialSum >= value) {
+                notFound = false;
+                chromosomeKey = key;
+            }
+        }
+
+        return populationCache.get(chromosomeKey);
+    }
+
+    /**
+     * Spin the wheel.
+     *
+     * @param fitnessScore Size of Gene pool
+     * @return value
+     */
+    private int spintheWheel(Double fitnessScore) {
+        Random randomGenerator = new Random();
+        return randomGenerator.nextInt(fitnessScore.intValue());
+    }
+
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/genetic/RouletteWheelSelectionTask.java b/modules/ml/src/main/java/org/apache/ignite/ml/genetic/RouletteWheelSelectionTask.java
new file mode 100644 (file)
index 0000000..9d81471
--- /dev/null
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.genetic;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.IgniteException;
+import org.apache.ignite.cache.affinity.Affinity;
+import org.apache.ignite.cache.query.QueryCursor;
+import org.apache.ignite.cache.query.SqlFieldsQuery;
+import org.apache.ignite.cluster.ClusterNode;
+import org.apache.ignite.compute.ComputeJob;
+import org.apache.ignite.compute.ComputeJobResult;
+import org.apache.ignite.compute.ComputeJobResultPolicy;
+import org.apache.ignite.compute.ComputeLoadBalancer;
+import org.apache.ignite.compute.ComputeTaskAdapter;
+import org.apache.ignite.ml.genetic.parameter.GAConfiguration;
+import org.apache.ignite.ml.genetic.parameter.GAGridConstants;
+import org.apache.ignite.resources.IgniteInstanceResource;
+import org.apache.ignite.resources.LoadBalancerResource;
+
+/**
+ * Responsible for performing Roulette Wheel selection.
+ */
+public class RouletteWheelSelectionTask extends ComputeTaskAdapter<LinkedHashMap<Long, Double>, List<Long>> {
+    /** Ignite resource. */
+    @IgniteInstanceResource
+    private Ignite ignite = null;
+
+    // Inject load balancer.
+    @LoadBalancerResource
+    ComputeLoadBalancer balancer;
+
+    /** GAConfiguration */
+    private GAConfiguration cfg = null;
+
+    /**
+     * @param cfg GAConfiguration
+     */
+    public RouletteWheelSelectionTask(GAConfiguration cfg) {
+        this.cfg = cfg;
+    }
+
+    /**
+     * Calculate total fitness of population
+     *
+     * @return Double value representing total fitness score of population
+     */
+    private Double calculateTotalFitness() {
+        double totalFitnessScore = 0;
+
+        IgniteCache<Long, Chromosome> cache = ignite.cache(GAGridConstants.POPULATION_CACHE);
+
+        SqlFieldsQuery sql = new SqlFieldsQuery("select SUM(FITNESSSCORE) from Chromosome");
+
+        // Iterate over the result set.
+        try (QueryCursor<List<?>> cursor = cache.query(sql)) {
+            for (List<?> row : cursor)
+                totalFitnessScore = (Double)row.get(0);
+        }
+
+        return totalFitnessScore;
+    }
+
+    /**
+     * @param nodes List of ClusterNode.
+     * @param chromosomeKeyFitness Map of key/fitness score pairs.
+     * @return Map of nodes to jobs.
+     */
+    @Override public Map<ComputeJob, ClusterNode> map(List<ClusterNode> nodes,
+        LinkedHashMap<Long, Double> chromosomeKeyFitness) throws IgniteException {
+        Map<ComputeJob, ClusterNode> map = new HashMap<>();
+
+        Affinity affinity = ignite.affinity(GAGridConstants.POPULATION_CACHE);
+        Double totalFitness = this.calculateTotalFitness();
+
+        int populationSize = this.cfg.getPopulationSize();
+
+        for (int i = 0; i < populationSize; i++) {
+            // Pick the next best balanced node for the job.
+            RouletteWheelSelectionJob job = new RouletteWheelSelectionJob(totalFitness, chromosomeKeyFitness);
+            map.put(job, balancer.getBalancedNode(job, null));
+        }
+
+        return map;
+    }
+
+    /**
+     * Return list of parent Chromosomes.
+     *
+     * @param list List of ComputeJobResult.
+     * @return List of Chromosome keys.
+     */
+    @Override public List<Long> reduce(List<ComputeJobResult> list) throws IgniteException {
+        List<Chromosome> parents = list.stream().map((x) -> (Chromosome)x.getData()).collect(Collectors.toList());
+
+        return createParents(parents);
+    }
+
+    /**
+     * Create new parents and add to populationCache
+     *
+     * @param parents Chromosomes chosen to breed
+     * @return List of Chromosome keys.
+     */
+    private List<Long> createParents(List<Chromosome> parents) {
+        IgniteCache<Long, Chromosome> cache = ignite.cache(GAGridConstants.POPULATION_CACHE);
+        cache.clear();
+
+        List<Long> keys = new ArrayList();
+
+        parents.stream().forEach((x) -> {
+            long[] genes = x.getGenes();
+            Chromosome newparent = new Chromosome(genes);
+            cache.put(newparent.id(), newparent);
+            keys.add(newparent.id());
+        });
+
+        return keys;
+    }
+
+    /** {@inheritDoc} */
+    @Override public ComputeJobResultPolicy result(ComputeJobResult res, List<ComputeJobResult> rcvd) {
+        IgniteException err = res.getException();
+
+        if (err != null)
+            return ComputeJobResultPolicy.FAILOVER;
+
+        // If there is no exception, wait for all job results.
+        return ComputeJobResultPolicy.WAIT;
+    }
+}
index 6d1645f..a44a802 100644 (file)
@@ -29,6 +29,11 @@ public interface GAGridConstants {
 
     /** Selection Method type **/
     public enum SELECTION_METHOD {
-        SELECTON_METHOD_ELETISM, SELECTION_METHOD_TRUNCATION
+        /** Selecton method eletism. */
+        SELECTON_METHOD_ELETISM,
+        /** Selection method truncation. */
+        SELECTION_METHOD_TRUNCATION,
+        /** Selection method roulette wheel. */
+        SELECTION_METHOD_ROULETTE_WHEEL
     }
 }