segmentation / app.py
Parthebhan's picture
Create app.py
8cc50cf verified
raw
history blame
1.3 kB
import gradio as gr
from transformers import AutoModelForImageSegmentation, AutoTokenizer
import torch
from PIL import Image
# Load the model
model = AutoModelForImageSegmentation.from_pretrained("path/to/gelan-c-seg.pt")
# Load the tokenizer (if needed)
tokenizer = AutoTokenizer.from_pretrained("path/to/tokenizer")
# Define the prediction function
def predict_segmentation(image):
# Convert image to PyTorch tensor
image = torch.tensor(image).permute(2, 0, 1).unsqueeze(0).float()
# Perform inference
output = model(image)
# Process the output as needed (e.g., post-processing for segmentation masks)
# (Replace this with your actual processing code)
segmentation_mask = output.logits.argmax(dim=1).squeeze().detach().numpy()
return segmentation_mask
# Create a Gradio interface
inputs = gr.inputs.Image(shape=(224, 224))
outputs = gr.outputs.Image(type="numpy", label="Segmentation Mask")
title = "Image Segmentation Demo"
description = "Upload an image and get the segmentation mask."
examples = [["example.jpg"]] # Add example images here if needed
interface = gr.Interface(fn=predict_segmentation, inputs=inputs, outputs=outputs, title=title, description=description, examples=examples)
# Run the interface
if __name__ == "__main__":
interface.launch()