Classifying images with Gemini Flash 1.5

llm
gemini
Author
Affiliation
Published

September 8, 2024

Modified

September 28, 2024

Most people think of In-Context Learning (ICL) — the ability of LLMs to learn from examples provided in the context — only as a component of RAG applications.

I used to think of it that way too. Until I recently found out that Multimodal Large Language Models (MLLMs) with ICL can be used to perform more traditional ML tasks such as image classification.

I was skeptical at first, but was surprised to see that it worked pretty well both in the literature (see here and here) and in my own experiments.

You shouldn’t expect state-of-the-art results with it, but it can often give you pretty good results with very little effort and data.

In this tutorial, I’ll show you how to use ICL to classify images using Gemini Flash 1.5.

Why Gemini Flash 1.5?

You can use any MLLM for this task, but I chose Gemini Flash 1.5 because:

  1. It’s cheaper than Gemini Pro 1.5, GPT-4o, and Sonnet 3.5. For an image of 512x512 pixels, Gemini Flash 1.5 is 50x cheaper than Gemini Pro 1.5, 5x to 16x cheaper than GPT-4o, and 26x cheaper than Sonnet 3.51.
  2. It lets you use up to 3,000 images per request. By trial and error, I found that GPT-4o seems to have a hard limit at 250 images per request and Sonnet 3.5’s documentation mentions a limit of 20 images per request.
  3. It works well. If you really want to squeeze the last bit of performance out of your model, you can use a bigger model, but for the purposes of this tutorial, Gemini Flash 1.5 will do just fine.

Regardless of the model you choose, this tutorial will be a good starting point for you to classify images using ICL.

Prerequisites

To follow this tutorial you’ll need to:

  1. Sign up and generate an API key in Google AI Studio.
  2. Set the API key as an environment variable called GEMINI_API_KEY.
  3. Download this dataset and save it to data/.
  4. Create a virtual environment and install the requirements:
python -m venv venv
source venv/bin/activate
pip install pandas numpy scikit-learn google-generativeai pillow

Set up

As usual, you start by importing the necessary libraries:

import json
import os
import warnings

import google.generativeai as genai
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score
from PIL import Image
from pathlib import Path

warnings.filterwarnings("ignore")

np.random.seed(42)

In addition to the usual popular libraries (e.g. pandas, sklearn), you’ll need:

  • google.generativeai for interacting with the Gemini API
  • PIL for handling images
  • sklearn for calculating performance metrics

Then, you’ll need to configure the Gemini API client with your API key:

genai.configure(api_key=os.environ["GEMINI_API_KEY"])

This will take the GEMINI_API_KEY environment variable and use it to authenticate your requests to the Gemini API.

Read data

To make a fair evaluation of the model’s performance, you should split the dataset into separate training and testing sets. The training set is used to provide context or examples to the model during inference. The testing set, comprised of unseen images, is then used to measure the model’s performance.

This process is different from the traditional “training” process, where you update the model’s weights or parameters. Here, you’re only providing the model with a set of images and asking it to learn from them at inference time.

This function will help you create the datasets:

def create_datasets(train_dir, test_dir, selected_classes, n_images_icl=3):
    train_data = []
    test_data = []

    for class_id, class_name in enumerate(selected_classes):
        train_class_dir = train_dir / class_name
        test_class_dir = test_dir / class_name

        if not train_class_dir.is_dir() or not test_class_dir.is_dir():
            continue

        # Train dataset
        train_image_files = list(train_class_dir.glob("*.jpg"))
        selected_train_images = np.random.choice(
            train_image_files,
            size=min(n_images_icl, len(train_image_files)),
            replace=False,
        )
        for img_path in selected_train_images:
            train_data.append(
                {
                    "image_path": str(img_path),
                    "class_id": f"class_{class_id}",
                    "class_name": class_name,
                }
            )

        # Test dataset
        test_image_files = list(test_class_dir.glob("*.jpg"))
        for img_path in test_image_files:
            test_data.append(
                {
                    "image_path": str(img_path),
                    "class_id": f"class_{class_id}",
                    "class_name": class_name,
                }
            )

    df_train = pd.DataFrame(train_data)
    df_test = pd.DataFrame(test_data).sample(frac=1).reset_index(drop=True)

    return df_train, df_test

This function will get a random selection of n_images_icl images per class from the train folder (that you’ll later use in the model’s context). For the testing set, which you’ll use to measure the model’s performance, you’ll use all the available images in the test folder from those classes.

To keep things simple, you’ll start by selecting 15 different classes and 1 image per class for the context (i.e., n_images_icl=1)

DATA_DIR = "../data/"
TRAIN_DIR = Path(DATA_DIR) / "train"
TEST_DIR = Path(DATA_DIR) / "test"

all_classes = list(os.listdir(TRAIN_DIR))
selected_classes = np.random.choice(all_classes, size=15, replace=False)

df_train, df_test = create_datasets(TRAIN_DIR, TEST_DIR, selected_classes=selected_classes, n_images_icl=1)

There will be 15 classes with 1 image in the training set and 15 classes with 5 images in the testing set.

Gemini Flash 1.5

Next, you’ll need to define a system prompt and configure the model to use it.

Define prompt

You’ll use a system prompt that will tell the model how to classify the images and the format you want the output to be in:

CLASSIFIER_SYSTEM_PROMPT = """You are an expert lepidopterist.

Your task is to classify images of butterflies into one of the provided labels.

Provide your output as a JSON object using this format:

{
    "number_of_labeled_images": <integer>,
    "output": [
        {
            "image_id": <image id, integer, starts at 0>,
            "confidence": <number between 0 and 10, the higher the more confident, integer>,
            "label": <label of the correct butterfly species, string>
        }, 
        ...
    ]
}

## Guidelines

- ALWAYS produce valid JSON.
- Generate ONLY a single prediction per input image.
- The `number_of_labeled_images` MUST be the same as the number of input images.

This is an example of a valid output:
```
{
  "number_of_labeled_images": 5,
  "output": [
      {
        "image_id": 0,
        "confidence": 10,
        "correct_label": "class_B"
      },
      {
        "image_id": 1,
        "confidence": 9,
        "correct_label": "class_C"
      },
      {
        "image_id": 2,
        "confidence": 4,
        "correct_label": "class_A"
      },
      {
        "image_id": 3,
        "confidence": 2,
        "correct_label": "class_B"
      },
      {
        "image_id": 4,
        "confidence": 10,
        "correct_label": "class_C"
      }
  ]
}
```
""".strip()

This prompt explains the task to the model. You’re providing it with a set of labels with corresponding images, and a set of images that should be classified into one of those labels. The model needs to output a single label for each image.

I included an additional field called number_of_labeled_images because I noticed that the model would often “forget” to include all the labels in the output, and this was a simple way to ensure that it did so.

Note

Fun fact: I didn’t know that lepidopterist was a word until I wrote this prompt.

Configure model

Then, you can define and configure the model:

generation_config = {
  "temperature": 1,
  "max_output_tokens": 8192,
  "response_mime_type": "application/json",
}
classification_model = genai.GenerativeModel(
    "gemini-1.5-flash", 
    system_instruction=CLASSIFIER_SYSTEM_PROMPT, 
    generation_config=generation_config
)

This sets up the model with the following configuration:

  • temperature=1: Controls the randomness of the model’s output.
  • max_output_tokens=8192: The maximum number of tokens the model can generate.
  • response_mime_type="application/json": Tells the model to produce JSON.

It also sets the system_instruction using the prompt you defined earlier and uses gemini-1.5-flash as the model.

Building the context

Gemini has a slightly different way of building the messages (context) used by the model.

Most providers have adjusted their API to match OpenAI’s messages format. Gemini, however, uses a list of strings and media files (if you’re including images).

You can use these functions for that:

def create_context_images_message(df):
    messages = ["Possible labels:"]
    grouped = df.groupby('class_id')
    for class_id, group in grouped:
        for _, row in group.iterrows():
            base64_img = Image.open(row["image_path"])
            messages.append(base64_img)
        messages.append(f"label: {class_id}")
    return messages
    
context_images_message = create_context_images_message(df_train)

First, you’ll create a message with the context images and their corresponding labels. This is the “training” part of ICL.

In create_context_images_message, you’re iterating over the training dataset, grouping the images by class and appending the images and labels to the messages list.

The resulting message will look something like this:

context_images_message[:5]
['Possible labels:',
 <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=224x224>,
 'label: class_0',
 <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=224x224>,
 'label: class_1']

You might have noticed that instead of the actual names of the classes, you’re using class_0, class_1, etc. This is because I want to make the model prediction as “fair” as possible, see the baseline performance section for more details.

Then, you’ll create a message with the input images. This are the images for which the model will generate predictions.

Simlar to the context images message, you’re iterating over the test dataset and appending the images to the messages list.

def create_input_images_message(df):
    messages = ["Input images:"]
    for i, image_path in enumerate(df.image_path):
        base64_img = Image.open(image_path)
        image_message = [
            base64_img,
            f"input_image_id: {i}",
        ]
        messages.extend(image_message)
    messages.append(f"Please correctly classify all {df.shape[0]} images.")
    return messages

input_images_message = create_input_images_message(df_test)

The resulting message will look something like this:

input_images_message[:5]
['Input images:',
 <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=224x224>,
 'input_image_id: 0',
 <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=224x224>,
 'input_image_id: 1']

Results

Now, you can combine the context images message and the input images message to create the contents you’ll pass to the model:

contents = context_images_message + input_images_message
response = classification_model.generate_content(
    contents=contents
)
response_json = json.loads(response.text)

It’ll take a few seconds to run. But after that you’ll have a JSON response with the model’s predictions:

response_json["output"][:3]
[{'image_id': 0, 'confidence': 10, 'label': 'class_7'},
 {'image_id': 1, 'confidence': 10, 'label': 'class_2'},
 {'image_id': 2, 'confidence': 10, 'label': 'class_4'}]

Then, you can calculate the accuracy and F1-score to evaluate the model’s performance:

def calculate_metrics(df_test, response_json):
    predictions = [item['label'] for item in response_json['output']]
    accuracy = accuracy_score(df_test.class_id, predictions)
    f1 = f1_score(df_test.class_id, predictions, average='weighted')
    return accuracy, f1

accuracy, f1 = calculate_metrics(df_test, response_json)
print(f"Accuracy: {accuracy:.4f}")
print(f"F1-score: {f1:.4f}")
Accuracy: 0.7333
F1-score: 0.7229

Using a single image in the context per class, you should get an accuracy around 73% and F1-score around 72%.

Not bad, but you can probably do better.

Using 5 images per class in the context

One quick way to improve the performance of the model is to use more images per class in the context. Try with 5 images per class:

df_train, df_test = create_datasets(TRAIN_DIR, TEST_DIR, selected_classes=selected_classes, n_images_icl=5)

# Create the context and input messages
context_images_message = create_context_images_message(df_train)
input_images_message = create_input_images_message(df_test)
contents = context_images_message + input_images_message

# Generate the response
response = classification_model.generate_content(
    contents=contents
)
response_json = json.loads(response.text)

# Calculate the metrics
accuracy, f1 = calculate_metrics(df_test, response_json)
print(f"Accuracy: {accuracy:.4f}")
print(f"F1-score: {f1:.4f}")
Accuracy: 0.9067
F1-score: 0.9013

With this change, you should get an accuracy and F1-score around 90%.

Nice gains in performance for such a small change!

Data leakage and baseline performance

You might be thinking, “MLLMs have been trained on a lot of data, so they already know a lot of the images in the dataset, which means that these results are inflated”.

Which is a good point, and for that purpose I’ve done two things:

  1. Anonymize the names of the classes (e.g., class_0 instead of Sleepy Orange), so that the model doesn’t have any information about the actual labels.
  2. Run a quick experiment using a zero-shot2 model without anonymizing the labels to see the model’s performance.

Here’s the code for the zero-shot baseline and the results:

possible_labels = "Possible labels: " + ", ".join(df_train.class_name.unique().tolist())
class_name_to_id = dict(zip(df_train['class_name'], df_train['class_id']))

response = classification_model.generate_content(
    contents=[possible_labels] + input_images_message
)
response_json = json.loads(response.text)

for item in response_json["output"]:
    item['label'] = class_name_to_id.get(item['label'], item['label'])

accuracy, f1 = calculate_metrics(df_test, response_json)
print(f"Accuracy: {accuracy:.4f}")
print(f"F1-score: {f1:.4f}")
Accuracy: 0.4800
F1-score: 0.4619

You should get a 48% accuracy and a 46% F1-score. Both significantly higher than the ~7% you’d expect from random guessing, but still far from the 90%+ accuracy you obtained earlier.

This demonstrates that ICL can indeed enhance the model’s performance.

Conclusion

That’s all!

I still find it amazing that without any “real” training and just a few minutes of work, you can achieve pretty good results in a non-trivial image classification task using ICL with Gemini Flash 1.5 (or most other MLLMs).

This is a mostly unexplored area. There’s a lot of room for trying out different ideas and seeing what works best. This tutorial is just a starting point.

Hope you found it useful! Let me know if you have any questions in the comments below.

Footnotes

  1. Estimated costs as of September 8, 2024:

    Model Cost (512x512 image)
    Gemini Flash 1.5 $0.000039
    Gemini Pro 1.5 $0.0018
    GPT-4o $0.000213 - $0.000638
    Sonnet 3.5 $0.001047
    ↩︎
  2. That is, without providing any context images.↩︎