Model description
This TransUNet model identifies contrails in satellite images. It takes pre-processed .npy files (images) from the OpenContrails dataset here as its inputs, and returns a "mask" image showing only the contrails overlayed on the same area. We achieve a Mean IOU of 0.6997 on the validation set.
Intended uses
Contrails (vapor trails from airplanes) are the number one contributor to global warming from the aviation industry. We hope that data scientists and researchers focused on reducing contrails will use this model to improve their work. There are current efforts underway to develop models that predict contrails, but one major limiting factor for these efforts is that image labeling is still done by humans (labeled images are needed in order to validate contrail prediction models). Labeling contrails in images is a difficult and expensive task - our model helps researchers efficiently segment satellite images so they can validate and improve contrail prediction models. To learn more about our work, visit our website.
How to Get Started with the Model
Use the code below to get started with the model.
#Required imports and Huggingface authentication
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["SM_FRAMEWORK"] = "tf.keras"
import segmentation_models as sm
import tensorflow as tf
from huggingface_hub import from_pretrained_keras
from huggingface_hub import notebook_login
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
weights = [0.5,0.5] # hyper parameter
dice_loss = sm.losses.DiceLoss(class_weights = weights)
focal_loss = sm.losses.CategoricalFocalLoss()
TOTAL_LOSS_FACTOR = 5
total_loss = dice_loss + (TOTAL_LOSS_FACTOR * focal_loss)
def jaccard_coef(y_true, y_pred):
"""
Defines custom jaccard coefficient metric
"""
y_true_flatten = K.flatten(y_true)
y_pred_flatten = K.flatten(y_pred)
intersection = K.sum(y_true_flatten * y_pred_flatten)
final_coef_value = (intersection + 1.0) / (K.sum(y_true_flatten) + K.sum(y_pred_flatten) - intersection + 1.0)
return final_coef_value
metrics = [tf.keras.metrics.MeanIoU(num_classes=2, sparse_y_true= False, sparse_y_pred=False, name="Mean IOU")]
notebook_login()
# Load model from Huggingface Hub
model = from_pretrained_keras("MIDSCapstoneTeam/ContrailSentinel", custom_objects={'dice_loss_plus_5focal_loss': total_loss, 'jaccard_coef': jaccard_coef, 'IOU score' : sm.metrics.IOUScore(threshold=0.9, name="IOU score"), 'Dice Coeficient' : sm.metrics.FScore(threshold=0.6, name="Dice Coeficient")}, compile=False)
model.compile(metrics=metrics)
# Inference -- User needs to specify the image path where label and ash images are stored
label = np.load({Image path} + 'human_pixel_masks.npy')
ash_image = np.load({Image path} + 'ash_image.npy')[...,4]
y_pred = model.predict(ash_image.reshape(1,256, 256, 3))
prediction = np.argmax(y_pred[0], axis=2).reshape(256,256,1)
fig, ax = plt.subplots(1, 2, figsize=(9, 5))
fig.tight_layout(pad=5.0)
ax[1].set_title("Contrail prediction")
ax[1].imshow(ash_image)
ax[1].imshow(prediction)
ax[1].axis('off')
ax[0].set_title("False colored satellite image")
ax[0].imshow(ash_image)
ax[0].axis('off')
- Downloads last month
- 32