|
import onnxruntime as ort |
|
from typing import List, Tuple, Any, Dict |
|
from pathlib import Path |
|
import numpy as np |
|
from croplands.io import read_zarr, read_zarr_profile |
|
from croplands.utils import impute_nan, normalize_s2 |
|
from croplands.polygonize import polygonize_raster |
|
import json |
|
from skimage import measure |
|
|
|
class CroplandHandler(): |
|
|
|
def __init__(self, input_dir: str, output_dir: str, device: str = "cpu") -> None: |
|
|
|
self.input_dir = Path(input_dir) |
|
self.output_dir = Path(output_dir) |
|
|
|
|
|
assert self.input_dir.exists(), "Input directory doesn't exist" |
|
assert self.output_dir.exists(), "Output directory doesn't exist" |
|
assert device == "cpu" or device.startswith("cuda"), f"{device} is not a valid device." |
|
|
|
|
|
mdoel_path = "model_repository/utae.onnx" |
|
provider = "CUDAExecutionProvider" if device.startswith("cuda") else "CPUExecutionProvider" |
|
self.session = ort.InferenceSession(str(mdoel_path), providers=[provider]) |
|
|
|
with open("months_per_patch.json") as dates: |
|
self.dates = json.load(dates) |
|
|
|
def preprocess(self, file: str) -> Tuple[np.array, Dict, np.array]: |
|
|
|
assert file is not None, "Missing input file for inference" |
|
|
|
file_path = self.input_dir / file |
|
data = read_zarr(file_path) |
|
data = impute_nan(data) |
|
data = normalize_s2(data) |
|
profile = read_zarr_profile(file_path) |
|
dates = self.dates[file_path.stem] |
|
batch = np.expand_dims(data,axis=0) |
|
dates = np.expand_dims(np.array(dates),axis=0) |
|
return batch, profile, dates |
|
|
|
def postprocess(self, outputs: Any, file: str, profile: Dict, save_raster: bool = False) -> np.array: |
|
outputs = np.array(outputs) |
|
|
|
|
|
if save_raster: |
|
out_class = np.argmax(outputs[0][0], axis=0) |
|
out_bin = (out_class!=0).astype(np.uint8) |
|
components = measure.label(out_bin, connectivity=1) |
|
gdf = polygonize_raster(out_class, components, tolerance = 0.0001, transform= profile["transform"], |
|
crs=profile["crs"]) |
|
data_path = self.input_dir / file |
|
save_path = self.output_dir / (data_path.stem + ".parquet") |
|
gdf.to_parquet(save_path) |
|
|
|
return outputs |
|
|
|
def predict(self, files: List[str], save_raster: bool = False) -> np.array: |
|
|
|
|
|
batch, profiles, dates = self.preprocess(files) |
|
|
|
outputs = self.session.run(None, {"input": batch, "batch_positions": dates}) |
|
|
|
outputs = self.postprocess(outputs, files, profiles, save_raster) |
|
|
|
return outputs |
|
|
|
|
|
|
|
|
|
|