[Kernel] Refactor FlashInfer allreduce for mnnvl backend (#34109)
Signed-off-by: hjjq <50634613+hjjq@users.noreply.github.com> Signed-off-by: wzhao18 <wzhao18.sz@gmail.com> Co-authored-by: wzhao18 <wzhao18.sz@gmail.com> Co-authored-by: Wei Zhao <51183510+wzhao18@users.noreply.github.com>
This commit is contained in:
@@ -5,8 +5,11 @@
|
||||
Benchmark for FlashInfer fused collective operations vs standard operations.
|
||||
|
||||
This benchmark compares:
|
||||
1. FlashInfer's allreduce_fusion (fused allreduce + rmsnorm + optional quant)
|
||||
2. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations
|
||||
1. FlashInfer's allreduce_fusion with trtllm backend
|
||||
(fused allreduce + rmsnorm + optional FP8/FP4 quant)
|
||||
2. FlashInfer's allreduce_fusion with mnnvl backend
|
||||
(fused allreduce + rmsnorm only, no quantization support)
|
||||
3. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations
|
||||
|
||||
Usage with torchrun:
|
||||
torchrun --nproc_per_node=2 benchmark_fused_collective.py
|
||||
@@ -48,8 +51,12 @@ SCALED_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Try to import FlashInfer
|
||||
TorchDistBackend = None
|
||||
try:
|
||||
import flashinfer.comm as flashinfer_comm # type: ignore
|
||||
from flashinfer.comm.mnnvl import ( # type: ignore
|
||||
TorchDistBackend,
|
||||
)
|
||||
|
||||
if not (
|
||||
hasattr(flashinfer_comm, "allreduce_fusion")
|
||||
@@ -74,11 +81,15 @@ _FI_MAX_SIZES = {
|
||||
8: 64 * MiB, # 64MB
|
||||
}
|
||||
|
||||
# Global workspace tensor for FlashInfer
|
||||
_FI_WORKSPACE = None
|
||||
# Global workspace tensors for FlashInfer (keyed by backend name)
|
||||
_FI_WORKSPACES: dict = {}
|
||||
|
||||
# Backends to benchmark
|
||||
FLASHINFER_BACKENDS = ["trtllm", "mnnvl"]
|
||||
|
||||
|
||||
def setup_flashinfer_workspace(
|
||||
backend: str,
|
||||
world_size: int,
|
||||
rank: int,
|
||||
hidden_dim: int,
|
||||
@@ -86,41 +97,54 @@ def setup_flashinfer_workspace(
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Setup FlashInfer workspace for fused allreduce operations."""
|
||||
global _FI_WORKSPACE
|
||||
global FI_WORKSPACES
|
||||
|
||||
if flashinfer_comm is None:
|
||||
return None, None
|
||||
return None
|
||||
|
||||
if world_size not in _FI_MAX_SIZES:
|
||||
logger.warning("FlashInfer not supported for world size %s", world_size)
|
||||
return None, None
|
||||
return None
|
||||
|
||||
try:
|
||||
kwargs = {}
|
||||
if TorchDistBackend is not None:
|
||||
kwargs["comm_backend"] = TorchDistBackend(group=dist.group.WORLD)
|
||||
|
||||
workspace = flashinfer_comm.create_allreduce_fusion_workspace(
|
||||
backend="trtllm",
|
||||
backend=backend,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
max_token_num=max_token_num,
|
||||
hidden_dim=hidden_dim,
|
||||
dtype=dtype,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
_FI_WORKSPACE = workspace
|
||||
_FI_WORKSPACES[backend] = workspace
|
||||
return workspace
|
||||
except Exception as e:
|
||||
logger.error("Failed to setup FlashInfer workspace: %s", e)
|
||||
logger.error(
|
||||
"Failed to setup FlashInfer workspace (backend=%s): %s", backend, e
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def cleanup_flashinfer_workspace(workspace):
|
||||
"""Cleanup FlashInfer workspace."""
|
||||
if flashinfer_comm is None or workspace is None:
|
||||
def cleanup_flashinfer_workspaces():
|
||||
"""Cleanup all FlashInfer workspaces."""
|
||||
if flashinfer_comm is None:
|
||||
return
|
||||
|
||||
try:
|
||||
workspace.destroy()
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup FlashInfer workspace: %s", e)
|
||||
for backend, workspace in _FI_WORKSPACES.items():
|
||||
try:
|
||||
workspace.destroy()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to cleanup FlashInfer workspace (backend=%s): %s",
|
||||
backend,
|
||||
e,
|
||||
)
|
||||
_FI_WORKSPACES.clear()
|
||||
|
||||
|
||||
class FlashInferFusedAllReduceParams:
|
||||
@@ -134,7 +158,7 @@ class FlashInferFusedAllReduceParams:
|
||||
self.fp32_acc = True
|
||||
self.max_token_num = max_token_num
|
||||
|
||||
def get_trtllm_fused_allreduce_kwargs(self):
|
||||
def get_flashinfer_fused_allreduce_kwargs(self):
|
||||
return {
|
||||
"launch_with_pdl": self.launch_with_pdl,
|
||||
"fp32_acc": self.fp32_acc,
|
||||
@@ -147,11 +171,12 @@ def flashinfer_fused_allreduce_rmsnorm(
|
||||
rms_gamma: torch.Tensor,
|
||||
rms_eps: float,
|
||||
allreduce_params: "FlashInferFusedAllReduceParams",
|
||||
workspace: object,
|
||||
use_oneshot: bool,
|
||||
norm_out: torch.Tensor | None = None,
|
||||
):
|
||||
"""FlashInfer fused allreduce + rmsnorm operation."""
|
||||
if flashinfer_comm is None or _FI_WORKSPACE is None:
|
||||
if flashinfer_comm is None or workspace is None:
|
||||
raise RuntimeError("FlashInfer not available or workspace not initialized")
|
||||
|
||||
if norm_out is None:
|
||||
@@ -160,9 +185,13 @@ def flashinfer_fused_allreduce_rmsnorm(
|
||||
else:
|
||||
residual_out = input_tensor
|
||||
|
||||
layout_code = None
|
||||
if workspace.backend == "trtllm":
|
||||
layout_code = flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4
|
||||
|
||||
flashinfer_comm.allreduce_fusion(
|
||||
input=input_tensor,
|
||||
workspace=_FI_WORKSPACE,
|
||||
workspace=workspace,
|
||||
pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
|
||||
residual_in=residual,
|
||||
residual_out=residual_out,
|
||||
@@ -171,10 +200,10 @@ def flashinfer_fused_allreduce_rmsnorm(
|
||||
rms_eps=rms_eps,
|
||||
quant_out=None,
|
||||
scale_out=None,
|
||||
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
|
||||
layout_code=layout_code,
|
||||
scale_factor=None,
|
||||
use_oneshot=use_oneshot,
|
||||
**allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||
**allreduce_params.get_flashinfer_fused_allreduce_kwargs(),
|
||||
)
|
||||
|
||||
|
||||
@@ -185,12 +214,16 @@ def flashinfer_fused_allreduce_rmsnorm_fp8_quant(
|
||||
rms_eps: float,
|
||||
scale_factor: torch.Tensor,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
workspace: object,
|
||||
use_oneshot: bool = True,
|
||||
norm_out: torch.Tensor | None = None,
|
||||
quant_out: torch.Tensor | None = None,
|
||||
):
|
||||
"""FlashInfer fused allreduce + rmsnorm + FP8 quantization."""
|
||||
if flashinfer_comm is None or _FI_WORKSPACE is None:
|
||||
"""FlashInfer fused allreduce + rmsnorm + FP8 quantization.
|
||||
|
||||
Note: Only supported by the trtllm backend.
|
||||
"""
|
||||
if flashinfer_comm is None or workspace is None:
|
||||
raise RuntimeError("FlashInfer not available or workspace not initialized")
|
||||
|
||||
if norm_out is None:
|
||||
@@ -201,7 +234,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp8_quant(
|
||||
|
||||
flashinfer_comm.allreduce_fusion(
|
||||
input=input_tensor,
|
||||
workspace=_FI_WORKSPACE,
|
||||
workspace=workspace,
|
||||
pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant,
|
||||
residual_in=residual,
|
||||
residual_out=residual_out,
|
||||
@@ -213,7 +246,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp8_quant(
|
||||
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
|
||||
scale_factor=scale_factor,
|
||||
use_oneshot=use_oneshot,
|
||||
**allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||
**allreduce_params.get_flashinfer_fused_allreduce_kwargs(),
|
||||
)
|
||||
|
||||
|
||||
@@ -224,13 +257,17 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant(
|
||||
rms_eps: float,
|
||||
input_global_scale: torch.Tensor,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
workspace: object,
|
||||
quant_out: torch.Tensor,
|
||||
use_oneshot: bool,
|
||||
output_scale: torch.Tensor,
|
||||
norm_out: torch.Tensor | None = None,
|
||||
):
|
||||
"""FlashInfer fused allreduce + rmsnorm + FP4 quantization."""
|
||||
if flashinfer_comm is None or _FI_WORKSPACE is None:
|
||||
"""FlashInfer fused allreduce + rmsnorm + FP4 quantization.
|
||||
|
||||
Note: Only supported by the trtllm backend.
|
||||
"""
|
||||
if flashinfer_comm is None or workspace is None:
|
||||
raise RuntimeError("FlashInfer not available or workspace not initialized")
|
||||
|
||||
if norm_out is None:
|
||||
@@ -241,7 +278,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant(
|
||||
|
||||
flashinfer_comm.allreduce_fusion(
|
||||
input=input_tensor,
|
||||
workspace=_FI_WORKSPACE,
|
||||
workspace=workspace,
|
||||
pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant,
|
||||
residual_in=residual,
|
||||
residual_out=residual_out,
|
||||
@@ -253,7 +290,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant(
|
||||
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
|
||||
scale_factor=input_global_scale,
|
||||
use_oneshot=use_oneshot,
|
||||
**allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||
**allreduce_params.get_flashinfer_fused_allreduce_kwargs(),
|
||||
)
|
||||
|
||||
|
||||
@@ -386,13 +423,16 @@ def run_benchmarks(
|
||||
dtype: torch.dtype,
|
||||
use_residual: bool,
|
||||
allreduce_params: FlashInferFusedAllReduceParams | None,
|
||||
workspaces: dict,
|
||||
quant_modes: set[str],
|
||||
no_oneshot: bool,
|
||||
):
|
||||
"""Run all benchmarks for given configuration.
|
||||
|
||||
Args:
|
||||
quant_mode: "none", "fp8_only", "fp4_only", or "all"
|
||||
allreduce_params: Shared parameters for FlashInfer fused allreduce.
|
||||
workspaces: Dict mapping backend name ("trtllm", "mnnvl") to workspace.
|
||||
quant_modes: Set of quantization modes: "none", "fp8", "fp4".
|
||||
"""
|
||||
(
|
||||
input_tensor,
|
||||
@@ -454,10 +494,11 @@ def run_benchmarks(
|
||||
logger.error("Standard AllReduce+RMSNorm Native Compiled failed: %s", e)
|
||||
results["standard_allreduce_rmsnorm_native_compiled"] = float("inf")
|
||||
|
||||
# FlashInfer Fused AllReduce + RMSNorm Oneshot/Twoshot
|
||||
if flashinfer_comm is not None and allreduce_params is not None:
|
||||
# FlashInfer Fused AllReduce + RMSNorm (all backends)
|
||||
for backend, workspace in workspaces.items():
|
||||
for use_oneshot in use_oneshot_options:
|
||||
suffix = "_oneshot" if use_oneshot else "_twoshot"
|
||||
key = f"flashinfer_{backend}_fused_allreduce_rmsnorm{suffix}"
|
||||
try:
|
||||
time_ms = benchmark_operation(
|
||||
flashinfer_fused_allreduce_rmsnorm,
|
||||
@@ -467,14 +508,17 @@ def run_benchmarks(
|
||||
rms_gamma=rms_gamma,
|
||||
rms_eps=rms_eps,
|
||||
allreduce_params=allreduce_params,
|
||||
workspace=workspace,
|
||||
use_oneshot=use_oneshot,
|
||||
)
|
||||
results[f"flashinfer_fused_allreduce_rmsnorm{suffix}"] = time_ms
|
||||
results[key] = time_ms
|
||||
except Exception as e:
|
||||
logger.error("FlashInfer Fused AllReduce+RMSNorm failed: %s", e)
|
||||
results[f"flashinfer_fused_allreduce_rmsnorm{suffix}"] = float(
|
||||
"inf"
|
||||
logger.error(
|
||||
"FlashInfer (%s) Fused AllReduce+RMSNorm failed: %s",
|
||||
backend,
|
||||
e,
|
||||
)
|
||||
results[key] = float("inf")
|
||||
|
||||
if "fp8" in quant_modes:
|
||||
# Standard AllReduce + RMSNorm + FP8 Quant
|
||||
@@ -540,10 +584,12 @@ def run_benchmarks(
|
||||
"inf"
|
||||
)
|
||||
|
||||
# FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Oneshot
|
||||
if flashinfer_comm is not None and allreduce_params is not None:
|
||||
# FlashInfer Fused AllReduce + RMSNorm + FP8 Quant (trtllm only)
|
||||
if "trtllm" in workspaces:
|
||||
trtllm_ws = workspaces["trtllm"]
|
||||
for use_oneshot in use_oneshot_options:
|
||||
suffix = "_oneshot" if use_oneshot else "_twoshot"
|
||||
key = f"flashinfer_trtllm_fused_allreduce_rmsnorm_fp8_quant{suffix}"
|
||||
try:
|
||||
time_ms = benchmark_operation(
|
||||
flashinfer_fused_allreduce_rmsnorm_fp8_quant,
|
||||
@@ -555,19 +601,16 @@ def run_benchmarks(
|
||||
scale_factor=scale_fp8,
|
||||
quant_out=quant_out_fp8,
|
||||
allreduce_params=allreduce_params,
|
||||
workspace=trtllm_ws,
|
||||
use_oneshot=use_oneshot,
|
||||
)
|
||||
results[f"flashinfer_fused_allreduce_rmsnorm_fp8_quant{suffix}"] = (
|
||||
time_ms
|
||||
)
|
||||
results[key] = time_ms
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"FlashInfer Fused AllReduce+RMSNorm+FP8 Oneshot failed: %s",
|
||||
"FlashInfer (trtllm) Fused AllReduce+RMSNorm+FP8 failed: %s",
|
||||
e,
|
||||
)
|
||||
results[f"flashinfer_fused_allreduce_rmsnorm_fp8_quant{suffix}"] = (
|
||||
float("inf")
|
||||
)
|
||||
results[key] = float("inf")
|
||||
|
||||
if "fp4" in quant_modes and current_platform.has_device_capability(100):
|
||||
# Standard AllReduce + RMSNorm + FP4 Quant
|
||||
@@ -627,10 +670,12 @@ def run_benchmarks(
|
||||
"inf"
|
||||
)
|
||||
|
||||
# FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Oneshot
|
||||
if flashinfer_comm is not None and allreduce_params is not None:
|
||||
# FlashInfer Fused AllReduce + RMSNorm + FP4 Quant (trtllm only)
|
||||
if "trtllm" in workspaces:
|
||||
trtllm_ws = workspaces["trtllm"]
|
||||
for use_oneshot in use_oneshot_options:
|
||||
suffix = "_oneshot" if use_oneshot else "_twoshot"
|
||||
key = f"flashinfer_trtllm_fused_allreduce_rmsnorm_fp4_quant{suffix}"
|
||||
try:
|
||||
time_ms = benchmark_operation(
|
||||
flashinfer_fused_allreduce_rmsnorm_fp4_quant,
|
||||
@@ -641,49 +686,18 @@ def run_benchmarks(
|
||||
rms_eps=rms_eps,
|
||||
input_global_scale=scale_fp4,
|
||||
allreduce_params=allreduce_params,
|
||||
workspace=trtllm_ws,
|
||||
quant_out=fp4_quant_out,
|
||||
output_scale=fp4_output_scale,
|
||||
use_oneshot=use_oneshot,
|
||||
)
|
||||
results[f"flashinfer_fused_allreduce_rmsnorm_fp4_quant{suffix}"] = (
|
||||
time_ms
|
||||
)
|
||||
results[key] = time_ms
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"FlashInfer Fused AllReduce+RMSNorm+FP4 Oneshot failed: %s",
|
||||
"FlashInfer (trtllm) Fused AllReduce+RMSNorm+FP4 failed: %s",
|
||||
e,
|
||||
)
|
||||
results[f"flashinfer_fused_allreduce_rmsnorm_fp4_quant{suffix}"] = (
|
||||
float("inf")
|
||||
)
|
||||
|
||||
# FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Two-shot
|
||||
if flashinfer_comm is not None and allreduce_params is not None:
|
||||
try:
|
||||
time_ms = benchmark_operation(
|
||||
flashinfer_fused_allreduce_rmsnorm_fp4_quant,
|
||||
input_tensor,
|
||||
residual=residual,
|
||||
norm_out=norm_out,
|
||||
rms_gamma=rms_gamma,
|
||||
rms_eps=rms_eps,
|
||||
input_global_scale=scale_fp4,
|
||||
allreduce_params=allreduce_params,
|
||||
quant_out=fp4_quant_out,
|
||||
output_scale=fp4_output_scale,
|
||||
use_oneshot=False,
|
||||
)
|
||||
results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = (
|
||||
time_ms
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"FlashInfer Fused AllReduce+RMSNorm+FP4 Two-shot failed: %s",
|
||||
e,
|
||||
)
|
||||
results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = float(
|
||||
"inf"
|
||||
)
|
||||
results[key] = float("inf")
|
||||
|
||||
return results
|
||||
|
||||
@@ -1021,8 +1035,7 @@ def main():
|
||||
|
||||
configs = list(itertools.product(args.num_tokens, dtypes, residual_options))
|
||||
|
||||
# Setup FlashInfer workspace if available
|
||||
workspace = None
|
||||
# Setup FlashInfer workspaces for all backends
|
||||
allreduce_params = None
|
||||
|
||||
if flashinfer_comm is not None:
|
||||
@@ -1037,15 +1050,17 @@ def main():
|
||||
args.hidden_dim * max_element_size
|
||||
)
|
||||
|
||||
workspace = setup_flashinfer_workspace(
|
||||
world_size,
|
||||
rank,
|
||||
args.hidden_dim,
|
||||
max_num_token,
|
||||
dtype=workspace_dtype,
|
||||
)
|
||||
for backend in FLASHINFER_BACKENDS:
|
||||
setup_flashinfer_workspace(
|
||||
backend=backend,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
hidden_dim=args.hidden_dim,
|
||||
max_token_num=max_num_token,
|
||||
dtype=workspace_dtype,
|
||||
)
|
||||
|
||||
if workspace is not None:
|
||||
if _FI_WORKSPACES:
|
||||
allreduce_params = FlashInferFusedAllReduceParams(
|
||||
max_token_num=max_num_token,
|
||||
)
|
||||
@@ -1071,6 +1086,7 @@ def main():
|
||||
dtype,
|
||||
use_residual,
|
||||
allreduce_params,
|
||||
workspaces=_FI_WORKSPACES,
|
||||
quant_modes=quant_modes,
|
||||
no_oneshot=args.no_oneshot,
|
||||
)
|
||||
@@ -1109,11 +1125,13 @@ def main():
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
if workspace is not None:
|
||||
cleanup_flashinfer_workspace(workspace)
|
||||
cleanup_flashinfer_workspaces()
|
||||
|
||||
dist.barrier()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
|
||||
with set_current_vllm_config(VllmConfig()):
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user