Spaces:
Runtime error
Runtime error
# Import libraries | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from sklearn.datasets import fetch_olivetti_faces | |
from sklearn.utils.validation import check_random_state | |
from sklearn.ensemble import ExtraTreesRegressor | |
from sklearn.neighbors import KNeighborsRegressor | |
from sklearn.linear_model import LinearRegression | |
from sklearn.linear_model import RidgeCV | |
import gradio as gr | |
# Load the faces datasets | |
data, targets = fetch_olivetti_faces(return_X_y=True) | |
train = data[targets < 30] | |
n_pixels = data.shape[1] | |
# Training data | |
# Upper half of the faces | |
X_train = train[:, : (n_pixels + 1) // 2] | |
# Lower half of the faces | |
y_train = train[:, n_pixels // 2 :] | |
# Fit estimators -> The problem (given half the image/features extrapolate the rest of the image/features) | |
ESTIMATORS = { | |
"Extra trees": ExtraTreesRegressor( | |
n_estimators=10, max_features=32, random_state=0 | |
), | |
"K-nn": KNeighborsRegressor(), | |
"Linear regression": LinearRegression(), | |
"Ridge": RidgeCV(), | |
} | |
for name, estimator in ESTIMATORS.items(): | |
estimator.fit(X_train, y_train) | |
test = data[targets >= 30] | |
n_faces = 15 | |
rng = check_random_state(4) | |
face_ids = rng.randint(test.shape[0], size=(n_faces,)) | |
test = test[face_ids, :] | |
# Function for returning 64*64 image, given the image index | |
def imageFromIndex(index): | |
return test[int(index)].reshape(1,-1).reshape(64, 64) | |
# Function for extrapolating face | |
def extrapolateFace(index, ESTIMATORS=ESTIMATORS): | |
image = test[int(index)].reshape(1,-1) | |
image_shape = (64, 64) | |
n_cols = 1 + len(ESTIMATORS) | |
n_faces = 1 | |
n_pixels = image.shape[1] | |
# Upper half of the face | |
X_upper = image[:, : (n_pixels + 1) // 2] | |
# Lower half of the face | |
y_ground_truth = image[:, n_pixels // 2 :] | |
# y_predict: Dictionary of predicted lower-faces | |
y_predict = dict() | |
for name, estimator in ESTIMATORS.items(): | |
y_predict[name] = estimator.predict(X_upper) | |
plt.figure(figsize=(2.0 * n_cols, 2.5 * n_faces)) | |
# plt.suptitle("Face completion with multi-output estimators", size=16) | |
true_face = np.hstack((X_upper, y_ground_truth)) | |
sub = plt.subplot(n_faces, n_cols, 1, title="true face") | |
sub.axis("off") | |
sub.imshow( | |
true_face.reshape(image_shape), cmap=plt.cm.gray, interpolation="nearest" | |
) | |
for j, est in enumerate(sorted(ESTIMATORS)): | |
completed_face = np.hstack((X_upper[0], y_predict[est][0])) | |
sub = plt.subplot(n_faces, n_cols, 2 + j, title=est) | |
sub.axis("off") | |
sub.imshow( | |
completed_face.reshape(image_shape), | |
cmap=plt.cm.gray, | |
interpolation="nearest", | |
) | |
return plt | |
with gr.Blocks() as demo: | |
link = "https://scikit-learn.org/stable/auto_examples/miscellaneous/plot_multioutput_face_completion.html#sphx-glr-auto-examples-miscellaneous-plot-multioutput-face-completion-py" | |
title = "Face completion with a multi-output estimators" | |
gr.Markdown(f"# {title}") | |
gr.Markdown(f"### This demo is based on this [scikit-learn example]({link}).") | |
gr.Markdown("### In this demo, we compare 4 multi-output estimators to complete images. \ | |
The goal is to predict the lower half of a face given its upper half.") | |
gr.Markdown("#### Use the below slider to choose a face's image. \ | |
Consequently, observe how the four estimators complete the lower half of that face.") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image_index = gr.Slider(1,15,1,step=1, label="Image Index", info="Choose an image") | |
face_image = gr.Image() | |
with gr.Column(scale=2): | |
plot = gr.Plot(label=f"Face completion with multi-output estimators") | |
image_index.change(imageFromIndex, inputs=[image_index], outputs=[face_image]) | |
image_index.change(extrapolateFace, inputs=[image_index], outputs=[plot]) | |
demo.load(imageFromIndex, inputs=[image_index], outputs=[face_image]) | |
demo.load(extrapolateFace, inputs=[image_index], outputs=[plot]) | |
if __name__ == "__main__": | |
demo.launch(debug=True) |