Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add mechanism for determining if regexp pattern is supported natively #771

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions native/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions native/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ chrono-tz = { version = "0.8" }
num = "0.4"
rand = "0.8"
regex = "1.9.6"
regex-syntax = "0.8.4"
thiserror = "1"

[profile.release]
Expand Down
23 changes: 20 additions & 3 deletions native/core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ use jni::{
use tokio::runtime::Runtime;

use crate::execution::operators::ScanExec;
use datafusion_comet_spark_expr::is_regexp_supported;
use log::info;

/// Comet native execution context. Kept alive across JNI calls.
Expand Down Expand Up @@ -93,7 +94,7 @@ struct ExecutionContext {

/// Accept serialized query plan and return the address of the native query plan.
/// # Safety
/// This function is inheritly unsafe since it deals with raw pointers passed from JNI.
/// This function is inherently unsafe since it deals with raw pointers passed from JNI.
#[no_mangle]
pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
e: JNIEnv,
Expand Down Expand Up @@ -300,7 +301,7 @@ fn pull_input_batches(exec_context: &mut ExecutionContext) -> Result<(), CometEr
/// Accept serialized query plan and the addresses of Arrow Arrays from Spark,
/// then execute the query. Return addresses of arrow vector.
/// # Safety
/// This function is inheritly unsafe since it deals with raw pointers passed from JNI.
/// This function is inherently unsafe since it deals with raw pointers passed from JNI.
#[no_mangle]
pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
e: JNIEnv,
Expand Down Expand Up @@ -450,7 +451,7 @@ fn get_execution_context<'a>(id: i64) -> &'a mut ExecutionContext {

/// Used by Comet shuffle external sorter to write sorted records to disk.
/// # Safety
/// This function is inheritly unsafe since it deals with raw pointers passed from JNI.
/// This function is inherently unsafe since it deals with raw pointers passed from JNI.
#[no_mangle]
pub unsafe extern "system" fn Java_org_apache_comet_Native_writeSortedFileNative(
e: JNIEnv,
Expand Down Expand Up @@ -535,3 +536,19 @@ pub extern "system" fn Java_org_apache_comet_Native_sortRowPartitionsNative(
Ok(())
})
}

/// Used by QueryPlanSerde to determine if a regular expression is supported natively
/// # Safety
/// This function is inherently unsafe since it deals with raw pointers passed from JNI.
#[no_mangle]
pub unsafe extern "system" fn Java_org_apache_comet_Native_isRegexpPatternSupported(
e: JNIEnv,
_class: JClass,
pattern: jstring,
) -> jboolean {
try_unwrap_or_throw(&e, |mut env| {
let pattern: String = env.get_string(&JString::from_raw(pattern)).unwrap().into();
// if we hit an error parsing the regexp then just report it as unsupported
Ok(is_regexp_supported(&pattern).ok().unwrap_or(false) as jboolean)
})
}
2 changes: 2 additions & 0 deletions native/spark-expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ datafusion-physical-plan = { workspace = true }
chrono-tz = { workspace = true }
num = { workspace = true }
regex = { workspace = true }
regex-syntax = { workspace = true }
thiserror = { workspace = true }
unicode-segmentation = "1.11.0"

Expand All @@ -49,6 +50,7 @@ criterion = "0.5.1"
rand = { workspace = true}
twox-hash = "1.6.3"


[lib]
name = "datafusion_comet_spark_expr"
path = "src/lib.rs"
Expand Down
1 change: 1 addition & 0 deletions native/spark-expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ mod if_expr;

mod kernels;
mod regexp;
pub use regexp::is_regexp_supported;
pub mod scalar_funcs;
pub mod spark_hash;
mod structs;
Expand Down
47 changes: 46 additions & 1 deletion native/spark-expr/src/regexp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ use arrow_array::builder::BooleanBuilder;
use arrow_array::types::Int32Type;
use arrow_array::{Array, BooleanArray, DictionaryArray, RecordBatch, StringArray};
use arrow_schema::{DataType, Schema};
use datafusion_common::{internal_err, Result};
use datafusion_common::{internal_err, DataFusionError, Result};
use datafusion_expr::ColumnarValue;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use regex::Regex;
use regex_syntax::hir::{Class, Hir, HirKind};
use regex_syntax::Parser;
use std::any::Any;
use std::fmt::{Display, Formatter};
use std::hash::{Hash, Hasher};
Expand All @@ -47,6 +49,27 @@ pub struct RLike {
pattern: Regex,
}

/// Determine if a regex pattern is guaranteed to produce the same values as
/// Java's regular expression engine
pub fn is_regexp_supported(pattern: &str) -> Result<bool> {
let ast = Parser::new()
.parse(pattern)
.map_err(|e| DataFusionError::Execution(format!("Error parsing regex pattern: {}", e)))?;
Ok(is_compat(&ast))
}

fn is_compat(ast: &Hir) -> bool {
match ast.kind() {
// character class such as `[a-z]` or `[^aeiou]`
HirKind::Class(Class::Unicode(c)) => c.is_ascii(),
// repetition quantifier such as `+`, `*`, `?`, `{1,3}`
HirKind::Repetition(r) => is_compat(r.sub.as_ref()),
// series of expressions such as `[A-Z][a-z]`
HirKind::Concat(items) => items.iter().all(is_compat),
_ => false,
}
}

impl Hash for RLike {
fn hash<H: Hasher>(&self, state: &mut H) {
state.write(self.pattern_str.as_bytes());
Expand Down Expand Up @@ -168,3 +191,25 @@ impl PhysicalExpr for RLike {
self.hash(&mut s);
}
}

#[cfg(test)]
mod test {
use crate::regexp::is_regexp_supported;
use datafusion_common::Result;

#[test]
fn parse_supported_regex() -> Result<()> {
assert!(is_regexp_supported("[a-z]")?);
assert!(is_regexp_supported("[a-z]+")?);
assert!(is_regexp_supported("[a-z]*")?);
assert!(is_regexp_supported("[a-z]?")?);
Ok(())
}

#[test]
fn parse_unsupported_regex() -> Result<()> {
assert!(!is_regexp_supported("abc$")?);
assert!(!is_regexp_supported("^abc")?);
Ok(())
}
}
11 changes: 11 additions & 0 deletions spark/src/main/scala/org/apache/comet/Native.scala
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,15 @@ class Native extends NativeBase {
* the size of the array.
*/
@native def sortRowPartitionsNative(addr: Long, size: Long): Unit

/**
* Determine if a regular expression pattern is supported natively and guaranteed to produce the
* same result as Spark.
*
* @param pattern
* The regular expression pattern
* @return
* true if supported
*/
@native def isRegexpPatternSupported(pattern: String): Boolean
}
32 changes: 0 additions & 32 deletions spark/src/main/scala/org/apache/comet/expressions/RegExp.scala

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import org.apache.comet.CometConf
import org.apache.comet.{CometConf, Native}
import org.apache.comet.CometSparkSessionExtensions.{isCometOperatorEnabled, isCometScan, isSpark34Plus, withInfo}
import org.apache.comet.expressions.{CometCast, CometEvalMode, Compatible, Incompatible, RegExp, Unsupported}
import org.apache.comet.expressions.{CometCast, CometEvalMode, Compatible, Incompatible, Unsupported}
import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc}
import org.apache.comet.serde.ExprOuterClass.DataType.{DataTypeInfo, DecimalInfo, ListInfo, MapInfo, StructInfo}
import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, BuildSide, JoinType, Operator}
Expand All @@ -55,6 +55,9 @@ import org.apache.comet.shims.ShimQueryPlanSerde
* An utility object for query plan and expression serialization.
*/
object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim {

private val native = new Native()

def emitWarning(reason: String): Unit = {
logWarning(s"Comet native execution is disabled due to: $reason")
}
Expand Down Expand Up @@ -1237,7 +1240,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
// we currently only support scalar regex patterns
right match {
case Literal(pattern, DataTypes.StringType) =>
if (!RegExp.isSupportedPattern(pattern.toString) &&
if (!native.isRegexpPatternSupported(pattern.toString) &&
!CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.get()) {
withInfo(
expr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -625,11 +625,9 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
// add repetitive data to trigger dictionary encoding
Range(0, 100).map(_ => "John Smith")
withParquetFile(data.zipWithIndex, withDictionary) { file =>
withSQLConf(CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key -> "true") {
spark.read.parquet(file).createOrReplaceTempView(table)
val query = sql(s"select _2 as id, _1 rlike 'R[a-z]+s [Rr]ose' from $table")
checkSparkAnswerAndOperator(query)
}
spark.read.parquet(file).createOrReplaceTempView(table)
val query = sql(s"select _2 as id, _1 rlike '[M-R]+[a-z]+' from $table")
checkSparkAnswerAndOperator(query)
}
}
}
Expand Down
Loading