[Doc] Add MTP docs and update speculative decoding guidance (#35197)

Signed-off-by: liuxing <945764858@qq.com>
This commit is contained in:
Xing Liu
2026-03-05 01:23:34 +08:00
committed by GitHub
parent 28028dff2f
commit 7cc6058ac6
3 changed files with 79 additions and 4 deletions

View File

@@ -6,14 +6,33 @@ To train your own draft models for optimized speculative decoding, see [vllm-pro
## vLLM Speculation Methods
vLLM supports a variety of methods of speculative decoding. Model-based methods such as EAGLE, draft models, and mlp provide the best latency reduction, while simpler methods such as n-gram and and suffix decoding provide modest speedups without increasing workload during peak traffic.
vLLM supports a variety of methods of speculative decoding. Model-based methods such as EAGLE, MTP, draft models, and MLP provide the best latency reduction, while simpler methods such as n-gram and suffix decoding provide modest speedups without increasing workload during peak traffic.
- [EAGLE](eagle.md)
- [Multi-Token Prediction (MTP)](mtp.md)
- [Draft Model](draft_model.md)
- [Multi-Layer Perceptron](mlp.md)
- [N-Gram](n_gram.md)
- [Suffix Decoding](suffix.md)
## Method Selection at a Glance
Use this qualitative table as a starting point for method selection. Real gains
depend on your model family, traffic pattern, hardware, and sampling settings.
| Method | Low QPS (latency focused) | High QPS (throughput focused) | Notes |
| --- | --- | --- | --- |
| EAGLE | High gain | Medium to high gain | Strong general-purpose model-based method. |
| MTP | High gain | Medium to high gain | Best when the target model has native MTP support. |
| Draft model | High gain | Medium gain | Needs a separate draft model. |
| MLP speculator | Medium to high gain | Medium gain | Good when compatible MLP speculators are available. |
| N-gram | Low to medium gain | Medium gain | Lightweight and easy to enable. |
| Suffix decoding | Low to medium gain | Medium gain | No extra draft model; dynamic speculation depth. |
For reproducible measurements in your environment, use
[`examples/offline_inference/spec_decode.py`](../../../examples/offline_inference/spec_decode.py)
or the [benchmark CLI guide](../../benchmarking/cli.md).
## Lossless guarantees of Speculative Decoding
In vLLM, speculative decoding aims to enhance inference efficiency while maintaining accuracy. This section addresses the lossless guarantees of

View File

@@ -11,10 +11,10 @@ prompts = ["The future of AI is"]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(
model="meta-llama/Meta-Llama-3.1-70B-Instruct",
tensor_parallel_size=4,
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
tensor_parallel_size=1,
speculative_config={
"model": "ibm-ai-platform/llama3-70b-accelerator",
"model": "ibm-ai-platform/llama3-8b-accelerator",
"draft_tensor_parallel_size": 1,
"method": "mlp_speculator",
},
@@ -27,6 +27,12 @@ for output in outputs:
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
!!! warning "Known issue"
`ibm-ai-platform/llama3-70b-accelerator` can fail with:
`AttributeError: 'MLPSpeculatorConfig' object has no attribute 'num_attention_heads'`.
Track status in [#34106](https://github.com/vllm-project/vllm/issues/34106)
and [#34163](https://github.com/vllm-project/vllm/pull/34163).
## Pre-Trained MLP Drafter Models
A variety of speculative models of this type are available on HF hub:

View File

@@ -0,0 +1,50 @@
# MTP (Multi-Token Prediction)
MTP is a speculative decoding method where the target model includes native
multi-token prediction capability. Unlike draft-model-based methods, you do not
need to provide a separate draft model.
MTP is useful when:
- Your model natively supports MTP.
- You want model-based speculative decoding with minimal extra configuration.
## Offline Example
```python
from vllm import LLM, SamplingParams
prompts = ["The future of AI is"]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(
model="XiaomiMiMo/MiMo-7B-Base",
tensor_parallel_size=1,
speculative_config={
"method": "mtp",
"num_speculative_tokens": 1,
},
)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
## Online Example
```bash
vllm serve XiaomiMiMo/MiMo-7B-Base \
--tensor-parallel-size 1 \
--speculative_config '{"method":"mtp","num_speculative_tokens":1}'
```
## Notes
- MTP only works for model families that support MTP in vLLM.
- `num_speculative_tokens` controls speculative depth. A small value like `1`
is a good default to start with.
- If your model does not support MTP, use another method such as EAGLE or draft
model speculation.