Skip to content

Commit

Permalink
Merge pull request #20 from rdeits/rd/binarize-kwarg
Browse files Browse the repository at this point in the history
Make `@cse(@binarize(expr))` actually work
  • Loading branch information
rdeits authored Jun 29, 2020
2 parents 7e59d77 + 2ff3077 commit 02e6fb8
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 24 deletions.
85 changes: 61 additions & 24 deletions src/CommonSubexpressions.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module CommonSubexpressions

using MacroTools: @capture, postwalk
using MacroTools: @capture, postwalk, MacroTools
using Base.Iterators: drop

export @cse, cse, @binarize
Expand All @@ -25,7 +25,7 @@ disqualify!(cache::Cache, s::Symbol) = push!(cache.disqualified_symbols, s)
disqualify!(cache::Cache, expr::Expr) = foreach(arg -> disqualify!(cache, arg), expr.args)

# fallback for non-Expr arguments
combine_subexprs!(setup, x, warn_enabled::Bool) = x
combine_subexprs!(setup, x; warn=true, mod=nothing) = x

const standard_expression_forms = Set{Symbol}(
(:call,
Expand All @@ -50,37 +50,42 @@ const assignment_expression_forms = Set{Symbol}(
:(*=),
:(/=)))

function combine_subexprs!(cache::Cache, expr::Expr, warn_enabled::Bool)
function combine_subexprs!(cache::Cache, expr::Expr;
warn::Bool=true, mod::Union{Module, Nothing}=nothing)
if expr.head == :macrocall
# We don't recursively expand other macros, but we can perform CSE on
# the expression inside the macro call.
for i in 2:length(expr.args)
expr.args[i] = combine_subexprs!(expr.args[i], warn_enabled)
if (mod === nothing)
error("""
`cse` cannot expand macro calls unless you explicitly pass in
a `Module` in which to perform that expansion. You can pass
`mod=@__MODULE__` to expand in the current module, or you can use
the `@cse` macro which handles this automatically.""")
end
return combine_subexprs!(cache, macroexpand(mod, expr);
warn=warn, mod=mod)
elseif expr.head == :function
# We can't continue CSE through a function definition, but we can
# start over inside the body of the function:
for i in 2:length(expr.args)
expr.args[i] = combine_subexprs!(expr.args[i], warn_enabled)
expr.args[i] = combine_subexprs!(expr.args[i]; warn=warn, mod=mod)
end
elseif expr.head == :line
# nothing
elseif expr.head in assignment_expression_forms
disqualify!(cache, expr.args[1])
for i in 2:length(expr.args)
expr.args[i] = combine_subexprs!(cache, expr.args[i], warn_enabled)
expr.args[i] = combine_subexprs!(cache, expr.args[i]; warn=warn, mod=mod)
end
elseif expr.head == :generator
for i in vcat(2:length(expr.args), 1)
expr.args[i] = combine_subexprs!(cache, expr.args[i], warn_enabled)
expr.args[i] = combine_subexprs!(cache, expr.args[i]; warn=warn, mod=mod)
end
elseif expr.head in standard_expression_forms
for (i, child) in enumerate(expr.args)
expr.args[i] = combine_subexprs!(cache, child, warn_enabled)
expr.args[i] = combine_subexprs!(cache, child; warn=warn, mod=mod)
end
if expr.head == :call
for (i, child) in enumerate(expr.args)
expr.args[i] = combine_subexprs!(cache, child, warn_enabled)
expr.args[i] = combine_subexprs!(cache, child, warn=warn, mod=mod)
end
if all(!isa(arg, Expr) && !(arg in cache.disqualified_symbols) for arg in drop(expr.args, 1))
combined_args = Symbol(expr.args...)
Expand All @@ -94,38 +99,70 @@ function combine_subexprs!(cache::Cache, expr::Expr, warn_enabled::Bool)
end
end
else
warn_enabled && @warn("CommonSubexpressions can't yet handle expressions of this form: $(expr.head)")
warn && @warn("CommonSubexpressions can't yet handle expressions of this form: $(expr.head)")
end
return expr
end

combine_subexprs!(x, warn_enabled::Bool = true) = x
combine_subexprs!(x; warn=true, mod=nothing) = x

function combine_subexprs!(expr::Expr, warn_enabled::Bool)
function combine_subexprs!(expr::Expr; warn=true, mod=nothing)
cache = Cache()
expr = combine_subexprs!(cache, expr, warn_enabled)
expr = combine_subexprs!(cache, expr; warn=warn, mod=mod)
Expr(:block, cache.setup..., expr)
end

function parse_cse_args(args)
# Overly complicated way to look for `warn=true` or `warn=false`,
# but should be easier to expand for other arguments later.
params = Dict(:warn => true)
for (i, arg) in enumerate(args)
if @capture(arg, key_Symbol = val_Bool)
if key in keys(params)
params[key] = val
else
error("Unrecognized key: $key")
end
elseif i == 1 && arg isa Bool
Base.depwarn("The `warn_enabled` positional argument is deprecated. Please use `warn=true` or `warn=false` instead", :cse_macro_kwargs)
else
error("Unrecognized argument: $arg. Expected `warn=true` or `warn=false`")

end
end
params
end

"""
@cse(expr, warn_enabled = true)
@cse(expr; warn=true)
Perform naive common subexpression elimination under the assumption
that all functions called withing the body of the macro are pure,
meaning that they have no side effects. See [Readme.md](https://github.com/rdeits/CommonSubexpressions.jl/blob/master/Readme.md)
for more details.
If `warn_enabled == true`, then this macro will warn whenever it encounters
an expression type that it does not know how to transform. Otherwise that
expression will be silently left unmodified.
This macro will recursively expand macro calls within the expression before
performing subexpression elimination. A useful macro to combine with this is
`@binarize`, which will turn n-ary function calls into nested binary calls and
can therefore provide more opportunities for subexpression elimination. Usage:
@cse(@binarize(<your code here>))
If the macro encounters an expression which it does not know how to handle,
it will return that expression unmodified. If `warn=true`, then it
will also log a warning in that event.
"""
macro cse(expr, warn_enabled::Bool = true)
result = combine_subexprs!(expr, warn_enabled)
# println(result)
macro cse(expr, args...)
params = parse_cse_args(args)
result = combine_subexprs!(expr, warn=params[:warn], mod=__module__)
esc(result)
end

cse(expr, warn_enabled::Bool = true) = combine_subexprs!(copy(expr), warn_enabled)
Base.@deprecate cse(expr, warn_enabled::Bool) cse(expr, warn=warn_enabled)

function cse(expr; warn::Bool=true, mod::Union{Module, Nothing}=nothing)
combine_subexprs!(copy(expr); warn=warn, mod=mod)
end

function _binarize(expr::Expr)
if @capture(expr, f_(a_, b_, c_, args__))
Expand Down
8 changes: 8 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ end

@testset "warnings" begin
@test_logs (:warn, "CommonSubexpressions can't yet handle expressions of this form: foo") cse(Expr(:foo, 1, 2, 3))

@cse(1 + 2 + 3, false)
@cse(1 + 2 + 3, warn=true)
end

@testset "inplace" begin
Expand Down Expand Up @@ -221,5 +224,10 @@ module NestedMacroTest
special_plus_calls[] = 0
@test(@special_math(@cse((2 + 2) + (2 + 2))) == 8)
@test special_plus_calls[] == 2

special_plus_calls[] = 0
@test(@cse(@binarize(@special_math((1 + 2 + 3) + (1 + 2 + 4) + (1 + 2 + 5)))) == 21)
# Test that the duplicate calls to `1 + 2` were eliminated
@test special_plus_calls[] == 6
end
end

2 comments on commit 02e6fb8

@rdeits
Copy link
Owner Author

@rdeits rdeits commented on 02e6fb8 Jun 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

  • Add @binarize to convert n-ary functions to nested binary calls
  • Recursively expand macros within @cse to allow for helpers like @binarize (previously, nested macro calls were incorrectly dropped altogether)
  • Deprecate warn_enabled positional argument in favor of warn=true keyword argument.
  • Require Julia 1.0

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/17147

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.0 -m "<description of version>" 02e6fb8a0a8f934f30784b4c193d1cdf4cd87c5a
git push origin v0.3.0

Please sign in to comment.