Vision Large Language Models for Counting objects#
In this notebook we use OpenAI’s LLMs with Vision capabilities to see how well they can count blobs in blobs.tif.
Note: It is not recommended to use this approach for counting objects in microscopy images. The author of this notebook is not aware of any publication showing that this approach works well.
import openai
import PIL
import stackview
from skimage.io import imread
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
We will need some helper functions for assembling a prompt and submitting it to the openai server.
def prompt_with_image(message:str, image, model="gpt-4o-2024-05-13"):
"""A prompt helper function that sends a text message and an image
to openAI and returns the text response.
"""
import os
# convert message in the right format if necessary
if isinstance(message, str):
message = [{"role": "user", "content": message}]
image_message = image_to_message(image)
# setup connection to the LLM
client = openai.OpenAI()
# submit prompt
response = client.chat.completions.create(
model=model,
messages=message + image_message
)
# extract answer
return response.choices[0].message.content
def image_to_message(image):
import base64
from stackview._image_widget import _img_to_rgb
rgb_image = _img_to_rgb(image)
byte_stream = numpy_to_bytestream(rgb_image)
base64_image = base64.b64encode(byte_stream).decode('utf-8')
return [{"role": "user", "content": [{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
}]}]
def numpy_to_bytestream(data):
"""Turn a NumPy array into a bytestream"""
import numpy as np
from PIL import Image
import io
# Convert the NumPy array to a PIL Image
image = Image.fromarray(data.astype(np.uint8)).convert("RGBA")
# Create a BytesIO object
bytes_io = io.BytesIO()
# Save the PIL image to the BytesIO object as a PNG
image.save(bytes_io, format='PNG')
# return the beginning of the file as a bytestream
bytes_io.seek(0)
return bytes_io.read()
This is the example image we will be using.
image = imread("../../data/blobs.tif")
stackview.insight(image)
|
This is the prompt we submit to the server.
my_prompt = """
Analyse the following image by counting the bright blobs. Respond with the number only.
"""
prompt_with_image(my_prompt, image)
'64'
Benchmarking vision-LLMs#
We can run this prompt in a loop for a couple of vision models.
num_samples = 25
models = {
"gpt-4-vision-preview":[],
"gpt-4-turbo-2024-04-09":[],
"gpt-4o-2024-05-13":[],
}
for model in models.keys():
samples = []
while len(samples) < num_samples:
result = prompt_with_image(my_prompt, image)
try:
samples.append(int(result))
except:
print("Error processing result:", result)
models[model] = samples
sampled_models = pd.DataFrame(models)
Let’s get an overview about samples:
# Extract the two columns for comparison
columns_to_plot = sampled_models[models.keys()]
# Melt the dataframe to prepare for plotting
df_melted = columns_to_plot.melt(var_name='Model', value_name='Blob count')
# Draw the violin plot
plt.figure(figsize=(8, 4))
sns.violinplot(x='Model', y='Blob count', data=df_melted)
plt.title('Vision models counting blobs')
plt.show()
These are the results in detail:
sampled_models
gpt-4-vision-preview | gpt-4-turbo-2024-04-09 | gpt-4o-2024-05-13 | |
---|---|---|---|
0 | 56 | 56 | 58 |
1 | 52 | 52 | 54 |
2 | 53 | 54 | 69 |
3 | 48 | 59 | 50 |
4 | 62 | 51 | 63 |
5 | 58 | 54 | 55 |
6 | 56 | 55 | 56 |
7 | 69 | 58 | 57 |
8 | 53 | 60 | 50 |
9 | 50 | 78 | 51 |
10 | 63 | 52 | 54 |
11 | 120 | 56 | 65 |
12 | 56 | 64 | 55 |
13 | 61 | 57 | 57 |
14 | 52 | 56 | 46 |
15 | 64 | 52 | 54 |
16 | 74 | 53 | 63 |
17 | 51 | 57 | 52 |
18 | 52 | 49 | 63 |
19 | 52 | 72 | 51 |
20 | 48 | 47 | 51 |
21 | 52 | 54 | 50 |
22 | 67 | 50 | 58 |
23 | 52 | 56 | 48 |
24 | 65 | 54 | 54 |
sampled_models.describe()
gpt-4-vision-preview | gpt-4-turbo-2024-04-09 | gpt-4o-2024-05-13 | |
---|---|---|---|
count | 25.000000 | 25.000000 | 25.000000 |
mean | 59.440000 | 56.240000 | 55.360000 |
std | 14.399306 | 6.765599 | 5.692685 |
min | 48.000000 | 47.000000 | 46.000000 |
25% | 52.000000 | 52.000000 | 51.000000 |
50% | 56.000000 | 55.000000 | 54.000000 |
75% | 63.000000 | 57.000000 | 58.000000 |
max | 120.000000 | 78.000000 | 69.000000 |