Spaces:
Sleeping
Sleeping
| import cv2 | |
| import numpy as np | |
| import torch | |
| import gradio as gr | |
| import segmentation_models_pytorch as smp | |
| from PIL import Image | |
| import boto3 | |
| import uuid | |
| import io | |
| from glob import glob | |
| import os | |
| from pipeline.ImgOutlier import detect_outliers | |
| from pipeline.normalization import align_images | |
| # Detect if running inside Hugging Face Spaces | |
| HF_SPACE = os.environ.get('SPACE_ID') is not None | |
| # DigitalOcean Spaces upload function | |
| def upload_mask(image, prefix="mask"): | |
| """ | |
| Upload segmentation mask image to DigitalOcean Spaces | |
| Args: | |
| image: PIL Image object | |
| prefix: filename prefix | |
| Returns: | |
| Public URL of the uploaded file | |
| """ | |
| try: | |
| # Get credentials from environment variables | |
| do_key = os.environ.get('DO_SPACES_KEY') | |
| do_secret = os.environ.get('DO_SPACES_SECRET') | |
| do_region = os.environ.get('DO_SPACES_REGION') | |
| do_bucket = os.environ.get('DO_SPACES_BUCKET') | |
| # Check if credentials exist | |
| if not all([do_key, do_secret, do_region, do_bucket]): | |
| return "DigitalOcean credentials not set" | |
| # Create S3 client | |
| session = boto3.session.Session() | |
| client = session.client('s3', | |
| region_name=do_region, | |
| endpoint_url=f'https://{do_region}.digitaloceanspaces.com', | |
| aws_access_key_id=do_key, | |
| aws_secret_access_key=do_secret) | |
| # Generate unique filename | |
| filename = f"{prefix}_{uuid.uuid4().hex}.png" | |
| # Convert image to bytes | |
| img_byte_arr = io.BytesIO() | |
| image.save(img_byte_arr, format='PNG') | |
| img_byte_arr.seek(0) | |
| # Upload to Spaces | |
| client.upload_fileobj( | |
| img_byte_arr, | |
| do_bucket, | |
| filename, | |
| ExtraArgs={'ACL': 'public-read', 'ContentType': 'image/png'} | |
| ) | |
| # Return public URL | |
| url = f'https://{do_bucket}.{do_region}.digitaloceanspaces.com/{filename}' | |
| return url | |
| except Exception as e: | |
| print(f"Upload failed: {str(e)}") | |
| return f"Upload error: {str(e)}" | |
| # Global Configuration | |
| MODEL_PATHS = { | |
| "Metal Marcy": "models/MM_best_model.pth", | |
| "Silhouette Jaenette": "models/SJ_best_model.pth" | |
| } | |
| REFERENCE_VECTOR_PATHS = { | |
| "Metal Marcy": "models/MM_mean.npy", | |
| "Silhouette Jaenette": "models/SJ_mean.npy" | |
| } | |
| REFERENCE_IMAGE_DIRS = { | |
| "Metal Marcy": "reference_images/MM", | |
| "Silhouette Jaenette": "reference_images/SJ" | |
| } | |
| # Category names and color mapping | |
| CLASSES = ['background', 'cobbles', 'drysand', 'plant', 'sky', 'water', 'wetsand'] | |
| COLORS = [ | |
| [0, 0, 0], # background - black | |
| [139, 137, 137], # cobbles - dark gray | |
| [255, 228, 181], # drysand - light yellow | |
| [0, 128, 0], # plant - green | |
| [135, 206, 235], # sky - sky blue | |
| [0, 0, 255], # water - blue | |
| [194, 178, 128] # wetsand - sand brown | |
| ] | |
| # Load model function | |
| def load_model(model_path, device="cuda"): | |
| try: | |
| # If running inside HF Spaces, default to CPU | |
| if HF_SPACE: | |
| device = "cpu" | |
| elif not torch.cuda.is_available(): | |
| device = "cpu" | |
| model = smp.create_model( | |
| "DeepLabV3Plus", | |
| encoder_name="efficientnet-b6", | |
| in_channels=3, | |
| classes=len(CLASSES), | |
| encoder_weights=None | |
| ) | |
| state_dict = torch.load(model_path, map_location=device) | |
| if all(k.startswith('model.') for k in state_dict.keys()): | |
| state_dict = {k[6:]: v for k, v in state_dict.items()} | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| model.eval() | |
| print(f"Model loaded successfully: {model_path}") | |
| return model | |
| except Exception as e: | |
| print(f"Model loading failed: {e}") | |
| return None | |
| # Load reference vector | |
| def load_reference_vector(vector_path): | |
| try: | |
| if not os.path.exists(vector_path): | |
| print(f"Reference vector file not found: {vector_path}") | |
| return [] | |
| ref_vector = np.load(vector_path) | |
| print(f"Reference vector loaded successfully: {vector_path}") | |
| return ref_vector | |
| except Exception as e: | |
| print(f"Reference vector loading failed {vector_path}: {e}") | |
| return [] | |
| # Load reference images | |
| def load_reference_images(ref_dir): | |
| try: | |
| if not os.path.exists(ref_dir): | |
| print(f"Reference image directory not found: {ref_dir}") | |
| os.makedirs(ref_dir, exist_ok=True) | |
| return [] | |
| image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp'] | |
| image_files = [] | |
| for ext in image_extensions: | |
| image_files.extend(glob(os.path.join(ref_dir, ext))) | |
| image_files.sort() | |
| reference_images = [] | |
| for file in image_files[:4]: | |
| img = cv2.imread(file) | |
| if img is not None: | |
| reference_images.append(img) | |
| print(f"Loaded {len(reference_images)} images from {ref_dir}") | |
| return reference_images | |
| except Exception as e: | |
| print(f"Image loading failed {ref_dir}: {e}") | |
| return [] | |
| # Preprocess the image | |
| def preprocess_image(image): | |
| if image.shape[2] == 4: | |
| image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) | |
| orig_h, orig_w = image.shape[:2] | |
| image_resized = cv2.resize(image, (1024, 1024)) | |
| image_norm = image_resized.astype(np.float32) / 255.0 | |
| mean = np.array([0.485, 0.456, 0.406]) | |
| std = np.array([0.229, 0.224, 0.225]) | |
| image_norm = (image_norm - mean) / std | |
| image_tensor = torch.from_numpy(image_norm.transpose(2, 0, 1)).float().unsqueeze(0) | |
| return image_tensor, orig_h, orig_w | |
| # Generate segmentation map and visualization | |
| def generate_segmentation_map(prediction, orig_h, orig_w): | |
| mask = prediction.argmax(1).squeeze().cpu().numpy().astype(np.uint8) | |
| mask_resized = cv2.resize(mask, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST) | |
| kernel = np.ones((5, 5), np.uint8) | |
| processed_mask = mask_resized.copy() | |
| for idx in range(1, len(CLASSES)): | |
| class_mask = (mask_resized == idx).astype(np.uint8) | |
| dilated_mask = cv2.dilate(class_mask, kernel, iterations=2) | |
| dilated_effect = dilated_mask & (mask_resized == 0) | |
| processed_mask[dilated_effect > 0] = idx | |
| segmentation_map = np.zeros((orig_h, orig_w, 3), dtype=np.uint8) | |
| for idx, color in enumerate(COLORS): | |
| segmentation_map[processed_mask == idx] = color | |
| return segmentation_map | |
| # Analysis result HTML | |
| def create_analysis_result(mask): | |
| total_pixels = mask.size | |
| percentages = {cls: round((np.sum(mask == i) / total_pixels) * 100, 1) | |
| for i, cls in enumerate(CLASSES)} | |
| ordered = ['sky', 'cobbles', 'plant', 'drysand', 'wetsand', 'water'] | |
| result = "<div style='font-size:18px;font-weight:bold;'>" | |
| result += " | ".join(f"{cls}: {percentages.get(cls,0)}%" for cls in ordered) | |
| result += "</div>" | |
| return result | |
| # Merge and overlay | |
| def create_overlay(image, segmentation_map, alpha=0.5): | |
| if image.shape[:2] != segmentation_map.shape[:2]: | |
| segmentation_map = cv2.resize(segmentation_map, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST) | |
| return cv2.addWeighted(image, 1-alpha, segmentation_map, alpha, 0) | |
| # Perform segmentation | |
| def perform_segmentation(model, image_bgr): | |
| device = "cuda" if torch.cuda.is_available() and not HF_SPACE else "cpu" | |
| image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) | |
| image_tensor, orig_h, orig_w = preprocess_image(image_rgb) | |
| with torch.no_grad(): | |
| prediction = model(image_tensor.to(device)) | |
| seg_map = generate_segmentation_map(prediction, orig_h, orig_w) # RGB | |
| overlay = create_overlay(image_rgb, seg_map) | |
| mask = prediction.argmax(1).squeeze().cpu().numpy() | |
| analysis = create_analysis_result(mask) | |
| return seg_map, overlay, analysis | |
| # Single image processing | |
| def process_coastal_image(location, input_image): | |
| if input_image is None: | |
| return None, None, "Please upload an image", "Not detected", None | |
| device = "cuda" if torch.cuda.is_available() and not HF_SPACE else "cpu" | |
| model = load_model(MODEL_PATHS[location], device) | |
| if model is None: | |
| return None, None, f"Error: Failed to load model", "Not detected", None | |
| ref_vector = load_reference_vector(REFERENCE_VECTOR_PATHS[location]) | |
| ref_images = load_reference_images(REFERENCE_IMAGE_DIRS[location]) | |
| outlier_status = "Not detected" | |
| is_outlier = False | |
| image_bgr = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR) | |
| if len(ref_vector) > 0: | |
| filtered, _ = detect_outliers(ref_images, [image_bgr], ref_vector) | |
| is_outlier = len(filtered) == 0 | |
| elif len(ref_images) > 0: | |
| filtered, _ = detect_outliers(ref_images, [image_bgr]) | |
| is_outlier = len(filtered) == 0 | |
| else: | |
| print("Warning: No reference images or reference vectors available for outlier detection") | |
| is_outlier = False | |
| outlier_status = "Outlier Detection: <span style='color:red;font-weight:bold'>Failed</span>" if is_outlier else "Outlier Detection: <span style='color:green;font-weight:bold'>Passed</span>" | |
| seg_map, overlay, analysis = perform_segmentation(model, image_bgr) | |
| # Try uploading to DigitalOcean Spaces | |
| url = "Local Storage" | |
| try: | |
| url = upload_mask(Image.fromarray(seg_map), prefix=location.replace(' ', '_')) | |
| except Exception as e: | |
| print(f"Upload failed: {e}") | |
| url = f"Upload error: {str(e)}" | |
| if is_outlier: | |
| analysis = "<div style='color:red;font-weight:bold;margin-bottom:10px'>Warning: The image failed outlier detection, the result may be inaccurate!</div>" + analysis | |
| return seg_map, overlay, analysis, outlier_status, url | |
| # Spatial Alignment | |
| def process_with_alignment(location, reference_image, input_image): | |
| if reference_image is None or input_image is None: | |
| return None, None, None, None, "Please upload both reference and target images", "Not processed", None | |
| device = "cuda" if torch.cuda.is_available() and not HF_SPACE else "cpu" | |
| model = load_model(MODEL_PATHS[location], device) | |
| if model is None: | |
| return None, None, None, None, "Error: Failed to load model", "Not processed", None | |
| ref_bgr = cv2.cvtColor(np.array(reference_image), cv2.COLOR_RGB2BGR) | |
| tgt_bgr = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR) | |
| try: | |
| aligned, _ = align_images([ref_bgr, tgt_bgr], [np.zeros_like(ref_bgr), np.zeros_like(tgt_bgr)]) | |
| aligned_tgt_bgr = aligned[1] | |
| except Exception as e: | |
| print(f"Spatial alignment failed: {e}") | |
| return None, None, None, None, f"Spatial alignment failed: {str(e)}", "Processing failed", None | |
| seg_map, overlay, analysis = perform_segmentation(model, aligned_tgt_bgr) | |
| # Try uploading to DigitalOcean Spaces | |
| url = "Local Storage" | |
| try: | |
| url = upload_mask(Image.fromarray(seg_map), prefix="aligned_" + location.replace(' ', '_')) | |
| except Exception as e: | |
| print(f"Upload failed: {e}") | |
| url = f"Upload error: {str(e)}" | |
| status = "Spatial Alignment: <span style='color:green;font-weight:bold'>Completed</span>" | |
| ref_rgb = cv2.cvtColor(ref_bgr, cv2.COLOR_BGR2RGB) | |
| aligned_tgt_rgb = cv2.cvtColor(aligned_tgt_bgr, cv2.COLOR_BGR2RGB) | |
| return ref_rgb, aligned_tgt_rgb, seg_map, overlay, analysis, status, url | |
| # Create the Gradio interface | |
| def create_interface(): | |
| # Set unified display size | |
| disp_w, disp_h = 683, 512 # Maintain aspect ratio | |
| with gr.Blocks(title="Coastal Erosion Analysis System") as demo: | |
| gr.Markdown("""# Coastal Erosion Analysis System | |
| Upload coastal images for analysis, including segmentation and spatial alignment.""") | |
| with gr.Tabs(): | |
| with gr.TabItem("Single Image Segmentation"): | |
| with gr.Row(): | |
| loc1 = gr.Radio(list(MODEL_PATHS.keys()), label="Select Model", value=list(MODEL_PATHS.keys())[0]) | |
| with gr.Row(): | |
| inp = gr.Image(label="Input Image", type="numpy", image_mode="RGB", height=disp_h, width=disp_w) | |
| seg = gr.Image(label="Segmentation Map", type="numpy", height=disp_h, width=disp_w) | |
| ovl = gr.Image(label="Overlay Image", type="numpy", height=disp_h, width=disp_w) | |
| with gr.Row(): | |
| btn1 = gr.Button("Run Segmentation") | |
| url1 = gr.Text(label="Segmentation Image URL") | |
| status1 = gr.HTML(label="Outlier Detection Status") | |
| res1 = gr.HTML(label="Analysis Result") | |
| btn1.click(fn=process_coastal_image, inputs=[loc1, inp], outputs=[seg, ovl, res1, status1, url1]) | |
| with gr.TabItem("Spatial Alignment Segmentation"): | |
| with gr.Row(): | |
| loc2 = gr.Radio(list(MODEL_PATHS.keys()), label="Select Model", value=list(MODEL_PATHS.keys())[0]) | |
| with gr.Row(): | |
| ref_img = gr.Image(label="Reference Image", type="numpy", image_mode="RGB", height=disp_h, width=disp_w) | |
| tgt_img = gr.Image(label="Target Image", type="numpy", image_mode="RGB", height=disp_h, width=disp_w) | |
| with gr.Row(): | |
| btn2 = gr.Button("Run Spatial Alignment and Segmentation") | |
| with gr.Row(): | |
| orig = gr.Image(label="Original Image", type="numpy", height=disp_h, width=disp_w) | |
| aligned = gr.Image(label="Aligned Image", type="numpy", height=disp_h, width=disp_w) | |
| with gr.Row(): | |
| seg2 = gr.Image(label="Segmentation Map", type="numpy", height=disp_h, width=disp_w) | |
| ovl2 = gr.Image(label="Overlay Image", type="numpy", height=disp_h, width=disp_w) | |
| url2 = gr.Text(label="Segmentation Image URL") | |
| status2 = gr.HTML(label="Alignment Status") | |
| res2 = gr.HTML(label="Analysis Result") | |
| btn2.click(fn=process_with_alignment, inputs=[loc2, ref_img, tgt_img], outputs=[orig, aligned, seg2, ovl2, res2, status2, url2]) | |
| return demo | |
| if __name__ == "__main__": | |
| # Create necessary directories | |
| for path in ["models", "reference_images/MM", "reference_images/SJ"]: | |
| os.makedirs(path, exist_ok=True) | |
| # Check if model files exist | |
| for p in MODEL_PATHS.values(): | |
| if not os.path.exists(p): | |
| print(f"Warning: Model file {p} does not exist!") | |
| # Check if DigitalOcean credentials exist | |
| do_creds = [ | |
| os.environ.get('DO_SPACES_KEY'), | |
| os.environ.get('DO_SPACES_SECRET'), | |
| os.environ.get('DO_SPACES_REGION'), | |
| os.environ.get('DO_SPACES_BUCKET') | |
| ] | |
| if not all(do_creds): | |
| print("Warning: Incomplete DigitalOcean Spaces credentials, upload functionality may not work") | |
| # Create and launch the interface | |
| demo = create_interface() | |
| if HF_SPACE: | |
| demo.launch() | |
| else: | |
| demo.launch(share=True) |