Dash for interactive image visualization#

This notebook demonstrates how to load multiple images, analyze their intensity statistics, and create an interactive visualization that combines a scatter plot with an image viewer. Visualizations are done using plotly and interaction on top comes with dash.

The underlying technical challenge is that when hovering with our mouse over a data point in a plot, python code gets executed and updates our image viewer.

import os
import numpy as np
import pandas as pd
from skimage.io import imread
import stackview
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from ipywidgets import HBox, VBox, Label
import plotly.express as px
import dash
from dash import dcc, html, Input, Output, callback
import json
import plotly.express as px

Load Images from Directory#

We’ll load all images from the data/images directory and store them in a list along with their filenames.

# Define the image directory
image_dir = "data/images"

# Load all images
images = []
filenames = []

for filename in os.listdir(image_dir):
    filepath = os.path.join(image_dir, filename)
    image = imread(filepath)
    images.append(image)
    filenames.append(filename)

images = np.array(images, dtype=object)
print(f"Loaded {len(images)} images")
Loaded 20 images

Calculate Intensity Statistics#

For each image, we’ll calculate the mean and standard deviation of pixel intensities.

# Calculate mean and standard deviation for each image
mean_intensities = []
std_intensities = []

for i, image in enumerate(images):
    mean_intensities.append(np.mean(image))
    std_intensities.append(np.std(image))

print(f"Mean intensity range: {min(mean_intensities):.2f} - {max(mean_intensities):.2f}")
print(f"Standard deviation range: {min(std_intensities):.2f} - {max(std_intensities):.2f}")
Mean intensity range: 1.65 - 93.18
Standard deviation range: 10.39 - 60.38

Create DataFrame with Measurements#

We’ll organize our data into a pandas DataFrame for easier handling.

# Create DataFrame with measurements
df = pd.DataFrame({
    'filename': filenames,
    'mean_intensity': mean_intensities,
    'std_intensity': std_intensities,
    'image_index': range(len(images))
})

display(df.head())
filename mean_intensity std_intensity image_index
0 synthetic_image_00.png 15.734467 30.791965 0
1 synthetic_image_01.png 7.520798 22.516088 1
2 synthetic_image_02.png 19.802139 35.109714 2
3 synthetic_image_03.png 1.667297 10.391785 3
4 synthetic_image_04.png 19.541229 34.701846 4

Interactive Visualization#

We’ll create a plotly scatter plot showing mean vs standard deviation intensity, alongside an image viewer. When hovering over data points, the image viewer will update to show the corresponding image. The interaction is implemented using dash.

# Create the plotly scatter plot
fig = go.Figure()

fig.add_trace(go.Scatter(
    x=df['mean_intensity'],
    y=df['std_intensity'],
    mode='markers',
    marker=dict(size=8, opacity=0.7),
    text=df['filename'],
    hovertemplate='<b>%{text}</b><br>' +
                  'Mean Intensity: %{x:.2f}<br>' +
                  'Std Intensity: %{y:.2f}<br>' +
                  '<extra></extra>',
    customdata=df['image_index']
))

fig.update_layout(
    title='Mean vs Standard Deviation Intensity',
    xaxis_title='Mean Intensity',
    yaxis_title='Standard Deviation Intensity',
    width=600,
    height=400
)

# Function to create image figure
def create_image_figure(image, title="Image"):
    if len(image.shape) == 3 and image.shape[0] > 1:
        # For 3D images, show the middle slice
        img_slice = image[image.shape[0]//2]
    else:
        img_slice = image.squeeze()
    
    fig_img = px.imshow(img_slice, color_continuous_scale='gray', title=title)
    fig_img.update_layout(width=400, height=400)
    return fig_img

# Create initial image figure
initial_image_fig = create_image_figure(images[0], filenames[0])

# Create Dash app
app = dash.Dash(__name__)

app.layout = html.Div([
    html.Div([
        dcc.Graph(
            id='scatter-plot',
            figure=fig
        )
    ], style={'display': 'inline-block', 'width': '50%'}),
    
    html.Div([
        dcc.Graph(
            id='image-viewer',
            figure=initial_image_fig
        )
    ], style={'display': 'inline-block', 'width': '50%'}),
    
    html.Div(id='debug-output')
])

# Define the callback function to update the image viewer
@app.callback(
    [Output('image-viewer', 'figure'),
     Output('debug-output', 'children')],
    Input('scatter-plot', 'hoverData')
)
def update_image_viewer(hoverData):
    """Update the image viewer when hovering over plot points"""
    if hoverData is not None:
        # Get the index of the hovered point
        point_index = hoverData['points'][0]['pointIndex']
        image_index = df.iloc[point_index]['image_index']
        
        # Create new image figure
        image_fig = create_image_figure(images[image_index], filenames[image_index])
        debug_msg = f"Showing image: {filenames[image_index]}"
        
        return image_fig, debug_msg
    
    return initial_image_fig, "Hover over a point to see the corresponding image"

app.run(debug=True, mode='inline')

Note that the code above was AI-generated using bia-bob and model claude-sonnet-4-20250514.