[SPARK-26065][FOLLOW-UP][SQL] Revert hint behavior in join reordering
authormaryannxue <maryannxue@apache.org>
Sun, 13 Jan 2019 23:30:45 +0000 (15:30 -0800)
committergatorsmile <gatorsmile@gmail.com>
Sun, 13 Jan 2019 23:30:45 +0000 (15:30 -0800)
## What changes were proposed in this pull request?

This is to fix a bug in #23036 that would cause a join hint to be applied on node it is not supposed to after join reordering. For example,
```
  val join = df.join(df, "id")
  val broadcasted = join.hint("broadcast")
  val join2 = join.join(broadcasted, "id").join(broadcasted, "id")
```
There should only be 2 broadcast hints on `join2`, but after join reordering there would be 4. It is because the hint application in join reordering compares the attribute set for testing relation equivalency.
Moreover, it could still be problematic even if the child relations were used in testing relation equivalency, due to the potential exprId conflict in nested self-join.

As a result, this PR simply reverts the join reorder hint behavior change introduced in #23036, which means if a join hint is present, the join node itself will not participate in the join reordering, while the sub-joins within its children still can.

## How was this patch tested?

Added new tests

Closes #23524 from maryannxue/query-hint-followup-2.

Authored-by: maryannxue <maryannxue@apache.org>
Signed-off-by: gatorsmile <gatorsmile@gmail.com>
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala
sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala

index 743d3ce..6540e95 100644 (file)
@@ -31,40 +31,6 @@ import org.apache.spark.sql.internal.SQLConf
  * Cost-based join reorder.
  * We may have several join reorder algorithms in the future. This class is the entry of these
  * algorithms, and chooses which one to use.
- *
- * Note that join strategy hints, e.g. the broadcast hint, do not interfere with the reordering.
- * Such hints will be applied on the equivalent counterparts (i.e., join between the same relations
- * regardless of the join order) of the original nodes after reordering.
- * For example, the plan before reordering is like:
- *
- *            Join
- *            /   \
- *          Hint1 t4
- *          /
- *        Join
- *        / \
- *      Join t3
- *      /   \
- *    Hint2 t2
- *    /
- *   t1
- *
- * The original join order as illustrated above is "((t1 JOIN t2) JOIN t3) JOIN t4", and after
- * reordering, the new join order is "((t1 JOIN t3) JOIN t2) JOIN t4", so the new plan will be like:
- *
- *            Join
- *            /   \
- *          Hint1 t4
- *          /
- *        Join
- *        / \
- *      Join t2
- *      / \
- *    t1  t3
- *
- * "Hint1" is applied on "(t1 JOIN t3) JOIN t2" as it is equivalent to the original hinted node,
- * "(t1 JOIN t2) JOIN t3"; while "Hint2" has disappeared from the new plan since there is no
- * equivalent node to "t1 JOIN t2".
  */
 object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper {
 
@@ -74,30 +40,24 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper {
     if (!conf.cboEnabled || !conf.joinReorderEnabled) {
       plan
     } else {
-      // Use a map to track the hints on the join items.
-      val hintMap = new mutable.HashMap[AttributeSet, HintInfo]
       val result = plan transformDown {
         // Start reordering with a joinable item, which is an InnerLike join with conditions.
-        case j @ Join(_, _, _: InnerLike, Some(cond), _) =>
-          reorder(j, j.output, hintMap)
-        case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond), _))
-          if projectList.forall(_.isInstanceOf[Attribute]) =>
-          reorder(p, p.output, hintMap)
+        // Avoid reordering if a join hint is present.
+        case j @ Join(_, _, _: InnerLike, Some(cond), hint) if hint == JoinHint.NONE =>
+          reorder(j, j.output)
+        case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond), hint))
+          if projectList.forall(_.isInstanceOf[Attribute]) && hint == JoinHint.NONE =>
+          reorder(p, p.output)
       }
       // After reordering is finished, convert OrderedJoin back to Join.
       result transform {
-        case OrderedJoin(left, right, jt, cond) =>
-          val joinHint = JoinHint(hintMap.get(left.outputSet), hintMap.get(right.outputSet))
-          Join(left, right, jt, cond, joinHint)
+        case OrderedJoin(left, right, jt, cond) => Join(left, right, jt, cond, JoinHint.NONE)
       }
     }
   }
 
-  private def reorder(
-      plan: LogicalPlan,
-      output: Seq[Attribute],
-      hintMap: mutable.HashMap[AttributeSet, HintInfo]): LogicalPlan = {
-    val (items, conditions) = extractInnerJoins(plan, hintMap)
+  private def reorder(plan: LogicalPlan, output: Seq[Attribute]): LogicalPlan = {
+    val (items, conditions) = extractInnerJoins(plan)
     val result =
       // Do reordering if the number of items is appropriate and join conditions exist.
       // We also need to check if costs of all items can be evaluated.
@@ -115,20 +75,16 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper {
    * Extracts items of consecutive inner joins and join conditions.
    * This method works for bushy trees and left/right deep trees.
    */
-  private def extractInnerJoins(
-      plan: LogicalPlan,
-      hintMap: mutable.HashMap[AttributeSet, HintInfo]): (Seq[LogicalPlan], Set[Expression]) = {
+  private def extractInnerJoins(plan: LogicalPlan): (Seq[LogicalPlan], Set[Expression]) = {
     plan match {
-      case Join(left, right, _: InnerLike, Some(cond), hint) =>
-        hint.leftHint.foreach(hintMap.put(left.outputSet, _))
-        hint.rightHint.foreach(hintMap.put(right.outputSet, _))
-        val (leftPlans, leftConditions) = extractInnerJoins(left, hintMap)
-        val (rightPlans, rightConditions) = extractInnerJoins(right, hintMap)
+      case Join(left, right, _: InnerLike, Some(cond), _) =>
+        val (leftPlans, leftConditions) = extractInnerJoins(left)
+        val (rightPlans, rightConditions) = extractInnerJoins(right)
         (leftPlans ++ rightPlans, splitConjunctivePredicates(cond).toSet ++
           leftConditions ++ rightConditions)
       case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), _))
         if projectList.forall(_.isInstanceOf[Attribute]) =>
-        extractInnerJoins(j, hintMap)
+        extractInnerJoins(j)
       case _ =>
         (Seq(plan), Set())
     }
index 82aefca..251ece3 100644 (file)
@@ -43,13 +43,11 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
    *
    * @param input a list of LogicalPlans to inner join and the type of inner join.
    * @param conditions a list of condition for join.
-   * @param hintMap a map of relation output attribute sets to their corresponding hints.
    */
   @tailrec
   final def createOrderedJoin(
       input: Seq[(LogicalPlan, InnerLike)],
-      conditions: Seq[Expression],
-      hintMap: Map[AttributeSet, HintInfo]): LogicalPlan = {
+      conditions: Seq[Expression]): LogicalPlan = {
     assert(input.size >= 2)
     if (input.size == 2) {
       val (joinConditions, others) = conditions.partition(canEvaluateWithinJoin)
@@ -58,8 +56,8 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
         case (Inner, Inner) => Inner
         case (_, _) => Cross
       }
-      val join = Join(left, right, innerJoinType, joinConditions.reduceLeftOption(And),
-        JoinHint(hintMap.get(left.outputSet), hintMap.get(right.outputSet)))
+      val join = Join(left, right, innerJoinType,
+        joinConditions.reduceLeftOption(And), JoinHint.NONE)
       if (others.nonEmpty) {
         Filter(others.reduceLeft(And), join)
       } else {
@@ -82,27 +80,27 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
       val joinedRefs = left.outputSet ++ right.outputSet
       val (joinConditions, others) = conditions.partition(
         e => e.references.subsetOf(joinedRefs) && canEvaluateWithinJoin(e))
-      val joined = Join(left, right, innerJoinType, joinConditions.reduceLeftOption(And),
-        JoinHint(hintMap.get(left.outputSet), hintMap.get(right.outputSet)))
+      val joined = Join(left, right, innerJoinType,
+        joinConditions.reduceLeftOption(And), JoinHint.NONE)
 
       // should not have reference to same logical plan
-      createOrderedJoin(Seq((joined, Inner)) ++ rest.filterNot(_._1 eq right), others, hintMap)
+      createOrderedJoin(Seq((joined, Inner)) ++ rest.filterNot(_._1 eq right), others)
     }
   }
 
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
-    case p @ ExtractFiltersAndInnerJoins(input, conditions, hintMap)
+    case p @ ExtractFiltersAndInnerJoins(input, conditions)
         if input.size > 2 && conditions.nonEmpty =>
       val reordered = if (SQLConf.get.starSchemaDetection && !SQLConf.get.cboEnabled) {
         val starJoinPlan = StarSchemaDetection.reorderStarJoins(input, conditions)
         if (starJoinPlan.nonEmpty) {
           val rest = input.filterNot(starJoinPlan.contains(_))
-          createOrderedJoin(starJoinPlan ++ rest, conditions, hintMap)
+          createOrderedJoin(starJoinPlan ++ rest, conditions)
         } else {
-          createOrderedJoin(input, conditions, hintMap)
+          createOrderedJoin(input, conditions)
         }
       } else {
-        createOrderedJoin(input, conditions, hintMap)
+        createOrderedJoin(input, conditions)
       }
 
       if (p.sameOutput(reordered)) {
index 95be0a5..a816922 100644 (file)
@@ -166,35 +166,27 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper {
    * was involved in an explicit cross join. Also returns the entire list of join conditions for
    * the left-deep tree.
    */
-  def flattenJoin(
-      plan: LogicalPlan,
-      hintMap: mutable.HashMap[AttributeSet, HintInfo],
-      parentJoinType: InnerLike = Inner)
+  def flattenJoin(plan: LogicalPlan, parentJoinType: InnerLike = Inner)
       : (Seq[(LogicalPlan, InnerLike)], Seq[Expression]) = plan match {
-    case Join(left, right, joinType: InnerLike, cond, hint) =>
-      val (plans, conditions) = flattenJoin(left, hintMap, joinType)
-      hint.leftHint.map(hintMap.put(left.outputSet, _))
-      hint.rightHint.map(hintMap.put(right.outputSet, _))
+    case Join(left, right, joinType: InnerLike, cond, hint) if hint == JoinHint.NONE =>
+      val (plans, conditions) = flattenJoin(left, joinType)
       (plans ++ Seq((right, joinType)), conditions ++
         cond.toSeq.flatMap(splitConjunctivePredicates))
-    case Filter(filterCondition, j @ Join(_, _, _: InnerLike, _, _)) =>
-      val (plans, conditions) = flattenJoin(j, hintMap)
+    case Filter(filterCondition, j @ Join(_, _, _: InnerLike, _, hint)) if hint == JoinHint.NONE =>
+      val (plans, conditions) = flattenJoin(j)
       (plans, conditions ++ splitConjunctivePredicates(filterCondition))
 
     case _ => (Seq((plan, parentJoinType)), Seq.empty)
   }
 
   def unapply(plan: LogicalPlan)
-      : Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression], Map[AttributeSet, HintInfo])]
+      : Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression])]
       = plan match {
-    case f @ Filter(filterCondition, j @ Join(_, _, joinType: InnerLike, _, _)) =>
-      val hintMap = new mutable.HashMap[AttributeSet, HintInfo]
-      val flattened = flattenJoin(f, hintMap)
-      Some((flattened._1, flattened._2, hintMap.toMap))
-    case j @ Join(_, _, joinType, _, _) =>
-      val hintMap = new mutable.HashMap[AttributeSet, HintInfo]
-      val flattened = flattenJoin(j, hintMap)
-      Some((flattened._1, flattened._2, hintMap.toMap))
+    case f @ Filter(filterCondition, j @ Join(_, _, joinType: InnerLike, _, hint))
+        if hint == JoinHint.NONE =>
+      Some(flattenJoin(f))
+    case j @ Join(_, _, joinType, _, hint) if hint == JoinHint.NONE =>
+      Some(flattenJoin(j))
     case _ => None
   }
 }
index 9093d7f..c570643 100644 (file)
@@ -66,7 +66,7 @@ class JoinOptimizationSuite extends PlanTest {
     def testExtractCheckCross
         (plan: LogicalPlan, expected: Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression])]) {
       assert(
-        ExtractFiltersAndInnerJoins.unapply(plan) === expected.map(e => (e._1, e._2, Map.empty)))
+        ExtractFiltersAndInnerJoins.unapply(plan) === expected.map(e => (e._1, e._2)))
     }
 
     testExtract(x, None)
index 0dee846..f1da0a8 100644 (file)
@@ -292,77 +292,56 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
     assertEqualPlans(originalPlan, bestPlan)
   }
 
-  test("hints preservation") {
-    // Apply hints if we find an equivalent node in the new plan, otherwise discard them.
+  test("don't reorder if hints present") {
     val originalPlan =
-      t1.join(t2.hint("broadcast")).hint("broadcast").join(t4.join(t3).hint("broadcast"))
-        .where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) &&
-          (nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
-          (nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
-
-    val bestPlan =
-    t1.join(t2.hint("broadcast"), Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
-      .hint("broadcast")
-      .join(
-        t4.join(t3, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
-          .hint("broadcast"),
-        Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
+      t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
+        .hint("broadcast")
+        .join(
+          t4.join(t3, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
+            .hint("broadcast"),
+          Inner,
+          Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
 
-    assertEqualPlans(originalPlan, bestPlan)
+    assertEqualPlans(originalPlan, originalPlan)
 
     val originalPlan2 =
-      t1.join(t2).hint("broadcast").join(t3).hint("broadcast").join(t4.hint("broadcast"))
-        .where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) &&
-          (nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
-          (nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
-
-    val bestPlan2 =
       t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
         .hint("broadcast")
-        .join(
-          t4.hint("broadcast")
-            .join(t3, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))),
-          Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
-        .select(outputsOf(t1, t2, t3, t4): _*)
+        .join(t4, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
+        .hint("broadcast")
+        .join(t3, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
 
-    assertEqualPlans(originalPlan2, bestPlan2)
+    assertEqualPlans(originalPlan2, originalPlan2)
+  }
 
-    val originalPlan3 =
-      t1.join(t4).hint("broadcast")
-        .join(t2.hint("broadcast")).hint("broadcast")
-        .join(t3.hint("broadcast"))
-        .where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) &&
-        (nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
-        (nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
+  test("reorder below and above the hint node") {
+    val originalPlan =
+      t1.join(t2).join(t3)
+        .where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
+          (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
+        .hint("broadcast").join(t4)
 
-    val bestPlan3 =
-      t1.join(t2.hint("broadcast"), Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
-        .join(
-          t4.join(t3.hint("broadcast"),
-            Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))),
-          Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
-        .select(outputsOf(t1, t4, t2, t3): _*)
+    val bestPlan =
+    t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
+      .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
+      .select(outputsOf(t1, t2, t3): _*)
+      .hint("broadcast").join(t4)
 
-    assertEqualPlans(originalPlan3, bestPlan3)
+    assertEqualPlans(originalPlan, bestPlan)
 
-    val originalPlan4 =
-      t2.hint("broadcast")
-        .join(t4).hint("broadcast")
-        .join(t3.hint("broadcast")).hint("broadcast")
-        .join(t1)
-        .where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) &&
-          (nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
-          (nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
+    val originalPlan2 =
+      t1.join(t2).join(t3)
+        .where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
+          (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
+        .join(t4.hint("broadcast"))
 
-    val bestPlan4 =
-      t1.join(t2.hint("broadcast"), Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
-        .join(
-          t4.join(t3.hint("broadcast"),
-            Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))),
-          Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
-        .select(outputsOf(t2, t4, t3, t1): _*)
+    val bestPlan2 =
+      t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
+        .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
+        .select(outputsOf(t1, t2, t3): _*)
+        .join(t4.hint("broadcast"))
 
-    assertEqualPlans(originalPlan4, bestPlan4)
+    assertEqualPlans(originalPlan2, bestPlan2)
   }
 
   private def assertEqualPlans(
index 55f210c..30a3d54 100644 (file)
@@ -19,6 +19,7 @@ package org.apache.spark.sql
 
 import org.apache.spark.sql.catalyst.plans.PlanTest
 import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
 
 class JoinHintSuite extends PlanTest with SharedSQLContext {
@@ -100,7 +101,7 @@ class JoinHintSuite extends PlanTest with SharedSQLContext {
     }
   }
 
-  test("hint preserved after join reorder") {
+  test("hints prevent join reorder") {
     withTempView("a", "b", "c") {
       df1.createOrReplaceTempView("a")
       df2.createOrReplaceTempView("b")
@@ -118,12 +119,10 @@ class JoinHintSuite extends PlanTest with SharedSQLContext {
       verifyJoinHint(
         sql("select /*+ broadcast(a, c)*/ * from a, c, b " +
           "where a.a1 = b.b1 and b.b1 = c.c1"),
-        JoinHint(
-          None,
-          Some(HintInfo(broadcast = true))) ::
+        JoinHint.NONE ::
           JoinHint(
             Some(HintInfo(broadcast = true)),
-            None):: Nil
+            Some(HintInfo(broadcast = true))):: Nil
       )
       verifyJoinHint(
         sql("select /*+ broadcast(b, c)*/ * from a, c, b " +
@@ -199,4 +198,21 @@ class JoinHintSuite extends PlanTest with SharedSQLContext {
         None) :: Nil
     )
   }
+
+  test("hints prevent cost-based join reorder") {
+    withSQLConf(SQLConf.CBO_ENABLED.key -> "true", SQLConf.JOIN_REORDER_ENABLED.key -> "true") {
+      val join = df.join(df, "id")
+      val broadcasted = join.hint("broadcast")
+      verifyJoinHint(
+        join.join(broadcasted, "id").join(broadcasted, "id"),
+        JoinHint(
+          None,
+          Some(HintInfo(broadcast = true))) ::
+          JoinHint(
+            None,
+            Some(HintInfo(broadcast = true))) ::
+          JoinHint.NONE :: JoinHint.NONE :: JoinHint.NONE :: Nil
+      )
+    }
+  }
 }