Stable | Nightly | Discord | Gurubase (experimental) | ||
---|---|---|---|---|---|
Installation | Getting Started | Examples | APIs | Cite our work
Latest News 🔥
- [2024/11/6] We release v0.4.0: Full AMD support, Tech Report, Modal CI, Llama-3.2-Vision!
- [2024/10/21] We have released the tech report of Liger Kernel on Arxiv: https://arxiv.org/pdf/2410.10989
- [2024/9/6] We release v0.2.1 (X post). 2500 Stars, 10 New Contributors, 50 PRs, 50k Downloads in two weeks!
- [2024/8/31] CUDA MODE talk, Liger-Kernel: Real-world Triton kernel for LLM Training, Slides
- [2024/8/23] Official release: check out our X post
Liger Kernel is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU training throughput by 20% and reduces memory usage by 60%. We have implemented Hugging Face Compatible RMSNorm
, RoPE
, SwiGLU
, CrossEntropy
, FusedLinearCrossEntropy
, and more to come. The kernel works out of the box with Flash Attention, PyTorch FSDP, and Microsoft DeepSpeed. We welcome contributions from the community to gather the best kernels for LLM training.
With one line of code, Liger Kernel can increase throughput by more than 20% and reduce memory usage by 60%, thereby enabling longer context lengths, larger batch sizes, and massive vocabularies.
Speed Up | Memory Reduction |
---|---|
Note:
- Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Data Type =
bf16
, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s.- Hugging Face models start to OOM at a 4K context length, whereas Hugging Face Liger Kernel scales up to 16K.
Use Case | Description |
---|---|
Hugging Face Trainer | Train LLaMA 3-8B ~20% faster with over 40% memory reduction on Alpaca dataset using 4 A100s with FSDP |
Lightning Trainer | Increase 15% throughput and reduce memory usage by 40% with LLaMA3-8B on MMLU dataset using 8 A100s with DeepSpeed ZeRO3 |
Medusa Multi-head LLM (Retraining Phase) | Reduce memory usage by 80% with 5 LM heads and improve throughput by 40% using 8 A100s with FSDP |
Vision-Language Model SFT | Finetune Qwen2-VL on image-text data using 4 A100s with FSDP |
- Ease of use: Simply patch your Hugging Face model with one line of code, or compose your own model using our Liger Kernel modules.
- Time and memory efficient: In the same spirit as Flash-Attn, but for layers like RMSNorm, RoPE, SwiGLU, and CrossEntropy! Increases multi-GPU training throughput by 20% and reduces memory usage by 60% with kernel fusion, in-place replacement, and chunking techniques.
- Exact: Computation is exact—no approximations! Both forward and backward passes are implemented with rigorous unit tests and undergo convergence testing against training runs without Liger Kernel to ensure accuracy.
- Lightweight: Liger Kernel has minimal dependencies, requiring only Torch and Triton—no extra libraries needed! Say goodbye to dependency headaches!
- Multi-GPU supported: Compatible with multi-GPU setups (PyTorch FSDP, DeepSpeed, DDP, etc.).
- Trainer Framework Integration: Axolotl, LLaMa-Factory, SFTTrainer, Hugging Face Trainer, SWIFT
torch >= 2.1.2
triton >= 2.3.0
torch >= 2.5.0
Install according to the instruction in Pytorch official webpage.triton >= 3.0.0
Install from pypi. (e.g.pip install triton==3.0.0
)
transformers >= 4.x
: Required if you plan to use the transformers models patching APIs. The specific model you are working will dictate the minimum version of transformers.
Note: Our kernels inherit the full spectrum of hardware compatibility offered by Triton.
To install the stable version:
$ pip install liger-kernel
To install the nightly version:
$ pip install liger-kernel-nightly
To install from source:
git clone https://github.com/linkedin/Liger-Kernel.git
cd Liger-Kernel
pip install -e .
# or if using transformers
pip install -e .[transformers]
There are a couple of ways to apply Liger kernels, depending on the level of customization required.
Using the AutoLigerKernelForCausalLM
is the simplest approach, as you don't have to import a model-specific patching API. If the model type is supported, the modeling code will be automatically patched using the default settings.
from liger_kernel.transformers import AutoLigerKernelForCausalLM
# This AutoModel wrapper class automatically monkey-patches the
# model with the optimized Liger kernels if the model is supported.
model = AutoLigerKernelForCausalLM.from_pretrained("path/to/some/model")
Using the patching APIs, you can swap Hugging Face models with optimized Liger Kernels.
import transformers
from liger_kernel.transformers import apply_liger_kernel_to_llama
# 1a. Adding this line automatically monkey-patches the model with the optimized Liger kernels
apply_liger_kernel_to_llama()
# 1b. You could alternatively specify exactly which kernels are applied
apply_liger_kernel_to_llama(
rope=True,
swiglu=True,
cross_entropy=True,
fused_linear_cross_entropy=False,
rms_norm=False
)
# 2. Instantiate patched model
model = transformers.AutoModelForCausalLM("path/to/llama/model")
You can take individual kernels to compose your models.
from liger_kernel.transformers import LigerFusedLinearCrossEntropyLoss
import torch.nn as nn
import torch
model = nn.Linear(128, 256).cuda()
# fuses linear cross entropy layers together and performs chunk-by-chunk computation to reduce memory
loss_fn = LigerFusedLinearCrossEntropyLoss()
input = torch.randn(4, 128, requires_grad=True, device="cuda")
target = torch.randint(256, (4, ), device="cuda")
loss = loss_fn(model.weight, input, target)
loss.backward()
AutoModel Variant | API |
---|---|
AutoModelForCausalLM | liger_kernel.transformers.AutoLigerKernelForCausalLM |
Model | API | Supported Operations |
---|---|---|
LLaMA 2 & 3 | liger_kernel.transformers.apply_liger_kernel_to_llama |
RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
LLaMA 3.2-Vision | liger_kernel.transformers.apply_liger_kernel_to_mllama |
RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
Mistral | liger_kernel.transformers.apply_liger_kernel_to_mistral |
RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
Mixtral | liger_kernel.transformers.apply_liger_kernel_to_mixtral |
RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
Gemma1 | liger_kernel.transformers.apply_liger_kernel_to_gemma |
RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
Gemma2 | liger_kernel.transformers.apply_liger_kernel_to_gemma2 |
RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
Qwen2 & Qwen2.5 | liger_kernel.transformers.apply_liger_kernel_to_qwen2 |
RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
Qwen2-VL | liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl |
RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
Phi3 & Phi3.5 | liger_kernel.transformers.apply_liger_kernel_to_phi3 |
RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
Kernel | API |
---|---|
RMSNorm | liger_kernel.transformers.LigerRMSNorm |
LayerNorm | liger_kernel.transformers.LigerLayerNorm |
RoPE | liger_kernel.transformers.liger_rotary_pos_emb |
SwiGLU | liger_kernel.transformers.LigerSwiGLUMLP |
GeGLU | liger_kernel.transformers.LigerGEGLUMLP |
CrossEntropy | liger_kernel.transformers.LigerCrossEntropyLoss |
FusedLinearCrossEntropy | liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss |
KLDivergence | liger_kernel.transformers.LigerKLDIVLoss |
JSD | liger_kernel.transformers.LigerJSD |
FusedLinearJSD | liger_kernel.transformers.LigerFusedLinearJSD |
- RMSNorm: RMSNorm, which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction.
- LayerNorm: LayerNorm, which centers and normalizes activations across the feature dimension, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and achieves ~2X speedup.
- GroupNorm: GroupNorm, which normalizes activations across the group dimension for a given sample. Channels are grouped in K groups over which the normalization is performed, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and can achieve up to ~2X speedup as the number of channels/groups increases.
- RoPE: Rotary Positional Embedding is implemented by fusing the query and key embedding rotary into a single kernel with inplace replacement, and achieves ~3X speedup with ~3X peak memory reduction.
-
SwiGLU: Swish Gated Linear Units, given by
$$\text{SwiGLU}(x)=\text{Swish}_{\beta}(xW b)\otimes(xV c)$$ , is implemented by fusing the elementwise multiplication (denoted by$\otimes$ ) into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction. -
GeGLU: GELU Gated Linear Units, given by
$$\text{GeGLU}(x)=\text{GELU}(xW b)\otimes(xV c)$$ , is implemented by fusing the elementwise multiplication into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction. Note that the tanh approximation form of GELU is used. - CrossEntropy: Cross entropy loss is implemented by computing both the loss and gradient in the forward pass with inplace replacement of input to reduce the peak memory by avoiding simultaneous materialization of both input logits and gradient. It achieves >2X speedup and >4X memory reduction for common vocab sizes (e.g., 32K, 128K, etc.).
- FusedLinearCrossEntropy: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by Efficient Cross Entropy. It achieves >4X memory reduction for 128k vocab size. This is highly effective for large batch size, large sequence length, and large vocabulary sizes. Please refer to the Medusa example for individual kernel usage.
- KLDivergence: KL Divergence is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size.
-
JSD: Generalized JSD (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size. NOTE: It implements forward/reverse KL when
beta
equals 0 and 1 respectively. -
FusedLinearJSD: Peak memory usage of JSD loss is further improved by fusing the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size
$\times$ sequence length is 8192. NOTE: It implements forward/reverse KL whenbeta
equals 0 and 1 respectively.
Kernel | API |
---|---|
Embedding | liger_kernel.transformers.experimental.LigerEmbedding |
Matmul int2xint8 | liger_kernel.transformers.experimental.matmul |
- Embedding: Embedding is implemented by fusing embedding lookup and output operations. It achieves a peak speedup of ~1.5x in the forward pass and an overall speedup of ~1.1x.
- Matmul int2xint8: is implemented by using the cache tiled matrix multiplication and by fusing the matmul with the unpacking process which achieves a considerable speed up and performs on par with @torch.compile
- For issues, create a Github ticket in this repository
- For open discussion, join our discord channel
- For formal collaboration, send an email to [email protected]
Biblatex entry:
@article{hsu2024ligerkernelefficienttriton,
title={Liger Kernel: Efficient Triton Kernels for LLM Training},
author={Pin-Lun Hsu and Yun Dai and Vignesh Kothapalli and Qingquan Song and Shao Tang and Siyu Zhu and Steven Shimizu and Shivam Sahni and Haowen Ning and Yanning Chen},
year={2024},
eprint={2410.10989},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2410.10989},
journal={arXiv preprint arXiv:2410.10989},
}