Skip to content

Commit

Permalink
weh
Browse files Browse the repository at this point in the history
  • Loading branch information
ifd3f committed Jul 10, 2024
1 parent 777943a commit 74fe871
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 24 deletions.
21 changes: 21 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ itertools = "0.12.1"
libc = "0.2.154"
lz4_flex = "0.11.3"
md-5 = "0.10.5"
pin-project = "1.1.5"
process_path = "0.1.4"
ratatui = "0.26.0"
ruzstd = "0.6.0"
Expand Down
164 changes: 140 additions & 24 deletions src/writer_process/utils.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
use std::{
fs::File,
io::{BufReader, Read, Seek, Write},
pin::{pin, Pin},
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
task::{Context, Poll},
};

use futures::{future::BoxFuture, FutureExt};
use pin_project::pin_project;
use tokio::{
fs::File,
io::{AsyncRead, AsyncSeek, AsyncWrite},
};

use crate::compression::{decompress, CompressionFormat, DecompressRead};
Expand Down Expand Up @@ -40,12 +52,14 @@ impl<R: Read> Read for CountRead<R> {

/// Wraps a writer and counts how many bytes we've written in total, without
/// making any system calls.
pub struct CountWrite<W: Write> {
#[pin_project]
pub struct CountWrite<W: AsyncWrite> {
#[pin]
w: W,
count: u64,
}

impl<W: Write> CountWrite<W> {
impl<W: AsyncWrite> CountWrite<W> {
#[inline(always)]
pub fn new(w: W) -> Self {
Self { w, count: 0 }
Expand All @@ -57,17 +71,38 @@ impl<W: Write> CountWrite<W> {
}
}

impl<W: Write> Write for CountWrite<W> {
fn inspect_poll<T>(p: Poll<T>, f: impl FnOnce(&T) -> ()) -> Poll<T> {
match &p {
Poll::Ready(x) => f(x),
Poll::Pending => (),
}
p
}

impl<W: AsyncWrite> AsyncWrite for CountWrite<W> {
#[inline(always)]
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let bytes = self.w.write(buf)?;
self.count += bytes as u64;
Ok(bytes)
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
let proj = self.as_mut().project();
let poll = proj.w.poll_write(cx, buf);
inspect_poll(poll, move |r| {
r.as_ref().inspect(|c| {
*proj.count += **c as u64;
});
})
}

#[inline(always)]
fn flush(&mut self) -> std::io::Result<()> {
self.w.flush()
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
self.project().w.poll_flush(cx)
}

#[inline(always)]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
self.project().w.poll_shutdown(cx)
}
}

Expand All @@ -78,41 +113,122 @@ impl<W: Write> Write for CountWrite<W> {
/// - trivially delegates [`Write::write`]
/// - replaces [`Write::flush`] with the platform-specific synchronous call to ensure
/// that the data has been written to the disk.
pub struct SyncDataFile(pub File);
#[pin_project]
pub struct SyncDataFile {
#[pin]
state: SyncDataState,
}

#[pin_project(project = SyncDataStateProj)]
enum SyncDataState {
NotFlushing(#[pin] File),
Flushing {
file: Arc<File>,
future: BoxFuture<'static, std::io::Result<()>>,
},
}

impl SyncDataFile {
fn new(file: File) -> Self {
Self {
state: SyncDataState::NotFlushing(file),
}
}

impl Read for SyncDataFile {
fn try_get_file(self: Pin<&mut Self>) -> Option<Pin<&mut File>> {
match self.project().state.project() {
SyncDataStateProj::NotFlushing(f) => Some(f),
_ => None,
}
}
}

impl AsyncRead for SyncDataFile {
#[inline(always)]
fn read(&mut self, buf: &mut [u8]) -> futures_io::Result<usize> {
self.0.read(buf)
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.try_get_file() {
Some(file) => file.poll_read(cx, buf),
None => Poll::Pending,
}
}
}

impl Write for SyncDataFile {
impl AsyncWrite for SyncDataFile {
#[inline(always)]
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.0.write(buf)
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
match self.try_get_file() {
Some(file) => file.poll_write(cx, buf),
None => Poll::Pending,
}
}

#[inline(always)]
fn flush(&mut self) -> std::io::Result<()> {
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
#[cfg(target_os = "linux")]
{
self.0.sync_data()
let (file, mut fut) = match self.state {
SyncDataState::NotFlushing(f) => {
let f = Arc::new(f);
let f2 = f.clone();
(f, async move { f2.sync_data().await }.boxed())
}
SyncDataState::Flushing { file, future } => (file, future),
};
let p = fut.poll_unpin(cx);
self.state = match &p {
Poll::Ready(_) => {
drop(fut);
SyncDataState::NotFlushing(
Arc::try_unwrap(file)
.expect("this should be the last instance of this Arc!"),
)
}
Poll::Pending => SyncDataState::Flushing { file, future: fut },
};
p
}

// On MacOS, calling sync_data() on a disk yields "Inappropriate ioctl for device (os error 25)"
// so for now we will just no-op.
#[cfg(target_os = "macos")]
{
Ok(())
Poll::Ready(Ok(()))
}
}
}

impl Seek for SyncDataFile {
#[inline(always)]
fn seek(&mut self, pos: futures_io::SeekFrom) -> std::io::Result<u64> {
self.0.seek(pos)
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.try_get_file() {
Some(file) => file.poll_shutdown(cx),
None => Poll::Pending,
}
}
}

impl AsyncSeek for SyncDataFile {
fn start_seek(self: Pin<&mut Self>, position: std::io::SeekFrom) -> std::io::Result<()> {
match self.try_get_file() {
Some(file) => file.start_seek(position),
None => Err(std::io::Error::new(
std::io::ErrorKind::Other,
"other file operation is pending, call poll_complete before start_seek",
)),
}
}

fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<u64>> {
match self.try_get_file() {
Some(file) => file.poll_complete(cx),
None => Poll::Pending,
}
}
}

Expand Down

0 comments on commit 74fe871

Please sign in to comment.