Mimicking ImageJ’s watershed algorithm#
In ImageJ there is an algorithm called “Watershed” which allows splitting segmented dense objects within a binary image. This notebook demonstrates how to achieve a similar operation in Python.
The “Watershed” in ImageJ was applied to a binary image using this macro:
open("../BioImageAnalysisNotebooks/data/blobs_otsu.tif");
run("Watershed");
from skimage.io import imread, imshow
import matplotlib.pyplot as plt
import napari_segment_blobs_and_things_with_membranes as nsbatwm
import numpy as np
from scipy import ndimage as ndi
from skimage.feature import peak_local_max
from skimage.filters import gaussian, sobel
from skimage.measure import label
from skimage.segmentation import watershed
from skimage.morphology import binary_opening
Starting point for the demonstration is a binary image.
binary_image = imread("../../data/blobs_otsu.tif")
imshow(binary_image)
<matplotlib.image.AxesImage at 0x1f8e41b2700>

After applying the macro shown above, the result image in ImageJ looksl ike this:
binary_watershed_imagej = imread("../../data/blobs_otsu_watershed.tif")
imshow(binary_watershed_imagej)
<matplotlib.image.AxesImage at 0x1f8e4209100>

The Napari plugin napari-segment-blobs-and-things-with-membranes offers a function for mimicking the functionality from ImageJ.
binary_watershed_nsbatwm = nsbatwm.split_touching_objects(binary_image)
imshow(binary_watershed_nsbatwm)
<matplotlib.image.AxesImage at 0x1f8e42a63d0>

Comparing results#
When comparing results, it is obvious that the results are not 100% identical.
fig, axs = plt.subplots(1, 2, figsize=(10,10))
axs[0].imshow(binary_watershed_imagej)
axs[0].set_title("ImageJ")
axs[1].imshow(binary_watershed_nsbatwm)
axs[1].set_title("nsbatwm")
Text(0.5, 1.0, 'nsbatwm')

Fine-tuning results#
Modifying the result is possible by tuning the sigma
parameter.
fig, axs = plt.subplots(1, 4, figsize=(10,10))
for i, sigma in enumerate(np.arange(2, 6, 1)):
result = nsbatwm.split_touching_objects(binary_image, sigma=sigma)
axs[i].imshow(result)
axs[i].set_title("sigma="+str(sigma))

How does it work?#
Under the hood, ImageJ’s watershed algorithm uses a distance image and spot-detection. The following code attempts to replicate the result.
Again, we start from the binary image.
imshow(binary_image)
<matplotlib.image.AxesImage at 0x1f8e55b87f0>

The first step is to produce a distance image.
distance = ndi.distance_transform_edt(binary_image)
imshow(distance)
C:\Users\haase\mambaforge\envs\bio39\lib\site-packages\skimage\io\_plugins\matplotlib_plugin.py:150: UserWarning: Float image out of standard range; displaying image with stretched contrast.
lo, hi, cmap = _get_display_range(image)
<matplotlib.image.AxesImage at 0x1f8e55f17c0>

To avoid very small split objects, we blur the distance image using the sigma
parameter.
sigma = 3.5
blurred_distance = gaussian(distance, sigma=sigma)
imshow(blurred_distance)
<matplotlib.image.AxesImage at 0x1f8e56ec5e0>

Within this blurred image, we search for local maxima and receive them as list of coordinates.
fp = np.ones((3,) * binary_image.ndim)
coords = peak_local_max(blurred_distance, footprint=fp, labels=binary_image)
# show the first 5 only
coords[:5]
array([[ 8, 254],
[ 97, 1],
[ 10, 108],
[230, 180],
[182, 179]], dtype=int64)
We next write these maxima into a new image and label them.
mask = np.zeros(distance.shape, dtype=bool)
mask[tuple(coords.T)] = True
markers = label(mask)
imshow(markers, cmap='jet')
C:\Users\haase\mambaforge\envs\bio39\lib\site-packages\skimage\io\_plugins\matplotlib_plugin.py:150: UserWarning: Low image data range; displaying image with stretched contrast.
lo, hi, cmap = _get_display_range(image)
<matplotlib.image.AxesImage at 0x1f8e59c1c40>

Next, we apply scikit-image’s Watershed algorithm (example). It takes a distance image and a label image as input. Optional input is the binary_image
to limit spreading the labels too far.
labels = watershed(-blurred_distance, markers, mask=binary_image)
imshow(labels, cmap='jet')
<matplotlib.image.AxesImage at 0x1f8e5a97f40>

To create a binary image again as ImageJ does, we now identify the edges between the labels.
# identify label-cutting edges
edges_labels = sobel(labels)
edges_binary = sobel(binary_image)
edges = np.logical_xor(edges_labels != 0, edges_binary != 0)
imshow(edges)
<matplotlib.image.AxesImage at 0x1f8e5aaea00>

Next we subtract those edges from the original binary_image
.
almost = np.logical_not(edges) * binary_image
imshow(almost)
<matplotlib.image.AxesImage at 0x1f8e55c8610>

As this result is not perfect yet, we apply a binary opening.
result = binary_opening(almost)
imshow(result)
<matplotlib.image.AxesImage at 0x1f8e3b81e50>
