Fix multi-node allreduce fusion (#38136)

Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: root <root@theia0053.lyris.clusters.nvidia.com>
This commit is contained in:
Wei Zhao
2026-03-26 16:24:36 -04:00
committed by GitHub
parent f26fcdfb9e
commit 0904b6550d
2 changed files with 63 additions and 9 deletions

View File

@@ -13,11 +13,13 @@ from torch.distributed import ProcessGroup
import vllm.envs as envs
from vllm.config.compilation import PassConfig
from vllm.distributed.parallel_state import get_node_count
from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__)
fi_ar_available = False
try:
import flashinfer.comm as flashinfer_comm # type: ignore[no-redef]
@@ -87,6 +89,27 @@ def _create_workspace(
return workspace
def _resolve_fi_ar_backend() -> str:
backend = envs.VLLM_FLASHINFER_ALLREDUCE_BACKEND
if backend != "auto":
logger.info_once(f"Using flashinfer allreduce backend: {backend}")
return backend
if get_node_count() > 1: # noqa: SIM108
# Use mnnvl backend for multi-node setup since
# trtllm backend does not support multi-node allreduce
backend = "mnnvl"
else:
# Currently defaulting to trtllm backend for single-node
# setup since mnnvl has issues with cudagraph:
# https://github.com/vllm-project/vllm/issues/35772
# Should switch back to auto when the issue is resolved.
backend = "trtllm"
logger.info_once(f"Auto-selected flashinfer allreduce backend: {backend}")
return backend
def get_fi_ar_workspace(
world_size: int,
rank: int,
@@ -106,7 +129,13 @@ def get_fi_ar_workspace(
if _fi_ar_workspace is not None:
return _fi_ar_workspace
backend = envs.VLLM_FLASHINFER_ALLREDUCE_BACKEND
backend = _resolve_fi_ar_backend()
if get_node_count() > 1 and backend == "trtllm":
raise ValueError(
"Flashinfer allreduce is not supported for multi-node allreduce with "
"'trtllm' backend. Please use 'mnnvl' backend instead."
)
# Reuse the quant workspace if it was already created with the same backend
if _fi_ar_quant_workspace is not None and _fi_ar_quant_workspace.backend == backend:
@@ -116,6 +145,17 @@ def get_fi_ar_workspace(
_fi_ar_workspace = _create_workspace(
backend, world_size, rank, max_token_num, hidden_dim, dtype, group
)
if _fi_ar_workspace is not None:
logger.info_once(
"Initialized FlashInfer Allreduce norm fusion workspace "
f"with backend={backend}"
)
else:
logger.warning_once(
"Failed to initialize FlashInfer Allreduce norm fusion workspace "
f"with backend={backend}"
)
return _fi_ar_workspace
@@ -131,12 +171,20 @@ def get_fi_ar_quant_workspace(
Return the allreduce workspace for quant patterns, initializing if needed.
Always uses trtllm backend as it is the only one supporting quantization
fusion (FP8/FP4).
fusion (FP8/FP4). Returns None for multi-node setups since not supported
by trtllm backend.
"""
global _fi_ar_quant_workspace
if _fi_ar_quant_workspace is not None:
return _fi_ar_quant_workspace
if get_node_count() > 1:
logger.warning_once(
"Flashinfer allreduce quantization fusion is not supported for "
"multi-node allreduce. Disabling quant fusion."
)
return None
# Reuse the non-quant workspace if it was already created with trtllm
if _fi_ar_workspace is not None and _fi_ar_workspace.backend == "trtllm":
_fi_ar_quant_workspace = _fi_ar_workspace
@@ -145,6 +193,17 @@ def get_fi_ar_quant_workspace(
_fi_ar_quant_workspace = _create_workspace(
"trtllm", world_size, rank, max_token_num, hidden_dim, dtype, group
)
if _fi_ar_quant_workspace is not None:
logger.info_once(
"Initialized FlashInfer Allreduce norm quantization "
"fusion workspace with backend=trtllm"
)
else:
logger.warning_once(
"Failed to initialize FlashInfer Allreduce norm quantization "
"fusion workspace with backend=trtllm"
)
return _fi_ar_quant_workspace

View File

@@ -169,7 +169,7 @@ if TYPE_CHECKING:
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency", "masked_gemm"] = (
"latency"
)
VLLM_FLASHINFER_ALLREDUCE_BACKEND: Literal["auto", "trtllm", "mnnvl"] = "trtllm"
VLLM_FLASHINFER_ALLREDUCE_BACKEND: Literal["auto", "trtllm", "mnnvl"] = "auto"
VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE: int = 394 * 1024 * 1024
VLLM_XGRAMMAR_CACHE_MB: int = 0
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
@@ -1305,14 +1305,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
["throughput", "latency", "masked_gemm"],
),
# Flashinfer fused allreduce backend.
# "auto" will default to "mnnvl", which performs mostly same/better than "trtllm".
# But "mnnvl" backend does not support fuse with quantization.
# TODO: Default is "trtllm" right now because "mnnvl" has issues with cudagraph:
# https://github.com/vllm-project/vllm/issues/35772
# Should switch back to "auto" if the issue is resolved.
"VLLM_FLASHINFER_ALLREDUCE_BACKEND": env_with_choices(
"VLLM_FLASHINFER_ALLREDUCE_BACKEND",
"trtllm",
"auto",
["auto", "trtllm", "mnnvl"],
),
# Control the workspace buffer size for the FlashInfer backend.