Skip to content

Commit

Permalink
Implement ModArithType for mod_arith dialect
Browse files Browse the repository at this point in the history
  • Loading branch information
ZenithalHourlyRate committed Nov 15, 2024
1 parent ddfddb4 commit a8eb769
Show file tree
Hide file tree
Showing 21 changed files with 769 additions and 522 deletions.
1 change: 1 addition & 0 deletions lib/Dialect/ModArith/Conversions/ModArithToArith/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ cc_library(
deps = [
":pass_inc_gen",
"@heir//lib/Dialect/ModArith/IR:Dialect",
"@heir//lib/Utils/ConversionUtils",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
Expand Down
213 changes: 202 additions & 11 deletions lib/Dialect/ModArith/Conversions/ModArithToArith/ModArithToArith.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "lib/Dialect/ModArith/Conversions/ModArithToArith/ModArithToArith.h"

#include "lib/Dialect/ModArith/IR/ModArithOps.h"
#include "lib/Dialect/ModArith/IR/ModArithTypes.h"
#include "lib/Utils/ConversionUtils/ConversionUtils.h"
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project
Expand All @@ -14,23 +16,207 @@ namespace mod_arith {
#define GEN_PASS_DEF_MODARITHTOARITH
#include "lib/Dialect/ModArith/Conversions/ModArithToArith/ModArithToArith.h.inc"

/// Returns a possibly extended modulus necessary to compute the given operation
/// without overflow.
template <typename ValueOrOpResult>
TypedAttr modulusHelper(IntegerAttr mod, ValueOrOpResult op, bool mul = false,
bool reduce = false) {
auto width = getElementTypeOrSelf(op).getIntOrFloatBitWidth();
auto modWidth = (mod.getValue() - 1).getActiveBits();
width = reduce ? width : std::max(width, mul ? 2 * modWidth : modWidth + 1);
IntegerType convertModArithType(ModArithType type) {
APInt modulus = type.getModulus().getValue();
return IntegerType::get(type.getContext(), modulus.getBitWidth());
}

Type convertModArithLikeType(ShapedType type) {
if (auto modArithType = llvm::dyn_cast<ModArithType>(type.getElementType())) {
return type.cloneWith(type.getShape(), convertModArithType(modArithType));
}
return type;
}

class ModArithToArithTypeConverter : public TypeConverter {
public:
ModArithToArithTypeConverter(MLIRContext *ctx) {
addConversion([](Type type) { return type; });
addConversion(
[](ModArithType type) -> Type { return convertModArithType(type); });
addConversion(
[](ShapedType type) -> Type { return convertModArithLikeType(type); });
}
};

// A herlper function to generate the attribute or type
// needed to represent the result of modarith op as an integer
// before applying a remainder operation
template <typename Op>
TypedAttr modulusAttr(Op op, bool mul = false) {
auto type = op.getResult().getType();
auto modArithType = getResultModArithType(op);
APInt modulus = modArithType.getModulus().getValue();

auto width = modulus.getBitWidth();
if (mul) {
width *= 2;
}

auto intType = IntegerType::get(op.getContext(), width);
auto truncmod = mod.getValue().zextOrTrunc(width);
if (auto st = mlir::dyn_cast<ShapedType>(op.getType())) {
auto truncmod = modulus.zextOrTrunc(width);

if (auto st = mlir::dyn_cast<ShapedType>(type)) {
auto containerType = st.cloneWith(st.getShape(), intType);
return DenseElementsAttr::get(containerType, truncmod);
}
return IntegerAttr::get(intType, truncmod);
}

// used for extui/trunci
template <typename Op>
inline Type modulusType(Op op, bool mul = false) {
return modulusAttr(op, mul).getType();
}

struct ConvertEncapsulate : public OpConversionPattern<EncapsulateOp> {
ConvertEncapsulate(mlir::MLIRContext *context)
: OpConversionPattern<EncapsulateOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
EncapsulateOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceAllUsesWith(op.getResult(), adaptor.getOperands()[0]);
rewriter.eraseOp(op);
return success();
}
};

struct ConvertExtract : public OpConversionPattern<ExtractOp> {
ConvertExtract(mlir::MLIRContext *context)
: OpConversionPattern<ExtractOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
ExtractOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceAllUsesWith(op.getResult(), adaptor.getOperands()[0]);
rewriter.eraseOp(op);
return success();
}
};

struct ConvertReduce : public OpConversionPattern<ReduceOp> {
ConvertReduce(mlir::MLIRContext *context)
: OpConversionPattern<ReduceOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto cmod = b.create<arith::ConstantOp>(modulusAttr(op));
// ModArithType ensures cmod can be correctly interpreted as a signed number
auto rems = b.create<arith::RemSIOp>(adaptor.getOperands()[0], cmod);
auto add = b.create<arith::AddIOp>(rems, cmod);
// TODO(#710): better with a subifge
auto remu = b.create<arith::RemUIOp>(add, cmod);
rewriter.replaceOp(op, remu);
return success();
}
};

// It is assumed inputs are canonical representatives
// ModArithType ensures add/sub result can not overflow
struct ConvertAdd : public OpConversionPattern<AddOp> {
ConvertAdd(mlir::MLIRContext *context)
: OpConversionPattern<AddOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
AddOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto cmod = b.create<arith::ConstantOp>(modulusAttr(op));
auto add = b.create<arith::AddIOp>(adaptor.getLhs(), adaptor.getRhs());
auto remu = b.create<arith::RemUIOp>(add, cmod);

rewriter.replaceOp(op, remu);
return success();
}
};

struct ConvertSub : public OpConversionPattern<SubOp> {
ConvertSub(mlir::MLIRContext *context)
: OpConversionPattern<SubOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
SubOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto cmod = b.create<arith::ConstantOp>(modulusAttr(op));
auto sub = b.create<arith::SubIOp>(adaptor.getLhs(), adaptor.getRhs());
auto add = b.create<arith::AddIOp>(sub, cmod);
auto remu = b.create<arith::RemUIOp>(add, cmod);

rewriter.replaceOp(op, remu);
return success();
}
};

struct ConvertMul : public OpConversionPattern<MulOp> {
ConvertMul(mlir::MLIRContext *context)
: OpConversionPattern<MulOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
MulOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto cmod = b.create<arith::ConstantOp>(modulusAttr(op, true));
auto lhs =
b.create<arith::ExtUIOp>(modulusType(op, true), adaptor.getLhs());
auto rhs =
b.create<arith::ExtUIOp>(modulusType(op, true), adaptor.getRhs());
auto mul = b.create<arith::MulIOp>(lhs, rhs);
auto remu = b.create<arith::RemUIOp>(mul, cmod);
auto trunc = b.create<arith::TruncIOp>(modulusType(op), remu);

rewriter.replaceOp(op, trunc);
return success();
}
};

struct ConvertMac : public OpConversionPattern<MacOp> {
ConvertMac(mlir::MLIRContext *context)
: OpConversionPattern<MacOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
MacOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto cmod = b.create<arith::ConstantOp>(modulusAttr(op, true));
auto x = b.create<arith::ExtUIOp>(modulusType(op, true),
adaptor.getOperands()[0]);
auto y = b.create<arith::ExtUIOp>(modulusType(op, true),
adaptor.getOperands()[1]);
auto acc = b.create<arith::ExtUIOp>(modulusType(op, true),
adaptor.getOperands()[2]);
auto mul = b.create<arith::MulIOp>(x, y);
auto add = b.create<arith::AddIOp>(mul, acc);
auto remu = b.create<arith::RemUIOp>(add, cmod);
auto trunc = b.create<arith::TruncIOp>(modulusType(op), remu);

rewriter.replaceOp(op, trunc);
return success();
}
};

namespace rewrites {
// In an inner namespace to avoid conflicts with canonicalization patterns
#include "lib/Dialect/ModArith/Conversions/ModArithToArith/ModArithToArith.cpp.inc"
Expand Down Expand Up @@ -104,14 +290,19 @@ struct ModArithToArith : impl::ModArithToArithBase<ModArithToArith> {
void ModArithToArith::runOnOperation() {
MLIRContext *context = &getContext();
ModuleOp module = getOperation();
ModArithToArithTypeConverter typeConverter(context);

ConversionTarget target(*context);
target.addIllegalDialect<ModArithDialect>();
target.addLegalDialect<arith::ArithDialect>();

RewritePatternSet patterns(context);
rewrites::populateWithGenerated(patterns);
patterns.add<ConvertBarrettReduce>(context);
patterns.add<ConvertEncapsulate, ConvertExtract, ConvertReduce, ConvertAdd,
ConvertSub, ConvertMul, ConvertMac, ConvertBarrettReduce>(
typeConverter, context);

addStructuralConversionPatterns(typeConverter, patterns, target);

if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
signalPassFailure();
Expand Down
129 changes: 0 additions & 129 deletions lib/Dialect/ModArith/Conversions/ModArithToArith/ModArithToArith.td
Original file line number Diff line number Diff line change
Expand Up @@ -33,133 +33,4 @@ def ConvertSubIfGE : Pattern<
]
>;


def HasEnoughSpaceAddSub: Constraint<CPred<"llvm::cast<IntegerType>(getElementTypeOrSelf($_self.getType())).getWidth() >= ($0.getValue() - 1).getActiveBits() + 1">,
"underlying type is sufficient for modular add/sub operation without overflow">;

def HasEnoughSpaceMul: Constraint<CPred<"llvm::cast<IntegerType>(getElementTypeOrSelf($_self.getType())).getWidth() >= 2 * ($0.getValue() - 1).getActiveBits()">,
"underlying type is sufficient for modular mul operation without overflow">;

def CastModulusAttributeAddSub : NativeCodeCall<"modulusHelper($0,$1,false)">;
def CastModulusAttributeMul : NativeCodeCall<"modulusHelper($0,$1,true)">;
def CastModulusAttributeReduce : NativeCodeCall<"modulusHelper($0,$1,false,true)">;

def ConvertAddSimple : Pattern<
(ModArith_AddOp:$op $x, $y, $mod),
[
(Arith_AddIOp:$add $x, $y, DefOverflow),
(Arith_RemUIOp $add, (Arith_ConstantOp (CastModulusAttributeAddSub $mod, $x)))
],
[(HasEnoughSpaceAddSub:$op $mod)],
[],
(addBenefit 2)
>;

def ConvertSubSimple : Pattern<
(ModArith_SubOp:$op $x, $y, $mod),
[
(Arith_ConstantOp:$newmod (CastModulusAttributeAddSub $mod, $x)),
(Arith_SubIOp:$sub $x, $y, DefOverflow),
(Arith_AddIOp:$shift $sub, $newmod, DefOverflow),
(Arith_RemUIOp $shift, $newmod)
],
[(HasEnoughSpaceAddSub:$op $mod)],
[],
(addBenefit 2)
>;

def ConvertMulSimple : Pattern<
(ModArith_MulOp:$op $x, $y, $mod),
[
(Arith_MulIOp:$mul $x, $y, DefOverflow),
(Arith_RemUIOp $mul, (Arith_ConstantOp (CastModulusAttributeMul $mod, $x)))
],
[(HasEnoughSpaceMul:$op $mod)],
[],
(addBenefit 2)
>;

def ConvertMacSimple : Pattern<
(ModArith_MacOp:$op $x, $y, $acc, $mod),
[
(Arith_MulIOp:$mul $x, $y, DefOverflow),
(Arith_AddIOp:$add $mul, $acc, DefOverflow),
(Arith_RemUIOp $add, (Arith_ConstantOp (CastModulusAttributeMul $mod, $x)))
],
[(HasEnoughSpaceMul:$op $mod)],
[],
(addBenefit 2)
>;

def ConvertAdd : Pattern<
(ModArith_AddOp $x, $y, $mod),
[
(Arith_ConstantOp:$newmod (CastModulusAttributeAddSub $mod, $x)),
(Arith_AddIOp:$add
(Arith_ExtUIOp $x,
(returnType $newmod)),
(Arith_ExtUIOp $y,
(returnType $newmod)),
DefOverflow),
(Arith_TruncIOp:$res
(Arith_RemUIOp $add, $newmod))
]
>;

def ConvertSub : Pattern<
(ModArith_SubOp $x, $y, $mod),
[
(Arith_ConstantOp:$newmod (CastModulusAttributeAddSub $mod, $x)),
(Arith_SubIOp:$sub
(Arith_ExtUIOp $x,
(returnType $newmod)),
(Arith_ExtUIOp $y,
(returnType $newmod)),
DefOverflow),
(Arith_AddIOp:$shift $sub, $newmod, DefOverflow),
(Arith_TruncIOp:$res
(Arith_RemUIOp $shift, $newmod))
]
>;

def ConvertMul : Pattern<
(ModArith_MulOp $x, $y, $mod),
[
(Arith_ConstantOp:$newmod (CastModulusAttributeMul $mod, $x)),
(Arith_MulIOp:$mul
(Arith_ExtUIOp $x,
(returnType $newmod)),
(Arith_ExtUIOp $y,
(returnType $newmod)),
DefOverflow),
(Arith_TruncIOp:$res
(Arith_RemUIOp $mul, $newmod))
]
>;

def ConvertMac : Pattern<
(ModArith_MacOp $x, $y, $acc, $mod),
[
(Arith_ConstantOp:$newmod (CastModulusAttributeMul $mod, $x)),
(Arith_MulIOp:$mul
(Arith_ExtUIOp $x,
(returnType $newmod)),
(Arith_ExtUIOp $y,
(returnType $newmod)),
DefOverflow),
(Arith_AddIOp:$add $mul,
(Arith_ExtUIOp:$extacc $acc, (returnType $newmod)), DefOverflow),
(Arith_TruncIOp:$res
(Arith_RemUIOp $add, $newmod))
]
>;

def ConvertReduce : Pattern<
(ModArith_ReduceOp $x, $mod),
[
(Arith_ConstantOp:$newmod (CastModulusAttributeReduce $mod, $x)),
(Arith_RemUIOp (Arith_AddIOp (Arith_RemSIOp $x, $newmod), $newmod, DefOverflow), $newmod)
]
>;

#endif // LIB_DIALECT_MODARITH_CONVERSIONS_MODARITHTOARITH_MODARITHTOARITH_TD_
Loading

0 comments on commit a8eb769

Please sign in to comment.