Skip to content

Commit

Permalink
chore: make db a single column family handler
Browse files Browse the repository at this point in the history
  • Loading branch information
ethe committed Jul 16, 2024
1 parent 596f286 commit 4635609
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 122 deletions.
1 change: 0 additions & 1 deletion src/inmem/immutable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use std::{
};

use arrow::array::RecordBatch;
use futures_util::stream::{self, Iter};

use super::mutable::Mutable;
use crate::{
Expand Down
75 changes: 60 additions & 15 deletions src/inmem/mutable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,26 +57,29 @@ where
R: Record + Send + Sync,
R::Key: Send,
{
pub(crate) fn insert(&self, record: Timestamped<R>) {
let (record, ts) = record.into_parts();
pub(crate) fn insert(&self, record: R, ts: Timestamp) {
self.data
// TODO: remove key cloning
.insert(Timestamped::new(record.key().to_key(), ts), Some(record));
}

pub(crate) fn remove(&self, key: Timestamped<R::Key>) {
self.data.insert(key, None);
pub(crate) fn remove(&self, key: R::Key, ts: Timestamp) {
self.data.insert(Timestamped::new(key, ts), None);
}

fn get(
&self,
key: &TimestampedRef<R::Key>,
key: &R::Key,
ts: Timestamp,
) -> Option<Entry<'_, Timestamped<R::Key>, Option<R>>> {
self.data
.range::<TimestampedRef<R::Key>, _>((Bound::Included(key), Bound::Unbounded))
.range::<TimestampedRef<R::Key>, _>((
Bound::Included(TimestampedRef::new(key, ts)),
Bound::Unbounded,
))
.next()
.and_then(|entry| {
if &entry.key().value == key.value() {
if &entry.key().value == key {
Some(entry)
} else {
None
Expand Down Expand Up @@ -111,9 +114,11 @@ where

#[cfg(test)]
mod tests {
use std::ops::Bound;

use super::Mutable;
use crate::{
oracle::timestamp::{Timestamped, TimestampedRef},
oracle::timestamp::Timestamped,
record::Record,
tests::{Test, TestRef},
};
Expand All @@ -125,26 +130,24 @@ mod tests {

let mem_table = Mutable::default();

mem_table.insert(Timestamped::new(
mem_table.insert(
Test {
vstring: key_1.clone(),
vu32: 1,
vobool: Some(true),
},
0_u32.into(),
));
mem_table.insert(Timestamped::new(
);
mem_table.insert(
Test {
vstring: key_2.clone(),
vu32: 2,
vobool: None,
},
1_u32.into(),
));
);

let entry = mem_table
.get(TimestampedRef::new(&key_1, 0_u32.into()))
.unwrap();
let entry = mem_table.get(&key_1, 0_u32.into()).unwrap();
assert_eq!(
entry.value().as_ref().unwrap().as_record_ref(),
TestRef {
Expand All @@ -154,4 +157,46 @@ mod tests {
}
)
}

#[test]
fn range() {
let mutable = Mutable::<String>::new();

mutable.insert("1".into(), 0_u32.into());
mutable.insert("2".into(), 0_u32.into());
mutable.insert("2".into(), 1_u32.into());
mutable.insert("3".into(), 1_u32.into());
mutable.insert("4".into(), 0_u32.into());

let mut scan = mutable.scan((Bound::Unbounded, Bound::Unbounded), 0_u32.into());

assert_eq!(
scan.next().unwrap().key(),
&Timestamped::new("1".into(), 0_u32.into())
);
assert_eq!(
scan.next().unwrap().key(),
&Timestamped::new("2".into(), 0_u32.into())
);
assert_eq!(
scan.next().unwrap().key(),
&Timestamped::new("4".into(), 0_u32.into())
);

let lower = "1".to_string();
let upper = "4".to_string();
let mut scan = mutable.scan(
(Bound::Included(&lower), Bound::Included(&upper)),
1_u32.into(),
);

assert_eq!(
scan.next().unwrap().key(),
&Timestamped::new("2".into(), 1_u32.into())
);
assert_eq!(
scan.next().unwrap().key(),
&Timestamped::new("3".into(), 1_u32.into())
);
}
}
166 changes: 75 additions & 91 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,113 +9,59 @@ mod record;
mod stream;
mod transaction;

use std::{
any::TypeId,
collections::{hash_map::Entry, HashMap, VecDeque},
io, mem,
ops::Bound,
sync::Arc,
};

use async_lock::RwLock;
use std::{collections::VecDeque, io, mem, ops::Bound, sync::Arc};

use async_lock::{RwLock, RwLockReadGuard};
use futures_core::Stream;
use futures_util::StreamExt;
use inmem::{immutable::Immutable, mutable::Mutable};
use oracle::{timestamp::Timestamped, Timestamp};
use oracle::Timestamp;
use parquet::errors::ParquetError;
use record::Record;
use stream::{merge::MergeStream, Entry, ScanStream};

#[derive(Debug)]
pub struct DB {
schemas: std::sync::RwLock<HashMap<TypeId, *const ()>>,
pub struct DB<R>
where
R: Record,
{
schema: Arc<RwLock<Schema<R>>>,
}

impl DB {
pub fn empty() -> Self {
impl<R> Default for DB<R>
where
R: Record,
{
fn default() -> Self {
Self {
schemas: std::sync::RwLock::new(HashMap::new()),
schema: Arc::new(RwLock::new(Schema::default())),
}
}
}

pub(crate) async fn write<R>(&self, record: R, ts: Timestamp) -> io::Result<()>
where
R: Record + Send + Sync,
R::Key: Send,
{
let columns = self.get_schema::<R>();
let columns = columns.read().await;
impl<R> DB<R>
where
R: Record + Send + Sync,
R::Key: Send,
{
pub(crate) async fn write(&self, record: R, ts: Timestamp) -> io::Result<()> {
let columns = self.schema.read().await;
columns.write(record, ts).await
}

pub(crate) async fn write_batch<R>(
pub(crate) async fn write_batch(
&self,
records: impl Iterator<Item = R>,
ts: Timestamp,
) -> io::Result<()>
where
R: Record + Send + Sync,
R::Key: Send,
{
let columns = self.get_schema::<R>();
let columns = columns.read().await;
) -> io::Result<()> {
let columns = self.schema.read().await;
for record in records {
columns.write(record, ts).await?;
}
Ok(())
}

pub(crate) async fn get<R: Record>(&self, key: Timestamped<R::Key>) -> io::Result<Option<&R>> {
let columns = self.get_schema::<R>();
let columns = columns.read().await;
// columns.get(key, ts).await
todo!()
}

pub async fn range_scan<T: Record>(&self, start: Bound<&T::Key>, end: Bound<&T::Key>) {}

fn get_schema<R>(&self) -> Arc<RwLock<Schema<R>>>
where
R: Record,
{
let schemas = self.schemas.read().unwrap();
match schemas.get(&TypeId::of::<R>()) {
Some(schema) => {
let inner = unsafe { Arc::from_raw(*schema as *const RwLock<Schema<R>>) };
let schema = inner.clone();
std::mem::forget(inner);
schema
}
None => {
drop(schemas);
let mut schemas = self.schemas.write().unwrap();
match schemas.entry(TypeId::of::<R>()) {
Entry::Occupied(o) => unsafe {
let inner = Arc::from_raw(*o.get() as *const RwLock<Schema<R>>);
let schema = inner.clone();
std::mem::forget(inner);
schema
},
Entry::Vacant(v) => {
let schema = Schema {
mutable: Mutable::new(),
immutables: VecDeque::new(),
};
let columns = Arc::new(RwLock::new(schema));
v.insert(Arc::into_raw(columns.clone()) as *const ());
columns
}
}
}
}
}
}

impl Drop for DB {
fn drop(&mut self) {
self.schemas
.write()
.unwrap()
.values()
.for_each(|schema| unsafe {
Arc::from_raw(*schema as *const RwLock<()>);
});
pub(crate) async fn read(&self) -> RwLockReadGuard<'_, Schema<R>> {
self.schema.read().await
}
}

Expand All @@ -127,16 +73,56 @@ where
immutables: VecDeque<Immutable<R::Columns>>,
}

impl<R> Default for Schema<R>
where
R: Record,
{
fn default() -> Self {
Self {
mutable: Mutable::default(),
immutables: VecDeque::default(),
}
}
}

impl<R> Schema<R>
where
R: Record + Send + Sync,
R::Key: Send + Sync,
{
async fn write(&self, record: R, ts: Timestamp) -> io::Result<()> {
self.mutable.insert(Timestamped::new(record, ts));
self.mutable.insert(record, ts);
Ok(())
}

async fn get<'get>(
&'get self,
key: &'get R::Key,
ts: Timestamp,
) -> Result<Option<Entry<'get, R>>, ParquetError> {
self.scan(Bound::Included(key), Bound::Unbounded, ts)
.await?
.next()
.await
.transpose()
}

async fn scan<'scan>(
&'scan self,
lower: Bound<&'scan R::Key>,
uppwer: Bound<&'scan R::Key>,
ts: Timestamp,
) -> Result<impl Stream<Item = Result<Entry<'scan, R>, ParquetError>>, ParquetError> {
let mut streams = Vec::<ScanStream<R>>::with_capacity(self.immutables.len() + 1);
streams.push(self.mutable.scan((lower, uppwer), ts).into());
for immutable in &self.immutables {
streams.push(immutable.scan((lower, uppwer), ts).into());
}
// TODO: sstable scan

MergeStream::from_vec(streams).await
}

fn freeze(&mut self) {
let mutable = mem::replace(&mut self.mutable, Mutable::new());
let immutable = Immutable::from(mutable);
Expand Down Expand Up @@ -242,7 +228,7 @@ pub(crate) mod tests {
}

pub(crate) async fn get_test_record_batch() -> RecordBatch {
let db = DB::empty();
let db = DB::default();

db.write(
Test {
Expand All @@ -265,9 +251,7 @@ pub(crate) mod tests {
.await
.unwrap();

let schema = db.get_schema::<Test>();

let mut schema = schema.write().await;
let mut schema = db.schema.write().await;

schema.freeze();

Expand Down
Loading

0 comments on commit 4635609

Please sign in to comment.