diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index d52f31f54..bf09d6417 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -439,7 +439,11 @@ class CometSparkSessionExtensions case op: BaseAggregateExec if op.isInstanceOf[HashAggregateExec] || - op.isInstanceOf[ObjectHashAggregateExec] => + op.isInstanceOf[ObjectHashAggregateExec] && + // When Comet shuffle is disabled, we don't want to transform the HashAggregate + // to CometHashAggregate. Otherwise, we probably get partial Comet aggregation + // and final Spark aggregation. + isCometShuffleEnabled(conf) => val groupingExprs = op.groupingExpressions val aggExprs = op.aggregateExpressions val resultExpressions = op.resultExpressions @@ -451,8 +455,10 @@ class CometSparkSessionExtensions // Fallback to Spark nevertheless here. op } else { + // For a final mode HashAggregate, we only need to transform the HashAggregate + // if there is Comet partial aggregation. val sparkFinalMode = { - !modes.isEmpty && modes.head == Final && findPartialAgg(child).isEmpty + !modes.isEmpty && modes.head == Final && findCometPartialAgg(child).isEmpty } if (sparkFinalMode) { @@ -995,13 +1001,15 @@ class CometSparkSessionExtensions * Find the first Comet partial aggregate in the plan. If it reaches a Spark HashAggregate * with partial mode, it will return None. */ - def findPartialAgg(plan: SparkPlan): Option[CometHashAggregateExec] = { + def findCometPartialAgg(plan: SparkPlan): Option[CometHashAggregateExec] = { plan.collectFirst { case agg: CometHashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) => Some(agg) case agg: HashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) => None - case a: AQEShuffleReadExec => findPartialAgg(a.child) - case s: ShuffleQueryStageExec => findPartialAgg(s.plan) + case agg: ObjectHashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) => + None + case a: AQEShuffleReadExec => findCometPartialAgg(a.child) + case s: ShuffleQueryStageExec => findCometPartialAgg(s.plan) }.flatten }