Probabilistic diff to sample partitions for diff testing based on probability trunk
authorJyothsna Konisa <jkonisa@apple.com>
Thu, 23 Sep 2021 19:47:38 +0000 (12:47 -0700)
committerYifan Cai <ycai@apache.org>
Mon, 25 Oct 2021 15:40:08 +0000 (08:40 -0700)
Patch by Jyothsna Konisa; reviewed by Dinesh Joshi, Yifan Cai for CASSANDRA-16967

common/src/main/java/org/apache/cassandra/diff/JobConfiguration.java
common/src/main/java/org/apache/cassandra/diff/YamlJobConfiguration.java
spark-job/src/main/java/org/apache/cassandra/diff/Differ.java
spark-job/src/main/java/org/apache/cassandra/diff/RangeComparator.java
spark-job/src/test/java/org/apache/cassandra/diff/DiffJobTest.java
spark-job/src/test/java/org/apache/cassandra/diff/DifferTest.java
spark-job/src/test/java/org/apache/cassandra/diff/RangeComparatorTest.java
spark-job/src/test/java/org/apache/cassandra/diff/SchemaTest.java

index 7a20b30376ed2f8e415f16f2cedc80787b28803d..8d74de8d50a0727d12a93a03813c23baa6709557 100644 (file)
@@ -87,6 +87,13 @@ public interface JobConfiguration extends Serializable {
 
     MetadataKeyspaceOptions metadataOptions();
 
+    /**
+     * Sampling probability ranges from 0-1 which decides how many partitions are to be diffed using probabilistic diff
+     * default value is 1 which means all the partitions are diffed
+     * @return partitionSamplingProbability
+     */
+    double partitionSamplingProbability();
+
     /**
      * Contains the options that specify the retry strategy for retrieving data at the application level.
      * Note that it is different than cassandra java driver's {@link com.datastax.driver.core.policies.RetryPolicy},
index 359466a33d342a6c4d13ad9db5290ec0ff83ed2e..7d60403e1344ff9075998687d0bb5b1d05537a88 100644 (file)
@@ -48,6 +48,7 @@ public class YamlJobConfiguration implements JobConfiguration {
     public String specific_tokens = null;
     public String disallowed_tokens = null;
     public RetryOptions retry_options;
+    public double partition_sampling_probability = 1;
 
     public static YamlJobConfiguration load(InputStream inputStream) {
         Yaml yaml = new Yaml(new CustomClassLoaderConstructor(YamlJobConfiguration.class,
@@ -103,6 +104,11 @@ public class YamlJobConfiguration implements JobConfiguration {
         return metadata_options;
     }
 
+    @Override
+    public double partitionSamplingProbability() {
+        return partition_sampling_probability;
+    }
+
     public RetryOptions retryOptions() {
         return retry_options;
     }
@@ -130,6 +136,7 @@ public class YamlJobConfiguration implements JobConfiguration {
                ", keyspace_tables=" + keyspace_tables +
                ", buckets=" + buckets +
                ", rate_limit=" + rate_limit +
+               ", partition_sampling_probability=" + partition_sampling_probability +
                ", job_id='" + job_id + '\'' +
                ", token_scan_fetch_size=" + token_scan_fetch_size +
                ", partition_read_fetch_size=" + partition_read_fetch_size +
index cf1c9a58bf230903c14774c2e553978a387f9790..11794c5682c51f4308e12a138bf75063c44dd40d 100644 (file)
@@ -27,10 +27,12 @@ import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import java.util.Random;
 import java.util.UUID;
 import java.util.concurrent.Callable;
 import java.util.function.BiConsumer;
 import java.util.function.Function;
+import java.util.function.Predicate;
 import java.util.stream.Collectors;
 
 import com.google.common.annotations.VisibleForTesting;
@@ -63,6 +65,7 @@ public class Differ implements Serializable
     private final double reverseReadProbability;
     private final SpecificTokens specificTokens;
     private final RetryStrategyProvider retryStrategyProvider;
+    private final double partitionSamplingProbability;
 
     private static DiffCluster srcDiffCluster;
     private static DiffCluster targetDiffCluster;
@@ -103,6 +106,7 @@ public class Differ implements Serializable
         this.reverseReadProbability = config.reverseReadProbability();
         this.specificTokens = config.specificTokens();
         this.retryStrategyProvider = retryStrategyProvider;
+        this.partitionSamplingProbability = config.partitionSamplingProbability();
         synchronized (Differ.class)
         {
             /*
@@ -225,12 +229,28 @@ public class Differ implements Serializable
                                                               mismatchReporter,
                                                               journal,
                                                               COMPARISON_EXECUTOR);
-
-        final RangeStats tableStats = rangeComparator.compare(sourceKeys, targetKeys, partitionTaskProvider);
+        final Predicate<PartitionKey> partitionSamplingFunction = shouldIncludePartition(jobId, partitionSamplingProbability);
+        final RangeStats tableStats = rangeComparator.compare(sourceKeys, targetKeys, partitionTaskProvider, partitionSamplingFunction);
         logger.debug("Table [{}] stats - ({})", context.table.getTable(), tableStats);
         return tableStats;
     }
 
+    // Returns a function which decides if we should include a partition for diffing
+    // Uses probability for sampling.
+    @VisibleForTesting
+    static Predicate<PartitionKey> shouldIncludePartition(final UUID jobId, final double partitionSamplingProbability) {
+        if (partitionSamplingProbability > 1 || partitionSamplingProbability <= 0) {
+            logger.error("Invalid partition sampling property {}, it should be between 0 and 1", partitionSamplingProbability);
+            throw new IllegalArgumentException("Invalid partition sampling property, it should be between 0 and 1");
+        }
+        if (partitionSamplingProbability == 1) {
+            return partitionKey -> true;
+        } else {
+            final Random random = new Random(jobId.hashCode());
+            return partitionKey -> random.nextDouble() <= partitionSamplingProbability;
+        }
+    }
+
     private Iterator<Row> fetchRows(DiffContext context, PartitionKey key, boolean shouldReverse, DiffCluster.Type type) {
         Callable<Iterator<Row>> rows = () -> type == DiffCluster.Type.SOURCE
                                              ? context.source.getPartition(context.table, key, shouldReverse)
index 5d6710ed4db849e90ee0518fead8fac7b892759b..280fbd570ee898f6012001accbb8cd17456be5e7 100644 (file)
@@ -27,6 +27,7 @@ import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.BiConsumer;
 import java.util.function.Consumer;
 import java.util.function.Function;
+import java.util.function.Predicate;
 
 import com.google.common.base.Verify;
 import org.slf4j.Logger;
@@ -57,6 +58,22 @@ public class RangeComparator {
     public RangeStats compare(Iterator<PartitionKey> sourceKeys,
                               Iterator<PartitionKey> targetKeys,
                               Function<PartitionKey, PartitionComparator> partitionTaskProvider) {
+        return compare(sourceKeys,targetKeys,partitionTaskProvider, partitionKey -> true);
+    }
+
+    /**
+     * Compares partitions in src and target clusters.
+     *
+     * @param sourceKeys partition keys in the source cluster
+     * @param targetKeys partition keys in the target cluster
+     * @param partitionTaskProvider comparision task
+     * @param partitionSampler samples partitions based on the probability for probabilistic diff
+     * @return stats about the diff
+     */
+    public RangeStats compare(Iterator<PartitionKey> sourceKeys,
+                              Iterator<PartitionKey> targetKeys,
+                              Function<PartitionKey, PartitionComparator> partitionTaskProvider,
+                              Predicate<PartitionKey> partitionSampler) {
 
         final RangeStats rangeStats = RangeStats.newStats();
         // We can catch this condition earlier, but it doesn't hurt to also check here
@@ -115,11 +132,16 @@ public class RangeComparator {
 
                     BigInteger token = sourceKey.getTokenAsBigInteger();
                     try {
-                        PartitionComparator comparisonTask = partitionTaskProvider.apply(sourceKey);
-                        comparisonExecutor.submit(comparisonTask,
-                                                  onSuccess(rangeStats, partitionCount, token, highestTokenSeen, mismatchReporter, journal),
-                                                  onError(rangeStats, token, errorReporter),
-                                                  phaser);
+                        // Use probabilisticPartitionSampler for sampling partitions, skip partition
+                        // if the sampler returns false otherwise run diff on that partition
+                        if (partitionSampler.test(sourceKey)) {
+                            PartitionComparator comparisonTask = partitionTaskProvider.apply(sourceKey);
+                            comparisonExecutor.submit(comparisonTask,
+                                                      onSuccess(rangeStats, partitionCount, token, highestTokenSeen, mismatchReporter, journal),
+                                                      onError(rangeStats, token, errorReporter),
+                                                      phaser);
+                        }
+
                     } catch (Throwable t) {
                         // Handle errors thrown when creating the comparison task. This should trap timeouts and
                         // unavailables occurring when performing the initial query to read the full partition.
index 1bf656d4e1bcf640789223fac939a4bea675da19..49c1f113d68c236dbef49858606185b8fffde461 100644 (file)
@@ -108,5 +108,10 @@ public class DiffJobTest
         public Optional<UUID> jobId() {
             return Optional.of(UUID.randomUUID());
         }
+
+        @Override
+        public double partitionSamplingProbability() {
+            return 1;
+        }
     }
 }
index e5885750dbfdaec14f82b401d0bedcaab73754a3..b1b524d22a901bfc4ec10a7df0e932def778e03b 100644 (file)
@@ -21,16 +21,65 @@ package org.apache.cassandra.diff;
 
 import java.math.BigInteger;
 import java.util.Map;
+import java.util.UUID;
 import java.util.function.Function;
+import java.util.function.Predicate;
 
 import com.google.common.base.VerifyException;
 import com.google.common.collect.Lists;
+import org.junit.Rule;
 import org.junit.Test;
+import org.junit.rules.ExpectedException;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
 
 public class DifferTest {
+    @Rule
+    public ExpectedException expectedException = ExpectedException.none();
+
+    @Test
+    public void testIncludeAllPartitions() {
+        final PartitionKey testKey = new RangeComparatorTest.TestPartitionKey(0);
+        final UUID uuid = UUID.fromString("cde3b15d-2363-4028-885a-52de58bad64e");
+        assertTrue(Differ.shouldIncludePartition(uuid, 1).test(testKey));
+    }
+
+    @Test
+    public void shouldIncludePartitionWithProbabilityInvalidProbability() {
+        final PartitionKey testKey = new RangeComparatorTest.TestPartitionKey(0);
+        final UUID uuid = UUID.fromString("cde3b15d-2363-4028-885a-52de58bad64e");
+        expectedException.expect(IllegalArgumentException.class);
+        expectedException.expectMessage("Invalid partition sampling property, it should be between 0 and 1");
+        Differ.shouldIncludePartition(uuid, -1).test(testKey);
+    }
+
+    @Test
+    public void shouldIncludePartitionWithProbabilityHalf() {
+        final PartitionKey testKey = new RangeComparatorTest.TestPartitionKey(0);
+        int count = 0;
+        final UUID uuid = UUID.fromString("cde3b15d-2363-4028-885a-52de58bad64e");
+        final Predicate<PartitionKey> partitionSampler = Differ.shouldIncludePartition(uuid, 0.5);
+        for (int i = 0; i < 20; i++) {
+            if (partitionSampler.test(testKey)) {
+                count++;
+            }
+        }
+        assertTrue(count <= 15);
+        assertTrue(count >= 5);
+    }
+
+    @Test
+    public void shouldIncludePartitionShouldGenerateSameSequenceForGivenJobId() {
+        final UUID uuid = UUID.fromString("cde3b15d-2363-4028-885a-52de58bad64e");
+        final PartitionKey testKey = new RangeComparatorTest.TestPartitionKey(0);
+        final Predicate<PartitionKey> partitionSampler1 = Differ.shouldIncludePartition(uuid, 0.5);
+        final Predicate<PartitionKey> partitionSampler2 = Differ.shouldIncludePartition(uuid, 0.5);
+        for (int i = 0; i < 10; i++) {
+            assertEquals(partitionSampler2.test(testKey), partitionSampler1.test(testKey));
+        }
+    }
 
     @Test(expected = VerifyException.class)
     public void rejectNullStartOfRange() {
index fd2926bf2334a0aff29493724a6a98663283ec9c..e09f68f7202ac3c0b6460782e560ea2f1f304b40 100644 (file)
@@ -56,6 +56,38 @@ public class RangeComparatorTest {
     private ComparisonExecutor executor = ComparisonExecutor.newExecutor(1, new MetricRegistry());
     private RetryStrategyProvider mockRetryStrategyFactory = RetryStrategyProvider.create(null); // create a NoRetry provider
 
+    @Test
+    public void probabilisticDiffIncludeAllPartitions() {
+        RangeComparator comparator = comparator(context(0L, 100L));
+        RangeStats stats = comparator.compare(keys(0, 1, 2, 3, 4, 5, 6), keys(0,1, 2, 3, 4, 5, 7), this::alwaysMatch);
+        assertFalse(stats.isEmpty());
+        assertEquals(1, stats.getOnlyInSource());
+        assertEquals(1, stats.getOnlyInTarget());
+        assertEquals(6, stats.getMatchedPartitions());
+        assertReported(6, MismatchType.ONLY_IN_SOURCE, mismatches);
+        assertReported(7, MismatchType.ONLY_IN_TARGET, mismatches);
+        assertNothingReported(errors, journal);
+        assertCompared(0, 1, 2, 3, 4, 5);
+    }
+
+    @Test
+    public void probabilisticDiffProbabilityHalf() {
+        RangeComparator comparator = comparator(context(0L, 100L));
+        RangeStats stats = comparator.compare(keys(0, 1, 2, 3, 4, 5, 6),
+                                              keys(0, 1, 2, 3, 4, 5, 7),
+                                              this::alwaysMatch,
+                                              key -> key.getTokenAsBigInteger().intValue() % 2 == 0);
+        assertFalse(stats.isEmpty());
+        assertEquals(1, stats.getOnlyInSource());
+        assertEquals(1, stats.getOnlyInTarget());
+        assertEquals(3, stats.getMatchedPartitions());
+        assertReported(6, MismatchType.ONLY_IN_SOURCE, mismatches);
+        assertReported(7, MismatchType.ONLY_IN_TARGET, mismatches);
+        assertNothingReported(errors, journal);
+        assertCompared(0, 2, 4);
+    }
+
+
     @Test
     public void emptyRange() {
         RangeComparator comparator = comparator(context(100L, 100L));
index 17dc67c003615b8e02f109067423efe994b2ed3e..b94d22c8fd8071f9c66c55da30c84c42c6d1e35d 100644 (file)
@@ -29,6 +29,11 @@ public class SchemaTest {
         public List<String> disallowedKeyspaces() {
             return disallowedKeyspaces;
         }
+
+        @Override
+        public double partitionSamplingProbability() {
+            return 1;
+        }
     }
 
     @Test