Inference with RecurrentGemma using JAX and Flax

View on ai.google.dev Run in Google Colab Open in Vertex AI View source on GitHub

This tutorial demonstrates how to perform basic sampling/inference with the RecurrentGemma 2B Instruct model using Google DeepMind's recurrentgemma library that was written with JAX (a high-performance numerical computing library), Flax (the JAX-based neural network library), Orbax (a JAX-based library for training utilities like checkpointing), and SentencePiece (a tokenizer/detokenizer library). Although Flax is not used directly in this notebook, Flax was used to create Gemma and RecurrentGemma (the Griffin model).

This notebook can run on Google Colab with the T4 GPU (go to Edit > Notebook settings > Under Hardware accelerator select T4 GPU).

Setup

The following sections explain the steps for preparing a notebook to use a RecurrentGemma model, including model access, getting an API key, and configuring the notebook runtime

Set up Kaggle access for Gemma

To complete this tutorial, you first need to follow the setup instructions similar to Gemma setup with a few exceptions:

  • Get access to RecurrentGemma (instead of Gemma) on kaggle.com.
  • Select a Colab runtime with sufficient resources to run the RecurrentGemma model.
  • Generate and configure a Kaggle username and API key.

After you've completed the RecurrentGemma setup, move on to the next section, where you'll set environment variables for your Colab environment.

Set environment variables

Set environment variables for KAGGLE_USERNAME and KAGGLE_KEY. When prompted with the "Grant access?" messages, agree to provide secret access.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Install the recurrentgemma library

This notebook focuses on using a free Colab GPU. To enable hardware acceleration, click on Edit > Notebook settings > Select T4 GPU > Save.

Next, you need to install the Google DeepMind recurrentgemma library from github.com/google-deepmind/recurrentgemma. If you get an error about "pip's dependency resolver", you can usually ignore it.

pip install git+https://github.com/google-deepmind/recurrentgemma.git

Load and prepare the RecurrentGemma model

  1. Load the RecurrentGemma model with kagglehub.model_download, which takes three arguments:
  • handle: The model handle from Kaggle
  • path: (Optional string) The local path
  • force_download: (Optional boolean) Forces to re-download the model
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download...
100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s]
Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
  1. Check the location of the model weights and the tokenizer, then set the path variables. The tokenizer directory will be in the main directory where you downloaded the model, while the model weights will be in a sub-directory. For example:
  • The tokenizer.model file will be in /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1).
  • The model checkpoint will be in /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it).
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model

Perform sampling/inference

  1. Load the RecurrentGemma model checkpoint with the recurrentgemma.jax.load_parameters method. The sharding argument set to "single_device" loads all model parameters on a single device.
import recurrentgemma
from recurrentgemma import jax as recurrentgemma

params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
  1. Load the RecurrentGemma model tokenizer, constructed using sentencepiece.SentencePieceProcessor:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. To automatically load the correct configuration from the RecurrentGemma model checkpoint, use recurrentgemma.GriffinConfig.from_flax_params_or_variables. Then, instantiate the Griffin model with recurrentgemma.jax.Griffin.
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
    flax_params_or_variables=params)

model = recurrentgemma.Griffin(model_config)
  1. Create a sampler with recurrentgemma.jax.Sampler on top of the RecurrentGemma model checkpoint/weights and the tokenizer:
sampler = recurrentgemma.Sampler(
    model=model,
    vocab=vocab,
    params=params,
)
  1. Write a prompt in prompt and perform inference. You can tweak total_generation_steps (the number of steps performed when generating a response — this example uses 50 to preserve host memory).
prompt = [
    "\n# 5+9=?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=50,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
Prompt:

# 5+9=?
Output:


# Answer: 14

# Explanation: 5 + 9 = 14.

Learn more