Introducing PaliGemma: Google’s Latest Visual Language Model

PaliGemma pushes the boundaries for efficient multi-modality in Visual Language Models through task-specific finetuning that is highly competitive with larger architectures.

Leonard So
Editor

What is PaliGemma?

PaliGemma is a new class of Visual Language Models (VLMs) that is part of Google's latest release as a push for multi-modal VLMs and model efficiency. While VLMs have been trending larger and larger to get closer to achieving universal input and output, PaliGemma alternatively focuses on developing relatively compact architectures that can be fine-tuned to perform competitively with larger models.

Model Architecture

PaliGemma was designed using the same setup as PaLI-3, an earlier multi-modal VLM. It consists of a visual encoder, SigLIP (Zhai et al., 2023), a robust contrastively trained visual encoder competing with OpenAI’s CLIP by using a simpler but computationally cheaper loss function, the sigmoid loss instead of the softmax cross entropy loss employed by CLIP for encoding images. Secondly, it incorporates a relatively compact decoder-only model, Google’s Gemma (Mesnard et al., 2024). It utilizes Gemma’s tokenizer to tokenize the accompanying input text as well as process all the tokens. Gemma’s tokenizer contains 256,000 tokens, and PaliGemma extends the token vocabulary with 1024 entries representing coordinates in a normalized image space, i.e. <loc0000>, …, <loc1023>, and 128 entries, i.e. <seg000>, …, <seg127> coming from a latent codebook for segmentation using a vector quantized visual auto-encoder developed by All in Tokens: Unifying Output Space of Visual Tasks via Soft Token (Ning et al, 2023). These tokens act as soft tokens, in which the next predicted token in the auto-regressive pipeline is created from a weighted average of tokens.

They are inputted into Gemma’s transformer-based decoder that is largely similar to the original transformer decoder by Vaswani et al. (2017), with a few modern modifications, such as multi-head or multi-query attention, rotary positional embeddings, GeGLU as a replacement for the traditional ReLU, and Root Mean Square Norm (RMSNorm). The decoder then outputs the tokens up to a limit depending on the size of the model.

For model outputs for detection and segmentation, location tokens like <loc0000> represent the top left and bottom right xy-coordinates normalized to a 1024x1024 image, as well as segmentation tokens that can be mapped back to local binary masks. For regular text output, token decoders will output text related tokens that can be mapped back to standard text.

Training Stages

Similar to PaLI-3, PaliGemma's training consists of multiple stages if readers desire to replicate the process on their own:

  • Uni-Modal Pre-Training: PaliGemma uses the SigLIP and Gemma models with their pre-trained weights as is, which were trained on uni-modal datasets.
  • Multi-Modal Pre-Training: The combined PaliGemma model is pre-trained on a fully vision and language training dataset using a relatively small resolution of 224x224 image inputs and a text sequence length of 128 tokens (prefix and suffix included). This pre-training outputs the pre-trained base model weights that can then be taken for fine-tuning.
  • High-Resolution Pre-Training: Further pre-training can be performed at larger resolutions and token sequence lengths. Researchers performed pre-trainings for 448x448 and 896x896 input image resolutions and a text sequence length of 512 tokens. The training data as the previous stage is the same, but it is subsequently re-weighted with more emphasis on training examples with a higher resolution or longer sequence length.
  • Fine-Tuning: The base models outputted from the previous two stages can then be fine-tuned for specific tasks, that will be further described below.

Task Loss Function

For various tasks, loss functions can vary. For instance, PaliGemma’s demo fine-tuning notebook utilizes mean perplexity for tokens as the generic loss across all tasks, which is a natural language sequence extension to the typical cross-entropy loss that is intuitive for individual token prediction. However, we would advocate for incorporating task-specific losses that better represent the practical performance of the model, such that it is capable of making more useful responses over the course of the training.

Available Model Checkpoints

Google DeepMind has provided an extensive list of checkpoints for immediate inference or further fine-tuning, primarily in three categories.

  • Pre-trained Checkpoints: Pre-trained models that can be fine-tuned for downstream tasks. These pre-trained models have three different resolutions: 224x224, 448x448, and 892x892.
  • Mix Checkpoints: Pre-trained models that have been fine-tuned to a mixture of tasks in the same training setting. They are designed for ease of use with general-purpose inference for free-text prompts, and can be used for research purposes only.
  • Fine-tuned Checkpoints: A set of fine-tuned models, each one specialized on a different academic benchmark. They are available in various resolutions and are intended for research purposes only.

Use Cases

PaliGemma has been trained across a breadth of practical and academic computer vision tasks, demonstrating its underlying flexibility and foundational knowledge in its adaptability to these tasks. Multi-modal VLMs open up flexibility to allow for data extraction and comprehension from textual and visual data, which most industrial data categorically falls under. The following are examples of PaliGemma executing specific tasks:

  • Image Captioning: Tasking the model to describe the image in a caption.
  • Visual Question Answering: Asking the model specific questions about details shown in the image.
  • Object Detection: Based on the prompt “detect [object description]”, output bounding box coordinates that bound the described object(s).
  • Object Segmentation: Based on the prompt “segment [object description]”, output polygon coordinates that bound the described object(s).
  • Document Understanding: Asking the model specific questions about details about text shown in the image.
  • Diagram Understanding: Asking the model specific questions about details about diagram information shown in the image.
  • Science Question Answering: Asking the model field-specific questions in which objective answers are necessary for correctness.

Running Inference on PaliGemma

PaliGemma Mix Checkpoints can be conveniently tested on supported downstream tasks such as question answering, object detection, and segmentation. We tested the pretrained 224- and 448- checkpoints on a safety monitoring dataset to identify helmets using the provided HuggingFace demo. On the question answering task, it was able to successfully detect that the man was wearing a helmet, as well as correctly identify the colour.

To kick things up a notch, we asked PaliGemma to draw a bounding box around the helmet, as well as segment its outline, both of which PaliGemma did with impressive precision.

As shown above, when a detect prefix is used, the output is 4 location token outputs, which show (390, 23) and (678, 327) as the top left and bottom right coordinates for the bounding box if the images was resized to 1024x1024.

When a segment prefix is used, the output is four location token outputs to represent the minimum bounding box for the mask, as well as 16 segment tokens, which can be decoded to represent a binary mask within the bounding box.

To test it out on your local system, you can install the transformers library and run the following code snippet:


from PIL import Image
import requests
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration

model = PaliGemmaForConditionalGeneration.from_pretrained("google/PaliGemma-test-224px-hf")
processor = AutoProcessor.from_pretrained("google/PaliGemma-test-224px-hf")

prompt = "answer en Where is the cow standing?"
url = "https://huggingface.co/gv-hf/PaliGemma-test-224px-hf/resolve/main/cow_beach_1.png"
image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(text=prompt, images=image, return_tensors="pt")

# Generate
generate_ids = model.generate(**inputs, max_length=30)
processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

>> "answer en Where is the cow standing?\nbeach"

PaliGemma does suffer when dealing with niche datasets and environments that were not present at the time of pretraining, which is expected given that the base model was not given a comprehensive and exhaustive pretraining unlike its other peers such as ChatGPT-4o. However, what PaliGemma can leverage to unleash its true potential is through further fine-tuning on specific downstream tasks.

Finetuning PaliGemma

To fine-tune one of the checkpoints on your own custom dataset, you can check out this Colab notebook. You will need to provide the model with both images and text inputs. These text inputs consist of two parts: prefixes and suffixes. Here's how they work:

  • Prefixes: These specify the task you want the model to perform. Supported tasks on the Mix Checkpoints include:
    • "caption en": To indicate that the model should generate an English caption for an image;
    • "detect": To instruct the model to perform object detection;
    • "segment": To instruct the model to perform image segmentation;
    • "how many animals are there?": Question format for the model to provide contextual answers regarding the image.
  • Suffixes: These provide the ground truths that the model should learn to predict. Examples of suffixes based on the supported tasks above include:
    • Image Captioning: "This is an image of a dog and cat posing for a picture";
    • Object Detection: "<loc0591><loc0252><loc0941><loc0784> dog" (four-corner bounding box coordinates);
    • Object Segmentation: "<loc0591><loc0252><loc0941><loc0784> ... <loc0823> cat" (polygon coordinates);
    • Question Answering: "Two".

When preprocessing these inputs, different attention mechanisms are applied to the prefixes and suffixes. This distinction in attention mechanisms is crucial because it allows the model to effectively understand the task (prefix) while generating the appropriate output (suffix) in a step-by-step manner.

  • Prefixes: Full attention (self-attention) is used. This means that every token in the prefix can attend to every other token. This is important because the prefix determines the model's function and requires a comprehensive understanding of the entire context.
  • Suffixes: Causal attention (masked attention) is used. This means that each token in the suffix can only attend to previous tokens and itself. This is typical for sequential generation tasks where the model generates text one token at a time, ensuring that future tokens do not influence the current token.

def preprocess_tokens(prefix, suffix=None, seqlen=None):
  # Model has been trained to handle tokenized text composed of a prefix with
  # full attention and a suffix with causal attention.
  separator = "\n"
  tokens = tokenizer.encode(prefix, add_bos=True) + tokenizer.encode(separator)
  mask_ar = [0] * len(tokens)    # 0 to use full attention for prefix.
  mask_loss = [0] * len(tokens)  # 0 to not use prefix tokens in the loss.

  if suffix:
    suffix = tokenizer.encode(suffix, add_eos=True)
    tokens += suffix
    mask_ar += [1] * len(suffix)    # 1 to use causal attention for suffix.
    mask_loss += [1] * len(suffix)  # 1 to use suffix tokens in the loss.

  mask_input = [1] * len(tokens)    # 1 if its a token, 0 if padding.
  if seqlen:
    padding = [0] * max(0, seqlen - len(tokens))
    tokens = tokens[:seqlen] + padding
    mask_ar = mask_ar[:seqlen] + padding
    mask_loss = mask_loss[:seqlen] + padding
    mask_input = mask_input[:seqlen] + padding

  return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input))

Similar to most other models, you can specify key hyperparameters in your training and evaluation loops, such as the number of training steps, batch size, learning rate, scheduler. The loss function is defined by default as predicting next tokens for the suffix (prefix and padded tokens are excluded from the calculation), since the notebook assumes a simple image captioning task where the output is text tokens.


def loss_fn(params):
    text_logits, _ = model.apply({"params": params}, imgs, txts[:, :-1], mask_ar[:, :-1], train=True)
    logp = jax.nn.log_softmax(text_logits, axis=-1)

    # The model takes as input txts[:, :-1] but the loss is defined as predicting
    # next tokens txts[:, 1:]. Additionally, mask_loss[:, 1:] indicates which tokens
    # are part of the loss (e.g. prefix and padded tokens are not included).
    mask_loss = batch["mask_loss"][:, 1:]
    targets = jax.nn.one_hot(txts[:, 1:], text_logits.shape[-1])

    # Compute the loss per example. i.e. the mean of per token pplx.
    # Since each example has a different number of tokens we normalize it.
    token_pplx = jnp.sum(logp * targets, axis=-1)  # sum across vocab_size.
    example_loss = -jnp.sum(token_pplx * mask_loss, axis=-1)  # sum across seq_len.
    example_loss /= jnp.clip(jnp.sum(mask_loss, -1), 1)  # weight by num of tokens.

    # batch_loss: mean of per example loss.
    return jnp.mean(example_loss)

However, you can customize your own loss function that best suits your downstream task, such as localization loss for bounding boxes, or pixel classification loss for instance segmentation.

What’s Next?

If you have questions, feel free to join our Community Slack to post your questions or contact us to finetune your own PaliGemma model on Datature Nexus

For more detailed information about the model functionality, customization options, or answers to any common questions you might have about VLMs, read more on our Developer Portal.

Developer’s Roadmap

Datature recognizes the importance of multi-modal VLMs in practical use cases. With this in mind, we will be incorporating the PaliGemma architecture for fine-tuning, such that Datature Nexus users will be able to import and annotate multi-modal datasets that can be used to train a PaliGemma model for their specific use cases. To learn more about how to annotate multi-modal datasets on Datature, you can check out this article on Nexus support for ontologies and metadata. We also note that while achieving academic benchmarks is significant, practical deployment still requires other considerations such as guard railing to ensure consistency and quality.

Build models with the best tools.

develop ml models in minutes with datature

START A PROJECT