Mimicking ImageJ’s watershed algorithm

Mimicking ImageJ’s watershed algorithm#

In ImageJ there is an algorithm called “Watershed” which allows splitting segmented dense objects within a binary image. This notebook demonstrates how to achieve a similar operation in Python.

The “Watershed” in ImageJ was applied to a binary image using this macro:

open("../BioImageAnalysisNotebooks/data/blobs_otsu.tif");
run("Watershed");
from skimage.io import imread, imshow
import matplotlib.pyplot as plt
import napari_segment_blobs_and_things_with_membranes as nsbatwm
import numpy as np

from scipy import ndimage as ndi
from skimage.feature import peak_local_max
from skimage.filters import gaussian, sobel
from skimage.measure import label
from skimage.segmentation import watershed
from skimage.morphology import binary_opening

Starting point for the demonstration is a binary image.

binary_image = imread("../../data/blobs_otsu.tif")
imshow(binary_image)
<matplotlib.image.AxesImage at 0x1f8e41b2700>
../_images/1a67ed79a4d4aad6da50c333def145da3c6d63c765ee7abfe2b9596de8270df6.png

After applying the macro shown above, the result image in ImageJ looksl ike this:

binary_watershed_imagej = imread("../../data/blobs_otsu_watershed.tif")
imshow(binary_watershed_imagej)
<matplotlib.image.AxesImage at 0x1f8e4209100>
../_images/86e8366055208bd6e1f9322796080904a9f59e2f2e95f6df24b086204337e884.png

The Napari plugin napari-segment-blobs-and-things-with-membranes offers a function for mimicking the functionality from ImageJ.

binary_watershed_nsbatwm = nsbatwm.split_touching_objects(binary_image)
imshow(binary_watershed_nsbatwm)
<matplotlib.image.AxesImage at 0x1f8e42a63d0>
../_images/7b1305d67025f669d96b7d40ee328b1386a1680344b38718b319f651ab255b28.png

Comparing results#

When comparing results, it is obvious that the results are not 100% identical.

fig, axs = plt.subplots(1, 2, figsize=(10,10))

axs[0].imshow(binary_watershed_imagej)
axs[0].set_title("ImageJ")
axs[1].imshow(binary_watershed_nsbatwm)
axs[1].set_title("nsbatwm")
Text(0.5, 1.0, 'nsbatwm')
../_images/8f620d03e62c5d9540f7b8417ba99ffa6fd4ce0c330ae762e92e786d90f631e6.png

Fine-tuning results#

Modifying the result is possible by tuning the sigma parameter.

fig, axs = plt.subplots(1, 4, figsize=(10,10))

for i, sigma in enumerate(np.arange(2, 6, 1)):
    result = nsbatwm.split_touching_objects(binary_image, sigma=sigma)
    axs[i].imshow(result)
    axs[i].set_title("sigma="+str(sigma))
../_images/2b270e8a59cbce9ef7d0f05de36435c9cd49fb2b6085dc659caed85cc9d956a2.png

How does it work?#

Under the hood, ImageJ’s watershed algorithm uses a distance image and spot-detection. The following code attempts to replicate the result.

Again, we start from the binary image.

imshow(binary_image)
<matplotlib.image.AxesImage at 0x1f8e55b87f0>
../_images/1a67ed79a4d4aad6da50c333def145da3c6d63c765ee7abfe2b9596de8270df6.png

The first step is to produce a distance image.

distance = ndi.distance_transform_edt(binary_image)
imshow(distance)
C:\Users\haase\mambaforge\envs\bio39\lib\site-packages\skimage\io\_plugins\matplotlib_plugin.py:150: UserWarning: Float image out of standard range; displaying image with stretched contrast.
  lo, hi, cmap = _get_display_range(image)
<matplotlib.image.AxesImage at 0x1f8e55f17c0>
../_images/a915b049790f4c508c064e452983cc5e39f2a9e725651d71344f68f30352d00f.png

To avoid very small split objects, we blur the distance image using the sigma parameter.

sigma = 3.5

blurred_distance = gaussian(distance, sigma=sigma)
imshow(blurred_distance)
<matplotlib.image.AxesImage at 0x1f8e56ec5e0>
../_images/cf45e136aa34b9b8b861abc84d50895f757c31ed010993bedc83c9fbe7a543e1.png

Within this blurred image, we search for local maxima and receive them as list of coordinates.

fp = np.ones((3,) * binary_image.ndim)
coords = peak_local_max(blurred_distance, footprint=fp, labels=binary_image)

# show the first 5 only
coords[:5]
array([[  8, 254],
       [ 97,   1],
       [ 10, 108],
       [230, 180],
       [182, 179]], dtype=int64)

We next write these maxima into a new image and label them.

mask = np.zeros(distance.shape, dtype=bool)
mask[tuple(coords.T)] = True
markers = label(mask)
imshow(markers, cmap='jet')
C:\Users\haase\mambaforge\envs\bio39\lib\site-packages\skimage\io\_plugins\matplotlib_plugin.py:150: UserWarning: Low image data range; displaying image with stretched contrast.
  lo, hi, cmap = _get_display_range(image)
<matplotlib.image.AxesImage at 0x1f8e59c1c40>
../_images/eb86a0069cb107c60d8e3a783746b77c651dba9b4359d01bd5c84455c02d4df6.png

Next, we apply scikit-image’s Watershed algorithm (example). It takes a distance image and a label image as input. Optional input is the binary_image to limit spreading the labels too far.

labels = watershed(-blurred_distance, markers, mask=binary_image)
imshow(labels, cmap='jet')
<matplotlib.image.AxesImage at 0x1f8e5a97f40>
../_images/9fd2db3a41881adc4680b7461e9d482e8ff369a2ab220a46293b725e5d81cd07.png

To create a binary image again as ImageJ does, we now identify the edges between the labels.

# identify label-cutting edges
edges_labels = sobel(labels)
edges_binary = sobel(binary_image)

edges = np.logical_xor(edges_labels != 0, edges_binary != 0)
imshow(edges)
<matplotlib.image.AxesImage at 0x1f8e5aaea00>
../_images/fef6ed859fb8cd3a99dc0c12fe342189109bcefdabbe3ce0a2d0abe0920e9444.png

Next we subtract those edges from the original binary_image.

almost = np.logical_not(edges) * binary_image
imshow(almost)
<matplotlib.image.AxesImage at 0x1f8e55c8610>
../_images/5b617b10b44acd24fb80a9af3c5e1d6870481916b0d727eab774ad6adb6e0020.png

As this result is not perfect yet, we apply a binary opening.

result = binary_opening(almost)
imshow(result)
<matplotlib.image.AxesImage at 0x1f8e3b81e50>
../_images/7b1305d67025f669d96b7d40ee328b1386a1680344b38718b319f651ab255b28.png