|
--- |
|
library_name: keras |
|
--- |
|
|
|
## Model description |
|
|
|
This TransUNet model identifies contrails in satellite images. It takes pre-processed .npy files (images) from the OpenContrails dataset [here](https://arxiv.org/abs/2304.02122) as its inputs, and returns a "mask" image showing only the contrails overlayed on the same area. |
|
We achieve a Mean IOU of 0.6997 on the validation set. |
|
|
|
## Intended uses |
|
|
|
Contrails (vapor trails from airplanes) are the [number one](https://www.science.org/content/article/aviation-s-dirty-secret-airplane-contrails-are-surprisingly-potent-cause-global-warming) contributor to global warming from the aviation industry. |
|
We hope that data scientists and researchers focused on reducing contrails will use this model to improve their work. |
|
There are current efforts underway to develop models that predict contrails, but one major limiting factor for these efforts is that image labeling is still done by humans (labeled images are needed in order to validate contrail prediction models). |
|
Labeling contrails in images is a difficult and expensive task - our model helps researchers efficiently segment satellite images so they can validate and improve contrail prediction models. |
|
To learn more about our work, visit [our website](http://contrailsentinel.pythonanywhere.com/). |
|
|
|
## How to Get Started with the Model |
|
|
|
Use the code below to get started with the model. |
|
|
|
``` |
|
#Required imports and Huggingface authentication |
|
import os |
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" |
|
os.environ["SM_FRAMEWORK"] = "tf.keras" |
|
import segmentation_models as sm |
|
import tensorflow as tf |
|
from huggingface_hub import from_pretrained_keras |
|
from huggingface_hub import notebook_login |
|
from PIL import Image |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
weights = [0.5,0.5] # hyper parameter |
|
|
|
dice_loss = sm.losses.DiceLoss(class_weights = weights) |
|
focal_loss = sm.losses.CategoricalFocalLoss() |
|
TOTAL_LOSS_FACTOR = 5 |
|
total_loss = dice_loss + (TOTAL_LOSS_FACTOR * focal_loss) |
|
|
|
def jaccard_coef(y_true, y_pred): |
|
""" |
|
Defines custom jaccard coefficient metric |
|
""" |
|
|
|
y_true_flatten = K.flatten(y_true) |
|
y_pred_flatten = K.flatten(y_pred) |
|
intersection = K.sum(y_true_flatten * y_pred_flatten) |
|
final_coef_value = (intersection + 1.0) / (K.sum(y_true_flatten) + K.sum(y_pred_flatten) - intersection + 1.0) |
|
return final_coef_value |
|
|
|
metrics = [tf.keras.metrics.MeanIoU(num_classes=2, sparse_y_true= False, sparse_y_pred=False, name="Mean IOU")] |
|
notebook_login() |
|
|
|
# Load model from Huggingface Hub |
|
model = from_pretrained_keras("MIDSCapstoneTeam/ContrailSentinel", custom_objects={'dice_loss_plus_5focal_loss': total_loss, 'jaccard_coef': jaccard_coef, 'IOU score' : sm.metrics.IOUScore(threshold=0.9, name="IOU score"), 'Dice Coeficient' : sm.metrics.FScore(threshold=0.6, name="Dice Coeficient")}, compile=False) |
|
model.compile(metrics=metrics) |
|
|
|
# Inference -- User needs to specify the image path where label and ash images are stored |
|
label = np.load({Image path} + 'human_pixel_masks.npy') |
|
ash_image = np.load({Image path} + 'ash_image.npy')[...,4] |
|
y_pred = model.predict(ash_image.reshape(1,256, 256, 3)) |
|
prediction = np.argmax(y_pred[0], axis=2).reshape(256,256,1) |
|
|
|
fig, ax = plt.subplots(1, 2, figsize=(9, 5)) |
|
fig.tight_layout(pad=5.0) |
|
ax[1].set_title("Contrail prediction") |
|
ax[1].imshow(ash_image) |
|
ax[1].imshow(prediction) |
|
ax[1].axis('off') |
|
|
|
ax[0].set_title("False colored satellite image") |
|
ax[0].imshow(ash_image) |
|
ax[0].axis('off') |
|
|
|
``` |
|
|
|
|
|
|
|
|