import base64 from io import BytesIO import gradio as gr from PIL import Image import json from tools.tools import convertToBuffer from visualize.visualize import removeBgFromSegmentImage, removeOnlyBg from models.model import getMask, loadModel from models.preprocess import preprocess FAST_SAM = loadModel('FastSAM.pt') # Main processing function def segment_marker(img_rgb: Image.Image, marker_coordinates: str): # Parse marker coordinates from JSON string try: marker_coordinates = json.loads(marker_coordinates) except json.JSONDecodeError: return "Invalid marker coordinates format. Ensure it's valid JSON." try: # Process marker points and labels input_points, input_labels = preprocess(marker_coordinates) print(f"Processing image with {len(input_points)} marker points...") # Get mask for segmentation masks = getMask(img_rgb, FAST_SAM, input_points, input_labels) # Generate the segmented images bg_removed_segmented_img = removeBgFromSegmentImage(img_rgb, masks[0]) img_base64_bg_segmented = convertToBuffer(bg_removed_segmented_img) bg_only_removed_img = removeOnlyBg(img_rgb, masks[0]) img_base64_only_bg = convertToBuffer(bg_only_removed_img) # Return the images in a dictionary format as base64 strings return { 'bg_removed_segmented_img': f'data:image/png;base64,{img_base64_bg_segmented}', 'bg_only_removed_segmented_img': f'data:image/png;base64,{img_base64_only_bg}' } except Exception as e: print(f"An error occurred: {str(e)}") return {'error': "An error occurred while processing the image."} # Set up the Gradio interface iface = gr.Interface( fn=segment_marker, inputs=[ gr.Image(type="pil", label="Upload Image"), gr.Textbox(label="Markers Coordinates (JSON format)") ], outputs="json", # Set output to JSON format to return the dictionary title="Image Segmentation with Background Removal", description="Upload an image and JSON-formatted marker coordinates to perform image segmentation and background removal." ) # Run the Gradio app if __name__ == "__main__": iface.launch(share=True)