from typing import List, Literal, Tuple, Type, Union from pathlib import Path import skimage.io import numpy as np from pydantic import BaseModel, Field, validator import imaging_server_kit as serverkit import orientationpy import matplotlib def rescale_intensity_quantile(image): """Rescale the image intensity based on the 2nd and 98th quantiles.""" image = image.astype(np.float64) image_normed = image - np.quantile(image, 0.02) image_normed = image / np.quantile(image_normed, 0.98) return image_normed class Parameters(BaseModel): """Defines the algorithm parameters""" image: str = Field( ..., title="Image", description="Base64-encoded numpy array. Should be decoded to a numpy array.", json_schema_extra={"widget_type": "image"}, ) mode: Literal["fiber", "membrane"] = Field( ..., title="Mode", description="The orientation computation mode.", json_schema_extra={"widget_type": "dropdown"}, ) scale: float = Field( title="Structural scale", description="The scale at which orientation is computed.", default=1.0, ge=0.1, le=10.0, json_schema_extra={"widget_type": "float"}, ) with_colors: bool = Field( default=False, title="Output color-coded orientation", description="Whether to output a color-coded representation of orientation or not.", json_schema_extra={"widget_type": "bool"}, ) vector_spacing: float = Field( title="Vector spacing", description="The spacing at which the orientation vectors are rendered.", default=1, ge=1, le=10, json_schema_extra={"widget_type": "int"}, ) @validator("image", pre=False, always=True) def decode_image_array(cls, v) -> np.ndarray: image_array = serverkit.decode_contents(v) if image_array.ndim not in [2, 3]: raise ValueError("Array has the wrong dimensionality.") return image_array class Server(serverkit.Server): def __init__( self, algorithm_name: str = "orientationpy", parameters_model: Type[BaseModel] = Parameters, ): super().__init__(algorithm_name, parameters_model) def run_algorithm( self, image: np.ndarray, mode: str, scale: float, with_colors: bool, vector_spacing: int, **kwargs, ) -> List[tuple]: """Run the orientationpy algorithm.""" if image.ndim == 2: mode = "fiber" # no membranes in 2D gradients = orientationpy.computeGradient(image, mode="splines") structureTensor = orientationpy.computeStructureTensor(gradients, sigma=scale) orientation_returns = orientationpy.computeOrientation( structureTensor, mode=mode, computeEnergy=False, computeCoherency=True, ) theta = orientation_returns.get("theta") + 90 phi = orientation_returns.get("phi") coherency = rescale_intensity_quantile(orientation_returns.get("coherency")) boxVectorCoords = orientationpy.anglesToVectors(orientation_returns) # This is messy... but it works (kinda)! node_spacings = np.array([vector_spacing] * image.ndim).astype(int) slices = [slice(n // 2, None, n) for n in node_spacings] grid = np.mgrid[[slice(0, x) for x in image.shape]] node_origins = np.stack([g[tuple(slices)] for g in grid]) slices.insert(0, slice(len(boxVectorCoords))) displacements = boxVectorCoords[tuple(slices)].copy() displacements *= np.mean(node_spacings) displacements = np.reshape(displacements, (image.ndim, -1)).T origins = np.reshape(node_origins, (image.ndim, -1)).T origins = origins - displacements / 2 displacement_vectors = np.stack((origins, displacements)) displacement_vectors = np.rollaxis(displacement_vectors, 1) data_tuple = [ ( displacement_vectors, { "name": "Orientation vectors", "edge_width": np.max(node_spacings) / 5.0, "opacity": 1.0, "ndim": image.ndim, "edge_color": "blue", "vector_style": "line", }, "vectors", ) ] if with_colors: if image.ndim == 3: imDisplayHSV = np.stack( (phi / 360, np.sin(np.deg2rad(theta)), image / image.max()), axis=-1 ) else: imDisplayHSV = np.stack( (theta / 180, coherency, image / image.max()), axis=-1 ) imdisplay_rgb = matplotlib.colors.hsv_to_rgb(imDisplayHSV) data_tuple.append( ( imdisplay_rgb, { "name": "Color-coded orientation", "rgb": True, }, "image", ) ) return data_tuple def load_sample_images(self) -> List["np.ndarray"]: """Load one or multiple sample images.""" image_dir = Path(__file__).parent / "sample_images" images = [skimage.io.imread(image_path) for image_path in image_dir.glob("*")] return images server = Server() app = server.app