RecurrentGemma is a family of open-weights Language Models by Google DeepMind, based on the novel Griffin architecture. This architecture achieves fast inference when generating long sequences by replacing global attention with a mixture of local attention and linear recurrences.
This repository contains the model implementation and examples for sampling and fine-tuning. We recommend most users adopt the Flax implementation, which is highly optimized. We also provide an un-optimized PyTorch implementation for reference.
- The RecurrentGemma technical report gives specific details on the training and evaluation of RecurrentGemma.
- The Griffin paper describes the underlying model architecture.
RecurrentGemma uses Poetry for dependency management.
To install dependencies for the full project:
- Checkout the code.
poetry install -E full
to create a virtual environment with all dependencies.poetry shell
to activate the created virtual environment.
If you only need to install a subset of dependencies use one of the alternative library-specific commands below.
If you want to use pip
instead of Poetry,
then create a virtual environment (run python -m venv recurrentgemma-demo
and . recurrentgemma-demo/bin/activate
) and:
- Checkout the code.
pip install .[full]
To install dependencies only for the JAX pathway use:
poetry install -E jax
or (pip install .[jax]
).
To install dependencies only for the PyTorch pathway use:
poetry install -E torch
(or pip install .[torch]
).
To install dependencies required for running unit tests use:
poetry install -E test
(or pip install .[test]
)
The model checkpoints are available through Kaggle at http://kaggle.com/models/google/recurrentgemma. Select either the Flax or PyTorch model variations, click the ⤓ button to download the model archive, then extract the contents to a local directory.
In both cases, the archive contains both the model weights and the tokenizer.
To run the tests, install the optional [test]
dependencies (e.g. using pip install .[test]
) from the root of the source tree, then:
pytest .
To run the example sampling script, pass the paths to the weights directory and tokenizer:
python examples/sampling_jax.py \
--path_checkpoint=/path/to/archive/contents/2b/ \
--path_tokenizer=/path/to/archive/contents/tokenizer.model
-
colabs/sampling_tutorial_jax.ipynb
contains a Colab notebook with a sampling example using JAX. -
colabs/sampling_tutorial_pytorch.ipynb
contains a Colab notebook with a sampling example using PyTorch. -
colabs/fine_tuning_tutorial_jax.ipynb
contains a Colab with a basic tutorial on how to fine-tune RecurrentGemma for a task, such as English to French translation, using JAX.
To run these notebooks you will need to have a Kaggle account and first read and accept the Gemma license terms and conditions from the RecurrentGemma page. After this you can run the notebooks, which will automatically download the weights and tokenizer from there.
Currently different notebooks are supported under the following hardware:
Hardware | T4 | P100 | V100 | A100 | TPUv2 | TPUv3 |
---|---|---|---|---|---|---|
Sampling in Jax | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
Sampling in PyTorch | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
Finetuning in Jax | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ |
RecurrentGemma code can run on CPU, GPU or TPU. The code has been optimized for running on TPU using the Flax implementation, which contains a low level Pallas kernel to perform the linear scan in the recurrent layers.
We are open to bug reports and issues. Please see CONTRIBUTING.md for details on PRs.
Copyright 2024 DeepMind Technologies Limited
This code is licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0.
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an AS IS BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
This is not an official Google product.