Skip to content

Commit

Permalink
Add ML-KEM 512, 768, and 1024
Browse files Browse the repository at this point in the history
  • Loading branch information
skmcgrail committed Oct 1, 2024
1 parent a537ded commit b474f2a
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 19 deletions.
144 changes: 133 additions & 11 deletions aws-lc-rs/src/kem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@
//! use aws_lc_rs::{
//! error::Unspecified,
//! kem::{Ciphertext, DecapsulationKey, EncapsulationKey},
//! unstable::kem::{AlgorithmId, get_algorithm}
//! unstable::kem::{AlgorithmId, ML_KEM_512}
//! };
//!
//! let kyber512_r3 = get_algorithm(AlgorithmId::Kyber512_R3).ok_or(Unspecified)?;
//!
//! // Alice generates their (private) decapsulation key.
//! let decapsulation_key = DecapsulationKey::generate(kyber512_r3)?;
//! let decapsulation_key = DecapsulationKey::generate(&ML_KEM_512)?;
//!
//! // Alices computes the (public) encapsulation key.
//! let encapsulation_key = decapsulation_key.encapsulation_key()?;
Expand All @@ -31,12 +29,12 @@
//! let encapsulation_key_bytes = encapsulation_key_bytes.as_ref();
//!
//! // Bob constructs the (public) encapsulation key from the key bytes provided by Alice.
//! let retrieved_encapsulation_key = EncapsulationKey::new(kyber512_r3, encapsulation_key_bytes)?;
//! let retrieved_encapsulation_key = EncapsulationKey::new(&ML_KEM_512, encapsulation_key_bytes)?;
//!
//! // Bob executes the encapsulation algorithm to to produce their copy of the secret, and associated ciphertext.
//! let (ciphertext, bob_secret) = retrieved_encapsulation_key.encapsulate()?;
//!
//! // Alice recieves ciphertext bytes from bob
//! // Alice receives ciphertext bytes from bob
//! let ciphertext_bytes = ciphertext.as_ref();
//!
//! // Bob sends Alice the ciphertext computed from the encapsulation algorithm, Alice runs decapsulation to derive their
Expand All @@ -58,7 +56,8 @@ use alloc::borrow::Cow;
use aws_lc::{
EVP_PKEY_CTX_kem_set_params, EVP_PKEY_CTX_new_id, EVP_PKEY_decapsulate, EVP_PKEY_encapsulate,
EVP_PKEY_get_raw_private_key, EVP_PKEY_get_raw_public_key, EVP_PKEY_kem_new_raw_public_key,
EVP_PKEY_keygen, EVP_PKEY_keygen_init, EVP_PKEY, EVP_PKEY_KEM,
EVP_PKEY_keygen, EVP_PKEY_keygen_init, EVP_PKEY, EVP_PKEY_KEM, NID_MLKEM1024, NID_MLKEM512,
NID_MLKEM768,
};
use core::{cmp::Ordering, ptr::null_mut};
use zeroize::Zeroize;
Expand Down Expand Up @@ -136,14 +135,27 @@ where
/// Identifier for a KEM algorithm.
///
/// See [`crate::unstable::kem::AlgorithmId`] and [`crate::unstable::kem::get_algorithm`] for
/// access to algorithms not subject to semantic versioning gurantees.
/// access to algorithms not subject to semantic versioning guarantees.
#[non_exhaustive]
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum AlgorithmId {}
pub enum AlgorithmId {
/// NIST FIPS 203 ML-KEM-512 algorithm.
MlKem512,

/// NIST FIPS 203 ML-KEM-768 algorithm.
MlKem768,

/// NIST FIPS 203 ML-KEM-1024 algorithm.
MlKem1024,
}

impl AlgorithmIdentifier for AlgorithmId {
fn nid(self) -> i32 {
unreachable!()
match self {
AlgorithmId::MlKem512 => NID_MLKEM512,
AlgorithmId::MlKem768 => NID_MLKEM768,
AlgorithmId::MlKem1024 => NID_MLKEM1024,
}
}
}

Expand Down Expand Up @@ -457,7 +469,12 @@ fn kem_key_generate(nid: i32) -> Result<LcPtr<EVP_PKEY>, Unspecified> {

#[cfg(test)]
mod tests {
use super::{Ciphertext, SharedSecret};
use crate::error::KeyRejected;

use super::{Ciphertext, DecapsulationKey, EncapsulationKey, SharedSecret};

#[cfg(feature = "unstable")]
use crate::unstable::kem::{ML_KEM_1024, ML_KEM_512, ML_KEM_768};

#[test]
fn ciphertext() {
Expand All @@ -477,4 +494,109 @@ mod tests {
let shared_secret = SharedSecret::new(secret_bytes.into_boxed_slice());
assert_eq!(shared_secret.as_ref(), &[42, 42, 42, 42]);
}

#[test]
fn test_kem_serialize() {
for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
let priv_key = DecapsulationKey::generate(algorithm).unwrap();
assert_eq!(priv_key.algorithm(), algorithm);

let pub_key = priv_key.encapsulation_key().unwrap();
let pubkey_raw_bytes = pub_key.key_bytes().unwrap();
let pub_key_from_bytes =
EncapsulationKey::new(algorithm, pubkey_raw_bytes.as_ref()).unwrap();

assert_eq!(
pub_key.key_bytes().unwrap().as_ref(),
pub_key_from_bytes.key_bytes().unwrap().as_ref()
);
assert_eq!(pub_key.algorithm(), pub_key_from_bytes.algorithm());
}
}

#[test]
#[cfg(feature = "unstable")]
fn test_kem_wrong_sizes() {
for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
let too_long_bytes = vec![0u8; algorithm.encapsulate_key_size() + 1];
let long_pub_key_from_bytes = EncapsulationKey::new(algorithm, &too_long_bytes);
assert_eq!(
long_pub_key_from_bytes.err(),
Some(KeyRejected::too_large())
);

let too_short_bytes = vec![0u8; algorithm.encapsulate_key_size() - 1];
let short_pub_key_from_bytes = EncapsulationKey::new(algorithm, &too_short_bytes);
assert_eq!(
short_pub_key_from_bytes.err(),
Some(KeyRejected::too_small())
);
}
}

#[test]
#[cfg(feature = "unstable")]
fn test_kem_e2e() {
for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
let priv_key = DecapsulationKey::generate(algorithm).unwrap();
assert_eq!(priv_key.algorithm(), algorithm);

let pub_key = priv_key.encapsulation_key().unwrap();

let (alice_ciphertext, alice_secret) =
pub_key.encapsulate().expect("encapsulate successful");

let bob_secret = priv_key
.decapsulate(alice_ciphertext)
.expect("decapsulate successful");

assert_eq!(alice_secret.as_ref(), bob_secret.as_ref());
}
}

#[test]
#[cfg(feature = "unstable")]
fn test_serialized_kem_e2e() {
for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
let priv_key = DecapsulationKey::generate(algorithm).unwrap();
assert_eq!(priv_key.algorithm(), algorithm);

let pub_key = priv_key.encapsulation_key().unwrap();

// Generate public key bytes to send to bob
let pub_key_bytes = pub_key.key_bytes().unwrap();

// Test that priv_key's EVP_PKEY isn't entirely freed since we remove this pub_key's reference.
drop(pub_key);

let retrieved_pub_key =
EncapsulationKey::new(algorithm, pub_key_bytes.as_ref()).unwrap();
let (ciphertext, bob_secret) = retrieved_pub_key
.encapsulate()
.expect("encapsulate successful");

let alice_secret = priv_key
.decapsulate(ciphertext)
.expect("encapsulate successful");

assert_eq!(alice_secret.as_ref(), bob_secret.as_ref());
}
}

#[test]
#[cfg(feature = "unstable")]
fn test_debug_fmt() {
let private = DecapsulationKey::generate(&ML_KEM_512).expect("successful generation");
assert_eq!(
format!("{private:?}"),
"DecapsulationKey { algorithm: MlKem512, .. }"
);
assert_eq!(
format!(
"{:?}",
private.encapsulation_key().expect("public key retrievable")
),
"EncapsulationKey { algorithm: MlKem512, .. }"
);
}
}
2 changes: 1 addition & 1 deletion aws-lc-rs/src/unstable/kdf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@
//! ```
//! # Single-step Key Derivation Function (SSKDF)
//!
//! [`sskdf_digest`] and [`sskd_hmac`] provided implementations of a one-step key derivation function defined in
//! [`sskdf_digest`] and [`sskdf_hmac`] provided implementations of a one-step key derivation function defined in
//! section 4 of [NIST SP 800-56Cr2](https://doi.org/10.6028/NIST.SP.800-56Cr2).
//!
//! These functions are used to derive keying material from a shared secret during a key establishment scheme.
Expand Down
67 changes: 60 additions & 7 deletions aws-lc-rs/src/unstable/kem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@
//! use aws_lc_rs::{
//! error::Unspecified,
//! kem::{Ciphertext, DecapsulationKey, EncapsulationKey},
//! unstable::kem::{AlgorithmId, get_algorithm}
//! unstable::kem::{AlgorithmId, ML_KEM_512}
//! };
//!
//! let kyber512_r3 = get_algorithm(AlgorithmId::Kyber512_R3).ok_or(Unspecified)?;
//!
//! // Alice generates their (private) decapsulation key.
//! let decapsulation_key = DecapsulationKey::generate(kyber512_r3)?;
//! let decapsulation_key = DecapsulationKey::generate(&ML_KEM_512)?;
//!
//! // Alices computes the (public) encapsulation key.
//! let encapsulation_key = decapsulation_key.encapsulation_key()?;
Expand All @@ -31,12 +29,12 @@
//! let encapsulation_key_bytes = encapsulation_key_bytes.as_ref();
//!
//! // Bob constructs the (public) encapsulation key from the key bytes provided by Alice.
//! let retrieved_encapsulation_key = EncapsulationKey::new(kyber512_r3, encapsulation_key_bytes)?;
//! let retrieved_encapsulation_key = EncapsulationKey::new(&ML_KEM_512, encapsulation_key_bytes)?;
//!
//! // Bob executes the encapsulation algorithm to to produce their copy of the secret, and associated ciphertext.
//! let (ciphertext, bob_secret) = retrieved_encapsulation_key.encapsulate()?;
//!
//! // Alice recieves ciphertext bytes from bob
//! // Alice receives ciphertext bytes from bob
//! let ciphertext_bytes = ciphertext.as_ref();
//!
//! // Bob sends Alice the ciphertext computed from the encapsulation algorithm, Alice runs decapsulation to derive their
Expand All @@ -54,6 +52,48 @@ use core::fmt::Debug;
use crate::kem::Algorithm;
use aws_lc::{NID_KYBER1024_R3, NID_KYBER512_R3, NID_KYBER768_R3};

const ML_KEM_512_SHARED_SECRET_LENGTH: usize = 32;
const ML_KEM_512_PUBLIC_KEY_LENGTH: usize = 800;
const ML_KEM_512_SECRET_KEY_LENGTH: usize = 1632;
const ML_KEM_512_CIPHERTEXT_LENGTH: usize = 768;

const ML_KEM_768_SHARED_SECRET_LENGTH: usize = 32;
const ML_KEM_768_PUBLIC_KEY_LENGTH: usize = 1184;
const ML_KEM_768_SECRET_KEY_LENGTH: usize = 2400;
const ML_KEM_768_CIPHERTEXT_LENGTH: usize = 1088;

const ML_KEM_1024_SHARED_SECRET_LENGTH: usize = 32;
const ML_KEM_1024_PUBLIC_KEY_LENGTH: usize = 1568;
const ML_KEM_1024_SECRET_KEY_LENGTH: usize = 3168;
const ML_KEM_1024_CIPHERTEXT_LENGTH: usize = 1568;

/// NIST FIPS 203 ML-KEM-512 algorithm.
pub const ML_KEM_512: Algorithm<crate::kem::AlgorithmId> = Algorithm {
id: crate::kem::AlgorithmId::MlKem512,
decapsulate_key_size: ML_KEM_512_SECRET_KEY_LENGTH,
encapsulate_key_size: ML_KEM_512_PUBLIC_KEY_LENGTH,
ciphertext_size: ML_KEM_512_CIPHERTEXT_LENGTH,
shared_secret_size: ML_KEM_512_SHARED_SECRET_LENGTH,
};

/// NIST FIPS 203 ML-KEM-768 algorithm.
pub const ML_KEM_768: Algorithm<crate::kem::AlgorithmId> = Algorithm {
id: crate::kem::AlgorithmId::MlKem768,
decapsulate_key_size: ML_KEM_768_SECRET_KEY_LENGTH,
encapsulate_key_size: ML_KEM_768_PUBLIC_KEY_LENGTH,
ciphertext_size: ML_KEM_768_CIPHERTEXT_LENGTH,
shared_secret_size: ML_KEM_768_SHARED_SECRET_LENGTH,
};

/// NIST FIPS 203 ML-KEM-1024 algorithm.
pub const ML_KEM_1024: Algorithm<crate::kem::AlgorithmId> = Algorithm {
id: crate::kem::AlgorithmId::MlKem1024,
decapsulate_key_size: ML_KEM_1024_SECRET_KEY_LENGTH,
encapsulate_key_size: ML_KEM_1024_PUBLIC_KEY_LENGTH,
ciphertext_size: ML_KEM_1024_CIPHERTEXT_LENGTH,
shared_secret_size: ML_KEM_1024_SHARED_SECRET_LENGTH,
};

// Key lengths defined as stated on the CRYSTALS website:
// https://pq-crystals.org/kyber/

Expand All @@ -73,6 +113,7 @@ const KYBER1024_R3_PUBLIC_KEY_LENGTH: usize = 1568;
const KYBER1024_R3_SHARED_SECRET_LENGTH: usize = 32;

/// NIST Round 3 submission of the Kyber-512 algorithm.
#[allow(deprecated)]
const KYBER512_R3: Algorithm<AlgorithmId> = Algorithm {
id: AlgorithmId::Kyber512_R3,
decapsulate_key_size: KYBER512_R3_SECRET_KEY_LENGTH,
Expand All @@ -82,6 +123,7 @@ const KYBER512_R3: Algorithm<AlgorithmId> = Algorithm {
};

/// NIST Round 3 submission of the Kyber-768 algorithm.
#[allow(deprecated)]
const KYBER768_R3: Algorithm<AlgorithmId> = Algorithm {
id: AlgorithmId::Kyber768_R3,
decapsulate_key_size: KYBER768_R3_SECRET_KEY_LENGTH,
Expand All @@ -91,6 +133,7 @@ const KYBER768_R3: Algorithm<AlgorithmId> = Algorithm {
};

/// NIST Round 3 submission of the Kyber-1024 algorithm.
#[allow(deprecated)]
const KYBER1024_R3: Algorithm<AlgorithmId> = Algorithm {
id: AlgorithmId::Kyber1024_R3,
decapsulate_key_size: KYBER1024_R3_SECRET_KEY_LENGTH,
Expand All @@ -105,18 +148,22 @@ const KYBER1024_R3: Algorithm<AlgorithmId> = Algorithm {
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum AlgorithmId {
/// NIST Round 3 submission of the Kyber-512 algorithm.
#[deprecated]
Kyber512_R3,

/// NIST Round 3 submission of the Kyber-768 algorithm.
#[deprecated]
Kyber768_R3,

/// NIST Round 3 submission of the Kyber-1024 algorithm.
#[deprecated]
Kyber1024_R3,
}

impl crate::kem::AlgorithmIdentifier for AlgorithmId {
#[inline]
fn nid(self) -> i32 {
#[allow(deprecated)]
match self {
AlgorithmId::Kyber512_R3 => NID_KYBER512_R3,
AlgorithmId::Kyber768_R3 => NID_KYBER768_R3,
Expand All @@ -131,6 +178,7 @@ impl crate::sealed::Sealed for AlgorithmId {}
/// May return [`None`] if support for the algorithm has been removed from the unstable module.
#[must_use]
pub const fn get_algorithm(id: AlgorithmId) -> Option<&'static Algorithm<AlgorithmId>> {
#[allow(deprecated)]
match id {
AlgorithmId::Kyber512_R3 => Some(&KYBER512_R3),
AlgorithmId::Kyber768_R3 => Some(&KYBER768_R3),
Expand All @@ -140,12 +188,17 @@ pub const fn get_algorithm(id: AlgorithmId) -> Option<&'static Algorithm<Algorit

#[cfg(test)]
mod tests {
#![allow(deprecated)]

use crate::{
error::KeyRejected,
kem::{DecapsulationKey, EncapsulationKey},
};

use super::{get_algorithm, AlgorithmId, KYBER1024_R3, KYBER512_R3, KYBER768_R3};
use super::{
get_algorithm, AlgorithmId, KYBER1024_R3, KYBER512_R3, KYBER768_R3, ML_KEM_1024,
ML_KEM_512, ML_KEM_768,
};

#[test]
fn test_kem_serialize() {
Expand Down

0 comments on commit b474f2a

Please sign in to comment.