diff --git a/vllm/distributed/device_communicators/flashinfer_all_reduce.py b/vllm/distributed/device_communicators/flashinfer_all_reduce.py index b2edfc15d..a65789a28 100644 --- a/vllm/distributed/device_communicators/flashinfer_all_reduce.py +++ b/vllm/distributed/device_communicators/flashinfer_all_reduce.py @@ -13,11 +13,13 @@ from torch.distributed import ProcessGroup import vllm.envs as envs from vllm.config.compilation import PassConfig +from vllm.distributed.parallel_state import get_node_count from vllm.logger import init_logger from vllm.platforms import current_platform logger = init_logger(__name__) + fi_ar_available = False try: import flashinfer.comm as flashinfer_comm # type: ignore[no-redef] @@ -87,6 +89,27 @@ def _create_workspace( return workspace +def _resolve_fi_ar_backend() -> str: + backend = envs.VLLM_FLASHINFER_ALLREDUCE_BACKEND + if backend != "auto": + logger.info_once(f"Using flashinfer allreduce backend: {backend}") + return backend + + if get_node_count() > 1: # noqa: SIM108 + # Use mnnvl backend for multi-node setup since + # trtllm backend does not support multi-node allreduce + backend = "mnnvl" + else: + # Currently defaulting to trtllm backend for single-node + # setup since mnnvl has issues with cudagraph: + # https://github.com/vllm-project/vllm/issues/35772 + # Should switch back to auto when the issue is resolved. + backend = "trtllm" + + logger.info_once(f"Auto-selected flashinfer allreduce backend: {backend}") + return backend + + def get_fi_ar_workspace( world_size: int, rank: int, @@ -106,7 +129,13 @@ def get_fi_ar_workspace( if _fi_ar_workspace is not None: return _fi_ar_workspace - backend = envs.VLLM_FLASHINFER_ALLREDUCE_BACKEND + backend = _resolve_fi_ar_backend() + + if get_node_count() > 1 and backend == "trtllm": + raise ValueError( + "Flashinfer allreduce is not supported for multi-node allreduce with " + "'trtllm' backend. Please use 'mnnvl' backend instead." + ) # Reuse the quant workspace if it was already created with the same backend if _fi_ar_quant_workspace is not None and _fi_ar_quant_workspace.backend == backend: @@ -116,6 +145,17 @@ def get_fi_ar_workspace( _fi_ar_workspace = _create_workspace( backend, world_size, rank, max_token_num, hidden_dim, dtype, group ) + if _fi_ar_workspace is not None: + logger.info_once( + "Initialized FlashInfer Allreduce norm fusion workspace " + f"with backend={backend}" + ) + else: + logger.warning_once( + "Failed to initialize FlashInfer Allreduce norm fusion workspace " + f"with backend={backend}" + ) + return _fi_ar_workspace @@ -131,12 +171,20 @@ def get_fi_ar_quant_workspace( Return the allreduce workspace for quant patterns, initializing if needed. Always uses trtllm backend as it is the only one supporting quantization - fusion (FP8/FP4). + fusion (FP8/FP4). Returns None for multi-node setups since not supported + by trtllm backend. """ global _fi_ar_quant_workspace if _fi_ar_quant_workspace is not None: return _fi_ar_quant_workspace + if get_node_count() > 1: + logger.warning_once( + "Flashinfer allreduce quantization fusion is not supported for " + "multi-node allreduce. Disabling quant fusion." + ) + return None + # Reuse the non-quant workspace if it was already created with trtllm if _fi_ar_workspace is not None and _fi_ar_workspace.backend == "trtllm": _fi_ar_quant_workspace = _fi_ar_workspace @@ -145,6 +193,17 @@ def get_fi_ar_quant_workspace( _fi_ar_quant_workspace = _create_workspace( "trtllm", world_size, rank, max_token_num, hidden_dim, dtype, group ) + if _fi_ar_quant_workspace is not None: + logger.info_once( + "Initialized FlashInfer Allreduce norm quantization " + "fusion workspace with backend=trtllm" + ) + else: + logger.warning_once( + "Failed to initialize FlashInfer Allreduce norm quantization " + "fusion workspace with backend=trtllm" + ) + return _fi_ar_quant_workspace diff --git a/vllm/envs.py b/vllm/envs.py index a531b0e77..d29e367bc 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -169,7 +169,7 @@ if TYPE_CHECKING: VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency", "masked_gemm"] = ( "latency" ) - VLLM_FLASHINFER_ALLREDUCE_BACKEND: Literal["auto", "trtllm", "mnnvl"] = "trtllm" + VLLM_FLASHINFER_ALLREDUCE_BACKEND: Literal["auto", "trtllm", "mnnvl"] = "auto" VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE: int = 394 * 1024 * 1024 VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 @@ -1305,14 +1305,9 @@ environment_variables: dict[str, Callable[[], Any]] = { ["throughput", "latency", "masked_gemm"], ), # Flashinfer fused allreduce backend. - # "auto" will default to "mnnvl", which performs mostly same/better than "trtllm". - # But "mnnvl" backend does not support fuse with quantization. - # TODO: Default is "trtllm" right now because "mnnvl" has issues with cudagraph: - # https://github.com/vllm-project/vllm/issues/35772 - # Should switch back to "auto" if the issue is resolved. "VLLM_FLASHINFER_ALLREDUCE_BACKEND": env_with_choices( "VLLM_FLASHINFER_ALLREDUCE_BACKEND", - "trtllm", + "auto", ["auto", "trtllm", "mnnvl"], ), # Control the workspace buffer size for the FlashInfer backend.