Skip to content

Commit

Permalink
Allow user to provide buffers to asynch::Session (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
GnomedDev authored Jul 23, 2024
1 parent 41a1b56 commit e8a33f2
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 29 deletions.
53 changes: 28 additions & 25 deletions esp-mbedtls/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ where
pub mod asynch {
use super::*;

pub struct Session<T, const BUFFER_SIZE: usize = 4096> {
pub struct Session<'a, T, const BUFFER_SIZE: usize = 4096> {
stream: T,
drbg_context: *mut mbedtls_ctr_drbg_context,
ssl_context: *mut mbedtls_ssl_context,
Expand All @@ -686,12 +686,12 @@ pub mod asynch {
client_crt: *mut mbedtls_x509_crt,
private_key: *mut mbedtls_pk_context,
eof: bool,
tx_buffer: BufferedBytes<BUFFER_SIZE>,
rx_buffer: BufferedBytes<BUFFER_SIZE>,
tx_buffer: BufferedBytes<'a, BUFFER_SIZE>,
rx_buffer: BufferedBytes<'a, BUFFER_SIZE>,
owns_rsa: bool,
}

impl<T, const BUFFER_SIZE: usize> Session<T, BUFFER_SIZE> {
impl<'a, T, const BUFFER_SIZE: usize> Session<'a, T, BUFFER_SIZE> {
/// Create a session for a TLS stream.
///
/// # Arguments
Expand All @@ -715,6 +715,9 @@ pub mod asynch {
mode: Mode,
min_version: TlsVersion,
certificates: Certificates,

tx_buffer: &'a mut [u8; BUFFER_SIZE],
rx_buffer: &'a mut [u8; BUFFER_SIZE],
) -> Result<Self, TlsError> {
let (drbg_context, ssl_context, ssl_config, crt, client_crt, private_key) =
certificates.init_ssl(servername, mode, min_version)?;
Expand All @@ -727,8 +730,8 @@ pub mod asynch {
client_crt,
private_key,
eof: false,
tx_buffer: Default::default(),
rx_buffer: Default::default(),
tx_buffer: BufferedBytes::new(tx_buffer),
rx_buffer: BufferedBytes::new(rx_buffer),
owns_rsa: false,
});
}
Expand All @@ -748,7 +751,7 @@ pub mod asynch {
}
}

impl<T, const BUFFER_SIZE: usize> Drop for Session<T, BUFFER_SIZE> {
impl<T, const BUFFER_SIZE: usize> Drop for Session<'_, T, BUFFER_SIZE> {
fn drop(&mut self) {
log::debug!("session dropped - freeing memory");
unsafe {
Expand All @@ -774,13 +777,13 @@ pub mod asynch {
}
}

impl<T, const BUFFER_SIZE: usize> Session<T, BUFFER_SIZE>
impl<'a, T, const BUFFER_SIZE: usize> Session<'a, T, BUFFER_SIZE>
where
T: embedded_io_async::Read + embedded_io_async::Write,
{
pub async fn connect<'b>(
pub async fn connect(
mut self,
) -> Result<AsyncConnectedSession<T, BUFFER_SIZE>, TlsError> {
) -> Result<AsyncConnectedSession<'a, T, BUFFER_SIZE>, TlsError> {
unsafe {
mbedtls_ssl_set_bio(
self.ssl_context,
Expand Down Expand Up @@ -979,22 +982,23 @@ pub mod asynch {
}
}

pub struct AsyncConnectedSession<T, const BUFFER_SIZE: usize>
pub struct AsyncConnectedSession<'a, T, const BUFFER_SIZE: usize>
where
T: embedded_io_async::Read + embedded_io_async::Write,
{
pub(crate) session: Session<T, BUFFER_SIZE>,
pub(crate) session: Session<'a, T, BUFFER_SIZE>,
}

impl<T, const BUFFER_SIZE: usize> embedded_io_async::ErrorType
for AsyncConnectedSession<T, BUFFER_SIZE>
for AsyncConnectedSession<'_, T, BUFFER_SIZE>
where
T: embedded_io_async::Read + embedded_io_async::Write,
{
type Error = TlsError;
}

impl<T, const BUFFER_SIZE: usize> embedded_io_async::Read for AsyncConnectedSession<T, BUFFER_SIZE>
impl<T, const BUFFER_SIZE: usize> embedded_io_async::Read
for AsyncConnectedSession<'_, T, BUFFER_SIZE>
where
T: embedded_io_async::Read + embedded_io_async::Write,
{
Expand All @@ -1016,7 +1020,8 @@ pub mod asynch {
}
}

impl<T, const BUFFER_SIZE: usize> embedded_io_async::Write for AsyncConnectedSession<T, BUFFER_SIZE>
impl<T, const BUFFER_SIZE: usize> embedded_io_async::Write
for AsyncConnectedSession<'_, T, BUFFER_SIZE>
where
T: embedded_io_async::Read + embedded_io_async::Write,
{
Expand All @@ -1038,24 +1043,22 @@ pub mod asynch {
.map_err(|_| TlsError::Unknown)
}
}
pub(crate) struct BufferedBytes<const BUFFER_SIZE: usize> {
buffer: [u8; BUFFER_SIZE],
pub(crate) struct BufferedBytes<'a, const BUFFER_SIZE: usize> {
buffer: &'a mut [u8; BUFFER_SIZE],
write_idx: usize,
read_idx: usize,
}

impl<const BUFFER_SIZE: usize> Default for BufferedBytes<BUFFER_SIZE> {
fn default() -> Self {
impl<'a, const BUFFER_SIZE: usize> BufferedBytes<'a, BUFFER_SIZE> {
pub fn new(buffer: &'a mut [u8; BUFFER_SIZE]) -> Self {
Self {
buffer: [0u8; BUFFER_SIZE],
write_idx: Default::default(),
read_idx: Default::default(),
buffer,
write_idx: 0,
read_idx: 0,
}
}
}

impl<const BUFFER_SIZE: usize> BufferedBytes<BUFFER_SIZE> {
pub fn pull<'a>(&'a mut self, max: usize) -> &'a [u8] {
pub fn pull(&mut self, max: usize) -> &[u8] {
if self.read_idx == self.write_idx {
self.read_idx = 0;
self.write_idx = 0;
Expand Down
4 changes: 3 additions & 1 deletion examples/async_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ async fn main(spawner: Spawner) -> ! {

set_debug(0);

let tls: Session<_, 4096> = Session::new(
let tls = Session::new(
&mut socket,
"www.google.com",
Mode::Client,
Expand All @@ -124,6 +124,8 @@ async fn main(spawner: Spawner) -> ! {
.ok(),
..Default::default()
},
make_static!([0; 4096]),
make_static!([0; 4096]),
)
.unwrap()
.with_hardware_rsa(peripherals.RSA);
Expand Down
4 changes: 3 additions & 1 deletion examples/async_client_mTLS.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,14 @@ async fn main(spawner: Spawner) -> ! {
password: None,
};

let tls: Session<_, 4096> = Session::new(
let tls = Session::new(
&mut socket,
"certauth.cryptomix.com",
Mode::Client,
TlsVersion::Tls1_3,
certificates,
make_static!([0; 4096]),
make_static!([0; 4096]),
)
.unwrap()
.with_hardware_rsa(peripherals.RSA);
Expand Down
4 changes: 3 additions & 1 deletion examples/async_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ async fn main(spawner: Spawner) -> ! {

let mut buffer = [0u8; 1024];
let mut pos = 0;
let tls: Session<_, 4096> = Session::new(
let tls = Session::new(
&mut socket,
"",
Mode::Server,
Expand All @@ -145,6 +145,8 @@ async fn main(spawner: Spawner) -> ! {
.ok(),
..Default::default()
},
make_static!([0; 4096]),
make_static!([0; 4096]),
)
.unwrap()
.with_hardware_rsa(&mut peripherals.RSA);
Expand Down
4 changes: 3 additions & 1 deletion examples/async_server_mTLS.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ async fn main(spawner: Spawner) -> ! {

let mut buffer = [0u8; 1024];
let mut pos = 0;
let tls: Session<_, 4096> = Session::new(
let tls = Session::new(
&mut socket,
"",
Mode::Server,
Expand All @@ -164,6 +164,8 @@ async fn main(spawner: Spawner) -> ! {
.ok(),
..Default::default()
},
make_static!([0; 4096]),
make_static!([0; 4096]),
)
.unwrap()
.with_hardware_rsa(&mut peripherals.RSA);
Expand Down

0 comments on commit e8a33f2

Please sign in to comment.