This repository provides a JAX/Flax implementation of the Hyena architecture introduced in Poli et. al. (2023). A full training run of a small 1.5M parameter model, on the Shakespeare dataset can be found in the included intro.ipynb
. This achieves a best validation loss of ~1.45, on par with the results in nanoGPT.
Specifically, the following is implemented:
- The Hyena layer itself can be found in
hyena/hyena.py
asHyenaOperator
- The efficient, FFT-based convolution is implemented in the
fftconv
method, providing an O(N log N) complexity in sequence length. This is used for training, and for the pre-fill stage during inference.- Caching is also implemented, which means this is called only once during inference pre-fill, with the subsequent individual tokens being computed using the alternate implementation (see below).
- An alternate implementation, having O(N) complexity per token is provided for the auto-regressive decoding stage during inference. This is implemented in the
inference_conv
method. It will be particularly faster when generating a small number of tokens from a very large input (e.g. a full document).
- The efficient, FFT-based convolution is implemented in the
- A standard Decoder tower is implemented in
hyena/decoder.py
asDecoder
. The implementation is largely similar to the one in nanoGPT, with the self-attention layers swapped out with Hyena layers.