|
import os |
|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from PIL import Image |
|
import requests |
|
import io |
|
import matplotlib.colors as mcolors |
|
import cv2 |
|
from io import BytesIO |
|
import urllib.request |
|
import tempfile |
|
import rasterio |
|
import warnings |
|
import pandas as pd |
|
import joblib |
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
try: |
|
import segmentation_models_pytorch as smp |
|
smp_available = True |
|
print("Successfully imported segmentation_models_pytorch") |
|
except ImportError: |
|
smp_available = False |
|
print("Warning: segmentation_models_pytorch not available, will try to install it") |
|
import subprocess |
|
try: |
|
subprocess.check_call([ |
|
"pip", "install", "segmentation-models-pytorch" |
|
]) |
|
import segmentation_models_pytorch as smp |
|
smp_available = True |
|
print("Successfully installed and imported segmentation_models_pytorch") |
|
except: |
|
print("Failed to install segmentation_models_pytorch") |
|
|
|
|
|
try: |
|
import albumentations as A |
|
albumentations_available = True |
|
print("Successfully imported albumentations") |
|
except ImportError: |
|
albumentations_available = False |
|
print("Warning: albumentations not available, will try to install it") |
|
import subprocess |
|
try: |
|
subprocess.check_call([ |
|
"pip", "install", "albumentations" |
|
]) |
|
import albumentations as A |
|
albumentations_available = True |
|
print("Successfully installed and imported albumentations") |
|
except: |
|
print("Failed to install albumentations, will use OpenCV for transforms") |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {device}") |
|
|
|
|
|
if smp_available: |
|
|
|
model = smp.DeepLabV3Plus( |
|
encoder_name="resnet34", |
|
encoder_weights=None, |
|
in_channels=3, |
|
classes=1, |
|
) |
|
else: |
|
|
|
print("Warning: Using a placeholder model that won't produce valid predictions.") |
|
from torch import nn |
|
class PlaceholderModel(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.conv = nn.Conv2d(3, 1, 3, padding=1) |
|
def forward(self, x): |
|
return self.conv(x) |
|
model = PlaceholderModel() |
|
|
|
|
|
SEGMENTATION_MODEL_REPO = "dcrey7/wetlands_segmentation_deeplabsv3plus" |
|
SEGMENTATION_MODEL_FILENAME = "DeepLabV3plus_best_model.pth" |
|
|
|
def download_model_weights(): |
|
"""Download model weights from HuggingFace repository""" |
|
try: |
|
os.makedirs('weights', exist_ok=True) |
|
local_path = os.path.join('weights', SEGMENTATION_MODEL_FILENAME) |
|
|
|
|
|
if os.path.exists(local_path): |
|
print(f"Model weights already downloaded at {local_path}") |
|
return local_path |
|
|
|
|
|
print(f"Downloading model weights from {SEGMENTATION_MODEL_REPO}...") |
|
url = f"https://huggingface.co/{SEGMENTATION_MODEL_REPO}/resolve/main/{SEGMENTATION_MODEL_FILENAME}" |
|
urllib.request.urlretrieve(url, local_path) |
|
print(f"Model weights downloaded to {local_path}") |
|
return local_path |
|
except Exception as e: |
|
print(f"Error downloading model weights: {e}") |
|
return None |
|
|
|
|
|
weights_path = download_model_weights() |
|
if weights_path: |
|
try: |
|
|
|
state_dict = torch.load(weights_path, map_location=device) |
|
|
|
if all(key.startswith('encoder.') or key.startswith('decoder.') for key in list(state_dict.keys())[:5]): |
|
print("Model weights use encoder/decoder format, loading directly") |
|
model.load_state_dict(state_dict, strict=False) |
|
else: |
|
print("Attempting to adapt state dict to match model architecture") |
|
|
|
model.load_state_dict(state_dict, strict=False) |
|
print("Model weights loaded successfully") |
|
except Exception as e: |
|
print(f"Error loading model weights: {e}") |
|
else: |
|
print("No weights available. Model will not produce valid predictions.") |
|
|
|
model.to(device) |
|
model.eval() |
|
|
|
|
|
def load_cloud_detection_model(): |
|
"""Load cloud detection model from the local file""" |
|
try: |
|
|
|
model_path = "cloud_detection_lightgbm.joblib" |
|
if os.path.exists(model_path): |
|
|
|
cloud_model = joblib.load(model_path) |
|
print(f"Cloud detection model loaded successfully from {model_path}") |
|
return cloud_model |
|
else: |
|
print(f"Cloud detection model file not found at {model_path}") |
|
return None |
|
except Exception as e: |
|
print(f"Error loading cloud detection model: {e}") |
|
return None |
|
|
|
|
|
cloud_model = load_cloud_detection_model() |
|
if cloud_model: |
|
print("Cloud detection model is ready for predictions") |
|
else: |
|
print("Warning: Cloud detection model could not be loaded") |
|
|
|
def normalize(band): |
|
"""Normalize band values using 2-98 percentile range""" |
|
|
|
band_cleaned = band[np.isfinite(band)] |
|
if len(band_cleaned) == 0: |
|
return band |
|
|
|
|
|
band_min, band_max = np.percentile(band_cleaned, (2, 98)) |
|
|
|
|
|
if band_max == band_min: |
|
return np.zeros_like(band) |
|
|
|
band_normalized = (band - band_min) / (band_max - band_min) |
|
band_normalized = np.clip(band_normalized, 0, 1) |
|
return band_normalized |
|
|
|
def calculate_cv(band): |
|
"""Calculate coefficient of variation (CV) for a band""" |
|
|
|
band_normalized = normalize(band) |
|
|
|
|
|
band_cleaned = band_normalized[np.isfinite(band_normalized)] |
|
if len(band_cleaned) == 0: |
|
return 0 |
|
|
|
|
|
mean = np.mean(band_cleaned) |
|
|
|
|
|
if abs(mean) < 1e-10: |
|
return 0 |
|
|
|
std = np.std(band_cleaned) |
|
cv = (std / mean) |
|
return cv |
|
|
|
def read_tiff_image_for_segmentation(tiff_path): |
|
""" |
|
Read a TIFF image using rasterio, focusing on RGB bands (first 3 bands) |
|
for wetland segmentation |
|
""" |
|
try: |
|
|
|
with rasterio.open(tiff_path) as src: |
|
|
|
if src.count >= 3: |
|
red = src.read(1) |
|
green = src.read(2) |
|
blue = src.read(3) |
|
|
|
|
|
image = np.dstack((red, green, blue)).astype(np.float32) |
|
|
|
|
|
if image.max() > 0: |
|
image = image / image.max() |
|
|
|
return image |
|
else: |
|
|
|
bands = [src.read(i+1) for i in range(src.count)] |
|
|
|
if len(bands) == 1: |
|
image = np.dstack((bands[0], bands[0], bands[0])) |
|
else: |
|
|
|
while len(bands) < 3: |
|
bands.append(np.zeros_like(bands[0])) |
|
image = np.dstack(bands[:3]) |
|
|
|
|
|
if image.max() > 0: |
|
image = image / image.max() |
|
|
|
return image |
|
except Exception as e: |
|
print(f"Error reading TIFF file for segmentation: {e}") |
|
return None |
|
|
|
def extract_cloud_features_from_tiff(tiff_path): |
|
""" |
|
Extract CV features from all bands in a TIFF file for cloud detection. |
|
Will try to use up to 10 bands. |
|
""" |
|
try: |
|
with rasterio.open(tiff_path) as src: |
|
num_bands = min(src.count, 10) |
|
|
|
|
|
features = {} |
|
for i in range(1, num_bands + 1): |
|
band = src.read(i) |
|
|
|
|
|
cv_value = calculate_cv(band) |
|
|
|
|
|
features[f'band{i}_cv'] = cv_value |
|
|
|
|
|
for i in range(num_bands + 1, 11): |
|
features[f'band{i}_cv'] = 0.0 |
|
|
|
return features |
|
except Exception as e: |
|
print(f"Error extracting cloud features from TIFF: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
return None |
|
|
|
def extract_cloud_features_from_rgb(image): |
|
""" |
|
Extract CV features from RGB image for cloud detection. |
|
Will use 3 bands and fill the remaining 7 with zeros to match the expected 10 features. |
|
""" |
|
try: |
|
|
|
if image.dtype != np.float32 and image.dtype != np.float64: |
|
image = image.astype(np.float32) |
|
|
|
if image.max() > 1.0: |
|
image = image / 255.0 |
|
|
|
|
|
features = {} |
|
|
|
|
|
for i in range(min(1, image.shape[2])): |
|
band = image[:, :, i] |
|
cv_value = calculate_cv(band) |
|
features[f'band{i+1}_cv'] = cv_value |
|
|
|
|
|
|
|
|
|
|
|
return features |
|
|
|
except Exception as e: |
|
print(f"Error extracting cloud features from RGB: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
return None |
|
|
|
def predict_cloud(features_dict, model): |
|
"""Predict if an image is cloudy based on extracted features""" |
|
if model is None: |
|
return {'prediction': 'Model unavailable', 'probability': 0.0} |
|
|
|
try: |
|
|
|
feature_dict = {} |
|
for i in range(1, 11): |
|
feature_name = f'band{i}_cv' |
|
feature_dict[feature_name] = features_dict.get(feature_name, 0.0) |
|
|
|
|
|
feature_df = pd.DataFrame([feature_dict]) |
|
|
|
|
|
if hasattr(model, 'set_params'): |
|
model.set_params(predict_disable_shape_check=True) |
|
|
|
|
|
if hasattr(model, 'predict_proba'): |
|
proba = model.predict_proba(feature_df) |
|
if proba.shape[1] > 1: |
|
probability = proba[0][1] |
|
else: |
|
probability = proba[0][0] |
|
else: |
|
|
|
pred = model.predict(feature_df) |
|
probability = float(pred[0]) |
|
|
|
|
|
prediction = 'Cloudy' if probability >= 0.5 else 'Non-Cloudy' |
|
|
|
return { |
|
'prediction': prediction, |
|
'probability': probability |
|
} |
|
except Exception as e: |
|
print(f"Error predicting cloud: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
return {'prediction': 'Error', 'probability': 0.0} |
|
|
|
def read_tiff_mask(mask_path): |
|
""" |
|
Read a TIFF mask using rasterio |
|
This matches your training data loading approach |
|
""" |
|
try: |
|
|
|
with rasterio.open(mask_path) as src: |
|
mask = src.read(1).astype(np.uint8) |
|
return mask |
|
except Exception as e: |
|
print(f"Error reading mask file: {e}") |
|
return None |
|
|
|
def preprocess_image(image, target_size=(128, 128)): |
|
""" |
|
Preprocess an image for inference |
|
""" |
|
|
|
if isinstance(image, np.ndarray): |
|
|
|
if len(image.shape) == 2: |
|
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) |
|
elif image.shape[2] == 4: |
|
image = image[:, :, :3] |
|
|
|
|
|
display_image = image.copy() |
|
|
|
|
|
if display_image.max() > 1.0: |
|
image = image.astype(np.float32) / 255.0 |
|
|
|
|
|
elif isinstance(image, Image.Image): |
|
image = np.array(image) |
|
|
|
|
|
if len(image.shape) == 2: |
|
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) |
|
elif image.shape[2] == 4: |
|
image = image[:, :, :3] |
|
|
|
|
|
display_image = image.copy() |
|
|
|
|
|
image = image.astype(np.float32) / 255.0 |
|
else: |
|
print(f"Unsupported image type: {type(image)}") |
|
return None, None |
|
|
|
|
|
if albumentations_available: |
|
|
|
aug = A.Compose([ |
|
A.PadIfNeeded(min_height=target_size[0], min_width=target_size[1], |
|
border_mode=cv2.BORDER_CONSTANT, value=0), |
|
A.CenterCrop(height=target_size[0], width=target_size[1]) |
|
]) |
|
augmented = aug(image=image) |
|
image_resized = augmented['image'] |
|
else: |
|
|
|
image_resized = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR) |
|
|
|
|
|
image_tensor = torch.from_numpy(image_resized.transpose(2, 0, 1)).float().unsqueeze(0) |
|
|
|
return image_tensor, display_image |
|
|
|
def extract_file_content(file_obj): |
|
"""Extract content from the file object, handling different types""" |
|
try: |
|
if hasattr(file_obj, 'name') and isinstance(file_obj, str): |
|
|
|
content = file_obj |
|
if os.path.exists(content): |
|
|
|
with open(content, 'rb') as f: |
|
return f.read() |
|
else: |
|
|
|
return content.encode('latin1') |
|
elif hasattr(file_obj, 'read'): |
|
|
|
return file_obj.read() |
|
elif isinstance(file_obj, bytes): |
|
|
|
return file_obj |
|
elif isinstance(file_obj, str): |
|
|
|
if os.path.exists(file_obj): |
|
with open(file_obj, 'rb') as f: |
|
return f.read() |
|
else: |
|
return file_obj.encode('utf-8') |
|
else: |
|
print(f"Unsupported file object type: {type(file_obj)}") |
|
return None |
|
except Exception as e: |
|
print(f"Error extracting file content: {e}") |
|
return None |
|
|
|
def process_uploaded_tiff(file_obj): |
|
"""Process an uploaded TIFF file for both segmentation and cloud detection""" |
|
try: |
|
|
|
file_content = extract_file_content(file_obj) |
|
if file_content is None: |
|
print("Failed to extract file content") |
|
return None, None, None |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.tif') as temp_file: |
|
temp_path = temp_file.name |
|
temp_file.write(file_content) |
|
|
|
|
|
image_for_segmentation = read_tiff_image_for_segmentation(temp_path) |
|
|
|
|
|
cloud_features = extract_cloud_features_from_tiff(temp_path) |
|
|
|
|
|
os.unlink(temp_path) |
|
|
|
if image_for_segmentation is None: |
|
return None, None, None |
|
|
|
|
|
display_image = (image_for_segmentation * 255).astype(np.uint8) if image_for_segmentation.max() <= 1.0 else image_for_segmentation.copy() |
|
|
|
|
|
if albumentations_available: |
|
aug = A.Compose([ |
|
A.PadIfNeeded(min_height=128, min_width=128, |
|
border_mode=cv2.BORDER_CONSTANT, value=0), |
|
A.CenterCrop(height=128, width=128) |
|
]) |
|
augmented = aug(image=image_for_segmentation) |
|
image_resized = augmented['image'] |
|
else: |
|
image_resized = cv2.resize(image_for_segmentation, (128, 128), interpolation=cv2.INTER_LINEAR) |
|
|
|
|
|
image_tensor = torch.from_numpy(image_resized.transpose(2, 0, 1)).float().unsqueeze(0) |
|
|
|
return image_tensor, display_image, cloud_features |
|
|
|
except Exception as e: |
|
print(f"Error processing uploaded TIFF: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
return None, None, None |
|
|
|
def process_uploaded_mask(file_obj): |
|
"""Process an uploaded mask file""" |
|
try: |
|
|
|
file_content = extract_file_content(file_obj) |
|
if file_content is None: |
|
return None |
|
|
|
|
|
|
|
suffix = '.tif' |
|
if hasattr(file_obj, 'name'): |
|
file_name = getattr(file_obj, 'name') |
|
if isinstance(file_name, str) and '.' in file_name: |
|
suffix = '.' + file_name.split('.')[-1].lower() |
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file: |
|
temp_path = temp_file.name |
|
temp_file.write(file_content) |
|
|
|
|
|
if temp_path.lower().endswith(('.tif', '.tiff')): |
|
mask = read_tiff_mask(temp_path) |
|
else: |
|
|
|
try: |
|
mask_img = Image.open(temp_path) |
|
mask = np.array(mask_img) |
|
if len(mask.shape) == 3: |
|
mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY) |
|
except Exception as e: |
|
print(f"Error opening mask as regular image: {e}") |
|
os.unlink(temp_path) |
|
return None |
|
|
|
|
|
os.unlink(temp_path) |
|
|
|
if mask is None: |
|
return None |
|
|
|
|
|
if albumentations_available: |
|
aug = A.Compose([ |
|
A.PadIfNeeded(min_height=128, min_width=128, |
|
border_mode=cv2.BORDER_CONSTANT, value=0), |
|
A.CenterCrop(height=128, width=128) |
|
]) |
|
augmented = aug(image=mask) |
|
mask_resized = augmented['image'] |
|
else: |
|
mask_resized = cv2.resize(mask, (128, 128), interpolation=cv2.INTER_NEAREST) |
|
|
|
|
|
mask_binary = (mask_resized > 0).astype(np.uint8) |
|
|
|
return mask_binary |
|
|
|
except Exception as e: |
|
print(f"Error processing uploaded mask: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
return None |
|
|
|
def predict_segmentation(image_tensor): |
|
""" |
|
Run inference on the model |
|
""" |
|
try: |
|
image_tensor = image_tensor.to(device) |
|
|
|
with torch.no_grad(): |
|
output = model(image_tensor) |
|
|
|
|
|
if isinstance(output, dict): |
|
output = output['out'] |
|
if output.shape[1] > 1: |
|
pred = torch.argmax(output, dim=1).squeeze(0).cpu().numpy() |
|
else: |
|
pred = (torch.sigmoid(output) > 0.5).squeeze().cpu().numpy().astype(np.uint8) |
|
|
|
return pred |
|
except Exception as e: |
|
print(f"Error during prediction: {e}") |
|
return None |
|
|
|
def calculate_metrics(pred_mask, gt_mask): |
|
""" |
|
Calculate evaluation metrics between prediction and ground truth |
|
""" |
|
|
|
pred_binary = (pred_mask > 0).astype(np.uint8) |
|
gt_binary = (gt_mask > 0).astype(np.uint8) |
|
|
|
|
|
intersection = np.logical_and(pred_binary, gt_binary).sum() |
|
union = np.logical_or(pred_binary, gt_binary).sum() |
|
|
|
|
|
iou = intersection / union if union > 0 else 0 |
|
|
|
|
|
true_positive = intersection |
|
false_positive = pred_binary.sum() - true_positive |
|
false_negative = gt_binary.sum() - true_positive |
|
|
|
precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 0 |
|
recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0 |
|
|
|
|
|
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 |
|
|
|
metrics = { |
|
"IoU": float(iou), |
|
"Precision": float(precision), |
|
"Recall": float(recall), |
|
"F1 Score": float(f1) |
|
} |
|
|
|
return metrics |
|
|
|
def process_images(input_image=None, input_tiff=None, gt_mask_file=None): |
|
""" |
|
Process input images, generate predictions for both wetland segmentation and cloud detection |
|
""" |
|
try: |
|
|
|
if input_image is None and input_tiff is None: |
|
return None, "Please upload an image or TIFF file." |
|
|
|
|
|
cloud_features = None |
|
|
|
if input_tiff is not None and input_tiff: |
|
|
|
image_tensor, display_image, cloud_features = process_uploaded_tiff(input_tiff) |
|
if image_tensor is None: |
|
return None, "Failed to process the input TIFF file." |
|
elif input_image is not None: |
|
|
|
image_tensor, display_image = preprocess_image(input_image) |
|
if image_tensor is None: |
|
return None, "Failed to process the input image." |
|
|
|
|
|
cloud_features = extract_cloud_features_from_rgb(display_image) |
|
else: |
|
return None, "No valid input provided." |
|
|
|
|
|
pred_mask = predict_segmentation(image_tensor) |
|
if pred_mask is None: |
|
return None, "Failed to generate wetland segmentation prediction." |
|
|
|
|
|
cloud_result = {'prediction': 'Unknown', 'probability': 0.0} |
|
if cloud_features and cloud_model: |
|
cloud_result = predict_cloud(cloud_features, cloud_model) |
|
|
|
|
|
gt_mask_processed = None |
|
metrics_text = "" |
|
|
|
if gt_mask_file is not None and gt_mask_file: |
|
gt_mask_processed = process_uploaded_mask(gt_mask_file) |
|
|
|
if gt_mask_processed is not None: |
|
metrics = calculate_metrics(pred_mask, gt_mask_processed) |
|
metrics_text = "\n".join([f"{k}: {v:.4f}" for k, v in metrics.items()]) |
|
|
|
|
|
fig = plt.figure(figsize=(12, 6)) |
|
|
|
if gt_mask_processed is not None: |
|
|
|
plt.subplot(1, 3, 1) |
|
plt.imshow(display_image) |
|
plt.title("Input Image") |
|
plt.axis('off') |
|
|
|
plt.subplot(1, 3, 2) |
|
plt.imshow(gt_mask_processed, cmap='binary') |
|
plt.title("Ground Truth") |
|
plt.axis('off') |
|
|
|
plt.subplot(1, 3, 3) |
|
plt.imshow(pred_mask, cmap='binary') |
|
plt.title("Prediction") |
|
plt.axis('off') |
|
else: |
|
|
|
plt.subplot(1, 2, 1) |
|
plt.imshow(display_image) |
|
plt.title("Input Image") |
|
plt.axis('off') |
|
|
|
plt.subplot(1, 2, 2) |
|
plt.imshow(pred_mask, cmap='binary') |
|
plt.title("Predicted Wetlands") |
|
plt.axis('off') |
|
|
|
|
|
wetland_percentage = np.mean(pred_mask) * 100 |
|
|
|
|
|
result_text = f"Wetland Coverage: {wetland_percentage:.2f}%\n\n" |
|
|
|
|
|
result_text += f"Cloud Detection: {cloud_result['prediction']} " |
|
result_text += f"({cloud_result['probability']*100:.2f}% confidence)\n\n" |
|
|
|
|
|
if metrics_text: |
|
result_text += f"Evaluation Metrics:\n{metrics_text}" |
|
|
|
|
|
plt.tight_layout() |
|
buf = BytesIO() |
|
plt.savefig(buf, format='png') |
|
buf.seek(0) |
|
result_image = Image.open(buf) |
|
plt.close(fig) |
|
|
|
return result_image, result_text |
|
|
|
except Exception as e: |
|
print(f"Error in processing: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
return None, f"Error: {str(e)}" |
|
|
|
|
|
with gr.Blocks(title="Wetlands Segmentation & Cloud Detection") as demo: |
|
gr.Markdown("# Wetlands Segmentation & Cloud Detection from Satellite Imagery") |
|
gr.Markdown("Upload a satellite image or TIFF file to identify wetland areas and detect cloud cover. Optionally, you can also upload a ground truth mask for evaluation.") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
gr.Markdown("### Input") |
|
with gr.Tab("Upload Image"): |
|
input_image = gr.Image(label="Upload Satellite Image", type="numpy") |
|
|
|
with gr.Tab("Upload TIFF"): |
|
input_tiff = gr.File(label="Upload TIFF File", file_types=[".tif", ".tiff"]) |
|
|
|
|
|
gt_mask_file = gr.File(label="Ground Truth Mask (Optional)", file_types=[".tif", ".tiff", ".png", ".jpg", ".jpeg"]) |
|
|
|
process_btn = gr.Button("Analyze Image", variant="primary") |
|
|
|
with gr.Column(): |
|
|
|
gr.Markdown("### Results") |
|
output_image = gr.Image(label="Segmentation Results", type="pil") |
|
output_text = gr.Textbox(label="Statistics", lines=8) |
|
|
|
|
|
gr.Markdown("### About these models") |
|
gr.Markdown(""" |
|
This application uses two models: |
|
|
|
**1. Wetland Segmentation Model:** |
|
- Architecture: DeepLabv3+ with ResNet-34 |
|
- Input: RGB satellite imagery |
|
- Output: Binary segmentation mask (Wetland vs Background) |
|
- Resolution: 128×128 pixels |
|
|
|
**2. Cloud Detection Model:** |
|
- Architecture: LightGBM Classifier |
|
- Input: CV features extracted from up to 10 image bands |
|
- Output: Binary classification (Cloudy vs Non-Cloudy) with probability |
|
|
|
**Tips for best results:** |
|
- For Cloudy image - train_11202327_p1, for Non cloudy image - train_02202325_p1 |
|
- For Cloudy image - test_07202330_p1, for Non cloudy image - test_02202325_p1 |
|
- The models work best with multi-band satellite imagery (TIFF files) |
|
- For optimal cloud detection results, use TIFF files with 10 bands |
|
- For optimal results, use images with similar characteristics to those used in training |
|
- The wetland model focuses on identifying wetland regions in natural landscapes |
|
- The cloud model detects cloud cover based on image band statistics |
|
- For ground truth masks, both TIFF and standard image formats are supported |
|
|
|
**Repository:** [dcrey7/wetlands_segmentation_deeplabsv3plus](https://huggingface.co/dcrey7/wetlands_segmentation_deeplabsv3plus) |
|
""") |
|
|
|
|
|
process_btn.click( |
|
fn=process_images, |
|
inputs=[input_image, input_tiff, gt_mask_file], |
|
outputs=[output_image, output_text] |
|
) |
|
|
|
|
|
demo.launch() |