Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| """effect of eta and iterations on sgd.ipynb | |
| Automatically generated by Colaboratory. | |
| Original file is located at | |
| https://colab.research.google.com/drive/1Lso8y1XapdHGJOHnY0pL4ZeSGaa-6lOz | |
| """ | |
| # Commented out IPython magic to ensure Python compatibility. | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| # %matplotlib inline | |
| plt.rcParams['figure.figsize'] = (10, 5) | |
| def generate_data(): | |
| X = 2 * np.random.rand(100, 1) | |
| y = 4 + 3 * X + np.random.randn(100, 1) | |
| return X, y | |
| def get_norm_eqn(X, y): | |
| X_b = np.c_[np.ones((100, 1)), X] | |
| theta_best = np.linalg.inv(X_b.T.dot(X_b)).dot(X_b.T).dot(y) | |
| y_norm = X_b.dot(theta_best) | |
| return X_b, y_norm | |
| def generate_sgd_plot(eta, n_iterations): | |
| #initialize parameters | |
| m = 100 | |
| theta = np.random.randn(2,1) | |
| X, y = generate_data() | |
| X_b, y_norm = get_norm_eqn(X, y) | |
| # plot how the parameters change wrt normal line | |
| # as the algorithm learns | |
| plt.scatter(X,y, c='#7678ed', label="data points") | |
| plt.axis([0, 2.0, 0, 14]) | |
| for iteration in range(n_iterations): | |
| gradients = 2/m * X_b.T.dot(X_b.dot(theta) - y) | |
| theta = theta - eta * gradients | |
| y_new = X_b.dot(theta) | |
| plt.plot(X, y_new, color='#f18701', linestyle='dashed', linewidth=0.2) | |
| plt.plot(X, y_norm, '#3d348b', label="Normal Eqation line") | |
| plt.xlabel('X') | |
| plt.ylabel('Y') | |
| plt.legend(loc='best') | |
| return plt | |
| import gradio as gr | |
| demo = gr.Blocks() | |
| with demo: | |
| gr.Markdown( | |
| """ | |
| # How learning rate and number of iterations affect SGD | |
| Move sliders to change the values of eta and number of iterations to see how it affects the convergance rate of algorithm. | |
| """ | |
| ) | |
| inputs = [gr.Slider(0.02, 0.5, label="learning rate, eta"), gr.Slider(500, 1000, 200, label="number of iterations")] | |
| output = gr.Plot() | |
| btn = gr.Button("Run") | |
| btn.click(fn=generate_sgd_plot, inputs=inputs, outputs=output) | |
| demo.launch() | |