|
import tensorflow as tf
|
|
import gradio as gr
|
|
import gcvit
|
|
from gcvit.utils import get_gradcam_model, get_gradcam_prediction
|
|
|
|
def predict_fn(image, model_name):
|
|
"""A predict function that will be invoked by gradio."""
|
|
model = getattr(gcvit, model_name)(pretrain=True)
|
|
gradcam_model = get_gradcam_model(model)
|
|
preds, overlay = get_gradcam_prediction(image, gradcam_model, cmap='jet', alpha=0.4, pred_index=None)
|
|
preds = {x[1]:float(x[2]) for x in preds}
|
|
return [preds, overlay]
|
|
|
|
demo = gr.Interface(
|
|
fn=predict_fn,
|
|
inputs=[
|
|
gr.inputs.Image(label="Input Image"),
|
|
gr.Radio(['GCViTTiny', 'GCViTSmall', 'GCViTBase'], value='GCViTTiny', label='Model Name')
|
|
],
|
|
outputs=[
|
|
gr.outputs.Label(label="Prediction"),
|
|
gr.inputs.Image(label="GradCAM"),
|
|
],
|
|
title="Global Context Vision Transformer (GCViT) Demo",
|
|
description="Image Classification with GCViT Model using ImageNet Pretrain Weights.",
|
|
examples=[
|
|
["example/hot_air_ballon.jpg", 'GCViTTiny'],
|
|
["example/chelsea.png", 'GCViTTiny'],
|
|
["example/penguin.JPG", 'GCViTTiny'],
|
|
["example/bus.jpg", 'GCViTTiny'],
|
|
],
|
|
)
|
|
demo.launch() |