Skip to content

Commit

Permalink
Merge pull request #2 from PureSwift/feature/system-socket-api
Browse files Browse the repository at this point in the history
Add `SocketDescriptor` and related low-level types.
  • Loading branch information
colemancda authored Apr 27, 2022
2 parents a90a85b + 203e2f8 commit e532539
Show file tree
Hide file tree
Showing 42 changed files with 3,946 additions and 60 deletions.
5 changes: 4 additions & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ let package = Package(
),
],
dependencies: [
.package(url: "https://github.com/PureSwift/swift-system.git", .branch("master")),
.package(
url: "https://github.com/apple/swift-system",
from: "1.0.0"
),
],
targets: [
.target(
Expand Down
22 changes: 0 additions & 22 deletions Sources/Socket/Extensions/FileEvents.swift

This file was deleted.

46 changes: 42 additions & 4 deletions Sources/Socket/Socket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//

import Foundation
import SystemPackage
@_exported import SystemPackage

/// Socket
public struct Socket {
Expand All @@ -16,8 +16,8 @@ public struct Socket {
/// Configuration for fine-tuning socket performance.
public static var configuration = Socket.Configuration()

/// Underlying file descriptor
public let fileDescriptor: FileDescriptor
/// Underlying native socket handle.
public let fileDescriptor: SocketDescriptor

public let event: Socket.Event.Stream

Expand All @@ -27,14 +27,52 @@ public struct Socket {

/// Starts monitoring a socket.
public init(
fileDescriptor: FileDescriptor
fileDescriptor: SocketDescriptor
) async {
let manager = SocketManager.shared
self.fileDescriptor = fileDescriptor
self.manager = manager
self.event = await manager.add(fileDescriptor)
}

/// Initialize
public init<T: SocketProtocol>(
_ protocolID: T
) async throws {
let fileDescriptor = try SocketDescriptor(protocolID)
await self.init(fileDescriptor: fileDescriptor)
}

///
public init<Address: SocketAddress>(
_ protocolID: Address.ProtocolID,
bind address: Address
) async throws {
let fileDescriptor = try SocketDescriptor(protocolID, bind: address)
await self.init(fileDescriptor: fileDescriptor)
}

#if os(Linux)
///
public init<T: SocketProtocol>(
_ protocolID: T,
flags: SocketFlags
) async throws {
let fileDescriptor = try SocketDescriptor(protocolID, flags: flags)
await self.init(fileDescriptor: fileDescriptor)
}

///
public init<Address: SocketAddress>(
_ protocolID: Address.ProtocolID,
bind address: Address,
flags: SocketFlags
) async throws {
let fileDescriptor = try SocketDescriptor(protocolID, bind: address, flags: flags)
await self.init(fileDescriptor: fileDescriptor)
}
#endif

// MARK: - Methods

/// Write to socket
Expand Down
12 changes: 6 additions & 6 deletions Sources/Socket/SocketContinuation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ internal struct SocketContinuation<T, E> where E: Error {

private let continuation: CheckedContinuation<T, E>

private let fileDescriptor: FileDescriptor
private let fileDescriptor: SocketDescriptor

fileprivate init(
continuation: UnsafeContinuation<T, E>,
function: String,
fileDescriptor: FileDescriptor
fileDescriptor: SocketDescriptor
) {
self.continuation = CheckedContinuation(continuation: continuation, function: function)
self.function = function
Expand Down Expand Up @@ -50,7 +50,7 @@ extension SocketContinuation where T == Void {
}

internal func withContinuation<T>(
for fileDescriptor: FileDescriptor,
for fileDescriptor: SocketDescriptor,
function: String = #function,
_ body: (SocketContinuation<T, Never>) -> Void
) async -> T {
Expand All @@ -60,7 +60,7 @@ internal func withContinuation<T>(
}

internal func withThrowingContinuation<T>(
for fileDescriptor: FileDescriptor,
for fileDescriptor: SocketDescriptor,
function: String = #function,
_ body: (SocketContinuation<T, Swift.Error>) -> Void
) async throws -> T {
Expand All @@ -73,7 +73,7 @@ internal typealias SocketContinuation<T, E> = UnsafeContinuation<T, E> where E:

@inline(__always)
internal func withContinuation<T>(
for fileDescriptor: FileDescriptor,
for fileDescriptor: SocketDescriptor,
function: String = #function,
_ body: (SocketContinuation<T, Never>) -> Void
) async -> T {
Expand All @@ -82,7 +82,7 @@ internal func withContinuation<T>(

@inline(__always)
internal func withThrowingContinuation<T>(
for fileDescriptor: FileDescriptor,
for fileDescriptor: SocketDescriptor,
function: String = #function,
_ body: (SocketContinuation<T, Swift.Error>) -> Void
) async throws -> T {
Expand Down
57 changes: 35 additions & 22 deletions Sources/Socket/SocketManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ internal actor SocketManager {

static let shared = SocketManager()

private var sockets = [FileDescriptor: SocketState]()
private var sockets = [SocketDescriptor: SocketState]()

private var pollDescriptors = [FileDescriptor.Poll]()
private var pollDescriptors = [SocketDescriptor.Poll]()

private var isMonitoring = false

Expand Down Expand Up @@ -45,12 +45,12 @@ internal actor SocketManager {
}
}

func contains(_ fileDescriptor: FileDescriptor) -> Bool {
func contains(_ fileDescriptor: SocketDescriptor) -> Bool {
return sockets.keys.contains(fileDescriptor)
}

func add(
_ fileDescriptor: FileDescriptor
_ fileDescriptor: SocketDescriptor
) -> Socket.Event.Stream {
guard sockets.keys.contains(fileDescriptor) == false else {
fatalError("Another socket for file descriptor \(fileDescriptor) already exists.")
Expand Down Expand Up @@ -83,7 +83,7 @@ internal actor SocketManager {
return event
}

func remove(_ fileDescriptor: FileDescriptor, error: Error? = nil) async {
func remove(_ fileDescriptor: SocketDescriptor, error: Error? = nil) async {
guard let socket = sockets[fileDescriptor] else {
return // could have been removed by `poll()`
}
Expand All @@ -101,7 +101,7 @@ internal actor SocketManager {
}

@discardableResult
internal nonisolated func write(_ data: Data, for fileDescriptor: FileDescriptor) async throws -> Int {
internal nonisolated func write(_ data: Data, for fileDescriptor: SocketDescriptor) async throws -> Int {
guard let socket = await sockets[fileDescriptor] else {
log("Unable to write unknown socket \(fileDescriptor).")
assertionFailure("\(#function) Unknown socket \(fileDescriptor)")
Expand All @@ -112,7 +112,7 @@ internal actor SocketManager {
return try await socket.write(data)
}

internal nonisolated func read(_ length: Int, for fileDescriptor: FileDescriptor) async throws -> Data {
internal nonisolated func read(_ length: Int, for fileDescriptor: SocketDescriptor) async throws -> Data {
guard let socket = await sockets[fileDescriptor] else {
log("Unable to read unknown socket \(fileDescriptor).")
assertionFailure("\(#function) Unknown socket \(fileDescriptor)")
Expand All @@ -123,14 +123,14 @@ internal actor SocketManager {
return try await socket.read(length)
}

private func events(for fileDescriptor: FileDescriptor) throws -> FileEvents {
guard let poll = pollDescriptors.first(where: { $0.fileDescriptor == fileDescriptor }) else {
private func events(for fileDescriptor: SocketDescriptor) throws -> FileEvents {
guard let poll = pollDescriptors.first(where: { $0.socket == fileDescriptor }) else {
throw Errno.connectionAbort
}
return poll.returnedEvents
}

private nonisolated func wait(for event: FileEvents, fileDescriptor: FileDescriptor) async throws {
private nonisolated func wait(for event: FileEvents, fileDescriptor: SocketDescriptor) async throws {
guard let socket = await sockets[fileDescriptor] else {
log("Unable to wait for unknown socket \(fileDescriptor).")
assertionFailure("\(#function) Unknown socket \(fileDescriptor)")
Expand Down Expand Up @@ -161,7 +161,7 @@ internal actor SocketManager {
pollDescriptors = sockets.keys
.lazy
.sorted(by: { $0.rawValue < $1.rawValue })
.map { FileDescriptor.Poll(fileDescriptor: $0, events: .socket) }
.map { SocketDescriptor.Poll(socket: $0, events: .socketManager) }
}

private func poll() async throws {
Expand All @@ -177,29 +177,29 @@ internal actor SocketManager {
// wait for concurrent handling
for poll in pollDescriptors {
if poll.returnedEvents.contains(.write) {
await self.canWrite(poll.fileDescriptor)
await self.canWrite(poll.socket)
}
if poll.returnedEvents.contains(.read) {
await self.shouldRead(poll.fileDescriptor)
await self.shouldRead(poll.socket)
}
if poll.returnedEvents.contains(.invalidRequest) {
assertionFailure("Polled for invalid socket \(poll.fileDescriptor)")
await self.error(.badFileDescriptor, for: poll.fileDescriptor)
assertionFailure("Polled for invalid socket \(poll.socket)")
await self.error(.badFileDescriptor, for: poll.socket)
}
if poll.returnedEvents.contains(.hangup) {
await self.error(.connectionReset, for: poll.fileDescriptor)
await self.error(.connectionReset, for: poll.socket)
}
if poll.returnedEvents.contains(.error) {
await self.error(.connectionAbort, for: poll.fileDescriptor)
await self.error(.connectionAbort, for: poll.socket)
}
}
}

private func error(_ error: Errno, for fileDescriptor: FileDescriptor) async {
private func error(_ error: Errno, for fileDescriptor: SocketDescriptor) async {
await self.remove(fileDescriptor, error: error)
}

private func shouldRead(_ fileDescriptor: FileDescriptor) async {
private func shouldRead(_ fileDescriptor: SocketDescriptor) async {
guard let socket = self.sockets[fileDescriptor] else {
log("Pending read for unknown socket \(fileDescriptor).")
assertionFailure("\(#function) Unknown socket \(fileDescriptor)")
Expand All @@ -211,7 +211,7 @@ internal actor SocketManager {
socket.event.yield(.pendingRead)
}

private func canWrite(_ fileDescriptor: FileDescriptor) async {
private func canWrite(_ fileDescriptor: SocketDescriptor) async {
guard let socket = self.sockets[fileDescriptor] else {
log("Can write for unknown socket \(fileDescriptor).")
assertionFailure("\(#function) Unknown socket \(fileDescriptor)")
Expand All @@ -228,13 +228,13 @@ extension SocketManager {

actor SocketState {

let fileDescriptor: FileDescriptor
let fileDescriptor: SocketDescriptor

let event: Socket.Event.Stream.Continuation

private var pendingEvent = [FileEvents: [SocketContinuation<(), Error>]]()

init(fileDescriptor: FileDescriptor,
init(fileDescriptor: SocketDescriptor,
event: Socket.Event.Stream.Continuation
) {
self.fileDescriptor = fileDescriptor
Expand Down Expand Up @@ -287,3 +287,16 @@ extension SocketManager.SocketState {
return data
}
}

private extension FileEvents {

static var socketManager: FileEvents {
[
.read,
.write,
.error,
.hangup,
.invalidRequest
]
}
}
78 changes: 78 additions & 0 deletions Sources/Socket/System/AsyncSocketOperations.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import SystemPackage

#if swift(>=5.5)
@available(macOS 10.15, iOS 13, watchOS 6, tvOS 13, *)
public extension SocketDescriptor {

/// Accept a connection on a socket.
///
/// - Parameters:
/// - retryOnInterrupt: Whether to retry the receive operation
/// if it throws ``Errno/interrupted``.
/// The default is `true`.
/// Pass `false` to try only once and throw an error upon interruption.
/// - sleep: The number of nanoseconds to sleep if the operation
/// throws ``Errno/wouldBlock`` or other async I/O errors..
/// - Returns: The file descriptor of the new connection.
///
/// The corresponding C function is `accept`.
@_alwaysEmitIntoClient
func accept(
retryOnInterrupt: Bool = true,
sleep: UInt64 = 10_000_000
) async throws -> SocketDescriptor {
try await retry(sleep: sleep) {
_accept(retryOnInterrupt: retryOnInterrupt)
}.get()
}

/// Accept a connection on a socket.
///
/// - Parameters:
/// - address: The type of the `SocketAddress` expected for the new connection.
/// - retryOnInterrupt: Whether to retry the receive operation
/// if it throws ``Errno/interrupted``.
/// The default is `true`.
/// Pass `false` to try only once and throw an error upon interruption.
/// - sleep: The number of nanoseconds to sleep if the operation
/// throws ``Errno/wouldBlock`` or other async I/O errors.
/// - Returns: A tuple containing the file descriptor and address of the new connection.
///
/// The corresponding C function is `accept`.
@_alwaysEmitIntoClient
func accept<Address: SocketAddress>(
_ address: Address.Type,
retryOnInterrupt: Bool = true,
sleep: UInt64 = 10_000_000
) async throws -> (SocketDescriptor, Address) {
try await retry(sleep: sleep) {
_accept(address, retryOnInterrupt: retryOnInterrupt)
}.get()
}

/// Initiate a connection on a socket.
///
/// - Parameters:
/// - address: The peer address.
/// - retryOnInterrupt: Whether to retry the receive operation
/// if it throws ``Errno/interrupted``.
/// The default is `true`.
/// Pass `false` to try only once and throw an error upon interruption.
/// - sleep: The number of nanoseconds to sleep if the operation
/// throws ``Errno/wouldBlock`` or other async I/O errors.
/// - Returns: The file descriptor of the new connection.
///
/// The corresponding C function is `connect`.
@_alwaysEmitIntoClient
func connect<Address: SocketAddress>(
to address: Address,
retryOnInterrupt: Bool = true,
sleep: UInt64 = 10_000_000
) async throws {
try await retry(sleep: sleep) {
_connect(to: address, retryOnInterrupt: retryOnInterrupt)
}.get()
}
}

#endif
Loading

0 comments on commit e532539

Please sign in to comment.