File size: 2,253 Bytes
4bfd44d
 
 
 
 
 
 
 
 
 
 
e2bcb00
4bfd44d
 
425083f
 
4bfd44d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42decf6
425083f
 
 
 
4bfd44d
 
 
42decf6
4bfd44d
 
 
 
 
 
 
 
 
425083f
4bfd44d
 
 
42decf6
4bfd44d
42decf6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import base64
from io import BytesIO
import gradio as gr
from PIL import Image
import json

from tools.tools import convertToBuffer
from visualize.visualize import removeBgFromSegmentImage, removeOnlyBg
from models.model import getMask, loadModel
from models.preprocess import preprocess

FAST_SAM = loadModel('FastSAM.pt')

# Main processing function


def segment_marker(img_rgb: Image.Image, marker_coordinates: str):
    # Parse marker coordinates from JSON string
    try:
        marker_coordinates = json.loads(marker_coordinates)
    except json.JSONDecodeError:
        return "Invalid marker coordinates format. Ensure it's valid JSON."

    try:
        # Process marker points and labels
        input_points, input_labels = preprocess(marker_coordinates)

        print(f"Processing image with {len(input_points)} marker points...")
        # Get mask for segmentation
        masks = getMask(img_rgb, FAST_SAM, input_points, input_labels)

        # Generate the segmented images
        bg_removed_segmented_img = removeBgFromSegmentImage(img_rgb, masks[0])
        img_base64_bg_segmented = convertToBuffer(bg_removed_segmented_img)

        bg_only_removed_img = removeOnlyBg(img_rgb, masks[0])
        img_base64_only_bg = convertToBuffer(bg_only_removed_img)

        # Return the images in a dictionary format as base64 strings
        return {
            'bg_removed_segmented_img': f'data:image/png;base64,{img_base64_bg_segmented}',
            'bg_only_removed_segmented_img': f'data:image/png;base64,{img_base64_only_bg}'
        }

    except Exception as e:
        print(f"An error occurred: {str(e)}")
        return {'error': "An error occurred while processing the image."}


# Set up the Gradio interface
iface = gr.Interface(
    fn=segment_marker,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Textbox(label="Markers Coordinates (JSON format)")
    ],
    outputs="json",  # Set output to JSON format to return the dictionary
    title="Image Segmentation with Background Removal",
    description="Upload an image and JSON-formatted marker coordinates to perform image segmentation and background removal."
)

# Run the Gradio app
if __name__ == "__main__":
    iface.launch(share=True)