alkzar90's picture
create an image segmentation model
bf601e4
raw
history blame
1.13 kB
import gradio as gr
import torch
from transformers import (SegformerFeatureExtractor,
SegformerForSemanticSegmentation)
from PIL import Image
MODEL_PATH="./best_model_test/"
device = torch.device("cpu")
preprocessor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
model = SegformerForSemanticSegmentation.from_pretrained(MODEL_PATH)
model.eval()
def query_image(img):
"""Función para generar predicciones a la escala origina"""
inputs = preprocessor(images=img, return_tensors="pt")
with torch.no_grad():
preds = model(inputs.unsqueeze(0).to(device))["logits"]
#preds = model(image.unsqueeze(0).to(device))["logits"]
preds_upscale = upscale_logits_modified(preds, image.shape[2])
predict_label = torch.argmax(preds_upscale, dim=1).to(device)
return predict_label[0,:,:].detach().cpu().numpy()
def visualize_instance_seg_mask(mask):
return mask
demo = gr.Interface(
query_image,
inputs=[gr.Image()],
outputs="image",
title="SegFormer Model for rock glacier image segmentation"
)
demo.launch()