Skip to content

Commit

Permalink
feat: Use builder pattern to enable rsa acceleration. (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
AnthonyGrondin authored Jul 16, 2024
1 parent c619fad commit 5fb0ca7
Show file tree
Hide file tree
Showing 9 changed files with 67 additions and 24 deletions.
57 changes: 49 additions & 8 deletions esp-mbedtls/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ pub use esp_mbedtls_sys::bindings::{
use esp_mbedtls_sys::c_types::*;

/// Hold the RSA peripheral for cryptographic operations.
///
/// This is initialized when `with_hardware_rsa()` is called on a [Session] and is set back to None
/// when the session that called `with_hardware_rsa()` is dropped.
///
/// Note: Due to implementation constraints, this session and every other session will use the
/// hardware accelerated RSA driver until the session called with this function is dropped.
static mut RSA_REF: Option<Rsa<esp_hal::Blocking>> = None;

// these will come from esp-wifi (i.e. this can only be used together with esp-wifi)
Expand Down Expand Up @@ -385,6 +391,8 @@ pub struct Session<T> {
crt: *mut mbedtls_x509_crt,
client_crt: *mut mbedtls_x509_crt,
private_key: *mut mbedtls_pk_context,
// Indicate if this session is the one holding the RSA ref
owns_rsa: bool,
}

impl<T> Session<T> {
Expand All @@ -399,8 +407,6 @@ impl<T> Session<T> {
/// * `min_version` - The minimum TLS version for the connection, that will be accepted.
/// * `certificates` - Certificate chain for the connection. Will play a different role
/// depending on if running as client or server. See [Certificates] for more information.
/// * `rsa` - Optionally take an RSA driver instance. This session will use the hardware rsa crypto
/// accelerators for the session. Passing None will use the software implementation of RSA which is slower.
///
/// # Errors
///
Expand All @@ -413,20 +419,33 @@ impl<T> Session<T> {
mode: Mode,
min_version: TlsVersion,
certificates: Certificates,
rsa: Option<impl Peripheral<P = RSA>>,
) -> Result<Self, TlsError> {
let (ssl_context, ssl_config, crt, client_crt, private_key) =
certificates.init_ssl(servername, mode, min_version)?;
unsafe { RSA_REF = core::mem::transmute(rsa.map(|inner| Rsa::new(inner, None))) }
return Ok(Self {
stream,
ssl_context,
ssl_config,
crt,
client_crt,
private_key,
owns_rsa: false,
});
}

/// Enable the use of the hardware accelerated RSA peripheral for the [Session].
///
/// Note: Due to implementation constraints, this session and every other session will use the
/// hardware accelerated RSA driver until the sesssion called with this function is dropped.
///
/// # Arguments
///
/// * `rsa` - The RSA peripheral from the HAL
pub fn with_hardware_rsa(mut self, rsa: impl Peripheral<P = RSA>) -> Self {
unsafe { RSA_REF = core::mem::transmute(Some(Rsa::new(rsa, None))) }
self.owns_rsa = true;
self
}
}

impl<T> Session<T>
Expand Down Expand Up @@ -536,6 +555,11 @@ impl<T> Drop for Session<T> {
fn drop(&mut self) {
log::debug!("session dropped - freeing memory");
unsafe {
// If the struct that owns the RSA reference is dropped
// we remove RSA in static for safety
if self.owns_rsa {
RSA_REF = core::mem::transmute(None::<RSA>);
}
mbedtls_ssl_close_notify(self.ssl_context);
mbedtls_ssl_config_free(self.ssl_config);
mbedtls_ssl_free(self.ssl_context);
Expand Down Expand Up @@ -611,6 +635,7 @@ pub mod asynch {
eof: bool,
tx_buffer: BufferedBytes<BUFFER_SIZE>,
rx_buffer: BufferedBytes<BUFFER_SIZE>,
owns_rsa: bool,
}

impl<T, const BUFFER_SIZE: usize> Session<T, BUFFER_SIZE> {
Expand All @@ -625,8 +650,6 @@ pub mod asynch {
/// * `min_version` - The minimum TLS version for the connection, that will be accepted.
/// * `certificates` - Certificate chain for the connection. Will play a different role
/// depending on if running as client or server. See [Certificates] for more information.
/// * `rsa` - Optionally take an RSA driver instance. This session will use the hardware rsa crypto
/// accelerators for the session. Passing None will use the software implementation of RSA which is slower.
///
/// # Errors
///
Expand All @@ -639,11 +662,9 @@ pub mod asynch {
mode: Mode,
min_version: TlsVersion,
certificates: Certificates,
rsa: Option<impl Peripheral<P = RSA>>,
) -> Result<Self, TlsError> {
let (ssl_context, ssl_config, crt, client_crt, private_key) =
certificates.init_ssl(servername, mode, min_version)?;
unsafe { RSA_REF = core::mem::transmute(rsa.map(|inner| Rsa::new(inner, None))) }
return Ok(Self {
stream,
ssl_context,
Expand All @@ -654,14 +675,34 @@ pub mod asynch {
eof: false,
tx_buffer: Default::default(),
rx_buffer: Default::default(),
owns_rsa: false,
});
}

/// Enable the use of the hardware accelerated RSA peripheral for the [Session].
///
/// Note: Due to implementation constraints, this session and every other session will use the
/// hardware accelerated RSA driver until the sesssion called with this function is dropped.
///
/// # Arguments
///
/// * `rsa` - The RSA peripheral from the HAL
pub fn with_hardware_rsa(mut self, rsa: impl Peripheral<P = RSA>) -> Self {
unsafe { RSA_REF = core::mem::transmute(Some(Rsa::new(rsa, None))) }
self.owns_rsa = true;
self
}
}

impl<T, const BUFFER_SIZE: usize> Drop for Session<T, BUFFER_SIZE> {
fn drop(&mut self) {
log::debug!("session dropped - freeing memory");
unsafe {
// If the struct that owns the RSA reference is dropped
// we remove RSA in static for safety
if self.owns_rsa {
RSA_REF = core::mem::transmute(None::<RSA>);
}
mbedtls_ssl_close_notify(self.ssl_context);
mbedtls_ssl_config_free(self.ssl_config);
mbedtls_ssl_free(self.ssl_context);
Expand Down
4 changes: 2 additions & 2 deletions examples/async_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@ async fn main(spawner: Spawner) -> ! {
.ok(),
..Default::default()
},
Some(peripherals.RSA),
)
.unwrap();
.unwrap()
.with_hardware_rsa(peripherals.RSA);

println!("Start tls connect");
let mut tls = tls.connect().await.unwrap();
Expand Down
4 changes: 2 additions & 2 deletions examples/async_client_mTLS.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ async fn main(spawner: Spawner) -> ! {
Mode::Client,
TlsVersion::Tls1_3,
certificates,
Some(peripherals.RSA),
)
.unwrap();
.unwrap()
.with_hardware_rsa(peripherals.RSA);

println!("Start tls connect");
let mut tls = tls.connect().await.unwrap();
Expand Down
4 changes: 2 additions & 2 deletions examples/async_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,9 @@ async fn main(spawner: Spawner) -> ! {
.ok(),
..Default::default()
},
Some(&mut peripherals.RSA),
)
.unwrap();
.unwrap()
.with_hardware_rsa(&mut peripherals.RSA);

println!("Start tls connect");
match tls.connect().await {
Expand Down
4 changes: 2 additions & 2 deletions examples/async_server_mTLS.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,9 @@ async fn main(spawner: Spawner) -> ! {
.ok(),
..Default::default()
},
Some(&mut peripherals.RSA),
)
.unwrap();
.unwrap()
.with_hardware_rsa(&mut peripherals.RSA);

println!("Start tls connect");
match tls.connect().await {
Expand Down
4 changes: 2 additions & 2 deletions examples/sync_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ fn main() -> ! {
.ok(),
..Default::default()
},
Some(peripherals.RSA),
)
.unwrap();
.unwrap()
.with_hardware_rsa(peripherals.RSA);

println!("Start tls connect");
let mut tls = tls.connect().unwrap();
Expand Down
4 changes: 2 additions & 2 deletions examples/sync_client_mTLS.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ fn main() -> ! {
Mode::Client,
TlsVersion::Tls1_3,
certificates,
Some(peripherals.RSA),
)
.unwrap();
.unwrap()
.with_hardware_rsa(peripherals.RSA);

println!("Start tls connect");
let mut tls = tls.connect().unwrap();
Expand Down
5 changes: 3 additions & 2 deletions examples/sync_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,10 @@ fn main() -> ! {
.ok(),
..Default::default()
},
Some(&mut peripherals.RSA),
)
.unwrap();
.unwrap()
.with_hardware_rsa(&mut peripherals.RSA);

match tls.connect() {
Ok(mut connected_session) => {
loop {
Expand Down
5 changes: 3 additions & 2 deletions examples/sync_server_mTLS.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,10 @@ fn main() -> ! {
.ok(),
..Default::default()
},
Some(&mut peripherals.RSA),
)
.unwrap();
.unwrap()
.with_hardware_rsa(&mut peripherals.RSA);

match tls.connect() {
Ok(mut connected_session) => {
loop {
Expand Down

0 comments on commit 5fb0ca7

Please sign in to comment.