From 322cd722b28525e26101c60f111a0c75ad841ca9 Mon Sep 17 00:00:00 2001 From: Alsey Coleman Miller Date: Wed, 10 May 2023 11:36:44 -0700 Subject: [PATCH] Updated `AsyncSocketConfiguration` --- .../SocketManager/AsyncSocketManager.swift | 257 ++++++++---------- 1 file changed, 108 insertions(+), 149 deletions(-) diff --git a/Sources/Socket/SocketManager/AsyncSocketManager.swift b/Sources/Socket/SocketManager/AsyncSocketManager.swift index c18efce..494c18b 100644 --- a/Sources/Socket/SocketManager/AsyncSocketManager.swift +++ b/Sources/Socket/SocketManager/AsyncSocketManager.swift @@ -31,25 +31,23 @@ public struct AsyncSocketConfiguration { extension AsyncSocketConfiguration: SocketManagerConfiguration { - public static var manager: some SocketManager { + public static nonisolated var manager: some SocketManager { AsyncSocketManager.shared } public func configureManager() { Task { - await AsyncSocketManager.shared.storage.update { - $0.configuration = self - } + await AsyncSocketManager.shared.updateConfiguration(self) } } } /// Async Socket Manager -internal final class AsyncSocketManager: SocketManager { +internal actor AsyncSocketManager: SocketManager { // MARK: - Properties - fileprivate let storage = Storage() + fileprivate var state = ManagerState() // MARK: - Initialization @@ -61,11 +59,11 @@ internal final class AsyncSocketManager: SocketManager { func add( _ fileDescriptor: SocketDescriptor - ) async -> Socket.Event.Stream { - guard await sockets.keys.contains(fileDescriptor) == false else { + ) -> Socket.Event.Stream { + guard state.sockets.keys.contains(fileDescriptor) == false else { fatalError("Another socket for file descriptor \(fileDescriptor) already exists.") } - await log("Add socket \(fileDescriptor)") + log("Add socket \(fileDescriptor)") // make sure its non blocking do { var status = try fileDescriptor.getStatus() @@ -75,31 +73,41 @@ internal final class AsyncSocketManager: SocketManager { } } catch { - await log("Unable to set non blocking. \(error)") + log("Unable to set non blocking. \(error)") assertionFailure("Unable to set non blocking. \(error)") } // append socket with events continuation - let eventStream = await storage.update { manager in - Socket.Event.Stream(bufferingPolicy: .bufferingNewest(1)) { continuation in - manager.sockets[fileDescriptor] = SocketState( - fileDescriptor: fileDescriptor, - continuation: continuation - ) - } + let eventStream = Socket.Event.Stream(bufferingPolicy: .bufferingNewest(1)) { continuation in + state.sockets[fileDescriptor] = SocketState( + fileDescriptor: fileDescriptor, + continuation: continuation + ) } // start monitoring - await startMonitoring() + startMonitoring() return eventStream } - func remove(_ fileDescriptor: SocketDescriptor) async { - await storage.update { - $0.remove(fileDescriptor) + func remove(_ fileDescriptor: SocketDescriptor) { + guard let socket = state.sockets[fileDescriptor] else { + return // could have been removed previously + } + log("Remove socket \(fileDescriptor)") + // close underlying socket + try? fileDescriptor.close() + // cancel all pending actions + Task(priority: .userInitiated) { + await socket.dequeueAll(Errno.connectionAbort) } + // notify + socket.continuation.yield(.close) + socket.continuation.finish() + // update sockets to monitor + state.sockets[fileDescriptor] = nil } /// Write data to managed file descriptor. - func write( + nonisolated func write( _ data: Data, for fileDescriptor: SocketDescriptor ) async throws -> Int { @@ -109,7 +117,7 @@ internal final class AsyncSocketManager: SocketManager { } /// Read managed file descriptor. - func read( + nonisolated func read( _ length: Int, for fileDescriptor: SocketDescriptor ) async throws -> Data { @@ -118,7 +126,7 @@ internal final class AsyncSocketManager: SocketManager { return try await socket.read(length) } - func sendMessage( + nonisolated func sendMessage( _ data: Data, for fileDescriptor: SocketDescriptor ) async throws -> Int { @@ -127,7 +135,7 @@ internal final class AsyncSocketManager: SocketManager { return try await socket.sendMessage(data) } - func sendMessage( + nonisolated func sendMessage( _ data: Data, to address: Address, for fileDescriptor: SocketDescriptor @@ -137,7 +145,7 @@ internal final class AsyncSocketManager: SocketManager { return try await socket.sendMessage(data, to: address) } - func receiveMessage( + nonisolated func receiveMessage( _ length: Int, for fileDescriptor: SocketDescriptor ) async throws -> Data { @@ -146,7 +154,7 @@ internal final class AsyncSocketManager: SocketManager { return try await socket.receiveMessage(length) } - func receiveMessage( + nonisolated func receiveMessage( _ length: Int, fromAddressOf addressType: Address.Type, for fileDescriptor: SocketDescriptor @@ -157,9 +165,9 @@ internal final class AsyncSocketManager: SocketManager { } /// Accept a connection on a socket. - func accept(for fileDescriptor: SocketDescriptor) async throws -> SocketDescriptor { - let socket = try await storage.state.socket(for: fileDescriptor) - let result = try await retry(sleep: configuration.monitorInterval) { + nonisolated func accept(for fileDescriptor: SocketDescriptor) async throws -> SocketDescriptor { + let socket = try await socket(for: fileDescriptor) + let result = try await retry(sleep: state.configuration.monitorInterval) { fileDescriptor._accept(retryOnInterrupt: true) }.get() socket.continuation.yield(.connection) @@ -167,12 +175,12 @@ internal final class AsyncSocketManager: SocketManager { } /// Accept a connection on a socket. - func accept( + nonisolated func accept( _ address: Address.Type, for fileDescriptor: SocketDescriptor ) async throws -> (fileDescriptor: SocketDescriptor, address: Address) { - let socket = try await storage.state.socket(for: fileDescriptor) - let result = try await retry(sleep: configuration.monitorInterval) { + let socket = try await socket(for: fileDescriptor) + let result = try await retry(sleep: state.configuration.monitorInterval) { fileDescriptor._accept(address, retryOnInterrupt: true) }.get() socket.continuation.yield(.connection) @@ -180,77 +188,81 @@ internal final class AsyncSocketManager: SocketManager { } /// Initiate a connection on a socket. - func connect( + nonisolated func connect( to address: Address, for fileDescriptor: SocketDescriptor ) async throws { - let socket = try await storage.state.socket(for: fileDescriptor) - try await retry(sleep: configuration.monitorInterval) { + let socket = try await socket(for: fileDescriptor) + try await retry(sleep: state.configuration.monitorInterval) { fileDescriptor._connect(to: address, retryOnInterrupt: true) }.get() socket.continuation.yield(.connection) } +} + +// MARK: - Private Methods + +private extension AsyncSocketManager { - // MARK: - Private Methods + func updateConfiguration(_ configuration: AsyncSocketConfiguration) { + self.state.configuration = configuration + } - private func startMonitoring() async { - guard await storage.update({ - guard $0.isMonitoring == false else { return false } - $0.log("Will start monitoring") - $0.isMonitoring = true - return true - }) else { return } + func startMonitoring() { + guard state.isMonitoring == false + else { return } + log("Will start monitoring") + state.isMonitoring = true // Create top level task to monitor - let configuration = await AsyncSocketManager.shared.configuration - Task.detached(priority: configuration.monitorPriority) { [unowned self] in - var tasks = [Task]() - while await self.isMonitoring { - do { - let hasEvents = try await storage.update({ (state: inout ManagerState) -> Bool in - tasks.reserveCapacity(state.sockets.count * 2) - // poll - let hasEvents = try state.poll(&tasks) - // stop monitoring if no sockets - if state.pollDescriptors.isEmpty { - state.isMonitoring = false - } - return hasEvents - }) - // wait for each task to complete - for task in tasks { - await task.value - } - tasks.removeAll(keepingCapacity: true) - // sleep - if hasEvents == false { - try await Task.sleep(nanoseconds: configuration.monitorInterval) - } + Task.detached(priority: state.configuration.monitorPriority) { [unowned self] in + await self.run() + } + } + + func run() async { + var tasks = [Task]() + while self.state.isMonitoring { + do { + tasks.reserveCapacity(state.sockets.count * 2) + // poll + let hasEvents = try poll(&tasks) + // stop monitoring if no sockets + if state.pollDescriptors.isEmpty { + state.isMonitoring = false + } + // wait for each task to complete + for task in tasks { + await task.value } - catch { - await log("Socket monitoring failed. \(error.localizedDescription)") - assertionFailure("Socket monitoring failed. \(error.localizedDescription)") - await storage.update { - $0.isMonitoring = false - } + tasks.removeAll(keepingCapacity: true) + // sleep + if hasEvents == false { + try await Task.sleep(nanoseconds: state.configuration.monitorInterval) } } + catch { + log("Socket monitoring failed. \(error.localizedDescription)") + assertionFailure("Socket monitoring failed. \(error.localizedDescription)") + state.isMonitoring = false + return + } } } - private func contains(_ fileDescriptor: SocketDescriptor) async -> Bool { - return await sockets.keys.contains(fileDescriptor) + func contains(_ fileDescriptor: SocketDescriptor) -> Bool { + return state.sockets.keys.contains(fileDescriptor) } - private func wait( + func wait( for events: FileEvents, fileDescriptor: SocketDescriptor ) async throws -> SocketState { // wait - let socket = try await storage.state.socket(for: fileDescriptor) + let socket = try socket(for: fileDescriptor) guard await socket.pendingEvents.contains(events) == false else { return socket // execute immediately } - await log("Will wait for \(events) for \(fileDescriptor)") + log("Will wait for \(events) for \(fileDescriptor)") // store continuation to resume when event is polled try await withThrowingContinuation(for: fileDescriptor) { (continuation: SocketContinuation<(), Swift.Error>) -> () in // store pending continuation @@ -260,69 +272,25 @@ internal final class AsyncSocketManager: SocketManager { } return socket } -} - -private extension AsyncSocketManager { - - var configuration: AsyncSocketConfiguration { - get async { await storage.state.configuration } - } - - var sockets: [SocketDescriptor: SocketState] { - get async { await storage.state.sockets } - } - - var pollDescriptors: [SocketDescriptor.Poll] { - get async { await storage.state.pollDescriptors } - } - - var isMonitoring: Bool { - get async { await storage.state.isMonitoring } - } - - func log(_ message: String) async { - await storage.state.log(message) - } -} - -extension AsyncSocketManager.ManagerState { func socket( for fileDescriptor: SocketDescriptor ) throws -> AsyncSocketManager.SocketState { - guard let socket = self.sockets[fileDescriptor] else { + guard let socket = state.sockets[fileDescriptor] else { throw Errno.socketShutdown } return socket } - mutating func remove(_ fileDescriptor: SocketDescriptor) { - guard let socket = sockets[fileDescriptor] else { - return // could have been removed previously - } - log("Remove socket \(fileDescriptor)") - // close underlying socket - try? fileDescriptor.close() - // cancel all pending actions - Task(priority: .userInitiated) { - await socket.dequeueAll(Errno.connectionAbort) - } - // notify - socket.continuation.yield(.close) - socket.continuation.finish() - // update sockets to monitor - sockets[fileDescriptor] = nil - } - /// Poll for events. @discardableResult - mutating func poll(_ tasks: inout [Task]) throws -> Bool { + func poll(_ tasks: inout [Task]) throws -> Bool { // build poll descriptor array - let sockets = self.sockets + let sockets = state.sockets .lazy .sorted(by: { $0.key.rawValue < $1.key.rawValue }) - pollDescriptors.removeAll(keepingCapacity: true) - pollDescriptors.reserveCapacity(sockets.count) + state.pollDescriptors.removeAll(keepingCapacity: true) + state.pollDescriptors.reserveCapacity(sockets.count) let events: FileEvents = [ .read, .readUrgent, @@ -336,22 +304,22 @@ extension AsyncSocketManager.ManagerState { socket: fileDescriptor, events: events ) - pollDescriptors.append(poll) + state.pollDescriptors.append(poll) } - assert(pollDescriptors.count == sockets.count) + assert(state.pollDescriptors.count == sockets.count) // poll sockets do { - try pollDescriptors.poll() + try state.pollDescriptors.poll() } catch { log("Unable to poll for events. \(error.localizedDescription)") throw error } // wait for concurrent handling - let hasEvents = pollDescriptors.contains(where: { $0.returnedEvents.isEmpty == false }) + let hasEvents = state.pollDescriptors.contains(where: { $0.returnedEvents.isEmpty == false }) if hasEvents { - for poll in pollDescriptors { - guard let state = self.sockets[poll.socket] else { + for poll in state.pollDescriptors { + guard let state = state.sockets[poll.socket] else { preconditionFailure() continue } @@ -361,7 +329,7 @@ extension AsyncSocketManager.ManagerState { return hasEvents } - mutating func process(_ poll: SocketDescriptor.Poll, socket: AsyncSocketManager.SocketState, tasks: inout [Task]) { + func process(_ poll: SocketDescriptor.Poll, socket: AsyncSocketManager.SocketState, tasks: inout [Task]) { /* let isListening = self.sockets[poll.socket]?.isListening ?? false if isListening, poll.returnedEvents.contains([.read, .write]) { @@ -397,12 +365,12 @@ extension AsyncSocketManager.ManagerState { } } - mutating func error(_ error: Errno, for fileDescriptor: SocketDescriptor) { - self.sockets[fileDescriptor]?.continuation.yield(.error(error)) + func error(_ error: Errno, for fileDescriptor: SocketDescriptor) { + state.sockets[fileDescriptor]?.continuation.yield(.error(error)) remove(fileDescriptor) } - mutating func hangup(_ fileDescriptor: SocketDescriptor) { + func hangup(_ fileDescriptor: SocketDescriptor) { remove(fileDescriptor) } } @@ -521,14 +489,14 @@ fileprivate extension AsyncSocketManager.SocketState { } } -extension AsyncSocketManager.ManagerState { +extension AsyncSocketManager { #if DEBUG static let debugLogEnabled = ProcessInfo.processInfo.environment["SWIFTSOCKETDEBUG"] == "1" #endif func log(_ message: String) { - if let logger = configuration.log { + if let logger = state.configuration.log { logger(message) } else { #if DEBUG @@ -544,15 +512,6 @@ extension AsyncSocketManager.ManagerState { extension AsyncSocketManager { - actor Storage { - - var state = ManagerState() - - func update(_ block: (inout ManagerState) throws -> (T)) rethrows -> T { - try block(&self.state) - } - } - struct ManagerState { var configuration = AsyncSocketConfiguration()