refactor to 3.24.1
Browse files- app.py +39 -36
- fsrs4anki_optimizer.ipynb +0 -0
- memory_states.py +0 -35
- model.py +0 -110
- plot.py +0 -99
- requirements.txt +1 -7
- utilities.py +0 -284
app.py
CHANGED
@@ -1,14 +1,11 @@
|
|
1 |
import gradio as gr
|
2 |
import pytz
|
3 |
-
|
4 |
from datetime import datetime
|
5 |
-
|
6 |
-
from utilities import extract, create_time_series_features, train_model, process_personalized_collection, my_loss, \
|
7 |
-
cleanup
|
8 |
from markdown import instructions_markdown, faq_markdown
|
9 |
-
from
|
10 |
-
from
|
11 |
-
|
12 |
|
13 |
def get_w_markdown(w):
|
14 |
return f"""
|
@@ -20,30 +17,38 @@ def get_w_markdown(w):
|
|
20 |
Check out the Analysis tab for more detailed information."""
|
21 |
|
22 |
|
23 |
-
def anki_optimizer(file, timezone, next_day_starts_at, revlog_start_date, requestRetention,
|
24 |
progress=gr.Progress(track_tqdm=True)):
|
25 |
now = datetime.now()
|
26 |
files = ['prediction.tsv', 'revlog.csv', 'revlog_history.tsv', 'stability_for_analysis.tsv',
|
27 |
-
'
|
28 |
prefix = now.strftime(f'%Y_%m_%d_%H_%M_%S')
|
29 |
-
|
30 |
-
proj_dir =
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
# Loss Information
|
49 |
{loss_markdown}
|
@@ -54,12 +59,13 @@ def anki_optimizer(file, timezone, next_day_starts_at, revlog_start_date, reques
|
|
54 |
# Ratings
|
55 |
{rating_markdown}
|
56 |
"""
|
57 |
-
files_out = [
|
58 |
-
|
|
|
59 |
|
60 |
|
61 |
description = """
|
62 |
-
# FSRS4Anki Optimizer App - v3.
|
63 |
Based on the [tutorial](https://medium.com/@JarrettYe/how-to-use-the-next-generation-spaced-repetition-algorithm-fsrs-on-anki-5a591ca562e2)
|
64 |
of [Jarrett Ye](https://github.com/L-M-Sherlock). This application can give you personalized anki parameters without having to code.
|
65 |
|
@@ -74,7 +80,6 @@ with gr.Blocks() as demo:
|
|
74 |
with gr.Row():
|
75 |
with gr.Column():
|
76 |
file = gr.File(label='Review Logs (Step 1)')
|
77 |
-
fast_mode_in = gr.Checkbox(value=False, label="Fast Mode (Will just return the optimized weights)")
|
78 |
with gr.Column():
|
79 |
next_day_starts_at = gr.Number(value=4,
|
80 |
label="Next Day Starts at (Step 2)",
|
@@ -95,15 +100,13 @@ with gr.Blocks() as demo:
|
|
95 |
with gr.Row():
|
96 |
markdown_output = gr.Markdown()
|
97 |
with gr.Column():
|
98 |
-
df_output = gr.DataFrame()
|
99 |
plot_output = gr.Plot()
|
100 |
files_output = gr.Files(label="Analysis Files")
|
101 |
with gr.Tab("FAQ"):
|
102 |
gr.Markdown(faq_markdown)
|
103 |
|
104 |
btn_plot.click(anki_optimizer,
|
105 |
-
inputs=[file, timezone, next_day_starts_at, revlog_start_date, requestRetention
|
106 |
-
outputs=[w_output,
|
107 |
|
108 |
-
|
109 |
-
demo.queue().launch(show_error=True)
|
|
|
1 |
import gradio as gr
|
2 |
import pytz
|
3 |
+
import os
|
4 |
from datetime import datetime
|
|
|
|
|
|
|
5 |
from markdown import instructions_markdown, faq_markdown
|
6 |
+
from fsrs4anki_optimizer import Optimizer
|
7 |
+
from pathlib import Path
|
8 |
+
from utilities import cleanup
|
9 |
|
10 |
def get_w_markdown(w):
|
11 |
return f"""
|
|
|
17 |
Check out the Analysis tab for more detailed information."""
|
18 |
|
19 |
|
20 |
+
def anki_optimizer(file: gr.File, timezone, next_day_starts_at, revlog_start_date, requestRetention,
|
21 |
progress=gr.Progress(track_tqdm=True)):
|
22 |
now = datetime.now()
|
23 |
files = ['prediction.tsv', 'revlog.csv', 'revlog_history.tsv', 'stability_for_analysis.tsv',
|
24 |
+
'expected_time.csv', 'evaluation.tsv']
|
25 |
prefix = now.strftime(f'%Y_%m_%d_%H_%M_%S')
|
26 |
+
suffix = file.name.split('/')[-1].replace(".", "_").replace("@", "_")
|
27 |
+
proj_dir = Path(f'projects/{prefix}/{suffix}')
|
28 |
+
proj_dir.mkdir(parents=True, exist_ok=True)
|
29 |
+
print(proj_dir)
|
30 |
+
os.chdir(proj_dir)
|
31 |
+
proj_dir = Path('.')
|
32 |
+
optimizer = Optimizer()
|
33 |
+
optimizer.anki_extract(file.name)
|
34 |
+
analysis_markdown = optimizer.create_time_series(timezone, revlog_start_date, next_day_starts_at).replace("\n", "\n\n")
|
35 |
+
optimizer.define_model()
|
36 |
+
optimizer.train()
|
37 |
+
w_markdown = get_w_markdown(optimizer.w)
|
38 |
+
optimizer.predict_memory_states()
|
39 |
+
difficulty_distribution = optimizer.difficulty_distribution.to_string().replace("\n", "\n\n")
|
40 |
+
plot_output = optimizer.find_optimal_retention()[0]
|
41 |
+
suggested_retention_markdown = f"""# Suggested Retention: `{optimizer.optimal_retention:.2f}`"""
|
42 |
+
rating_markdown = optimizer.preview(requestRetention).replace("\n", "\n\n")
|
43 |
+
loss_before, loss_after = optimizer.evaluate()
|
44 |
+
loss_markdown = f"""
|
45 |
+
**Loss before training**: {loss_before}
|
46 |
+
|
47 |
+
**Loss after training**: {loss_after}
|
48 |
+
"""
|
49 |
+
# optimizer.calibration_graph()
|
50 |
+
# optimizer.compare_with_sm2()
|
51 |
+
markdown_out = f"""{suggested_retention_markdown}
|
52 |
|
53 |
# Loss Information
|
54 |
{loss_markdown}
|
|
|
59 |
# Ratings
|
60 |
{rating_markdown}
|
61 |
"""
|
62 |
+
files_out = [file for file in files if (proj_dir / file).exists()]
|
63 |
+
cleanup(proj_dir, files)
|
64 |
+
return w_markdown, markdown_out, plot_output, files_out
|
65 |
|
66 |
|
67 |
description = """
|
68 |
+
# FSRS4Anki Optimizer App - v3.24.1
|
69 |
Based on the [tutorial](https://medium.com/@JarrettYe/how-to-use-the-next-generation-spaced-repetition-algorithm-fsrs-on-anki-5a591ca562e2)
|
70 |
of [Jarrett Ye](https://github.com/L-M-Sherlock). This application can give you personalized anki parameters without having to code.
|
71 |
|
|
|
80 |
with gr.Row():
|
81 |
with gr.Column():
|
82 |
file = gr.File(label='Review Logs (Step 1)')
|
|
|
83 |
with gr.Column():
|
84 |
next_day_starts_at = gr.Number(value=4,
|
85 |
label="Next Day Starts at (Step 2)",
|
|
|
100 |
with gr.Row():
|
101 |
markdown_output = gr.Markdown()
|
102 |
with gr.Column():
|
|
|
103 |
plot_output = gr.Plot()
|
104 |
files_output = gr.Files(label="Analysis Files")
|
105 |
with gr.Tab("FAQ"):
|
106 |
gr.Markdown(faq_markdown)
|
107 |
|
108 |
btn_plot.click(anki_optimizer,
|
109 |
+
inputs=[file, timezone, next_day_starts_at, revlog_start_date, requestRetention],
|
110 |
+
outputs=[w_output, markdown_output, plot_output, files_output])
|
111 |
|
112 |
+
demo.queue().launch(show_error=True)
|
|
fsrs4anki_optimizer.ipynb
DELETED
The diff for this file is too large to render.
See raw diff
|
|
memory_states.py
DELETED
@@ -1,35 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
from functools import partial
|
3 |
-
|
4 |
-
import pandas as pd
|
5 |
-
|
6 |
-
|
7 |
-
def predict_memory_states(my_collection, group):
|
8 |
-
states = my_collection.states(*group.name)
|
9 |
-
group['stability'] = float(states[0])
|
10 |
-
group['difficulty'] = float(states[1])
|
11 |
-
group['count'] = len(group)
|
12 |
-
return pd.DataFrame({
|
13 |
-
'r_history': [group.name[1]],
|
14 |
-
't_history': [group.name[0]],
|
15 |
-
'stability': [round(float(states[0]), 2)],
|
16 |
-
'difficulty': [round(float(states[1]), 2)],
|
17 |
-
'count': [len(group)]
|
18 |
-
})
|
19 |
-
|
20 |
-
|
21 |
-
def get_my_memory_states(proj_dir, dataset, my_collection):
|
22 |
-
prediction = dataset.groupby(by=['t_history', 'r_history']).progress_apply(
|
23 |
-
partial(predict_memory_states, my_collection))
|
24 |
-
prediction.reset_index(drop=True, inplace=True)
|
25 |
-
prediction.sort_values(by=['r_history'], inplace=True)
|
26 |
-
prediction.to_csv(proj_dir / "prediction.tsv", sep='\t', index=None)
|
27 |
-
# print("prediction.tsv saved.")
|
28 |
-
prediction['difficulty'] = prediction['difficulty'].map(lambda x: int(round(x)))
|
29 |
-
difficulty_distribution = prediction.groupby(by=['difficulty'])['count'].sum() / prediction['count'].sum()
|
30 |
-
# print(difficulty_distribution)
|
31 |
-
difficulty_distribution_padding = np.zeros(10)
|
32 |
-
for i in range(10):
|
33 |
-
if i + 1 in difficulty_distribution.index:
|
34 |
-
difficulty_distribution_padding[i] = difficulty_distribution.loc[i + 1]
|
35 |
-
return difficulty_distribution_padding, difficulty_distribution
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model.py
DELETED
@@ -1,110 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import torch
|
3 |
-
from torch import nn
|
4 |
-
|
5 |
-
init_w = [1, 1, 5, -0.5, -0.5, 0.2, 1.4, -0.12, 0.8, 2, -0.2, 0.2, 1]
|
6 |
-
'''
|
7 |
-
w[0]: initial_stability_for_again_answer
|
8 |
-
w[1]: initial_stability_step_per_rating
|
9 |
-
w[2]: initial_difficulty_for_good_answer
|
10 |
-
w[3]: initial_difficulty_step_per_rating
|
11 |
-
w[4]: next_difficulty_step_per_rating
|
12 |
-
w[5]: next_difficulty_reversion_to_mean_speed (used to avoid ease hell)
|
13 |
-
w[6]: next_stability_factor_after_success
|
14 |
-
w[7]: next_stability_stabilization_decay_after_success
|
15 |
-
w[8]: next_stability_retrievability_gain_after_success
|
16 |
-
w[9]: next_stability_factor_after_failure
|
17 |
-
w[10]: next_stability_difficulty_decay_after_success
|
18 |
-
w[11]: next_stability_stability_gain_after_failure
|
19 |
-
w[12]: next_stability_retrievability_gain_after_failure
|
20 |
-
For more details about the parameters, please see:
|
21 |
-
https://github.com/open-spaced-repetition/fsrs4anki/wiki/Free-Spaced-Repetition-Scheduler
|
22 |
-
'''
|
23 |
-
|
24 |
-
|
25 |
-
class FSRS(nn.Module):
|
26 |
-
def __init__(self, w):
|
27 |
-
super(FSRS, self).__init__()
|
28 |
-
self.w = nn.Parameter(torch.FloatTensor(w))
|
29 |
-
self.zero = torch.FloatTensor([0.0])
|
30 |
-
|
31 |
-
def forward(self, x, s, d):
|
32 |
-
'''
|
33 |
-
:param x: [review interval, review response]
|
34 |
-
:param s: stability
|
35 |
-
:param d: difficulty
|
36 |
-
:return:
|
37 |
-
'''
|
38 |
-
if torch.equal(s, self.zero):
|
39 |
-
# first learn, init memory states
|
40 |
-
new_s = self.w[0] + self.w[1] * (x[1] - 1)
|
41 |
-
new_d = self.w[2] + self.w[3] * (x[1] - 3)
|
42 |
-
new_d = new_d.clamp(1, 10)
|
43 |
-
else:
|
44 |
-
r = torch.exp(np.log(0.9) * x[0] / s)
|
45 |
-
new_d = d + self.w[4] * (x[1] - 3)
|
46 |
-
new_d = self.mean_reversion(self.w[2], new_d)
|
47 |
-
new_d = new_d.clamp(1, 10)
|
48 |
-
# recall
|
49 |
-
if x[1] > 1:
|
50 |
-
new_s = s * (1 + torch.exp(self.w[6]) *
|
51 |
-
(11 - new_d) *
|
52 |
-
torch.pow(s, self.w[7]) *
|
53 |
-
(torch.exp((1 - r) * self.w[8]) - 1))
|
54 |
-
# forget
|
55 |
-
else:
|
56 |
-
new_s = self.w[9] * torch.pow(new_d, self.w[10]) * torch.pow(
|
57 |
-
s, self.w[11]) * torch.exp((1 - r) * self.w[12])
|
58 |
-
return new_s, new_d
|
59 |
-
|
60 |
-
def loss(self, s, t, r):
|
61 |
-
return - (r * np.log(0.9) * t / s + (1 - r) * torch.log(1 - torch.exp(np.log(0.9) * t / s)))
|
62 |
-
|
63 |
-
def mean_reversion(self, init, current):
|
64 |
-
return self.w[5] * init + (1-self.w[5]) * current
|
65 |
-
|
66 |
-
|
67 |
-
class WeightClipper(object):
|
68 |
-
def __init__(self, frequency=1):
|
69 |
-
self.frequency = frequency
|
70 |
-
|
71 |
-
def __call__(self, module):
|
72 |
-
if hasattr(module, 'w'):
|
73 |
-
w = module.w.data
|
74 |
-
w[0] = w[0].clamp(0.1, 10)
|
75 |
-
w[1] = w[1].clamp(0.1, 5)
|
76 |
-
w[2] = w[2].clamp(1, 10)
|
77 |
-
w[3] = w[3].clamp(-5, -0.1)
|
78 |
-
w[4] = w[4].clamp(-5, -0.1)
|
79 |
-
w[5] = w[5].clamp(0, 0.5)
|
80 |
-
w[6] = w[6].clamp(0, 2)
|
81 |
-
w[7] = w[7].clamp(-0.2, -0.01)
|
82 |
-
w[8] = w[8].clamp(0.01, 1.5)
|
83 |
-
w[9] = w[9].clamp(0.5, 5)
|
84 |
-
w[10] = w[10].clamp(-2, -0.01)
|
85 |
-
w[11] = w[11].clamp(0.01, 0.9)
|
86 |
-
w[12] = w[12].clamp(0.01, 2)
|
87 |
-
module.w.data = w
|
88 |
-
|
89 |
-
|
90 |
-
def lineToTensor(line):
|
91 |
-
ivl = line[0].split(',')
|
92 |
-
response = line[1].split(',')
|
93 |
-
tensor = torch.zeros(len(response), 2)
|
94 |
-
for li, response in enumerate(response):
|
95 |
-
tensor[li][0] = int(ivl[li])
|
96 |
-
tensor[li][1] = int(response)
|
97 |
-
return tensor
|
98 |
-
|
99 |
-
|
100 |
-
class Collection:
|
101 |
-
def __init__(self, w):
|
102 |
-
self.model = FSRS(w)
|
103 |
-
|
104 |
-
def states(self, t_history, r_history):
|
105 |
-
with torch.no_grad():
|
106 |
-
line_tensor = lineToTensor(list(zip([t_history], [r_history]))[0])
|
107 |
-
output_t = [(self.model.zero, self.model.zero)]
|
108 |
-
for input_t in line_tensor:
|
109 |
-
output_t.append(self.model(input_t, *output_t[-1]))
|
110 |
-
return output_t[-1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
plot.py
DELETED
@@ -1,99 +0,0 @@
|
|
1 |
-
from tqdm.auto import trange
|
2 |
-
import gradio as gr
|
3 |
-
import pandas as pd
|
4 |
-
import numpy as np
|
5 |
-
import plotly.express as px
|
6 |
-
|
7 |
-
|
8 |
-
def make_plot(proj_dir, type_sequence, time_sequence, w, difficulty_distribution_padding, progress=gr.Progress(track_tqdm=True)):
|
9 |
-
base = 1.01
|
10 |
-
index_len = 793
|
11 |
-
index_offset = 200
|
12 |
-
d_range = 10
|
13 |
-
d_offset = 1
|
14 |
-
r_time = 8
|
15 |
-
f_time = 25
|
16 |
-
max_time = 200000
|
17 |
-
|
18 |
-
type_block = dict()
|
19 |
-
type_count = dict()
|
20 |
-
type_time = dict()
|
21 |
-
last_t = type_sequence[0]
|
22 |
-
type_block[last_t] = 1
|
23 |
-
type_count[last_t] = 1
|
24 |
-
type_time[last_t] = time_sequence[0]
|
25 |
-
for i,t in enumerate(type_sequence[1:]):
|
26 |
-
type_count[t] = type_count.setdefault(t, 0) + 1
|
27 |
-
type_time[t] = type_time.setdefault(t, 0) + time_sequence[i]
|
28 |
-
if t != last_t:
|
29 |
-
type_block[t] = type_block.setdefault(t, 0) + 1
|
30 |
-
last_t = t
|
31 |
-
|
32 |
-
r_time = round(type_time[1]/type_count[1]/1000, 1)
|
33 |
-
|
34 |
-
if 2 in type_count and 2 in type_block:
|
35 |
-
f_time = round(type_time[2]/type_block[2]/1000 + r_time, 1)
|
36 |
-
|
37 |
-
def stability2index(stability):
|
38 |
-
return int(round(np.log(stability) / np.log(base)) + index_offset)
|
39 |
-
|
40 |
-
def init_stability(d):
|
41 |
-
return max(((d - w[2]) / w[3] + 2) * w[1] + w[0], np.power(base, -index_offset))
|
42 |
-
|
43 |
-
def cal_next_recall_stability(s, r, d, response):
|
44 |
-
if response == 1:
|
45 |
-
return s * (1 + np.exp(w[6]) * (11 - d) * np.power(s, w[7]) * (np.exp((1 - r) * w[8]) - 1))
|
46 |
-
else:
|
47 |
-
return w[9] * np.power(d, w[10]) * np.power(s, w[11]) * np.exp((1 - r) * w[12])
|
48 |
-
|
49 |
-
stability_list = np.array([np.power(base, i - index_offset) for i in range(index_len)])
|
50 |
-
# print(f"terminal stability: {stability_list.max(): .2f}")
|
51 |
-
df = pd.DataFrame(columns=["retention", "difficulty", "time"])
|
52 |
-
|
53 |
-
for percentage in trange(96, 66, -2, desc='Time vs Retention plot'):
|
54 |
-
recall = percentage / 100
|
55 |
-
time_list = np.zeros((d_range, index_len))
|
56 |
-
time_list[:,:-1] = max_time
|
57 |
-
for d in range(d_range, 0, -1):
|
58 |
-
s0 = init_stability(d)
|
59 |
-
s0_index = stability2index(s0)
|
60 |
-
diff = max_time
|
61 |
-
while diff > 0.1:
|
62 |
-
s0_time = time_list[d - 1][s0_index]
|
63 |
-
for s_index in range(index_len - 2, -1, -1):
|
64 |
-
stability = stability_list[s_index];
|
65 |
-
interval = max(1, round(stability * np.log(recall) / np.log(0.9)))
|
66 |
-
p_recall = np.power(0.9, interval / stability)
|
67 |
-
recall_s = cal_next_recall_stability(stability, p_recall, d, 1)
|
68 |
-
forget_d = min(d + d_offset, 10)
|
69 |
-
forget_s = cal_next_recall_stability(stability, p_recall, forget_d, 0)
|
70 |
-
recall_s_index = min(stability2index(recall_s), index_len - 1)
|
71 |
-
forget_s_index = min(max(stability2index(forget_s), 0), index_len - 1)
|
72 |
-
recall_time = time_list[d - 1][recall_s_index] + r_time
|
73 |
-
forget_time = time_list[forget_d - 1][forget_s_index] + f_time
|
74 |
-
exp_time = p_recall * recall_time + (1.0 - p_recall) * forget_time
|
75 |
-
if exp_time < time_list[d - 1][s_index]:
|
76 |
-
time_list[d - 1][s_index] = exp_time
|
77 |
-
diff = s0_time - time_list[d - 1][s0_index]
|
78 |
-
df.loc[0 if pd.isnull(df.index.max()) else df.index.max() + 1] = [recall, d, s0_time]
|
79 |
-
|
80 |
-
|
81 |
-
df.sort_values(by=["difficulty", "retention"], inplace=True)
|
82 |
-
df.to_csv(proj_dir/"expected_time.csv", index=False)
|
83 |
-
# print("expected_repetitions.csv saved.")
|
84 |
-
|
85 |
-
optimal_retention_list = np.zeros(10)
|
86 |
-
df2 = pd.DataFrame()
|
87 |
-
for d in range(1, d_range + 1):
|
88 |
-
retention = df[df["difficulty"] == d]["retention"]
|
89 |
-
time = df[df["difficulty"] == d]["time"]
|
90 |
-
optimal_retention = retention.iat[time.argmin()]
|
91 |
-
optimal_retention_list[d - 1] = optimal_retention
|
92 |
-
df2 = df2.append(
|
93 |
-
pd.DataFrame({'retention': retention, 'expected time': time, 'd': d, 'r': optimal_retention}))
|
94 |
-
|
95 |
-
fig = px.line(df2, x="retention", y="expected time", color='d', log_y=True)
|
96 |
-
|
97 |
-
# print(f"\n-----suggested retention: {np.inner(difficulty_distribution_padding, optimal_retention_list):.2f}-----")
|
98 |
-
suggested_retention_markdown = f"""# Suggested Retention: `{np.inner(difficulty_distribution_padding, optimal_retention_list):.2f}`"""
|
99 |
-
return fig, suggested_retention_markdown
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -1,7 +1 @@
|
|
1 |
-
|
2 |
-
numpy==1.23.3
|
3 |
-
pandas==1.3.2
|
4 |
-
scikit_learn==1.1.2
|
5 |
-
torch==1.9.0
|
6 |
-
tqdm==4.64.1
|
7 |
-
plotly==5.13.0
|
|
|
1 |
+
fsrs4anki_optimizer==3.24.1
|
|
|
|
|
|
|
|
|
|
|
|
utilities.py
CHANGED
@@ -1,27 +1,9 @@
|
|
1 |
-
from functools import partial
|
2 |
-
import datetime
|
3 |
from zipfile import ZipFile
|
4 |
-
|
5 |
-
import sqlite3
|
6 |
-
import time
|
7 |
-
|
8 |
-
import gradio as gr
|
9 |
-
from tqdm.auto import tqdm
|
10 |
-
import pandas as pd
|
11 |
-
import numpy as np
|
12 |
import os
|
13 |
-
from datetime import timedelta, datetime
|
14 |
from pathlib import Path
|
15 |
|
16 |
-
import torch
|
17 |
-
from sklearn.utils import shuffle
|
18 |
-
|
19 |
-
from model import Collection, init_w, FSRS, WeightClipper, lineToTensor
|
20 |
-
|
21 |
|
22 |
# Extract the collection file or deck file to get the .anki21 database.
|
23 |
-
|
24 |
-
|
25 |
def extract(file, prefix):
|
26 |
proj_dir = Path(f'projects/{prefix}_{file.orig_name.replace(".", "_").replace("@", "_")}')
|
27 |
with ZipFile(file, 'r') as zip_ref:
|
@@ -29,272 +11,6 @@ def extract(file, prefix):
|
|
29 |
# print(f"Extracted {file.orig_name} successfully!")
|
30 |
return proj_dir
|
31 |
|
32 |
-
|
33 |
-
def create_time_series_features(revlog_start_date, timezone, next_day_starts_at, proj_dir,
|
34 |
-
progress=gr.Progress(track_tqdm=True)):
|
35 |
-
if os.path.isfile(proj_dir / "collection.anki21b"):
|
36 |
-
os.remove(proj_dir / "collection.anki21b")
|
37 |
-
raise gr.Error(
|
38 |
-
"Please export the file with `support older Anki versions` if you use the latest version of Anki.")
|
39 |
-
elif os.path.isfile(proj_dir / "collection.anki21"):
|
40 |
-
con = sqlite3.connect(proj_dir / "collection.anki21")
|
41 |
-
elif os.path.isfile(proj_dir / "collection.anki2"):
|
42 |
-
con = sqlite3.connect(proj_dir / "collection.anki2")
|
43 |
-
else:
|
44 |
-
raise Exception("Collection not exist!")
|
45 |
-
cur = con.cursor()
|
46 |
-
res = cur.execute("SELECT * FROM revlog")
|
47 |
-
revlog = res.fetchall()
|
48 |
-
|
49 |
-
df = pd.DataFrame(revlog)
|
50 |
-
df.columns = ['id', 'cid', 'usn', 'r', 'ivl',
|
51 |
-
'last_lvl', 'factor', 'time', 'type']
|
52 |
-
df = df[(df['cid'] <= time.time() * 1000) &
|
53 |
-
(df['id'] <= time.time() * 1000) &
|
54 |
-
(df['r'] > 0)].copy()
|
55 |
-
df['create_date'] = pd.to_datetime(df['cid'] // 1000, unit='s')
|
56 |
-
df['create_date'] = df['create_date'].dt.tz_localize(
|
57 |
-
'UTC').dt.tz_convert(timezone)
|
58 |
-
df['review_date'] = pd.to_datetime(df['id'] // 1000, unit='s')
|
59 |
-
df['review_date'] = df['review_date'].dt.tz_localize(
|
60 |
-
'UTC').dt.tz_convert(timezone)
|
61 |
-
df.drop(df[df['review_date'].dt.year < 2006].index, inplace=True)
|
62 |
-
df.sort_values(by=['cid', 'id'], inplace=True, ignore_index=True)
|
63 |
-
type_sequence = np.array(df['type'])
|
64 |
-
time_sequence = np.array(df['time'])
|
65 |
-
df.to_csv(proj_dir / "revlog.csv", index=False)
|
66 |
-
# print("revlog.csv saved.")
|
67 |
-
df = df[df['type'] != 3].copy()
|
68 |
-
df['real_days'] = df['review_date'] - timedelta(hours=next_day_starts_at)
|
69 |
-
df['real_days'] = pd.DatetimeIndex(df['real_days'].dt.floor('D', ambiguous='infer', nonexistent='shift_forward')).to_julian_date()
|
70 |
-
df.drop_duplicates(['cid', 'real_days'], keep='first', inplace=True)
|
71 |
-
df['delta_t'] = df.real_days.diff()
|
72 |
-
df.dropna(inplace=True)
|
73 |
-
df['delta_t'] = df['delta_t'].astype(dtype=int)
|
74 |
-
df['i'] = 1
|
75 |
-
df['r_history'] = ""
|
76 |
-
df['t_history'] = ""
|
77 |
-
col_idx = {key: i for i, key in enumerate(df.columns)}
|
78 |
-
|
79 |
-
# code from https://github.com/L-M-Sherlock/anki_revlog_analysis/blob/main/revlog_analysis.py
|
80 |
-
def get_feature(x):
|
81 |
-
last_kind = None
|
82 |
-
for idx, log in enumerate(x.itertuples()):
|
83 |
-
if last_kind is not None and last_kind in (1, 2) and log.type == 0:
|
84 |
-
return x.iloc[:idx]
|
85 |
-
last_kind = log.type
|
86 |
-
if idx == 0:
|
87 |
-
if log.type != 0:
|
88 |
-
return x.iloc[:idx]
|
89 |
-
x.iloc[idx, col_idx['delta_t']] = 0
|
90 |
-
if idx == x.shape[0] - 1:
|
91 |
-
break
|
92 |
-
x.iloc[idx + 1, col_idx['i']] = x.iloc[idx, col_idx['i']] + 1
|
93 |
-
x.iloc[idx + 1, col_idx[
|
94 |
-
't_history']] = f"{x.iloc[idx, col_idx['t_history']]},{x.iloc[idx, col_idx['delta_t']]}"
|
95 |
-
x.iloc[idx + 1, col_idx['r_history']] = f"{x.iloc[idx, col_idx['r_history']]},{x.iloc[idx, col_idx['r']]}"
|
96 |
-
return x
|
97 |
-
|
98 |
-
tqdm.pandas(desc='Saving Trainset')
|
99 |
-
df = df.groupby('cid', as_index=False, group_keys=False).progress_apply(get_feature)
|
100 |
-
df = df[df['id'] >= time.mktime(datetime.strptime(revlog_start_date, "%Y-%m-%d").timetuple()) * 1000]
|
101 |
-
df["t_history"] = df["t_history"].map(lambda x: x[1:] if len(x) > 1 else x)
|
102 |
-
df["r_history"] = df["r_history"].map(lambda x: x[1:] if len(x) > 1 else x)
|
103 |
-
df.to_csv(proj_dir / 'revlog_history.tsv', sep="\t", index=False)
|
104 |
-
# print("Trainset saved.")
|
105 |
-
|
106 |
-
def cal_retention(group: pd.DataFrame) -> pd.DataFrame:
|
107 |
-
group['retention'] = round(group['r'].map(lambda x: {1: 0, 2: 1, 3: 1, 4: 1}[x]).mean(), 4)
|
108 |
-
group['total_cnt'] = group.shape[0]
|
109 |
-
return group
|
110 |
-
|
111 |
-
tqdm.pandas(desc='Calculating Retention')
|
112 |
-
df = df.groupby(by=['r_history', 'delta_t']).progress_apply(cal_retention)
|
113 |
-
# print("Retention calculated.")
|
114 |
-
df = df.drop(columns=['id', 'cid', 'usn', 'ivl', 'last_lvl', 'factor', 'time', 'type', 'create_date', 'review_date',
|
115 |
-
'real_days', 'r', 't_history'])
|
116 |
-
df.drop_duplicates(inplace=True)
|
117 |
-
df['retention'] = df['retention'].map(lambda x: max(min(0.99, x), 0.01))
|
118 |
-
|
119 |
-
def cal_stability(group: pd.DataFrame) -> pd.DataFrame:
|
120 |
-
group_cnt = sum(group['total_cnt'])
|
121 |
-
if group_cnt < 10:
|
122 |
-
return pd.DataFrame()
|
123 |
-
group['group_cnt'] = group_cnt
|
124 |
-
if group['i'].values[0] > 1:
|
125 |
-
r_ivl_cnt = sum(group['delta_t'] * group['retention'].map(np.log) * pow(group['total_cnt'], 2))
|
126 |
-
ivl_ivl_cnt = sum(group['delta_t'].map(lambda x: x ** 2) * pow(group['total_cnt'], 2))
|
127 |
-
group['stability'] = round(np.log(0.9) / (r_ivl_cnt / ivl_ivl_cnt), 1)
|
128 |
-
else:
|
129 |
-
group['stability'] = 0.0
|
130 |
-
group['avg_retention'] = round(
|
131 |
-
sum(group['retention'] * pow(group['total_cnt'], 2)) / sum(pow(group['total_cnt'], 2)), 3)
|
132 |
-
group['avg_interval'] = round(
|
133 |
-
sum(group['delta_t'] * pow(group['total_cnt'], 2)) / sum(pow(group['total_cnt'], 2)), 1)
|
134 |
-
del group['total_cnt']
|
135 |
-
del group['retention']
|
136 |
-
del group['delta_t']
|
137 |
-
return group
|
138 |
-
|
139 |
-
tqdm.pandas(desc='Calculating Stability')
|
140 |
-
df = df.groupby(by=['r_history'], group_keys=False).progress_apply(cal_stability)
|
141 |
-
# print("Stability calculated.")
|
142 |
-
df.reset_index(drop=True, inplace=True)
|
143 |
-
df.drop_duplicates(inplace=True)
|
144 |
-
df.sort_values(by=['r_history'], inplace=True, ignore_index=True)
|
145 |
-
|
146 |
-
df_out = pd.DataFrame()
|
147 |
-
if df.shape[0] > 0:
|
148 |
-
for idx in tqdm(df.index):
|
149 |
-
item = df.loc[idx]
|
150 |
-
index = df[(df['i'] == item['i'] + 1) & (df['r_history'].str.startswith(item['r_history']))].index
|
151 |
-
df.loc[index, 'last_stability'] = item['stability']
|
152 |
-
df['factor'] = round(df['stability'] / df['last_stability'], 2)
|
153 |
-
df = df[(df['i'] >= 2) & (df['group_cnt'] >= 100)]
|
154 |
-
df['last_recall'] = df['r_history'].map(lambda x: x[-1])
|
155 |
-
df = df[df.groupby(['i', 'r_history'], group_keys=False)['group_cnt'].transform(max) == df['group_cnt']]
|
156 |
-
df.to_csv(proj_dir / 'stability_for_analysis.tsv', sep='\t', index=None)
|
157 |
-
# print("1:again, 2:hard, 3:good, 4:easy\n")
|
158 |
-
# print(df[df['r_history'].str.contains(r'^[1-4][^124]*$', regex=True)][
|
159 |
-
# ['r_history', 'avg_interval', 'avg_retention', 'stability', 'factor', 'group_cnt']].to_string(
|
160 |
-
# index=False))
|
161 |
-
# print("Analysis saved!")
|
162 |
-
|
163 |
-
df_out = df[df['r_history'].str.contains(r'^[1-4][^124]*$', regex=True)][
|
164 |
-
['r_history', 'avg_interval', 'avg_retention', 'stability', 'factor', 'group_cnt']]
|
165 |
-
return type_sequence, time_sequence, df_out
|
166 |
-
|
167 |
-
|
168 |
-
def train_model(proj_dir, progress=gr.Progress(track_tqdm=True)):
|
169 |
-
model = FSRS(init_w)
|
170 |
-
|
171 |
-
clipper = WeightClipper()
|
172 |
-
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
|
173 |
-
|
174 |
-
dataset = pd.read_csv(proj_dir / "revlog_history.tsv", sep='\t', index_col=None,
|
175 |
-
dtype={'r_history': str, 't_history': str})
|
176 |
-
dataset = dataset[(dataset['i'] > 1) & (dataset['delta_t'] > 0) & (dataset['t_history'].str.count(',0') == 0)]
|
177 |
-
|
178 |
-
tqdm.pandas(desc='Tensorizing Line')
|
179 |
-
dataset['tensor'] = dataset.progress_apply(lambda x: lineToTensor(list(zip([x['t_history']], [x['r_history']]))[0]),
|
180 |
-
axis=1)
|
181 |
-
# print("Tensorized!")
|
182 |
-
|
183 |
-
pre_train_set = dataset[dataset['i'] == 2]
|
184 |
-
# pretrain
|
185 |
-
epoch_len = len(pre_train_set)
|
186 |
-
n_epoch = 1
|
187 |
-
pbar = tqdm(desc="Pre-training", colour="red", total=epoch_len * n_epoch)
|
188 |
-
|
189 |
-
for k in range(n_epoch):
|
190 |
-
for i, (_, row) in enumerate(shuffle(pre_train_set, random_state=2022 + k).iterrows()):
|
191 |
-
model.train()
|
192 |
-
optimizer.zero_grad()
|
193 |
-
output_t = [(model.zero, model.zero)]
|
194 |
-
for input_t in row['tensor']:
|
195 |
-
output_t.append(model(input_t, *output_t[-1]))
|
196 |
-
loss = model.loss(output_t[-1][0], row['delta_t'],
|
197 |
-
{1: 0, 2: 1, 3: 1, 4: 1}[row['r']])
|
198 |
-
if np.isnan(loss.data.item()):
|
199 |
-
# Exception Case
|
200 |
-
# print(row, output_t)
|
201 |
-
raise Exception('error case')
|
202 |
-
loss.backward()
|
203 |
-
optimizer.step()
|
204 |
-
model.apply(clipper)
|
205 |
-
pbar.update()
|
206 |
-
pbar.close()
|
207 |
-
# for name, param in model.named_parameters():
|
208 |
-
# print(f"{name}: {list(map(lambda x: round(float(x), 4), param))}")
|
209 |
-
|
210 |
-
train_set = dataset[dataset['i'] > 2]
|
211 |
-
epoch_len = len(train_set)
|
212 |
-
n_epoch = 1
|
213 |
-
print_len = max(epoch_len * n_epoch // 10, 1)
|
214 |
-
pbar = tqdm(desc="Training", total=epoch_len * n_epoch)
|
215 |
-
|
216 |
-
for k in range(n_epoch):
|
217 |
-
for i, (_, row) in enumerate(shuffle(train_set, random_state=2022 + k).iterrows()):
|
218 |
-
model.train()
|
219 |
-
optimizer.zero_grad()
|
220 |
-
output_t = [(model.zero, model.zero)]
|
221 |
-
for input_t in row['tensor']:
|
222 |
-
output_t.append(model(input_t, *output_t[-1]))
|
223 |
-
loss = model.loss(output_t[-1][0], row['delta_t'],
|
224 |
-
{1: 0, 2: 1, 3: 1, 4: 1}[row['r']])
|
225 |
-
if np.isnan(loss.data.item()):
|
226 |
-
# Exception Case
|
227 |
-
# print(row, output_t)
|
228 |
-
raise Exception('error case')
|
229 |
-
loss.backward()
|
230 |
-
for param in model.parameters():
|
231 |
-
param.grad[:2] = torch.zeros(2)
|
232 |
-
optimizer.step()
|
233 |
-
model.apply(clipper)
|
234 |
-
pbar.update()
|
235 |
-
|
236 |
-
# if (k * epoch_len + i) % print_len == 0:
|
237 |
-
# print(f"iteration: {k * epoch_len + i + 1}")
|
238 |
-
# for name, param in model.named_parameters():
|
239 |
-
# print(f"{name}: {list(map(lambda x: round(float(x), 4), param))}")
|
240 |
-
pbar.close()
|
241 |
-
|
242 |
-
w = list(map(lambda x: round(float(x), 4), dict(model.named_parameters())['w'].data))
|
243 |
-
|
244 |
-
# print("\nTraining finished!")
|
245 |
-
return w, dataset
|
246 |
-
|
247 |
-
|
248 |
-
def process_personalized_collection(requestRetention, w):
|
249 |
-
my_collection = Collection(w)
|
250 |
-
rating_dict = {1: "again", 2: "hard", 3: "good", 4: "easy"}
|
251 |
-
rating_markdown = []
|
252 |
-
for first_rating in (1, 2, 3, 4):
|
253 |
-
rating_markdown.append(f'## First Rating: {first_rating} ({rating_dict[first_rating]})')
|
254 |
-
t_history = "0"
|
255 |
-
d_history = "0"
|
256 |
-
r_history = f"{first_rating}" # the first rating of the new card
|
257 |
-
# print("stability, difficulty, lapses")
|
258 |
-
for i in range(10):
|
259 |
-
states = my_collection.states(t_history, r_history)
|
260 |
-
# print('{0:9.2f} {1:11.2f} {2:7.0f}'.format(
|
261 |
-
# *list(map(lambda x: round(float(x), 4), states))))
|
262 |
-
next_t = max(round(float(np.log(requestRetention) / np.log(0.9) * states[0])), 1)
|
263 |
-
difficulty = round(float(states[1]), 1)
|
264 |
-
t_history += f',{int(next_t)}'
|
265 |
-
d_history += f',{difficulty}'
|
266 |
-
r_history += f",3"
|
267 |
-
rating_markdown.append(f"**rating history**: {r_history}")
|
268 |
-
rating_markdown.append(f"**interval history**: {t_history}")
|
269 |
-
rating_markdown.append(f"**difficulty history**: {d_history}\n")
|
270 |
-
rating_markdown = '\n\n'.join(rating_markdown)
|
271 |
-
return my_collection, rating_markdown
|
272 |
-
|
273 |
-
|
274 |
-
def log_loss(my_collection, row):
|
275 |
-
states = my_collection.states(row['t_history'], row['r_history'])
|
276 |
-
row['log_loss'] = float(my_collection.model.loss(states[0], row['delta_t'], {1: 0, 2: 1, 3: 1, 4: 1}[row['r']]))
|
277 |
-
return row
|
278 |
-
|
279 |
-
|
280 |
-
def my_loss(dataset, w):
|
281 |
-
my_collection = Collection(init_w)
|
282 |
-
tqdm.pandas(desc='Calculating Loss before Training')
|
283 |
-
dataset = dataset.progress_apply(partial(log_loss, my_collection), axis=1)
|
284 |
-
# print(f"Loss before training: {dataset['log_loss'].mean():.4f}")
|
285 |
-
loss_before = f"{dataset['log_loss'].mean():.4f}"
|
286 |
-
my_collection = Collection(w)
|
287 |
-
tqdm.pandas(desc='Calculating Loss After Training')
|
288 |
-
dataset = dataset.progress_apply(partial(log_loss, my_collection), axis=1)
|
289 |
-
# print(f"Loss after training: {dataset['log_loss'].mean():.4f}")
|
290 |
-
loss_after = f"{dataset['log_loss'].mean():.4f}"
|
291 |
-
return f"""
|
292 |
-
**Loss before training**: {loss_before}
|
293 |
-
|
294 |
-
**Loss after training**: {loss_after}
|
295 |
-
"""
|
296 |
-
|
297 |
-
|
298 |
def cleanup(proj_dir: Path, files):
|
299 |
"""
|
300 |
Delete all files in prefix that dont have filenames in files
|
|
|
|
|
|
|
1 |
from zipfile import ZipFile
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import os
|
|
|
3 |
from pathlib import Path
|
4 |
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
# Extract the collection file or deck file to get the .anki21 database.
|
|
|
|
|
7 |
def extract(file, prefix):
|
8 |
proj_dir = Path(f'projects/{prefix}_{file.orig_name.replace(".", "_").replace("@", "_")}')
|
9 |
with ZipFile(file, 'r') as zip_ref:
|
|
|
11 |
# print(f"Extracted {file.orig_name} successfully!")
|
12 |
return proj_dir
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
def cleanup(proj_dir: Path, files):
|
15 |
"""
|
16 |
Delete all files in prefix that dont have filenames in files
|