import onnxruntime as ort from typing import List, Tuple, Any, Dict from pathlib import Path import numpy as np from croplands.io import read_zarr, read_zarr_profile from croplands.utils import impute_nan, normalize_s2 from croplands.polygonize import polygonize_raster import json from skimage import measure class CroplandHandler(): def __init__(self, input_dir: str, output_dir: str, device: str = "cpu") -> None: self.input_dir = Path(input_dir) self.output_dir = Path(output_dir) assert self.input_dir.exists(), "Input directory doesn't exist" assert self.output_dir.exists(), "Output directory doesn't exist" assert device == "cpu" or device.startswith("cuda"), f"{device} is not a valid device." mdoel_path = "model_repository/utae.onnx" provider = "CUDAExecutionProvider" if device.startswith("cuda") else "CPUExecutionProvider" self.session = ort.InferenceSession(str(mdoel_path), providers=[provider]) with open("months_per_patch.json") as dates: self.dates = json.load(dates) def preprocess(self, file: str) -> Tuple[np.array, Dict, np.array]: assert file is not None, "Missing input file for inference" file_path = self.input_dir / file data = read_zarr(file_path) data = impute_nan(data) data = normalize_s2(data) profile = read_zarr_profile(file_path) dates = self.dates[file_path.stem] batch = np.expand_dims(data,axis=0) dates = np.expand_dims(np.array(dates),axis=0) return batch, profile, dates def postprocess(self, outputs: Any, file: str, profile: Dict, save_raster: bool = False) -> np.array: outputs = np.array(outputs) if save_raster: out_class = np.argmax(outputs[0][0], axis=0) out_bin = (out_class!=0).astype(np.uint8) components = measure.label(out_bin, connectivity=1) gdf = polygonize_raster(out_class, components, tolerance = 0.0001, transform= profile["transform"], crs=profile["crs"]) data_path = self.input_dir / file save_path = self.output_dir / (data_path.stem + ".parquet") gdf.to_parquet(save_path) return outputs def predict(self, files: List[str], save_raster: bool = False) -> np.array: # Preprocessing batch, profiles, dates = self.preprocess(files) # Inference outputs = self.session.run(None, {"input": batch, "batch_positions": dates}) # Postprocessing outputs = self.postprocess(outputs, files, profiles, save_raster) return outputs