adding code for creating Fig 1B and Fig 2B
Browse files- README.md +1 -1
- fuson_plm/paper_figures/README.md +11 -0
- fuson_plm/paper_figures/fig1/circles.png +0 -0
- fuson_plm/paper_figures/fig1/data_circles.py +45 -0
- fuson_plm/paper_figures/fig2/cosine_3_epochs.png +0 -0
- fuson_plm/paper_figures/fig2/log_linear_3_epochs.png +0 -0
- fuson_plm/paper_figures/fig2/mask_rate_0.15.png +0 -0
- fuson_plm/paper_figures/fig2/mask_rate_0.15_0.2.png +0 -0
- fuson_plm/paper_figures/fig2/mask_rate_0.15_0.25.png +0 -0
- fuson_plm/paper_figures/fig2/mask_rate_0.15_0.3.png +0 -0
- fuson_plm/paper_figures/fig2/mask_rate_0.15_0.35.png +0 -0
- fuson_plm/paper_figures/fig2/mask_rate_0.15_0.4.png +0 -0
- fuson_plm/paper_figures/fig2/mask_rate_0.2.png +0 -0
- fuson_plm/paper_figures/fig2/mask_rate_0.25.png +0 -0
- fuson_plm/paper_figures/fig2/mask_rate_plots.py +156 -0
- fuson_plm/paper_figures/fig2/stepwise_3_epochs.png +0 -0
README.md
CHANGED
@@ -47,7 +47,7 @@ with torch.no_grad():
|
|
47 |
embeddings = embeddings[1:-1, :]
|
48 |
|
49 |
# Convert embeddings to numpy array (if needed)
|
50 |
-
embeddings = embeddings.numpy()
|
51 |
|
52 |
print("Per-residue embeddings shape:", embeddings.shape)
|
53 |
|
|
|
47 |
embeddings = embeddings[1:-1, :]
|
48 |
|
49 |
# Convert embeddings to numpy array (if needed)
|
50 |
+
embeddings = embeddings.cpu().numpy()
|
51 |
|
52 |
print("Per-residue embeddings shape:", embeddings.shape)
|
53 |
|
fuson_plm/paper_figures/README.md
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Paper figures
|
2 |
+
|
3 |
+
This folder holds code for creating Fig. 1B and all of the plots in Fig. 2B. These figures are programmatically generated, but not part of a particular training/benchmarking pipeline. For all other figures in the paper, any code used to generate figures will be in the appropriate `training/`, `data/`, or `benchmarking/` subfolder.
|
4 |
+
|
5 |
+
### Figure 1B
|
6 |
+
|
7 |
+
To recreate the plotted circles in Figure 1B, go to the `fig1` folder and enter `python data_circles.py`.
|
8 |
+
|
9 |
+
### Figure 2B
|
10 |
+
|
11 |
+
To recreate the masking rate and scheduler plots in figure 2B, go to the `fig2` folder and enter `python mask_rate_plots.py`.
|
fuson_plm/paper_figures/fig1/circles.png
ADDED
![]() |
fuson_plm/paper_figures/fig1/data_circles.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
import matplotlib.patches as patches
|
3 |
+
|
4 |
+
# Calculate the quotients
|
5 |
+
quotient = (65000000 / 44414)/1#00
|
6 |
+
ht_quotient = (65000000 / 9249)/1#00
|
7 |
+
|
8 |
+
# Sizes for the circles
|
9 |
+
grey_circle_size = 20000
|
10 |
+
pink_circle_size = grey_circle_size/quotient
|
11 |
+
|
12 |
+
# Plot the circles side by side
|
13 |
+
fig, ax = plt.subplots(figsize=(20, 20))
|
14 |
+
|
15 |
+
# Plot grey circle on the left with a black outline
|
16 |
+
grey_radius = grey_circle_size**0.5 / 500
|
17 |
+
grey_circle = plt.Circle((0.3, 0.5), radius=grey_radius, color="grey", ec="black", lw=2, alpha=0.3)
|
18 |
+
ax.add_patch(grey_circle)
|
19 |
+
|
20 |
+
# Calculate the red/blue circle size based on ht_quotient
|
21 |
+
red_blue_circle_radius = (grey_circle_size / ht_quotient)**0.5 / 500
|
22 |
+
|
23 |
+
# Position the red/blue circle on the right edge of the grey circle
|
24 |
+
red_blue_circle_center_x = 0.25 + grey_radius - red_blue_circle_radius
|
25 |
+
|
26 |
+
# Add the half red, half blue circle by overlaying two half-circle patches
|
27 |
+
red_half_circle = patches.Wedge((red_blue_circle_center_x, 0.5), red_blue_circle_radius, 90, 270, color="#de8a8a", ec="black", lw=1)
|
28 |
+
blue_half_circle = patches.Wedge((red_blue_circle_center_x, 0.5), red_blue_circle_radius, 270, 90, color="#6ea4da", ec="black", lw=1)
|
29 |
+
|
30 |
+
# Add the half-circle patches to the plot
|
31 |
+
ax.add_patch(red_half_circle)
|
32 |
+
ax.add_patch(blue_half_circle)
|
33 |
+
|
34 |
+
# Plot pink circle on the right with a black outline
|
35 |
+
pink_circle = plt.Circle((0.6, 0.5), radius=pink_circle_size**0.5 / 500, color="mediumpurple", ec="black", lw=2, alpha=0.7)
|
36 |
+
ax.add_patch(pink_circle)
|
37 |
+
|
38 |
+
# Set aspect ratio and limits
|
39 |
+
ax.set_aspect('equal')
|
40 |
+
ax.set_xlim(0, 1)
|
41 |
+
ax.set_ylim(0, 1)
|
42 |
+
ax.axis('off') # Turn off axes
|
43 |
+
plt.tight_layout()
|
44 |
+
plt.savefig('circles.png', dpi=300)
|
45 |
+
plt.show()
|
fuson_plm/paper_figures/fig2/cosine_3_epochs.png
ADDED
![]() |
fuson_plm/paper_figures/fig2/log_linear_3_epochs.png
ADDED
![]() |
fuson_plm/paper_figures/fig2/mask_rate_0.15.png
ADDED
![]() |
fuson_plm/paper_figures/fig2/mask_rate_0.15_0.2.png
ADDED
![]() |
fuson_plm/paper_figures/fig2/mask_rate_0.15_0.25.png
ADDED
![]() |
fuson_plm/paper_figures/fig2/mask_rate_0.15_0.3.png
ADDED
![]() |
fuson_plm/paper_figures/fig2/mask_rate_0.15_0.35.png
ADDED
![]() |
fuson_plm/paper_figures/fig2/mask_rate_0.15_0.4.png
ADDED
![]() |
fuson_plm/paper_figures/fig2/mask_rate_0.2.png
ADDED
![]() |
fuson_plm/paper_figures/fig2/mask_rate_0.25.png
ADDED
![]() |
fuson_plm/paper_figures/fig2/mask_rate_plots.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
import matplotlib.patches as patches
|
3 |
+
import numpy as np
|
4 |
+
from fuson_plm.utils.visualizing import set_font
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
# Cosine Increase Masking Rate Scheduler implementation
|
9 |
+
def compute_cosine_masking_rate(progress, min_rate, max_rate):
|
10 |
+
cosine_increase = 0.5 * (1 - np.cos(np.pi * progress))
|
11 |
+
return min_rate + (max_rate - min_rate) * cosine_increase
|
12 |
+
|
13 |
+
def compute_log_linear_masking_rate(progress, min_rate, max_rate):
|
14 |
+
# Avoid log(0) by clamping progress to a minimum of a small positive number
|
15 |
+
progress = max(progress, 1e-10)
|
16 |
+
log_linear_increase = np.log1p(progress) / np.log1p(1) # Normalizing to keep range in [0, 1]
|
17 |
+
return min_rate + (max_rate - min_rate) * log_linear_increase
|
18 |
+
|
19 |
+
def compute_stepwise_masking_rate(progress, min_rate, max_rate, total_batches, num_steps):
|
20 |
+
# Compute the batch interval and rate increment
|
21 |
+
batch_interval = total_batches // num_steps
|
22 |
+
rate_increment = (max_rate - min_rate) / (num_steps - 1) # Include max_rate in steps
|
23 |
+
|
24 |
+
# Determine the current step based on progress
|
25 |
+
current_step = int(progress * total_batches / batch_interval)
|
26 |
+
# Cap the step number to `num_steps - 1` to ensure max rate is included
|
27 |
+
current_step = min(current_step, num_steps - 1)
|
28 |
+
|
29 |
+
# Calculate the masking rate for the current step
|
30 |
+
masking_rate = min_rate + current_step * rate_increment
|
31 |
+
return masking_rate
|
32 |
+
|
33 |
+
def n_epochs_scheduler_plot(min_rate, max_rate, total_batches, num_epochs=5, scheduler="cosine", num_steps=20):
|
34 |
+
set_font()
|
35 |
+
# Parameters for the scheduler - using training
|
36 |
+
batch_numbers = np.arange(total_batches)
|
37 |
+
|
38 |
+
masking_rates = None
|
39 |
+
if scheduler == "cosine":
|
40 |
+
masking_rates = [compute_cosine_masking_rate(batch / total_batches, min_rate, max_rate) for batch in batch_numbers]
|
41 |
+
elif scheduler == "log_linear":
|
42 |
+
masking_rates = [compute_log_linear_masking_rate(batch / total_batches, min_rate, max_rate) for batch in batch_numbers]
|
43 |
+
elif scheduler == "stepwise":
|
44 |
+
masking_rates = [compute_stepwise_masking_rate(batch / total_batches, min_rate, max_rate, total_batches, num_steps) for batch in batch_numbers]
|
45 |
+
else:
|
46 |
+
return
|
47 |
+
|
48 |
+
# Generate masking rates for multiple epochs
|
49 |
+
epoch_masking_rates = []
|
50 |
+
for epoch in range(num_epochs):
|
51 |
+
epoch_masking_rates.extend(masking_rates)
|
52 |
+
|
53 |
+
# Generate batch numbers for the extended epochs
|
54 |
+
extended_batch_numbers = np.arange(len(epoch_masking_rates))
|
55 |
+
|
56 |
+
# Plot the masking rate over the batches for multiple epochs
|
57 |
+
plt.figure(figsize=(10, 4))
|
58 |
+
plt.plot(extended_batch_numbers, epoch_masking_rates, color='black', linewidth=3)
|
59 |
+
|
60 |
+
# Add y ticks
|
61 |
+
plt.yticks(
|
62 |
+
[0.15, 0.20, 0.25, 0.30, 0.35, 0.40],
|
63 |
+
labels=["0.15", "0.20", "0.25", "0.30", "0.35", "0.40"],
|
64 |
+
fontsize=30
|
65 |
+
)
|
66 |
+
|
67 |
+
# Add x tick labels at the end of each wave
|
68 |
+
wave_positions = [total_batches * (i + 1) - 1 for i in range(num_epochs)]
|
69 |
+
wave_labels = [str(i + 1) if i < num_epochs - 1 else "N" for i in range(num_epochs)]
|
70 |
+
|
71 |
+
plt.xticks(
|
72 |
+
wave_positions,
|
73 |
+
labels=wave_labels,
|
74 |
+
fontsize=30
|
75 |
+
)
|
76 |
+
|
77 |
+
# Add "..." between the second and last wave
|
78 |
+
if num_epochs > 2:
|
79 |
+
mid_x = (wave_positions[1] + wave_positions[-1]) / 2
|
80 |
+
plt.text(mid_x, 0.12, "...", ha="center", fontsize=30)
|
81 |
+
|
82 |
+
|
83 |
+
# Remove axis labels and title
|
84 |
+
plt.gca().set_xlabel('') # Remove x-axis label
|
85 |
+
plt.gca().set_ylabel('') # Remove y-axis label
|
86 |
+
plt.title('') # Remove the title
|
87 |
+
|
88 |
+
plt.tight_layout()
|
89 |
+
plt.show()
|
90 |
+
plt.savefig(f"{scheduler}_{num_epochs}_epochs.png", dpi=300)
|
91 |
+
|
92 |
+
|
93 |
+
def plot_masking_rate_range(min_rate, max_rate=None):
|
94 |
+
set_font()
|
95 |
+
plt.figure(figsize=(5, 1)) # Make the plot short to emphasize the rectangle
|
96 |
+
|
97 |
+
if max_rate is None:
|
98 |
+
# Plot a vertical red line at min_rate
|
99 |
+
plt.axvline(x=min_rate, color='black', linestyle='-', linewidth=4, label=f"Rate = {min_rate}")
|
100 |
+
#plt.text(min_rate, 0.5, f"{min_rate:.2f}", color='red', ha='center', va='center', fontsize=10)
|
101 |
+
else:
|
102 |
+
# Shade the range from min_rate to max_rate in red
|
103 |
+
plt.fill_betweenx([0, 1], min_rate, max_rate, color='black', alpha=0.5)
|
104 |
+
#plt.text(min_rate, 0.5, f"{min_rate:.2f}", color='black', ha='center', va='center', fontsize=10)
|
105 |
+
#plt.text(max_rate, 0.5, f"{max_rate:.2f}", color='black', ha='center', va='center', fontsize=10)
|
106 |
+
|
107 |
+
# Adjust x-axis
|
108 |
+
plt.xlim(0.145, 0.40)
|
109 |
+
plt.xticks([0.15, 0.20, 0.25, 0.30, 0.35, 0.40], fontsize=20)
|
110 |
+
plt.tick_params(axis='y', which='both', left=False, labelleft=False) # Remove y-axis ticks and labels
|
111 |
+
|
112 |
+
# Remove unnecessary elements
|
113 |
+
plt.gca().spines['top'].set_visible(False)
|
114 |
+
plt.gca().spines['right'].set_visible(False)
|
115 |
+
plt.gca().spines['left'].set_visible(False)
|
116 |
+
plt.gca().spines['bottom'].set_linewidth(0.5)
|
117 |
+
plt.gca().yaxis.set_visible(False) # Remove y-axis entirely
|
118 |
+
plt.xlabel("") # No x-axis label
|
119 |
+
plt.title("") # No title
|
120 |
+
|
121 |
+
plt.tight_layout()
|
122 |
+
plt.show()
|
123 |
+
plot_title = f"mask_rate_{min_rate}.png"
|
124 |
+
if max_rate is not None:
|
125 |
+
plot_title =f"mask_rate_{min_rate}_{max_rate}.png"
|
126 |
+
plt.savefig(plot_title, dpi=300)
|
127 |
+
|
128 |
+
|
129 |
+
def main():
|
130 |
+
min_rate = 0.15
|
131 |
+
max_rate = 0.40
|
132 |
+
num_steps = 20
|
133 |
+
total_batches = 4215
|
134 |
+
num_epochs = 3
|
135 |
+
|
136 |
+
# Make the 3-epoch cosine plot
|
137 |
+
n_epochs_scheduler_plot(min_rate, max_rate, total_batches, num_epochs=num_epochs, scheduler="cosine")
|
138 |
+
|
139 |
+
# Make the 3-epoch log-linear plot
|
140 |
+
n_epochs_scheduler_plot(min_rate, max_rate, total_batches, num_epochs=num_epochs, scheduler="log_linear")
|
141 |
+
|
142 |
+
# Make the 3-epoch stepwise plot
|
143 |
+
n_epochs_scheduler_plot(min_rate, max_rate, total_batches, num_epochs=num_epochs, scheduler="stepwise")
|
144 |
+
|
145 |
+
# Make all the rate plots
|
146 |
+
plot_masking_rate_range(0.15)
|
147 |
+
plot_masking_rate_range(0.20)
|
148 |
+
plot_masking_rate_range(0.25)
|
149 |
+
plot_masking_rate_range(0.15, 0.20)
|
150 |
+
plot_masking_rate_range(0.15, 0.25)
|
151 |
+
plot_masking_rate_range(0.15, 0.30)
|
152 |
+
plot_masking_rate_range(0.15, 0.35)
|
153 |
+
plot_masking_rate_range(0.15, 0.40)
|
154 |
+
|
155 |
+
if __name__ == "__main__":
|
156 |
+
main()
|
fuson_plm/paper_figures/fig2/stepwise_3_epochs.png
ADDED
![]() |