Fill-Mask
Transformers
Safetensors
esm
svincoff's picture
adding code for creating Fig 1B and Fig 2B
0e3c3b0
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
from fuson_plm.utils.visualizing import set_font
import matplotlib.pyplot as plt
import numpy as np
# Cosine Increase Masking Rate Scheduler implementation
def compute_cosine_masking_rate(progress, min_rate, max_rate):
cosine_increase = 0.5 * (1 - np.cos(np.pi * progress))
return min_rate + (max_rate - min_rate) * cosine_increase
def compute_log_linear_masking_rate(progress, min_rate, max_rate):
# Avoid log(0) by clamping progress to a minimum of a small positive number
progress = max(progress, 1e-10)
log_linear_increase = np.log1p(progress) / np.log1p(1) # Normalizing to keep range in [0, 1]
return min_rate + (max_rate - min_rate) * log_linear_increase
def compute_stepwise_masking_rate(progress, min_rate, max_rate, total_batches, num_steps):
# Compute the batch interval and rate increment
batch_interval = total_batches // num_steps
rate_increment = (max_rate - min_rate) / (num_steps - 1) # Include max_rate in steps
# Determine the current step based on progress
current_step = int(progress * total_batches / batch_interval)
# Cap the step number to `num_steps - 1` to ensure max rate is included
current_step = min(current_step, num_steps - 1)
# Calculate the masking rate for the current step
masking_rate = min_rate + current_step * rate_increment
return masking_rate
def n_epochs_scheduler_plot(min_rate, max_rate, total_batches, num_epochs=5, scheduler="cosine", num_steps=20):
set_font()
# Parameters for the scheduler - using training
batch_numbers = np.arange(total_batches)
masking_rates = None
if scheduler == "cosine":
masking_rates = [compute_cosine_masking_rate(batch / total_batches, min_rate, max_rate) for batch in batch_numbers]
elif scheduler == "log_linear":
masking_rates = [compute_log_linear_masking_rate(batch / total_batches, min_rate, max_rate) for batch in batch_numbers]
elif scheduler == "stepwise":
masking_rates = [compute_stepwise_masking_rate(batch / total_batches, min_rate, max_rate, total_batches, num_steps) for batch in batch_numbers]
else:
return
# Generate masking rates for multiple epochs
epoch_masking_rates = []
for epoch in range(num_epochs):
epoch_masking_rates.extend(masking_rates)
# Generate batch numbers for the extended epochs
extended_batch_numbers = np.arange(len(epoch_masking_rates))
# Plot the masking rate over the batches for multiple epochs
plt.figure(figsize=(10, 4))
plt.plot(extended_batch_numbers, epoch_masking_rates, color='black', linewidth=3)
# Add y ticks
plt.yticks(
[0.15, 0.20, 0.25, 0.30, 0.35, 0.40],
labels=["0.15", "0.20", "0.25", "0.30", "0.35", "0.40"],
fontsize=30
)
# Add x tick labels at the end of each wave
wave_positions = [total_batches * (i + 1) - 1 for i in range(num_epochs)]
wave_labels = [str(i + 1) if i < num_epochs - 1 else "N" for i in range(num_epochs)]
plt.xticks(
wave_positions,
labels=wave_labels,
fontsize=30
)
# Add "..." between the second and last wave
if num_epochs > 2:
mid_x = (wave_positions[1] + wave_positions[-1]) / 2
plt.text(mid_x, 0.12, "...", ha="center", fontsize=30)
# Remove axis labels and title
plt.gca().set_xlabel('') # Remove x-axis label
plt.gca().set_ylabel('') # Remove y-axis label
plt.title('') # Remove the title
plt.tight_layout()
plt.show()
plt.savefig(f"{scheduler}_{num_epochs}_epochs.png", dpi=300)
def plot_masking_rate_range(min_rate, max_rate=None):
set_font()
plt.figure(figsize=(5, 1)) # Make the plot short to emphasize the rectangle
if max_rate is None:
# Plot a vertical red line at min_rate
plt.axvline(x=min_rate, color='black', linestyle='-', linewidth=4, label=f"Rate = {min_rate}")
#plt.text(min_rate, 0.5, f"{min_rate:.2f}", color='red', ha='center', va='center', fontsize=10)
else:
# Shade the range from min_rate to max_rate in red
plt.fill_betweenx([0, 1], min_rate, max_rate, color='black', alpha=0.5)
#plt.text(min_rate, 0.5, f"{min_rate:.2f}", color='black', ha='center', va='center', fontsize=10)
#plt.text(max_rate, 0.5, f"{max_rate:.2f}", color='black', ha='center', va='center', fontsize=10)
# Adjust x-axis
plt.xlim(0.145, 0.40)
plt.xticks([0.15, 0.20, 0.25, 0.30, 0.35, 0.40], fontsize=20)
plt.tick_params(axis='y', which='both', left=False, labelleft=False) # Remove y-axis ticks and labels
# Remove unnecessary elements
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['left'].set_visible(False)
plt.gca().spines['bottom'].set_linewidth(0.5)
plt.gca().yaxis.set_visible(False) # Remove y-axis entirely
plt.xlabel("") # No x-axis label
plt.title("") # No title
plt.tight_layout()
plt.show()
plot_title = f"mask_rate_{min_rate}.png"
if max_rate is not None:
plot_title =f"mask_rate_{min_rate}_{max_rate}.png"
plt.savefig(plot_title, dpi=300)
def main():
min_rate = 0.15
max_rate = 0.40
num_steps = 20
total_batches = 4215
num_epochs = 3
# Make the 3-epoch cosine plot
n_epochs_scheduler_plot(min_rate, max_rate, total_batches, num_epochs=num_epochs, scheduler="cosine")
# Make the 3-epoch log-linear plot
n_epochs_scheduler_plot(min_rate, max_rate, total_batches, num_epochs=num_epochs, scheduler="log_linear")
# Make the 3-epoch stepwise plot
n_epochs_scheduler_plot(min_rate, max_rate, total_batches, num_epochs=num_epochs, scheduler="stepwise")
# Make all the rate plots
plot_masking_rate_range(0.15)
plot_masking_rate_range(0.20)
plot_masking_rate_range(0.25)
plot_masking_rate_range(0.15, 0.20)
plot_masking_rate_range(0.15, 0.25)
plot_masking_rate_range(0.15, 0.30)
plot_masking_rate_range(0.15, 0.35)
plot_masking_rate_range(0.15, 0.40)
if __name__ == "__main__":
main()