Commit 
							
							·
						
						634d7ce
	
0
								Parent(s):
							
							
add files
Browse files- .gitattributes +3 -0
 - Dockerfile +97 -0
 - README.md +11 -0
 - app.py +234 -0
 - chip_102_345_merged.tif +3 -0
 - chip_104_104_merged.tif +3 -0
 - chip_109_421_merged.tif +3 -0
 - requirements.txt +3 -0
 
    	
        .gitattributes
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            chip_102_345_merged.tif filter=lfs diff=lfs merge=lfs -text
         
     | 
| 2 | 
         
            +
            chip_104_104_merged.tif filter=lfs diff=lfs merge=lfs -text
         
     | 
| 3 | 
         
            +
            chip_109_421_merged.tif filter=lfs diff=lfs merge=lfs -text
         
     | 
    	
        Dockerfile
    ADDED
    
    | 
         @@ -0,0 +1,97 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            FROM ubuntu:18.04
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            RUN apt-get update && apt-get install --no-install-recommends -y \
         
     | 
| 5 | 
         
            +
              build-essential \
         
     | 
| 6 | 
         
            +
              python3.8 \
         
     | 
| 7 | 
         
            +
              python3-pip \
         
     | 
| 8 | 
         
            +
              python3-setuptools \
         
     | 
| 9 | 
         
            +
              git \
         
     | 
| 10 | 
         
            +
              wget \
         
     | 
| 11 | 
         
            +
              && apt-get clean && rm -rf /var/lib/apt/lists/*
         
     | 
| 12 | 
         
            +
              
         
     | 
| 13 | 
         
            +
            RUN apt-get update && apt-get install ffmpeg libsm6 libxext6  -y
         
     | 
| 14 | 
         
            +
              
         
     | 
| 15 | 
         
            +
            # RUN echo $(ls /run/secrets/)
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            WORKDIR /code
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            # COPY ./requirements.txt /code/requirements.txt
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            # add conda
         
     | 
| 23 | 
         
            +
            RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -P /code/
         
     | 
| 24 | 
         
            +
            RUN chmod 777 /code/Miniconda3-latest-Linux-x86_64.sh
         
     | 
| 25 | 
         
            +
            RUN /code/Miniconda3-latest-Linux-x86_64.sh -b -p /code/miniconda
         
     | 
| 26 | 
         
            +
            ENV PATH="/code/miniconda/bin:${PATH}"
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            RUN groupadd miniconda
         
     | 
| 29 | 
         
            +
            RUN chgrp -R miniconda /code/miniconda/ 
         
     | 
| 30 | 
         
            +
            RUN chmod 770 -R /code/miniconda/ 
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            # Set up a new user named "user" with user ID 1000
         
     | 
| 34 | 
         
            +
            RUN useradd -m -u 1000 user
         
     | 
| 35 | 
         
            +
            RUN adduser user miniconda
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            # Switch to the "user" user
         
     | 
| 38 | 
         
            +
            USER user
         
     | 
| 39 | 
         
            +
            # Set home to the user's home directory
         
     | 
| 40 | 
         
            +
            ENV HOME=/home/user \
         
     | 
| 41 | 
         
            +
            	PATH=/home/user/.local/bin:$PATH \
         
     | 
| 42 | 
         
            +
                PYTHONPATH=$HOME/app \
         
     | 
| 43 | 
         
            +
            	PYTHONUNBUFFERED=1 \
         
     | 
| 44 | 
         
            +
            	GRADIO_ALLOW_FLAGGING=never \
         
     | 
| 45 | 
         
            +
            	GRADIO_NUM_PORTS=1 \
         
     | 
| 46 | 
         
            +
            	GRADIO_SERVER_NAME=0.0.0.0 \
         
     | 
| 47 | 
         
            +
            	GRADIO_THEME=huggingface \
         
     | 
| 48 | 
         
            +
            	SYSTEM=spaces
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
            RUN conda install python=3.8
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
            RUN pip3 install setuptools-rust
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
            RUN conda install pillow -y
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
            # RUN pip3 install --no-cache-dir --upgrade -r /code/requirements.txt
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            RUN conda install -c pytorch pytorch==1.7.1 torchvision==0.8.2 
         
     | 
| 59 | 
         
            +
            # RUN pip install torchvision-cpu==0.8.2
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
            # WORKDIR /home/user
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
            # RUN git clone https://github.com/open-mmlab/mim.git
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
            RUN pip3 install openmim
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
            RUN conda install -c conda-forge gradio -y 
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
             
         
     | 
| 71 | 
         
            +
            # RUN --mount=type=secret,id=git_token,mode=0444,required=true \
         
     | 
| 72 | 
         
            +
            #     echo $(https://$(cat /run/secrets/git_token)@github.com/NASA-IMPACT/hls-foundation-os.git) 
         
     | 
| 73 | 
         
            +
            #     # git clone https://$(cat /run/secrets/git_token)@github.com/NASA-IMPACT/hls-foundation-os.git
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
            WORKDIR /home/user
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
            RUN --mount=type=secret,id=git_token,mode=0444,required=true \
         
     | 
| 78 | 
         
            +
                git clone https://$(cat /run/secrets/git_token)@github.com/NASA-IMPACT/hls-foundation-os.git
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
            WORKDIR hls-foundation-os 
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
            RUN pip3 install fine-tuning-examples/
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
            # RUN --mount=type=secret,id=git_token,mode=0444,required=true \
         
     | 
| 87 | 
         
            +
            #     pip3 install git+https://$(cat /run/secrets/git_token)@github.com/NASA-IMPACT/hls-foundation-os.git@mmseg-only
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
            RUN mim install mmcv-full==1.5.0
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
            # Set the working directory to the user's home directory
         
     | 
| 92 | 
         
            +
            WORKDIR $HOME/app
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
            # Copy the current directory contents into the container at $HOME/app setting the owner to the user
         
     | 
| 95 | 
         
            +
            COPY --chown=user . $HOME/app
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
            CMD ["python3", "app.py"]
         
     | 
    	
        README.md
    ADDED
    
    | 
         @@ -0,0 +1,11 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            ---
         
     | 
| 2 | 
         
            +
            title: Prithvi 100M Burn Scars Demo
         
     | 
| 3 | 
         
            +
            emoji: 🌖
         
     | 
| 4 | 
         
            +
            colorFrom: purple
         
     | 
| 5 | 
         
            +
            colorTo: green
         
     | 
| 6 | 
         
            +
            sdk: docker
         
     | 
| 7 | 
         
            +
            pinned: false
         
     | 
| 8 | 
         
            +
            license: apache-2.0
         
     | 
| 9 | 
         
            +
            ---
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         
     | 
    	
        app.py
    ADDED
    
    | 
         @@ -0,0 +1,234 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            ######### pull files
         
     | 
| 2 | 
         
            +
            import os
         
     | 
| 3 | 
         
            +
            from huggingface_hub import hf_hub_download
         
     | 
| 4 | 
         
            +
            config_path=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification", 
         
     | 
| 5 | 
         
            +
                                        filename="multi_temporal_crop_classification_Prithvi_100M.py", 
         
     | 
| 6 | 
         
            +
                                        token=os.environ.get("token"))
         
     | 
| 7 | 
         
            +
            ckpt=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification", 
         
     | 
| 8 | 
         
            +
                                 filename='multi_temporal_crop_classification_best_mIoU_epoch_66.pth', 
         
     | 
| 9 | 
         
            +
                                 token=os.environ.get("token"))
         
     | 
| 10 | 
         
            +
            ##########
         
     | 
| 11 | 
         
            +
            import argparse
         
     | 
| 12 | 
         
            +
            from mmcv import Config
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            from mmseg.models import build_segmentor
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            from mmseg.datasets.pipelines import Compose, LoadImageFromFile
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            import rasterio
         
     | 
| 19 | 
         
            +
            import torch
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            from mmseg.apis import init_segmentor
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            from mmcv.parallel import collate, scatter
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            import numpy as np
         
     | 
| 26 | 
         
            +
            import glob
         
     | 
| 27 | 
         
            +
            import os
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            import time
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            import numpy as np
         
     | 
| 32 | 
         
            +
            import gradio as gr
         
     | 
| 33 | 
         
            +
            from functools import partial
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            import pdb
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            import matplotlib.pyplot as plt
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            def open_tiff(fname):
         
     | 
| 41 | 
         
            +
                
         
     | 
| 42 | 
         
            +
                with rasterio.open(fname, "r") as src:
         
     | 
| 43 | 
         
            +
                    
         
     | 
| 44 | 
         
            +
                    data = src.read()
         
     | 
| 45 | 
         
            +
                    
         
     | 
| 46 | 
         
            +
                return data
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            def write_tiff(img_wrt, filename, metadata):
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                """
         
     | 
| 51 | 
         
            +
                It writes a raster image to file.
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                :param img_wrt: numpy array containing the data (can be 2D for single band or 3D for multiple bands)
         
     | 
| 54 | 
         
            +
                :param filename: file path to the output file
         
     | 
| 55 | 
         
            +
                :param metadata: metadata to use to write the raster to disk
         
     | 
| 56 | 
         
            +
                :return:
         
     | 
| 57 | 
         
            +
                """
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                with rasterio.open(filename, "w", **metadata) as dest:
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                    if len(img_wrt.shape) == 2:
         
     | 
| 62 | 
         
            +
                        
         
     | 
| 63 | 
         
            +
                        img_wrt = img_wrt[None]
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                    for i in range(img_wrt.shape[0]):
         
     | 
| 66 | 
         
            +
                        dest.write(img_wrt[i, :, :], i + 1)
         
     | 
| 67 | 
         
            +
                
         
     | 
| 68 | 
         
            +
                return filename
         
     | 
| 69 | 
         
            +
                        
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
            def get_meta(fname):
         
     | 
| 72 | 
         
            +
                
         
     | 
| 73 | 
         
            +
                with rasterio.open(fname, "r") as src:
         
     | 
| 74 | 
         
            +
                    
         
     | 
| 75 | 
         
            +
                    meta = src.meta
         
     | 
| 76 | 
         
            +
                    
         
     | 
| 77 | 
         
            +
                return meta
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
            def preprocess_example(example_list):
         
     | 
| 80 | 
         
            +
                
         
     | 
| 81 | 
         
            +
                example_list = [os.path.join(os.path.abspath(''), x) for x in example_list]
         
     | 
| 82 | 
         
            +
                
         
     | 
| 83 | 
         
            +
                return example_list
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
            def inference_segmentor(model, imgs, custom_test_pipeline=None):
         
     | 
| 87 | 
         
            +
                """Inference image(s) with the segmentor.
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                Args:
         
     | 
| 90 | 
         
            +
                    model (nn.Module): The loaded segmentor.
         
     | 
| 91 | 
         
            +
                    imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
         
     | 
| 92 | 
         
            +
                        images.
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                Returns:
         
     | 
| 95 | 
         
            +
                    (list[Tensor]): The segmentation result.
         
     | 
| 96 | 
         
            +
                """
         
     | 
| 97 | 
         
            +
                cfg = model.cfg
         
     | 
| 98 | 
         
            +
                device = next(model.parameters()).device  # model device
         
     | 
| 99 | 
         
            +
                # build the data pipeline
         
     | 
| 100 | 
         
            +
                test_pipeline = [LoadImageFromFile()] + cfg.data.test.pipeline[1:] if custom_test_pipeline == None else custom_test_pipeline
         
     | 
| 101 | 
         
            +
                test_pipeline = Compose(test_pipeline)
         
     | 
| 102 | 
         
            +
                # prepare data
         
     | 
| 103 | 
         
            +
                data = []
         
     | 
| 104 | 
         
            +
                imgs = imgs if isinstance(imgs, list) else [imgs]
         
     | 
| 105 | 
         
            +
                for img in imgs:
         
     | 
| 106 | 
         
            +
                    img_data = {'img_info': {'filename': img}}
         
     | 
| 107 | 
         
            +
                    img_data = test_pipeline(img_data)
         
     | 
| 108 | 
         
            +
                    data.append(img_data)
         
     | 
| 109 | 
         
            +
                # print(data.shape)
         
     | 
| 110 | 
         
            +
                
         
     | 
| 111 | 
         
            +
                data = collate(data, samples_per_gpu=len(imgs))
         
     | 
| 112 | 
         
            +
                if next(model.parameters()).is_cuda:
         
     | 
| 113 | 
         
            +
                    # data = collate(data, samples_per_gpu=len(imgs))
         
     | 
| 114 | 
         
            +
                    # scatter to specified GPU
         
     | 
| 115 | 
         
            +
                    data = scatter(data, [device])[0]
         
     | 
| 116 | 
         
            +
                else:
         
     | 
| 117 | 
         
            +
                    # img_metas = scatter(data['img_metas'],'cpu')
         
     | 
| 118 | 
         
            +
                    # data['img_metas'] = [i.data[0] for i in data['img_metas']]
         
     | 
| 119 | 
         
            +
                    
         
     | 
| 120 | 
         
            +
                    img_metas = data['img_metas'].data[0]
         
     | 
| 121 | 
         
            +
                    img = data['img']
         
     | 
| 122 | 
         
            +
                    data = {'img': img, 'img_metas':img_metas}
         
     | 
| 123 | 
         
            +
                
         
     | 
| 124 | 
         
            +
                with torch.no_grad():
         
     | 
| 125 | 
         
            +
                    result = model(return_loss=False, rescale=True, **data)
         
     | 
| 126 | 
         
            +
                return result
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
            def inference_on_file(target_image, model, custom_test_pipeline):
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                target_image = target_image.name
         
     | 
| 132 | 
         
            +
                # print(type(target_image))
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                # output_image = target_image.replace('.tif', '_pred.tif')
         
     | 
| 135 | 
         
            +
                time_taken=-1
         
     | 
| 136 | 
         
            +
                try:
         
     | 
| 137 | 
         
            +
                    st = time.time()
         
     | 
| 138 | 
         
            +
                    print('Running inference...')
         
     | 
| 139 | 
         
            +
                    result = inference_segmentor(model, target_image, custom_test_pipeline)
         
     | 
| 140 | 
         
            +
                    print("Output has shape: " + str(result[0].shape))
         
     | 
| 141 | 
         
            +
                    
         
     | 
| 142 | 
         
            +
                    ##### get metadata mask
         
     | 
| 143 | 
         
            +
                    mask = open_tiff(target_image)
         
     | 
| 144 | 
         
            +
                    # rgb = mask[[2, 1, 0], :, :].transpose((1,2,0))
         
     | 
| 145 | 
         
            +
                    rgb1 = mask[[2, 1, 0], :, :].transpose((1,2,0))
         
     | 
| 146 | 
         
            +
                    rgb2 = mask[[8, 7, 6], :, :].transpose((1,2,0))
         
     | 
| 147 | 
         
            +
                    rgb3 = mask[[14, 13, 12], :, :].transpose((1,2,0))
         
     | 
| 148 | 
         
            +
                    meta = get_meta(target_image)
         
     | 
| 149 | 
         
            +
                    mask = np.where(mask == meta['nodata'], 1, 0)
         
     | 
| 150 | 
         
            +
                    mask = np.max(mask, axis=0)[None]
         
     | 
| 151 | 
         
            +
                    
         
     | 
| 152 | 
         
            +
                    result[0] = np.where(mask == 1, -1, result[0])
         
     | 
| 153 | 
         
            +
                    
         
     | 
| 154 | 
         
            +
                    ##### Save file to disk
         
     | 
| 155 | 
         
            +
                    meta["count"] = 1
         
     | 
| 156 | 
         
            +
                    meta["dtype"] = "int16"
         
     | 
| 157 | 
         
            +
                    meta["compress"] = "lzw"
         
     | 
| 158 | 
         
            +
                    meta["nodata"] = -1
         
     | 
| 159 | 
         
            +
                    print('Saving output...')
         
     | 
| 160 | 
         
            +
                    # write_tiff(result[0], output_image, meta)
         
     | 
| 161 | 
         
            +
                    et = time.time()
         
     | 
| 162 | 
         
            +
                    time_taken = np.round(et - st, 1)
         
     | 
| 163 | 
         
            +
                    print(f'Inference completed in {str(time_taken)} seconds')
         
     | 
| 164 | 
         
            +
                    
         
     | 
| 165 | 
         
            +
                except:
         
     | 
| 166 | 
         
            +
                    print(f'Error on image {target_image} \nContinue to next input')
         
     | 
| 167 | 
         
            +
                    
         
     | 
| 168 | 
         
            +
                return rgb, result[0][0]*255
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
            def process_test_pipeline(custom_test_pipeline, bands=None):
         
     | 
| 171 | 
         
            +
                
         
     | 
| 172 | 
         
            +
                # change extracted bands if necessary
         
     | 
| 173 | 
         
            +
                if bands is not None:
         
     | 
| 174 | 
         
            +
                    
         
     | 
| 175 | 
         
            +
                    extract_index = [i for i, x in enumerate(custom_test_pipeline) if x['type'] == 'BandsExtract' ]
         
     | 
| 176 | 
         
            +
                    
         
     | 
| 177 | 
         
            +
                    if len(extract_index) > 0:
         
     | 
| 178 | 
         
            +
                        
         
     | 
| 179 | 
         
            +
                        custom_test_pipeline[extract_index[0]]['bands'] = eval(bands)
         
     | 
| 180 | 
         
            +
                        
         
     | 
| 181 | 
         
            +
                collect_index = [i for i, x in enumerate(custom_test_pipeline) if x['type'].find('Collect') > -1]
         
     | 
| 182 | 
         
            +
                
         
     | 
| 183 | 
         
            +
                # adapt collected keys if necessary
         
     | 
| 184 | 
         
            +
                if len(collect_index) > 0:
         
     | 
| 185 | 
         
            +
                    
         
     | 
| 186 | 
         
            +
                    keys = ['img_info', 'filename', 'ori_filename', 'img', 'img_shape', 'ori_shape', 'pad_shape', 'scale_factor', 'img_norm_cfg']
         
     | 
| 187 | 
         
            +
                    custom_test_pipeline[collect_index[0]]['meta_keys'] = keys
         
     | 
| 188 | 
         
            +
                
         
     | 
| 189 | 
         
            +
                return custom_test_pipeline
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
            config = Config.fromfile(config_path)
         
     | 
| 192 | 
         
            +
            config.model.backbone.pretrained=None
         
     | 
| 193 | 
         
            +
            model = init_segmentor(config, ckpt, device='cpu')
         
     | 
| 194 | 
         
            +
            custom_test_pipeline=process_test_pipeline(model.cfg.data.test.pipeline, None)
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
            func = partial(inference_on_file, model=model, custom_test_pipeline=custom_test_pipeline)
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
            with gr.Blocks() as demo:
         
     | 
| 199 | 
         
            +
               
         
     | 
| 200 | 
         
            +
                gr.Markdown(value='# Prithvi multi temporal crop classification')
         
     | 
| 201 | 
         
            +
                gr.Markdown(value='''Prithvi is a first-of-its-kind temporal Vision transformer pretrained by the IBM and NASA team on continental US Harmonised Landsat Sentinel 2 (HLS) data. This demo showcases how the model was finetuned to classify crop and other land use categories using multi temporal data. More detailes can be found [here](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification).\n
         
     | 
| 202 | 
         
            +
                The user needs to provide an HLS geotiff image, including 18 bands for 3 time-step, and each time-step includes the channels described above (Blue, Green, Red, Narrow NIR, SWIR, SWIR 2) in order.
         
     | 
| 203 | 
         
            +
                ''')
         
     | 
| 204 | 
         
            +
                with gr.Row():
         
     | 
| 205 | 
         
            +
                    with gr.Column():
         
     | 
| 206 | 
         
            +
                        inp = gr.File()
         
     | 
| 207 | 
         
            +
                        btn = gr.Button("Submit")
         
     | 
| 208 | 
         
            +
                
         
     | 
| 209 | 
         
            +
                with gr.Row():
         
     | 
| 210 | 
         
            +
                    gr.Markdown(value='### T1')
         
     | 
| 211 | 
         
            +
                    gr.Markdown(value='### T2')
         
     | 
| 212 | 
         
            +
                    gr.Markdown(value='### T3')
         
     | 
| 213 | 
         
            +
                    gr.Markdown(value='### Model prediction')
         
     | 
| 214 | 
         
            +
                    
         
     | 
| 215 | 
         
            +
                with gr.Row():
         
     | 
| 216 | 
         
            +
                    inp1=gr.Image(image_mode='RGB')
         
     | 
| 217 | 
         
            +
                    inp2=gr.Image(image_mode='RGB')
         
     | 
| 218 | 
         
            +
                    inp3=gr.Image(image_mode='RGB')
         
     | 
| 219 | 
         
            +
                    out = gr.Image(image_mode='L')
         
     | 
| 220 | 
         
            +
                
         
     | 
| 221 | 
         
            +
                btn.click(fn=func, inputs=inp, outputs=[inp1, inp2, inp3, out])
         
     | 
| 222 | 
         
            +
                
         
     | 
| 223 | 
         
            +
                with gr.Row():
         
     | 
| 224 | 
         
            +
                    gr.Examples(examples=["chip_102_345_merged.tif",
         
     | 
| 225 | 
         
            +
                                 "chip_104_104_merged.tif",
         
     | 
| 226 | 
         
            +
                                 "chip_109_421_merged.tif"],
         
     | 
| 227 | 
         
            +
                                inputs=inp,
         
     | 
| 228 | 
         
            +
                                outputs=[inp1, inp2, inp3, out],
         
     | 
| 229 | 
         
            +
                                preprocess=preprocess_example,
         
     | 
| 230 | 
         
            +
                                fn=func,
         
     | 
| 231 | 
         
            +
                                cache_examples=True,
         
     | 
| 232 | 
         
            +
                )
         
     | 
| 233 | 
         
            +
             
     | 
| 234 | 
         
            +
            demo.launch() 
         
     | 
    	
        chip_102_345_merged.tif
    ADDED
    
    | 
											 | 
									
								
											Git LFS Details
  | 
									
    	
        chip_104_104_merged.tif
    ADDED
    
    | 
											 | 
									
								
											Git LFS Details
  | 
									
    	
        chip_109_421_merged.tif
    ADDED
    
    | 
											 | 
									
								
											Git LFS Details
  | 
									
    	
        requirements.txt
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            pytorch==1.7.1
         
     | 
| 2 | 
         
            +
            torchvision==0.8.2
         
     | 
| 3 | 
         
            +
            openmim
         
     |