Mimicking ImageJ’s watershed algorithm#

In ImageJ there is an algorithm called “Watershed” which allows splitting segmented dense objects within a binary image. This notebook demonstrates how to achieve a similar operation in Python.

The “Watershed” in ImageJ was applied to a binary image using this macro:

from skimage.io import imread, imshow
import matplotlib.pyplot as plt
import napari_segment_blobs_and_things_with_membranes as nsbatwm
import numpy as np

from scipy import ndimage as ndi
from skimage.feature import peak_local_max
from skimage.filters import gaussian, sobel
from skimage.measure import label
from skimage.segmentation import watershed
from skimage.morphology import binary_opening

Starting point for the demonstration is a binary image.

binary_image = imread("../../data/blobs_otsu.tif")
<matplotlib.image.AxesImage at 0x1f8e41b2700>

After applying the macro shown above, the result image in ImageJ looksl ike this:

binary_watershed_imagej = imread("../../data/blobs_otsu_watershed.tif")
<matplotlib.image.AxesImage at 0x1f8e4209100>

The Napari plugin napari-segment-blobs-and-things-with-membranes offers a function for mimicking the functionality from ImageJ.

binary_watershed_nsbatwm = nsbatwm.split_touching_objects(binary_image)
<matplotlib.image.AxesImage at 0x1f8e42a63d0>

Comparing results#

When comparing results, it is obvious that the results are not 100% identical.

fig, axs = plt.subplots(1, 2, figsize=(10,10))

Text(0.5, 1.0, 'nsbatwm')

Fine-tuning results#

Modifying the result is possible by tuning the sigma parameter.

fig, axs = plt.subplots(1, 4, figsize=(10,10))

for i, sigma in enumerate(np.arange(2, 6, 1)):
    result = nsbatwm.split_touching_objects(binary_image, sigma=sigma)

How does it work?#

Under the hood, ImageJ’s watershed algorithm uses a distance image and spot-detection. The following code attempts to replicate the result.

Again, we start from the binary image.

<matplotlib.image.AxesImage at 0x1f8e55b87f0>

The first step is to produce a distance image.

distance = ndi.distance_transform_edt(binary_image)
C:\Users\haase\mambaforge\envs\bio39\lib\site-packages\skimage\io\_plugins\matplotlib_plugin.py:150: UserWarning: Float image out of standard range; displaying image with stretched contrast.
  lo, hi, cmap = _get_display_range(image)
<matplotlib.image.AxesImage at 0x1f8e55f17c0>

To avoid very small split objects, we blur the distance image using the sigma parameter.

sigma = 3.5

blurred_distance = gaussian(distance, sigma=sigma)
<matplotlib.image.AxesImage at 0x1f8e56ec5e0>

Within this blurred image, we search for local maxima and receive them as list of coordinates.

fp = np.ones((3,) * binary_image.ndim)
coords = peak_local_max(blurred_distance, footprint=fp, labels=binary_image)

# show the first 5 only
array([[  8, 254],
       [ 97,   1],
       [ 10, 108],
       [230, 180],
       [182, 179]], dtype=int64)

We next write these maxima into a new image and label them.

mask = np.zeros(distance.shape, dtype=bool)
mask[tuple(coords.T)] = True
markers = label(mask)
imshow(markers, cmap='jet')
C:\Users\haase\mambaforge\envs\bio39\lib\site-packages\skimage\io\_plugins\matplotlib_plugin.py:150: UserWarning: Low image data range; displaying image with stretched contrast.
  lo, hi, cmap = _get_display_range(image)
<matplotlib.image.AxesImage at 0x1f8e59c1c40>

Next, we apply scikit-image’s Watershed algorithm (example). It takes a distance image and a label image as input. Optional input is the binary_image to limit spreading the labels to far.

labels = watershed(-blurred_distance, markers, mask=binary_image)
imshow(labels, cmap='jet')
<matplotlib.image.AxesImage at 0x1f8e5a97f40>

To create a binary image again as ImageJ does, we now identify the edges between the labels.

# identify label-cutting edges
edges_labels = sobel(labels)
edges_binary = sobel(binary_image)

edges = np.logical_xor(edges_labels != 0, edges_binary != 0)
<matplotlib.image.AxesImage at 0x1f8e5aaea00>

Next we subtract those edges from the original binary_image.

almost = np.logical_not(edges) * binary_image
<matplotlib.image.AxesImage at 0x1f8e55c8610>

As this result is not perfect yet, we apply a binary opening.

result = binary_opening(almost)
<matplotlib.image.AxesImage at 0x1f8e3b81e50>