Speculative Decoding¶
This document shows how to use Speculative Decoding with vLLM to reduce inter-token latency under medium-to-low QPS (query per second), memory-bound workloads.
To train your own draft models for optimized speculative decoding, see vllm-project/speculators for seamless training and integration with vLLM.
vLLM Speculation Methods¶
vLLM supports a variety of methods of speculative decoding. Model-based methods such as EAGLE, MTP, draft models, PARD and MLP provide the best latency reduction, while simpler methods such as n-gram and suffix decoding provide modest speedups without increasing workload during peak traffic.
- EAGLE
- Multi-Token Prediction (MTP)
- Draft Model
- Parallel Draft Model (PARD)
- Multi-Layer Perceptron
- N-Gram
- Suffix Decoding
Method Selection at a Glance¶
Use this qualitative table as a starting point for method selection. Real gains depend on your model family, traffic pattern, hardware, and sampling settings.
| Method | Low QPS (latency focused) | High QPS (throughput focused) | Notes |
|---|---|---|---|
| EAGLE | High gain | Medium to high gain | Strong general-purpose model-based method. |
| MTP | High gain | Medium to high gain | Best when the target model has native MTP support. |
| Draft model | High gain | Medium gain | Needs a separate draft model. |
| Parallel Draft Model | High gain | Medium to high gain | Low draft model latency. |
| MLP speculator | Medium to high gain | Medium gain | Good when compatible MLP speculators are available. |
| N-gram | Low to medium gain | Medium gain | Lightweight and easy to enable. |
| Suffix decoding | Low to medium gain | Medium gain | No extra draft model; dynamic speculation depth. |
For reproducible measurements in your environment, use examples/offline_inference/spec_decode.py or the benchmark CLI guide.
--speculative-config schema¶
Use --speculative-config to pass speculative decoding settings as a JSON object on the CLI:
vllm serve <target-model> \
--speculative-config '{
"method": "draft_model",
"model": "<draft-model>",
"num_speculative_tokens": 5
}'
The same keys are accepted from Python via LLM(..., speculative_config={...}). The tables below highlight common user-facing keys accepted in this JSON object; they are not an exhaustive schema reference. For more details, see the generated engine arguments reference and the API docs for vllm.config.SpeculativeConfig.
Common keys¶
These keys are commonly used across speculative decoding setups, though some only apply to model-based methods such as draft_model, mtp, eagle3, and dflash.
| Key | Type | Default | Allowed values / meaning |
|---|---|---|---|
method | string | None | Speculation method. Common values include draft_model, ngram, suffix, mtp, eagle3, and dflash. If omitted, vLLM infers the method from the provided configuration when possible. |
model | string | None | Draft model, EAGLE head, or auxiliary model identifier. For ngram, ngram_gpu, suffix, and mtp, this can often be omitted. |
num_speculative_tokens | integer > 0 | None | Number of speculative tokens to propose per step. Required for methods that do not infer it from model metadata. |
draft_tensor_parallel_size | integer >= 1 | None | Tensor parallel size for the draft model. |
max_model_len | integer >= 1 | None | Maximum context length for the draft model. |
parallel_drafting | boolean | false | Enable parallel draft token generation. Only compatible with EAGLE and draft-model methods. |
rejection_sample_method | string | strict | strict, probabilistic, or synthetic. |
synthetic_acceptance_rate | float | None | Average acceptance rate to target when rejection_sample_method is synthetic. Valid range is [0, 1]. |
Method-specific keys¶
N-gram¶
| Key | Type | Default | Meaning |
|---|---|---|---|
prompt_lookup_max | integer >= 1 | 5 if both lookup bounds are omitted; otherwise mirrors prompt_lookup_min when omitted | Maximum n-gram window size. |
prompt_lookup_min | integer >= 1 | 5 if both lookup bounds are omitted; otherwise mirrors prompt_lookup_max when omitted | Minimum n-gram window size. |
Example:
vllm serve <target-model> \
--speculative-config '{
"method": "ngram",
"num_speculative_tokens": 4,
"prompt_lookup_min": 2,
"prompt_lookup_max": 5
}'
Suffix decoding¶
| Key | Type | Default | Meaning |
|---|---|---|---|
suffix_decoding_max_tree_depth | integer | 24 | Maximum combined prefix-match and speculation tree depth. |
suffix_decoding_max_cached_requests | integer | 10000 | Maximum number of requests cached in the global suffix tree. Set 0 to disable the global cache. |
suffix_decoding_max_spec_factor | float | 1.0 | Caps speculative length as a multiple of prefix-match length. |
suffix_decoding_min_token_prob | float | 0.1 | Minimum estimated token probability required to speculate a token. |
Example:
vllm serve <target-model> \
--speculative-config '{
"method": "suffix",
"num_speculative_tokens": 8,
"suffix_decoding_max_tree_depth": 24,
"suffix_decoding_max_cached_requests": 10000,
"suffix_decoding_max_spec_factor": 1.0,
"suffix_decoding_min_token_prob": 0.1
}'
Notes¶
--speculative-configexpects a JSON object on the CLI. In YAML config files, use a nested mapping instead of an escaped JSON string.tensor_parallel_sizeis not a valid key inspeculative_config. Usedraft_tensor_parallel_sizeinstead.- Keys such as
temperatureandtop_pare sampling parameters, not--speculative-configfields. - Internal fields such as
target_model_config,draft_model_config,target_parallel_config,draft_parallel_config, anddraft_load_configare populated by vLLM and are not intended to be set by users.
Lossless guarantees of Speculative Decoding¶
In vLLM, speculative decoding aims to enhance inference efficiency while maintaining accuracy. This section addresses the lossless guarantees of speculative decoding, breaking down the guarantees into three key areas:
-
Theoretical Losslessness - Speculative decoding sampling is theoretically lossless up to the precision limits of hardware numerics. Floating-point errors might cause slight variations in output distributions, as discussed in Accelerating Large Language Model Decoding with Speculative Sampling
-
Algorithmic Losslessness - vLLM’s implementation of speculative decoding is algorithmically validated to be lossless. Key validation tests include:
- Rejection Sampler Convergence: Ensures that samples from vLLM’s rejection sampler align with the target distribution. View Test Code
- Greedy Sampling Equality: Confirms that greedy sampling with speculative decoding matches greedy sampling without it. This verifies that vLLM's speculative decoding framework, when integrated with the vLLM forward pass and the vLLM rejection sampler, provides a lossless guarantee. Almost all of the tests in tests/spec_decode/e2e. verify this property using this assertion implementation
-
vLLM Logprob Stability - vLLM does not currently guarantee stable token log probabilities (logprobs). This can result in different outputs for the same request across runs. For more details, see the FAQ section titled Can the output of a prompt vary across runs in vLLM? in the FAQs.
While vLLM strives to ensure losslessness in speculative decoding, variations in generated outputs with and without speculative decoding can occur due to following factors:
- Floating-Point Precision: Differences in hardware numerical precision may lead to slight discrepancies in the output distribution.
- Batch Size and Numerical Stability: Changes in batch size may cause variations in logprobs and output probabilities, potentially due to non-deterministic behavior in batched operations or numerical instability.
For mitigation strategies, please refer to the FAQ entry Can the output of a prompt vary across runs in vLLM? in the FAQs.
Known Feature Incompatibility¶
- Pipeline parallelism is not composible with speculative decoding as of
vllm<=0.15.0 - Speculative decoding with a draft models is not supported in
vllm<=0.10.0