Skip to content

Commit

Permalink
chore: Don't transform the HashAggregate to CometHashAggregate if Com…
Browse files Browse the repository at this point in the history
…et shuffle is disabled (#991)
  • Loading branch information
viirya authored Oct 1, 2024
1 parent a1599e2 commit 18150fb
Showing 1 changed file with 13 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,11 @@ class CometSparkSessionExtensions

case op: BaseAggregateExec
if op.isInstanceOf[HashAggregateExec] ||
op.isInstanceOf[ObjectHashAggregateExec] =>
op.isInstanceOf[ObjectHashAggregateExec] &&
// When Comet shuffle is disabled, we don't want to transform the HashAggregate
// to CometHashAggregate. Otherwise, we probably get partial Comet aggregation
// and final Spark aggregation.
isCometShuffleEnabled(conf) =>
val groupingExprs = op.groupingExpressions
val aggExprs = op.aggregateExpressions
val resultExpressions = op.resultExpressions
Expand All @@ -451,8 +455,10 @@ class CometSparkSessionExtensions
// Fallback to Spark nevertheless here.
op
} else {
// For a final mode HashAggregate, we only need to transform the HashAggregate
// if there is Comet partial aggregation.
val sparkFinalMode = {
!modes.isEmpty && modes.head == Final && findPartialAgg(child).isEmpty
!modes.isEmpty && modes.head == Final && findCometPartialAgg(child).isEmpty
}

if (sparkFinalMode) {
Expand Down Expand Up @@ -995,13 +1001,15 @@ class CometSparkSessionExtensions
* Find the first Comet partial aggregate in the plan. If it reaches a Spark HashAggregate
* with partial mode, it will return None.
*/
def findPartialAgg(plan: SparkPlan): Option[CometHashAggregateExec] = {
def findCometPartialAgg(plan: SparkPlan): Option[CometHashAggregateExec] = {
plan.collectFirst {
case agg: CometHashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) =>
Some(agg)
case agg: HashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) => None
case a: AQEShuffleReadExec => findPartialAgg(a.child)
case s: ShuffleQueryStageExec => findPartialAgg(s.plan)
case agg: ObjectHashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) =>
None
case a: AQEShuffleReadExec => findCometPartialAgg(a.child)
case s: ShuffleQueryStageExec => findCometPartialAgg(s.plan)
}.flatten
}

Expand Down

0 comments on commit 18150fb

Please sign in to comment.