diff --git a/src/CommonSubexpressions.jl b/src/CommonSubexpressions.jl index 32723ab..29b7d2f 100644 --- a/src/CommonSubexpressions.jl +++ b/src/CommonSubexpressions.jl @@ -1,6 +1,6 @@ module CommonSubexpressions -using MacroTools: @capture, postwalk +using MacroTools: @capture, postwalk, MacroTools using Base.Iterators: drop export @cse, cse, @binarize @@ -25,7 +25,7 @@ disqualify!(cache::Cache, s::Symbol) = push!(cache.disqualified_symbols, s) disqualify!(cache::Cache, expr::Expr) = foreach(arg -> disqualify!(cache, arg), expr.args) # fallback for non-Expr arguments -combine_subexprs!(setup, x, warn_enabled::Bool) = x +combine_subexprs!(setup, x; warn=true, mod=nothing) = x const standard_expression_forms = Set{Symbol}( (:call, @@ -50,37 +50,42 @@ const assignment_expression_forms = Set{Symbol}( :(*=), :(/=))) -function combine_subexprs!(cache::Cache, expr::Expr, warn_enabled::Bool) +function combine_subexprs!(cache::Cache, expr::Expr; + warn::Bool=true, mod::Union{Module, Nothing}=nothing) if expr.head == :macrocall - # We don't recursively expand other macros, but we can perform CSE on - # the expression inside the macro call. - for i in 2:length(expr.args) - expr.args[i] = combine_subexprs!(expr.args[i], warn_enabled) + if (mod === nothing) + error(""" + `cse` cannot expand macro calls unless you explicitly pass in + a `Module` in which to perform that expansion. You can pass + `mod=@__MODULE__` to expand in the current module, or you can use + the `@cse` macro which handles this automatically.""") end + return combine_subexprs!(cache, macroexpand(mod, expr); + warn=warn, mod=mod) elseif expr.head == :function # We can't continue CSE through a function definition, but we can # start over inside the body of the function: for i in 2:length(expr.args) - expr.args[i] = combine_subexprs!(expr.args[i], warn_enabled) + expr.args[i] = combine_subexprs!(expr.args[i]; warn=warn, mod=mod) end elseif expr.head == :line # nothing elseif expr.head in assignment_expression_forms disqualify!(cache, expr.args[1]) for i in 2:length(expr.args) - expr.args[i] = combine_subexprs!(cache, expr.args[i], warn_enabled) + expr.args[i] = combine_subexprs!(cache, expr.args[i]; warn=warn, mod=mod) end elseif expr.head == :generator for i in vcat(2:length(expr.args), 1) - expr.args[i] = combine_subexprs!(cache, expr.args[i], warn_enabled) + expr.args[i] = combine_subexprs!(cache, expr.args[i]; warn=warn, mod=mod) end elseif expr.head in standard_expression_forms for (i, child) in enumerate(expr.args) - expr.args[i] = combine_subexprs!(cache, child, warn_enabled) + expr.args[i] = combine_subexprs!(cache, child; warn=warn, mod=mod) end if expr.head == :call for (i, child) in enumerate(expr.args) - expr.args[i] = combine_subexprs!(cache, child, warn_enabled) + expr.args[i] = combine_subexprs!(cache, child, warn=warn, mod=mod) end if all(!isa(arg, Expr) && !(arg in cache.disqualified_symbols) for arg in drop(expr.args, 1)) combined_args = Symbol(expr.args...) @@ -94,38 +99,70 @@ function combine_subexprs!(cache::Cache, expr::Expr, warn_enabled::Bool) end end else - warn_enabled && @warn("CommonSubexpressions can't yet handle expressions of this form: $(expr.head)") + warn && @warn("CommonSubexpressions can't yet handle expressions of this form: $(expr.head)") end return expr end -combine_subexprs!(x, warn_enabled::Bool = true) = x +combine_subexprs!(x; warn=true, mod=nothing) = x -function combine_subexprs!(expr::Expr, warn_enabled::Bool) +function combine_subexprs!(expr::Expr; warn=true, mod=nothing) cache = Cache() - expr = combine_subexprs!(cache, expr, warn_enabled) + expr = combine_subexprs!(cache, expr; warn=warn, mod=mod) Expr(:block, cache.setup..., expr) end +function parse_cse_args(args) + # Overly complicated way to look for `warn=true` or `warn=false`, + # but should be easier to expand for other arguments later. + params = Dict(:warn => true) + for (i, arg) in enumerate(args) + if @capture(arg, key_Symbol = val_Bool) + if key in keys(params) + params[key] = val + else + error("Unrecognized key: $key") + end + elseif i == 1 && arg isa Bool + Base.depwarn("The `warn_enabled` positional argument is deprecated. Please use `warn=true` or `warn=false` instead", :cse_macro_kwargs) + else + error("Unrecognized argument: $arg. Expected `warn=true` or `warn=false`") + + end + end + params +end + """ - @cse(expr, warn_enabled = true) + @cse(expr; warn=true) Perform naive common subexpression elimination under the assumption that all functions called withing the body of the macro are pure, meaning that they have no side effects. See [Readme.md](https://github.com/rdeits/CommonSubexpressions.jl/blob/master/Readme.md) for more details. -If `warn_enabled == true`, then this macro will warn whenever it encounters -an expression type that it does not know how to transform. Otherwise that -expression will be silently left unmodified. +This macro will recursively expand macro calls within the expression before +performing subexpression elimination. A useful macro to combine with this is +`@binarize`, which will turn n-ary function calls into nested binary calls and +can therefore provide more opportunities for subexpression elimination. Usage: + + @cse(@binarize()) + +If the macro encounters an expression which it does not know how to handle, +it will return that expression unmodified. If `warn=true`, then it +will also log a warning in that event. """ -macro cse(expr, warn_enabled::Bool = true) - result = combine_subexprs!(expr, warn_enabled) - # println(result) +macro cse(expr, args...) + params = parse_cse_args(args) + result = combine_subexprs!(expr, warn=params[:warn], mod=__module__) esc(result) end -cse(expr, warn_enabled::Bool = true) = combine_subexprs!(copy(expr), warn_enabled) +Base.@deprecate cse(expr, warn_enabled::Bool) cse(expr, warn=warn_enabled) + +function cse(expr; warn::Bool=true, mod::Union{Module, Nothing}=nothing) + combine_subexprs!(copy(expr); warn=warn, mod=mod) +end function _binarize(expr::Expr) if @capture(expr, f_(a_, b_, c_, args__)) diff --git a/test/runtests.jl b/test/runtests.jl index dde9277..e6b266c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -136,6 +136,9 @@ end @testset "warnings" begin @test_logs (:warn, "CommonSubexpressions can't yet handle expressions of this form: foo") cse(Expr(:foo, 1, 2, 3)) + + @cse(1 + 2 + 3, false) + @cse(1 + 2 + 3, warn=true) end @testset "inplace" begin @@ -221,5 +224,10 @@ module NestedMacroTest special_plus_calls[] = 0 @test(@special_math(@cse((2 + 2) + (2 + 2))) == 8) @test special_plus_calls[] == 2 + + special_plus_calls[] = 0 + @test(@cse(@binarize(@special_math((1 + 2 + 3) + (1 + 2 + 4) + (1 + 2 + 5)))) == 21) + # Test that the duplicate calls to `1 + 2` were eliminated + @test special_plus_calls[] == 6 end end