[CLI][Doc] Formalize --mm-encoder-tp-mode (#23190)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-08-20 23:42:28 +08:00
committed by GitHub
parent b17109beea
commit 5efd6905bc
7 changed files with 104 additions and 24 deletions

View File

@@ -129,6 +129,51 @@ Data parallelism replicates the entire model across multiple GPU sets and proces
Data parallelism can be combined with the other parallelism strategies and is set by `data_parallel_size=N`.
Note that MoE layers will be sharded according to the product of the tensor parallel size and data parallel size.
### Batch-level DP for Multi-Modal Encoders
By default, TP is used to shard the weights of multi-modal encoders just like for language decoders,
in order to reduce the memory and compute load on each GPU.
However, since the size of multi-modal encoders is very small compared to language decoders,
there is relatively little gain from TP. On the other hand, TP incurs significant communication
overhead because of all-reduce being performed after every layer.
Given this, it may be advantageous to instead shard the batched input data using TP, essentially
performing batch-level DP. This has been shown to improve the throughput by around 10% for
`tensor_parallel_size=8`. For vision encoders that use hardware-unoptimized Conv3D operations,
batch-level DP can provide another 40% increase to throughput compared to regular TP.
Nevertheless, since the weights of the multi-modal encoder are replicated across each TP rank,
there will be a minor increase in memory consumption and may cause OOM if you can barely fit the model already.
You can enable batch-level DP by setting `mm_encoder_tp_mode="data"`, for example:
```python
from vllm import LLM
llm = LLM(
model="Qwen/Qwen2.5-VL-72B-Instruct",
# Create two EngineCore instances, one per DP rank
data_parallel_size=2,
# Within each EngineCore instance:
# The vision encoder uses TP=4 (not DP=2) to shard the input data
# The language decoder uses TP=4 to shard the weights as usual
tensor_parallel_size=4,
mm_encoder_tp_mode="data",
)
```
!! important
Batch-level DP is not to be confused with API request-level DP
(which is instead controlled by `data_parallel_size`).
The availablilty of batch-level DP is based on model implementation.
Currently, the following models support `mm_encoder_tp_mode="data"`:
- Llama4 (<gh-pr:18368>)
- Qwen2.5-VL (<gh-pr:22742>)
- Step3 (<gh-pr:22697>)
## Input Processing
### Parallel Processing