Skip to content

Commit

Permalink
updated tls options retreival at nw_socket connect
Browse files Browse the repository at this point in the history
  • Loading branch information
sbSteveK committed Sep 26, 2024
1 parent 06ff298 commit c6adea9
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 29 deletions.
9 changes: 7 additions & 2 deletions include/aws/io/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ struct aws_socket_options {
uint16_t keep_alive_max_failed_probes;
enum aws_event_loop_style event_loop_style;
bool keepalive;
void *tls_ctx;
struct aws_string *host_name;

/**
* THIS IS AN EXPERIMENTAL AND UNSTABLE API
Expand All @@ -74,6 +72,11 @@ struct aws_event_loop;
*/
typedef void(aws_socket_on_connection_result_fn)(struct aws_socket *socket, int error_code, void *user_data);

/**
* Called to retrieve TLS related options
*/
typedef void(aws_socket_retrieve_tls_options_fn)(struct aws_tls_connection_options **tls_ctx_options, void *user_data);

/**
* Called by a listening socket when either an incoming connection has been received or an error occurred.
*
Expand Down Expand Up @@ -127,6 +130,7 @@ struct aws_socket_vtable {
const struct aws_socket_endpoint *remote_endpoint,
struct aws_event_loop *event_loop,
aws_socket_on_connection_result_fn *on_connection_result,
aws_socket_retrieve_tls_options_fn *retrieve_tls_options,
void *user_data);
int (*socket_bind_fn)(struct aws_socket *socket, const struct aws_socket_endpoint *local_endpoint);
int (*socket_listen_fn)(struct aws_socket *socket, int backlog_size);
Expand Down Expand Up @@ -232,6 +236,7 @@ AWS_IO_API int aws_socket_connect(
const struct aws_socket_endpoint *remote_endpoint,
struct aws_event_loop *event_loop,
aws_socket_on_connection_result_fn *on_connection_result,
aws_socket_retrieve_tls_options_fn *retrieve_tls_options,
void *user_data);

/**
Expand Down
23 changes: 15 additions & 8 deletions source/channel_bootstrap.c
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,14 @@ static void s_on_client_connection_established(struct aws_socket *socket, int er
}
}

/* Called when a socket connection attempt requires access to TLS options. Currently this is only necessary on
* iOS/tvOS where the parameters used to create the Apple Network Framework socket requires TLS options.
*/
static void s_retrieve_tls_options(struct aws_tls_connection_options **tls_ctx_options, void *user_data) {
struct client_connection_args *connection_args = user_data;
*tls_ctx_options = &connection_args->channel_data.tls_options;
}

struct connection_task_data {
struct aws_task task;
struct aws_socket_endpoint endpoint;
Expand All @@ -652,18 +660,12 @@ static void s_attempt_connection(struct aws_task *task, void *arg, enum aws_task
goto socket_init_failed;
}

/* Apple Network Framework TLS negotiation requires access to the stored SecItem identity
* and host_name. */
if (task_data->args->channel_data.use_tls) {
outgoing_socket->options.tls_ctx = task_data->args->channel_data.tls_options.ctx;
outgoing_socket->options.host_name = task_data->args->host_name;
}

if (aws_socket_connect(
outgoing_socket,
&task_data->endpoint,
task_data->connect_loop,
s_on_client_connection_established,
s_retrieve_tls_options,
task_data->args)) {

goto socket_connect_failed;
Expand Down Expand Up @@ -959,7 +961,12 @@ int aws_client_bootstrap_new_socket_channel(struct aws_socket_channel_bootstrap_

s_client_connection_args_acquire(client_connection_args);
if (aws_socket_connect(
outgoing_socket, &endpoint, connect_loop, s_on_client_connection_established, client_connection_args)) {
outgoing_socket,
&endpoint,
connect_loop,
s_on_client_connection_established,
s_retrieve_tls_options,
client_connection_args)) {
aws_socket_clean_up(outgoing_socket);
aws_mem_release(client_connection_args->bootstrap->allocator, outgoing_socket);
s_client_connection_args_release(client_connection_args);
Expand Down
37 changes: 28 additions & 9 deletions source/darwin/nw_socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ struct nw_socket {
void *connect_accept_user_data;
struct aws_event_loop *event_loop;
struct aws_string *host_name;
struct aws_tls_ctx *tls_ctx;
};

struct socket_address {
Expand Down Expand Up @@ -192,13 +193,15 @@ static void s_setup_tcp_options(nw_protocol_options_t tcp_options, const struct
}
}

static int s_setup_socket_params(struct nw_socket *nw_socket, const struct aws_socket_options *options) {
static int s_setup_socket_params(
struct nw_socket *nw_socket,
const struct aws_socket_options *options,
struct aws_tls_connection_options *tls_ctx_options) {
if (options->type == AWS_SOCKET_STREAM) {
if (options->domain == AWS_SOCKET_IPV4 || options->domain == AWS_SOCKET_IPV6) {
/* options->user_data will contain the tls_ctx if tls_ctx was initialized */
if (options->tls_ctx) {
struct aws_tls_ctx *tls_ctx = options->tls_ctx;
struct secure_transport_ctx *transport_ctx = tls_ctx->impl;
if (nw_socket->tls_ctx) {
struct secure_transport_ctx *transport_ctx = nw_socket->tls_ctx->impl;
struct dispatch_loop *dispatch_loop = nw_socket->event_loop->impl_data;

nw_socket->socket_options_to_params = nw_parameters_create_secure_tcp(
Expand Down Expand Up @@ -363,6 +366,7 @@ static int s_socket_connect_fn(
const struct aws_socket_endpoint *remote_endpoint,
struct aws_event_loop *event_loop,
aws_socket_on_connection_result_fn *on_connection_result,
aws_socket_retrieve_tls_options_fn *retrieve_tls_options,
void *user_data);
static int s_socket_bind_fn(struct aws_socket *socket, const struct aws_socket_endpoint *local_endpoint);
static int s_socket_listen_fn(struct aws_socket *socket, int backlog_size);
Expand Down Expand Up @@ -469,6 +473,11 @@ static void s_socket_impl_destroy(void *sock_ptr) {
aws_string_destroy(nw_socket->host_name);
}

if (nw_socket->tls_ctx) {
aws_tls_ctx_release(nw_socket->tls_ctx);
nw_socket->tls_ctx = NULL;
}

aws_mem_release(nw_socket->allocator, nw_socket->timeout_args);
aws_mem_release(nw_socket->allocator, nw_socket);
nw_socket = NULL;
Expand Down Expand Up @@ -545,26 +554,35 @@ static int s_socket_connect_fn(
const struct aws_socket_endpoint *remote_endpoint,
struct aws_event_loop *event_loop,
aws_socket_on_connection_result_fn *on_connection_result,
aws_socket_retrieve_tls_options_fn *retrieve_tls_options,
void *user_data) {
struct nw_socket *nw_socket = socket->impl;

AWS_ASSERT(event_loop);
AWS_ASSERT(!socket->event_loop);

if (socket->options.host_name) {
struct aws_tls_connection_options *tls_connection_options = NULL;
retrieve_tls_options(&tls_connection_options, user_data);

if (tls_connection_options->server_name) {
if (nw_socket->host_name != NULL) {
aws_string_destroy(nw_socket->host_name);
nw_socket->host_name = NULL;
}
nw_socket->host_name =
aws_string_new_from_string(socket->options.host_name->allocator, socket->options.host_name);
nw_socket->host_name = aws_string_new_from_string(
tls_connection_options->server_name->allocator, tls_connection_options->server_name);
if (nw_socket->host_name == NULL) {
return AWS_OP_ERR;
}
}

if (tls_connection_options->ctx) {
nw_socket->tls_ctx = tls_connection_options->ctx;
aws_tls_ctx_acquire(nw_socket->tls_ctx);
}

nw_socket->event_loop = event_loop;
if (s_setup_socket_params(nw_socket, &socket->options)) {
if (s_setup_socket_params(nw_socket, &socket->options, tls_connection_options)) {
return AWS_OP_ERR;
}

Expand Down Expand Up @@ -1054,6 +1072,7 @@ static int s_socket_shutdown_dir_fn(struct aws_socket *socket, enum aws_channel_
return s_socket_close_fn(socket);
}

/* DEBUG WIP REMOVE THIS OR SETUP s_setup_socket_params to get TLS options */
static int s_socket_set_options_fn(struct aws_socket *socket, const struct aws_socket_options *options) {
if (socket->options.domain != options->domain || socket->options.type != options->type) {
return aws_raise_error(AWS_IO_SOCKET_INVALID_OPTIONS);
Expand Down Expand Up @@ -1081,7 +1100,7 @@ static int s_socket_set_options_fn(struct aws_socket *socket, const struct aws_s
nw_socket->socket_options_to_params = NULL;
}

return s_setup_socket_params(nw_socket, options);
return s_setup_socket_params(nw_socket, options, NULL);
}

static int s_socket_assign_to_event_loop_fn(struct aws_socket *socket, struct aws_event_loop *event_loop) {
Expand Down
2 changes: 2 additions & 0 deletions source/posix/socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ static int s_socket_connect(
const struct aws_socket_endpoint *remote_endpoint,
struct aws_event_loop *event_loop,
aws_socket_on_connection_result_fn *on_connection_result,
aws_socket_retrieve_tls_options_fn retrieve_tls_options,
void *user_data);
static int s_socket_bind(struct aws_socket *socket, const struct aws_socket_endpoint *local_endpoint);
static int s_socket_listen(struct aws_socket *socket, int backlog_size);
Expand Down Expand Up @@ -660,6 +661,7 @@ static int s_socket_connect(
const struct aws_socket_endpoint *remote_endpoint,
struct aws_event_loop *event_loop,
aws_socket_on_connection_result_fn *on_connection_result,
aws_socket_retrieve_tls_options_fn *retrieve_tls_options,
void *user_data) {
AWS_ASSERT(event_loop);
AWS_ASSERT(!socket->event_loop);
Expand Down
4 changes: 3 additions & 1 deletion source/socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,12 @@ int aws_socket_connect(
const struct aws_socket_endpoint *remote_endpoint,
struct aws_event_loop *event_loop,
aws_socket_on_connection_result_fn *on_connection_result,
aws_socket_retrieve_tls_options_fn *retrieve_tls_options,
void *user_data) {
AWS_PRECONDITION(socket->vtable && socket->vtable->socket_connect_fn);
AWS_PRECONDITION(socket->event_loop_style & event_loop->vtable->event_loop_style);
return socket->vtable->socket_connect_fn(socket, remote_endpoint, event_loop, on_connection_result, user_data);
return socket->vtable->socket_connect_fn(
socket, remote_endpoint, event_loop, on_connection_result, retrieve_tls_options, user_data);
}

int aws_socket_bind(struct aws_socket *socket, const struct aws_socket_endpoint *local_endpoint) {
Expand Down
1 change: 1 addition & 0 deletions source/windows/iocp/socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ int aws_socket_connect(
const struct aws_socket_endpoint *remote_endpoint,
struct aws_event_loop *event_loop,
aws_socket_on_connection_result_fn *on_connection_result,
aws_socket_retrieve_tls_options_fn *retrieve_tls_options,
void *user_data) {
struct iocp_socket *socket_impl = socket->impl;
if (socket->options.type != AWS_SOCKET_DGRAM) {
Expand Down
19 changes: 10 additions & 9 deletions tests/socket_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ static int s_test_socket_ex(
ASSERT_SUCCESS(aws_socket_bind(&outgoing, local));
}
ASSERT_SUCCESS(
aws_socket_connect(&outgoing, endpoint, event_loop, s_local_outgoing_connection, &outgoing_args));
aws_socket_connect(&outgoing, endpoint, event_loop, s_local_outgoing_connection, NULL, &outgoing_args));

if (listener.options.type == AWS_SOCKET_STREAM) {
ASSERT_SUCCESS(aws_mutex_lock(&mutex));
Expand Down Expand Up @@ -709,7 +709,7 @@ static int s_test_connect_timeout(struct aws_allocator *allocator, void *ctx) {
struct aws_socket outgoing;
ASSERT_SUCCESS(aws_socket_init(&outgoing, allocator, &options));
ASSERT_SUCCESS(
aws_socket_connect(&outgoing, &endpoint, event_loop, s_local_outgoing_connection, &outgoing_args));
aws_socket_connect(&outgoing, &endpoint, event_loop, s_local_outgoing_connection, NULL, &outgoing_args));
aws_mutex_lock(&mutex);
ASSERT_SUCCESS(aws_condition_variable_wait_pred(
&condition_variable, &mutex, s_connection_completed_predicate, &outgoing_args));
Expand Down Expand Up @@ -797,7 +797,7 @@ static int s_test_connect_timeout_cancelation(struct aws_allocator *allocator, v
struct aws_socket outgoing;
ASSERT_SUCCESS(aws_socket_init(&outgoing, allocator, &options));
ASSERT_SUCCESS(
aws_socket_connect(&outgoing, &endpoint, event_loop, s_local_outgoing_connection, &outgoing_args));
aws_socket_connect(&outgoing, &endpoint, event_loop, s_local_outgoing_connection, NULL, &outgoing_args));

aws_event_loop_group_release(el_group);

Expand Down Expand Up @@ -859,7 +859,7 @@ static int s_test_outgoing_local_sock_errors(struct aws_allocator *allocator, vo
struct aws_socket outgoing;
ASSERT_SUCCESS(aws_socket_init(&outgoing, allocator, &options));

ASSERT_FAILS(aws_socket_connect(&outgoing, &endpoint, event_loop, s_null_sock_connection, &args));
ASSERT_FAILS(aws_socket_connect(&outgoing, &endpoint, event_loop, s_null_sock_connection, NULL, &args));
ASSERT_TRUE(
aws_last_error() == AWS_IO_SOCKET_CONNECTION_REFUSED || aws_last_error() == AWS_ERROR_FILE_INVALID_PATH);

Expand Down Expand Up @@ -907,7 +907,7 @@ static int s_test_outgoing_tcp_sock_error(struct aws_allocator *allocator, void
struct aws_socket outgoing;
ASSERT_SUCCESS(aws_socket_init(&outgoing, allocator, &options));
/* tcp connect is non-blocking, it should return success, but the error callback will be invoked. */
ASSERT_SUCCESS(aws_socket_connect(&outgoing, &endpoint, event_loop, s_null_sock_connection, &args));
ASSERT_SUCCESS(aws_socket_connect(&outgoing, &endpoint, event_loop, s_null_sock_connection, NULL, &args));
ASSERT_SUCCESS(aws_mutex_lock(&args.mutex));
ASSERT_SUCCESS(aws_condition_variable_wait_pred(
&args.condition_variable, &args.mutex, s_outgoing_tcp_error_predicate, &args));
Expand Down Expand Up @@ -1346,7 +1346,7 @@ static int s_cleanup_in_accept_doesnt_explode(struct aws_allocator *allocator, v
struct aws_socket outgoing;
ASSERT_SUCCESS(aws_socket_init(&outgoing, allocator, &options));
ASSERT_SUCCESS(
aws_socket_connect(&outgoing, &endpoint, event_loop, s_local_outgoing_connection, &outgoing_args));
aws_socket_connect(&outgoing, &endpoint, event_loop, s_local_outgoing_connection, NULL, &outgoing_args));

ASSERT_SUCCESS(aws_mutex_lock(&mutex));
ASSERT_SUCCESS(
Expand Down Expand Up @@ -1481,7 +1481,7 @@ static int s_cleanup_in_write_cb_doesnt_explode(struct aws_allocator *allocator,
struct aws_socket outgoing;
ASSERT_SUCCESS(aws_socket_init(&outgoing, allocator, &options));
ASSERT_SUCCESS(
aws_socket_connect(&outgoing, &endpoint, event_loop, s_local_outgoing_connection, &outgoing_args));
aws_socket_connect(&outgoing, &endpoint, event_loop, s_local_outgoing_connection, NULL, &outgoing_args));

ASSERT_SUCCESS(aws_mutex_lock(&mutex));
ASSERT_SUCCESS(
Expand Down Expand Up @@ -1743,7 +1743,7 @@ static int s_sock_write_cb_is_async(struct aws_allocator *allocator, void *ctx)
struct aws_socket outgoing;
ASSERT_SUCCESS(aws_socket_init(&outgoing, allocator, &options));
ASSERT_SUCCESS(
aws_socket_connect(&outgoing, &endpoint, event_loop, s_local_outgoing_connection, &outgoing_args));
aws_socket_connect(&outgoing, &endpoint, event_loop, s_local_outgoing_connection, NULL, &outgoing_args));

ASSERT_SUCCESS(aws_mutex_lock(&mutex));
ASSERT_SUCCESS(
Expand Down Expand Up @@ -1837,7 +1837,8 @@ static int s_local_socket_pipe_connected_race(struct aws_allocator *allocator, v
struct aws_socket outgoing;
ASSERT_SUCCESS(aws_socket_init(&outgoing, allocator, &options));

ASSERT_SUCCESS(aws_socket_connect(&outgoing, &endpoint, event_loop, s_local_outgoing_connection, &outgoing_args));
ASSERT_SUCCESS(
aws_socket_connect(&outgoing, &endpoint, event_loop, s_local_outgoing_connection, NULL, &outgoing_args));

ASSERT_SUCCESS(aws_socket_start_accept(&listener, event_loop, s_local_listener_incoming, &listener_args));
aws_mutex_lock(&mutex);
Expand Down
1 change: 1 addition & 0 deletions tests/tls_handler_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -1831,6 +1831,7 @@ static int s_tls_server_hangup_during_negotiation_fn(struct aws_allocator *alloc
&local_server_tester.endpoint,
aws_event_loop_group_get_next_loop(c_tester.el_group),
s_on_client_connected_do_hangup,
NULL,
shutdown_tester));

/* Wait for client socket to close */
Expand Down

0 comments on commit c6adea9

Please sign in to comment.