svincoff commited on
Commit
0e3c3b0
·
1 Parent(s): 7fc386b

adding code for creating Fig 1B and Fig 2B

Browse files
README.md CHANGED
@@ -47,7 +47,7 @@ with torch.no_grad():
47
  embeddings = embeddings[1:-1, :]
48
 
49
  # Convert embeddings to numpy array (if needed)
50
- embeddings = embeddings.numpy()
51
 
52
  print("Per-residue embeddings shape:", embeddings.shape)
53
 
 
47
  embeddings = embeddings[1:-1, :]
48
 
49
  # Convert embeddings to numpy array (if needed)
50
+ embeddings = embeddings.cpu().numpy()
51
 
52
  print("Per-residue embeddings shape:", embeddings.shape)
53
 
fuson_plm/paper_figures/README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Paper figures
2
+
3
+ This folder holds code for creating Fig. 1B and all of the plots in Fig. 2B. These figures are programmatically generated, but not part of a particular training/benchmarking pipeline. For all other figures in the paper, any code used to generate figures will be in the appropriate `training/`, `data/`, or `benchmarking/` subfolder.
4
+
5
+ ### Figure 1B
6
+
7
+ To recreate the plotted circles in Figure 1B, go to the `fig1` folder and enter `python data_circles.py`.
8
+
9
+ ### Figure 2B
10
+
11
+ To recreate the masking rate and scheduler plots in figure 2B, go to the `fig2` folder and enter `python mask_rate_plots.py`.
fuson_plm/paper_figures/fig1/circles.png ADDED
fuson_plm/paper_figures/fig1/data_circles.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import matplotlib.patches as patches
3
+
4
+ # Calculate the quotients
5
+ quotient = (65000000 / 44414)/1#00
6
+ ht_quotient = (65000000 / 9249)/1#00
7
+
8
+ # Sizes for the circles
9
+ grey_circle_size = 20000
10
+ pink_circle_size = grey_circle_size/quotient
11
+
12
+ # Plot the circles side by side
13
+ fig, ax = plt.subplots(figsize=(20, 20))
14
+
15
+ # Plot grey circle on the left with a black outline
16
+ grey_radius = grey_circle_size**0.5 / 500
17
+ grey_circle = plt.Circle((0.3, 0.5), radius=grey_radius, color="grey", ec="black", lw=2, alpha=0.3)
18
+ ax.add_patch(grey_circle)
19
+
20
+ # Calculate the red/blue circle size based on ht_quotient
21
+ red_blue_circle_radius = (grey_circle_size / ht_quotient)**0.5 / 500
22
+
23
+ # Position the red/blue circle on the right edge of the grey circle
24
+ red_blue_circle_center_x = 0.25 + grey_radius - red_blue_circle_radius
25
+
26
+ # Add the half red, half blue circle by overlaying two half-circle patches
27
+ red_half_circle = patches.Wedge((red_blue_circle_center_x, 0.5), red_blue_circle_radius, 90, 270, color="#de8a8a", ec="black", lw=1)
28
+ blue_half_circle = patches.Wedge((red_blue_circle_center_x, 0.5), red_blue_circle_radius, 270, 90, color="#6ea4da", ec="black", lw=1)
29
+
30
+ # Add the half-circle patches to the plot
31
+ ax.add_patch(red_half_circle)
32
+ ax.add_patch(blue_half_circle)
33
+
34
+ # Plot pink circle on the right with a black outline
35
+ pink_circle = plt.Circle((0.6, 0.5), radius=pink_circle_size**0.5 / 500, color="mediumpurple", ec="black", lw=2, alpha=0.7)
36
+ ax.add_patch(pink_circle)
37
+
38
+ # Set aspect ratio and limits
39
+ ax.set_aspect('equal')
40
+ ax.set_xlim(0, 1)
41
+ ax.set_ylim(0, 1)
42
+ ax.axis('off') # Turn off axes
43
+ plt.tight_layout()
44
+ plt.savefig('circles.png', dpi=300)
45
+ plt.show()
fuson_plm/paper_figures/fig2/cosine_3_epochs.png ADDED
fuson_plm/paper_figures/fig2/log_linear_3_epochs.png ADDED
fuson_plm/paper_figures/fig2/mask_rate_0.15.png ADDED
fuson_plm/paper_figures/fig2/mask_rate_0.15_0.2.png ADDED
fuson_plm/paper_figures/fig2/mask_rate_0.15_0.25.png ADDED
fuson_plm/paper_figures/fig2/mask_rate_0.15_0.3.png ADDED
fuson_plm/paper_figures/fig2/mask_rate_0.15_0.35.png ADDED
fuson_plm/paper_figures/fig2/mask_rate_0.15_0.4.png ADDED
fuson_plm/paper_figures/fig2/mask_rate_0.2.png ADDED
fuson_plm/paper_figures/fig2/mask_rate_0.25.png ADDED
fuson_plm/paper_figures/fig2/mask_rate_plots.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import matplotlib.patches as patches
3
+ import numpy as np
4
+ from fuson_plm.utils.visualizing import set_font
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+
8
+ # Cosine Increase Masking Rate Scheduler implementation
9
+ def compute_cosine_masking_rate(progress, min_rate, max_rate):
10
+ cosine_increase = 0.5 * (1 - np.cos(np.pi * progress))
11
+ return min_rate + (max_rate - min_rate) * cosine_increase
12
+
13
+ def compute_log_linear_masking_rate(progress, min_rate, max_rate):
14
+ # Avoid log(0) by clamping progress to a minimum of a small positive number
15
+ progress = max(progress, 1e-10)
16
+ log_linear_increase = np.log1p(progress) / np.log1p(1) # Normalizing to keep range in [0, 1]
17
+ return min_rate + (max_rate - min_rate) * log_linear_increase
18
+
19
+ def compute_stepwise_masking_rate(progress, min_rate, max_rate, total_batches, num_steps):
20
+ # Compute the batch interval and rate increment
21
+ batch_interval = total_batches // num_steps
22
+ rate_increment = (max_rate - min_rate) / (num_steps - 1) # Include max_rate in steps
23
+
24
+ # Determine the current step based on progress
25
+ current_step = int(progress * total_batches / batch_interval)
26
+ # Cap the step number to `num_steps - 1` to ensure max rate is included
27
+ current_step = min(current_step, num_steps - 1)
28
+
29
+ # Calculate the masking rate for the current step
30
+ masking_rate = min_rate + current_step * rate_increment
31
+ return masking_rate
32
+
33
+ def n_epochs_scheduler_plot(min_rate, max_rate, total_batches, num_epochs=5, scheduler="cosine", num_steps=20):
34
+ set_font()
35
+ # Parameters for the scheduler - using training
36
+ batch_numbers = np.arange(total_batches)
37
+
38
+ masking_rates = None
39
+ if scheduler == "cosine":
40
+ masking_rates = [compute_cosine_masking_rate(batch / total_batches, min_rate, max_rate) for batch in batch_numbers]
41
+ elif scheduler == "log_linear":
42
+ masking_rates = [compute_log_linear_masking_rate(batch / total_batches, min_rate, max_rate) for batch in batch_numbers]
43
+ elif scheduler == "stepwise":
44
+ masking_rates = [compute_stepwise_masking_rate(batch / total_batches, min_rate, max_rate, total_batches, num_steps) for batch in batch_numbers]
45
+ else:
46
+ return
47
+
48
+ # Generate masking rates for multiple epochs
49
+ epoch_masking_rates = []
50
+ for epoch in range(num_epochs):
51
+ epoch_masking_rates.extend(masking_rates)
52
+
53
+ # Generate batch numbers for the extended epochs
54
+ extended_batch_numbers = np.arange(len(epoch_masking_rates))
55
+
56
+ # Plot the masking rate over the batches for multiple epochs
57
+ plt.figure(figsize=(10, 4))
58
+ plt.plot(extended_batch_numbers, epoch_masking_rates, color='black', linewidth=3)
59
+
60
+ # Add y ticks
61
+ plt.yticks(
62
+ [0.15, 0.20, 0.25, 0.30, 0.35, 0.40],
63
+ labels=["0.15", "0.20", "0.25", "0.30", "0.35", "0.40"],
64
+ fontsize=30
65
+ )
66
+
67
+ # Add x tick labels at the end of each wave
68
+ wave_positions = [total_batches * (i + 1) - 1 for i in range(num_epochs)]
69
+ wave_labels = [str(i + 1) if i < num_epochs - 1 else "N" for i in range(num_epochs)]
70
+
71
+ plt.xticks(
72
+ wave_positions,
73
+ labels=wave_labels,
74
+ fontsize=30
75
+ )
76
+
77
+ # Add "..." between the second and last wave
78
+ if num_epochs > 2:
79
+ mid_x = (wave_positions[1] + wave_positions[-1]) / 2
80
+ plt.text(mid_x, 0.12, "...", ha="center", fontsize=30)
81
+
82
+
83
+ # Remove axis labels and title
84
+ plt.gca().set_xlabel('') # Remove x-axis label
85
+ plt.gca().set_ylabel('') # Remove y-axis label
86
+ plt.title('') # Remove the title
87
+
88
+ plt.tight_layout()
89
+ plt.show()
90
+ plt.savefig(f"{scheduler}_{num_epochs}_epochs.png", dpi=300)
91
+
92
+
93
+ def plot_masking_rate_range(min_rate, max_rate=None):
94
+ set_font()
95
+ plt.figure(figsize=(5, 1)) # Make the plot short to emphasize the rectangle
96
+
97
+ if max_rate is None:
98
+ # Plot a vertical red line at min_rate
99
+ plt.axvline(x=min_rate, color='black', linestyle='-', linewidth=4, label=f"Rate = {min_rate}")
100
+ #plt.text(min_rate, 0.5, f"{min_rate:.2f}", color='red', ha='center', va='center', fontsize=10)
101
+ else:
102
+ # Shade the range from min_rate to max_rate in red
103
+ plt.fill_betweenx([0, 1], min_rate, max_rate, color='black', alpha=0.5)
104
+ #plt.text(min_rate, 0.5, f"{min_rate:.2f}", color='black', ha='center', va='center', fontsize=10)
105
+ #plt.text(max_rate, 0.5, f"{max_rate:.2f}", color='black', ha='center', va='center', fontsize=10)
106
+
107
+ # Adjust x-axis
108
+ plt.xlim(0.145, 0.40)
109
+ plt.xticks([0.15, 0.20, 0.25, 0.30, 0.35, 0.40], fontsize=20)
110
+ plt.tick_params(axis='y', which='both', left=False, labelleft=False) # Remove y-axis ticks and labels
111
+
112
+ # Remove unnecessary elements
113
+ plt.gca().spines['top'].set_visible(False)
114
+ plt.gca().spines['right'].set_visible(False)
115
+ plt.gca().spines['left'].set_visible(False)
116
+ plt.gca().spines['bottom'].set_linewidth(0.5)
117
+ plt.gca().yaxis.set_visible(False) # Remove y-axis entirely
118
+ plt.xlabel("") # No x-axis label
119
+ plt.title("") # No title
120
+
121
+ plt.tight_layout()
122
+ plt.show()
123
+ plot_title = f"mask_rate_{min_rate}.png"
124
+ if max_rate is not None:
125
+ plot_title =f"mask_rate_{min_rate}_{max_rate}.png"
126
+ plt.savefig(plot_title, dpi=300)
127
+
128
+
129
+ def main():
130
+ min_rate = 0.15
131
+ max_rate = 0.40
132
+ num_steps = 20
133
+ total_batches = 4215
134
+ num_epochs = 3
135
+
136
+ # Make the 3-epoch cosine plot
137
+ n_epochs_scheduler_plot(min_rate, max_rate, total_batches, num_epochs=num_epochs, scheduler="cosine")
138
+
139
+ # Make the 3-epoch log-linear plot
140
+ n_epochs_scheduler_plot(min_rate, max_rate, total_batches, num_epochs=num_epochs, scheduler="log_linear")
141
+
142
+ # Make the 3-epoch stepwise plot
143
+ n_epochs_scheduler_plot(min_rate, max_rate, total_batches, num_epochs=num_epochs, scheduler="stepwise")
144
+
145
+ # Make all the rate plots
146
+ plot_masking_rate_range(0.15)
147
+ plot_masking_rate_range(0.20)
148
+ plot_masking_rate_range(0.25)
149
+ plot_masking_rate_range(0.15, 0.20)
150
+ plot_masking_rate_range(0.15, 0.25)
151
+ plot_masking_rate_range(0.15, 0.30)
152
+ plot_masking_rate_range(0.15, 0.35)
153
+ plot_masking_rate_range(0.15, 0.40)
154
+
155
+ if __name__ == "__main__":
156
+ main()
fuson_plm/paper_figures/fig2/stepwise_3_epochs.png ADDED