[Kernel] Refactor FlashInfer allreduce for mnnvl backend (#34109)

Signed-off-by: hjjq <50634613+hjjq@users.noreply.github.com>
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: Wei Zhao <51183510+wzhao18@users.noreply.github.com>
This commit is contained in:
Hanjie Qiu
2026-02-25 19:17:20 -08:00
committed by GitHub
parent 2aa4140402
commit 71dfce6aa6
7 changed files with 593 additions and 180 deletions

View File

@@ -5,8 +5,11 @@
Benchmark for FlashInfer fused collective operations vs standard operations.
This benchmark compares:
1. FlashInfer's allreduce_fusion (fused allreduce + rmsnorm + optional quant)
2. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations
1. FlashInfer's allreduce_fusion with trtllm backend
(fused allreduce + rmsnorm + optional FP8/FP4 quant)
2. FlashInfer's allreduce_fusion with mnnvl backend
(fused allreduce + rmsnorm only, no quantization support)
3. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations
Usage with torchrun:
torchrun --nproc_per_node=2 benchmark_fused_collective.py
@@ -48,8 +51,12 @@ SCALED_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant
logger = init_logger(__name__)
# Try to import FlashInfer
TorchDistBackend = None
try:
import flashinfer.comm as flashinfer_comm # type: ignore
from flashinfer.comm.mnnvl import ( # type: ignore
TorchDistBackend,
)
if not (
hasattr(flashinfer_comm, "allreduce_fusion")
@@ -74,11 +81,15 @@ _FI_MAX_SIZES = {
8: 64 * MiB, # 64MB
}
# Global workspace tensor for FlashInfer
_FI_WORKSPACE = None
# Global workspace tensors for FlashInfer (keyed by backend name)
_FI_WORKSPACES: dict = {}
# Backends to benchmark
FLASHINFER_BACKENDS = ["trtllm", "mnnvl"]
def setup_flashinfer_workspace(
backend: str,
world_size: int,
rank: int,
hidden_dim: int,
@@ -86,41 +97,54 @@ def setup_flashinfer_workspace(
dtype: torch.dtype,
):
"""Setup FlashInfer workspace for fused allreduce operations."""
global _FI_WORKSPACE
global FI_WORKSPACES
if flashinfer_comm is None:
return None, None
return None
if world_size not in _FI_MAX_SIZES:
logger.warning("FlashInfer not supported for world size %s", world_size)
return None, None
return None
try:
kwargs = {}
if TorchDistBackend is not None:
kwargs["comm_backend"] = TorchDistBackend(group=dist.group.WORLD)
workspace = flashinfer_comm.create_allreduce_fusion_workspace(
backend="trtllm",
backend=backend,
world_size=world_size,
rank=rank,
max_token_num=max_token_num,
hidden_dim=hidden_dim,
dtype=dtype,
**kwargs,
)
_FI_WORKSPACE = workspace
_FI_WORKSPACES[backend] = workspace
return workspace
except Exception as e:
logger.error("Failed to setup FlashInfer workspace: %s", e)
logger.error(
"Failed to setup FlashInfer workspace (backend=%s): %s", backend, e
)
return None
def cleanup_flashinfer_workspace(workspace):
"""Cleanup FlashInfer workspace."""
if flashinfer_comm is None or workspace is None:
def cleanup_flashinfer_workspaces():
"""Cleanup all FlashInfer workspaces."""
if flashinfer_comm is None:
return
try:
workspace.destroy()
except Exception as e:
logger.error("Failed to cleanup FlashInfer workspace: %s", e)
for backend, workspace in _FI_WORKSPACES.items():
try:
workspace.destroy()
except Exception as e:
logger.error(
"Failed to cleanup FlashInfer workspace (backend=%s): %s",
backend,
e,
)
_FI_WORKSPACES.clear()
class FlashInferFusedAllReduceParams:
@@ -134,7 +158,7 @@ class FlashInferFusedAllReduceParams:
self.fp32_acc = True
self.max_token_num = max_token_num
def get_trtllm_fused_allreduce_kwargs(self):
def get_flashinfer_fused_allreduce_kwargs(self):
return {
"launch_with_pdl": self.launch_with_pdl,
"fp32_acc": self.fp32_acc,
@@ -147,11 +171,12 @@ def flashinfer_fused_allreduce_rmsnorm(
rms_gamma: torch.Tensor,
rms_eps: float,
allreduce_params: "FlashInferFusedAllReduceParams",
workspace: object,
use_oneshot: bool,
norm_out: torch.Tensor | None = None,
):
"""FlashInfer fused allreduce + rmsnorm operation."""
if flashinfer_comm is None or _FI_WORKSPACE is None:
if flashinfer_comm is None or workspace is None:
raise RuntimeError("FlashInfer not available or workspace not initialized")
if norm_out is None:
@@ -160,9 +185,13 @@ def flashinfer_fused_allreduce_rmsnorm(
else:
residual_out = input_tensor
layout_code = None
if workspace.backend == "trtllm":
layout_code = flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4
flashinfer_comm.allreduce_fusion(
input=input_tensor,
workspace=_FI_WORKSPACE,
workspace=workspace,
pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
residual_in=residual,
residual_out=residual_out,
@@ -171,10 +200,10 @@ def flashinfer_fused_allreduce_rmsnorm(
rms_eps=rms_eps,
quant_out=None,
scale_out=None,
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
layout_code=layout_code,
scale_factor=None,
use_oneshot=use_oneshot,
**allreduce_params.get_trtllm_fused_allreduce_kwargs(),
**allreduce_params.get_flashinfer_fused_allreduce_kwargs(),
)
@@ -185,12 +214,16 @@ def flashinfer_fused_allreduce_rmsnorm_fp8_quant(
rms_eps: float,
scale_factor: torch.Tensor,
allreduce_params: FlashInferFusedAllReduceParams,
workspace: object,
use_oneshot: bool = True,
norm_out: torch.Tensor | None = None,
quant_out: torch.Tensor | None = None,
):
"""FlashInfer fused allreduce + rmsnorm + FP8 quantization."""
if flashinfer_comm is None or _FI_WORKSPACE is None:
"""FlashInfer fused allreduce + rmsnorm + FP8 quantization.
Note: Only supported by the trtllm backend.
"""
if flashinfer_comm is None or workspace is None:
raise RuntimeError("FlashInfer not available or workspace not initialized")
if norm_out is None:
@@ -201,7 +234,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp8_quant(
flashinfer_comm.allreduce_fusion(
input=input_tensor,
workspace=_FI_WORKSPACE,
workspace=workspace,
pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant,
residual_in=residual,
residual_out=residual_out,
@@ -213,7 +246,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp8_quant(
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
scale_factor=scale_factor,
use_oneshot=use_oneshot,
**allreduce_params.get_trtllm_fused_allreduce_kwargs(),
**allreduce_params.get_flashinfer_fused_allreduce_kwargs(),
)
@@ -224,13 +257,17 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant(
rms_eps: float,
input_global_scale: torch.Tensor,
allreduce_params: FlashInferFusedAllReduceParams,
workspace: object,
quant_out: torch.Tensor,
use_oneshot: bool,
output_scale: torch.Tensor,
norm_out: torch.Tensor | None = None,
):
"""FlashInfer fused allreduce + rmsnorm + FP4 quantization."""
if flashinfer_comm is None or _FI_WORKSPACE is None:
"""FlashInfer fused allreduce + rmsnorm + FP4 quantization.
Note: Only supported by the trtllm backend.
"""
if flashinfer_comm is None or workspace is None:
raise RuntimeError("FlashInfer not available or workspace not initialized")
if norm_out is None:
@@ -241,7 +278,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant(
flashinfer_comm.allreduce_fusion(
input=input_tensor,
workspace=_FI_WORKSPACE,
workspace=workspace,
pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant,
residual_in=residual,
residual_out=residual_out,
@@ -253,7 +290,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant(
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
scale_factor=input_global_scale,
use_oneshot=use_oneshot,
**allreduce_params.get_trtllm_fused_allreduce_kwargs(),
**allreduce_params.get_flashinfer_fused_allreduce_kwargs(),
)
@@ -386,13 +423,16 @@ def run_benchmarks(
dtype: torch.dtype,
use_residual: bool,
allreduce_params: FlashInferFusedAllReduceParams | None,
workspaces: dict,
quant_modes: set[str],
no_oneshot: bool,
):
"""Run all benchmarks for given configuration.
Args:
quant_mode: "none", "fp8_only", "fp4_only", or "all"
allreduce_params: Shared parameters for FlashInfer fused allreduce.
workspaces: Dict mapping backend name ("trtllm", "mnnvl") to workspace.
quant_modes: Set of quantization modes: "none", "fp8", "fp4".
"""
(
input_tensor,
@@ -454,10 +494,11 @@ def run_benchmarks(
logger.error("Standard AllReduce+RMSNorm Native Compiled failed: %s", e)
results["standard_allreduce_rmsnorm_native_compiled"] = float("inf")
# FlashInfer Fused AllReduce + RMSNorm Oneshot/Twoshot
if flashinfer_comm is not None and allreduce_params is not None:
# FlashInfer Fused AllReduce + RMSNorm (all backends)
for backend, workspace in workspaces.items():
for use_oneshot in use_oneshot_options:
suffix = "_oneshot" if use_oneshot else "_twoshot"
key = f"flashinfer_{backend}_fused_allreduce_rmsnorm{suffix}"
try:
time_ms = benchmark_operation(
flashinfer_fused_allreduce_rmsnorm,
@@ -467,14 +508,17 @@ def run_benchmarks(
rms_gamma=rms_gamma,
rms_eps=rms_eps,
allreduce_params=allreduce_params,
workspace=workspace,
use_oneshot=use_oneshot,
)
results[f"flashinfer_fused_allreduce_rmsnorm{suffix}"] = time_ms
results[key] = time_ms
except Exception as e:
logger.error("FlashInfer Fused AllReduce+RMSNorm failed: %s", e)
results[f"flashinfer_fused_allreduce_rmsnorm{suffix}"] = float(
"inf"
logger.error(
"FlashInfer (%s) Fused AllReduce+RMSNorm failed: %s",
backend,
e,
)
results[key] = float("inf")
if "fp8" in quant_modes:
# Standard AllReduce + RMSNorm + FP8 Quant
@@ -540,10 +584,12 @@ def run_benchmarks(
"inf"
)
# FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Oneshot
if flashinfer_comm is not None and allreduce_params is not None:
# FlashInfer Fused AllReduce + RMSNorm + FP8 Quant (trtllm only)
if "trtllm" in workspaces:
trtllm_ws = workspaces["trtllm"]
for use_oneshot in use_oneshot_options:
suffix = "_oneshot" if use_oneshot else "_twoshot"
key = f"flashinfer_trtllm_fused_allreduce_rmsnorm_fp8_quant{suffix}"
try:
time_ms = benchmark_operation(
flashinfer_fused_allreduce_rmsnorm_fp8_quant,
@@ -555,19 +601,16 @@ def run_benchmarks(
scale_factor=scale_fp8,
quant_out=quant_out_fp8,
allreduce_params=allreduce_params,
workspace=trtllm_ws,
use_oneshot=use_oneshot,
)
results[f"flashinfer_fused_allreduce_rmsnorm_fp8_quant{suffix}"] = (
time_ms
)
results[key] = time_ms
except Exception as e:
logger.error(
"FlashInfer Fused AllReduce+RMSNorm+FP8 Oneshot failed: %s",
"FlashInfer (trtllm) Fused AllReduce+RMSNorm+FP8 failed: %s",
e,
)
results[f"flashinfer_fused_allreduce_rmsnorm_fp8_quant{suffix}"] = (
float("inf")
)
results[key] = float("inf")
if "fp4" in quant_modes and current_platform.has_device_capability(100):
# Standard AllReduce + RMSNorm + FP4 Quant
@@ -627,10 +670,12 @@ def run_benchmarks(
"inf"
)
# FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Oneshot
if flashinfer_comm is not None and allreduce_params is not None:
# FlashInfer Fused AllReduce + RMSNorm + FP4 Quant (trtllm only)
if "trtllm" in workspaces:
trtllm_ws = workspaces["trtllm"]
for use_oneshot in use_oneshot_options:
suffix = "_oneshot" if use_oneshot else "_twoshot"
key = f"flashinfer_trtllm_fused_allreduce_rmsnorm_fp4_quant{suffix}"
try:
time_ms = benchmark_operation(
flashinfer_fused_allreduce_rmsnorm_fp4_quant,
@@ -641,49 +686,18 @@ def run_benchmarks(
rms_eps=rms_eps,
input_global_scale=scale_fp4,
allreduce_params=allreduce_params,
workspace=trtllm_ws,
quant_out=fp4_quant_out,
output_scale=fp4_output_scale,
use_oneshot=use_oneshot,
)
results[f"flashinfer_fused_allreduce_rmsnorm_fp4_quant{suffix}"] = (
time_ms
)
results[key] = time_ms
except Exception as e:
logger.error(
"FlashInfer Fused AllReduce+RMSNorm+FP4 Oneshot failed: %s",
"FlashInfer (trtllm) Fused AllReduce+RMSNorm+FP4 failed: %s",
e,
)
results[f"flashinfer_fused_allreduce_rmsnorm_fp4_quant{suffix}"] = (
float("inf")
)
# FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Two-shot
if flashinfer_comm is not None and allreduce_params is not None:
try:
time_ms = benchmark_operation(
flashinfer_fused_allreduce_rmsnorm_fp4_quant,
input_tensor,
residual=residual,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
input_global_scale=scale_fp4,
allreduce_params=allreduce_params,
quant_out=fp4_quant_out,
output_scale=fp4_output_scale,
use_oneshot=False,
)
results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = (
time_ms
)
except Exception as e:
logger.error(
"FlashInfer Fused AllReduce+RMSNorm+FP4 Two-shot failed: %s",
e,
)
results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = float(
"inf"
)
results[key] = float("inf")
return results
@@ -1021,8 +1035,7 @@ def main():
configs = list(itertools.product(args.num_tokens, dtypes, residual_options))
# Setup FlashInfer workspace if available
workspace = None
# Setup FlashInfer workspaces for all backends
allreduce_params = None
if flashinfer_comm is not None:
@@ -1037,15 +1050,17 @@ def main():
args.hidden_dim * max_element_size
)
workspace = setup_flashinfer_workspace(
world_size,
rank,
args.hidden_dim,
max_num_token,
dtype=workspace_dtype,
)
for backend in FLASHINFER_BACKENDS:
setup_flashinfer_workspace(
backend=backend,
world_size=world_size,
rank=rank,
hidden_dim=args.hidden_dim,
max_token_num=max_num_token,
dtype=workspace_dtype,
)
if workspace is not None:
if _FI_WORKSPACES:
allreduce_params = FlashInferFusedAllReduceParams(
max_token_num=max_num_token,
)
@@ -1071,6 +1086,7 @@ def main():
dtype,
use_residual,
allreduce_params,
workspaces=_FI_WORKSPACES,
quant_modes=quant_modes,
no_oneshot=args.no_oneshot,
)
@@ -1109,11 +1125,13 @@ def main():
finally:
# Cleanup
if workspace is not None:
cleanup_flashinfer_workspace(workspace)
cleanup_flashinfer_workspaces()
dist.barrier()
if __name__ == "__main__":
main()
from vllm.config import VllmConfig, set_current_vllm_config
with set_current_vllm_config(VllmConfig()):
main()