Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ModArithType for mod_arith dialect #1088

Merged
merged 1 commit into from
Nov 18, 2024

Conversation

ZenithalHourlyRate
Copy link
Contributor

See #1084

This PR is in draft state, opened here for reviewing the ModArithType type itself.

It is worth noting is that the type explicitly does not allow modulus value to be too large to fill the entire underlying type as there would be signedness issues. As we are working in FHE situations where modulus width typically goes up to only 60 for 64bit native int, this requirement does not harm the functionality and edge cases for lowering to native int could be avoided.

Later pushes will replace all existing ModArithOps to use the ModArithType and corresponding ModArithToArith lowering. Currently mreduce and madd serves only as a demo for the type verification and would be renamed to reduce/add once finished migration.

Syntax now:

!Zp1 = !mod_arith.mod_arith<modulus = 65537 : i32>

module {
  func.func @dot_product() -> !Zp1 {
    %c = arith.constant 1 : i32
    %0 = mod_arith.mreduce %c : i32 -> !Zp1
    %1 = mod_arith.madd %0, %0 : !Zp1
    return %1 : !Zp1
  }
}

@j2kun
Copy link
Collaborator

j2kun commented Nov 12, 2024

IMO this looks great. I think the use of the : i64 type specifier for the "storage type" is a nice way to avoid extra verbosity.

@ZenithalHourlyRate
Copy link
Contributor Author

ZenithalHourlyRate commented Nov 13, 2024

Encountered issues of integer range when trying to migrate mod_arith.barrett_reduce and subifge because barett reduction will give a result in [0, 2p), and [p, 2p) are not represented in mod_arith.mod_arith type.

HEaaN.mlir resolves such issue by using a modf

Both ModArith and Poly operations may optionally carry modulus factor (modf ). If an instruction is tagged with a modulus factor 𝑓𝑖, its output is allowed to fit in [0, 𝑓𝑖· 𝑝) where 𝑝 is the exact modulus. Assigning a larger 𝑓𝑖 generates a faster loop because SubIfGE in Barrett reduction can be removed. We describe our algorithm for tagging modulus factors of instructions at Sec. 6.1.

Should we annotate the modf in the op or in the type? I think the former is easier to implement and can be done on the fly by the analysis pass, and lowering can use such information to do a cheaper lowering. We can also verify how many modf are availble given the storage type and the modulus in the analysis pass. These are all for later PRs.

For this PR. I think migrating basic ops like add/sub/mul with basic lowering is enough. And as barett_reduce is not used for now (though there was a PR #712 doing such transform), I'll just use the ModArithType for barrett_reduce without ensuring its range, I'll leave it as is so that we can have a useable ModArithType earlier.

@ZenithalHourlyRate ZenithalHourlyRate marked this pull request as ready for review November 13, 2024 11:57
@j2kun
Copy link
Collaborator

j2kun commented Nov 14, 2024

Should we annotate the modf in the op or in the type? I think the former is easier to implement and can be done on the fly by the analysis pass, and lowering can use such information to do a cheaper lowering. We can also verify how many modf are availble given the storage type and the modulus in the analysis pass. These are all for later PRs.

For this PR. I think migrating basic ops like add/sub/mul with basic lowering is enough. And as barett_reduce is not used for now (though there was a PR #712 doing such transform), I'll just use the ModArithType for barrett_reduce without ensuring its range, I'll leave it as is so that we can have a useable ModArithType earlier.

Your reasoning is sound for the decision here. From my perspective, the main trade-off here is the inefficiency loss by incurring an extra field on every mod_arith type vs the extra work of lowering that type. The former doesn't seem that bad alone, but the worry is always that 5 or 10 more such optimizations will lead to 5 or 10 more fields on the type, so having a systematic strategy to avoid that is good.

If we put it on the ops, I just want to quickly sketch out the extra work that would require. Say we want to lower modarith<5 : i3> and there's a factor of 3 on some ops. The pass ultimately needs to be able to type-convert types without knowing which op it's on. So the pass setup would need to first run an analysis over the IR (or, say, at the func level) to determine the appropriate storage type to lower to. This might be as simple as just a max over modulus factors found across all ops, in which case perhaps we could reuse int range analysis from upstream. In this example, the analysis would determine we lower to i5 (because fi*p = 3*5 = 15 fits in an i4 but is too close to the max), and then that width would be given to the TypeConverter's constructor, and it would use that instead of the : i3 annotation when converting the type (the analysis would default to a width of 3 if no modulus-factors are found among the relevant ops).

Copy link
Collaborator

@j2kun j2kun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks fantastic! My main critique is that we should be clear at the start about the operation semantics in the tablegen description fields, as that will be the source of truth when folks (particularly me) forget what we decided on.

@j2kun j2kun added the pull_ready Indicates whether a PR is ready to pull. The copybara worker will import for internal testing label Nov 15, 2024
Comment on lines +204 to +209
auto x = b.create<arith::ExtUIOp>(modulusType(op, true),
adaptor.getOperands()[0]);
auto y = b.create<arith::ExtUIOp>(modulusType(op, true),
adaptor.getOperands()[1]);
auto acc = b.create<arith::ExtUIOp>(modulusType(op, true),
adaptor.getOperands()[2]);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just for later - the OpAdaptor comes with named accessors for the operands (like adaptor.getLhs(), adaptor.getRhs(), adapator.getAcc())

@copybara-service copybara-service bot merged commit 3874e8a into google:main Nov 18, 2024
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull_ready Indicates whether a PR is ready to pull. The copybara worker will import for internal testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[modarith] Add a modarith type to describe storage type & modulus
3 participants