Skip to content

Commit

Permalink
Merge branch 'main' into patch-2
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Sep 8, 2024
2 parents 6a73234 + c7568ae commit 126e0de
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 41 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "FSRS-Optimizer"
version = "5.0.6"
version = "5.0.7"
readme = "README.md"
dependencies = [
"matplotlib>=3.7.0",
Expand Down
81 changes: 41 additions & 40 deletions src/fsrs_optimizer/fsrs_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,9 @@ def sample(
forget_rating_offset=DEFAULT_FORGET_RATING_OFFSET,
forget_session_len=DEFAULT_FORGET_SESSION_LEN,
loss_aversion=2.5,
workload_only=False,
):
memorization = []
results = []

def best_sample_size(days_to_simulate):
if days_to_simulate <= 30:
Expand Down Expand Up @@ -296,8 +297,11 @@ def best_sample_size(days_to_simulate):
loss_aversion,
seed=42 + i,
)
memorization.append(cost_per_day.sum() / memorized_cnt_per_day[-1])
return np.mean(memorization)
if workload_only:
results.append(cost_per_day.sum())
else:
results.append(cost_per_day.sum() / memorized_cnt_per_day[-1])
return np.mean(results)


def brent(tol=0.01, maxiter=20, **kwargs):
Expand Down Expand Up @@ -417,68 +421,65 @@ def brent(tol=0.01, maxiter=20, **kwargs):
raise Exception("The algorithm terminated without finding a valid value.")


def workload_graph(default_params):
R = [x / 100 for x in range(70, 100)]
cost_per_memorization = [sample(r=r, **default_params) for r in R]
def workload_graph(default_params, sampling_size=30):
R = np.linspace(0.7, 0.999, sampling_size).tolist()
default_params["max_cost_perday"] = math.inf
default_params["learn_limit_perday"] = int(
default_params["deck_size"] / default_params["learn_span"]
)
default_params["review_limit_perday"] = math.inf
workload = [sample(r=r, workload_only=True, **default_params) for r in R]

# this is for testing
# cost_per_memorization = [min(x, 2.3 * min(cost_per_memorization)) for x in cost_per_memorization]
min_w = min(cost_per_memorization) # minimum workload
max_w = max(cost_per_memorization) # maximum workload
min1_index = R.index(R[cost_per_memorization.index(min_w)])
# workload = [min(x, 2.3 * min(workload)) for x in workload]
min_w = min(workload) # minimum workload
max_w = max(workload) # maximum workload
min1_index = R.index(R[workload.index(min_w)])

min_w2 = 0
min_w3 = 0
target2 = 2 * min_w
target3 = 3 * min_w

for i in range(len(cost_per_memorization) - 1):
if (cost_per_memorization[i] <= target2) and (
cost_per_memorization[i + 1] >= target2
):
if abs(cost_per_memorization[i] - target2) < abs(
cost_per_memorization[i + 1] - target2
):
min_w2 = cost_per_memorization[i]
for i in range(len(workload) - 1):
if (workload[i] <= target2) and (workload[i + 1] >= target2):
if abs(workload[i] - target2) < abs(workload[i + 1] - target2):
min_w2 = workload[i]
else:
min_w2 = cost_per_memorization[i + 1]

for i in range(len(cost_per_memorization) - 1):
if (cost_per_memorization[i] <= target3) and (
cost_per_memorization[i + 1] >= target3
):
if abs(cost_per_memorization[i] - target3) < abs(
cost_per_memorization[i + 1] - target3
):
min_w3 = cost_per_memorization[i]
min_w2 = workload[i + 1]

for i in range(len(workload) - 1):
if (workload[i] <= target3) and (workload[i + 1] >= target3):
if abs(workload[i] - target3) < abs(workload[i + 1] - target3):
min_w3 = workload[i]
else:
min_w3 = cost_per_memorization[i + 1]
min_w3 = workload[i + 1]

if min_w2 == 0:
min2_index = len(R)
else:
min2_index = R.index(R[cost_per_memorization.index(min_w2)])
min2_index = R.index(R[workload.index(min_w2)])

min1_5_index = int(math.ceil((min2_index + 3 * min1_index) / 4))
if min_w3 == 0:
min3_index = len(R)
else:
min3_index = R.index(R[cost_per_memorization.index(min_w3)])
min3_index = R.index(R[workload.index(min_w3)])

fig = plt.figure(figsize=(16, 8))
ax = fig.gca()
if min1_index > 0:
ax.fill_between(
x=R[: min1_index + 1],
y1=0,
y2=cost_per_memorization[: min1_index + 1],
y2=workload[: min1_index + 1],
color="red",
alpha=1,
)
ax.fill_between(
x=R[min1_index : min1_5_index + 1],
y1=0,
y2=cost_per_memorization[min1_index : min1_5_index + 1],
y2=workload[min1_index : min1_5_index + 1],
color="gold",
alpha=1,
)
Expand All @@ -487,29 +488,29 @@ def workload_graph(default_params):
ax.fill_between(
x=R[: min1_5_index + 1],
y1=0,
y2=cost_per_memorization[: min1_5_index + 1],
y2=workload[: min1_5_index + 1],
color="gold",
alpha=1,
)

ax.fill_between(
x=R[min1_5_index : min2_index + 1],
y1=0,
y2=cost_per_memorization[min1_5_index : min2_index + 1],
y2=workload[min1_5_index : min2_index + 1],
color="limegreen",
alpha=1,
)
ax.fill_between(
x=R[min2_index : min3_index + 1],
y1=0,
y2=cost_per_memorization[min2_index : min3_index + 1],
y2=workload[min2_index : min3_index + 1],
color="gold",
alpha=1,
)
ax.fill_between(
x=R[min3_index:],
y1=0,
y2=cost_per_memorization[min3_index:],
y2=workload[min3_index:],
color="red",
alpha=1,
)
Expand All @@ -527,7 +528,7 @@ def workload_graph(default_params):

ax.set_ylim(0, lim)
ax.set_ylabel("Workload (minutes of study per day)", fontsize=20)
ax.set_xlabel("Retention", fontsize=20)
ax.set_xlabel("Desired Retention", fontsize=20)
ax.axhline(y=min_w, color="black", alpha=0.75, ls="--")
ax.text(
0.701,
Expand Down Expand Up @@ -571,7 +572,7 @@ def workload_graph(default_params):
color="black",
fontsize=12,
)

fig.tight_layout(h_pad=0, w_pad=0)
return fig


Expand Down Expand Up @@ -641,4 +642,4 @@ def moving_average(data, window_size=365 // 20):
ax.set_title("Memorized Count per Day")
ax.grid(True)
plt.show()
workload_graph(default_params).savefig("workload.png")
workload_graph(default_params, sampling_size=300).savefig("workload.png")

0 comments on commit 126e0de

Please sign in to comment.