LUCENE-8687: Optimise radix partitioning for points on heap
authoriverase <ivera@apache.org>
Mon, 11 Feb 2019 07:11:23 +0000 (08:11 +0100)
committeriverase <ivera@apache.org>
Mon, 11 Feb 2019 07:11:23 +0000 (08:11 +0100)
lucene/CHANGES.txt
lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextBKDWriter.java
lucene/core/src/java/org/apache/lucene/util/bkd/BKDRadixSelector.java
lucene/core/src/java/org/apache/lucene/util/bkd/BKDWriter.java
lucene/core/src/test/org/apache/lucene/util/bkd/TestBKD.java
lucene/core/src/test/org/apache/lucene/util/bkd/TestBKDRadixSelector.java

index 03e5781..76cf703 100644 (file)
@@ -22,6 +22,8 @@ Improvements
 * LUCENE-8673: Use radix partitioning when merging dimensional points instead
   of sorting all dimensions before hand. (Ignacio Vera, Adrien Grand)
 
+* LUCENE-8687: Optimise radix partitioning for points on heap. (Ignacio Vera)
+
 Other
 
 * LUCENE-8680: Refactor EdgeTree#relateTriangle method. (Ignacio Vera)
index a6276ea..8464284 100644 (file)
@@ -179,16 +179,8 @@ final class SimpleTextBKDWriter implements Closeable {
     // dimensional values (numDims * bytesPerDim) +  docID (int)
     bytesPerDoc = packedBytesLength + Integer.BYTES;
 
-    // As we recurse, we compute temporary partitions of the data, halving the
-    // number of points at each recursion.  Once there are few enough points,
-    // we can switch to sorting in heap instead of offline (on disk).  At any
-    // time in the recursion, we hold the number of points at that level, plus
-    // all recursive halves (i.e. 16 + 8 + 4 + 2) so the memory usage is 2X
-    // what that level would consume, so we multiply by 0.5 to convert from
-    // bytes to points here.  In addition the radix partitioning may sort on memory
-    // double of this size so we multiply by another 0.5.
-
-    maxPointsSortInHeap = (int) (0.25 * (maxMBSortInHeap * 1024 * 1024) / (bytesPerDoc * numDataDims));
+    // Maximum number of points we hold in memory at any time
+    maxPointsSortInHeap = (int) ((maxMBSortInHeap * 1024 * 1024) / (bytesPerDoc * numDataDims));
 
     // Finally, we must be able to hold at least the leaf node in heap during build:
     if (maxPointsSortInHeap < maxPointsInLeafNode) {
@@ -577,25 +569,21 @@ final class SimpleTextBKDWriter implements Closeable {
   // encoding and not have our own ByteSequencesReader/Writer
 
   /** Sort the heap writer by the specified dim */
-  private void sortHeapPointWriter(final HeapPointWriter writer, int dim) {
-    final int pointCount = Math.toIntExact(writer.count());
+  private void sortHeapPointWriter(final HeapPointWriter writer, int from, int to, int dim, int commonPrefixLength) {
     // Tie-break by docID:
-
-    // No need to tie break on ord, for the case where the same doc has the same value in a given dimension indexed more than once: it
-    // can't matter at search time since we don't write ords into the index:
-    new MSBRadixSorter(bytesPerDim + Integer.BYTES) {
+    new MSBRadixSorter(bytesPerDim + Integer.BYTES - commonPrefixLength) {
 
       @Override
       protected int byteAt(int i, int k) {
         assert k >= 0;
-        if (k < bytesPerDim) {
+        if (k + commonPrefixLength < bytesPerDim) {
           // dim bytes
           int block = i / writer.valuesPerBlock;
           int index = i % writer.valuesPerBlock;
-          return writer.blocks.get(block)[index * packedBytesLength + dim * bytesPerDim + k] & 0xff;
+          return writer.blocks.get(block)[index * packedBytesLength + dim * bytesPerDim + k + commonPrefixLength] & 0xff;
         } else {
           // doc id
-          int s = 3 - (k - bytesPerDim);
+          int s = 3 - (k + commonPrefixLength - bytesPerDim);
           return (writer.docIDs[i] >>> (s * 8)) & 0xff;
         }
       }
@@ -605,7 +593,7 @@ final class SimpleTextBKDWriter implements Closeable {
         writer.swap(i, j);
       }
 
-    }.sort(0, pointCount);
+    }.sort(from, to);
   }
 
   private void checkMaxLeafNodeCount(int numLeaves) {
@@ -625,14 +613,13 @@ final class SimpleTextBKDWriter implements Closeable {
       throw new IllegalStateException("already finished");
     }
 
-    PointWriter data;
-
+    BKDRadixSelector.PathSlice writer;
     if (offlinePointWriter != null) {
       offlinePointWriter.close();
-      data = offlinePointWriter;
+      writer = new BKDRadixSelector.PathSlice(offlinePointWriter, 0, pointCount);
       tempInput = null;
     } else {
-      data = heapPointWriter;
+      writer =  new BKDRadixSelector.PathSlice(heapPointWriter, 0, pointCount);
       heapPointWriter = null;
     }
 
@@ -671,7 +658,7 @@ final class SimpleTextBKDWriter implements Closeable {
     try {
 
 
-      build(1, numLeaves, data, out,
+      build(1, numLeaves, writer, out,
           radixSelector, minPackedValue, maxPackedValue,
             splitPackedValues, leafBlockFPs);
 
@@ -1017,7 +1004,7 @@ final class SimpleTextBKDWriter implements Closeable {
 
   /** The array (sized numDims) of PathSlice describe the cell we have currently recursed to. */
   private void build(int nodeID, int leafNodeOffset,
-                     PointWriter data,
+                     BKDRadixSelector.PathSlice points,
                      IndexOutput out,
                      BKDRadixSelector radixSelector,
                      byte[] minPackedValue, byte[] maxPackedValue,
@@ -1030,14 +1017,17 @@ final class SimpleTextBKDWriter implements Closeable {
       // We can write the block in any order so by default we write it sorted by the dimension that has the
       // least number of unique bytes at commonPrefixLengths[dim], which makes compression more efficient
 
-      if (data instanceof HeapPointWriter == false) {
+      HeapPointWriter heapSource;
+      if (points.writer instanceof HeapPointWriter == false) {
         // Adversarial cases can cause this, e.g. very lopsided data, all equal points, such that we started
         // offline, but then kept splitting only in one dimension, and so never had to rewrite into heap writer
-        data = switchToHeap(data);
+        heapSource  = switchToHeap(points.writer);
+      } else {
+        heapSource = (HeapPointWriter) points.writer;
       }
 
-      // We ensured that maxPointsSortInHeap was >= maxPointsInLeafNode, so we better be in heap at this point:
-      HeapPointWriter heapSource = (HeapPointWriter) data;
+      int from = Math.toIntExact(points.start);
+      int to = Math.toIntExact(points.start + points.count);
 
       //we store common prefix on scratch1
       computeCommonPrefixLength(heapSource, scratch1);
@@ -1068,7 +1058,8 @@ final class SimpleTextBKDWriter implements Closeable {
         }
       }
 
-      sortHeapPointWriter(heapSource, sortedDim);
+      // sort the chosen dimension
+      sortHeapPointWriter(heapSource, from, to, sortedDim, commonPrefixLengths[sortedDim]);
 
       // Save the block file pointer:
       leafBlockFPs[nodeID - leafNodeOffset] = out.getFilePointer();
@@ -1076,9 +1067,9 @@ final class SimpleTextBKDWriter implements Closeable {
 
       // Write docIDs first, as their own chunk, so that at intersect time we can add all docIDs w/o
       // loading the values:
-      int count = Math.toIntExact(heapSource.count());
+      int count = to - from;
       assert count > 0: "nodeID=" + nodeID + " leafNodeOffset=" + leafNodeOffset;
-      writeLeafBlockDocs(out, heapSource.docIDs, 0, count);
+      writeLeafBlockDocs(out, heapSource.docIDs, from, count);
 
       // TODO: minor opto: we don't really have to write the actual common prefixes, because BKDReader on recursing can regenerate it for us
       // from the index, much like how terms dict does so from the FST:
@@ -1093,12 +1084,12 @@ final class SimpleTextBKDWriter implements Closeable {
 
         @Override
         public BytesRef apply(int i) {
-          heapSource.getPackedValueSlice(i, scratch);
+          heapSource.getPackedValueSlice(from + i, scratch);
           return scratch;
         }
       };
       assert valuesInOrderAndBounds(count, sortedDim, minPackedValue, maxPackedValue, packedValues,
-          heapSource.docIDs, Math.toIntExact(0));
+          heapSource.docIDs, from);
       writeLeafBlockPackedValues(out, commonPrefixLengths, count, sortedDim, packedValues);
 
     } else {
@@ -1111,26 +1102,23 @@ final class SimpleTextBKDWriter implements Closeable {
         splitDim = 0;
       }
 
-
       assert nodeID < splitPackedValues.length : "nodeID=" + nodeID + " splitValues.length=" + splitPackedValues.length;
 
       // How many points will be in the left tree:
-      long rightCount = data.count() / 2;
-      long leftCount = data.count() - rightCount;
-
-      PointWriter leftPointWriter;
-      PointWriter rightPointWriter;
-      byte[] splitValue;
-
-      try (PointWriter leftPointWriter2 = getPointWriter(leftCount, "left" + splitDim);
-           PointWriter rightPointWriter2 = getPointWriter(rightCount, "right" + splitDim)) {
-        splitValue = radixSelector.select(data, leftPointWriter2, rightPointWriter2, 0, data.count(),  leftCount, splitDim);
-        leftPointWriter = leftPointWriter2;
-        rightPointWriter = rightPointWriter2;
-      } catch (Throwable t) {
-        throw verifyChecksum(t, data);
+      long rightCount = points.count / 2;
+      long leftCount = points.count - rightCount;
+
+      int commonPrefixLen = FutureArrays.mismatch(minPackedValue, splitDim * bytesPerDim,
+          splitDim * bytesPerDim + bytesPerDim, maxPackedValue, splitDim * bytesPerDim,
+          splitDim * bytesPerDim + bytesPerDim);
+      if (commonPrefixLen == -1) {
+        commonPrefixLen = bytesPerDim;
       }
 
+      BKDRadixSelector.PathSlice[] pathSlices = new BKDRadixSelector.PathSlice[2];
+
+      byte[] splitValue =  radixSelector.select(points, pathSlices, points.start, points.start + points.count,  points.start + leftCount, splitDim, commonPrefixLen);
+
       int address = nodeID * (1 + bytesPerDim);
       splitPackedValues[address] = (byte) splitDim;
       System.arraycopy(splitValue, 0, splitPackedValues, address + 1, bytesPerDim);
@@ -1144,15 +1132,13 @@ final class SimpleTextBKDWriter implements Closeable {
       System.arraycopy(splitValue, 0, minSplitPackedValue, splitDim * bytesPerDim, bytesPerDim);
       System.arraycopy(splitValue, 0, maxSplitPackedValue, splitDim * bytesPerDim, bytesPerDim);
 
-
-
       // Recurse on left tree:
-      build(2*nodeID, leafNodeOffset, leftPointWriter, out, radixSelector,
+      build(2*nodeID, leafNodeOffset, pathSlices[0], out, radixSelector,
             minPackedValue, maxSplitPackedValue, splitPackedValues, leafBlockFPs);
 
       // TODO: we could "tail recurse" here?  have our parent discard its refs as we recurse right?
       // Recurse on right tree:
-      build(2*nodeID+1, leafNodeOffset, rightPointWriter, out, radixSelector,
+      build(2*nodeID+1, leafNodeOffset, pathSlices[1], out, radixSelector,
             minSplitPackedValue, maxPackedValue, splitPackedValues, leafBlockFPs);
     }
   }
@@ -1212,15 +1198,6 @@ final class SimpleTextBKDWriter implements Closeable {
     return true;
   }
 
-  PointWriter getPointWriter(long count, String desc) throws IOException {
-    if (count <= maxPointsSortInHeap) {
-      int size = Math.toIntExact(count);
-      return new HeapPointWriter(size, size, packedBytesLength);
-    } else {
-      return new OfflinePointWriter(tempDir, tempFileNamePrefix, packedBytesLength, desc, count);
-    }
-  }
-
   private void write(IndexOutput out, String s) throws IOException {
     SimpleTextUtil.write(out, s, scratch);
   }
index 8d6c852..3bc025c 100644 (file)
@@ -66,7 +66,7 @@ public final class BKDRadixSelector {
     this.bytesPerDim = bytesPerDim;
     this.packedBytesLength = numDim * bytesPerDim;
     this.bytesSorted = bytesPerDim + Integer.BYTES;
-    this.maxPointsSortInHeap = 2 * maxPointsSortInHeap;
+    this.maxPointsSortInHeap = maxPointsSortInHeap;
     int numberOfPointsOffline  = MAX_SIZE_OFFLINE_BUFFER / (packedBytesLength + Integer.BYTES);
     this.offlineBuffer = new byte[numberOfPointsOffline * (packedBytesLength + Integer.BYTES)];
     this.partitionBucket = new int[bytesSorted];
@@ -77,35 +77,54 @@ public final class BKDRadixSelector {
   }
 
   /**
+   *  It uses the provided {@code points} from the given {@code from} to the given {@code to}
+   *  to populate the {@code partitionSlices} array holder (length &gt; 1) with two path slices
+   *  so the path slice at position 0 contains {@code partition - from} points
+   *  where the value of the {@code dim} is lower or equal to the {@code to -from}
+   *  points on the slice at position 1.
    *
-   * Method to partition the input data. It returns the value of the dimension where
-   * the split happens. The method destroys the original writer.
+   *  The {@code dimCommonPrefix} provides a hint for the length of the common prefix length for
+   *  the {@code dim} where are partitioning the points.
    *
+   *  It return the value of the {@code dim} at the partition point.
+   *
+   *  If the provided {@code points} is wrapping an {@link OfflinePointWriter}, the
+   *  writer is destroyed in the process to save disk space.
    */
-  public byte[] select(PointWriter points, PointWriter left, PointWriter right, long from, long to, long partitionPoint, int dim) throws IOException {
+  public byte[] select(PathSlice points, PathSlice[] partitionSlices, long from, long to, long partitionPoint, int dim, int dimCommonPrefix) throws IOException {
     checkArgs(from, to, partitionPoint);
 
+    assert partitionSlices.length > 1;
+
     //If we are on heap then we just select on heap
-    if (points instanceof HeapPointWriter) {
-      return heapSelect((HeapPointWriter) points, left, right, dim, Math.toIntExact(from), Math.toIntExact(to),  Math.toIntExact(partitionPoint), 0);
+    if (points.writer instanceof HeapPointWriter) {
+      byte[] partition = heapRadixSelect((HeapPointWriter) points.writer, dim, Math.toIntExact(from), Math.toIntExact(to),  Math.toIntExact(partitionPoint), dimCommonPrefix);
+      partitionSlices[0] = new PathSlice(points.writer, from, partitionPoint - from);
+      partitionSlices[1] = new PathSlice(points.writer, partitionPoint, to - partitionPoint);
+      return partition;
     }
 
     //reset histogram
     for (int i = 0; i < bytesSorted; i++) {
       Arrays.fill(histogram[i], 0);
     }
-    OfflinePointWriter offlinePointWriter = (OfflinePointWriter) points;
+    OfflinePointWriter offlinePointWriter = (OfflinePointWriter) points.writer;
 
-    //find common prefix, it does already set histogram values if needed
-    int commonPrefix = findCommonPrefix(offlinePointWriter, from, to, dim);
+    //find common prefix from dimCommonPrefix, it does already set histogram values if needed
+    int commonPrefix = findCommonPrefix(offlinePointWriter, from, to, dim, dimCommonPrefix);
 
-    //if all equals we just partition the data
-    if (commonPrefix ==  bytesSorted) {
-      partition(offlinePointWriter, left,  right, null, from, to, dim, commonPrefix - 1, partitionPoint);
-      return partitionPointFromCommonPrefix();
+    try (PointWriter left = getPointWriter(partitionPoint - from, "left" + dim);
+         PointWriter right = getPointWriter(to - partitionPoint, "right" + dim)) {
+      partitionSlices[0] = new PathSlice(left, 0, partitionPoint - from);
+      partitionSlices[1] = new PathSlice(right, 0, to - partitionPoint);
+      //if all equals we just partition the points
+      if (commonPrefix == bytesSorted) {
+        offlinePartition(offlinePointWriter, left, right, null, from, to, dim, commonPrefix - 1, partitionPoint);
+        return partitionPointFromCommonPrefix();
+      }
+      //let's rock'n'roll
+      return buildHistogramAndPartition(offlinePointWriter, left, right, from, to, partitionPoint, 0, commonPrefix, dim);
     }
-    //let's rock'n'roll
-    return buildHistogramAndPartition(offlinePointWriter, left, right, from, to, partitionPoint, 0, commonPrefix, dim);
   }
 
   void checkArgs(long from, long to, long partitionPoint) {
@@ -117,11 +136,12 @@ public final class BKDRadixSelector {
     }
   }
 
-  private int findCommonPrefix(OfflinePointWriter points, long from, long to, int dim) throws IOException{
+  private int findCommonPrefix(OfflinePointWriter points, long from, long to, int dim, int dimCommonPrefix) throws IOException{
     //find common prefix
     byte[] commonPrefix = new byte[bytesSorted];
     int commonPrefixPosition = bytesSorted;
     try (OfflinePointReader reader = points.getReader(from, to - from, offlineBuffer)) {
+      assert commonPrefixPosition > dimCommonPrefix;
       reader.next();
       reader.packedValueWithDocId(bytesRef1);
       // copy dimension
@@ -131,21 +151,22 @@ public final class BKDRadixSelector {
       for (long i = from + 1; i< to; i++) {
         reader.next();
         reader.packedValueWithDocId(bytesRef1);
-        int startIndex =  dim * bytesPerDim;
-        int endIndex  = (commonPrefixPosition > bytesPerDim) ? startIndex + bytesPerDim :  startIndex + commonPrefixPosition;
-        int j = FutureArrays.mismatch(commonPrefix, 0, endIndex - startIndex, bytesRef1.bytes, bytesRef1.offset + startIndex, bytesRef1.offset + endIndex);
+        int startIndex =  (dimCommonPrefix > bytesPerDim) ? bytesPerDim : dimCommonPrefix;
+        int endIndex  = (commonPrefixPosition > bytesPerDim) ? bytesPerDim :  commonPrefixPosition;
+        int j = FutureArrays.mismatch(commonPrefix, startIndex, endIndex, bytesRef1.bytes, bytesRef1.offset + dim * bytesPerDim + startIndex, bytesRef1.offset + dim * bytesPerDim + endIndex);
         if (j == 0) {
-          return 0;
+          commonPrefixPosition = dimCommonPrefix;
+          break;
         } else if (j == -1) {
           if (commonPrefixPosition > bytesPerDim) {
             //tie-break on docID
-            int k = FutureArrays.mismatch(commonPrefix, bytesPerDim, commonPrefixPosition, bytesRef1.bytes, bytesRef1.offset + packedBytesLength, bytesRef1.offset + packedBytesLength + commonPrefixPosition - bytesPerDim );
+            int k = FutureArrays.mismatch(commonPrefix, bytesPerDim, commonPrefixPosition, bytesRef1.bytes, bytesRef1.offset + packedBytesLength, bytesRef1.offset + packedBytesLength + commonPrefixPosition - bytesPerDim);
             if (k != -1) {
               commonPrefixPosition = bytesPerDim + k;
             }
           }
         } else {
-          commonPrefixPosition = j;
+          commonPrefixPosition = dimCommonPrefix + j;
         }
       }
     }
@@ -196,33 +217,29 @@ public final class BKDRadixSelector {
     //special case when be have lot of points that are equal
     if (commonPrefix == bytesSorted - 1) {
       long tieBreakCount =(partitionPoint - from - leftCount);
-      partition(points, left,  right, null, from, to, dim, commonPrefix, tieBreakCount);
+      offlinePartition(points, left,  right, null, from, to, dim, commonPrefix, tieBreakCount);
       return partitionPointFromCommonPrefix();
     }
 
     //create the delta points writer
     PointWriter deltaPoints;
-    if (delta <= maxPointsSortInHeap) {
-      deltaPoints =  new HeapPointWriter(Math.toIntExact(delta), Math.toIntExact(delta), packedBytesLength);
-    } else {
-      deltaPoints = new OfflinePointWriter(tempDir, tempFileNamePrefix, packedBytesLength, "delta" + iteration, delta);
+    try (PointWriter tempDeltaPoints = getDeltaPointWriter(left, right, delta, iteration)) {
+      //divide the points. This actually destroys the current writer
+      offlinePartition(points, left, right, tempDeltaPoints, from, to, dim, commonPrefix, 0);
+      deltaPoints = tempDeltaPoints;
     }
-    //divide the points. This actually destroys the current writer
-    partition(points, left, right, deltaPoints, from, to, dim, commonPrefix, 0);
-    //close delta point writer
-    deltaPoints.close();
 
     long newPartitionPoint = partitionPoint - from - leftCount;
 
     if (deltaPoints instanceof HeapPointWriter) {
-      return heapSelect((HeapPointWriter) deltaPoints, left, right, dim, 0, (int) deltaPoints.count(), Math.toIntExact(newPartitionPoint), ++commonPrefix);
+      return heapPartition((HeapPointWriter) deltaPoints, left, right, dim, 0, (int) deltaPoints.count(), Math.toIntExact(newPartitionPoint), ++commonPrefix);
     } else {
       return buildHistogramAndPartition((OfflinePointWriter) deltaPoints, left, right, 0, deltaPoints.count(), newPartitionPoint, ++iteration, ++commonPrefix, dim);
     }
   }
 
-  private void partition(OfflinePointWriter points, PointWriter left, PointWriter right, PointWriter deltaPoints,
-                           long from, long to, int dim, int bytePosition, long numDocsTiebreak) throws IOException {
+  private void offlinePartition(OfflinePointWriter points, PointWriter left, PointWriter right, PointWriter deltaPoints,
+                                long from, long to, int dim, int bytePosition, long numDocsTiebreak) throws IOException {
     assert bytePosition == bytesSorted -1 || deltaPoints != null;
     long tiebreakCounter = 0;
     try (OfflinePointReader reader = points.getReader(from, to - from, offlineBuffer)) {
@@ -269,7 +286,24 @@ public final class BKDRadixSelector {
     return partition;
   }
 
-  private byte[] heapSelect(HeapPointWriter points, PointWriter left, PointWriter right, int dim, int from, int to, int partitionPoint, int commonPrefix) throws IOException {
+  private byte[] heapPartition(HeapPointWriter points, PointWriter left, PointWriter right, int dim, int from, int to, int partitionPoint, int commonPrefix) throws IOException {
+
+    byte[] partition = heapRadixSelect(points, dim, from, to, partitionPoint, commonPrefix);
+
+    for (int i = from; i < to; i++) {
+      points.getPackedValueSlice(i, bytesRef1);
+      int docID = points.docIDs[i];
+      if (i < partitionPoint) {
+        left.append(bytesRef1, docID);
+      } else {
+        right.append(bytesRef1, docID);
+      }
+    }
+
+    return partition;
+  }
+
+  private byte[] heapRadixSelect(HeapPointWriter points, int dim, int from, int to, int partitionPoint, int commonPrefix) {
     final int offset = dim * bytesPerDim + commonPrefix;
     new RadixSelector(bytesSorted - commonPrefix) {
 
@@ -294,18 +328,59 @@ public final class BKDRadixSelector {
       }
     }.select(from, to, partitionPoint);
 
-    for (int i = from; i < to; i++) {
-      points.getPackedValueSlice(i, bytesRef1);
-      int docID = points.docIDs[i];
-      if (i < partitionPoint) {
-        left.append(bytesRef1, docID);
-      } else {
-        right.append(bytesRef1, docID);
-      }
-    }
     byte[] partition = new byte[bytesPerDim];
     points.getPackedValueSlice(partitionPoint, bytesRef1);
     System.arraycopy(bytesRef1.bytes, bytesRef1.offset + dim * bytesPerDim, partition, 0, bytesPerDim);
     return partition;
   }
+
+  private PointWriter getDeltaPointWriter(PointWriter left, PointWriter right, long delta, int iteration) throws IOException {
+    if (delta <= getMaxPointsSortInHeap(left, right)) {
+      return  new HeapPointWriter(Math.toIntExact(delta), Math.toIntExact(delta), packedBytesLength);
+    } else {
+      return new OfflinePointWriter(tempDir, tempFileNamePrefix, packedBytesLength, "delta" + iteration, delta);
+    }
+  }
+
+  private int getMaxPointsSortInHeap(PointWriter left, PointWriter right) {
+    int pointsUsed = 0;
+    if (left instanceof HeapPointWriter) {
+      pointsUsed += ((HeapPointWriter) left).maxSize;
+    }
+    if (right instanceof HeapPointWriter) {
+      pointsUsed += ((HeapPointWriter) right).maxSize;
+    }
+    assert maxPointsSortInHeap >= pointsUsed;
+    return maxPointsSortInHeap - pointsUsed;
+  }
+
+  PointWriter getPointWriter(long count, String desc) throws IOException {
+    //As we recurse, we hold two on-heap point writers at any point. Therefore the
+    //max size for these objects is half of the total points we can have on-heap.
+    if (count <= maxPointsSortInHeap / 2) {
+      int size = Math.toIntExact(count);
+      return new HeapPointWriter(size, size, packedBytesLength);
+    } else {
+      return new OfflinePointWriter(tempDir, tempFileNamePrefix, packedBytesLength, desc, count);
+    }
+  }
+
+  /** Sliced reference to points in an PointWriter. */
+  public static final class PathSlice {
+    public final PointWriter writer;
+    public final long start;
+    public final long count;
+
+    public PathSlice(PointWriter writer, long start, long count) {
+      this.writer = writer;
+      this.start = start;
+      this.count = count;
+    }
+
+    @Override
+    public String toString() {
+      return "PathSlice(start=" + start + " count=" + count + " writer=" + writer + ")";
+    }
+  }
+
 }
index a8ee7c5..3e87378 100644 (file)
@@ -172,17 +172,8 @@ public class BKDWriter implements Closeable {
     // dimensional values (numDims * bytesPerDim) + docID (int)
     bytesPerDoc = packedBytesLength + Integer.BYTES;
 
-
-    // As we recurse, we compute temporary partitions of the data, halving the
-    // number of points at each recursion.  Once there are few enough points,
-    // we can switch to sorting in heap instead of offline (on disk).  At any
-    // time in the recursion, we hold the number of points at that level, plus
-    // all recursive halves (i.e. 16 + 8 + 4 + 2) so the memory usage is 2X
-    // what that level would consume, so we multiply by 0.5 to convert from
-    // bytes to points here.  In addition the radix partitioning may sort on memory
-    // double of this size so we multiply by another 0.5.
-
-    maxPointsSortInHeap = (int) (0.25 * (maxMBSortInHeap * 1024 * 1024) / (bytesPerDoc));
+    // Maximum number of points we hold in memory at any time
+    maxPointsSortInHeap = (int) ((maxMBSortInHeap * 1024 * 1024) / (bytesPerDoc));
 
     // Finally, we must be able to hold at least the leaf node in heap during build:
     if (maxPointsSortInHeap < maxPointsInLeafNode) {
@@ -402,7 +393,6 @@ public class BKDWriter implements Closeable {
     }
   }
 
-
   /* In the 2+D case, we recursively pick the split dimension, compute the
    * median value and partition other values around it. */
   private long writeFieldNDims(IndexOutput out, String fieldName, MutablePointValues values) throws IOException {
@@ -722,7 +712,7 @@ public class BKDWriter implements Closeable {
   // encoding and not have our own ByteSequencesReader/Writer
 
   /** Sort the heap writer by the specified dim */
-  private void sortHeapPointWriter(final HeapPointWriter writer, int pointCount, int dim, int commonPrefixLength) {
+  private void sortHeapPointWriter(final HeapPointWriter writer, int from, int to, int dim, int commonPrefixLength) {
     // Tie-break by docID:
     new MSBRadixSorter(bytesPerDim + Integer.BYTES - commonPrefixLength) {
 
@@ -746,7 +736,7 @@ public class BKDWriter implements Closeable {
         writer.swap(i, j);
       }
 
-    }.sort(0, pointCount);
+    }.sort(from, to);
   }
 
   // useful for debugging:
@@ -784,20 +774,20 @@ public class BKDWriter implements Closeable {
       throw new IllegalStateException("already finished");
     }
 
-    PointWriter writer;
+    if (pointCount == 0) {
+      throw new IllegalStateException("must index at least one point");
+    }
+
+    BKDRadixSelector.PathSlice points;
     if (offlinePointWriter != null) {
       offlinePointWriter.close();
-      writer = offlinePointWriter;
+      points = new BKDRadixSelector.PathSlice(offlinePointWriter, 0, pointCount);
       tempInput = null;
     } else {
-      writer = heapPointWriter;
+      points = new BKDRadixSelector.PathSlice(heapPointWriter, 0, pointCount);
       heapPointWriter = null;
     }
 
-    if (pointCount == 0) {
-      throw new IllegalStateException("must index at least one point");
-    }
-
     long countPerLeaf = pointCount;
     long innerNodeCount = 1;
 
@@ -829,7 +819,7 @@ public class BKDWriter implements Closeable {
     try {
 
       final int[] parentSplits = new int[numIndexDims];
-      build(1, numLeaves, writer,
+      build(1, numLeaves, points,
              out, radixSelector,
             minPackedValue, maxPackedValue,
             parentSplits,
@@ -1429,7 +1419,7 @@ public class BKDWriter implements Closeable {
   /** The point writer contains the data that is going to be splitted using radix selection.
   /*  This method is used when we are merging previously written segments, in the numDims > 1 case. */
   private void build(int nodeID, int leafNodeOffset,
-                     PointWriter points,
+                     BKDRadixSelector.PathSlice points,
                      IndexOutput out,
                      BKDRadixSelector radixSelector,
                      byte[] minPackedValue, byte[] maxPackedValue,
@@ -1442,18 +1432,19 @@ public class BKDWriter implements Closeable {
       // Leaf node: write block
       // We can write the block in any order so by default we write it sorted by the dimension that has the
       // least number of unique bytes at commonPrefixLengths[dim], which makes compression more efficient
-
-      if (points instanceof HeapPointWriter == false) {
+      HeapPointWriter heapSource;
+      if (points.writer instanceof HeapPointWriter == false) {
         // Adversarial cases can cause this, e.g. very lopsided data, all equal points, such that we started
         // offline, but then kept splitting only in one dimension, and so never had to rewrite into heap writer
-        points = switchToHeap(points);
+        heapSource = switchToHeap(points.writer);
+      } else {
+        heapSource = (HeapPointWriter) points.writer;
       }
 
-      // We ensured that maxPointsSortInHeap was >= maxPointsInLeafNode, so we better be in heap at this point:
-      HeapPointWriter heapSource = (HeapPointWriter) points;
-
+      int from = Math.toIntExact(points.start);
+      int to = Math.toIntExact(points.start + points.count);
       //we store common prefix on scratch1
-      computeCommonPrefixLength(heapSource, scratch1);
+      computeCommonPrefixLength(heapSource, scratch1, from, to);
 
       int sortedDim = 0;
       int sortedDimCardinality = Integer.MAX_VALUE;
@@ -1468,7 +1459,7 @@ public class BKDWriter implements Closeable {
         int prefix = commonPrefixLengths[dim];
         if (prefix < bytesPerDim) {
           int offset = dim * bytesPerDim;
-          for (int i = 0; i < heapSource.count(); ++i) {
+          for (int i = from; i < to; ++i) {
             heapSource.getPackedValueSlice(i, scratchBytesRef1);
             int bucket = scratchBytesRef1.bytes[scratchBytesRef1.offset + offset + prefix] & 0xff;
             usedBytes[dim].set(bucket);
@@ -1482,7 +1473,7 @@ public class BKDWriter implements Closeable {
       }
 
       // sort the chosen dimension
-      sortHeapPointWriter(heapSource, Math.toIntExact(heapSource.count()), sortedDim, commonPrefixLengths[sortedDim]);
+      sortHeapPointWriter(heapSource, from, to, sortedDim, commonPrefixLengths[sortedDim]);
 
       // Save the block file pointer:
       leafBlockFPs[nodeID - leafNodeOffset] = out.getFilePointer();
@@ -1490,9 +1481,9 @@ public class BKDWriter implements Closeable {
 
       // Write docIDs first, as their own chunk, so that at intersect time we can add all docIDs w/o
       // loading the values:
-      int count = Math.toIntExact(heapSource.count());
+      int count = to - from;
       assert count > 0: "nodeID=" + nodeID + " leafNodeOffset=" + leafNodeOffset;
-      writeLeafBlockDocs(out, heapSource.docIDs, Math.toIntExact(0), count);
+      writeLeafBlockDocs(out, heapSource.docIDs, from, count);
 
       // TODO: minor opto: we don't really have to write the actual common prefixes, because BKDReader on recursing can regenerate it for us
       // from the index, much like how terms dict does so from the FST:
@@ -1510,12 +1501,12 @@ public class BKDWriter implements Closeable {
 
         @Override
         public BytesRef apply(int i) {
-          heapSource.getPackedValueSlice(Math.toIntExact(i), scratch);
+          heapSource.getPackedValueSlice(from + i, scratch);
           return scratch;
         }
       };
       assert valuesInOrderAndBounds(count, sortedDim, minPackedValue, maxPackedValue, packedValues,
-          heapSource.docIDs, Math.toIntExact(0));
+          heapSource.docIDs, from);
       writeLeafBlockPackedValues(out, commonPrefixLengths, count, sortedDim, packedValues);
 
     } else {
@@ -1528,25 +1519,23 @@ public class BKDWriter implements Closeable {
         splitDim = 0;
       }
 
-
       assert nodeID < splitPackedValues.length : "nodeID=" + nodeID + " splitValues.length=" + splitPackedValues.length;
 
       // How many points will be in the left tree:
-      long rightCount = points.count() / 2;
-      long leftCount = points.count() - rightCount;
-
-      PointWriter leftPointWriter;
-      PointWriter rightPointWriter;
-      byte[] splitValue;
-      try (PointWriter tempLeftPointWriter = getPointWriter(leftCount, "left" + splitDim);
-           PointWriter tempRightPointWriter = getPointWriter(rightCount, "right" + splitDim)) {
-        splitValue = radixSelector.select(points, tempLeftPointWriter, tempRightPointWriter, 0, points.count(),  leftCount, splitDim);
-        leftPointWriter = tempLeftPointWriter;
-        rightPointWriter = tempRightPointWriter;
-      } catch (Throwable t) {
-        throw verifyChecksum(t, points);
+      long rightCount = points.count / 2;
+      long leftCount = points.count - rightCount;
+
+      BKDRadixSelector.PathSlice[] slices = new BKDRadixSelector.PathSlice[2];
+
+      int commonPrefixLen = FutureArrays.mismatch(minPackedValue, splitDim * bytesPerDim,
+          splitDim * bytesPerDim + bytesPerDim, maxPackedValue, splitDim * bytesPerDim,
+          splitDim * bytesPerDim + bytesPerDim);
+      if (commonPrefixLen == -1) {
+        commonPrefixLen = bytesPerDim;
       }
 
+      byte[] splitValue = radixSelector.select(points, slices, points.start, points.start + points.count,  points.start + leftCount, splitDim, commonPrefixLen);
+
       int address = nodeID * (1 + bytesPerDim);
       splitPackedValues[address] = (byte) splitDim;
       System.arraycopy(splitValue, 0, splitPackedValues, address + 1, bytesPerDim);
@@ -1562,12 +1551,12 @@ public class BKDWriter implements Closeable {
 
       parentSplits[splitDim]++;
       // Recurse on left tree:
-      build(2 * nodeID, leafNodeOffset, leftPointWriter,
+      build(2 * nodeID, leafNodeOffset, slices[0],
           out, radixSelector, minPackedValue, maxSplitPackedValue,
           parentSplits, splitPackedValues, leafBlockFPs);
 
       // Recurse on right tree:
-      build(2 * nodeID + 1, leafNodeOffset, rightPointWriter,
+      build(2 * nodeID + 1, leafNodeOffset, slices[1],
           out, radixSelector, minSplitPackedValue, maxPackedValue
           , parentSplits, splitPackedValues, leafBlockFPs);
 
@@ -1575,14 +1564,14 @@ public class BKDWriter implements Closeable {
     }
   }
 
-  private void computeCommonPrefixLength(HeapPointWriter heapPointWriter, byte[] commonPrefix) {
+  private void computeCommonPrefixLength(HeapPointWriter heapPointWriter, byte[] commonPrefix, int from, int to) {
     Arrays.fill(commonPrefixLengths, bytesPerDim);
     scratchBytesRef1.length = packedBytesLength;
-    heapPointWriter.getPackedValueSlice(0, scratchBytesRef1);
+    heapPointWriter.getPackedValueSlice(from, scratchBytesRef1);
     for (int dim = 0; dim < numDataDims; dim++) {
       System.arraycopy(scratchBytesRef1.bytes, scratchBytesRef1.offset + dim * bytesPerDim, commonPrefix, dim * bytesPerDim, bytesPerDim);
     }
-    for (int i = 1; i < heapPointWriter.count(); i++) {
+    for (int i = from + 1; i < to; i++) {
       heapPointWriter.getPackedValueSlice(i, scratchBytesRef1);
       for (int dim = 0; dim < numDataDims; dim++) {
         if (commonPrefixLengths[dim] != 0) {
@@ -1629,14 +1618,4 @@ public class BKDWriter implements Closeable {
     System.arraycopy(packedValue, packedValueOffset, lastPackedValue, 0, packedBytesLength);
     return true;
   }
-
-  PointWriter getPointWriter(long count, String desc) throws IOException {
-    if (count <= maxPointsSortInHeap) {
-      int size = Math.toIntExact(count);
-      return new HeapPointWriter(size, size, packedBytesLength);
-    } else {
-      return new OfflinePointWriter(tempDir, tempFileNamePrefix, packedBytesLength, desc, count);
-    }
-  }
-
 }
index 01d05a0..6e3863f 100644 (file)
@@ -971,7 +971,7 @@ public class TestBKD extends LuceneTestCase {
         public IndexOutput createTempOutput(String prefix, String suffix, IOContext context) throws IOException {
           IndexOutput out = in.createTempOutput(prefix, suffix, context);
           //System.out.println("prefix=" + prefix + " suffix=" + suffix);
-          if (corrupted == false && suffix.equals("bkd_left1")) {
+          if (corrupted == false && suffix.equals("bkd_left0")) {
             //System.out.println("now corrupt byte=" + x + " prefix=" + prefix + " suffix=" + suffix);
             corrupted = true;
             return new CorruptingIndexOutput(dir0, 22072, out);
index ca61b02..558b9f2 100644 (file)
@@ -48,7 +48,8 @@ public class TestBKDRadixSelector extends LuceneTestCase {
     NumericUtils.intToSortableBytes(4, bytes, 0);
     points.append(bytes, 3);
     points.close();
-    verify(dir, points, dimensions, 0, values, middle, packedLength, bytesPerDimensions, 0);
+    PointWriter copy = copyPoints(dir,points, packedLength);
+    verify(dir, copy, dimensions, 0, values, middle, packedLength, bytesPerDimensions, 0);
     dir.close();
   }
 
@@ -183,24 +184,31 @@ public class TestBKDRadixSelector extends LuceneTestCase {
   private void verify(Directory dir, PointWriter points, int dimensions, long start, long end, long middle, int packedLength, int bytesPerDimensions, int sortedOnHeap) throws IOException{
     for (int splitDim =0; splitDim < dimensions; splitDim++) {
       PointWriter copy = copyPoints(dir, points, packedLength);
-      PointWriter leftPointWriter = getRandomPointWriter(dir, middle - start, packedLength);
-      PointWriter rightPointWriter = getRandomPointWriter(dir, end - middle, packedLength);
+      BKDRadixSelector.PathSlice[] slices = new BKDRadixSelector.PathSlice[2];
       BKDRadixSelector radixSelector = new BKDRadixSelector(dimensions, bytesPerDimensions, sortedOnHeap, dir, "test");
-      byte[] partitionPoint = radixSelector.select(copy, leftPointWriter, rightPointWriter, start, end, middle, splitDim);
-      leftPointWriter.close();
-      rightPointWriter.close();
-      byte[] max = getMax(leftPointWriter, middle - start, bytesPerDimensions, splitDim);
-      byte[] min = getMin(rightPointWriter, end - middle, bytesPerDimensions, splitDim);
+      BKDRadixSelector.PathSlice copySlice = new BKDRadixSelector.PathSlice(copy, 0, copy.count());
+      byte[] pointsMax = getMax(copySlice, bytesPerDimensions, splitDim);
+      byte[] pointsMin = getMin(copySlice, bytesPerDimensions, splitDim);
+      int commonPrefixLength = FutureArrays.mismatch(pointsMin, 0, bytesPerDimensions, pointsMax, 0, bytesPerDimensions);
+      if (commonPrefixLength == -1) {
+        commonPrefixLength = bytesPerDimensions;
+      }
+      int commonPrefixLengthInput = (random().nextBoolean()) ? commonPrefixLength : commonPrefixLength == 0 ? 0 : random().nextInt(commonPrefixLength);
+      byte[] partitionPoint = radixSelector.select(copySlice, slices, start, end, middle, splitDim, commonPrefixLengthInput);
+      assertEquals(middle - start, slices[0].count);
+      assertEquals(end - middle, slices[1].count);
+      byte[] max = getMax(slices[0], bytesPerDimensions, splitDim);
+      byte[] min = getMin(slices[1], bytesPerDimensions, splitDim);
       int cmp = FutureArrays.compareUnsigned(max, 0, bytesPerDimensions, min, 0, bytesPerDimensions);
       assertTrue(cmp <= 0);
       if (cmp == 0) {
-        int maxDocID = getMaxDocId(leftPointWriter, middle - start, bytesPerDimensions, splitDim, partitionPoint);
-        int minDocId = getMinDocId(rightPointWriter, end - middle, bytesPerDimensions, splitDim, partitionPoint);
+        int maxDocID = getMaxDocId(slices[0], bytesPerDimensions, splitDim, partitionPoint);
+        int minDocId = getMinDocId(slices[1], bytesPerDimensions, splitDim, partitionPoint);
         assertTrue(minDocId >= maxDocID);
       }
       assertTrue(Arrays.equals(partitionPoint, min));
-      leftPointWriter.destroy();
-      rightPointWriter.destroy();
+      slices[0].writer.destroy();
+      slices[1].writer.destroy();
     }
     points.destroy();
   }
@@ -236,10 +244,10 @@ public class TestBKDRadixSelector extends LuceneTestCase {
     return dir;
   }
 
-  private byte[] getMin(PointWriter p, long size, int bytesPerDimension, int dimension) throws  IOException {
+  private byte[] getMin(BKDRadixSelector.PathSlice p, int bytesPerDimension, int dimension) throws  IOException {
     byte[] min = new byte[bytesPerDimension];
     Arrays.fill(min, (byte) 0xff);
-    try (PointReader reader = p.getReader(0, size)) {
+    try (PointReader reader = p.writer.getReader(p.start, p.count)) {
       byte[] value = new byte[bytesPerDimension];
       BytesRef packedValue = new BytesRef();
       while (reader.next()) {
@@ -253,9 +261,9 @@ public class TestBKDRadixSelector extends LuceneTestCase {
     return min;
   }
 
-  private int getMinDocId(PointWriter p, long size, int bytesPerDimension, int dimension, byte[] partitionPoint) throws  IOException {
+  private int getMinDocId(BKDRadixSelector.PathSlice p, int bytesPerDimension, int dimension, byte[] partitionPoint) throws  IOException {
    int docID = Integer.MAX_VALUE;
-    try (PointReader reader = p.getReader(0, size)) {
+    try (PointReader reader = p.writer.getReader(p.start, p.count)) {
       BytesRef packedValue = new BytesRef();
       while (reader.next()) {
         reader.packedValue(packedValue);
@@ -271,10 +279,10 @@ public class TestBKDRadixSelector extends LuceneTestCase {
     return docID;
   }
 
-  private byte[] getMax(PointWriter p, long size, int bytesPerDimension, int dimension) throws  IOException {
+  private byte[] getMax(BKDRadixSelector.PathSlice p, int bytesPerDimension, int dimension) throws  IOException {
     byte[] max = new byte[bytesPerDimension];
     Arrays.fill(max, (byte) 0);
-    try (PointReader reader = p.getReader(0, size)) {
+    try (PointReader reader = p.writer.getReader(p.start, p.count)) {
       byte[] value = new byte[bytesPerDimension];
       BytesRef packedValue = new BytesRef();
       while (reader.next()) {
@@ -288,9 +296,9 @@ public class TestBKDRadixSelector extends LuceneTestCase {
     return max;
   }
 
-  private int getMaxDocId(PointWriter p, long size, int bytesPerDimension, int dimension, byte[] partitionPoint) throws  IOException {
+  private int getMaxDocId(BKDRadixSelector.PathSlice p, int bytesPerDimension, int dimension, byte[] partitionPoint) throws  IOException {
     int docID = Integer.MIN_VALUE;
-    try (PointReader reader = p.getReader(0, size)) {
+    try (PointReader reader = p.writer.getReader(p.start, p.count)) {
       BytesRef packedValue = new BytesRef();
       while (reader.next()) {
         reader.packedValue(packedValue);