diff --git a/lib/Dialect/ModArith/Conversions/ModArithToArith/BUILD b/lib/Dialect/ModArith/Conversions/ModArithToArith/BUILD index a1a2284c3..20a5c0aec 100644 --- a/lib/Dialect/ModArith/Conversions/ModArithToArith/BUILD +++ b/lib/Dialect/ModArith/Conversions/ModArithToArith/BUILD @@ -14,6 +14,7 @@ cc_library( deps = [ ":pass_inc_gen", "@heir//lib/Dialect/ModArith/IR:Dialect", + "@heir//lib/Utils/ConversionUtils", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", diff --git a/lib/Dialect/ModArith/Conversions/ModArithToArith/ModArithToArith.cpp b/lib/Dialect/ModArith/Conversions/ModArithToArith/ModArithToArith.cpp index 39e0ac185..4cbb18a2c 100644 --- a/lib/Dialect/ModArith/Conversions/ModArithToArith/ModArithToArith.cpp +++ b/lib/Dialect/ModArith/Conversions/ModArithToArith/ModArithToArith.cpp @@ -1,6 +1,8 @@ #include "lib/Dialect/ModArith/Conversions/ModArithToArith/ModArithToArith.h" #include "lib/Dialect/ModArith/IR/ModArithOps.h" +#include "lib/Dialect/ModArith/IR/ModArithTypes.h" +#include "lib/Utils/ConversionUtils/ConversionUtils.h" #include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project #include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project @@ -14,23 +16,207 @@ namespace mod_arith { #define GEN_PASS_DEF_MODARITHTOARITH #include "lib/Dialect/ModArith/Conversions/ModArithToArith/ModArithToArith.h.inc" -/// Returns a possibly extended modulus necessary to compute the given operation -/// without overflow. -template -TypedAttr modulusHelper(IntegerAttr mod, ValueOrOpResult op, bool mul = false, - bool reduce = false) { - auto width = getElementTypeOrSelf(op).getIntOrFloatBitWidth(); - auto modWidth = (mod.getValue() - 1).getActiveBits(); - width = reduce ? width : std::max(width, mul ? 2 * modWidth : modWidth + 1); +IntegerType convertModArithType(ModArithType type) { + APInt modulus = type.getModulus().getValue(); + return IntegerType::get(type.getContext(), modulus.getBitWidth()); +} + +Type convertModArithLikeType(ShapedType type) { + if (auto modArithType = llvm::dyn_cast(type.getElementType())) { + return type.cloneWith(type.getShape(), convertModArithType(modArithType)); + } + return type; +} + +class ModArithToArithTypeConverter : public TypeConverter { + public: + ModArithToArithTypeConverter(MLIRContext *ctx) { + addConversion([](Type type) { return type; }); + addConversion( + [](ModArithType type) -> Type { return convertModArithType(type); }); + addConversion( + [](ShapedType type) -> Type { return convertModArithLikeType(type); }); + } +}; + +// A herlper function to generate the attribute or type +// needed to represent the result of modarith op as an integer +// before applying a remainder operation +template +TypedAttr modulusAttr(Op op, bool mul = false) { + auto type = op.getResult().getType(); + auto modArithType = getResultModArithType(op); + APInt modulus = modArithType.getModulus().getValue(); + + auto width = modulus.getBitWidth(); + if (mul) { + width *= 2; + } + auto intType = IntegerType::get(op.getContext(), width); - auto truncmod = mod.getValue().zextOrTrunc(width); - if (auto st = mlir::dyn_cast(op.getType())) { + auto truncmod = modulus.zextOrTrunc(width); + + if (auto st = mlir::dyn_cast(type)) { auto containerType = st.cloneWith(st.getShape(), intType); return DenseElementsAttr::get(containerType, truncmod); } return IntegerAttr::get(intType, truncmod); } +// used for extui/trunci +template +inline Type modulusType(Op op, bool mul = false) { + return modulusAttr(op, mul).getType(); +} + +struct ConvertEncapsulate : public OpConversionPattern { + ConvertEncapsulate(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + EncapsulateOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceAllUsesWith(op.getResult(), adaptor.getOperands()[0]); + rewriter.eraseOp(op); + return success(); + } +}; + +struct ConvertExtract : public OpConversionPattern { + ConvertExtract(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + ExtractOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceAllUsesWith(op.getResult(), adaptor.getOperands()[0]); + rewriter.eraseOp(op); + return success(); + } +}; + +struct ConvertReduce : public OpConversionPattern { + ConvertReduce(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + auto cmod = b.create(modulusAttr(op)); + // ModArithType ensures cmod can be correctly interpreted as a signed number + auto rems = b.create(adaptor.getOperands()[0], cmod); + auto add = b.create(rems, cmod); + // TODO(#710): better with a subifge + auto remu = b.create(add, cmod); + rewriter.replaceOp(op, remu); + return success(); + } +}; + +// It is assumed inputs are canonical representatives +// ModArithType ensures add/sub result can not overflow +struct ConvertAdd : public OpConversionPattern { + ConvertAdd(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + AddOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + auto cmod = b.create(modulusAttr(op)); + auto add = b.create(adaptor.getLhs(), adaptor.getRhs()); + auto remu = b.create(add, cmod); + + rewriter.replaceOp(op, remu); + return success(); + } +}; + +struct ConvertSub : public OpConversionPattern { + ConvertSub(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + SubOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + auto cmod = b.create(modulusAttr(op)); + auto sub = b.create(adaptor.getLhs(), adaptor.getRhs()); + auto add = b.create(sub, cmod); + auto remu = b.create(add, cmod); + + rewriter.replaceOp(op, remu); + return success(); + } +}; + +struct ConvertMul : public OpConversionPattern { + ConvertMul(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + MulOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + auto cmod = b.create(modulusAttr(op, true)); + auto lhs = + b.create(modulusType(op, true), adaptor.getLhs()); + auto rhs = + b.create(modulusType(op, true), adaptor.getRhs()); + auto mul = b.create(lhs, rhs); + auto remu = b.create(mul, cmod); + auto trunc = b.create(modulusType(op), remu); + + rewriter.replaceOp(op, trunc); + return success(); + } +}; + +struct ConvertMac : public OpConversionPattern { + ConvertMac(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + MacOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + auto cmod = b.create(modulusAttr(op, true)); + auto x = b.create(modulusType(op, true), + adaptor.getOperands()[0]); + auto y = b.create(modulusType(op, true), + adaptor.getOperands()[1]); + auto acc = b.create(modulusType(op, true), + adaptor.getOperands()[2]); + auto mul = b.create(x, y); + auto add = b.create(mul, acc); + auto remu = b.create(add, cmod); + auto trunc = b.create(modulusType(op), remu); + + rewriter.replaceOp(op, trunc); + return success(); + } +}; + namespace rewrites { // In an inner namespace to avoid conflicts with canonicalization patterns #include "lib/Dialect/ModArith/Conversions/ModArithToArith/ModArithToArith.cpp.inc" @@ -104,6 +290,7 @@ struct ModArithToArith : impl::ModArithToArithBase { void ModArithToArith::runOnOperation() { MLIRContext *context = &getContext(); ModuleOp module = getOperation(); + ModArithToArithTypeConverter typeConverter(context); ConversionTarget target(*context); target.addIllegalDialect(); @@ -111,7 +298,11 @@ void ModArithToArith::runOnOperation() { RewritePatternSet patterns(context); rewrites::populateWithGenerated(patterns); - patterns.add(context); + patterns.add( + typeConverter, context); + + addStructuralConversionPatterns(typeConverter, patterns, target); if (failed(applyPartialConversion(module, target, std::move(patterns)))) { signalPassFailure(); diff --git a/lib/Dialect/ModArith/Conversions/ModArithToArith/ModArithToArith.td b/lib/Dialect/ModArith/Conversions/ModArithToArith/ModArithToArith.td index 45f40888d..39d1cdde2 100644 --- a/lib/Dialect/ModArith/Conversions/ModArithToArith/ModArithToArith.td +++ b/lib/Dialect/ModArith/Conversions/ModArithToArith/ModArithToArith.td @@ -33,133 +33,4 @@ def ConvertSubIfGE : Pattern< ] >; - -def HasEnoughSpaceAddSub: Constraint(getElementTypeOrSelf($_self.getType())).getWidth() >= ($0.getValue() - 1).getActiveBits() + 1">, -"underlying type is sufficient for modular add/sub operation without overflow">; - -def HasEnoughSpaceMul: Constraint(getElementTypeOrSelf($_self.getType())).getWidth() >= 2 * ($0.getValue() - 1).getActiveBits()">, - "underlying type is sufficient for modular mul operation without overflow">; - -def CastModulusAttributeAddSub : NativeCodeCall<"modulusHelper($0,$1,false)">; -def CastModulusAttributeMul : NativeCodeCall<"modulusHelper($0,$1,true)">; -def CastModulusAttributeReduce : NativeCodeCall<"modulusHelper($0,$1,false,true)">; - -def ConvertAddSimple : Pattern< - (ModArith_AddOp:$op $x, $y, $mod), - [ - (Arith_AddIOp:$add $x, $y, DefOverflow), - (Arith_RemUIOp $add, (Arith_ConstantOp (CastModulusAttributeAddSub $mod, $x))) - ], - [(HasEnoughSpaceAddSub:$op $mod)], - [], - (addBenefit 2) ->; - -def ConvertSubSimple : Pattern< - (ModArith_SubOp:$op $x, $y, $mod), - [ - (Arith_ConstantOp:$newmod (CastModulusAttributeAddSub $mod, $x)), - (Arith_SubIOp:$sub $x, $y, DefOverflow), - (Arith_AddIOp:$shift $sub, $newmod, DefOverflow), - (Arith_RemUIOp $shift, $newmod) - ], - [(HasEnoughSpaceAddSub:$op $mod)], - [], - (addBenefit 2) ->; - -def ConvertMulSimple : Pattern< - (ModArith_MulOp:$op $x, $y, $mod), - [ - (Arith_MulIOp:$mul $x, $y, DefOverflow), - (Arith_RemUIOp $mul, (Arith_ConstantOp (CastModulusAttributeMul $mod, $x))) - ], - [(HasEnoughSpaceMul:$op $mod)], - [], - (addBenefit 2) ->; - -def ConvertMacSimple : Pattern< - (ModArith_MacOp:$op $x, $y, $acc, $mod), - [ - (Arith_MulIOp:$mul $x, $y, DefOverflow), - (Arith_AddIOp:$add $mul, $acc, DefOverflow), - (Arith_RemUIOp $add, (Arith_ConstantOp (CastModulusAttributeMul $mod, $x))) - ], - [(HasEnoughSpaceMul:$op $mod)], - [], - (addBenefit 2) ->; - -def ConvertAdd : Pattern< - (ModArith_AddOp $x, $y, $mod), - [ - (Arith_ConstantOp:$newmod (CastModulusAttributeAddSub $mod, $x)), - (Arith_AddIOp:$add - (Arith_ExtUIOp $x, - (returnType $newmod)), - (Arith_ExtUIOp $y, - (returnType $newmod)), - DefOverflow), - (Arith_TruncIOp:$res - (Arith_RemUIOp $add, $newmod)) - ] ->; - -def ConvertSub : Pattern< - (ModArith_SubOp $x, $y, $mod), - [ - (Arith_ConstantOp:$newmod (CastModulusAttributeAddSub $mod, $x)), - (Arith_SubIOp:$sub - (Arith_ExtUIOp $x, - (returnType $newmod)), - (Arith_ExtUIOp $y, - (returnType $newmod)), - DefOverflow), - (Arith_AddIOp:$shift $sub, $newmod, DefOverflow), - (Arith_TruncIOp:$res - (Arith_RemUIOp $shift, $newmod)) - ] ->; - -def ConvertMul : Pattern< - (ModArith_MulOp $x, $y, $mod), - [ - (Arith_ConstantOp:$newmod (CastModulusAttributeMul $mod, $x)), - (Arith_MulIOp:$mul - (Arith_ExtUIOp $x, - (returnType $newmod)), - (Arith_ExtUIOp $y, - (returnType $newmod)), - DefOverflow), - (Arith_TruncIOp:$res - (Arith_RemUIOp $mul, $newmod)) - ] ->; - -def ConvertMac : Pattern< - (ModArith_MacOp $x, $y, $acc, $mod), - [ - (Arith_ConstantOp:$newmod (CastModulusAttributeMul $mod, $x)), - (Arith_MulIOp:$mul - (Arith_ExtUIOp $x, - (returnType $newmod)), - (Arith_ExtUIOp $y, - (returnType $newmod)), - DefOverflow), - (Arith_AddIOp:$add $mul, - (Arith_ExtUIOp:$extacc $acc, (returnType $newmod)), DefOverflow), - (Arith_TruncIOp:$res - (Arith_RemUIOp $add, $newmod)) - ] ->; - -def ConvertReduce : Pattern< - (ModArith_ReduceOp $x, $mod), - [ - (Arith_ConstantOp:$newmod (CastModulusAttributeReduce $mod, $x)), - (Arith_RemUIOp (Arith_AddIOp (Arith_RemSIOp $x, $newmod), $newmod, DefOverflow), $newmod) - ] ->; - #endif // LIB_DIALECT_MODARITH_CONVERSIONS_MODARITHTOARITH_MODARITHTOARITH_TD_ diff --git a/lib/Dialect/ModArith/IR/BUILD b/lib/Dialect/ModArith/IR/BUILD index 018fe5dcf..1ee9a0422 100644 --- a/lib/Dialect/ModArith/IR/BUILD +++ b/lib/Dialect/ModArith/IR/BUILD @@ -15,10 +15,12 @@ cc_library( hdrs = [ "ModArithDialect.h", "ModArithOps.h", + "ModArithTypes.h", ], deps = [ ":dialect_inc_gen", ":ops_inc_gen", + ":types_inc_gen", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:IR", @@ -32,6 +34,7 @@ td_library( srcs = [ "ModArithDialect.td", "ModArithOps.td", + "ModArithTypes.td", ], # include from the heir-root to enable fully-qualified include-paths includes = ["../../../.."], @@ -67,6 +70,30 @@ gentbl_cc_library( ], ) +gentbl_cc_library( + name = "types_inc_gen", + tbl_outs = [ + ( + ["-gen-typedef-decls"], + "ModArithTypes.h.inc", + ), + ( + ["-gen-typedef-defs"], + "ModArithTypes.cpp.inc", + ), + ( + ["-gen-typedef-doc"], + "ModArithTypes.md", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "ModArithTypes.td", + deps = [ + ":dialect_inc_gen", + ":td_files", + ], +) + gentbl_cc_library( name = "ops_inc_gen", tbl_outs = [ diff --git a/lib/Dialect/ModArith/IR/ModArithDialect.cpp b/lib/Dialect/ModArith/IR/ModArithDialect.cpp index d0401add1..ce09b56ac 100644 --- a/lib/Dialect/ModArith/IR/ModArithDialect.cpp +++ b/lib/Dialect/ModArith/IR/ModArithDialect.cpp @@ -2,20 +2,26 @@ #include -#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project -#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project - -// NOLINTBEGIN(misc-include-cleaner): Required to define ModArithDialect and -// ModArithOps +#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project +#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project + +// NOLINTBEGIN(misc-include-cleaner): Required to define ModArithDialect, +// ModArithTypes, ModArithOps #include "lib/Dialect/ModArith/IR/ModArithOps.h" +#include "lib/Dialect/ModArith/IR/ModArithTypes.h" #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project // NOLINTEND(misc-include-cleaner) // Generated definitions #include "lib/Dialect/ModArith/IR/ModArithDialect.cpp.inc" +#define GET_TYPEDEF_CLASSES +#include "lib/Dialect/ModArith/IR/ModArithTypes.cpp.inc" + #define GET_OP_CLASSES #include "lib/Dialect/ModArith/IR/ModArithOps.cpp.inc" @@ -24,6 +30,10 @@ namespace heir { namespace mod_arith { void ModArithDialect::initialize() { + addTypes< +#define GET_TYPEDEF_LIST +#include "lib/Dialect/ModArith/IR/ModArithTypes.cpp.inc" + >(); addOperations< #define GET_OP_LIST #include "lib/Dialect/ModArith/IR/ModArithOps.cpp.inc" @@ -32,45 +42,65 @@ void ModArithDialect::initialize() { /// Ensures that the underlying integer type is wide enough for the coefficient template -LogicalResult verifyModArithOpMod(OpType op, bool reduce = false) { - auto type = - llvm::cast(getElementTypeOrSelf(op.getResult().getType())); - unsigned bitWidth = type.getWidth(); - unsigned modWidth = (op.getModulus() - 1).getActiveBits(); - if (modWidth > bitWidth) - return op.emitOpError() - << "underlying type's bitwidth must be at least as " - << "large as the modulus bitwidth, but got " << bitWidth - << " while modulus requires width " << modWidth << "."; - if (reduce && modWidth == bitWidth) +LogicalResult verifyModArithType(OpType op, ModArithType type) { + APInt modulus = type.getModulus().getValue(); + unsigned bitWidth = modulus.getBitWidth(); + unsigned modWidth = modulus.getActiveBits(); + if (modWidth > bitWidth - 1) return op.emitOpError() - << "underlying type's bitwidth must be larger than " + << "underlying type's bitwidth must be 1 bit larger than " << "the modulus bitwidth, but got " << bitWidth << " while modulus requires width " << modWidth << "."; - if (!type.isUnsigned() && modWidth == bitWidth) - emitWarning(op.getLoc()) - << "for signed (or signless) underlying types, the bitwidth of the " - "underlying type must be at least as large as modulus bitwidth + " - "1 (for the sign bit), but found " - << bitWidth << " while modulus requires width " << modWidth << "."; - - if (op.getModulus().slt(0)) + return success(); +} + +template +LogicalResult verifySameWidth(OpType op, ModArithType modArithType, + IntegerType integerType) { + unsigned bitWidth = modArithType.getModulus().getValue().getBitWidth(); + unsigned intWidth = integerType.getWidth(); + if (intWidth != bitWidth) return op.emitOpError() - << "provided modulus " << op.getModulus().getSExtValue() - << " is not a positive integer."; + << "the result integer type should be of the same width as the " + << "mod arith type width, but got " << intWidth + << " while mod arith type width " << bitWidth << "."; return success(); } -LogicalResult AddOp::verify() { return verifyModArithOpMod(*this); } +LogicalResult EncapsulateOp::verify() { + auto modArithType = getResultModArithType(*this); + auto integerType = getOperandIntegerType(*this); + auto result = verifySameWidth(*this, modArithType, integerType); + if (result.failed()) return result; + return verifyModArithType(*this, getResultModArithType(*this)); +} -LogicalResult SubOp::verify() { return verifyModArithOpMod(*this); } +LogicalResult ExtractOp::verify() { + auto modArithType = getOperandModArithType(*this); + auto integerType = getResultIntegerType(*this); + auto result = verifySameWidth(*this, modArithType, integerType); + if (result.failed()) return result; + return verifyModArithType(*this, modArithType); +} -LogicalResult MulOp::verify() { return verifyModArithOpMod(*this); } +LogicalResult ReduceOp::verify() { + return verifyModArithType(*this, getResultModArithType(*this)); +} -LogicalResult MacOp::verify() { return verifyModArithOpMod(*this); } +LogicalResult AddOp::verify() { + return verifyModArithType(*this, getResultModArithType(*this)); +} -LogicalResult ReduceOp::verify() { - return verifyModArithOpMod(*this, true); +LogicalResult SubOp::verify() { + return verifyModArithType(*this, getResultModArithType(*this)); +} + +LogicalResult MulOp::verify() { + return verifyModArithType(*this, getResultModArithType(*this)); +} + +LogicalResult MacOp::verify() { + return verifyModArithType(*this, getResultModArithType(*this)); } LogicalResult BarrettReduceOp::verify() { diff --git a/lib/Dialect/ModArith/IR/ModArithDialect.td b/lib/Dialect/ModArith/IR/ModArithDialect.td index 9d35047e8..0573b2feb 100644 --- a/lib/Dialect/ModArith/IR/ModArithDialect.td +++ b/lib/Dialect/ModArith/IR/ModArithDialect.td @@ -10,6 +10,8 @@ def ModArith_Dialect : Dialect { }]; let cppNamespace = "::mlir::heir::mod_arith"; + let useDefaultTypePrinterParser = 1; + let dependentDialects = [ "arith::ArithDialect", ]; diff --git a/lib/Dialect/ModArith/IR/ModArithOps.h b/lib/Dialect/ModArith/IR/ModArithOps.h index 4850a0e3b..d9250a4bc 100644 --- a/lib/Dialect/ModArith/IR/ModArithOps.h +++ b/lib/Dialect/ModArith/IR/ModArithOps.h @@ -3,6 +3,7 @@ // NOLINTBEGIN(misc-include-cleaner): Required to define ModArithOps #include "lib/Dialect/ModArith/IR/ModArithDialect.h" +#include "lib/Dialect/ModArith/IR/ModArithTypes.h" #include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project // NOLINTEND(misc-include-cleaner) @@ -10,4 +11,32 @@ #define GET_OP_CLASSES #include "lib/Dialect/ModArith/IR/ModArithOps.h.inc" +namespace mlir { +namespace heir { +namespace mod_arith { + +template +inline ModArithType getResultModArithType(OpType op) { + return cast(getElementTypeOrSelf(op.getResult().getType())); +} + +template +inline ModArithType getOperandModArithType(OpType op) { + return cast(getElementTypeOrSelf(op.getOperand().getType())); +} + +template +inline IntegerType getResultIntegerType(OpType op) { + return cast(getElementTypeOrSelf(op.getResult().getType())); +} + +template +inline IntegerType getOperandIntegerType(OpType op) { + return cast(getElementTypeOrSelf(op.getOperand().getType())); +} + +} // namespace mod_arith +} // namespace heir +} // namespace mlir + #endif // LIB_DIALECT_MODARITH_IR_MODARITHOPS_H_ diff --git a/lib/Dialect/ModArith/IR/ModArithOps.td b/lib/Dialect/ModArith/IR/ModArithOps.td index 56e9cda91..0bb0bef04 100644 --- a/lib/Dialect/ModArith/IR/ModArithOps.td +++ b/lib/Dialect/ModArith/IR/ModArithOps.td @@ -2,6 +2,7 @@ #define LIB_DIALECT_MODARITH_IR_MODARITHOPS_TD_ include "lib/Dialect/ModArith/IR/ModArithDialect.td" +include "lib/Dialect/ModArith/IR/ModArithTypes.td" include "mlir/IR/BuiltinAttributes.td" include "mlir/IR/CommonTypeConstraints.td" include "mlir/IR/OpBase.td" @@ -15,10 +16,86 @@ class ModArith_Op traits = [Pure]> : let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; } +// type conversion operations +def ModArith_EncapsulateOp : ModArith_Op<"encapsulate", [Pure, ElementwiseMappable]> { + let summary = "encapsulate an integer into a mod_arith type"; + + let description = [{ + `mod_arith.encapsulate` converts the integer to be of mod_arith type. + + It is required that the bitwidth of the input integer type is the same + as that of the storage type of the output mod_arith type. + + Examples: + ``` + mod_arith.encapsulate %c0 : i32 -> mod_arith.mod_arith<65537 : i32> + mod_arith.encapsulate %c1 : i64 -> mod_arith.mod_arith<65537> + ``` + }]; + + let arguments = (ins + SignlessIntegerLike:$input + ); + let results = (outs ModArithLike:$output); + let hasVerifier = 1; + let assemblyFormat = "operands attr-dict `:` type($input) `->` type($output)"; +} + +def ModArith_ExtractOp : ModArith_Op<"extract", [Pure, ElementwiseMappable]> { + let summary = "extract the integer stored inside mod_arith type"; + + let description = [{ + `mod_arith.extract` extracts the integer inside the mod_arith type. + + It is required that the bitwidth of the output integer type is the same + as that of the storage type of the input mod_arith type. + + Examples: + ``` + %m0 = mod_arith.encapsulate %c0 : i32 -> mod_arith.mod_arith<65537 : i32> + %m1 = mod_arith.encapsulate %c1 : i64 -> mod_arith.mod_arith<65537> + %c2 = mod_arith.extract %m0 : mod_arith.mod_arith<65537 : i32> -> i32 + %c3 = mod_arith.extract %m1 : mod_arith.mod_arith<65537> -> i64 + ``` + }]; + + let arguments = (ins + ModArithLike:$input + ); + let results = (outs SignlessIntegerLike:$output); + let hasVerifier = 1; + let assemblyFormat = "operands attr-dict `:` type($input) `->` type($output)"; +} + +def ModArith_ReduceOp : ModArith_Op<"reduce", [Pure, ElementwiseMappable, SameOperandsAndResultType]> { + let summary = "reduce the mod arith type to its canonical representative"; + + let description = [{ + `mod_arith.reduce x` produces $y$, the canonical representative in $[0, q)$ + such that $x \equiv y \mod q$. + + Examples: + ``` + %c0 = arith.constant 65538 : i32 + %m0 = mod_arith.encapsulate %c0 : i32 -> mod_arith.mod_arith<65537 : i32> + // mod_arith.extract %m0 produces 65538 + %m1 = mod_arith.reduce %m0 : mod_arith.mod_arith<65537: i32> + // mod_arith.extract %m1 produces 1 + ``` + }]; + + let arguments = (ins + ModArithLike:$input + ); + let results = (outs ModArithLike:$output); + let hasVerifier = 1; + let assemblyFormat = "operands attr-dict `:` type($output)"; +} + class ModArith_BinaryOp traits = []> : ModArith_Op, - Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs, APIntAttr:$modulus)>, - Results<(outs SignlessIntegerLike:$output)> { + Arguments<(ins ModArithLike:$lhs, ModArithLike:$rhs)>, + Results<(outs ModArithLike:$output)> { let hasVerifier = 1; let assemblyFormat ="operands attr-dict `:` type($output)"; } @@ -26,21 +103,30 @@ class ModArith_BinaryOp traits = []> : def ModArith_AddOp : ModArith_BinaryOp<"add", [Commutative]> { let summary = "modular addition operation"; let description = [{ - Computes addition modulo a statically known modulus $q$. + Computes modular addition. + + Unless otherwise specified, the operation assumes both inputs are canonical + representatives and guarantees the output being canonical representative. }]; } def ModArith_SubOp : ModArith_BinaryOp<"sub"> { let summary = "modular subtraction operation"; let description = [{ - Computes subtraction modulo a statically known modulus $q$. + Computes modular subtraction. + + Unless otherwise specified, the operation assumes both inputs are canonical + representatives and guarantees the output being canonical representative. }]; } def ModArith_MulOp : ModArith_BinaryOp<"mul", [Commutative]> { let summary = "modular multiplication operation"; let description = [{ - Computes multiplication modulo a statically known modulus $q$. + Computes modular multiplication. + + Unless otherwise specified, the operation assumes both inputs are canonical + representatives and guarantees the output being canonical representative. }]; } @@ -48,34 +134,18 @@ def ModArith_MacOp : ModArith_Op<"mac", [SameOperandsAndResultType, Pure, Elemen let summary = "modular multiplication-and-accumulation operation"; let description = [{ - `mod_arith.mac x, y, z {modulus = q}` computes $(x * y) + z \mod q$ - }]; - let arguments = (ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs, SignlessIntegerLike:$acc, APIntAttr:$modulus); - let results = (outs SignlessIntegerLike:$output); - let hasVerifier = 1; - let assemblyFormat = "operands attr-dict `:` type($output)"; -} - -def ModArith_ReduceOp : ModArith_Op<"reduce", [SameOperandsAndResultType, Pure, ElementwiseMappable]> { - let summary = "reduce a signed integer to its congruence modulo equivalent"; - - let description = [{ - `mod_arith.reduce x {modulus = q}` computes $y \in [0, q)$ such that - $x \equiv y \mod n$. + `mod_arith.mac x, y, z` computes $(x * y) + z$ - Note this will interpret `x` as a signed integer. It is required the bitwidth of `q` is smaller - than that of `x`. For an unsigned integer, equivalent functionality is: `y = arith.remui x`. + Unless otherwise specified, the operation assumes all inputs are canonical + representatives and guarantees the output being canonical representative. }]; - - let arguments = (ins - SignlessIntegerLike:$input, - APIntAttr:$modulus - ); - let results = (outs SignlessIntegerLike:$output); + let arguments = (ins ModArithLike:$lhs, ModArithLike:$rhs, ModArithLike:$acc); + let results = (outs ModArithLike:$output); let hasVerifier = 1; let assemblyFormat = "operands attr-dict `:` type($output)"; } +// TODO(#1084): migrate barrett/subifge to mod arith type def ModArith_BarrettReduceOp : ModArith_Op<"barrett_reduce", [SameOperandsAndResultType]> { let summary = "Compute the first step of the Barrett reduction."; let description = [{ @@ -83,7 +153,7 @@ def ModArith_BarrettReduceOp : ModArith_Op<"barrett_reduce", [SameOperandsAndRes smallest bit-width that contains the range $[0, q)$. The Barrett reduce operation computes `barret_reduce x = x - floor(x * floor(b / q) / b) * q`. - Given $0 <= x < q^2$, then this will compute $(x \mod q)$ or $(x \mod q) + p$. + Given $0 <= x < q^2$, then this will compute $(x \mod q)$ or $(x \mod q) + q$. }]; let arguments = (ins diff --git a/lib/Dialect/ModArith/IR/ModArithTypes.h b/lib/Dialect/ModArith/IR/ModArithTypes.h new file mode 100644 index 000000000..a4411a13d --- /dev/null +++ b/lib/Dialect/ModArith/IR/ModArithTypes.h @@ -0,0 +1,9 @@ +#ifndef LIB_DIALECT_MODARITH_IR_MODARITHTYPES_H_ +#define LIB_DIALECT_MODARITH_IR_MODARITHTYPES_H_ + +#include "lib/Dialect/ModArith/IR/ModArithDialect.h" + +#define GET_TYPEDEF_CLASSES +#include "lib/Dialect/ModArith/IR/ModArithTypes.h.inc" + +#endif // LIB_DIALECT_MODARITH_IR_MODARITHTYPES_H_ diff --git a/lib/Dialect/ModArith/IR/ModArithTypes.td b/lib/Dialect/ModArith/IR/ModArithTypes.td new file mode 100644 index 000000000..27e66d12b --- /dev/null +++ b/lib/Dialect/ModArith/IR/ModArithTypes.td @@ -0,0 +1,56 @@ +#ifndef LIB_TYPES_MODARITH_IR_MODARITHTYPES_TD_ +#define LIB_TYPES_MODARITH_IR_MODARITHTYPES_TD_ + +include "lib/Dialect/ModArith/IR/ModArithDialect.td" + +include "mlir/IR/DialectBase.td" +include "mlir/IR/BuiltinTypeInterfaces.td" +include "mlir/IR/AttrTypeBase.td" + +class ModArith_Type + : TypeDef { + let mnemonic = typeMnemonic; +} + +def ModArith_ModArith : ModArith_Type<"ModArith", "mod_arith"> { + let summary = "Integer type with modular arithmetic"; + let description = [{ + `mod_arith.mod_arith

` represents an element of the ring of integers modulo $p$. + The `modulus` attribute is the ring modulus, and `mod_arith` operations lower to + `arith` operations that produce results in the range `[0, modulus)`, often called + the _canonical representative_. + + `modulus` is specified with an integer type suffix, for example, + `mod_arith.mod_arith<65537 : i32>`. This corresponds to the storage type for the + modulus, and is `i64` by default. + + It is required that the underlying integer type should be larger than + twice the modulus (have one extra bit of storage space) to avoid signedness + issues. For example, when `modulus == 2 ** 32 - 1`, the underlying type + for the modulus should be at least `i33`, though `i64` is a natural choice. + + Passes may allow intermediate values that do not always produce a + canonical representative in `[0, modulus)`. For example, if the machine storage + type is `i64`, but the `modulus` fits within an `i32`, a lowering could + allow intermediate arithmetic values to grow to as large as an `i64` before + reducing them. However, all passes must ensure that values used outside + the local scope (e.g., function return values or arguments to calls to linked + functions) are appropriately reduced to the canonical representative. + `modulus` is the modulus the arithmetic working with. + + Examples: + ``` + !Zp1 = !mod_arith.mod_arith<7> // implicitly being i64 + !Zp2 = !mod_arith.mod_arith<65537 : i32> + !Zp3 = !mod_arith.mod_arith<536903681 : i64> + ``` + }]; + let parameters = (ins + "::mlir::IntegerAttr":$modulus + ); + let assemblyFormat = "`<` $modulus `>`"; +} + +def ModArithLike: TypeOrContainer; + +#endif // LIB_TYPES_MODARITH_IR_MODARITHTYPES_TD_ diff --git a/lib/Dialect/Polynomial/Transforms/NTTRewrites.cpp b/lib/Dialect/Polynomial/Transforms/NTTRewrites.cpp index bc8b3a5cd..dd839314c 100644 --- a/lib/Dialect/Polynomial/Transforms/NTTRewrites.cpp +++ b/lib/Dialect/Polynomial/Transforms/NTTRewrites.cpp @@ -21,7 +21,8 @@ struct PolyMulToNTT : impl::PolyMulToNTTBase { void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); - patterns.add(patterns.getContext()); + // TODO(#1095): migrate to mod arith type + // patterns.add(patterns.getContext()); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } }; diff --git a/lib/Dialect/Polynomial/Transforms/NTTRewrites.td b/lib/Dialect/Polynomial/Transforms/NTTRewrites.td index 42d3bdb92..939cd5b51 100644 --- a/lib/Dialect/Polynomial/Transforms/NTTRewrites.td +++ b/lib/Dialect/Polynomial/Transforms/NTTRewrites.td @@ -30,24 +30,25 @@ def HasDegreePowerOfTwo : Constraint< def Nullptr : NativeCodeCall<"nullptr">; -def NTTRewritePolyMul : Pattern< - (Polynomial_MulOp:$mulOp $p1, $p2), - [ - // Transform to NTT point-value representation - (Polynomial_NTTOp:$p1NTT $p1, (Nullptr), - (returnType (InputTensorType (GetRingAttr $p1)))), - (Polynomial_NTTOp:$p2NTT $p2, (Nullptr), - (returnType (InputTensorType (GetRingAttr $p2)))), - - // Compute elementwise multiplication modulo cmod - (ModArith_MulOp:$mulNTT $p1NTT, $p2NTT, (GetRingModAttr $p1)), - - // Compute inverse transform back to coefficient representation - (Polynomial_INTTOp:$res $mulNTT, (Nullptr)) - ], - [ - (HasDegreePowerOfTwo $p1) - ] ->; +// TODO(#1095): migrate to mod arith type +// def NTTRewritePolyMul : Pattern< +// (Polynomial_MulOp:$mulOp $p1, $p2), +// [ +// // Transform to NTT point-value representation +// (Polynomial_NTTOp:$p1NTT $p1, (Nullptr), +// (returnType (InputTensorType (GetRingAttr $p1)))), +// (Polynomial_NTTOp:$p2NTT $p2, (Nullptr), +// (returnType (InputTensorType (GetRingAttr $p2)))), +// +// // Compute elementwise multiplication modulo cmod +// (ModArith_MulOp:$mulNTT $p1NTT, $p2NTT, (GetRingModAttr $p1)), +// +// // Compute inverse transform back to coefficient representation +// (Polynomial_INTTOp:$res $mulNTT, (Nullptr)) +// ], +// [ +// (HasDegreePowerOfTwo $p1) +// ] +// >; #endif // LIB_DIALECT_POLYNOMIAL_TRANSFORMS_NTTREWRITES_TD_ diff --git a/tests/Dialect/ModArith/Conversions/mod_arith_to_arith/mod-arith-to-arith.mlir b/tests/Dialect/ModArith/Conversions/mod_arith_to_arith/mod-arith-to-arith.mlir index fc64b847d..7c344d9b6 100644 --- a/tests/Dialect/ModArith/Conversions/mod_arith_to_arith/mod-arith-to-arith.mlir +++ b/tests/Dialect/ModArith/Conversions/mod_arith_to_arith/mod-arith-to-arith.mlir @@ -1,272 +1,182 @@ // RUN: heir-opt -mod-arith-to-arith --split-input-file %s | FileCheck %s --enable-var-scope -// CHECK-LABEL: @test_lower_simple_add -// CHECK-SAME: (%[[LHS:.*]]: [[TYPE:.*]], %[[RHS:.*]]: [[TYPE]]) -> [[TYPE]] { -func.func @test_lower_simple_add(%lhs : i8, %rhs : i8) -> i8 { - // CHECK-NOT: mod_arith.add - // CHECK: %[[ADD:.*]] = arith.addi %[[LHS]], %[[RHS]] : [[TYPE]] - // CHECK: %[[CMOD:.*]] = arith.constant 17 : [[TYPE]] - // CHECK: %[[REM:.*]] = arith.remui %[[ADD]], %[[CMOD]] : [[TYPE]] - // CHECK: return %[[REM]] : [[TYPE]] - %res = mod_arith.add %lhs, %rhs {modulus = 17 }: i8 - return %res : i8 +!Zp = !mod_arith.mod_arith<65537 : i32> +!Zpv = tensor<4x!Zp> + +// CHECK-LABEL: @test_lower_encapsulate +// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]]) -> [[T]] { +func.func @test_lower_encapsulate(%lhs : i32) -> !Zp { + // CHECK-NOT: mod_arith.encapsulate + // CHECK: return %[[LHS]] : [[T]] + %res = mod_arith.encapsulate %lhs: i32 -> !Zp + return %res : !Zp } -// CHECK-LABEL: @test_lower_simple_add_vec -// CHECK-SAME: (%[[LHS:.*]]: [[TYPE:.*]], %[[RHS:.*]]: [[TYPE]]) -> [[TYPE]] { -func.func @test_lower_simple_add_vec(%lhs : tensor<4xi8>, %rhs : tensor<4xi8>) -> tensor<4xi8> { - // CHECK-NOT: mod_arith.add - // CHECK: %[[ADD:.*]] = arith.addi %[[LHS]], %[[RHS]] : [[TYPE]] - // CHECK: %[[CMOD:.*]] = arith.constant dense<17> : [[TYPE]] - // CHECK: %[[REM:.*]] = arith.remui %[[ADD]], %[[CMOD]] : [[TYPE]] - // CHECK: return %[[REM]] : [[TYPE]] - %res = mod_arith.add %lhs, %rhs {modulus = 17}: tensor<4xi8> - return %res : tensor<4xi8> +// CHECK-LABEL: @test_lower_encapsulate_vec +// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]]) -> [[T]] { +func.func @test_lower_encapsulate_vec(%lhs : tensor<4xi32>) -> !Zpv { + // CHECK-NOT: mod_arith.encapsulate + // CHECK: return %[[LHS]] : [[T]] + %res = mod_arith.encapsulate %lhs: tensor<4xi32> -> !Zpv + return %res : !Zpv } -// CHECK-LABEL: @test_lower_add -// CHECK-SAME: (%[[LHS:.*]]: [[TYPE:.*]], %[[RHS:.*]]: [[TYPE]]) -> [[TYPE]] { -func.func @test_lower_add(%lhs : i8, %rhs : i8) -> i8 { - // CHECK-NOT: mod_arith.add - // CHECK: %[[CMOD:.*]] = arith.constant 217 : [[INTERMEDIATE_TYPE:.*]] - // CHECK: %[[EXT0:.*]] = arith.extui %[[LHS]] : [[TYPE]] to [[INTERMEDIATE_TYPE]] - // CHECK: %[[EXT1:.*]] = arith.extui %[[RHS]] : [[TYPE]] to [[INTERMEDIATE_TYPE]] - // CHECK: %[[ADD:.*]] = arith.addi %[[EXT0]], %[[EXT1]] : [[INTERMEDIATE_TYPE]] - // CHECK: %[[REM:.*]] = arith.remui %[[ADD]], %[[CMOD]] : [[INTERMEDIATE_TYPE]] - // CHECK: %[[TRUNC:.*]] = arith.trunci %[[REM]] : [[INTERMEDIATE_TYPE]] to [[TYPE]] - // CHECK: return %[[TRUNC]] : [[TYPE]] - %res = mod_arith.add %lhs, %rhs {modulus = 217 }: i8 - return %res : i8 +// CHECK-LABEL: @test_lower_extract +// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]]) -> [[T]] { +func.func @test_lower_extract(%lhs : !Zp) -> i32 { + // CHECK-NOT: mod_arith.extract + // CHECK: return %[[LHS]] : [[T]] + %res = mod_arith.extract %lhs: !Zp -> i32 + return %res : i32 } -// CHECK-LABEL: @test_lower_add_vec -// CHECK-SAME: (%[[LHS:.*]]: [[TYPE:.*]], %[[RHS:.*]]: [[TYPE]]) -> [[TYPE]] { -func.func @test_lower_add_vec(%lhs : tensor<4xi8>, %rhs : tensor<4xi8>) -> tensor<4xi8> { - // CHECK-NOT: mod_arith.add - // CHECK: %[[CMOD:.*]] = arith.constant dense<217> : [[INTERMEDIATE_TYPE:.*]] - // CHECK: %[[EXT0:.*]] = arith.extui %[[LHS]] : [[TYPE]] to [[INTERMEDIATE_TYPE]] - // CHECK: %[[EXT1:.*]] = arith.extui %[[RHS]] : [[TYPE]] to [[INTERMEDIATE_TYPE]] - // CHECK: %[[ADD:.*]] = arith.addi %[[EXT0]], %[[EXT1]] : [[INTERMEDIATE_TYPE]] - // CHECK: %[[REM:.*]] = arith.remui %[[ADD]], %[[CMOD]] : [[INTERMEDIATE_TYPE]] - // CHECK: %[[TRUNC:.*]] = arith.trunci %[[REM]] : [[INTERMEDIATE_TYPE]] to [[TYPE]] - // CHECK: return %[[TRUNC]] : [[TYPE]] - %res = mod_arith.add %lhs, %rhs {modulus = 217 }: tensor<4xi8> - return %res : tensor<4xi8> +// CHECK-LABEL: @test_lower_extract_vec +// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]]) -> [[T]] { +func.func @test_lower_extract_vec(%lhs : !Zpv) -> tensor<4xi32> { + // CHECK-NOT: mod_arith.extract + // CHECK: return %[[LHS]] : [[T]] + %res = mod_arith.extract %lhs: !Zpv -> tensor<4xi32> + return %res : tensor<4xi32> } -// CHECK-LABEL: @test_lower_simple_sub -// CHECK-SAME: (%[[LHS:.*]]: [[TYPE:.*]], %[[RHS:.*]]: [[TYPE]]) -> [[TYPE]] { -func.func @test_lower_simple_sub(%lhs : i8, %rhs : i8) -> i8 { - // CHECK-NOT: mod_arith.sub - // CHECK: %[[CMOD:.*]] = arith.constant 17 : [[TYPE]] - // CHECK: %[[SUB:.*]] = arith.subi %[[LHS]], %[[RHS]] : [[TYPE]] - // CHECK: %[[SHIFT:.*]] = arith.addi %[[SUB]], %[[CMOD]] : [[TYPE]] - // CHECK: %[[REM:.*]] = arith.remui %[[SHIFT]], %[[CMOD]] : [[TYPE]] - // CHECK: return %[[REM]] : [[TYPE]] - %res = mod_arith.sub %lhs, %rhs {modulus = 17}: i8 - return %res : i8 +// CHECK-LABEL: @test_lower_reduce +// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]]) -> [[T]] { +func.func @test_lower_reduce(%lhs : !Zp) -> !Zp { + // CHECK-NOT: mod_arith.reduce + // CHECK: %[[CMOD:.*]] = arith.constant 65537 : [[T]] + // CHECK: %[[REMS:.*]] = arith.remsi %[[LHS]], %[[CMOD]] : [[T]] + // CHECK: %[[ADD:.*]] = arith.addi %[[REMS]], %[[CMOD]] : [[T]] + // CHECK: %[[REM:.*]] = arith.remui %[[ADD]], %[[CMOD]] : [[T]] + // CHECK: return %[[REM]] : [[T]] + %res = mod_arith.reduce %lhs: !Zp + return %res : !Zp +} + +// CHECK-LABEL: @test_lower_reduce_vec +// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]]) -> [[T]] { +func.func @test_lower_reduce_vec(%lhs : !Zpv) -> !Zpv { + // CHECK-NOT: mod_arith.reduce + // CHECK: %[[CMOD:.*]] = arith.constant dense<65537> : [[T]] + // CHECK: %[[REMS:.*]] = arith.remsi %[[LHS]], %[[CMOD]] : [[T]] + // CHECK: %[[ADD:.*]] = arith.addi %[[REMS]], %[[CMOD]] : [[T]] + // CHECK: %[[REM:.*]] = arith.remui %[[ADD]], %[[CMOD]] : [[T]] + // CHECK: return %[[REM]] : [[T]] + %res = mod_arith.reduce %lhs: !Zpv + return %res : !Zpv } -// CHECK-LABEL: @test_lower_simple_sub_vec -// CHECK-SAME: (%[[LHS:.*]]: [[TYPE:.*]], %[[RHS:.*]]: [[TYPE]]) -> [[TYPE]] { -func.func @test_lower_simple_sub_vec(%lhs : tensor<4xi8>, %rhs : tensor<4xi8>) -> tensor<4xi8> { - // CHECK-NOT: mod_arith.sub - // CHECK: %[[CMOD:.*]] = arith.constant dense<17> : [[TYPE]] - // CHECK: %[[SUB:.*]] = arith.subi %[[LHS]], %[[RHS]] : [[TYPE]] - // CHECK: %[[SHIFT:.*]] = arith.addi %[[SUB]], %[[CMOD]] : [[TYPE]] - // CHECK: %[[REM:.*]] = arith.remui %[[SHIFT]], %[[CMOD]] : [[TYPE]] - // CHECK: return %[[REM]] : [[TYPE]] - %res = mod_arith.sub %lhs, %rhs {modulus = 17}: tensor<4xi8> - return %res : tensor<4xi8> +// CHECK-LABEL: @test_lower_add +// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]], %[[RHS:.*]]: [[T]]) -> [[T]] { +func.func @test_lower_add(%lhs : !Zp, %rhs : !Zp) -> !Zp { + // CHECK-NOT: mod_arith.add + // CHECK: %[[CMOD:.*]] = arith.constant 65537 : [[T]] + // CHECK: %[[ADD:.*]] = arith.addi %[[LHS]], %[[RHS]] : [[T]] + // CHECK: %[[REM:.*]] = arith.remui %[[ADD]], %[[CMOD]] : [[T]] + // CHECK: return %[[REM]] : [[T]] + %res = mod_arith.add %lhs, %rhs : !Zp + return %res : !Zp +} + +// CHECK-LABEL: @test_lower_add_vec +// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]], %[[RHS:.*]]: [[T]]) -> [[T]] { +func.func @test_lower_add_vec(%lhs : !Zpv, %rhs : !Zpv) -> !Zpv { + // CHECK-NOT: mod_arith.add + // CHECK: %[[CMOD:.*]] = arith.constant dense<65537> : [[T]] + // CHECK: %[[ADD:.*]] = arith.addi %[[LHS]], %[[RHS]] : [[T]] + // CHECK: %[[REM:.*]] = arith.remui %[[ADD]], %[[CMOD]] : [[T]] + // CHECK: return %[[REM]] : [[T]] + %res = mod_arith.add %lhs, %rhs : !Zpv + return %res : !Zpv } // CHECK-LABEL: @test_lower_sub -// CHECK-SAME: (%[[LHS:.*]]: [[TYPE:.*]], %[[RHS:.*]]: [[TYPE]]) -> [[TYPE]] { -func.func @test_lower_sub(%lhs : i8, %rhs : i8) -> i8 { +// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]], %[[RHS:.*]]: [[T]]) -> [[T]] { +func.func @test_lower_sub(%lhs : !Zp, %rhs : !Zp) -> !Zp { // CHECK-NOT: mod_arith.sub - // CHECK: %[[CMOD:.*]] = arith.constant 217 : [[INTERMEDIATE_TYPE:.*]] - // CHECK: %[[EXT0:.*]] = arith.extui %[[LHS]] : [[TYPE]] to [[INTERMEDIATE_TYPE]] - // CHECK: %[[EXT1:.*]] = arith.extui %[[RHS]] : [[TYPE]] to [[INTERMEDIATE_TYPE]] - // CHECK: %[[SUB:.*]] = arith.subi %[[EXT0]], %[[EXT1]] : [[INTERMEDIATE_TYPE]] - // CHECK: %[[SHIFT:.*]] = arith.addi %[[SUB]], %[[CMOD]] : [[INTERMEDIATE_TYPE]] - // CHECK: %[[REM:.*]] = arith.remui %[[SHIFT]], %[[CMOD]] : [[INTERMEDIATE_TYPE]] - // CHECK: %[[TRUNC:.*]] = arith.trunci %[[REM]] : [[INTERMEDIATE_TYPE]] to [[TYPE]] - // CHECK: return %[[TRUNC]] : [[TYPE]] - %res = mod_arith.sub %lhs, %rhs {modulus = 217 }: i8 - return %res : i8 + // CHECK: %[[CMOD:.*]] = arith.constant 65537 : [[T]] + // CHECK: %[[SUB:.*]] = arith.subi %[[LHS]], %[[RHS]] : [[T]] + // CHECK: %[[ADD:.*]] = arith.addi %[[SUB]], %[[CMOD]] : [[T]] + // CHECK: %[[REM:.*]] = arith.remui %[[ADD]], %[[CMOD]] : [[T]] + // CHECK: return %[[REM]] : [[T]] + %res = mod_arith.sub %lhs, %rhs : !Zp + return %res : !Zp } // CHECK-LABEL: @test_lower_sub_vec -// CHECK-SAME: (%[[LHS:.*]]: [[TYPE:.*]], %[[RHS:.*]]: [[TYPE]]) -> [[TYPE]] { -func.func @test_lower_sub_vec(%lhs : tensor<4xi8>, %rhs : tensor<4xi8>) -> tensor<4xi8> { +// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]], %[[RHS:.*]]: [[T]]) -> [[T]] { +func.func @test_lower_sub_vec(%lhs : !Zpv, %rhs : !Zpv) -> !Zpv { // CHECK-NOT: mod_arith.sub - // CHECK: %[[CMOD:.*]] = arith.constant dense<217> : [[INTERMEDIATE_TYPE:.*]] - // CHECK: %[[EXT0:.*]] = arith.extui %[[LHS]] : [[TYPE]] to [[INTERMEDIATE_TYPE]] - // CHECK: %[[EXT1:.*]] = arith.extui %[[RHS]] : [[TYPE]] to [[INTERMEDIATE_TYPE]] - // CHECK: %[[SUB:.*]] = arith.subi %[[EXT0]], %[[EXT1]] : [[INTERMEDIATE_TYPE]] - // CHECK: %[[SHIFT:.*]] = arith.addi %[[SUB]], %[[CMOD]] : [[INTERMEDIATE_TYPE]] - // CHECK: %[[REM:.*]] = arith.remui %[[SHIFT]], %[[CMOD]] : [[INTERMEDIATE_TYPE]] - // CHECK: %[[TRUNC:.*]] = arith.trunci %[[REM]] : [[INTERMEDIATE_TYPE]] to [[TYPE]] - // CHECK: return %[[TRUNC]] : [[TYPE]] - %res = mod_arith.sub %lhs, %rhs {modulus = 217 }: tensor<4xi8> - return %res : tensor<4xi8> -} - -// CHECK-LABEL: @test_lower_simple_mul -// CHECK-SAME: (%[[LHS:.*]]: [[TYPE:.*]], %[[RHS:.*]]: [[TYPE]]) -> [[TYPE]] { -func.func @test_lower_simple_mul(%lhs : i16, %rhs : i16) -> i16 { - // CHECK-NOT: mod_arith.mul - // CHECK: %[[MUL:.*]] = arith.muli %[[LHS]], %[[RHS]] : [[TYPE]] - // CHECK: %[[CMOD:.*]] = arith.constant 17 : [[TYPE]] - // CHECK: %[[REM:.*]] = arith.remui %[[MUL]], %[[CMOD]] : [[TYPE]] - // CHECK: return %[[REM]] : [[TYPE]] - %res = mod_arith.mul %lhs, %rhs {modulus = 17}: i16 - return %res : i16 -} - -// CHECK-LABEL: @test_lower_simple_mul_vec -// CHECK-SAME: (%[[LHS:.*]]: [[TYPE:.*]], %[[RHS:.*]]: [[TYPE]]) -> [[TYPE]] { -func.func @test_lower_simple_mul_vec(%lhs : tensor<4xi16>, %rhs : tensor<4xi16>) -> tensor<4xi16> { - // CHECK-NOT: mod_arith.mul - // CHECK: %[[MUL:.*]] = arith.muli %[[LHS]], %[[RHS]] : [[TYPE]] - // CHECK: %[[CMOD:.*]] = arith.constant dense<17> : [[TYPE]] - // CHECK: %[[REM:.*]] = arith.remui %[[MUL]], %[[CMOD]] : [[TYPE]] - // CHECK: return %[[REM]] : [[TYPE]] - %res = mod_arith.mul %lhs, %rhs {modulus = 17}: tensor<4xi16> - return %res : tensor<4xi16> + // CHECK: %[[CMOD:.*]] = arith.constant dense<65537> : [[T]] + // CHECK: %[[SUB:.*]] = arith.subi %[[LHS]], %[[RHS]] : [[T]] + // CHECK: %[[ADD:.*]] = arith.addi %[[SUB]], %[[CMOD]] : [[T]] + // CHECK: %[[REM:.*]] = arith.remui %[[ADD]], %[[CMOD]] : [[T]] + // CHECK: return %[[REM]] : [[T]] + %res = mod_arith.sub %lhs, %rhs : !Zpv + return %res : !Zpv } // CHECK-LABEL: @test_lower_mul -// CHECK-SAME: (%[[LHS:.*]]: [[TYPE:.*]], %[[RHS:.*]]: [[TYPE]]) -> [[TYPE]] { -func.func @test_lower_mul(%lhs : i8, %rhs : i8) -> i8 { +// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]], %[[RHS:.*]]: [[T]]) -> [[T]] { +func.func @test_lower_mul(%lhs : !Zp, %rhs : !Zp) -> !Zp { // CHECK-NOT: mod_arith.mul - // CHECK: %[[CMOD:.*]] = arith.constant 217 : [[INTERMEDIATE_TYPE:.*]] - // CHECK: %[[EXT0:.*]] = arith.extui %[[LHS]] : [[TYPE]] to [[INTERMEDIATE_TYPE]] - // CHECK: %[[EXT1:.*]] = arith.extui %[[RHS]] : [[TYPE]] to [[INTERMEDIATE_TYPE]] - // CHECK: %[[MUL:.*]] = arith.muli %[[EXT0]], %[[EXT1]] : [[INTERMEDIATE_TYPE]] - // CHECK: %[[REM:.*]] = arith.remui %[[MUL]], %[[CMOD]] : [[INTERMEDIATE_TYPE]] - // CHECK: %[[TRUNC:.*]] = arith.trunci %[[REM]] : [[INTERMEDIATE_TYPE]] to [[TYPE]] - // CHECK: return %[[TRUNC]] : [[TYPE]] - %res = mod_arith.mul %lhs, %rhs {modulus = 217 }: i8 - return %res : i8 + // CHECK: %[[CMOD:.*]] = arith.constant 65537 : [[TEXT:.*]] + // CHECK: %[[EXT0:.*]] = arith.extui %[[LHS]] : [[T]] to [[TEXT]] + // CHECK: %[[EXT1:.*]] = arith.extui %[[RHS]] : [[T]] to [[TEXT]] + // CHECK: %[[MUL:.*]] = arith.muli %[[EXT0]], %[[EXT1]] : [[TEXT]] + // CHECK: %[[REM:.*]] = arith.remui %[[MUL]], %[[CMOD]] : [[TEXT]] + // CHECK: %[[TRUNC:.*]] = arith.trunci %[[REM]] : [[TEXT]] to [[T]] + // CHECK: return %[[TRUNC]] : [[T]] + %res = mod_arith.mul %lhs, %rhs : !Zp + return %res : !Zp } // CHECK-LABEL: @test_lower_mul_vec -// CHECK-SAME: (%[[LHS:.*]]: [[TYPE:.*]], %[[RHS:.*]]: [[TYPE]]) -> [[TYPE]] { -func.func @test_lower_mul_vec(%lhs : tensor<4xi8>, %rhs : tensor<4xi8>) -> tensor<4xi8> { +// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]], %[[RHS:.*]]: [[T]]) -> [[T]] { +func.func @test_lower_mul_vec(%lhs : !Zpv, %rhs : !Zpv) -> !Zpv { // CHECK-NOT: mod_arith.mul - // CHECK: %[[CMOD:.*]] = arith.constant dense<217> : [[INTERMEDIATE_TYPE:.*]] - // CHECK: %[[EXT0:.*]] = arith.extui %[[LHS]] : [[TYPE]] to [[INTERMEDIATE_TYPE]] - // CHECK: %[[EXT1:.*]] = arith.extui %[[RHS]] : [[TYPE]] to [[INTERMEDIATE_TYPE]] - // CHECK: %[[MUL:.*]] = arith.muli %[[EXT0]], %[[EXT1]] : [[INTERMEDIATE_TYPE]] - // CHECK: %[[REM:.*]] = arith.remui %[[MUL]], %[[CMOD]] : [[INTERMEDIATE_TYPE]] - // CHECK: %[[TRUNC:.*]] = arith.trunci %[[REM]] : [[INTERMEDIATE_TYPE]] to [[TYPE]] - // CHECK: return %[[TRUNC]] : [[TYPE]] - %res = mod_arith.mul %lhs, %rhs {modulus = 217 }: tensor<4xi8> - return %res : tensor<4xi8> -} - -// CHECK-LABEL: @test_lower_simple_mac -// CHECK-SAME: (%[[LHS:.*]]: [[TYPE:.*]], %[[RHS:.*]]: [[TYPE]], %[[ACC:.*]]: [[TYPE]]) -> [[TYPE]] { -func.func @test_lower_simple_mac(%lhs : tensor<4xi16>, %rhs : tensor<4xi16>, %acc : tensor<4xi16>) -> tensor<4xi16> { - // CHECK-NOT: mod_arith.mac - // CHECK: %[[MUL:.*]] = arith.muli %[[LHS]], %[[RHS]] : [[TYPE]] - // CHECK: %[[ADD:.*]] = arith.addi %[[MUL]], %[[ACC]] : [[TYPE]] - // CHECK: %[[CMOD:.*]] = arith.constant dense<17> : [[TYPE]] - // CHECK: %[[REM:.*]] = arith.remui %[[ADD]], %[[CMOD]] : [[TYPE]] - // CHECK: return %[[REM]] : [[TYPE]] - %res = mod_arith.mac %lhs, %rhs, %acc {modulus = 17}: tensor<4xi16> - return %res : tensor<4xi16> -} - -// CHECK-LABEL: @test_lower_simple_mac_vec -// CHECK-SAME: (%[[LHS:.*]]: [[TYPE:.*]], %[[RHS:.*]]: [[TYPE]], %[[ACC:.*]]: [[TYPE]]) -> [[TYPE]] { -func.func @test_lower_simple_mac_vec(%lhs : i16, %rhs : i16, %acc : i16) -> i16 { - // CHECK-NOT: mod_arith.mac - // CHECK: %[[MUL:.*]] = arith.muli %[[LHS]], %[[RHS]] : [[TYPE]] - // CHECK: %[[ADD:.*]] = arith.addi %[[MUL]], %[[ACC]] : [[TYPE]] - // CHECK: %[[CMOD:.*]] = arith.constant 17 : [[TYPE]] - // CHECK: %[[REM:.*]] = arith.remui %[[ADD]], %[[CMOD]] : [[TYPE]] - // CHECK: return %[[REM]] : [[TYPE]] - %res = mod_arith.mac %lhs, %rhs, %acc{modulus = 17}: i16 - return %res : i16 + // CHECK: %[[CMOD:.*]] = arith.constant dense<65537> : [[TEXT:.*]] + // CHECK: %[[EXT0:.*]] = arith.extui %[[LHS]] : [[T]] to [[TEXT]] + // CHECK: %[[EXT1:.*]] = arith.extui %[[RHS]] : [[T]] to [[TEXT]] + // CHECK: %[[MUL:.*]] = arith.muli %[[EXT0]], %[[EXT1]] : [[TEXT]] + // CHECK: %[[REM:.*]] = arith.remui %[[MUL]], %[[CMOD]] : [[TEXT]] + // CHECK: %[[TRUNC:.*]] = arith.trunci %[[REM]] : [[TEXT]] to [[T]] + // CHECK: return %[[TRUNC]] : [[T]] + %res = mod_arith.mul %lhs, %rhs : !Zpv + return %res : !Zpv } // CHECK-LABEL: @test_lower_mac -// CHECK-SAME: (%[[LHS:.*]]: [[TYPE:.*]], %[[RHS:.*]]: [[TYPE]], %[[ACC:.*]]: [[TYPE]]) -> [[TYPE]] { -func.func @test_lower_mac(%lhs : i8, %rhs : i8, %acc : i8) -> i8 { +// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]], %[[RHS:.*]]: [[T]], %[[ACC:.*]]: [[T]]) -> [[T]] { +func.func @test_lower_mac(%lhs : !Zp, %rhs : !Zp, %acc : !Zp) -> !Zp { // CHECK-NOT: mod_arith.mac - // CHECK: %[[CMOD:.*]] = arith.constant 217 : [[INTERMEDIATE_TYPE:.*]] - // CHECK: %[[EXT0:.*]] = arith.extui %[[LHS]] : [[TYPE]] to [[INTERMEDIATE_TYPE]] - // CHECK: %[[EXT1:.*]] = arith.extui %[[RHS]] : [[TYPE]] to [[INTERMEDIATE_TYPE]] - // CHECK: %[[MUL:.*]] = arith.muli %[[EXT0]], %[[EXT1]] : [[INTERMEDIATE_TYPE]] - // CHECK: %[[EXT2:.*]] = arith.extui %[[ACC]] : [[TYPE]] to [[INTERMEDIATE_TYPE]] - // CHECK: %[[ADD:.*]] = arith.addi %[[MUL]], %[[EXT2]] : [[INTERMEDIATE_TYPE]] - // CHECK: %[[REM:.*]] = arith.remui %[[ADD]], %[[CMOD]] : [[INTERMEDIATE_TYPE]] - // CHECK: %[[TRUNC:.*]] = arith.trunci %[[REM]] : [[INTERMEDIATE_TYPE]] to [[TYPE]] - // CHECK: return %[[TRUNC]] : [[TYPE]] - %res = mod_arith.mac %lhs, %rhs, %acc {modulus = 217 }: i8 - return %res : i8 + // CHECK: %[[CMOD:.*]] = arith.constant 65537 : [[TEXT:.*]] + // CHECK: %[[EXT0:.*]] = arith.extui %[[LHS]] : [[T]] to [[TEXT]] + // CHECK: %[[EXT1:.*]] = arith.extui %[[RHS]] : [[T]] to [[TEXT]] + // CHECK: %[[EXT2:.*]] = arith.extui %[[ACC]] : [[T]] to [[TEXT]] + // CHECK: %[[MUL:.*]] = arith.muli %[[EXT0]], %[[EXT1]] : [[TEXT]] + // CHECK: %[[ADD:.*]] = arith.addi %[[MUL]], %[[EXT2]] : [[TEXT]] + // CHECK: %[[REM:.*]] = arith.remui %[[ADD]], %[[CMOD]] : [[TEXT]] + // CHECK: %[[TRUNC:.*]] = arith.trunci %[[REM]] : [[TEXT]] to [[T]] + // CHECK: return %[[TRUNC]] : [[T]] + %res = mod_arith.mac %lhs, %rhs, %acc : !Zp + return %res : !Zp } // CHECK-LABEL: @test_lower_mac_vec -// CHECK-SAME: (%[[LHS:.*]]: [[TYPE:.*]], %[[RHS:.*]]: [[TYPE]], %[[ACC:.*]]: [[TYPE]]) -> [[TYPE]] { -func.func @test_lower_mac_vec(%lhs : tensor<4xi8>, %rhs : tensor<4xi8>, %acc : tensor<4xi8>) -> tensor<4xi8> { +// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]], %[[RHS:.*]]: [[T]], %[[ACC:.*]]: [[T]]) -> [[T]] { +func.func @test_lower_mac_vec(%lhs : !Zpv, %rhs : !Zpv, %acc : !Zpv) -> !Zpv { // CHECK-NOT: mod_arith.mac - // CHECK: %[[CMOD:.*]] = arith.constant dense<217> : [[INTERMEDIATE_TYPE:.*]] - // CHECK: %[[EXT0:.*]] = arith.extui %[[LHS]] : [[TYPE]] to [[INTERMEDIATE_TYPE]] - // CHECK: %[[EXT1:.*]] = arith.extui %[[RHS]] : [[TYPE]] to [[INTERMEDIATE_TYPE]] - // CHECK: %[[MUL:.*]] = arith.muli %[[EXT0]], %[[EXT1]] : [[INTERMEDIATE_TYPE]] - // CHECK: %[[EXT2:.*]] = arith.extui %[[ACC]] : [[TYPE]] to [[INTERMEDIATE_TYPE]] - // CHECK: %[[ADD:.*]] = arith.addi %[[MUL]], %[[EXT2]] : [[INTERMEDIATE_TYPE]] - // CHECK: %[[REM:.*]] = arith.remui %[[ADD]], %[[CMOD]] : [[INTERMEDIATE_TYPE]] - // CHECK: %[[TRUNC:.*]] = arith.trunci %[[REM]] : [[INTERMEDIATE_TYPE]] to [[TYPE]] - // CHECK: return %[[TRUNC]] : [[TYPE]] - %res = mod_arith.mac %lhs, %rhs, %acc {modulus = 217 }: tensor<4xi8> - return %res : tensor<4xi8> -} - - -// ----- - -// CHECK-LABEL: @test_lower_reduce -// CHECK-SAME: (%[[ARG:.*]]: [[TENSOR_TYPE:.*]]) -> [[TENSOR_TYPE]] { -func.func @test_lower_reduce(%arg : tensor<4xi8>) -> tensor<4xi8> { - // CHECK: %[[CMOD:.*]] = arith.constant dense<17> : [[TENSOR_TYPE]] - - // CHECK: %[[MOD:.*]] = arith.remsi %[[ARG]], %[[CMOD]] : [[TENSOR_TYPE]] - // CHECK: %[[SHIFT:.*]] = arith.addi %[[MOD]], %[[CMOD]] : [[TENSOR_TYPE]] - // CHECK: %[[RES:.*]] = arith.remui %[[SHIFT]], %[[CMOD]] : [[TENSOR_TYPE]] - %res = mod_arith.reduce %arg { modulus = 17 } : tensor<4xi8> - return %res : tensor<4xi8> -} - -// ----- - -// CHECK-LABEL: @test_lower_reduce_int -// CHECK-SAME: (%[[ARG:.*]]: [[INT_TYPE:.*]]) -> [[INT_TYPE]] { -func.func @test_lower_reduce_int(%arg : i8) -> i8 { - // CHECK: %[[CMOD:.*]] = arith.constant 17 : [[INT_TYPE]] - - // CHECK: %[[MOD:.*]] = arith.remsi %[[ARG]], %[[CMOD]] : [[INT_TYPE]] - // CHECK: %[[SHIFT:.*]] = arith.addi %[[MOD]], %[[CMOD]] : [[INT_TYPE]] - // CHECK: %[[RES:.*]] = arith.remui %[[SHIFT]], %[[CMOD]] : [[INT_TYPE]] - %res = mod_arith.reduce %arg { modulus = 17 } : i8 - return %res : i8 -} - -// ----- - -// CHECK-LABEL: @test_lower_reduce_int_max_modulus -// CHECK-SAME: (%[[ARG:.*]]: [[INT_TYPE:.*]]) -> [[INT_TYPE]] { -func.func @test_lower_reduce_int_max_modulus(%arg : i8) -> i8 { - // CHECK: %[[CMOD:.*]] = arith.constant -128 : [[INT_TYPE:i8]] - - // CHECK: %[[MOD:.*]] = arith.remsi %[[ARG]], %[[CMOD]] : [[INT_TYPE]] - // CHECK: %[[SHIFT:.*]] = arith.addi %[[MOD]], %[[CMOD]] : [[INT_TYPE]] - // CHECK: %[[RES:.*]] = arith.remui %[[SHIFT]], %[[CMOD]] : [[INT_TYPE]] - %res = mod_arith.reduce %arg { modulus = 128 : i32 } : i8 - return %res : i8 + // CHECK: %[[CMOD:.*]] = arith.constant dense<65537> : [[TEXT:.*]] + // CHECK: %[[EXT0:.*]] = arith.extui %[[LHS]] : [[T]] to [[TEXT]] + // CHECK: %[[EXT1:.*]] = arith.extui %[[RHS]] : [[T]] to [[TEXT]] + // CHECK: %[[EXT2:.*]] = arith.extui %[[ACC]] : [[T]] to [[TEXT]] + // CHECK: %[[MUL:.*]] = arith.muli %[[EXT0]], %[[EXT1]] : [[TEXT]] + // CHECK: %[[ADD:.*]] = arith.addi %[[MUL]], %[[EXT2]] : [[TEXT]] + // CHECK: %[[REM:.*]] = arith.remui %[[ADD]], %[[CMOD]] : [[TEXT]] + // CHECK: %[[TRUNC:.*]] = arith.trunci %[[REM]] : [[TEXT]] to [[T]] + // CHECK: return %[[TRUNC]] : [[T]] + %res = mod_arith.mac %lhs, %rhs, %acc : !Zpv + return %res : !Zpv } // ----- diff --git a/tests/Dialect/ModArith/Conversions/mod_arith_to_arith/runner/lower_add.mlir b/tests/Dialect/ModArith/Conversions/mod_arith_to_arith/runner/lower_add.mlir index 0cf9c915e..5dfcc1f93 100644 --- a/tests/Dialect/ModArith/Conversions/mod_arith_to_arith/runner/lower_add.mlir +++ b/tests/Dialect/ModArith/Conversions/mod_arith_to_arith/runner/lower_add.mlir @@ -5,10 +5,20 @@ func.func private @printMemrefI32(memref<*xi32>) attributes { llvm.emit_c_interface } +!Zp = !mod_arith.mod_arith<7681 : i26> +!Zpv = tensor<4x!Zp> + func.func @test_lower_add() { + // 67108862 is -2 %x = arith.constant dense<[29498763, 42, 67108862, 7681]> : tensor<4xi26> + // 36789492 is -30319372, 67108863 is -1 %y = arith.constant dense<[36789492, 7234, 67108863, 7681]> : tensor<4xi26> - %1 = mod_arith.add %x, %y { modulus = 7681 } : tensor<4xi26> + %ex = mod_arith.encapsulate %x : tensor<4xi26> -> !Zpv + %ey = mod_arith.encapsulate %y : tensor<4xi26> -> !Zpv + %mx = mod_arith.reduce %ex : !Zpv + %my = mod_arith.reduce %ey : !Zpv + %m1 = mod_arith.add %mx, %my : !Zpv + %1 = mod_arith.extract %m1 : !Zpv -> tensor<4xi26> %2 = arith.extui %1 : tensor<4xi26> to tensor<4xi32> %3 = bufferization.to_memref %2 : memref<4xi32> @@ -17,4 +27,4 @@ func.func @test_lower_add() { return } -// CHECK_TEST_ADD: [1225, 7276, 7645, 0] +// CHECK_TEST_ADD: [1258, 7276, 7678, 0] diff --git a/tests/Dialect/ModArith/Conversions/mod_arith_to_arith/runner/lower_mac.mlir b/tests/Dialect/ModArith/Conversions/mod_arith_to_arith/runner/lower_mac.mlir index 725cc00a1..e2757e26e 100644 --- a/tests/Dialect/ModArith/Conversions/mod_arith_to_arith/runner/lower_mac.mlir +++ b/tests/Dialect/ModArith/Conversions/mod_arith_to_arith/runner/lower_mac.mlir @@ -5,11 +5,23 @@ func.func private @printMemrefI32(memref<*xi32>) attributes { llvm.emit_c_interface } +!Zp = !mod_arith.mod_arith<7681 : i26> +!Zpv = tensor<4x!Zp> + func.func @test_lower_mac() { + // 67108862 is -2 %x = arith.constant dense<[29498763, 42, 67108862, 7681]> : tensor<4xi26> + // 36789492 is -30319372, 67108863 is -1 %y = arith.constant dense<[36789492, 7234, 67108863, 7681]> : tensor<4xi26> %z = arith.constant dense<[0, 1, 2, 3]> : tensor<4xi26> - %1 = mod_arith.mac %x, %y, %z { modulus = 7681 } : tensor<4xi26> + %ex = mod_arith.encapsulate %x : tensor<4xi26> -> !Zpv + %ey = mod_arith.encapsulate %y : tensor<4xi26> -> !Zpv + %ez = mod_arith.encapsulate %z : tensor<4xi26> -> !Zpv + %mx = mod_arith.reduce %ex : !Zpv + %my = mod_arith.reduce %ey : !Zpv + %mz = mod_arith.reduce %ez : !Zpv + %m1 = mod_arith.mac %mx, %my, %mz : !Zpv + %1 = mod_arith.extract %m1 : !Zpv -> tensor<4xi26> %2 = arith.extui %1 : tensor<4xi26> to tensor<4xi32> %3 = bufferization.to_memref %2 : memref<4xi32> @@ -18,4 +30,4 @@ func.func @test_lower_mac() { return } -// CHECK_TEST_MAC: [5099, 4270, 4, 3] +// CHECK_TEST_MAC: [1600, 4270, 4, 3] diff --git a/tests/Dialect/ModArith/Conversions/mod_arith_to_arith/runner/lower_mul.mlir b/tests/Dialect/ModArith/Conversions/mod_arith_to_arith/runner/lower_mul.mlir index a2e33cf4f..52ed8f270 100644 --- a/tests/Dialect/ModArith/Conversions/mod_arith_to_arith/runner/lower_mul.mlir +++ b/tests/Dialect/ModArith/Conversions/mod_arith_to_arith/runner/lower_mul.mlir @@ -5,10 +5,20 @@ func.func private @printMemrefI32(memref<*xi32>) attributes { llvm.emit_c_interface } +!Zp = !mod_arith.mod_arith<7681 : i26> +!Zpv = tensor<4x!Zp> + func.func @test_lower_mul() { + // 67108862 is -2 %x = arith.constant dense<[29498763, 42, 67108862, 7681]> : tensor<4xi26> + // 36789492 is -30319372, 67108863 is -1 %y = arith.constant dense<[36789492, 7234, 67108863, 7681]> : tensor<4xi26> - %1 = mod_arith.mul %x, %y { modulus = 7681 } : tensor<4xi26> + %ex = mod_arith.encapsulate %x : tensor<4xi26> -> !Zpv + %ey = mod_arith.encapsulate %y : tensor<4xi26> -> !Zpv + %mx = mod_arith.reduce %ex : !Zpv + %my = mod_arith.reduce %ey : !Zpv + %m1 = mod_arith.mul %mx, %my : !Zpv + %1 = mod_arith.extract %m1 : !Zpv -> tensor<4xi26> %2 = arith.extui %1 : tensor<4xi26> to tensor<4xi32> %3 = bufferization.to_memref %2 : memref<4xi32> @@ -17,4 +27,4 @@ func.func @test_lower_mul() { return } -// CHECK_TEST_MUL: [5099, 4269, 2, 0] +// CHECK_TEST_MUL: [1600, 4269, 2, 0] diff --git a/tests/Dialect/ModArith/Conversions/mod_arith_to_arith/runner/lower_reduce.mlir b/tests/Dialect/ModArith/Conversions/mod_arith_to_arith/runner/lower_reduce.mlir index 6bd182c73..c5983b349 100644 --- a/tests/Dialect/ModArith/Conversions/mod_arith_to_arith/runner/lower_reduce.mlir +++ b/tests/Dialect/ModArith/Conversions/mod_arith_to_arith/runner/lower_reduce.mlir @@ -5,11 +5,19 @@ func.func private @printMemrefI32(memref<*xi32>) attributes { llvm.emit_c_interface } +!Zp1 = !mod_arith.mod_arith<7681 : i26> +!Zp1v = tensor<6x!Zp1> +// 33554431 = 2 ** 25 - 1 +!Zp2 = !mod_arith.mod_arith<33554431 : i26> +!Zp2v = tensor<6x!Zp2> + func.func @test_lower_reduce() { // reduce intends the input to be signed // 67108862 = 2 ** 26 - 2, equivalent to -2 as input %x = arith.constant dense<[29498763, 42, 67108862, 7681, -1, 7680]> : tensor<6xi26> - %1 = mod_arith.reduce %x { modulus = 7681 } : tensor<6xi26> + %e1 = mod_arith.encapsulate %x : tensor<6xi26> -> !Zp1v + %m1 = mod_arith.reduce %e1 : !Zp1v + %1 = mod_arith.extract %m1 : !Zp1v -> tensor<6xi26> // CHECK_TEST_REDUCE: [3723, 42, 7679, 0, 7680, 7680] %2 = arith.extui %1 : tensor<6xi26> to tensor<6xi32> @@ -20,22 +28,15 @@ func.func @test_lower_reduce() { // 67108862 = 2 ** 26 - 2, equivalent to -2 as input %y = arith.constant dense<[29498763, 42, 67108862, 67108863, -1, 7680]> : tensor<6xi26> - // 33554432 = 2 ** 25 - %4 = mod_arith.reduce %y { modulus = 33554432 } : tensor<6xi26> - // CHECK_TEST_REDUCE: [29498763, 42, 33554430, 33554431, 33554431, 7680] + // 33554431 = 2 ** 25 - 1 + %e4 = mod_arith.encapsulate %y : tensor<6xi26> -> !Zp2v + %m4 = mod_arith.reduce %e4 : !Zp2v + %4 = mod_arith.extract %m4 : !Zp2v -> tensor<6xi26> + // CHECK_TEST_REDUCE: [29498763, 42, 33554429, 33554430, 33554430, 7680] %5 = arith.extui %4 : tensor<6xi26> to tensor<6xi32> %6 = bufferization.to_memref %5 : memref<6xi32> %V = memref.cast %6 : memref<6xi32> to memref<*xi32> func.call @printMemrefI32(%V) : (memref<*xi32>) -> () - - // 33554431 = 2 ** 25 - 1 - %7 = mod_arith.reduce %y { modulus = 33554431 } : tensor<6xi26> - // CHECK_TEST_REDUCE: [29498763, 42, 33554429, 33554430, 33554430, 7680] - - %8 = arith.extui %7 : tensor<6xi26> to tensor<6xi32> - %9 = bufferization.to_memref %8 : memref<6xi32> - %W = memref.cast %9 : memref<6xi32> to memref<*xi32> - func.call @printMemrefI32(%W) : (memref<*xi32>) -> () return } diff --git a/tests/Dialect/ModArith/Conversions/mod_arith_to_arith/runner/lower_sub.mlir b/tests/Dialect/ModArith/Conversions/mod_arith_to_arith/runner/lower_sub.mlir index d6c0ba36e..aa55cb46b 100644 --- a/tests/Dialect/ModArith/Conversions/mod_arith_to_arith/runner/lower_sub.mlir +++ b/tests/Dialect/ModArith/Conversions/mod_arith_to_arith/runner/lower_sub.mlir @@ -5,10 +5,21 @@ func.func private @printMemrefI32(memref<*xi32>) attributes { llvm.emit_c_interface } +!Zp = !mod_arith.mod_arith<7681 : i26> +!Zpv = tensor<4x!Zp> + func.func @test_lower_sub() { + // 67108862 is -2 %x = arith.constant dense<[29498763, 42, 67108862, 7681]> : tensor<4xi26> + // 36789492 is -30319372, 67108863 is -1 %y = arith.constant dense<[36789492, 7234, 67108863, 7681]> : tensor<4xi26> - %1 = mod_arith.sub %x, %y { modulus = 7681 } : tensor<4xi26> + %ex = mod_arith.encapsulate %x : tensor<4xi26> -> !Zpv + %ey = mod_arith.encapsulate %y : tensor<4xi26> -> !Zpv + %mx = mod_arith.reduce %ex : !Zpv + %my = mod_arith.reduce %ey : !Zpv + %m1 = mod_arith.sub %mx, %my : !Zpv + %1 = mod_arith.extract %m1 : !Zpv -> tensor<4xi26> + %2 = arith.extui %1 : tensor<4xi26> to tensor<4xi32> %3 = bufferization.to_memref %2 : memref<4xi32> diff --git a/tests/Dialect/ModArith/IR/invalid-ops.mlir b/tests/Dialect/ModArith/IR/invalid-ops.mlir index a1b167394..139e48eca 100644 --- a/tests/Dialect/ModArith/IR/invalid-ops.mlir +++ b/tests/Dialect/ModArith/IR/invalid-ops.mlir @@ -1,40 +1,35 @@ // RUN: heir-opt --verify-diagnostics --split-input-file %s | FileCheck %s -// CHECK-NOT: @test_bad_arith_syntax -func.func @test_bad_arith_syntax() { - %c_vec = arith.constant dense<[1, 2, 1, 2]> : tensor<4xi4> +!Zp = !mod_arith.mod_arith<255 : i8> - // expected-error@+1 {{input bitwidth is required to be in the range [w, 2w], where w is the smallest bit-width that contains the range [0, modulus).}} - %barrett = mod_arith.barrett_reduce %c_vec { modulus = 17 } : tensor<4xi4> - - return +// CHECK-NOT: @test_bad_mod +func.func @test_bad_mod(%lhs : i8) -> !Zp { + // expected-error@+1 {{underlying type's bitwidth must be 1 bit larger than the modulus bitwidth, but got 8 while modulus requires width 8.}} + %m = mod_arith.encapsulate %lhs : i8 -> !Zp + return %m : !Zp } // ----- -// CHECK-NOT: @test_bad_mod -func.func @test_bad_mod(%lhs : i8, %rhs : i8) -> i8 { - // expected-error@+1 {{underlying type's bitwidth must be at least as large as the modulus bitwidth, but got 8 while modulus requires width 23.}} - %res = mod_arith.add %lhs, %rhs {modulus = 6666666 }: i8 - return %res : i8 +!Zp = !mod_arith.mod_arith<255 : i32> + +// CHECK-NOT: @test_bad_extract +func.func @test_bad_extract(%lhs : !Zp) -> i8 { + // expected-error@+1 {{the result integer type should be of the same width as the mod arith type width, but got 8 while mod arith type width 32}} + %m = mod_arith.extract %lhs : !Zp -> i8 + return %m : i8 } // ----- -// CHECK-NOT: @test_bad_mod_reduce -func.func @test_bad_mod_reduce(%arg0 : i8) -> i8 { - // expected-error@+1 {{underlying type's bitwidth must be larger than the modulus bitwidth, but got 8 while modulus requires width 8.}} - %res = mod_arith.reduce %arg0 {modulus = 217 }: i8 - return %res : i8 -} +// CHECK-NOT: @test_bad_arith_syntax +func.func @test_bad_arith_syntax() { + %c_vec = arith.constant dense<[1, 2, 1, 2]> : tensor<4xi4> -// ----- + // expected-error@+1 {{input bitwidth is required to be in the range [w, 2w], where w is the smallest bit-width that contains the range [0, modulus).}} + %barrett = mod_arith.barrett_reduce %c_vec { modulus = 17 } : tensor<4xi4> -// CHECK-NOT: @test_neg_mod_err -func.func @test_neg_mod_err(%arg : i8) -> i8 { - // expected-error@+1 {{provided modulus -3 is not a positive integer.}} - %res = mod_arith.reduce %arg { modulus = -3 : i7 } : i8 - return %res : i8 + return } // ----- @@ -45,12 +40,3 @@ func.func @test_barrett_neg_mod_err(%arg : i8) -> i8 { %res = mod_arith.barrett_reduce %arg { modulus = -3 : i7 } : i8 return %res : i8 } - -// ----- - -// CHECK: @test_bad_mod_warning -func.func @test_bad_mod_warning(%lhs : i8, %rhs : i8) -> i8 { - // expected-warning@+1 {{for signed (or signless) underlying types, the bitwidth of the underlying type must be at least as large as modulus bitwidth + 1 (for the sign bit), but found 8 while modulus requires width 8.}} - %res = mod_arith.add %lhs, %rhs {modulus = 135 }: i8 - return %res : i8 -} diff --git a/tests/Dialect/ModArith/IR/syntax.mlir b/tests/Dialect/ModArith/IR/syntax.mlir index c716f9ea7..a2029872b 100644 --- a/tests/Dialect/ModArith/IR/syntax.mlir +++ b/tests/Dialect/ModArith/IR/syntax.mlir @@ -1,5 +1,8 @@ // RUN: heir-opt %s | FileCheck %s +!Zp = !mod_arith.mod_arith<17 : i10> +!Zp_vec = tensor<4x!Zp> + // CHECK-LABEL: @test_arith_syntax func.func @test_arith_syntax() { %zero = arith.constant 1 : i10 @@ -12,30 +15,45 @@ func.func @test_arith_syntax() { %c_vec3 = arith.constant dense<[1, 1, 1, 1]> : tensor<4xi10> %cmod_vec = arith.constant dense<17> : tensor<4xi10> + // CHECK-COUNT-6: mod_arith.encapsulate + %e4 = mod_arith.encapsulate %c4 : i10 -> !Zp + %e5 = mod_arith.encapsulate %c5 : i10 -> !Zp + %e6 = mod_arith.encapsulate %c6 : i10 -> !Zp + %e_vec = mod_arith.encapsulate %c_vec : tensor<4xi10> -> !Zp_vec + %e_vec2 = mod_arith.encapsulate %c_vec2 : tensor<4xi10> -> !Zp_vec + %e_vec3 = mod_arith.encapsulate %c_vec3 : tensor<4xi10> -> !Zp_vec + + // CHECK-COUNT-6: mod_arith.reduce + %m4 = mod_arith.reduce %e4 : !Zp + %m5 = mod_arith.reduce %e5 : !Zp + %m6 = mod_arith.reduce %e6 : !Zp + %m_vec = mod_arith.reduce %e_vec : !Zp_vec + %m_vec2 = mod_arith.reduce %e_vec2 : !Zp_vec + %m_vec3 = mod_arith.reduce %e_vec3 : !Zp_vec + + // CHECK: mod_arith.extract + %extract = mod_arith.extract %m4 : !Zp -> i10 + %extract_vec = mod_arith.extract %m_vec : !Zp_vec -> tensor<4xi10> + // CHECK: mod_arith.add // CHECK: mod_arith.add - %add = mod_arith.add %c5, %c6 { modulus = 17 } : i10 - %add_vec = mod_arith.add %c_vec, %c_vec2 { modulus = 17 } : tensor<4xi10> + %add = mod_arith.add %m5, %m6 : !Zp + %add_vec = mod_arith.add %m_vec, %m_vec2 : !Zp_vec // CHECK: mod_arith.sub // CHECK: mod_arith.sub - %sub = mod_arith.sub %c5, %c6 { modulus = 17 } : i10 - %sub_vec = mod_arith.sub %c_vec, %c_vec2 { modulus = 17 } : tensor<4xi10> + %sub = mod_arith.sub %m5, %m6 : !Zp + %sub_vec = mod_arith.sub %m_vec, %m_vec2 : !Zp_vec // CHECK: mod_arith.mul // CHECK: mod_arith.mul - %mul = mod_arith.mul %c5, %c6 { modulus = 17 } : i10 - %mul_vec = mod_arith.mul %c_vec, %c_vec2 { modulus = 17 } : tensor<4xi10> + %mul = mod_arith.mul %m5, %m6 : !Zp + %mul_vec = mod_arith.mul %m_vec, %m_vec2 : !Zp_vec // CHECK: mod_arith.mac // CHECK: mod_arith.mac - %mac = mod_arith.mac %c5, %c6, %c4 { modulus = 17 } : i10 - %mac_vec = mod_arith.mac %c_vec, %c_vec2, %c_vec3 { modulus = 17 } : tensor<4xi10> - - // CHECK: mod_arith.reduce - // CHECK: mod_arith.reduce - %reduce = mod_arith.reduce %c4 { modulus = 17 } : i10 - %reduce_vec = mod_arith.reduce %c_vec { modulus = 17 } : tensor<4xi10> + %mac = mod_arith.mac %m5, %m6, %m4 : !Zp + %mac_vec = mod_arith.mac %m_vec, %m_vec2, %m_vec3 : !Zp_vec // CHECK: mod_arith.barrett_reduce // CHECK: mod_arith.barrett_reduce diff --git a/tests/Dialect/Polynomial/Transforms/BUILD b/tests/Dialect/Polynomial/Transforms/BUILD index c571e6fc6..35dba5a61 100644 --- a/tests/Dialect/Polynomial/Transforms/BUILD +++ b/tests/Dialect/Polynomial/Transforms/BUILD @@ -6,5 +6,6 @@ glob_lit_tests( name = "all_tests", data = ["@heir//tests:test_utilities"], driver = "@heir//tests:run_lit.sh", + exclude = ["ntt_rewrites.mlir"], # TODO(#1095): disabled for mod_arith type migration test_file_exts = ["mlir"], )