diff --git a/include/aws/io/socket.h b/include/aws/io/socket.h index d0740bf22..0ab5f180e 100644 --- a/include/aws/io/socket.h +++ b/include/aws/io/socket.h @@ -47,8 +47,6 @@ struct aws_socket_options { uint16_t keep_alive_max_failed_probes; enum aws_event_loop_style event_loop_style; bool keepalive; - void *tls_ctx; - struct aws_string *host_name; /** * THIS IS AN EXPERIMENTAL AND UNSTABLE API @@ -74,6 +72,11 @@ struct aws_event_loop; */ typedef void(aws_socket_on_connection_result_fn)(struct aws_socket *socket, int error_code, void *user_data); +/** + * Called to retrieve TLS related options + */ +typedef void(aws_socket_retrieve_tls_options_fn)(struct aws_tls_connection_options **tls_ctx_options, void *user_data); + /** * Called by a listening socket when either an incoming connection has been received or an error occurred. * @@ -127,6 +130,7 @@ struct aws_socket_vtable { const struct aws_socket_endpoint *remote_endpoint, struct aws_event_loop *event_loop, aws_socket_on_connection_result_fn *on_connection_result, + aws_socket_retrieve_tls_options_fn *retrieve_tls_options, void *user_data); int (*socket_bind_fn)(struct aws_socket *socket, const struct aws_socket_endpoint *local_endpoint); int (*socket_listen_fn)(struct aws_socket *socket, int backlog_size); @@ -232,6 +236,7 @@ AWS_IO_API int aws_socket_connect( const struct aws_socket_endpoint *remote_endpoint, struct aws_event_loop *event_loop, aws_socket_on_connection_result_fn *on_connection_result, + aws_socket_retrieve_tls_options_fn *retrieve_tls_options, void *user_data); /** diff --git a/source/channel_bootstrap.c b/source/channel_bootstrap.c index 8fd143a0f..693d694da 100644 --- a/source/channel_bootstrap.c +++ b/source/channel_bootstrap.c @@ -628,6 +628,14 @@ static void s_on_client_connection_established(struct aws_socket *socket, int er } } +/* Called when a socket connection attempt requires access to TLS options. Currently this is only necessary on + * iOS/tvOS where the parameters used to create the Apple Network Framework socket requires TLS options. + */ +static void s_retrieve_tls_options(struct aws_tls_connection_options **tls_ctx_options, void *user_data) { + struct client_connection_args *connection_args = user_data; + *tls_ctx_options = &connection_args->channel_data.tls_options; +} + struct connection_task_data { struct aws_task task; struct aws_socket_endpoint endpoint; @@ -652,18 +660,12 @@ static void s_attempt_connection(struct aws_task *task, void *arg, enum aws_task goto socket_init_failed; } - /* Apple Network Framework TLS negotiation requires access to the stored SecItem identity - * and host_name. */ - if (task_data->args->channel_data.use_tls) { - outgoing_socket->options.tls_ctx = task_data->args->channel_data.tls_options.ctx; - outgoing_socket->options.host_name = task_data->args->host_name; - } - if (aws_socket_connect( outgoing_socket, &task_data->endpoint, task_data->connect_loop, s_on_client_connection_established, + s_retrieve_tls_options, task_data->args)) { goto socket_connect_failed; @@ -959,7 +961,12 @@ int aws_client_bootstrap_new_socket_channel(struct aws_socket_channel_bootstrap_ s_client_connection_args_acquire(client_connection_args); if (aws_socket_connect( - outgoing_socket, &endpoint, connect_loop, s_on_client_connection_established, client_connection_args)) { + outgoing_socket, + &endpoint, + connect_loop, + s_on_client_connection_established, + s_retrieve_tls_options, + client_connection_args)) { aws_socket_clean_up(outgoing_socket); aws_mem_release(client_connection_args->bootstrap->allocator, outgoing_socket); s_client_connection_args_release(client_connection_args); diff --git a/source/darwin/nw_socket.c b/source/darwin/nw_socket.c index c899ee94d..7849eec7d 100644 --- a/source/darwin/nw_socket.c +++ b/source/darwin/nw_socket.c @@ -156,6 +156,7 @@ struct nw_socket { void *connect_accept_user_data; struct aws_event_loop *event_loop; struct aws_string *host_name; + struct aws_tls_ctx *tls_ctx; }; struct socket_address { @@ -192,13 +193,15 @@ static void s_setup_tcp_options(nw_protocol_options_t tcp_options, const struct } } -static int s_setup_socket_params(struct nw_socket *nw_socket, const struct aws_socket_options *options) { +static int s_setup_socket_params( + struct nw_socket *nw_socket, + const struct aws_socket_options *options, + struct aws_tls_connection_options *tls_ctx_options) { if (options->type == AWS_SOCKET_STREAM) { if (options->domain == AWS_SOCKET_IPV4 || options->domain == AWS_SOCKET_IPV6) { /* options->user_data will contain the tls_ctx if tls_ctx was initialized */ - if (options->tls_ctx) { - struct aws_tls_ctx *tls_ctx = options->tls_ctx; - struct secure_transport_ctx *transport_ctx = tls_ctx->impl; + if (nw_socket->tls_ctx) { + struct secure_transport_ctx *transport_ctx = nw_socket->tls_ctx->impl; struct dispatch_loop *dispatch_loop = nw_socket->event_loop->impl_data; nw_socket->socket_options_to_params = nw_parameters_create_secure_tcp( @@ -363,6 +366,7 @@ static int s_socket_connect_fn( const struct aws_socket_endpoint *remote_endpoint, struct aws_event_loop *event_loop, aws_socket_on_connection_result_fn *on_connection_result, + aws_socket_retrieve_tls_options_fn *retrieve_tls_options, void *user_data); static int s_socket_bind_fn(struct aws_socket *socket, const struct aws_socket_endpoint *local_endpoint); static int s_socket_listen_fn(struct aws_socket *socket, int backlog_size); @@ -469,6 +473,11 @@ static void s_socket_impl_destroy(void *sock_ptr) { aws_string_destroy(nw_socket->host_name); } + if (nw_socket->tls_ctx) { + aws_tls_ctx_release(nw_socket->tls_ctx); + nw_socket->tls_ctx = NULL; + } + aws_mem_release(nw_socket->allocator, nw_socket->timeout_args); aws_mem_release(nw_socket->allocator, nw_socket); nw_socket = NULL; @@ -545,26 +554,35 @@ static int s_socket_connect_fn( const struct aws_socket_endpoint *remote_endpoint, struct aws_event_loop *event_loop, aws_socket_on_connection_result_fn *on_connection_result, + aws_socket_retrieve_tls_options_fn *retrieve_tls_options, void *user_data) { struct nw_socket *nw_socket = socket->impl; AWS_ASSERT(event_loop); AWS_ASSERT(!socket->event_loop); - if (socket->options.host_name) { + struct aws_tls_connection_options *tls_connection_options = NULL; + retrieve_tls_options(&tls_connection_options, user_data); + + if (tls_connection_options->server_name) { if (nw_socket->host_name != NULL) { aws_string_destroy(nw_socket->host_name); nw_socket->host_name = NULL; } - nw_socket->host_name = - aws_string_new_from_string(socket->options.host_name->allocator, socket->options.host_name); + nw_socket->host_name = aws_string_new_from_string( + tls_connection_options->server_name->allocator, tls_connection_options->server_name); if (nw_socket->host_name == NULL) { return AWS_OP_ERR; } } + if (tls_connection_options->ctx) { + nw_socket->tls_ctx = tls_connection_options->ctx; + aws_tls_ctx_acquire(nw_socket->tls_ctx); + } + nw_socket->event_loop = event_loop; - if (s_setup_socket_params(nw_socket, &socket->options)) { + if (s_setup_socket_params(nw_socket, &socket->options, tls_connection_options)) { return AWS_OP_ERR; } @@ -1054,6 +1072,7 @@ static int s_socket_shutdown_dir_fn(struct aws_socket *socket, enum aws_channel_ return s_socket_close_fn(socket); } +/* DEBUG WIP REMOVE THIS OR SETUP s_setup_socket_params to get TLS options */ static int s_socket_set_options_fn(struct aws_socket *socket, const struct aws_socket_options *options) { if (socket->options.domain != options->domain || socket->options.type != options->type) { return aws_raise_error(AWS_IO_SOCKET_INVALID_OPTIONS); @@ -1081,7 +1100,7 @@ static int s_socket_set_options_fn(struct aws_socket *socket, const struct aws_s nw_socket->socket_options_to_params = NULL; } - return s_setup_socket_params(nw_socket, options); + return s_setup_socket_params(nw_socket, options, NULL); } static int s_socket_assign_to_event_loop_fn(struct aws_socket *socket, struct aws_event_loop *event_loop) { diff --git a/source/posix/socket.c b/source/posix/socket.c index 1251160f7..c4bfc7963 100644 --- a/source/posix/socket.c +++ b/source/posix/socket.c @@ -193,6 +193,7 @@ static int s_socket_connect( const struct aws_socket_endpoint *remote_endpoint, struct aws_event_loop *event_loop, aws_socket_on_connection_result_fn *on_connection_result, + aws_socket_retrieve_tls_options_fn retrieve_tls_options, void *user_data); static int s_socket_bind(struct aws_socket *socket, const struct aws_socket_endpoint *local_endpoint); static int s_socket_listen(struct aws_socket *socket, int backlog_size); @@ -660,6 +661,7 @@ static int s_socket_connect( const struct aws_socket_endpoint *remote_endpoint, struct aws_event_loop *event_loop, aws_socket_on_connection_result_fn *on_connection_result, + aws_socket_retrieve_tls_options_fn *retrieve_tls_options, void *user_data) { AWS_ASSERT(event_loop); AWS_ASSERT(!socket->event_loop); diff --git a/source/socket.c b/source/socket.c index bbb8593a2..6d9949c78 100644 --- a/source/socket.c +++ b/source/socket.c @@ -108,10 +108,12 @@ int aws_socket_connect( const struct aws_socket_endpoint *remote_endpoint, struct aws_event_loop *event_loop, aws_socket_on_connection_result_fn *on_connection_result, + aws_socket_retrieve_tls_options_fn *retrieve_tls_options, void *user_data) { AWS_PRECONDITION(socket->vtable && socket->vtable->socket_connect_fn); AWS_PRECONDITION(socket->event_loop_style & event_loop->vtable->event_loop_style); - return socket->vtable->socket_connect_fn(socket, remote_endpoint, event_loop, on_connection_result, user_data); + return socket->vtable->socket_connect_fn( + socket, remote_endpoint, event_loop, on_connection_result, retrieve_tls_options, user_data); } int aws_socket_bind(struct aws_socket *socket, const struct aws_socket_endpoint *local_endpoint) { diff --git a/source/windows/iocp/socket.c b/source/windows/iocp/socket.c index 6039abb0c..aef7a73ef 100644 --- a/source/windows/iocp/socket.c +++ b/source/windows/iocp/socket.c @@ -434,6 +434,7 @@ int aws_socket_connect( const struct aws_socket_endpoint *remote_endpoint, struct aws_event_loop *event_loop, aws_socket_on_connection_result_fn *on_connection_result, + aws_socket_retrieve_tls_options_fn *retrieve_tls_options, void *user_data) { struct iocp_socket *socket_impl = socket->impl; if (socket->options.type != AWS_SOCKET_DGRAM) { diff --git a/tests/socket_test.c b/tests/socket_test.c index d5bc4a695..3010213ba 100644 --- a/tests/socket_test.c +++ b/tests/socket_test.c @@ -273,7 +273,7 @@ static int s_test_socket_ex( ASSERT_SUCCESS(aws_socket_bind(&outgoing, local)); } ASSERT_SUCCESS( - aws_socket_connect(&outgoing, endpoint, event_loop, s_local_outgoing_connection, &outgoing_args)); + aws_socket_connect(&outgoing, endpoint, event_loop, s_local_outgoing_connection, NULL, &outgoing_args)); if (listener.options.type == AWS_SOCKET_STREAM) { ASSERT_SUCCESS(aws_mutex_lock(&mutex)); @@ -709,7 +709,7 @@ static int s_test_connect_timeout(struct aws_allocator *allocator, void *ctx) { struct aws_socket outgoing; ASSERT_SUCCESS(aws_socket_init(&outgoing, allocator, &options)); ASSERT_SUCCESS( - aws_socket_connect(&outgoing, &endpoint, event_loop, s_local_outgoing_connection, &outgoing_args)); + aws_socket_connect(&outgoing, &endpoint, event_loop, s_local_outgoing_connection, NULL, &outgoing_args)); aws_mutex_lock(&mutex); ASSERT_SUCCESS(aws_condition_variable_wait_pred( &condition_variable, &mutex, s_connection_completed_predicate, &outgoing_args)); @@ -797,7 +797,7 @@ static int s_test_connect_timeout_cancelation(struct aws_allocator *allocator, v struct aws_socket outgoing; ASSERT_SUCCESS(aws_socket_init(&outgoing, allocator, &options)); ASSERT_SUCCESS( - aws_socket_connect(&outgoing, &endpoint, event_loop, s_local_outgoing_connection, &outgoing_args)); + aws_socket_connect(&outgoing, &endpoint, event_loop, s_local_outgoing_connection, NULL, &outgoing_args)); aws_event_loop_group_release(el_group); @@ -859,7 +859,7 @@ static int s_test_outgoing_local_sock_errors(struct aws_allocator *allocator, vo struct aws_socket outgoing; ASSERT_SUCCESS(aws_socket_init(&outgoing, allocator, &options)); - ASSERT_FAILS(aws_socket_connect(&outgoing, &endpoint, event_loop, s_null_sock_connection, &args)); + ASSERT_FAILS(aws_socket_connect(&outgoing, &endpoint, event_loop, s_null_sock_connection, NULL, &args)); ASSERT_TRUE( aws_last_error() == AWS_IO_SOCKET_CONNECTION_REFUSED || aws_last_error() == AWS_ERROR_FILE_INVALID_PATH); @@ -907,7 +907,7 @@ static int s_test_outgoing_tcp_sock_error(struct aws_allocator *allocator, void struct aws_socket outgoing; ASSERT_SUCCESS(aws_socket_init(&outgoing, allocator, &options)); /* tcp connect is non-blocking, it should return success, but the error callback will be invoked. */ - ASSERT_SUCCESS(aws_socket_connect(&outgoing, &endpoint, event_loop, s_null_sock_connection, &args)); + ASSERT_SUCCESS(aws_socket_connect(&outgoing, &endpoint, event_loop, s_null_sock_connection, NULL, &args)); ASSERT_SUCCESS(aws_mutex_lock(&args.mutex)); ASSERT_SUCCESS(aws_condition_variable_wait_pred( &args.condition_variable, &args.mutex, s_outgoing_tcp_error_predicate, &args)); @@ -1346,7 +1346,7 @@ static int s_cleanup_in_accept_doesnt_explode(struct aws_allocator *allocator, v struct aws_socket outgoing; ASSERT_SUCCESS(aws_socket_init(&outgoing, allocator, &options)); ASSERT_SUCCESS( - aws_socket_connect(&outgoing, &endpoint, event_loop, s_local_outgoing_connection, &outgoing_args)); + aws_socket_connect(&outgoing, &endpoint, event_loop, s_local_outgoing_connection, NULL, &outgoing_args)); ASSERT_SUCCESS(aws_mutex_lock(&mutex)); ASSERT_SUCCESS( @@ -1481,7 +1481,7 @@ static int s_cleanup_in_write_cb_doesnt_explode(struct aws_allocator *allocator, struct aws_socket outgoing; ASSERT_SUCCESS(aws_socket_init(&outgoing, allocator, &options)); ASSERT_SUCCESS( - aws_socket_connect(&outgoing, &endpoint, event_loop, s_local_outgoing_connection, &outgoing_args)); + aws_socket_connect(&outgoing, &endpoint, event_loop, s_local_outgoing_connection, NULL, &outgoing_args)); ASSERT_SUCCESS(aws_mutex_lock(&mutex)); ASSERT_SUCCESS( @@ -1743,7 +1743,7 @@ static int s_sock_write_cb_is_async(struct aws_allocator *allocator, void *ctx) struct aws_socket outgoing; ASSERT_SUCCESS(aws_socket_init(&outgoing, allocator, &options)); ASSERT_SUCCESS( - aws_socket_connect(&outgoing, &endpoint, event_loop, s_local_outgoing_connection, &outgoing_args)); + aws_socket_connect(&outgoing, &endpoint, event_loop, s_local_outgoing_connection, NULL, &outgoing_args)); ASSERT_SUCCESS(aws_mutex_lock(&mutex)); ASSERT_SUCCESS( @@ -1837,7 +1837,8 @@ static int s_local_socket_pipe_connected_race(struct aws_allocator *allocator, v struct aws_socket outgoing; ASSERT_SUCCESS(aws_socket_init(&outgoing, allocator, &options)); - ASSERT_SUCCESS(aws_socket_connect(&outgoing, &endpoint, event_loop, s_local_outgoing_connection, &outgoing_args)); + ASSERT_SUCCESS( + aws_socket_connect(&outgoing, &endpoint, event_loop, s_local_outgoing_connection, NULL, &outgoing_args)); ASSERT_SUCCESS(aws_socket_start_accept(&listener, event_loop, s_local_listener_incoming, &listener_args)); aws_mutex_lock(&mutex); diff --git a/tests/tls_handler_test.c b/tests/tls_handler_test.c index 0b0f5c88c..c058e7a62 100644 --- a/tests/tls_handler_test.c +++ b/tests/tls_handler_test.c @@ -1831,6 +1831,7 @@ static int s_tls_server_hangup_during_negotiation_fn(struct aws_allocator *alloc &local_server_tester.endpoint, aws_event_loop_group_get_next_loop(c_tester.el_group), s_on_client_connected_do_hangup, + NULL, shutdown_tester)); /* Wait for client socket to close */