This is the training code for a 2 stage autoregressive video model.
- Chunked scatter/gather/init functions
- Parallel model save/load
- Dtype conversions at scatter/gather/init functions
- Distributed data loading
- Distributed model training
- Multi-platform file backend via PyFilesystem2
- GPU Support
- SLURM Support
- Kubernetes Support
- Text conditional diffusion Transformer
- (5/6)-D parallelism
- FSDP
- Ring attention
- Pipeline parallelism
- Async swarm
- Llama 3 support
- Sophisticated logging (Logfire/SQL database)
Parameter scaling:
- A Spectral Condition for Feature Learning
- Scaling Exponents Across Parameterizations and Optimizers (NTK init with global LR is used for most experiments)
Jax sharding:
Data loader Design: