Skip to content
This repository has been archived by the owner on Jul 17, 2023. It is now read-only.

Commit

Permalink
vsock: passthrough fd support for hybrid mode
Browse files Browse the repository at this point in the history
After applying this commit, the hybrid vsock would support two
types of connections:

1) pass "CONNECT <port>\n" to the hybrid unix socket;
2) pass "PASSFD\n" to the hybrid unix socket first,
   then, using sendmsg and msg_control to pass the <port>
   number and fd to the unix socket.

In both of the cases, you should read on the socket to
get a response of "OK <assigned_hostside_port>\n".

Signed-off-by: fupan <[email protected]>
  • Loading branch information
fupan authored and studychao committed Jul 17, 2023
1 parent 3b17110 commit d30406b
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 17 deletions.
1 change: 1 addition & 0 deletions crates/dbs-virtio-devices/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ virtio-bindings = "0.1.0"
virtio-queue = "0.6.0"
vmm-sys-util = "0.11.0"
vm-memory = { version = "0.9.0", features = [ "backend-mmap" ] }
sendfd = "0.4.3"

[dev-dependencies]
vm-memory = { version = "0.9.0", features = [ "backend-mmap", "backend-atomic" ] }
Expand Down
11 changes: 10 additions & 1 deletion crates/dbs-virtio-devices/src/vsock/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
/// or even the protocol created by us.
use std::any::Any;
use std::io::{Read, Write};
use std::os::unix::io::AsRawFd;
use std::os::unix::io::{AsRawFd, RawFd};
use std::time::Duration;

mod inner;
Expand All @@ -15,6 +15,7 @@ mod unix_stream;

pub use self::inner::{VsockInnerBackend, VsockInnerConnector, VsockInnerStream};
pub use self::tcp::VsockTcpBackend;
pub use self::unix_stream::HybridUnixStreamBackend;
pub use self::unix_stream::VsockUnixStreamBackend;

/// The type of vsock backend.
Expand Down Expand Up @@ -59,6 +60,14 @@ pub trait VsockStream: Read + Write + AsRawFd + Send {
fn set_write_timeout(&mut self, _dur: Option<Duration>) -> std::io::Result<()> {
Err(std::io::Error::from(std::io::ErrorKind::InvalidInput))
}
/// Receive the port and fd from the peer.
fn recv_data_fd(
&self,
_bytes: &mut [u8],
_fds: &mut [RawFd],
) -> std::io::Result<(usize, usize)> {
Err(std::io::Error::from(std::io::ErrorKind::InvalidInput))
}
/// Used to downcast to the specific type.
fn as_any(&self) -> &dyn Any;
}
66 changes: 66 additions & 0 deletions crates/dbs-virtio-devices/src/vsock/backend/unix_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,77 @@
// SPDX-License-Identifier: Apache-2.0

use std::any::Any;
use std::io::{Read, Write};
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::net::{UnixListener, UnixStream};
use std::time::Duration;

use log::info;
use sendfd::RecvWithFd;

use super::super::{Result, VsockError};
use super::{VsockBackend, VsockBackendType, VsockStream};

pub struct HybridUnixStreamBackend {
pub unix_stream: Box<dyn VsockStream>,
pub slave_stream: Option<Box<dyn VsockStream>>,
}

impl VsockStream for HybridUnixStreamBackend {
fn backend_type(&self) -> VsockBackendType {
self.unix_stream.backend_type()
}

fn set_nonblocking(&mut self, nonblocking: bool) -> std::io::Result<()> {
self.unix_stream.set_nonblocking(nonblocking)
}

fn set_read_timeout(&mut self, dur: Option<Duration>) -> std::io::Result<()> {
self.unix_stream.set_read_timeout(dur)
}

fn set_write_timeout(&mut self, dur: Option<Duration>) -> std::io::Result<()> {
self.unix_stream.set_write_timeout(dur)
}

fn as_any(&self) -> &dyn Any {
self.unix_stream.as_any()
}

fn recv_data_fd(&self, bytes: &mut [u8], fds: &mut [RawFd]) -> std::io::Result<(usize, usize)> {
self.unix_stream.recv_data_fd(bytes, fds)
}
}

impl AsRawFd for HybridUnixStreamBackend {
fn as_raw_fd(&self) -> RawFd {
self.unix_stream.as_raw_fd()
}
}

impl Read for HybridUnixStreamBackend {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.unix_stream.read(buf)
}
}

impl Write for HybridUnixStreamBackend {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
// The slave stream was only used to reply the connect result "ok <port>",
// thus it was only used once here, and the data would be replied by the
// main stream.
if let Some(mut stream) = self.slave_stream.take() {
stream.write(buf)
} else {
self.unix_stream.write(buf)
}
}

fn flush(&mut self) -> std::io::Result<()> {
self.unix_stream.flush()
}
}

impl VsockStream for UnixStream {
fn backend_type(&self) -> VsockBackendType {
VsockBackendType::UnixStream
Expand All @@ -31,6 +93,10 @@ impl VsockStream for UnixStream {
fn as_any(&self) -> &dyn Any {
self
}

fn recv_data_fd(&self, bytes: &mut [u8], fds: &mut [RawFd]) -> std::io::Result<(usize, usize)> {
self.recv_with_fd(bytes, fds)
}
}

/// The backend implementation that using Unix Stream.
Expand Down
131 changes: 115 additions & 16 deletions crates/dbs-virtio-devices/src/vsock/muxer/muxer_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,14 @@
/// `HashMap` object, mapping `RawFd`s to `EpollListener`s.
use std::collections::{HashMap, HashSet};
use std::io::Read;
use std::os::fd::FromRawFd;
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::net::UnixStream;

use log::{debug, error, info, trace, warn};

use super::super::backend::{VsockBackend, VsockBackendType, VsockStream};
use super::super::backend::{HybridUnixStreamBackend, VsockBackend, VsockBackendType, VsockStream};

use super::super::csm::{ConnState, VsockConnection};
use super::super::defs::uapi;
use super::super::packet::VsockPacket;
Expand All @@ -68,6 +71,11 @@ pub enum MuxerRx {
RstPkt { local_port: u32, peer_port: u32 },
}

enum ReadPortResult {
PassFd,
Connect(u32),
}

/// An epoll listener, registered under the muxer's nested epoll FD.
pub enum EpollListener {
/// The listener is a `VsockConnection`, identified by `key`, and interested
Expand All @@ -84,6 +92,9 @@ pub enum EpollListener {
/// A listener interested in reading host "connect <port>" commands from a
/// freshly connected host socket.
LocalStream(Box<dyn VsockStream>),
/// A listener interested in recvmsg from host to get the <port> and a
/// socket/pipe fd.
PassFdStream(Box<dyn VsockStream>),
}

/// The vsock connection multiplexer.
Expand Down Expand Up @@ -415,6 +426,7 @@ impl VsockMuxer {
// from a "connect" command received on this socket,
// so the next step is to ask to be notified the
// moment we can read from it.

self.add_listener(
stream.as_raw_fd(),
EpollListener::LocalStream(stream),
Expand All @@ -433,15 +445,55 @@ impl VsockMuxer {
Some(EpollListener::LocalStream(_)) => {
if let Some(EpollListener::LocalStream(mut stream)) = self.remove_listener(fd) {
Self::read_local_stream_port(&mut stream)
.map(|peer_port| (self.allocate_local_port(), peer_port))
.and_then(|(local_port, peer_port)| {
.and_then(|read_port_result| match read_port_result {
ReadPortResult::Connect(peer_port) => {
let local_port = self.allocate_local_port();
self.add_connection(
ConnMapKey {
local_port,
peer_port,
},
VsockConnection::new_local_init(
stream,
uapi::VSOCK_HOST_CID,
self.cid,
local_port,
peer_port,
),
)
}
ReadPortResult::PassFd => self.add_listener(
stream.as_raw_fd(),
EpollListener::PassFdStream(stream),
),
})
.unwrap_or_else(|err| {
info!("vsock: error adding local-init connection: {:?}", err);
})
}
}

Some(EpollListener::PassFdStream(_)) => {
if let Some(EpollListener::PassFdStream(mut stream)) = self.remove_listener(fd) {
Self::passfd_read_port_and_fd(&mut stream)
.map(|(nfd, peer_port)| (nfd, self.allocate_local_port(), peer_port))
.and_then(|(nfd, local_port, peer_port)| {
// Here we should make sure the nfd the sole owner to convert it
// into an UnixStream object, otherwise, it could cause memory unsafety.
let nstream = unsafe { UnixStream::from_raw_fd(nfd) };

let hybridstream = HybridUnixStreamBackend {
unix_stream: Box::new(nstream),
slave_stream: Some(stream),
};

self.add_connection(
ConnMapKey {
local_port,
peer_port,
},
VsockConnection::new_local_init(
stream,
Box::new(hybridstream),
uapi::VSOCK_HOST_CID,
self.cid,
local_port,
Expand All @@ -450,7 +502,10 @@ impl VsockMuxer {
)
})
.unwrap_or_else(|err| {
info!("vsock: error adding local-init connection: {:?}", err);
info!(
"vsock: error adding local-init passthrough fd connection: {:?}",
err
);
})
}
}
Expand All @@ -465,22 +520,23 @@ impl VsockMuxer {
}

/// Parse a host "connect" command, and extract the destination vsock port.
fn read_local_stream_port(stream: &mut Box<dyn VsockStream>) -> Result<u32> {
fn read_local_stream_port(stream: &mut Box<dyn VsockStream>) -> Result<ReadPortResult> {
let mut buf = [0u8; 32];

// This is the minimum number of bytes that we should be able to read,
// when parsing a valid connection request. I.e. `b"connect 0\n".len()`.
const MIN_READ_LEN: usize = 10;
// when parsing a valid connection request. I.e. `b"passfd\n"`, otherwise,
// it would be `b"connect 0\n".len()`.
const MIN_READ_LEN: usize = 7;

// Bring in the minimum number of bytes that we should be able to read.
stream
.read(&mut buf[..MIN_READ_LEN])
.map_err(Error::BackendRead)?;

// Now, finish reading the destination port number, by bringing in one
// byte at a time, until we reach an EOL terminator (or our buffer space
// runs out). Yeah, not particularly proud of this approach, but it
// will have to do for now.
// Now, finish reading the destination port number if it's connect <port> command,
// by bringing in one byte at a time, until we reach an EOL terminator (or our buffer
// space runs out). Yeah, not particularly proud of this approach, but it will have to
// do for now.
let mut blen = MIN_READ_LEN;
while buf[blen - 1] != b'\n' && blen < buf.len() {
stream
Expand All @@ -497,17 +553,59 @@ impl VsockMuxer {
.next()
.ok_or(Error::InvalidPortRequest)
.and_then(|word| {
if word.to_lowercase() == "connect" {
Ok(())
let key = word.to_lowercase();
if key == "connect" {
Ok(true)
} else if key == "passfd" {
Ok(false)
} else {
Err(Error::InvalidPortRequest)
}
})
.and_then(|_| word_iter.next().ok_or(Error::InvalidPortRequest))
.and_then(|word| word.parse::<u32>().map_err(|_| Error::InvalidPortRequest))
.and_then(|connect| {
if connect {
word_iter.next().ok_or(Error::InvalidPortRequest).map(Some)
} else {
Ok(None)
}
})
.and_then(|word| {
word.map_or_else(
|| Ok(ReadPortResult::PassFd),
|word| {
word.parse::<u32>()
.map_or(Err(Error::InvalidPortRequest), |word| {
Ok(ReadPortResult::Connect(word))
})
},
)
})
.map_err(|_| Error::InvalidPortRequest)
}

fn passfd_read_port_and_fd(stream: &mut Box<dyn VsockStream>) -> Result<(RawFd, u32)> {
let mut buf = [0u8; 32];
let mut fds = [0, 1];
let (data_len, fd_len) = stream
.recv_data_fd(&mut buf, &mut fds)
.map_err(Error::BackendRead)?;

if fd_len != 1 || fds[0] <= 0 {
return Err(Error::InvalidPortRequest);
}

let mut port_iter = std::str::from_utf8(&buf[..data_len])
.map_err(|_| Error::InvalidPortRequest)?
.split_whitespace();

let port = port_iter
.next()
.ok_or(Error::InvalidPortRequest)
.and_then(|word| word.parse::<u32>().map_err(|_| Error::InvalidPortRequest))?;

Ok((fds[0], port))
}

/// Add a new connection to the active connection pool.
fn add_connection(&mut self, key: ConnMapKey, conn: VsockConnection) -> Result<()> {
// We might need to make room for this new connection, so let's sweep
Expand Down Expand Up @@ -582,6 +680,7 @@ impl VsockMuxer {
EpollListener::Connection { evset, .. } => evset,
EpollListener::LocalStream(_) => epoll::Events::EPOLLIN,
EpollListener::Backend(_) => epoll::Events::EPOLLIN,
EpollListener::PassFdStream(_) => epoll::Events::EPOLLIN,
};

epoll::ctl(
Expand Down

0 comments on commit d30406b

Please sign in to comment.