Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads
Abstract
The inference process in Large Language Models (LLMs) is often limited due to the absence of parallelism in the auto-regressive decoding process, resulting in most operations being restricted by the memory bandwidth of accelerators. While methods such as speculative decoding have been suggested to address this issue, their implementation is impeded by the challenges associated with acquiring and maintaining a separate draft model. In this paper, we present Medusa, an efficient method that augments LLM inference by adding extra decoding heads to predict multiple subsequent tokens in parallel. Using a tree-based attention mechanism, Medusa constructs multiple candidate continuations and verifies them simultaneously in each decoding step. By leveraging parallel processing, Medusa introduces only minimal overhead in terms of single-step latency while substantially reducing the number of decoding steps required. We present two levels of fine-tuning procedures for Medusa to meet the needs of different use cases: Medusa-1: Medusa is directly fine-tuned on top of a frozen backbone LLM, enabling lossless inference acceleration. Medusa-2: Medusa is fine-tuned together with the backbone LLM, enabling better prediction accuracy of Medusa heads and higher speedup but needing a special training recipe that preserves the backbone model's capabilities. Moreover, we propose several extensions that improve or expand the utility of Medusa, including a self-distillation to handle situations where no training data is available and a typical acceptance scheme to boost the acceptance rate while maintaining generation quality. We evaluate Medusa on models of various sizes and training procedures. Our experiments demonstrate that Medusa-1 can achieve over 2.2x speedup without compromising generation quality, while Medusa-2 further improves the speedup to 2.3-3.6x.
Community
This is an automated message from the Librarian Bot. I found the following papers similar to this paper.
The following papers were recommended by the Semantic Scholar API
- Unlocking Efficiency in Large Language Model Inference: A Comprehensive Survey of Speculative Decoding (2024)
- APAR: LLMs Can Do Auto-Parallel Auto-Regressive Decoding (2024)
- Multi-Candidate Speculative Decoding (2024)
- Inferflow: an Efficient and Highly Configurable Inference Engine for Large Language Models (2024)
- SparQ Attention: Bandwidth-Efficient LLM Inference (2023)
Please give a thumbs up to this comment if you found it helpful!
If you want recommendations for any Paper on Hugging Face checkout this Space
Unlocking Faster AI: Medusa's Multi-Head Decoding for LLMs
Links π:
π Subscribe: https://www.youtube.com/@Arxflix
π Twitter: https://x.com/arxflix
π LMNT (Partner): https://lmnt.com/
Awesome paper!
When evaluating candidate sequences, is there an optimization that involves computing only the next token for each candidate in relation to the original model? I am confused, I initially thought that the total number of tokens to compute would simply be the sum of the top-K values used at each layer. For example, given:
- h1 predictions: [h11, h12, h13]
- h2 predictions: [h21, h22, h23]
Wouldn't we only 6 new tokens need to be computed, assuming that the tokens from h11
to h13
can be cached and reused when generating tokens at the h2
layer? i.e. can we computeh11
, adding its value to your cached state, and then evaluating all candidate sequences stemming from h11
without recomputing the previous tokens (in this example, without recomputing h11
each time)?
I think I must be missing something basic!
Models citing this paper 0
No model linking this paper
Datasets citing this paper 0
No dataset linking this paper
Spaces citing this paper 0
No Space linking this paper