Skip to content

Commit

Permalink
[WIP] Add support for BNLJ
Browse files Browse the repository at this point in the history
  • Loading branch information
Prashant Singh committed Apr 26, 2024
1 parent 4da74d8 commit 5e9e8c5
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 4 deletions.
22 changes: 21 additions & 1 deletion core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use datafusion::{
physical_plan::{
aggregates::{AggregateMode as DFAggregateMode, PhysicalGroupBy},
filter::FilterExec,
joins::{utils::JoinFilter, HashJoinExec, PartitionMode, SortMergeJoinExec},
joins::{utils::JoinFilter, HashJoinExec, NestedLoopJoinExec, PartitionMode, SortMergeJoinExec},
limit::LocalLimitExec,
projection::ProjectionExec,
sorts::sort::SortExec,
Expand Down Expand Up @@ -960,6 +960,26 @@ impl PhysicalPlanner {
)?);
Ok((scans, join))
}

OpStruct::BroadcastNestedLoopJoin(join) => {
// create physical op of arrow data fusion.
let empty_keys: &[Expr] = &[];
let (join_params, scans) = self.parse_join_parameters(
inputs,
children,
&empty_keys, // as bnlj doesn't have join keys
&empty_keys, // as bnlj doesn't have join keys
join.join_type,
&join.condition,
)?;
let join = Arc::new(NestedLoopJoinExec::try_new(
join_params.left,
join_params.right,
join_params.join_filter,
&join_params.join_type
)?);
Ok((scans, join))
}
}
}

Expand Down
7 changes: 7 additions & 0 deletions core/src/execution/proto/operator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ message Operator {
Expand expand = 107;
SortMergeJoin sort_merge_join = 108;
HashJoin hash_join = 109;
BroadcastNestedLoopJoin broadcast_nested_loop_join = 110;
}
}

Expand Down Expand Up @@ -104,6 +105,12 @@ message SortMergeJoin {
repeated spark.spark_expression.Expr sort_options = 4;
}

message BroadcastNestedLoopJoin {
// join keys will always be null.
JoinType join_type = 1;
optional spark.spark_expression.Expr condition = 2;
}

enum JoinType {
Inner = 0;
LeftOuter = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -400,6 +400,25 @@ class CometSparkSessionExtensions
op
}

case op: BroadcastNestedLoopJoinExec
if isCometOperatorEnabled(conf, "broadcast_nested_loop_join") &&
op.children.forall(isCometNative(_)) =>
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
CometBroadcastNestedLoopJoinExec(
nativeOp,
op,
op.joinType,
op.condition,
op.buildSide,
op.left,
op.right,
SerializedPlan(None))
case None =>
op
}

case op: SortMergeJoinExec
if isCometOperatorEnabled(conf, "sort_merge_join") &&
op.children.forall(isCometNative(_)) =>
Expand Down
43 changes: 42 additions & 1 deletion spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -2008,6 +2008,47 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
None
}

case join: BroadcastNestedLoopJoinExec
if isCometOperatorEnabled(op.conf, "broadcast_nested_loop_join") =>
if (join.buildSide == BuildRight) {
if (join.joinType != Inner && join.joinType != LeftOuter
&& join.joinType != LeftSemi && join.joinType != LeftAnti) {
return None
}
} else {
if (join.joinType != RightOuter && join.joinType != FullOuter) {
return None
}
}

val joinType = join.joinType match {
case Inner => JoinType.Inner
case LeftOuter => JoinType.LeftOuter
case RightOuter => JoinType.RightOuter
case FullOuter => JoinType.FullOuter
case LeftSemi => JoinType.LeftSemi
case LeftAnti => JoinType.LeftAnti
case _ => return None // Spark doesn't support other join types
}

val condition = join.condition.map { cond =>
val condProto = exprToProto(cond, join.left.output ++ join.right.output)
if (condProto.isEmpty) {
return None
}
condProto.get
}

if (childOp.nonEmpty) {
val joinBuilder = OperatorOuterClass.BroadcastNestedLoopJoin
.newBuilder()
.setJoinType(joinType)
condition.map(joinBuilder.setCondition(_))
Some(result.setBroadcastNestedLoopJoin(joinBuilder).build())
} else {
None
}

case join: SortMergeJoinExec if isCometOperatorEnabled(op.conf, "sort_merge_join") =>
// `requiredOrders` and `getKeyOrdering` are copied from Spark's SortMergeJoinExec.
def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = {
Expand Down
33 changes: 33 additions & 0 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,39 @@ case class CometBroadcastHashJoinExec(
Objects.hashCode(leftKeys, rightKeys, condition, buildSide, left, right)
}

case class CometBroadcastNestedLoopJoinExec(
override val nativeOp: Operator,
override val originalPlan: SparkPlan,
joinType: JoinType,
condition: Option[Expression],
buildSide: BuildSide,
override val left: SparkPlan,
override val right: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometBinaryExec {
override def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan =
this.copy(left = newLeft, right = newRight)

override def stringArgs: Iterator[Any] =
Iterator(joinType, condition, left, right)

override def equals(obj: Any): Boolean = {
obj match {
case other: CometBroadcastNestedLoopJoinExec =>
this.condition == other.condition &&
this.buildSide == other.buildSide &&
this.left == other.left &&
this.right == other.right &&
this.serializedPlanOpt == other.serializedPlanOpt
case _ =>
false
}
}

override def hashCode(): Int =
Objects.hashCode(condition, buildSide, left, right)
}

case class CometSortMergeJoinExec(
override val nativeOp: Operator,
override val originalPlan: SparkPlan,
Expand Down
30 changes: 29 additions & 1 deletion spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.scalactic.source.Position
import org.scalatest.Tag

import org.apache.spark.sql.CometTestBase
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec}
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec, CometBroadcastNestedLoopJoinExec}
import org.apache.spark.sql.internal.SQLConf

import org.apache.comet.CometConf
Expand Down Expand Up @@ -232,4 +232,32 @@ class CometJoinSuite extends CometTestBase {
}
}
}

test("BroadcastNestedLoopJoin without filter") {
assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark 3.4+")
withSQLConf(
CometConf.COMET_BATCH_SIZE.key -> "100",
SQLConf.PREFER_SORTMERGEJOIN.key -> "false",
"spark.sql.join.forceApplyShuffledHashJoin" -> "true",
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
withParquetTable((0 until 1000).map(i => (i, i % 5)), "tbl_a") {
withParquetTable((0 until 1000).map(i => (i % 10, i + 2)), "tbl_b") {
// Inner join: build right
val df1 =
sql("SELECT /*+ BROADCAST(tbl_b) */ * FROM tbl_a JOIN tbl_b")
checkSparkAnswerAndOperator(
df1,
Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastNestedLoopJoinExec]))

// Right join: build left
val df2 =
sql("SELECT /*+ BROADCAST(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b")
checkSparkAnswerAndOperator(
df2,
Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastNestedLoopJoinExec]))
}
}
}
}
}

0 comments on commit 5e9e8c5

Please sign in to comment.