diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index 3c13096..7ad7b1f 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -95,9 +95,13 @@ def stability_short_term(self, state: Tensor, rating: Tensor) -> Tensor: new_s = state[:, 0] * torch.exp(self.w[17] * (rating - 3 + self.w[18])) return new_s + def init_d(self, rating: Tensor) -> Tensor: + new_d = self.w[4] - torch.exp(self.w[5] * (X[:, 1] - 1)) + 1 + return new_d + def next_d(self, state: Tensor, rating: Tensor) -> Tensor: new_d = state[:, 1] - self.w[6] * (rating - 3) - new_d = self.mean_reversion(self.w[4], new_d) + new_d = self.mean_reversion(init_d(4), new_d) return new_d def step(self, X: Tensor, state: Tensor) -> Tensor: @@ -113,7 +117,7 @@ def step(self, X: Tensor, state: Tensor) -> Tensor: # first learn, init memory states new_s = torch.ones_like(state[:, 0]) new_s[index[0]] = self.w[index[1]] - new_d = self.w[4] - torch.exp(self.w[5] * (X[:, 1] - 1)) + 1 + new_d = init_d(X[:, 1]) new_d = new_d.clamp(1, 10) else: r = power_forgetting_curve(X[:, 0], state[:, 0])