From 297025c5ece13443d00999e27347a5dc5df71b9c Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Thu, 26 Sep 2024 11:15:01 +0800 Subject: [PATCH] Fix/support float delta_t (#140) * Fix/support float delta_t * bump version --- pyproject.toml | 2 +- src/fsrs_optimizer/fsrs_optimizer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1726ef7..1732ca3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "FSRS-Optimizer" -version = "5.0.9" +version = "5.0.10" readme = "README.md" dependencies = [ "matplotlib>=3.7.0", diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index 0a8b87f..78e31b5 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -213,7 +213,7 @@ def __init__( self.x_train = pad_sequence( dataframe["tensor"].to_list(), batch_first=True, padding_value=0 ) - self.t_train = torch.tensor(dataframe["delta_t"].values, dtype=torch.int) + self.t_train = torch.tensor(dataframe["delta_t"].values, dtype=torch.float) self.y_train = torch.tensor(dataframe["y"].values, dtype=torch.float) self.seq_len = torch.tensor( dataframe["tensor"].map(len).values, dtype=torch.long