[Kernel] Refactor FlashInfer allreduce for mnnvl backend (#34109)

Signed-off-by: hjjq <50634613+hjjq@users.noreply.github.com>
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: Wei Zhao <51183510+wzhao18@users.noreply.github.com>
This commit is contained in:
Hanjie Qiu
2026-02-25 19:17:20 -08:00
committed by GitHub
parent 2aa4140402
commit 71dfce6aa6
7 changed files with 593 additions and 180 deletions

View File

@@ -30,6 +30,9 @@ import torch.distributed as dist
from torch.distributed import ProcessGroup
from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce
from vllm.distributed.device_communicators.flashinfer_all_reduce import (
FlashInferAllReduce,
)
from vllm.distributed.device_communicators.pynccl import (
PyNcclCommunicator,
register_nccl_symmetric_ops,
@@ -44,7 +47,7 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser
logger = init_logger(__name__)
# Default sequence lengths to benchmark
DEFAULT_SEQUENCE_LENGTHS = [128, 512, 1024, 2048, 4096, 8192]
DEFAULT_SEQUENCE_LENGTHS = [16, 64, 128, 512, 1024, 2048, 4096, 8192]
# Fixed hidden size and dtype for all benchmarks
HIDDEN_SIZE = 8192
@@ -81,6 +84,7 @@ class CommunicatorBenchmark:
self.symm_mem_comm = None
self.symm_mem_comm_multimem = None
self.symm_mem_comm_two_shot = None
self.fi_ar_comm = None
self._init_communicators()
@@ -161,6 +165,22 @@ class CommunicatorBenchmark:
)
self.symm_mem_comm_two_shot = None
try:
self.fi_ar_comm = FlashInferAllReduce(
group=self.cpu_group,
device=self.device,
)
if not self.fi_ar_comm.disabled:
logger.info("Rank %s: FlashInferAllReduce initialized", self.rank)
else:
logger.info("Rank %s: FlashInferAllReduce disabled", self.rank)
self.fi_ar_comm = None
except Exception as e:
logger.warning(
"Rank %s: Failed to initialize FlashInferAllReduce: %s", self.rank, e
)
self.fi_ar_comm = None
def benchmark_allreduce(
self, sequence_length: int, num_warmup: int, num_trials: int
) -> dict[str, float]:
@@ -180,7 +200,8 @@ class CommunicatorBenchmark:
lambda t, c=comm: c.custom_all_reduce(t),
lambda t, c=comm: c.should_custom_ar(t),
comm.capture(),
"1stage", # env variable value
{"VLLM_CUSTOM_ALLREDUCE_ALGO": "1stage"},
None, # no destroy function
)
)
# CustomAllreduce two-shot
@@ -190,7 +211,8 @@ class CommunicatorBenchmark:
lambda t, c=comm: c.custom_all_reduce(t),
lambda t, c=comm: c.should_custom_ar(t),
comm.capture(),
"2stage", # env variable value
{"VLLM_CUSTOM_ALLREDUCE_ALGO": "2stage"},
None, # no destroy function
)
)
@@ -202,7 +224,8 @@ class CommunicatorBenchmark:
lambda t, c=comm: c.all_reduce(t),
lambda t: True, # Always available if initialized
nullcontext(),
None, # no env variable needed
{}, # no env variable needed
None, # no destroy function
)
)
communicators.append(
@@ -211,7 +234,8 @@ class CommunicatorBenchmark:
lambda t: torch.ops.vllm.all_reduce_symmetric_with_copy(t),
lambda t: True, # Always available if initialized
nullcontext(),
None, # no env variable needed
{}, # no env variable needed
None, # no destroy function
)
)
@@ -223,7 +247,8 @@ class CommunicatorBenchmark:
lambda t, c=comm: c.all_reduce(t),
lambda t, c=comm: c.should_use_symm_mem(t),
nullcontext(),
None, # no env variable needed
{}, # no env variable needed
None, # no destroy function
)
)
@@ -235,29 +260,67 @@ class CommunicatorBenchmark:
lambda t, c=comm: c.all_reduce(t),
lambda t, c=comm: c.should_use_symm_mem(t),
nullcontext(),
None, # no env variable needed
{}, # no env variable needed
None, # no destroy function needed
)
)
if self.fi_ar_comm is not None:
comm = self.fi_ar_comm
communicators.append(
(
"flashinfer_trtllm",
lambda t, c=comm: c.all_reduce(t),
lambda t, c=comm: c.should_use_fi_ar(t),
nullcontext(),
{"VLLM_FLASHINFER_ALLREDUCE_BACKEND": "trtllm"},
lambda c=comm: c.destroy(),
)
)
communicators.append(
(
"flashinfer_mnnvl",
lambda t, c=comm: c.all_reduce(t),
lambda t, c=comm: c.should_use_fi_ar(t),
nullcontext(),
{"VLLM_FLASHINFER_ALLREDUCE_BACKEND": "mnnvl"},
lambda c=comm: c.destroy(),
)
)
# Benchmark each communicator
for name, allreduce_fn, should_use_fn, context, env_var in communicators:
# Set environment variable if needed
if env_var is not None:
os.environ["VLLM_CUSTOM_ALLREDUCE_ALGO"] = env_var
else:
# Clear the environment variable to avoid interference
os.environ.pop("VLLM_CUSTOM_ALLREDUCE_ALGO", None)
latency = self.benchmark_allreduce_single(
sequence_length,
allreduce_fn,
should_use_fn,
context,
num_warmup,
num_trials,
)
if latency is not None:
results[name] = latency
for (
name,
allreduce_fn,
should_use_fn,
context,
env_dict,
destroy_fn,
) in communicators:
# Save original values and apply new environment variables
saved_env = {key: os.environ.get(key) for key in env_dict}
for key, value in env_dict.items():
os.environ[key] = value
try:
latency = self.benchmark_allreduce_single(
sequence_length,
allreduce_fn,
should_use_fn,
context,
num_warmup,
num_trials,
)
if latency is not None:
results[name] = latency
finally:
if destroy_fn is not None:
destroy_fn()
# Restore environment variables to their original state
for key, original_value in saved_env.items():
if original_value is None:
os.environ.pop(key, None)
else:
os.environ[key] = original_value
return results

View File

@@ -5,8 +5,11 @@
Benchmark for FlashInfer fused collective operations vs standard operations.
This benchmark compares:
1. FlashInfer's allreduce_fusion (fused allreduce + rmsnorm + optional quant)
2. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations
1. FlashInfer's allreduce_fusion with trtllm backend
(fused allreduce + rmsnorm + optional FP8/FP4 quant)
2. FlashInfer's allreduce_fusion with mnnvl backend
(fused allreduce + rmsnorm only, no quantization support)
3. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations
Usage with torchrun:
torchrun --nproc_per_node=2 benchmark_fused_collective.py
@@ -48,8 +51,12 @@ SCALED_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant
logger = init_logger(__name__)
# Try to import FlashInfer
TorchDistBackend = None
try:
import flashinfer.comm as flashinfer_comm # type: ignore
from flashinfer.comm.mnnvl import ( # type: ignore
TorchDistBackend,
)
if not (
hasattr(flashinfer_comm, "allreduce_fusion")
@@ -74,11 +81,15 @@ _FI_MAX_SIZES = {
8: 64 * MiB, # 64MB
}
# Global workspace tensor for FlashInfer
_FI_WORKSPACE = None
# Global workspace tensors for FlashInfer (keyed by backend name)
_FI_WORKSPACES: dict = {}
# Backends to benchmark
FLASHINFER_BACKENDS = ["trtllm", "mnnvl"]
def setup_flashinfer_workspace(
backend: str,
world_size: int,
rank: int,
hidden_dim: int,
@@ -86,41 +97,54 @@ def setup_flashinfer_workspace(
dtype: torch.dtype,
):
"""Setup FlashInfer workspace for fused allreduce operations."""
global _FI_WORKSPACE
global FI_WORKSPACES
if flashinfer_comm is None:
return None, None
return None
if world_size not in _FI_MAX_SIZES:
logger.warning("FlashInfer not supported for world size %s", world_size)
return None, None
return None
try:
kwargs = {}
if TorchDistBackend is not None:
kwargs["comm_backend"] = TorchDistBackend(group=dist.group.WORLD)
workspace = flashinfer_comm.create_allreduce_fusion_workspace(
backend="trtllm",
backend=backend,
world_size=world_size,
rank=rank,
max_token_num=max_token_num,
hidden_dim=hidden_dim,
dtype=dtype,
**kwargs,
)
_FI_WORKSPACE = workspace
_FI_WORKSPACES[backend] = workspace
return workspace
except Exception as e:
logger.error("Failed to setup FlashInfer workspace: %s", e)
logger.error(
"Failed to setup FlashInfer workspace (backend=%s): %s", backend, e
)
return None
def cleanup_flashinfer_workspace(workspace):
"""Cleanup FlashInfer workspace."""
if flashinfer_comm is None or workspace is None:
def cleanup_flashinfer_workspaces():
"""Cleanup all FlashInfer workspaces."""
if flashinfer_comm is None:
return
try:
workspace.destroy()
except Exception as e:
logger.error("Failed to cleanup FlashInfer workspace: %s", e)
for backend, workspace in _FI_WORKSPACES.items():
try:
workspace.destroy()
except Exception as e:
logger.error(
"Failed to cleanup FlashInfer workspace (backend=%s): %s",
backend,
e,
)
_FI_WORKSPACES.clear()
class FlashInferFusedAllReduceParams:
@@ -134,7 +158,7 @@ class FlashInferFusedAllReduceParams:
self.fp32_acc = True
self.max_token_num = max_token_num
def get_trtllm_fused_allreduce_kwargs(self):
def get_flashinfer_fused_allreduce_kwargs(self):
return {
"launch_with_pdl": self.launch_with_pdl,
"fp32_acc": self.fp32_acc,
@@ -147,11 +171,12 @@ def flashinfer_fused_allreduce_rmsnorm(
rms_gamma: torch.Tensor,
rms_eps: float,
allreduce_params: "FlashInferFusedAllReduceParams",
workspace: object,
use_oneshot: bool,
norm_out: torch.Tensor | None = None,
):
"""FlashInfer fused allreduce + rmsnorm operation."""
if flashinfer_comm is None or _FI_WORKSPACE is None:
if flashinfer_comm is None or workspace is None:
raise RuntimeError("FlashInfer not available or workspace not initialized")
if norm_out is None:
@@ -160,9 +185,13 @@ def flashinfer_fused_allreduce_rmsnorm(
else:
residual_out = input_tensor
layout_code = None
if workspace.backend == "trtllm":
layout_code = flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4
flashinfer_comm.allreduce_fusion(
input=input_tensor,
workspace=_FI_WORKSPACE,
workspace=workspace,
pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
residual_in=residual,
residual_out=residual_out,
@@ -171,10 +200,10 @@ def flashinfer_fused_allreduce_rmsnorm(
rms_eps=rms_eps,
quant_out=None,
scale_out=None,
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
layout_code=layout_code,
scale_factor=None,
use_oneshot=use_oneshot,
**allreduce_params.get_trtllm_fused_allreduce_kwargs(),
**allreduce_params.get_flashinfer_fused_allreduce_kwargs(),
)
@@ -185,12 +214,16 @@ def flashinfer_fused_allreduce_rmsnorm_fp8_quant(
rms_eps: float,
scale_factor: torch.Tensor,
allreduce_params: FlashInferFusedAllReduceParams,
workspace: object,
use_oneshot: bool = True,
norm_out: torch.Tensor | None = None,
quant_out: torch.Tensor | None = None,
):
"""FlashInfer fused allreduce + rmsnorm + FP8 quantization."""
if flashinfer_comm is None or _FI_WORKSPACE is None:
"""FlashInfer fused allreduce + rmsnorm + FP8 quantization.
Note: Only supported by the trtllm backend.
"""
if flashinfer_comm is None or workspace is None:
raise RuntimeError("FlashInfer not available or workspace not initialized")
if norm_out is None:
@@ -201,7 +234,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp8_quant(
flashinfer_comm.allreduce_fusion(
input=input_tensor,
workspace=_FI_WORKSPACE,
workspace=workspace,
pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant,
residual_in=residual,
residual_out=residual_out,
@@ -213,7 +246,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp8_quant(
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
scale_factor=scale_factor,
use_oneshot=use_oneshot,
**allreduce_params.get_trtllm_fused_allreduce_kwargs(),
**allreduce_params.get_flashinfer_fused_allreduce_kwargs(),
)
@@ -224,13 +257,17 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant(
rms_eps: float,
input_global_scale: torch.Tensor,
allreduce_params: FlashInferFusedAllReduceParams,
workspace: object,
quant_out: torch.Tensor,
use_oneshot: bool,
output_scale: torch.Tensor,
norm_out: torch.Tensor | None = None,
):
"""FlashInfer fused allreduce + rmsnorm + FP4 quantization."""
if flashinfer_comm is None or _FI_WORKSPACE is None:
"""FlashInfer fused allreduce + rmsnorm + FP4 quantization.
Note: Only supported by the trtllm backend.
"""
if flashinfer_comm is None or workspace is None:
raise RuntimeError("FlashInfer not available or workspace not initialized")
if norm_out is None:
@@ -241,7 +278,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant(
flashinfer_comm.allreduce_fusion(
input=input_tensor,
workspace=_FI_WORKSPACE,
workspace=workspace,
pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant,
residual_in=residual,
residual_out=residual_out,
@@ -253,7 +290,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant(
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
scale_factor=input_global_scale,
use_oneshot=use_oneshot,
**allreduce_params.get_trtllm_fused_allreduce_kwargs(),
**allreduce_params.get_flashinfer_fused_allreduce_kwargs(),
)
@@ -386,13 +423,16 @@ def run_benchmarks(
dtype: torch.dtype,
use_residual: bool,
allreduce_params: FlashInferFusedAllReduceParams | None,
workspaces: dict,
quant_modes: set[str],
no_oneshot: bool,
):
"""Run all benchmarks for given configuration.
Args:
quant_mode: "none", "fp8_only", "fp4_only", or "all"
allreduce_params: Shared parameters for FlashInfer fused allreduce.
workspaces: Dict mapping backend name ("trtllm", "mnnvl") to workspace.
quant_modes: Set of quantization modes: "none", "fp8", "fp4".
"""
(
input_tensor,
@@ -454,10 +494,11 @@ def run_benchmarks(
logger.error("Standard AllReduce+RMSNorm Native Compiled failed: %s", e)
results["standard_allreduce_rmsnorm_native_compiled"] = float("inf")
# FlashInfer Fused AllReduce + RMSNorm Oneshot/Twoshot
if flashinfer_comm is not None and allreduce_params is not None:
# FlashInfer Fused AllReduce + RMSNorm (all backends)
for backend, workspace in workspaces.items():
for use_oneshot in use_oneshot_options:
suffix = "_oneshot" if use_oneshot else "_twoshot"
key = f"flashinfer_{backend}_fused_allreduce_rmsnorm{suffix}"
try:
time_ms = benchmark_operation(
flashinfer_fused_allreduce_rmsnorm,
@@ -467,14 +508,17 @@ def run_benchmarks(
rms_gamma=rms_gamma,
rms_eps=rms_eps,
allreduce_params=allreduce_params,
workspace=workspace,
use_oneshot=use_oneshot,
)
results[f"flashinfer_fused_allreduce_rmsnorm{suffix}"] = time_ms
results[key] = time_ms
except Exception as e:
logger.error("FlashInfer Fused AllReduce+RMSNorm failed: %s", e)
results[f"flashinfer_fused_allreduce_rmsnorm{suffix}"] = float(
"inf"
logger.error(
"FlashInfer (%s) Fused AllReduce+RMSNorm failed: %s",
backend,
e,
)
results[key] = float("inf")
if "fp8" in quant_modes:
# Standard AllReduce + RMSNorm + FP8 Quant
@@ -540,10 +584,12 @@ def run_benchmarks(
"inf"
)
# FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Oneshot
if flashinfer_comm is not None and allreduce_params is not None:
# FlashInfer Fused AllReduce + RMSNorm + FP8 Quant (trtllm only)
if "trtllm" in workspaces:
trtllm_ws = workspaces["trtllm"]
for use_oneshot in use_oneshot_options:
suffix = "_oneshot" if use_oneshot else "_twoshot"
key = f"flashinfer_trtllm_fused_allreduce_rmsnorm_fp8_quant{suffix}"
try:
time_ms = benchmark_operation(
flashinfer_fused_allreduce_rmsnorm_fp8_quant,
@@ -555,19 +601,16 @@ def run_benchmarks(
scale_factor=scale_fp8,
quant_out=quant_out_fp8,
allreduce_params=allreduce_params,
workspace=trtllm_ws,
use_oneshot=use_oneshot,
)
results[f"flashinfer_fused_allreduce_rmsnorm_fp8_quant{suffix}"] = (
time_ms
)
results[key] = time_ms
except Exception as e:
logger.error(
"FlashInfer Fused AllReduce+RMSNorm+FP8 Oneshot failed: %s",
"FlashInfer (trtllm) Fused AllReduce+RMSNorm+FP8 failed: %s",
e,
)
results[f"flashinfer_fused_allreduce_rmsnorm_fp8_quant{suffix}"] = (
float("inf")
)
results[key] = float("inf")
if "fp4" in quant_modes and current_platform.has_device_capability(100):
# Standard AllReduce + RMSNorm + FP4 Quant
@@ -627,10 +670,12 @@ def run_benchmarks(
"inf"
)
# FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Oneshot
if flashinfer_comm is not None and allreduce_params is not None:
# FlashInfer Fused AllReduce + RMSNorm + FP4 Quant (trtllm only)
if "trtllm" in workspaces:
trtllm_ws = workspaces["trtllm"]
for use_oneshot in use_oneshot_options:
suffix = "_oneshot" if use_oneshot else "_twoshot"
key = f"flashinfer_trtllm_fused_allreduce_rmsnorm_fp4_quant{suffix}"
try:
time_ms = benchmark_operation(
flashinfer_fused_allreduce_rmsnorm_fp4_quant,
@@ -641,49 +686,18 @@ def run_benchmarks(
rms_eps=rms_eps,
input_global_scale=scale_fp4,
allreduce_params=allreduce_params,
workspace=trtllm_ws,
quant_out=fp4_quant_out,
output_scale=fp4_output_scale,
use_oneshot=use_oneshot,
)
results[f"flashinfer_fused_allreduce_rmsnorm_fp4_quant{suffix}"] = (
time_ms
)
results[key] = time_ms
except Exception as e:
logger.error(
"FlashInfer Fused AllReduce+RMSNorm+FP4 Oneshot failed: %s",
"FlashInfer (trtllm) Fused AllReduce+RMSNorm+FP4 failed: %s",
e,
)
results[f"flashinfer_fused_allreduce_rmsnorm_fp4_quant{suffix}"] = (
float("inf")
)
# FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Two-shot
if flashinfer_comm is not None and allreduce_params is not None:
try:
time_ms = benchmark_operation(
flashinfer_fused_allreduce_rmsnorm_fp4_quant,
input_tensor,
residual=residual,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
input_global_scale=scale_fp4,
allreduce_params=allreduce_params,
quant_out=fp4_quant_out,
output_scale=fp4_output_scale,
use_oneshot=False,
)
results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = (
time_ms
)
except Exception as e:
logger.error(
"FlashInfer Fused AllReduce+RMSNorm+FP4 Two-shot failed: %s",
e,
)
results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = float(
"inf"
)
results[key] = float("inf")
return results
@@ -1021,8 +1035,7 @@ def main():
configs = list(itertools.product(args.num_tokens, dtypes, residual_options))
# Setup FlashInfer workspace if available
workspace = None
# Setup FlashInfer workspaces for all backends
allreduce_params = None
if flashinfer_comm is not None:
@@ -1037,15 +1050,17 @@ def main():
args.hidden_dim * max_element_size
)
workspace = setup_flashinfer_workspace(
world_size,
rank,
args.hidden_dim,
max_num_token,
dtype=workspace_dtype,
)
for backend in FLASHINFER_BACKENDS:
setup_flashinfer_workspace(
backend=backend,
world_size=world_size,
rank=rank,
hidden_dim=args.hidden_dim,
max_token_num=max_num_token,
dtype=workspace_dtype,
)
if workspace is not None:
if _FI_WORKSPACES:
allreduce_params = FlashInferFusedAllReduceParams(
max_token_num=max_num_token,
)
@@ -1071,6 +1086,7 @@ def main():
dtype,
use_residual,
allreduce_params,
workspaces=_FI_WORKSPACES,
quant_modes=quant_modes,
no_oneshot=args.no_oneshot,
)
@@ -1109,11 +1125,13 @@ def main():
finally:
# Cleanup
if workspace is not None:
cleanup_flashinfer_workspace(workspace)
cleanup_flashinfer_workspaces()
dist.barrier()
if __name__ == "__main__":
main()
from vllm.config import VllmConfig, set_current_vllm_config
with set_current_vllm_config(VllmConfig()):
main()

View File

@@ -142,7 +142,6 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
*(scaled_fp4_quant(w, wg) for w, wg in zip(self.w, wgscale))
)
self.wq, self.wscale = list(wq_gen), list(wscale_gen)
print(f"{self.wq=}, {self.wscale=}")
def forward(self, hidden_states):
# avoid having graph input be an arg to a pattern directly
@@ -199,6 +198,7 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
@pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
@pytest.mark.parametrize("flashinfer_allreduce_backend", ["trtllm", "mnnvl"])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
@pytest.mark.skipif(
not find_spec("flashinfer")
@@ -215,6 +215,7 @@ def test_all_reduce_fusion_pass_replace(
dtype: torch.dtype,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
flashinfer_allreduce_backend,
):
num_processes = 2
if (
@@ -238,6 +239,7 @@ def test_all_reduce_fusion_pass_replace(
dtype,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
flashinfer_allreduce_backend,
),
nprocs=nprocs,
)
@@ -255,6 +257,7 @@ def all_reduce_fusion_pass_on_test_model(
dtype: torch.dtype,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
flashinfer_allreduce_backend,
):
set_random_seed(0)
@@ -270,6 +273,7 @@ def all_reduce_fusion_pass_on_test_model(
"WORLD_SIZE": str(world_size),
"MASTER_ADDR": "localhost",
"MASTER_PORT": "12345",
"VLLM_FLASHINFER_ALLREDUCE_BACKEND": flashinfer_allreduce_backend,
}
)
@@ -317,6 +321,10 @@ def all_reduce_fusion_pass_on_test_model(
compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states)
results_unfused = model(hidden_states)
results_fused = compiled_model(hidden_states)
torch.testing.assert_close(results_unfused, results_fused, atol=1e-2, rtol=1e-2)
assert all_reduce_fusion_pass.matched_count == 4, (
f"{all_reduce_fusion_pass.matched_count=}"
)

View File

@@ -22,7 +22,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.utils.torch_utils import (
direct_register_custom_op,
)
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
@@ -44,8 +46,6 @@ if find_spec("flashinfer"):
except ImportError:
pass
logger = init_logger(__name__)
if hasattr(torch.ops._C, "scaled_fp4_quant"):
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
@@ -82,7 +82,16 @@ _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB: dict[int, dict[int, float]] = {
if flashinfer_comm is not None:
_FI_WORKSPACE = None
from vllm.distributed.device_communicators.flashinfer_all_reduce import (
destroy_fi_ar_workspace,
get_fi_ar_quant_workspace,
get_fi_ar_workspace,
initialize_fi_ar_quant_workspace,
initialize_fi_ar_workspace,
)
ar_fusion_patterns = flashinfer_comm.AllReduceFusionPattern
MiB = 1024 * 1024
def call_trtllm_fused_allreduce_norm(
@@ -122,9 +131,19 @@ if flashinfer_comm is not None:
max_one_shot_size is None or current_tensor_size <= max_one_shot_size * MiB
)
assert _FI_WORKSPACE is not None, (
"Flashinfer must be enabled when using flashinfer"
# Select workspace based on pattern: quant patterns use the
# trtllm quant workspace, non-quant patterns use the primary workspace.
if pattern_code in (
ar_fusion_patterns.kARResidualRMSNormFP8Quant,
ar_fusion_patterns.kARResidualRMSNormFP4Quant,
):
workspace = get_fi_ar_quant_workspace()
else:
workspace = get_fi_ar_workspace()
assert workspace is not None, (
"Flashinfer workspace must be initialized when using flashinfer"
)
assert flashinfer_comm is not None
if norm_out is None:
norm_out = allreduce_in
residual_out = residual
@@ -133,25 +152,30 @@ if flashinfer_comm is not None:
# as flashinfer does not support rms_norm
# and allreduce_out together
residual_out = allreduce_in
# For the sizes that are smaller than the max size,
# we only use flashinfer one shot allreduce
layout_code = None
# layout_code only supported by trtllm backend
if workspace.backend == "trtllm":
# in vllm we only support swizzled layout
layout_code = flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4
flashinfer_comm.allreduce_fusion(
input=allreduce_in,
workspace=_FI_WORKSPACE,
workspace=workspace,
pattern=pattern_code,
residual_in=residual,
launch_with_pdl=launch_with_pdl,
output=None,
residual_out=residual_out,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
launch_with_pdl=launch_with_pdl,
use_oneshot=use_oneshot,
fp32_acc=fp32_acc,
quant_out=quant_out,
scale_out=scale_out,
# in vllm we only support swizzled layout
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
residual_in=residual,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
scale_factor=scale_factor,
layout_code=layout_code,
use_oneshot=use_oneshot,
fp32_acc=fp32_acc,
)
def call_trtllm_fused_allreduce_norm_fake(
@@ -729,29 +753,36 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
scope="global",
)
try:
self.workspace = flashinfer_comm.create_allreduce_fusion_workspace(
backend="trtllm",
world_size=self.tp_size,
rank=rank,
max_token_num=self.max_token_num,
hidden_dim=self.hidden_dim,
dtype=self.model_dtype,
)
except RuntimeError as e:
if "multicast" not in str(e).lower():
raise
logger.warning_once(
"AllReduce fusion pass is disabled: flashinfer workspace "
"creation failed: %s. This is expected on GPUs without "
"NVSwitch (e.g., NVLink bridge-only or PCIe topologies). "
"Falling back to non-fused allreduce.",
str(e),
)
return
for workspace_init_fn in [
initialize_fi_ar_workspace,
initialize_fi_ar_quant_workspace,
]:
try:
workspace_init_fn(
world_size=self.tp_size,
rank=rank,
max_token_num=self.max_token_num,
hidden_dim=self.hidden_dim,
dtype=self.model_dtype,
group=self.group,
)
except Exception as e:
if "multicast" in str(e).lower():
logger.warning(
"AllReduce fusion pass is disabled: flashinfer workspace "
"creation failed: %s. This is expected on GPUs without "
"NVSwitch (e.g., NVLink bridge-only or PCIe topologies). "
"Falling back to non-fused allreduce.",
str(e),
)
else:
logger.warning(
"Failed to initialize FlashInfer All Reduce workspace: %s. "
"AllReduce fusion pass will be disabled.",
e,
)
return
global _FI_WORKSPACE
_FI_WORKSPACE = self.workspace
self.allreduce_params = FlashInferFusedAllReduceParams(
world_size=self.tp_size,
max_token_num=self.max_token_num,
@@ -762,32 +793,34 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
@enable_fake_mode
def register_patterns(self) -> None:
supports_quantization = get_fi_ar_quant_workspace() is not None
for epsilon in [1e-5, 1e-6]:
AllReduceFusedRMSNormStaticQuantFP8Pattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
AllReduceFusedAddRMSNormStaticQuantFP8Pattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
if current_platform.has_device_capability(100):
AllReduceFusedRMSNormStaticQuantNVFP4Pattern(
if supports_quantization:
AllReduceFusedRMSNormStaticQuantFP8Pattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(
AllReduceFusedAddRMSNormStaticQuantFP8Pattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
if current_platform.has_device_capability(100):
AllReduceFusedRMSNormStaticQuantNVFP4Pattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
AllReduceRMSNormPattern(
epsilon,
self.model_dtype,
@@ -825,6 +858,5 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
def __del__(self) -> None:
if getattr(self, "disabled", True):
return
if getattr(self, "workspace", None) is not None:
with contextlib.suppress(Exception):
self.workspace.destroy()
with contextlib.suppress(Exception):
destroy_fi_ar_workspace()

View File

@@ -34,19 +34,25 @@ class CudaCommunicator(DeviceCommunicatorBase):
# custom allreduce or torch symm mem can be used only by tp
use_custom_allreduce = False
use_torch_symm_mem = False
use_flashinfer_allreduce = False
else:
from vllm.distributed.parallel_state import _ENABLE_CUSTOM_ALL_REDUCE
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM
use_flashinfer_allreduce = envs.VLLM_ALLREDUCE_USE_FLASHINFER
self.use_custom_allreduce = use_custom_allreduce
self.use_torch_symm_mem = use_torch_symm_mem
self.use_flashinfer_allreduce = use_flashinfer_allreduce
# lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import (
CustomAllreduce,
)
from vllm.distributed.device_communicators.flashinfer_all_reduce import (
FlashInferAllReduce,
)
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.device_communicators.quick_all_reduce import (
QuickAllReduce,
@@ -65,12 +71,20 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.ca_comm: CustomAllreduce | None = None
self.qr_comm: QuickAllReduce | None = None
self.symm_mem_comm: SymmMemCommunicator | None = None
self.fi_ar_comm: FlashInferAllReduce | None = None
if use_torch_symm_mem and current_platform.is_cuda():
self.symm_mem_comm = SymmMemCommunicator(
group=self.cpu_group,
device=self.device,
)
if self.use_flashinfer_allreduce and self.world_size > 1:
self.fi_ar_comm = FlashInferAllReduce(
group=self.cpu_group,
device=self.device,
)
if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation.
self.ca_comm = CustomAllreduce(
@@ -136,7 +150,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
out = torch.ops.vllm.all_reduce_symmetric_with_copy(input_)
if out is not None:
return out
# always try quick reduce first, then custom allreduce,
# always try quick reduce first, then flashinfer, then custom allreduce,
# and then pynccl. (quick reduce just for ROCM MI3*)
qr_comm = self.qr_comm
if (
@@ -147,6 +161,15 @@ class CudaCommunicator(DeviceCommunicatorBase):
out = qr_comm.quick_all_reduce(input_)
assert out is not None
return out
fi_ar_comm = self.fi_ar_comm
if (
fi_ar_comm is not None
and not fi_ar_comm.disabled
and fi_ar_comm.should_use_fi_ar(input_)
):
out = fi_ar_comm.all_reduce(input_)
assert out is not None
return out
ca_comm = self.ca_comm
if (
ca_comm is not None
@@ -270,6 +293,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.pynccl_comm = None
if self.ca_comm is not None:
self.ca_comm = None
if self.fi_ar_comm is not None:
self.fi_ar_comm.destroy()
self.fi_ar_comm = None
if self.all2all_manager is not None:
self.all2all_manager.destroy()
self.all2all_manager = None

View File

@@ -0,0 +1,252 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
import vllm.envs as envs
from vllm.config.compilation import PassConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__)
fi_ar_available = False
try:
import flashinfer.comm as flashinfer_comm # type: ignore[no-redef]
from flashinfer.comm.mnnvl import (
TorchDistBackend, # type: ignore[import-not-found, no-redef]
)
fi_ar_available = hasattr(flashinfer_comm, "allreduce_fusion")
except ImportError:
pass
# Global workspace for standalone allreduce and non-quant ar+rms fusion
_fi_ar_workspace = None
# Extra workspace for quant fusion patterns (only supported by trtllm backend)
# Only created if primary workspace is not already trtllm
_fi_ar_quant_workspace = None
def get_fi_ar_workspace():
return _fi_ar_workspace
def get_fi_ar_quant_workspace():
return _fi_ar_quant_workspace
def initialize_fi_ar_workspace(
world_size: int,
rank: int,
max_token_num: int,
hidden_dim: int,
dtype: torch.dtype,
group: ProcessGroup,
) -> None:
"""
Initialize the workspace if not already initialized.
Currently, this function is called by either the AllReduceFusionPass
or the FlashInferAllReduce backend for standalone allreduce.
If the fusion pass is enabled via
--compilation-config.pass_config.fuse_allreduce_rms=true,
it will create the workspace first, and the standalone backend
will reuse the workspace. Otherwise, the standalone backend will
create the workspace.
"""
global _fi_ar_workspace
if _fi_ar_workspace is not None:
return
backend = envs.VLLM_FLASHINFER_ALLREDUCE_BACKEND
comm_backend = TorchDistBackend(group=group)
_fi_ar_workspace = flashinfer_comm.create_allreduce_fusion_workspace(
backend=backend,
world_size=world_size,
rank=rank,
max_token_num=max_token_num,
hidden_dim=hidden_dim,
dtype=dtype,
comm_backend=comm_backend,
)
assert _fi_ar_workspace is not None
logger.debug(
"Initialized FlashInfer All Reduce workspace: backend=%s, "
"world_size=%d, rank=%d, max_token_num=%d, hidden_dim=%d, dtype=%s",
backend,
world_size,
rank,
max_token_num,
hidden_dim,
dtype,
)
def initialize_fi_ar_quant_workspace(
world_size: int,
rank: int,
max_token_num: int,
hidden_dim: int,
dtype: torch.dtype,
group: ProcessGroup,
) -> None:
"""
Initialize the workspace used by quantization fusion patterns.
Currently this always creates a workspace for trtllm backend as only it
supports quantization fusion (FP8/FP4). If the primary workspace
is already trtllm, the quant workspace aliases to it.
"""
global _fi_ar_quant_workspace
if _fi_ar_quant_workspace is not None:
return
# If primary workspace is already trtllm, reuse it
if _fi_ar_workspace is not None and _fi_ar_workspace.backend == "trtllm":
_fi_ar_quant_workspace = _fi_ar_workspace
return
comm_backend = TorchDistBackend(group=group)
_fi_ar_quant_workspace = flashinfer_comm.create_allreduce_fusion_workspace(
backend="trtllm",
world_size=world_size,
rank=rank,
max_token_num=max_token_num,
hidden_dim=hidden_dim,
dtype=dtype,
comm_backend=comm_backend,
)
assert _fi_ar_quant_workspace is not None
logger.debug(
"Initialized FlashInfer All Reduce workspace: backend=trtllm, "
"world_size=%d, rank=%d, max_token_num=%d, hidden_dim=%d, dtype=%s",
world_size,
rank,
max_token_num,
hidden_dim,
dtype,
)
def destroy_fi_ar_workspace():
global _fi_ar_workspace
global _fi_ar_quant_workspace
if (
_fi_ar_quant_workspace is not None
and _fi_ar_quant_workspace is not _fi_ar_workspace
):
_fi_ar_quant_workspace.destroy()
_fi_ar_quant_workspace = None
if _fi_ar_workspace is not None:
_fi_ar_workspace.destroy()
_fi_ar_workspace = None
class FlashInferAllReduce:
def __init__(
self,
group: ProcessGroup,
device: int | str | torch.device,
):
self.disabled = True
if not fi_ar_available:
logger.info(
"FlashInfer All Reduce is disabled because flashinfer is not available"
)
return
if not current_platform.is_cuda():
logger.info(
"FlashInfer All Reduce is disabled because it requires CUDA platform"
)
return
self.group = group
self.world_size = dist.get_world_size(self.group)
self.rank = dist.get_rank(self.group)
self.device = device
if self.world_size == 1:
return
# Use the same threshold as the allreduce-rms fusion pass
# TODO: tune the threshold
MiB = 1024 * 1024
max_workspace_size = PassConfig.default_fi_allreduce_fusion_max_size_mb().get(
self.world_size, None
)
if not max_workspace_size:
logger.warning(
"FlashInfer All Reduce is disabled because it "
"is not supported for world_size=%d.",
self.world_size,
)
return
self.max_workspace_size = max_workspace_size * MiB
self.max_num_tokens = 0
self.disabled = False
def _ensure_workspace(self, hidden_dim: int, dtype: torch.dtype) -> bool:
"""Ensure the all reduce workspace is initialized."""
if get_fi_ar_workspace() is not None:
return True
if self.max_num_tokens == 0:
element_size = torch.tensor([], dtype=dtype, device="cpu").element_size()
self.max_num_tokens = self.max_workspace_size // (hidden_dim * element_size)
try:
initialize_fi_ar_workspace(
world_size=self.world_size,
rank=self.rank,
max_token_num=self.max_num_tokens,
hidden_dim=hidden_dim,
dtype=dtype,
group=self.group,
)
return True
except Exception as e:
logger.warning(
"Failed to initialize FlashInfer All Reduce workspace: %s. "
"FlashInfer All Reduce will be disabled.",
e,
)
self.disabled = True
return False
def should_use_fi_ar(self, input_tensor: torch.Tensor) -> bool:
if self.disabled:
return False
if not input_tensor.is_cuda:
return False
if not input_tensor.is_contiguous():
return False
if len(input_tensor.shape) != 2:
return False
num_tokens, hidden_dim = input_tensor.shape
if not self.max_num_tokens:
element_size = torch.tensor([], dtype=input_tensor.dtype).element_size()
self.max_num_tokens = self.max_workspace_size // (hidden_dim * element_size)
if num_tokens > self.max_num_tokens:
return False
return self._ensure_workspace(hidden_dim, input_tensor.dtype)
def all_reduce(self, input_tensor: torch.Tensor) -> torch.Tensor:
workspace = get_fi_ar_workspace()
return flashinfer_comm.allreduce_fusion(
input=input_tensor,
workspace=workspace,
pattern=flashinfer_comm.AllReduceFusionPattern.kAllReduce,
)
def destroy(self):
if not self.disabled:
destroy_fi_ar_workspace()

View File

@@ -168,6 +168,7 @@ if TYPE_CHECKING:
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency", "masked_gemm"] = (
"latency"
)
VLLM_FLASHINFER_ALLREDUCE_BACKEND: Literal["auto", "trtllm", "mnnvl"] = "auto"
VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE: int = 394 * 1024 * 1024
VLLM_XGRAMMAR_CACHE_MB: int = 0
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
@@ -206,6 +207,7 @@ if TYPE_CHECKING:
VLLM_ROCM_FP8_MFMA_PAGE_ATTN: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True
VLLM_ALLREDUCE_USE_FLASHINFER: bool = False
VLLM_TUNED_CONFIG_FOLDER: str | None = None
VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS: set[str] = set()
VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT: bool = False
@@ -1290,6 +1292,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
"latency",
["throughput", "latency", "masked_gemm"],
),
# Flashinfer fused allreduce backend.
# "auto" will default to "mnnvl", which performs mostly same/better than "trtllm".
# But "mnnvl" backend does not support fuse with quantization.
"VLLM_FLASHINFER_ALLREDUCE_BACKEND": env_with_choices(
"VLLM_FLASHINFER_ALLREDUCE_BACKEND",
"auto",
["auto", "trtllm", "mnnvl"],
),
# Control the workspace buffer size for the FlashInfer backend.
"VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE": lambda: int(
os.getenv("VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE", str(394 * 1024 * 1024))
@@ -1448,6 +1458,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ALLREDUCE_USE_SYMM_MEM": lambda: bool(
int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1"))
),
# Whether to use FlashInfer allreduce
"VLLM_ALLREDUCE_USE_FLASHINFER": lambda: bool(
int(os.getenv("VLLM_ALLREDUCE_USE_FLASHINFER", "0"))
),
# Experimental: use this to enable MCP tool calling for non harmony models
"VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT": lambda: bool(
int(os.getenv("VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT", "0"))