|
import matplotlib.pyplot as plt |
|
import matplotlib.patches as patches |
|
import numpy as np |
|
from fuson_plm.utils.visualizing import set_font |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
|
|
|
|
def compute_cosine_masking_rate(progress, min_rate, max_rate): |
|
cosine_increase = 0.5 * (1 - np.cos(np.pi * progress)) |
|
return min_rate + (max_rate - min_rate) * cosine_increase |
|
|
|
def compute_log_linear_masking_rate(progress, min_rate, max_rate): |
|
|
|
progress = max(progress, 1e-10) |
|
log_linear_increase = np.log1p(progress) / np.log1p(1) |
|
return min_rate + (max_rate - min_rate) * log_linear_increase |
|
|
|
def compute_stepwise_masking_rate(progress, min_rate, max_rate, total_batches, num_steps): |
|
|
|
batch_interval = total_batches // num_steps |
|
rate_increment = (max_rate - min_rate) / (num_steps - 1) |
|
|
|
|
|
current_step = int(progress * total_batches / batch_interval) |
|
|
|
current_step = min(current_step, num_steps - 1) |
|
|
|
|
|
masking_rate = min_rate + current_step * rate_increment |
|
return masking_rate |
|
|
|
def n_epochs_scheduler_plot(min_rate, max_rate, total_batches, num_epochs=5, scheduler="cosine", num_steps=20): |
|
set_font() |
|
|
|
batch_numbers = np.arange(total_batches) |
|
|
|
masking_rates = None |
|
if scheduler == "cosine": |
|
masking_rates = [compute_cosine_masking_rate(batch / total_batches, min_rate, max_rate) for batch in batch_numbers] |
|
elif scheduler == "log_linear": |
|
masking_rates = [compute_log_linear_masking_rate(batch / total_batches, min_rate, max_rate) for batch in batch_numbers] |
|
elif scheduler == "stepwise": |
|
masking_rates = [compute_stepwise_masking_rate(batch / total_batches, min_rate, max_rate, total_batches, num_steps) for batch in batch_numbers] |
|
else: |
|
return |
|
|
|
|
|
epoch_masking_rates = [] |
|
for epoch in range(num_epochs): |
|
epoch_masking_rates.extend(masking_rates) |
|
|
|
|
|
extended_batch_numbers = np.arange(len(epoch_masking_rates)) |
|
|
|
|
|
plt.figure(figsize=(10, 4)) |
|
plt.plot(extended_batch_numbers, epoch_masking_rates, color='black', linewidth=3) |
|
|
|
|
|
plt.yticks( |
|
[0.15, 0.20, 0.25, 0.30, 0.35, 0.40], |
|
labels=["0.15", "0.20", "0.25", "0.30", "0.35", "0.40"], |
|
fontsize=30 |
|
) |
|
|
|
|
|
wave_positions = [total_batches * (i + 1) - 1 for i in range(num_epochs)] |
|
wave_labels = [str(i + 1) if i < num_epochs - 1 else "N" for i in range(num_epochs)] |
|
|
|
plt.xticks( |
|
wave_positions, |
|
labels=wave_labels, |
|
fontsize=30 |
|
) |
|
|
|
|
|
if num_epochs > 2: |
|
mid_x = (wave_positions[1] + wave_positions[-1]) / 2 |
|
plt.text(mid_x, 0.12, "...", ha="center", fontsize=30) |
|
|
|
|
|
|
|
plt.gca().set_xlabel('') |
|
plt.gca().set_ylabel('') |
|
plt.title('') |
|
|
|
plt.tight_layout() |
|
plt.show() |
|
plt.savefig(f"{scheduler}_{num_epochs}_epochs.png", dpi=300) |
|
|
|
|
|
def plot_masking_rate_range(min_rate, max_rate=None): |
|
set_font() |
|
plt.figure(figsize=(5, 1)) |
|
|
|
if max_rate is None: |
|
|
|
plt.axvline(x=min_rate, color='black', linestyle='-', linewidth=4, label=f"Rate = {min_rate}") |
|
|
|
else: |
|
|
|
plt.fill_betweenx([0, 1], min_rate, max_rate, color='black', alpha=0.5) |
|
|
|
|
|
|
|
|
|
plt.xlim(0.145, 0.40) |
|
plt.xticks([0.15, 0.20, 0.25, 0.30, 0.35, 0.40], fontsize=20) |
|
plt.tick_params(axis='y', which='both', left=False, labelleft=False) |
|
|
|
|
|
plt.gca().spines['top'].set_visible(False) |
|
plt.gca().spines['right'].set_visible(False) |
|
plt.gca().spines['left'].set_visible(False) |
|
plt.gca().spines['bottom'].set_linewidth(0.5) |
|
plt.gca().yaxis.set_visible(False) |
|
plt.xlabel("") |
|
plt.title("") |
|
|
|
plt.tight_layout() |
|
plt.show() |
|
plot_title = f"mask_rate_{min_rate}.png" |
|
if max_rate is not None: |
|
plot_title =f"mask_rate_{min_rate}_{max_rate}.png" |
|
plt.savefig(plot_title, dpi=300) |
|
|
|
|
|
def main(): |
|
min_rate = 0.15 |
|
max_rate = 0.40 |
|
num_steps = 20 |
|
total_batches = 4215 |
|
num_epochs = 3 |
|
|
|
|
|
n_epochs_scheduler_plot(min_rate, max_rate, total_batches, num_epochs=num_epochs, scheduler="cosine") |
|
|
|
|
|
n_epochs_scheduler_plot(min_rate, max_rate, total_batches, num_epochs=num_epochs, scheduler="log_linear") |
|
|
|
|
|
n_epochs_scheduler_plot(min_rate, max_rate, total_batches, num_epochs=num_epochs, scheduler="stepwise") |
|
|
|
|
|
plot_masking_rate_range(0.15) |
|
plot_masking_rate_range(0.20) |
|
plot_masking_rate_range(0.25) |
|
plot_masking_rate_range(0.15, 0.20) |
|
plot_masking_rate_range(0.15, 0.25) |
|
plot_masking_rate_range(0.15, 0.30) |
|
plot_masking_rate_range(0.15, 0.35) |
|
plot_masking_rate_range(0.15, 0.40) |
|
|
|
if __name__ == "__main__": |
|
main() |