Fix multi-node allreduce fusion (#38136)
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com> Co-authored-by: root <root@theia0053.lyris.clusters.nvidia.com>
This commit is contained in:
@@ -13,11 +13,13 @@ from torch.distributed import ProcessGroup
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config.compilation import PassConfig
|
||||
from vllm.distributed.parallel_state import get_node_count
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
fi_ar_available = False
|
||||
try:
|
||||
import flashinfer.comm as flashinfer_comm # type: ignore[no-redef]
|
||||
@@ -87,6 +89,27 @@ def _create_workspace(
|
||||
return workspace
|
||||
|
||||
|
||||
def _resolve_fi_ar_backend() -> str:
|
||||
backend = envs.VLLM_FLASHINFER_ALLREDUCE_BACKEND
|
||||
if backend != "auto":
|
||||
logger.info_once(f"Using flashinfer allreduce backend: {backend}")
|
||||
return backend
|
||||
|
||||
if get_node_count() > 1: # noqa: SIM108
|
||||
# Use mnnvl backend for multi-node setup since
|
||||
# trtllm backend does not support multi-node allreduce
|
||||
backend = "mnnvl"
|
||||
else:
|
||||
# Currently defaulting to trtllm backend for single-node
|
||||
# setup since mnnvl has issues with cudagraph:
|
||||
# https://github.com/vllm-project/vllm/issues/35772
|
||||
# Should switch back to auto when the issue is resolved.
|
||||
backend = "trtllm"
|
||||
|
||||
logger.info_once(f"Auto-selected flashinfer allreduce backend: {backend}")
|
||||
return backend
|
||||
|
||||
|
||||
def get_fi_ar_workspace(
|
||||
world_size: int,
|
||||
rank: int,
|
||||
@@ -106,7 +129,13 @@ def get_fi_ar_workspace(
|
||||
if _fi_ar_workspace is not None:
|
||||
return _fi_ar_workspace
|
||||
|
||||
backend = envs.VLLM_FLASHINFER_ALLREDUCE_BACKEND
|
||||
backend = _resolve_fi_ar_backend()
|
||||
|
||||
if get_node_count() > 1 and backend == "trtllm":
|
||||
raise ValueError(
|
||||
"Flashinfer allreduce is not supported for multi-node allreduce with "
|
||||
"'trtllm' backend. Please use 'mnnvl' backend instead."
|
||||
)
|
||||
|
||||
# Reuse the quant workspace if it was already created with the same backend
|
||||
if _fi_ar_quant_workspace is not None and _fi_ar_quant_workspace.backend == backend:
|
||||
@@ -116,6 +145,17 @@ def get_fi_ar_workspace(
|
||||
_fi_ar_workspace = _create_workspace(
|
||||
backend, world_size, rank, max_token_num, hidden_dim, dtype, group
|
||||
)
|
||||
if _fi_ar_workspace is not None:
|
||||
logger.info_once(
|
||||
"Initialized FlashInfer Allreduce norm fusion workspace "
|
||||
f"with backend={backend}"
|
||||
)
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Failed to initialize FlashInfer Allreduce norm fusion workspace "
|
||||
f"with backend={backend}"
|
||||
)
|
||||
|
||||
return _fi_ar_workspace
|
||||
|
||||
|
||||
@@ -131,12 +171,20 @@ def get_fi_ar_quant_workspace(
|
||||
Return the allreduce workspace for quant patterns, initializing if needed.
|
||||
|
||||
Always uses trtllm backend as it is the only one supporting quantization
|
||||
fusion (FP8/FP4).
|
||||
fusion (FP8/FP4). Returns None for multi-node setups since not supported
|
||||
by trtllm backend.
|
||||
"""
|
||||
global _fi_ar_quant_workspace
|
||||
if _fi_ar_quant_workspace is not None:
|
||||
return _fi_ar_quant_workspace
|
||||
|
||||
if get_node_count() > 1:
|
||||
logger.warning_once(
|
||||
"Flashinfer allreduce quantization fusion is not supported for "
|
||||
"multi-node allreduce. Disabling quant fusion."
|
||||
)
|
||||
return None
|
||||
|
||||
# Reuse the non-quant workspace if it was already created with trtllm
|
||||
if _fi_ar_workspace is not None and _fi_ar_workspace.backend == "trtllm":
|
||||
_fi_ar_quant_workspace = _fi_ar_workspace
|
||||
@@ -145,6 +193,17 @@ def get_fi_ar_quant_workspace(
|
||||
_fi_ar_quant_workspace = _create_workspace(
|
||||
"trtllm", world_size, rank, max_token_num, hidden_dim, dtype, group
|
||||
)
|
||||
if _fi_ar_quant_workspace is not None:
|
||||
logger.info_once(
|
||||
"Initialized FlashInfer Allreduce norm quantization "
|
||||
"fusion workspace with backend=trtllm"
|
||||
)
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Failed to initialize FlashInfer Allreduce norm quantization "
|
||||
"fusion workspace with backend=trtllm"
|
||||
)
|
||||
|
||||
return _fi_ar_quant_workspace
|
||||
|
||||
|
||||
|
||||
@@ -169,7 +169,7 @@ if TYPE_CHECKING:
|
||||
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency", "masked_gemm"] = (
|
||||
"latency"
|
||||
)
|
||||
VLLM_FLASHINFER_ALLREDUCE_BACKEND: Literal["auto", "trtllm", "mnnvl"] = "trtllm"
|
||||
VLLM_FLASHINFER_ALLREDUCE_BACKEND: Literal["auto", "trtllm", "mnnvl"] = "auto"
|
||||
VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE: int = 394 * 1024 * 1024
|
||||
VLLM_XGRAMMAR_CACHE_MB: int = 0
|
||||
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
|
||||
@@ -1305,14 +1305,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
["throughput", "latency", "masked_gemm"],
|
||||
),
|
||||
# Flashinfer fused allreduce backend.
|
||||
# "auto" will default to "mnnvl", which performs mostly same/better than "trtllm".
|
||||
# But "mnnvl" backend does not support fuse with quantization.
|
||||
# TODO: Default is "trtllm" right now because "mnnvl" has issues with cudagraph:
|
||||
# https://github.com/vllm-project/vllm/issues/35772
|
||||
# Should switch back to "auto" if the issue is resolved.
|
||||
"VLLM_FLASHINFER_ALLREDUCE_BACKEND": env_with_choices(
|
||||
"VLLM_FLASHINFER_ALLREDUCE_BACKEND",
|
||||
"trtllm",
|
||||
"auto",
|
||||
["auto", "trtllm", "mnnvl"],
|
||||
),
|
||||
# Control the workspace buffer size for the FlashInfer backend.
|
||||
|
||||
Reference in New Issue
Block a user