[SPARK-26450][SQL] Avoid rebuilding map of schema for every column in projection
authorBruce Robbins <bersprockets@gmail.com>
Sun, 13 Jan 2019 22:54:19 +0000 (23:54 +0100)
committerHerman van Hovell <hvanhovell@databricks.com>
Sun, 13 Jan 2019 22:54:19 +0000 (23:54 +0100)
## What changes were proposed in this pull request?

When creating some unsafe projections, Spark rebuilds the map of schema attributes once for each expression in the projection. Some file format readers create one unsafe projection per input file, others create one per task. ProjectExec also creates one unsafe projection per task. As a result, for wide queries on wide tables, Spark might build the map of schema attributes hundreds of thousands of times.

This PR changes two functions to reuse the same AttributeSeq instance when creating BoundReference objects for each expression in the projection. This avoids the repeated rebuilding of the map of schema attributes.

### Benchmarks

The time saved by this PR depends on size of the schema, size of the projection, number of input files (or number of file splits), number of tasks, and file format. I chose a couple of example cases.

In the following tests, I ran the query
```sql
select * from table where id1 = 1
```

Matching rows are about 0.2% of the table.

#### Orc table 6000 columns, 500K rows, 34 input files

baseline | pr | improvement
----|----|----
1.772306 min | 1.487267 min | 16.082943%

#### Orc table 6000 columns, 500K rows, *17* input files

baseline | pr | improvement
----|----|----
 1.656400 min | 1.423550 min | 14.057595%

#### Orc table 60 columns, 50M rows, 34 input files

baseline | pr | improvement
----|----|----
0.299878 min | 0.290339 min | 3.180926%

#### Parquet table 6000 columns, 500K rows, 34 input files

baseline | pr | improvement
----|----|----
1.478306 min | 1.373728 min | 7.074165%

Note: The parquet reader does not create an unsafe projection. However, the filter operation in the query causes the planner to add a ProjectExec, which does create an unsafe projection for each task. So these results have nothing to do with Parquet itself.

#### Parquet table 60 columns, 50M rows, 34 input files

baseline | pr | improvement
----|----|----
0.245006 min | 0.242200 min | 1.145099%

#### CSV table 6000 columns, 500K rows, 34 input files

baseline | pr | improvement
----|----|----
2.390117 min | 2.182778 min | 8.674844%

#### CSV table 60 columns, 50M rows, 34 input files

baseline | pr | improvement
----|----|----
1.520911 min | 1.510211 min | 0.703526%

## How was this patch tested?

SQL unit tests
Python core and SQL test

Closes #23392 from bersprockets/norebuild.

Authored-by: Bruce Robbins <bersprockets@gmail.com>
Signed-off-by: Herman van Hovell <hvanhovell@databricks.com>
17 files changed:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala

index ea8c369..7ae5924 100644 (file)
@@ -86,4 +86,13 @@ object BindReferences extends Logging {
       }
     }.asInstanceOf[A] // Kind of a hack, but safe.  TODO: Tighten return type when possible.
   }
+
+  /**
+   * A helper function to bind given expressions to an input schema.
+   */
+  def bindReferences[A <: Expression](
+      expressions: Seq[A],
+      input: AttributeSeq): Seq[A] = {
+    expressions.map(BindReferences.bindReference(_, input))
+  }
 }
index 122a564..5c8aa4e 100644 (file)
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
 
 
@@ -30,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
  */
 class InterpretedMutableProjection(expressions: Seq[Expression]) extends MutableProjection {
   def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
-    this(toBoundExprs(expressions, inputSchema))
+    this(bindReferences(expressions, inputSchema))
 
   private[this] val buffer = new Array[Any](expressions.size)
 
index b48f7ba..eaaf94b 100644 (file)
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateMutableProjection, GenerateSafeProjection, GenerateUnsafeProjection}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{DataType, StructType}
@@ -30,7 +31,7 @@ import org.apache.spark.sql.types.{DataType, StructType}
  */
 class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
   def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
-    this(expressions.map(BindReferences.bindReference(_, inputSchema)))
+    this(bindReferences(expressions, inputSchema))
 
   override def initialize(partitionIndex: Int): Unit = {
     expressions.foreach(_.foreach {
@@ -99,7 +100,7 @@ object MutableProjection
    * `inputSchema`.
    */
   def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): MutableProjection = {
-    create(toBoundExprs(exprs, inputSchema))
+    create(bindReferences(exprs, inputSchema))
   }
 }
 
@@ -162,7 +163,7 @@ object UnsafeProjection
    * `inputSchema`.
    */
   def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = {
-    create(toBoundExprs(exprs, inputSchema))
+    create(bindReferences(exprs, inputSchema))
   }
 }
 
@@ -203,6 +204,6 @@ object SafeProjection extends CodeGeneratorWithInterpretedFallback[Seq[Expressio
    * `inputSchema`.
    */
   def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): Projection = {
-    create(toBoundExprs(exprs, inputSchema))
+    create(bindReferences(exprs, inputSchema))
   }
 }
index d588e7f..838bd1c 100644 (file)
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions.codegen
 
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
 
 // MutableProjection is not accessible in Java
@@ -35,7 +36,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
     in.map(ExpressionCanonicalizer.execute)
 
   protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
-    in.map(BindReferences.bindReference(_, inputSchema))
+    bindReferences(in, inputSchema)
 
   def generate(
       expressions: Seq[Expression],
index 283fd2a..b66b80a 100644 (file)
@@ -25,6 +25,7 @@ import com.esotericsoftware.kryo.io.{Input, Output}
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.Utils
 
@@ -46,7 +47,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
     in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder])
 
   protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] =
-    in.map(BindReferences.bindReference(_, inputSchema))
+    bindReferences(in, inputSchema)
 
   /**
    * Creates a code gen ordering for sorting this schema, in ascending order.
@@ -188,7 +189,7 @@ class LazilyGeneratedOrdering(val ordering: Seq[SortOrder])
   extends Ordering[InternalRow] with KryoSerializable {
 
   def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) =
-    this(ordering.map(BindReferences.bindReference(_, inputSchema)))
+    this(bindReferences(ordering, inputSchema))
 
   @transient
   private[this] var generatedOrdering = GenerateOrdering.generate(ordering)
index 3977866..e285398 100644 (file)
@@ -21,6 +21,7 @@ import scala.annotation.tailrec
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
@@ -41,7 +42,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
     in.map(ExpressionCanonicalizer.execute)
 
   protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
-    in.map(BindReferences.bindReference(_, inputSchema))
+    bindReferences(in, inputSchema)
 
   private def createCodeForStruct(
       ctx: CodegenContext,
index 0ecd0de..fb1d8a3 100644 (file)
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions.codegen
 
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.types._
 
@@ -317,7 +318,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
     in.map(ExpressionCanonicalizer.execute)
 
   protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
-    in.map(BindReferences.bindReference(_, inputSchema))
+    bindReferences(in, inputSchema)
 
   def generate(
       expressions: Seq[Expression],
index e24a3de..c8d6671 100644 (file)
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.types._
 
 
@@ -27,7 +28,7 @@ import org.apache.spark.sql.types._
 class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] {
 
   def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) =
-    this(ordering.map(BindReferences.bindReference(_, inputSchema)))
+    this(bindReferences(ordering, inputSchema))
 
   def compare(a: InternalRow, b: InternalRow): Int = {
     var i = 0
index bf18e8b..932c364 100644 (file)
@@ -86,13 +86,6 @@ package object expressions  {
   }
 
   /**
-   * A helper function to bind given expressions to an input schema.
-   */
-  def toBoundExprs(exprs: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = {
-    exprs.map(BindReferences.bindReference(_, inputSchema))
-  }
-
-  /**
    * Helper functions for working with `Seq[Attribute]`.
    */
   implicit class AttributeSeq(val attrs: Seq[Attribute]) extends Serializable {
index 5b4edf5..85f4914 100644 (file)
@@ -145,11 +145,12 @@ case class ExpandExec(
     // Part 1: declare variables for each column
     // If a column has the same value for all output rows, then we also generate its computation
     // right after declaration. Otherwise its value is computed in the part 2.
+    lazy val attributeSeq: AttributeSeq = child.output
     val outputColumns = output.indices.map { col =>
       val firstExpr = projections.head(col)
       if (sameOutput(col)) {
         // This column is the same across all output rows. Just generate code for it here.
-        BindReferences.bindReference(firstExpr, child.output).genCode(ctx)
+        BindReferences.bindReference(firstExpr, attributeSeq).genCode(ctx)
       } else {
         val isNull = ctx.freshName("isNull")
         val value = ctx.freshName("value")
@@ -170,7 +171,7 @@ case class ExpandExec(
       var updateCode = ""
       for (col <- exprs.indices) {
         if (!sameOutput(col)) {
-          val ev = BindReferences.bindReference(exprs(col), child.output).genCode(ctx)
+          val ev = BindReferences.bindReference(exprs(col), attributeSeq).genCode(ctx)
           updateCode +=
             s"""
                |${ev.code}
index 98c4a51..a1fb23d 100644 (file)
@@ -77,6 +77,7 @@ abstract class AggregationIterator(
     val expressionsLength = expressions.length
     val functions = new Array[AggregateFunction](expressionsLength)
     var i = 0
+    val inputAttributeSeq: AttributeSeq = inputAttributes
     while (i < expressionsLength) {
       val func = expressions(i).aggregateFunction
       val funcWithBoundReferences: AggregateFunction = expressions(i).mode match {
@@ -86,7 +87,7 @@ abstract class AggregationIterator(
           // this function is Partial or Complete because we will call eval of this
           // function's children in the update method of this aggregate function.
           // Those eval calls require BoundReferences to work.
-          BindReferences.bindReference(func, inputAttributes)
+          BindReferences.bindReference(func, inputAttributeSeq)
         case _ =>
           // We only need to set inputBufferOffset for aggregate functions with mode
           // PartialMerge and Final.
index 2355d30..220a4b0 100644 (file)
@@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
@@ -199,15 +200,13 @@ case class HashAggregateExec(
     val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) {
       // evaluate aggregate results
       ctx.currentVars = bufVars
-      val aggResults = functions.map(_.evaluateExpression).map { e =>
-        BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx)
-      }
+      val aggResults = bindReferences(
+        functions.map(_.evaluateExpression),
+        aggregateBufferAttributes).map(_.genCode(ctx))
       val evaluateAggResults = evaluateVariables(aggResults)
       // evaluate result expressions
       ctx.currentVars = aggResults
-      val resultVars = resultExpressions.map { e =>
-        BindReferences.bindReference(e, aggregateAttributes).genCode(ctx)
-      }
+      val resultVars = bindReferences(resultExpressions, aggregateAttributes).map(_.genCode(ctx))
       (resultVars, s"""
         |$evaluateAggResults
         |${evaluateVariables(resultVars)}
@@ -264,7 +263,7 @@ case class HashAggregateExec(
       }
     }
     ctx.currentVars = bufVars ++ input
-    val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs))
+    val boundUpdateExpr = bindReferences(updateExpr, inputAttrs)
     val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
     val effectiveCodes = subExprs.codes.mkString("\n")
     val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) {
@@ -456,16 +455,16 @@ case class HashAggregateExec(
       val evaluateBufferVars = evaluateVariables(bufferVars)
       // evaluate the aggregation result
       ctx.currentVars = bufferVars
-      val aggResults = declFunctions.map(_.evaluateExpression).map { e =>
-        BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx)
-      }
+      val aggResults = bindReferences(
+        declFunctions.map(_.evaluateExpression),
+        aggregateBufferAttributes).map(_.genCode(ctx))
       val evaluateAggResults = evaluateVariables(aggResults)
       // generate the final result
       ctx.currentVars = keyVars ++ aggResults
       val inputAttrs = groupingAttributes ++ aggregateAttributes
-      val resultVars = resultExpressions.map { e =>
-        BindReferences.bindReference(e, inputAttrs).genCode(ctx)
-      }
+      val resultVars = bindReferences[Expression](
+        resultExpressions,
+        inputAttrs).map(_.genCode(ctx))
       s"""
        $evaluateKeyVars
        $evaluateBufferVars
@@ -494,9 +493,9 @@ case class HashAggregateExec(
 
       ctx.currentVars = keyVars ++ resultBufferVars
       val inputAttrs = resultExpressions.map(_.toAttribute)
-      val resultVars = resultExpressions.map { e =>
-        BindReferences.bindReference(e, inputAttrs).genCode(ctx)
-      }
+      val resultVars = bindReferences[Expression](
+        resultExpressions,
+        inputAttrs).map(_.genCode(ctx))
       s"""
        $evaluateKeyVars
        $evaluateResultBufferVars
@@ -506,9 +505,9 @@ case class HashAggregateExec(
       // generate result based on grouping key
       ctx.INPUT_ROW = keyTerm
       ctx.currentVars = null
-      val eval = resultExpressions.map{ e =>
-        BindReferences.bindReference(e, groupingAttributes).genCode(ctx)
-      }
+      val eval = bindReferences[Expression](
+        resultExpressions,
+        groupingAttributes).map(_.genCode(ctx))
       consume(ctx, eval)
     }
     ctx.addNewFunction(funcName,
@@ -730,9 +729,9 @@ case class HashAggregateExec(
   private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
     // create grouping key
     val unsafeRowKeyCode = GenerateUnsafeProjection.createCode(
-      ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output)))
+      ctx, bindReferences[Expression](groupingExpressions, child.output))
     val fastRowKeys = ctx.generateExpressions(
-      groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output)))
+      bindReferences[Expression](groupingExpressions, child.output))
     val unsafeRowKeys = unsafeRowKeyCode.value
     val unsafeRowBuffer = ctx.freshName("unsafeRowAggBuffer")
     val fastRowBuffer = ctx.freshName("fastAggBuffer")
@@ -825,7 +824,7 @@ case class HashAggregateExec(
 
     val updateRowInRegularHashMap: String = {
       ctx.INPUT_ROW = unsafeRowBuffer
-      val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr))
+      val boundUpdateExpr = bindReferences(updateExpr, inputAttr)
       val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
       val effectiveCodes = subExprs.codes.mkString("\n")
       val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
@@ -849,7 +848,7 @@ case class HashAggregateExec(
       if (isFastHashMapEnabled) {
         if (isVectorizedHashMapEnabled) {
           ctx.INPUT_ROW = fastRowBuffer
-          val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr))
+          val boundUpdateExpr = bindReferences(updateExpr, inputAttr)
           val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
           val effectiveCodes = subExprs.codes.mkString("\n")
           val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
index 2570b36..318dca0 100644 (file)
@@ -24,6 +24,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon
 import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -56,7 +57,7 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
   }
 
   override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
-    val exprs = projectList.map(x => BindReferences.bindReference[Expression](x, child.output))
+    val exprs = bindReferences[Expression](projectList, child.output)
     val resultVars = exprs.map(_.genCode(ctx))
     // Evaluation of non-deterministic expressions can't be deferred.
     val nonDeterministicAttrs = projectList.filterNot(_.deterministic).map(_.toAttribute)
index 774fe38..260ad97 100644 (file)
@@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.catalog.BucketSpec
 import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
 import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
 import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution}
@@ -145,9 +146,8 @@ object FileFormatWriter extends Logging {
         // SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and
         // the physical plan may have different attribute ids due to optimizer removing some
         // aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch.
-        val orderingExpr = requiredOrdering
-          .map(SortOrder(_, Ascending))
-          .map(BindReferences.bindReference(_, outputSpec.outputColumns))
+        val orderingExpr = bindReferences(
+          requiredOrdering.map(SortOrder(_, Ascending)), outputSpec.outputColumns)
         SortExec(
           orderingExpr,
           global = false,
index 1aef5f6..5ee4c7f 100644 (file)
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.joins
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.execution.{RowIterator, SparkPlan}
@@ -63,9 +64,8 @@ trait HashJoin {
   protected lazy val (buildKeys, streamedKeys) = {
     require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType),
       "Join keys from two sides should have same types")
-    val lkeys = HashJoin.rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output))
-    val rkeys = HashJoin.rewriteKeyExpr(rightKeys)
-      .map(BindReferences.bindReference(_, right.output))
+    val lkeys = bindReferences(HashJoin.rewriteKeyExpr(leftKeys), left.output)
+    val rkeys = bindReferences(HashJoin.rewriteKeyExpr(rightKeys), right.output)
     buildSide match {
       case BuildLeft => (lkeys, rkeys)
       case BuildRight => (rkeys, lkeys)
index d7d3f6d..f829f07 100644 (file)
@@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.plans._
@@ -393,7 +394,7 @@ case class SortMergeJoinExec(
       input: Seq[Attribute]): Seq[ExprCode] = {
     ctx.INPUT_ROW = row
     ctx.currentVars = null
-    keys.map(BindReferences.bindReference(_, input).genCode(ctx))
+    bindReferences(keys, input).map(_.genCode(ctx))
   }
 
   private def copyKeys(ctx: CodegenContext, vars: Seq[ExprCode]): Seq[ExprCode] = {
index a560189..d5f2ffa 100644 (file)
@@ -21,6 +21,7 @@ import java.util
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
 import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray
 
@@ -103,9 +104,8 @@ final class OffsetWindowFunctionFrame(
   private[this] val projection = {
     // Collect the expressions and bind them.
     val inputAttrs = inputSchema.map(_.withNullability(true))
-    val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { e =>
-      BindReferences.bindReference(e.input, inputAttrs)
-    }
+    val boundExpressions = Seq.fill(ordinal)(NoOp) ++ bindReferences(
+      expressions.toSeq.map(_.input), inputAttrs)
 
     // Create the projection.
     newMutableProjection(boundExpressions, Nil).target(target)
@@ -114,7 +114,7 @@ final class OffsetWindowFunctionFrame(
   /** Create the projection used when the offset row DOES NOT exists. */
   private[this] val fillDefaultValue = {
     // Collect the expressions and bind them.
-    val inputAttrs = inputSchema.map(_.withNullability(true))
+    val inputAttrs: AttributeSeq = inputSchema.map(_.withNullability(true))
     val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { e =>
       if (e.default == null || e.default.foldable && e.default.eval() == null) {
         // The default value is null.