--- datasets: - conflux-xyz/tcga-tissue-segmentation language: - en base_model: - timm/mobilenetv3_small_100.lamb_in1k pipeline_tag: image-segmentation tags: - histology - pathology license: apache-2.0 --- # CxTissueSeg ## Overview The **CxTissueSeg** model performs binary segmentation of patches of tissue present in [H&E](https://en.wikipedia.org/wiki/H%26E_stain) pathology slides. It is architected to run efficiently on resource constrained systems, providing tissue segmentation on a slide in under 1 second on a typical CPU. The model is trained on a manually curated set of slides from [our linked dataset](https://huggingface.co/datasets/conflux-xyz/tcga-tissue-segmentation), where it achieves 0.93 mIoU for tissue on the test split. By default, the model outputs logits, where the positive class is predicted tissue and the negative class is predicted backgound. It is recommended to use the model with our open source [tiled inference framework](https://github.com/conflux-xyz/conflux-segmentation), which will handle running inference on a full image through tiling and stitching results. This model was trained using PyTorch and [Segmentation Models PyTorch](https://smp.readthedocs.io/en/latest/). It uses a UNet decoder with a MobileNet-v3 encoder -- specifically, we use [`timm/mobilenetv3_small_100`](https://huggingface.co/timm/mobilenetv3_small_100.lamb_in1k) as the encoder. We provide the model weights in both a [pickled format](https://pytorch.org/tutorials/beginner/saving_loading_models.html) ([`model.pth`](./model.pth)) and via [safetensors](https://huggingface.co/docs/safetensors/en/index) ([`model.safetensors`](./model.safetensors)). We also provide the model exported to ONNX ([`model.onnx`](./model.onnx)) to be used with ONNX Runtime so it can be run even more efficiently and across programming languages. To try a demo of the model being run in the browser vai ONNX Runtime, see: http://www.conflux.xyz/demos/tissue-segmentation. We also provide a statically quantized model (int8) usable via ONNX Runtime with [`model_qint8.onnx`](./model_qint8.onnx), although its performance is not on par with the full float32 model (0.85 mIoU rather than 0.93 mIoU). For more details on the background of the model, check out the blog post here: http://www.conflux.xyz/blog/tissue-segmentation. ## Usage **CxTissueSeg** was trained on 512 x 512 pixel patches from thumbnail images of whole slides at 40 microns per pixel (MPP) -- a 4x downsample from the images in the dataset. Thus, it is important when running inference with the model to run it on 40 MPP thumbnails and run inference on tiles of the same dimension (512 x 512). When padding tiles, pad with pure white: `rgb(255, 255, 255)`. To make this easier, we provide a more general segmentation library to aid in performing tiled inference: https://github.com/conflux-xyz/conflux-segmentation. ### Create a segmentation model #### ONNX ```python # pip install conflux-segmentation[onnx] onnxruntime import onnxruntime as ort from conflux_segmentation import Segmenter session = ort.InferenceSession("/path/to/model.onnx") segmenter = Segmenter.from_onnx(session, activation="sigmoid") ``` #### PyTorch ```python # pip install conflux-segmentation[torch] torch segmentation-models-pytorch import segmentation_models_pytorch as smp from conflux_segmentation import Segmenter net = smp.Unet(encoder_name="tu-mobilenetv3_small_100", encoder_weights=None, activation=None) net.load_state_dict(torch.load("/path/to/model.pth", weights_only=True)) # alternatively with safetensors: # net.load_state_dict(safetensors.torch.load_file("/path/to/model.safetensors")) net.eval() # Optionally, trace the model to get a TorchScript ScriptModule # example = torch.randn(1, 3, 512, 512) # net = torch.jit.trace(net, example) # net.eval() segmenter = Segmenter.from_torch(net, activation="sigmoid") ``` ### Segment! ```python import cv2 # A 40 MPP thumbnail: H x W x 3 image array of np.uint8 image = cv2.cvtColor(cv2.imread("/path/to/large/image"), cv2.COLOR_BGR2RGB) # Alternatively, use `openslide` or `tiffslide` to get a 40 MPP thumbnail # H x W boolean array mask = segmenter(image).to_binary().get_mask() tissue_fraction = mask.sum() / mask.size print(f"Fraction of slide with tissue: {tissue_fraction:.3f}") ``` ## Acknowledgements We are grateful to the TCGA Research Network from which the slides used for training were originally sourced. Per their citation request (https://www.cancer.gov/ccg/research/genome-sequencing/tcga/using-tcga-data/citing), > The results shown here are in whole or part based upon data generated by the TCGA Research Network: https://www.cancer.gov/tcga.