Parthebhan commited on
Commit
8cc50cf
·
verified ·
1 Parent(s): d83235b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -0
app.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForImageSegmentation, AutoTokenizer
3
+ import torch
4
+ from PIL import Image
5
+
6
+ # Load the model
7
+ model = AutoModelForImageSegmentation.from_pretrained("path/to/gelan-c-seg.pt")
8
+
9
+ # Load the tokenizer (if needed)
10
+ tokenizer = AutoTokenizer.from_pretrained("path/to/tokenizer")
11
+
12
+ # Define the prediction function
13
+ def predict_segmentation(image):
14
+ # Convert image to PyTorch tensor
15
+ image = torch.tensor(image).permute(2, 0, 1).unsqueeze(0).float()
16
+ # Perform inference
17
+ output = model(image)
18
+ # Process the output as needed (e.g., post-processing for segmentation masks)
19
+ # (Replace this with your actual processing code)
20
+ segmentation_mask = output.logits.argmax(dim=1).squeeze().detach().numpy()
21
+ return segmentation_mask
22
+
23
+ # Create a Gradio interface
24
+ inputs = gr.inputs.Image(shape=(224, 224))
25
+ outputs = gr.outputs.Image(type="numpy", label="Segmentation Mask")
26
+ title = "Image Segmentation Demo"
27
+ description = "Upload an image and get the segmentation mask."
28
+ examples = [["example.jpg"]] # Add example images here if needed
29
+ interface = gr.Interface(fn=predict_segmentation, inputs=inputs, outputs=outputs, title=title, description=description, examples=examples)
30
+
31
+ # Run the interface
32
+ if __name__ == "__main__":
33
+ interface.launch()