ONNX
aida-cropland-models / handler.py
gtano's picture
Big upload
67a3943 verified
raw
history blame
2.52 kB
import onnxruntime as ort
from typing import List, Tuple, Any, Dict
from pathlib import Path
import numpy as np
from croplands.io import read_zarr, read_zarr_profile
from croplands.utils import impute_nan, normalize_s2
from croplands.polygonize import polygonize_raster
import json
from skimage import measure
class CroplandHandler():
def __init__(self, input_dir: str, output_dir: str, device: str = "cpu") -> None:
self.input_dir = Path(input_dir)
self.output_dir = Path(output_dir)
assert self.input_dir.exists(), "Input directory doesn't exist"
assert self.output_dir.exists(), "Output directory doesn't exist"
assert device == "cpu" or device.startswith("cuda"), f"{device} is not a valid device."
mdoel_path = "model_repository/utae.onnx"
provider = "CUDAExecutionProvider" if device.startswith("cuda") else "CPUExecutionProvider"
self.session = ort.InferenceSession(str(mdoel_path), providers=[provider])
with open("months_per_patch.json") as dates:
self.dates = json.load(dates)
def preprocess(self, file: str) -> Tuple[np.array, Dict, np.array]:
assert file is not None, "Missing input file for inference"
file_path = self.input_dir / file
data = read_zarr(file_path)
data = impute_nan(data)
data = normalize_s2(data)
profile = read_zarr_profile(file_path)
dates = self.dates[file_path.stem]
batch = np.expand_dims(data,axis=0)
dates = np.expand_dims(np.array(dates),axis=0)
return batch, profile, dates
def postprocess(self, outputs: Any, file: str, profile: Dict, save_raster: bool = False) -> np.array:
outputs = np.array(outputs)
if save_raster:
out_class = np.argmax(outputs[0][0], axis=0)
out_bin = (out_class!=0).astype(np.uint8)
components = measure.label(out_bin, connectivity=1)
gdf = polygonize_raster(out_class, components, tolerance = 0.0001, transform= profile["transform"],
crs=profile["crs"])
data_path = self.input_dir / file
save_path = self.output_dir / (data_path.stem + ".parquet")
gdf.to_parquet(save_path)
return outputs
def predict(self, files: List[str], save_raster: bool = False) -> np.array:
# Preprocessing
batch, profiles, dates = self.preprocess(files)
# Inference
outputs = self.session.run(None, {"input": batch, "batch_positions": dates})
# Postprocessing
outputs = self.postprocess(outputs, files, profiles, save_raster)
return outputs