[Feature] NUMA binding support for GPU workers (#38635)
Signed-off-by: Shengqi Chen <harry-chen@outlook.com> Co-authored-by: Jason Li <jasonlizhengjian@gmail.com> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@@ -140,6 +140,80 @@ Data parallelism replicates the entire model across multiple GPU sets and proces
|
||||
Data parallelism can be combined with the other parallelism strategies and is set by `data_parallel_size=N`.
|
||||
Note that MoE layers will be sharded according to the product of the tensor parallel size and data parallel size.
|
||||
|
||||
### NUMA Binding for Multi-Socket GPU Nodes
|
||||
|
||||
On multi-socket GPU servers, GPU worker processes can lose performance if their
|
||||
CPU execution and memory allocation drift away from the NUMA node nearest to the
|
||||
GPU. vLLM can pin each worker with `numactl` before the Python subprocess starts,
|
||||
so the interpreter, imports, and early allocator state are created with the
|
||||
desired NUMA policy from the beginning.
|
||||
|
||||
Use `--numa-bind` to enable the feature. By default, vLLM auto-detects the
|
||||
GPU-to-NUMA mapping and uses `--cpunodebind=<node> --membind=<node>` for each
|
||||
worker. When you need a custom CPU policy, add `--numa-bind-cpus` and vLLM will
|
||||
switch to `--physcpubind=<cpu-list> --membind=<node>`.
|
||||
|
||||
These `--numa-bind*` options only apply to GPU execution processes. They do not
|
||||
configure the CPU backend's separate thread-affinity controls. Automatic
|
||||
GPU-to-NUMA detection is currently implemented for CUDA/NVML-based platforms;
|
||||
other GPU backends must provide explicit binding lists if they use these
|
||||
options.
|
||||
|
||||
`--numa-bind-nodes` takes one non-negative NUMA node index per visible GPU, in
|
||||
the same order as the GPU indices.
|
||||
`--numa-bind-cpus` takes one `numactl` CPU list per visible GPU, in the same
|
||||
order as the GPU indices. Each CPU list must use
|
||||
`numactl --physcpubind` syntax such as `0-3`, `0,2,4-7`, or `16-31,48-63`.
|
||||
|
||||
```bash
|
||||
# Auto-detect NUMA nodes for visible GPUs
|
||||
vllm serve meta-llama/Llama-3.1-8B-Instruct \
|
||||
--tensor-parallel-size 4 \
|
||||
--numa-bind
|
||||
|
||||
# Explicit NUMA-node mapping
|
||||
vllm serve meta-llama/Llama-3.1-8B-Instruct \
|
||||
--tensor-parallel-size 4 \
|
||||
--numa-bind \
|
||||
--numa-bind-nodes 0 0 1 1
|
||||
|
||||
# Explicit CPU pinning, useful for PCT or other high-frequency core layouts
|
||||
vllm serve meta-llama/Llama-3.1-8B-Instruct \
|
||||
--tensor-parallel-size 4 \
|
||||
--numa-bind \
|
||||
--numa-bind-nodes 0 0 1 1 \
|
||||
--numa-bind-cpus 0-3 4-7 48-51 52-55
|
||||
```
|
||||
|
||||
Notes:
|
||||
|
||||
- CLI usage forces multiprocessing to use the `spawn` method automatically. If you enable NUMA binding through the Python API, also set `VLLM_WORKER_MULTIPROC_METHOD=spawn`.
|
||||
- Automatic detection relies on NVML and NUMA support from the host. If it cannot determine the mapping reliably, pass `--numa-bind-nodes` explicitly.
|
||||
- Explicit `--numa-bind-nodes` and `--numa-bind-cpus` values must be valid `numactl` inputs. vLLM does a small amount of validation, but the effective binding semantics are still determined by `numactl`.
|
||||
- The current implementation binds GPU execution processes such as `EngineCore` and multiprocessing workers. It does not apply NUMA binding to frontend API server processes or the DP coordinator.
|
||||
- In containerized environments, NUMA policy syscalls may require extra permissions, such as `--cap-add SYS_NICE` when running via `docker run`.
|
||||
|
||||
### CPU Backend Thread Affinity
|
||||
|
||||
The CPU backend uses a different mechanism from `--numa-bind`. CPU execution is
|
||||
configured through CPU-specific environment variables such as
|
||||
`VLLM_CPU_OMP_THREADS_BIND`, `VLLM_CPU_NUM_OF_RESERVED_CPU`, and
|
||||
`CPU_VISIBLE_MEMORY_NODES`, rather than the GPU-oriented `--numa-bind*` CLI
|
||||
options.
|
||||
|
||||
By default, `VLLM_CPU_OMP_THREADS_BIND=auto` derives OpenMP placement from the
|
||||
available CPU and NUMA topology for each CPU worker. To override the automatic
|
||||
policy, set `VLLM_CPU_OMP_THREADS_BIND` explicitly using the CPU list format
|
||||
documented for the CPU backend, or use `nobind` to disable this behavior.
|
||||
|
||||
For the current CPU backend setup and tuning guidance, see:
|
||||
|
||||
- [Related runtime environment variables](../getting_started/installation/cpu.md#related-runtime-environment-variables)
|
||||
- [How to decide `VLLM_CPU_OMP_THREADS_BIND`](../getting_started/installation/cpu.md#how-to-decide-vllm_cpu_omp_threads_bind)
|
||||
|
||||
The GPU-only `--numa-bind`, `--numa-bind-nodes`, and `--numa-bind-cpus` options
|
||||
do not configure CPU worker affinity.
|
||||
|
||||
### Batch-level DP for Multi-Modal Encoders
|
||||
|
||||
By default, TP is used to shard the weights of multi-modal encoders just like for language decoders,
|
||||
|
||||
Reference in New Issue
Block a user