Skip to content

MasterSkepticista/gpt2

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

50 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

GPT-2 in Jax/Flax

This is a Jax/Flax reimplementation of GPT-2 family of models on FineWeb-Edu dataset, inspired from karpathy/build_nanoGPT.

Updates:

  • Add support for tf.data pipelines over TFRecords.
  • Add support for bfloat16 computation.
  • SPMD (multi-node) training support using pmap.
  • Expose configurables via CLI flags (or config dict).
  • Use cuDNN flash attention kernel (SDPA API) (jax-ml/jax#22546).
  • nn.Embed typecast performance issue.
  • Use scale init for residual paths.
  • Fix large gradient norm spikes for longer training runs.
  • Test accumulate_gradient.
  • Update docstrings.
  • Add shard_map support for model and data sharding.
  • KV cache decoding.

Setup

Create a virtual environment and install packages.

pip install -r requirements.txt

For SPMD support (multi-node training), install OpenMPI.

sudo apt install openmpi-bin openmpi-doc libopenmpi-dev

Prepare TFRecords

python fineweb.py --outdir /path/to/store/tfrecord

Train

# Single process, multi-GPU.
python train.py --workdir artifacts/gpt2_124M --config configs/default.py

# multi-process on same host using OpenMPI.
mpirun -n 8 \
          -bind-to socket \
          python train.py --workdir artifacts/gpt2_124M --config configs/default.py

# multi-node across 8 hosts (needs passwordless SSH across hosts).
mpirun -n 8 \
          -pernode \
          -H hostname1,hostname2,...,hostname8 \
          -bind-to socket \
          python train.py --workdir artifacts/gpt2_124M --config configs/default.py

License

MIT

About

Training GPT-2 on FineWeb-Edu in JAX/Flax

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages