[Kernel] Refactor FlashInfer allreduce for mnnvl backend (#34109)
Signed-off-by: hjjq <50634613+hjjq@users.noreply.github.com> Signed-off-by: wzhao18 <wzhao18.sz@gmail.com> Co-authored-by: wzhao18 <wzhao18.sz@gmail.com> Co-authored-by: Wei Zhao <51183510+wzhao18@users.noreply.github.com>
This commit is contained in:
@@ -30,6 +30,9 @@ import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce
|
||||
from vllm.distributed.device_communicators.flashinfer_all_reduce import (
|
||||
FlashInferAllReduce,
|
||||
)
|
||||
from vllm.distributed.device_communicators.pynccl import (
|
||||
PyNcclCommunicator,
|
||||
register_nccl_symmetric_ops,
|
||||
@@ -44,7 +47,7 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Default sequence lengths to benchmark
|
||||
DEFAULT_SEQUENCE_LENGTHS = [128, 512, 1024, 2048, 4096, 8192]
|
||||
DEFAULT_SEQUENCE_LENGTHS = [16, 64, 128, 512, 1024, 2048, 4096, 8192]
|
||||
|
||||
# Fixed hidden size and dtype for all benchmarks
|
||||
HIDDEN_SIZE = 8192
|
||||
@@ -81,6 +84,7 @@ class CommunicatorBenchmark:
|
||||
self.symm_mem_comm = None
|
||||
self.symm_mem_comm_multimem = None
|
||||
self.symm_mem_comm_two_shot = None
|
||||
self.fi_ar_comm = None
|
||||
|
||||
self._init_communicators()
|
||||
|
||||
@@ -161,6 +165,22 @@ class CommunicatorBenchmark:
|
||||
)
|
||||
self.symm_mem_comm_two_shot = None
|
||||
|
||||
try:
|
||||
self.fi_ar_comm = FlashInferAllReduce(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
)
|
||||
if not self.fi_ar_comm.disabled:
|
||||
logger.info("Rank %s: FlashInferAllReduce initialized", self.rank)
|
||||
else:
|
||||
logger.info("Rank %s: FlashInferAllReduce disabled", self.rank)
|
||||
self.fi_ar_comm = None
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Rank %s: Failed to initialize FlashInferAllReduce: %s", self.rank, e
|
||||
)
|
||||
self.fi_ar_comm = None
|
||||
|
||||
def benchmark_allreduce(
|
||||
self, sequence_length: int, num_warmup: int, num_trials: int
|
||||
) -> dict[str, float]:
|
||||
@@ -180,7 +200,8 @@ class CommunicatorBenchmark:
|
||||
lambda t, c=comm: c.custom_all_reduce(t),
|
||||
lambda t, c=comm: c.should_custom_ar(t),
|
||||
comm.capture(),
|
||||
"1stage", # env variable value
|
||||
{"VLLM_CUSTOM_ALLREDUCE_ALGO": "1stage"},
|
||||
None, # no destroy function
|
||||
)
|
||||
)
|
||||
# CustomAllreduce two-shot
|
||||
@@ -190,7 +211,8 @@ class CommunicatorBenchmark:
|
||||
lambda t, c=comm: c.custom_all_reduce(t),
|
||||
lambda t, c=comm: c.should_custom_ar(t),
|
||||
comm.capture(),
|
||||
"2stage", # env variable value
|
||||
{"VLLM_CUSTOM_ALLREDUCE_ALGO": "2stage"},
|
||||
None, # no destroy function
|
||||
)
|
||||
)
|
||||
|
||||
@@ -202,7 +224,8 @@ class CommunicatorBenchmark:
|
||||
lambda t, c=comm: c.all_reduce(t),
|
||||
lambda t: True, # Always available if initialized
|
||||
nullcontext(),
|
||||
None, # no env variable needed
|
||||
{}, # no env variable needed
|
||||
None, # no destroy function
|
||||
)
|
||||
)
|
||||
communicators.append(
|
||||
@@ -211,7 +234,8 @@ class CommunicatorBenchmark:
|
||||
lambda t: torch.ops.vllm.all_reduce_symmetric_with_copy(t),
|
||||
lambda t: True, # Always available if initialized
|
||||
nullcontext(),
|
||||
None, # no env variable needed
|
||||
{}, # no env variable needed
|
||||
None, # no destroy function
|
||||
)
|
||||
)
|
||||
|
||||
@@ -223,7 +247,8 @@ class CommunicatorBenchmark:
|
||||
lambda t, c=comm: c.all_reduce(t),
|
||||
lambda t, c=comm: c.should_use_symm_mem(t),
|
||||
nullcontext(),
|
||||
None, # no env variable needed
|
||||
{}, # no env variable needed
|
||||
None, # no destroy function
|
||||
)
|
||||
)
|
||||
|
||||
@@ -235,29 +260,67 @@ class CommunicatorBenchmark:
|
||||
lambda t, c=comm: c.all_reduce(t),
|
||||
lambda t, c=comm: c.should_use_symm_mem(t),
|
||||
nullcontext(),
|
||||
None, # no env variable needed
|
||||
{}, # no env variable needed
|
||||
None, # no destroy function needed
|
||||
)
|
||||
)
|
||||
|
||||
if self.fi_ar_comm is not None:
|
||||
comm = self.fi_ar_comm
|
||||
communicators.append(
|
||||
(
|
||||
"flashinfer_trtllm",
|
||||
lambda t, c=comm: c.all_reduce(t),
|
||||
lambda t, c=comm: c.should_use_fi_ar(t),
|
||||
nullcontext(),
|
||||
{"VLLM_FLASHINFER_ALLREDUCE_BACKEND": "trtllm"},
|
||||
lambda c=comm: c.destroy(),
|
||||
)
|
||||
)
|
||||
communicators.append(
|
||||
(
|
||||
"flashinfer_mnnvl",
|
||||
lambda t, c=comm: c.all_reduce(t),
|
||||
lambda t, c=comm: c.should_use_fi_ar(t),
|
||||
nullcontext(),
|
||||
{"VLLM_FLASHINFER_ALLREDUCE_BACKEND": "mnnvl"},
|
||||
lambda c=comm: c.destroy(),
|
||||
)
|
||||
)
|
||||
|
||||
# Benchmark each communicator
|
||||
for name, allreduce_fn, should_use_fn, context, env_var in communicators:
|
||||
# Set environment variable if needed
|
||||
if env_var is not None:
|
||||
os.environ["VLLM_CUSTOM_ALLREDUCE_ALGO"] = env_var
|
||||
else:
|
||||
# Clear the environment variable to avoid interference
|
||||
os.environ.pop("VLLM_CUSTOM_ALLREDUCE_ALGO", None)
|
||||
|
||||
latency = self.benchmark_allreduce_single(
|
||||
sequence_length,
|
||||
allreduce_fn,
|
||||
should_use_fn,
|
||||
context,
|
||||
num_warmup,
|
||||
num_trials,
|
||||
)
|
||||
if latency is not None:
|
||||
results[name] = latency
|
||||
for (
|
||||
name,
|
||||
allreduce_fn,
|
||||
should_use_fn,
|
||||
context,
|
||||
env_dict,
|
||||
destroy_fn,
|
||||
) in communicators:
|
||||
# Save original values and apply new environment variables
|
||||
saved_env = {key: os.environ.get(key) for key in env_dict}
|
||||
for key, value in env_dict.items():
|
||||
os.environ[key] = value
|
||||
try:
|
||||
latency = self.benchmark_allreduce_single(
|
||||
sequence_length,
|
||||
allreduce_fn,
|
||||
should_use_fn,
|
||||
context,
|
||||
num_warmup,
|
||||
num_trials,
|
||||
)
|
||||
if latency is not None:
|
||||
results[name] = latency
|
||||
finally:
|
||||
if destroy_fn is not None:
|
||||
destroy_fn()
|
||||
# Restore environment variables to their original state
|
||||
for key, original_value in saved_env.items():
|
||||
if original_value is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = original_value
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@@ -5,8 +5,11 @@
|
||||
Benchmark for FlashInfer fused collective operations vs standard operations.
|
||||
|
||||
This benchmark compares:
|
||||
1. FlashInfer's allreduce_fusion (fused allreduce + rmsnorm + optional quant)
|
||||
2. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations
|
||||
1. FlashInfer's allreduce_fusion with trtllm backend
|
||||
(fused allreduce + rmsnorm + optional FP8/FP4 quant)
|
||||
2. FlashInfer's allreduce_fusion with mnnvl backend
|
||||
(fused allreduce + rmsnorm only, no quantization support)
|
||||
3. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations
|
||||
|
||||
Usage with torchrun:
|
||||
torchrun --nproc_per_node=2 benchmark_fused_collective.py
|
||||
@@ -48,8 +51,12 @@ SCALED_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Try to import FlashInfer
|
||||
TorchDistBackend = None
|
||||
try:
|
||||
import flashinfer.comm as flashinfer_comm # type: ignore
|
||||
from flashinfer.comm.mnnvl import ( # type: ignore
|
||||
TorchDistBackend,
|
||||
)
|
||||
|
||||
if not (
|
||||
hasattr(flashinfer_comm, "allreduce_fusion")
|
||||
@@ -74,11 +81,15 @@ _FI_MAX_SIZES = {
|
||||
8: 64 * MiB, # 64MB
|
||||
}
|
||||
|
||||
# Global workspace tensor for FlashInfer
|
||||
_FI_WORKSPACE = None
|
||||
# Global workspace tensors for FlashInfer (keyed by backend name)
|
||||
_FI_WORKSPACES: dict = {}
|
||||
|
||||
# Backends to benchmark
|
||||
FLASHINFER_BACKENDS = ["trtllm", "mnnvl"]
|
||||
|
||||
|
||||
def setup_flashinfer_workspace(
|
||||
backend: str,
|
||||
world_size: int,
|
||||
rank: int,
|
||||
hidden_dim: int,
|
||||
@@ -86,41 +97,54 @@ def setup_flashinfer_workspace(
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Setup FlashInfer workspace for fused allreduce operations."""
|
||||
global _FI_WORKSPACE
|
||||
global FI_WORKSPACES
|
||||
|
||||
if flashinfer_comm is None:
|
||||
return None, None
|
||||
return None
|
||||
|
||||
if world_size not in _FI_MAX_SIZES:
|
||||
logger.warning("FlashInfer not supported for world size %s", world_size)
|
||||
return None, None
|
||||
return None
|
||||
|
||||
try:
|
||||
kwargs = {}
|
||||
if TorchDistBackend is not None:
|
||||
kwargs["comm_backend"] = TorchDistBackend(group=dist.group.WORLD)
|
||||
|
||||
workspace = flashinfer_comm.create_allreduce_fusion_workspace(
|
||||
backend="trtllm",
|
||||
backend=backend,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
max_token_num=max_token_num,
|
||||
hidden_dim=hidden_dim,
|
||||
dtype=dtype,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
_FI_WORKSPACE = workspace
|
||||
_FI_WORKSPACES[backend] = workspace
|
||||
return workspace
|
||||
except Exception as e:
|
||||
logger.error("Failed to setup FlashInfer workspace: %s", e)
|
||||
logger.error(
|
||||
"Failed to setup FlashInfer workspace (backend=%s): %s", backend, e
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def cleanup_flashinfer_workspace(workspace):
|
||||
"""Cleanup FlashInfer workspace."""
|
||||
if flashinfer_comm is None or workspace is None:
|
||||
def cleanup_flashinfer_workspaces():
|
||||
"""Cleanup all FlashInfer workspaces."""
|
||||
if flashinfer_comm is None:
|
||||
return
|
||||
|
||||
try:
|
||||
workspace.destroy()
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup FlashInfer workspace: %s", e)
|
||||
for backend, workspace in _FI_WORKSPACES.items():
|
||||
try:
|
||||
workspace.destroy()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to cleanup FlashInfer workspace (backend=%s): %s",
|
||||
backend,
|
||||
e,
|
||||
)
|
||||
_FI_WORKSPACES.clear()
|
||||
|
||||
|
||||
class FlashInferFusedAllReduceParams:
|
||||
@@ -134,7 +158,7 @@ class FlashInferFusedAllReduceParams:
|
||||
self.fp32_acc = True
|
||||
self.max_token_num = max_token_num
|
||||
|
||||
def get_trtllm_fused_allreduce_kwargs(self):
|
||||
def get_flashinfer_fused_allreduce_kwargs(self):
|
||||
return {
|
||||
"launch_with_pdl": self.launch_with_pdl,
|
||||
"fp32_acc": self.fp32_acc,
|
||||
@@ -147,11 +171,12 @@ def flashinfer_fused_allreduce_rmsnorm(
|
||||
rms_gamma: torch.Tensor,
|
||||
rms_eps: float,
|
||||
allreduce_params: "FlashInferFusedAllReduceParams",
|
||||
workspace: object,
|
||||
use_oneshot: bool,
|
||||
norm_out: torch.Tensor | None = None,
|
||||
):
|
||||
"""FlashInfer fused allreduce + rmsnorm operation."""
|
||||
if flashinfer_comm is None or _FI_WORKSPACE is None:
|
||||
if flashinfer_comm is None or workspace is None:
|
||||
raise RuntimeError("FlashInfer not available or workspace not initialized")
|
||||
|
||||
if norm_out is None:
|
||||
@@ -160,9 +185,13 @@ def flashinfer_fused_allreduce_rmsnorm(
|
||||
else:
|
||||
residual_out = input_tensor
|
||||
|
||||
layout_code = None
|
||||
if workspace.backend == "trtllm":
|
||||
layout_code = flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4
|
||||
|
||||
flashinfer_comm.allreduce_fusion(
|
||||
input=input_tensor,
|
||||
workspace=_FI_WORKSPACE,
|
||||
workspace=workspace,
|
||||
pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
|
||||
residual_in=residual,
|
||||
residual_out=residual_out,
|
||||
@@ -171,10 +200,10 @@ def flashinfer_fused_allreduce_rmsnorm(
|
||||
rms_eps=rms_eps,
|
||||
quant_out=None,
|
||||
scale_out=None,
|
||||
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
|
||||
layout_code=layout_code,
|
||||
scale_factor=None,
|
||||
use_oneshot=use_oneshot,
|
||||
**allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||
**allreduce_params.get_flashinfer_fused_allreduce_kwargs(),
|
||||
)
|
||||
|
||||
|
||||
@@ -185,12 +214,16 @@ def flashinfer_fused_allreduce_rmsnorm_fp8_quant(
|
||||
rms_eps: float,
|
||||
scale_factor: torch.Tensor,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
workspace: object,
|
||||
use_oneshot: bool = True,
|
||||
norm_out: torch.Tensor | None = None,
|
||||
quant_out: torch.Tensor | None = None,
|
||||
):
|
||||
"""FlashInfer fused allreduce + rmsnorm + FP8 quantization."""
|
||||
if flashinfer_comm is None or _FI_WORKSPACE is None:
|
||||
"""FlashInfer fused allreduce + rmsnorm + FP8 quantization.
|
||||
|
||||
Note: Only supported by the trtllm backend.
|
||||
"""
|
||||
if flashinfer_comm is None or workspace is None:
|
||||
raise RuntimeError("FlashInfer not available or workspace not initialized")
|
||||
|
||||
if norm_out is None:
|
||||
@@ -201,7 +234,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp8_quant(
|
||||
|
||||
flashinfer_comm.allreduce_fusion(
|
||||
input=input_tensor,
|
||||
workspace=_FI_WORKSPACE,
|
||||
workspace=workspace,
|
||||
pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant,
|
||||
residual_in=residual,
|
||||
residual_out=residual_out,
|
||||
@@ -213,7 +246,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp8_quant(
|
||||
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
|
||||
scale_factor=scale_factor,
|
||||
use_oneshot=use_oneshot,
|
||||
**allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||
**allreduce_params.get_flashinfer_fused_allreduce_kwargs(),
|
||||
)
|
||||
|
||||
|
||||
@@ -224,13 +257,17 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant(
|
||||
rms_eps: float,
|
||||
input_global_scale: torch.Tensor,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
workspace: object,
|
||||
quant_out: torch.Tensor,
|
||||
use_oneshot: bool,
|
||||
output_scale: torch.Tensor,
|
||||
norm_out: torch.Tensor | None = None,
|
||||
):
|
||||
"""FlashInfer fused allreduce + rmsnorm + FP4 quantization."""
|
||||
if flashinfer_comm is None or _FI_WORKSPACE is None:
|
||||
"""FlashInfer fused allreduce + rmsnorm + FP4 quantization.
|
||||
|
||||
Note: Only supported by the trtllm backend.
|
||||
"""
|
||||
if flashinfer_comm is None or workspace is None:
|
||||
raise RuntimeError("FlashInfer not available or workspace not initialized")
|
||||
|
||||
if norm_out is None:
|
||||
@@ -241,7 +278,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant(
|
||||
|
||||
flashinfer_comm.allreduce_fusion(
|
||||
input=input_tensor,
|
||||
workspace=_FI_WORKSPACE,
|
||||
workspace=workspace,
|
||||
pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant,
|
||||
residual_in=residual,
|
||||
residual_out=residual_out,
|
||||
@@ -253,7 +290,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant(
|
||||
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
|
||||
scale_factor=input_global_scale,
|
||||
use_oneshot=use_oneshot,
|
||||
**allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||
**allreduce_params.get_flashinfer_fused_allreduce_kwargs(),
|
||||
)
|
||||
|
||||
|
||||
@@ -386,13 +423,16 @@ def run_benchmarks(
|
||||
dtype: torch.dtype,
|
||||
use_residual: bool,
|
||||
allreduce_params: FlashInferFusedAllReduceParams | None,
|
||||
workspaces: dict,
|
||||
quant_modes: set[str],
|
||||
no_oneshot: bool,
|
||||
):
|
||||
"""Run all benchmarks for given configuration.
|
||||
|
||||
Args:
|
||||
quant_mode: "none", "fp8_only", "fp4_only", or "all"
|
||||
allreduce_params: Shared parameters for FlashInfer fused allreduce.
|
||||
workspaces: Dict mapping backend name ("trtllm", "mnnvl") to workspace.
|
||||
quant_modes: Set of quantization modes: "none", "fp8", "fp4".
|
||||
"""
|
||||
(
|
||||
input_tensor,
|
||||
@@ -454,10 +494,11 @@ def run_benchmarks(
|
||||
logger.error("Standard AllReduce+RMSNorm Native Compiled failed: %s", e)
|
||||
results["standard_allreduce_rmsnorm_native_compiled"] = float("inf")
|
||||
|
||||
# FlashInfer Fused AllReduce + RMSNorm Oneshot/Twoshot
|
||||
if flashinfer_comm is not None and allreduce_params is not None:
|
||||
# FlashInfer Fused AllReduce + RMSNorm (all backends)
|
||||
for backend, workspace in workspaces.items():
|
||||
for use_oneshot in use_oneshot_options:
|
||||
suffix = "_oneshot" if use_oneshot else "_twoshot"
|
||||
key = f"flashinfer_{backend}_fused_allreduce_rmsnorm{suffix}"
|
||||
try:
|
||||
time_ms = benchmark_operation(
|
||||
flashinfer_fused_allreduce_rmsnorm,
|
||||
@@ -467,14 +508,17 @@ def run_benchmarks(
|
||||
rms_gamma=rms_gamma,
|
||||
rms_eps=rms_eps,
|
||||
allreduce_params=allreduce_params,
|
||||
workspace=workspace,
|
||||
use_oneshot=use_oneshot,
|
||||
)
|
||||
results[f"flashinfer_fused_allreduce_rmsnorm{suffix}"] = time_ms
|
||||
results[key] = time_ms
|
||||
except Exception as e:
|
||||
logger.error("FlashInfer Fused AllReduce+RMSNorm failed: %s", e)
|
||||
results[f"flashinfer_fused_allreduce_rmsnorm{suffix}"] = float(
|
||||
"inf"
|
||||
logger.error(
|
||||
"FlashInfer (%s) Fused AllReduce+RMSNorm failed: %s",
|
||||
backend,
|
||||
e,
|
||||
)
|
||||
results[key] = float("inf")
|
||||
|
||||
if "fp8" in quant_modes:
|
||||
# Standard AllReduce + RMSNorm + FP8 Quant
|
||||
@@ -540,10 +584,12 @@ def run_benchmarks(
|
||||
"inf"
|
||||
)
|
||||
|
||||
# FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Oneshot
|
||||
if flashinfer_comm is not None and allreduce_params is not None:
|
||||
# FlashInfer Fused AllReduce + RMSNorm + FP8 Quant (trtllm only)
|
||||
if "trtllm" in workspaces:
|
||||
trtllm_ws = workspaces["trtllm"]
|
||||
for use_oneshot in use_oneshot_options:
|
||||
suffix = "_oneshot" if use_oneshot else "_twoshot"
|
||||
key = f"flashinfer_trtllm_fused_allreduce_rmsnorm_fp8_quant{suffix}"
|
||||
try:
|
||||
time_ms = benchmark_operation(
|
||||
flashinfer_fused_allreduce_rmsnorm_fp8_quant,
|
||||
@@ -555,19 +601,16 @@ def run_benchmarks(
|
||||
scale_factor=scale_fp8,
|
||||
quant_out=quant_out_fp8,
|
||||
allreduce_params=allreduce_params,
|
||||
workspace=trtllm_ws,
|
||||
use_oneshot=use_oneshot,
|
||||
)
|
||||
results[f"flashinfer_fused_allreduce_rmsnorm_fp8_quant{suffix}"] = (
|
||||
time_ms
|
||||
)
|
||||
results[key] = time_ms
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"FlashInfer Fused AllReduce+RMSNorm+FP8 Oneshot failed: %s",
|
||||
"FlashInfer (trtllm) Fused AllReduce+RMSNorm+FP8 failed: %s",
|
||||
e,
|
||||
)
|
||||
results[f"flashinfer_fused_allreduce_rmsnorm_fp8_quant{suffix}"] = (
|
||||
float("inf")
|
||||
)
|
||||
results[key] = float("inf")
|
||||
|
||||
if "fp4" in quant_modes and current_platform.has_device_capability(100):
|
||||
# Standard AllReduce + RMSNorm + FP4 Quant
|
||||
@@ -627,10 +670,12 @@ def run_benchmarks(
|
||||
"inf"
|
||||
)
|
||||
|
||||
# FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Oneshot
|
||||
if flashinfer_comm is not None and allreduce_params is not None:
|
||||
# FlashInfer Fused AllReduce + RMSNorm + FP4 Quant (trtllm only)
|
||||
if "trtllm" in workspaces:
|
||||
trtllm_ws = workspaces["trtllm"]
|
||||
for use_oneshot in use_oneshot_options:
|
||||
suffix = "_oneshot" if use_oneshot else "_twoshot"
|
||||
key = f"flashinfer_trtllm_fused_allreduce_rmsnorm_fp4_quant{suffix}"
|
||||
try:
|
||||
time_ms = benchmark_operation(
|
||||
flashinfer_fused_allreduce_rmsnorm_fp4_quant,
|
||||
@@ -641,49 +686,18 @@ def run_benchmarks(
|
||||
rms_eps=rms_eps,
|
||||
input_global_scale=scale_fp4,
|
||||
allreduce_params=allreduce_params,
|
||||
workspace=trtllm_ws,
|
||||
quant_out=fp4_quant_out,
|
||||
output_scale=fp4_output_scale,
|
||||
use_oneshot=use_oneshot,
|
||||
)
|
||||
results[f"flashinfer_fused_allreduce_rmsnorm_fp4_quant{suffix}"] = (
|
||||
time_ms
|
||||
)
|
||||
results[key] = time_ms
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"FlashInfer Fused AllReduce+RMSNorm+FP4 Oneshot failed: %s",
|
||||
"FlashInfer (trtllm) Fused AllReduce+RMSNorm+FP4 failed: %s",
|
||||
e,
|
||||
)
|
||||
results[f"flashinfer_fused_allreduce_rmsnorm_fp4_quant{suffix}"] = (
|
||||
float("inf")
|
||||
)
|
||||
|
||||
# FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Two-shot
|
||||
if flashinfer_comm is not None and allreduce_params is not None:
|
||||
try:
|
||||
time_ms = benchmark_operation(
|
||||
flashinfer_fused_allreduce_rmsnorm_fp4_quant,
|
||||
input_tensor,
|
||||
residual=residual,
|
||||
norm_out=norm_out,
|
||||
rms_gamma=rms_gamma,
|
||||
rms_eps=rms_eps,
|
||||
input_global_scale=scale_fp4,
|
||||
allreduce_params=allreduce_params,
|
||||
quant_out=fp4_quant_out,
|
||||
output_scale=fp4_output_scale,
|
||||
use_oneshot=False,
|
||||
)
|
||||
results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = (
|
||||
time_ms
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"FlashInfer Fused AllReduce+RMSNorm+FP4 Two-shot failed: %s",
|
||||
e,
|
||||
)
|
||||
results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = float(
|
||||
"inf"
|
||||
)
|
||||
results[key] = float("inf")
|
||||
|
||||
return results
|
||||
|
||||
@@ -1021,8 +1035,7 @@ def main():
|
||||
|
||||
configs = list(itertools.product(args.num_tokens, dtypes, residual_options))
|
||||
|
||||
# Setup FlashInfer workspace if available
|
||||
workspace = None
|
||||
# Setup FlashInfer workspaces for all backends
|
||||
allreduce_params = None
|
||||
|
||||
if flashinfer_comm is not None:
|
||||
@@ -1037,15 +1050,17 @@ def main():
|
||||
args.hidden_dim * max_element_size
|
||||
)
|
||||
|
||||
workspace = setup_flashinfer_workspace(
|
||||
world_size,
|
||||
rank,
|
||||
args.hidden_dim,
|
||||
max_num_token,
|
||||
dtype=workspace_dtype,
|
||||
)
|
||||
for backend in FLASHINFER_BACKENDS:
|
||||
setup_flashinfer_workspace(
|
||||
backend=backend,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
hidden_dim=args.hidden_dim,
|
||||
max_token_num=max_num_token,
|
||||
dtype=workspace_dtype,
|
||||
)
|
||||
|
||||
if workspace is not None:
|
||||
if _FI_WORKSPACES:
|
||||
allreduce_params = FlashInferFusedAllReduceParams(
|
||||
max_token_num=max_num_token,
|
||||
)
|
||||
@@ -1071,6 +1086,7 @@ def main():
|
||||
dtype,
|
||||
use_residual,
|
||||
allreduce_params,
|
||||
workspaces=_FI_WORKSPACES,
|
||||
quant_modes=quant_modes,
|
||||
no_oneshot=args.no_oneshot,
|
||||
)
|
||||
@@ -1109,11 +1125,13 @@ def main():
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
if workspace is not None:
|
||||
cleanup_flashinfer_workspace(workspace)
|
||||
cleanup_flashinfer_workspaces()
|
||||
|
||||
dist.barrier()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
|
||||
with set_current_vllm_config(VllmConfig()):
|
||||
main()
|
||||
|
||||
@@ -142,7 +142,6 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
||||
*(scaled_fp4_quant(w, wg) for w, wg in zip(self.w, wgscale))
|
||||
)
|
||||
self.wq, self.wscale = list(wq_gen), list(wscale_gen)
|
||||
print(f"{self.wq=}, {self.wscale=}")
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
@@ -199,6 +198,7 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
||||
@pytest.mark.parametrize("hidden_size", [64])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
|
||||
@pytest.mark.parametrize("flashinfer_allreduce_backend", ["trtllm", "mnnvl"])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||
@pytest.mark.skipif(
|
||||
not find_spec("flashinfer")
|
||||
@@ -215,6 +215,7 @@ def test_all_reduce_fusion_pass_replace(
|
||||
dtype: torch.dtype,
|
||||
enable_rms_norm_custom_op,
|
||||
enable_quant_fp8_custom_op,
|
||||
flashinfer_allreduce_backend,
|
||||
):
|
||||
num_processes = 2
|
||||
if (
|
||||
@@ -238,6 +239,7 @@ def test_all_reduce_fusion_pass_replace(
|
||||
dtype,
|
||||
enable_rms_norm_custom_op,
|
||||
enable_quant_fp8_custom_op,
|
||||
flashinfer_allreduce_backend,
|
||||
),
|
||||
nprocs=nprocs,
|
||||
)
|
||||
@@ -255,6 +257,7 @@ def all_reduce_fusion_pass_on_test_model(
|
||||
dtype: torch.dtype,
|
||||
enable_rms_norm_custom_op,
|
||||
enable_quant_fp8_custom_op,
|
||||
flashinfer_allreduce_backend,
|
||||
):
|
||||
set_random_seed(0)
|
||||
|
||||
@@ -270,6 +273,7 @@ def all_reduce_fusion_pass_on_test_model(
|
||||
"WORLD_SIZE": str(world_size),
|
||||
"MASTER_ADDR": "localhost",
|
||||
"MASTER_PORT": "12345",
|
||||
"VLLM_FLASHINFER_ALLREDUCE_BACKEND": flashinfer_allreduce_backend,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -317,6 +321,10 @@ def all_reduce_fusion_pass_on_test_model(
|
||||
compiled_model = torch.compile(model, backend=backend)
|
||||
compiled_model(hidden_states)
|
||||
|
||||
results_unfused = model(hidden_states)
|
||||
results_fused = compiled_model(hidden_states)
|
||||
torch.testing.assert_close(results_unfused, results_fused, atol=1e-2, rtol=1e-2)
|
||||
|
||||
assert all_reduce_fusion_pass.matched_count == 4, (
|
||||
f"{all_reduce_fusion_pass.matched_count=}"
|
||||
)
|
||||
|
||||
@@ -22,7 +22,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.utils.torch_utils import (
|
||||
direct_register_custom_op,
|
||||
)
|
||||
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
@@ -44,8 +46,6 @@ if find_spec("flashinfer"):
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
|
||||
|
||||
@@ -82,7 +82,16 @@ _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB: dict[int, dict[int, float]] = {
|
||||
|
||||
|
||||
if flashinfer_comm is not None:
|
||||
_FI_WORKSPACE = None
|
||||
from vllm.distributed.device_communicators.flashinfer_all_reduce import (
|
||||
destroy_fi_ar_workspace,
|
||||
get_fi_ar_quant_workspace,
|
||||
get_fi_ar_workspace,
|
||||
initialize_fi_ar_quant_workspace,
|
||||
initialize_fi_ar_workspace,
|
||||
)
|
||||
|
||||
ar_fusion_patterns = flashinfer_comm.AllReduceFusionPattern
|
||||
|
||||
MiB = 1024 * 1024
|
||||
|
||||
def call_trtllm_fused_allreduce_norm(
|
||||
@@ -122,9 +131,19 @@ if flashinfer_comm is not None:
|
||||
max_one_shot_size is None or current_tensor_size <= max_one_shot_size * MiB
|
||||
)
|
||||
|
||||
assert _FI_WORKSPACE is not None, (
|
||||
"Flashinfer must be enabled when using flashinfer"
|
||||
# Select workspace based on pattern: quant patterns use the
|
||||
# trtllm quant workspace, non-quant patterns use the primary workspace.
|
||||
if pattern_code in (
|
||||
ar_fusion_patterns.kARResidualRMSNormFP8Quant,
|
||||
ar_fusion_patterns.kARResidualRMSNormFP4Quant,
|
||||
):
|
||||
workspace = get_fi_ar_quant_workspace()
|
||||
else:
|
||||
workspace = get_fi_ar_workspace()
|
||||
assert workspace is not None, (
|
||||
"Flashinfer workspace must be initialized when using flashinfer"
|
||||
)
|
||||
assert flashinfer_comm is not None
|
||||
if norm_out is None:
|
||||
norm_out = allreduce_in
|
||||
residual_out = residual
|
||||
@@ -133,25 +152,30 @@ if flashinfer_comm is not None:
|
||||
# as flashinfer does not support rms_norm
|
||||
# and allreduce_out together
|
||||
residual_out = allreduce_in
|
||||
# For the sizes that are smaller than the max size,
|
||||
# we only use flashinfer one shot allreduce
|
||||
|
||||
layout_code = None
|
||||
# layout_code only supported by trtllm backend
|
||||
if workspace.backend == "trtllm":
|
||||
# in vllm we only support swizzled layout
|
||||
layout_code = flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4
|
||||
|
||||
flashinfer_comm.allreduce_fusion(
|
||||
input=allreduce_in,
|
||||
workspace=_FI_WORKSPACE,
|
||||
workspace=workspace,
|
||||
pattern=pattern_code,
|
||||
residual_in=residual,
|
||||
launch_with_pdl=launch_with_pdl,
|
||||
output=None,
|
||||
residual_out=residual_out,
|
||||
norm_out=norm_out,
|
||||
rms_gamma=rms_gamma,
|
||||
rms_eps=rms_eps,
|
||||
launch_with_pdl=launch_with_pdl,
|
||||
use_oneshot=use_oneshot,
|
||||
fp32_acc=fp32_acc,
|
||||
quant_out=quant_out,
|
||||
scale_out=scale_out,
|
||||
# in vllm we only support swizzled layout
|
||||
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
|
||||
residual_in=residual,
|
||||
rms_gamma=rms_gamma,
|
||||
rms_eps=rms_eps,
|
||||
scale_factor=scale_factor,
|
||||
layout_code=layout_code,
|
||||
use_oneshot=use_oneshot,
|
||||
fp32_acc=fp32_acc,
|
||||
)
|
||||
|
||||
def call_trtllm_fused_allreduce_norm_fake(
|
||||
@@ -729,29 +753,36 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
|
||||
scope="global",
|
||||
)
|
||||
|
||||
try:
|
||||
self.workspace = flashinfer_comm.create_allreduce_fusion_workspace(
|
||||
backend="trtllm",
|
||||
world_size=self.tp_size,
|
||||
rank=rank,
|
||||
max_token_num=self.max_token_num,
|
||||
hidden_dim=self.hidden_dim,
|
||||
dtype=self.model_dtype,
|
||||
)
|
||||
except RuntimeError as e:
|
||||
if "multicast" not in str(e).lower():
|
||||
raise
|
||||
logger.warning_once(
|
||||
"AllReduce fusion pass is disabled: flashinfer workspace "
|
||||
"creation failed: %s. This is expected on GPUs without "
|
||||
"NVSwitch (e.g., NVLink bridge-only or PCIe topologies). "
|
||||
"Falling back to non-fused allreduce.",
|
||||
str(e),
|
||||
)
|
||||
return
|
||||
for workspace_init_fn in [
|
||||
initialize_fi_ar_workspace,
|
||||
initialize_fi_ar_quant_workspace,
|
||||
]:
|
||||
try:
|
||||
workspace_init_fn(
|
||||
world_size=self.tp_size,
|
||||
rank=rank,
|
||||
max_token_num=self.max_token_num,
|
||||
hidden_dim=self.hidden_dim,
|
||||
dtype=self.model_dtype,
|
||||
group=self.group,
|
||||
)
|
||||
except Exception as e:
|
||||
if "multicast" in str(e).lower():
|
||||
logger.warning(
|
||||
"AllReduce fusion pass is disabled: flashinfer workspace "
|
||||
"creation failed: %s. This is expected on GPUs without "
|
||||
"NVSwitch (e.g., NVLink bridge-only or PCIe topologies). "
|
||||
"Falling back to non-fused allreduce.",
|
||||
str(e),
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Failed to initialize FlashInfer All Reduce workspace: %s. "
|
||||
"AllReduce fusion pass will be disabled.",
|
||||
e,
|
||||
)
|
||||
return
|
||||
|
||||
global _FI_WORKSPACE
|
||||
_FI_WORKSPACE = self.workspace
|
||||
self.allreduce_params = FlashInferFusedAllReduceParams(
|
||||
world_size=self.tp_size,
|
||||
max_token_num=self.max_token_num,
|
||||
@@ -762,32 +793,34 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
|
||||
|
||||
@enable_fake_mode
|
||||
def register_patterns(self) -> None:
|
||||
supports_quantization = get_fi_ar_quant_workspace() is not None
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
AllReduceFusedRMSNormStaticQuantFP8Pattern(
|
||||
epsilon,
|
||||
self.model_dtype,
|
||||
self.device,
|
||||
self.allreduce_params,
|
||||
).register(self.patterns)
|
||||
AllReduceFusedAddRMSNormStaticQuantFP8Pattern(
|
||||
epsilon,
|
||||
self.model_dtype,
|
||||
self.device,
|
||||
self.allreduce_params,
|
||||
).register(self.patterns)
|
||||
if current_platform.has_device_capability(100):
|
||||
AllReduceFusedRMSNormStaticQuantNVFP4Pattern(
|
||||
if supports_quantization:
|
||||
AllReduceFusedRMSNormStaticQuantFP8Pattern(
|
||||
epsilon,
|
||||
self.model_dtype,
|
||||
self.device,
|
||||
self.allreduce_params,
|
||||
).register(self.patterns)
|
||||
AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(
|
||||
AllReduceFusedAddRMSNormStaticQuantFP8Pattern(
|
||||
epsilon,
|
||||
self.model_dtype,
|
||||
self.device,
|
||||
self.allreduce_params,
|
||||
).register(self.patterns)
|
||||
if current_platform.has_device_capability(100):
|
||||
AllReduceFusedRMSNormStaticQuantNVFP4Pattern(
|
||||
epsilon,
|
||||
self.model_dtype,
|
||||
self.device,
|
||||
self.allreduce_params,
|
||||
).register(self.patterns)
|
||||
AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(
|
||||
epsilon,
|
||||
self.model_dtype,
|
||||
self.device,
|
||||
self.allreduce_params,
|
||||
).register(self.patterns)
|
||||
AllReduceRMSNormPattern(
|
||||
epsilon,
|
||||
self.model_dtype,
|
||||
@@ -825,6 +858,5 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
|
||||
def __del__(self) -> None:
|
||||
if getattr(self, "disabled", True):
|
||||
return
|
||||
if getattr(self, "workspace", None) is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
self.workspace.destroy()
|
||||
with contextlib.suppress(Exception):
|
||||
destroy_fi_ar_workspace()
|
||||
|
||||
@@ -34,19 +34,25 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
# custom allreduce or torch symm mem can be used only by tp
|
||||
use_custom_allreduce = False
|
||||
use_torch_symm_mem = False
|
||||
use_flashinfer_allreduce = False
|
||||
else:
|
||||
from vllm.distributed.parallel_state import _ENABLE_CUSTOM_ALL_REDUCE
|
||||
|
||||
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
||||
use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM
|
||||
use_flashinfer_allreduce = envs.VLLM_ALLREDUCE_USE_FLASHINFER
|
||||
|
||||
self.use_custom_allreduce = use_custom_allreduce
|
||||
self.use_torch_symm_mem = use_torch_symm_mem
|
||||
self.use_flashinfer_allreduce = use_flashinfer_allreduce
|
||||
|
||||
# lazy import to avoid documentation build error
|
||||
from vllm.distributed.device_communicators.custom_all_reduce import (
|
||||
CustomAllreduce,
|
||||
)
|
||||
from vllm.distributed.device_communicators.flashinfer_all_reduce import (
|
||||
FlashInferAllReduce,
|
||||
)
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from vllm.distributed.device_communicators.quick_all_reduce import (
|
||||
QuickAllReduce,
|
||||
@@ -65,12 +71,20 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
self.ca_comm: CustomAllreduce | None = None
|
||||
self.qr_comm: QuickAllReduce | None = None
|
||||
self.symm_mem_comm: SymmMemCommunicator | None = None
|
||||
self.fi_ar_comm: FlashInferAllReduce | None = None
|
||||
|
||||
if use_torch_symm_mem and current_platform.is_cuda():
|
||||
self.symm_mem_comm = SymmMemCommunicator(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if self.use_flashinfer_allreduce and self.world_size > 1:
|
||||
self.fi_ar_comm = FlashInferAllReduce(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if use_custom_allreduce and self.world_size > 1:
|
||||
# Initialize a custom fast all-reduce implementation.
|
||||
self.ca_comm = CustomAllreduce(
|
||||
@@ -136,7 +150,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
out = torch.ops.vllm.all_reduce_symmetric_with_copy(input_)
|
||||
if out is not None:
|
||||
return out
|
||||
# always try quick reduce first, then custom allreduce,
|
||||
# always try quick reduce first, then flashinfer, then custom allreduce,
|
||||
# and then pynccl. (quick reduce just for ROCM MI3*)
|
||||
qr_comm = self.qr_comm
|
||||
if (
|
||||
@@ -147,6 +161,15 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
out = qr_comm.quick_all_reduce(input_)
|
||||
assert out is not None
|
||||
return out
|
||||
fi_ar_comm = self.fi_ar_comm
|
||||
if (
|
||||
fi_ar_comm is not None
|
||||
and not fi_ar_comm.disabled
|
||||
and fi_ar_comm.should_use_fi_ar(input_)
|
||||
):
|
||||
out = fi_ar_comm.all_reduce(input_)
|
||||
assert out is not None
|
||||
return out
|
||||
ca_comm = self.ca_comm
|
||||
if (
|
||||
ca_comm is not None
|
||||
@@ -270,6 +293,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
self.pynccl_comm = None
|
||||
if self.ca_comm is not None:
|
||||
self.ca_comm = None
|
||||
if self.fi_ar_comm is not None:
|
||||
self.fi_ar_comm.destroy()
|
||||
self.fi_ar_comm = None
|
||||
if self.all2all_manager is not None:
|
||||
self.all2all_manager.destroy()
|
||||
self.all2all_manager = None
|
||||
|
||||
252
vllm/distributed/device_communicators/flashinfer_all_reduce.py
Normal file
252
vllm/distributed/device_communicators/flashinfer_all_reduce.py
Normal file
@@ -0,0 +1,252 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config.compilation import PassConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
fi_ar_available = False
|
||||
try:
|
||||
import flashinfer.comm as flashinfer_comm # type: ignore[no-redef]
|
||||
from flashinfer.comm.mnnvl import (
|
||||
TorchDistBackend, # type: ignore[import-not-found, no-redef]
|
||||
)
|
||||
|
||||
fi_ar_available = hasattr(flashinfer_comm, "allreduce_fusion")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Global workspace for standalone allreduce and non-quant ar+rms fusion
|
||||
_fi_ar_workspace = None
|
||||
# Extra workspace for quant fusion patterns (only supported by trtllm backend)
|
||||
# Only created if primary workspace is not already trtllm
|
||||
_fi_ar_quant_workspace = None
|
||||
|
||||
|
||||
def get_fi_ar_workspace():
|
||||
return _fi_ar_workspace
|
||||
|
||||
|
||||
def get_fi_ar_quant_workspace():
|
||||
return _fi_ar_quant_workspace
|
||||
|
||||
|
||||
def initialize_fi_ar_workspace(
|
||||
world_size: int,
|
||||
rank: int,
|
||||
max_token_num: int,
|
||||
hidden_dim: int,
|
||||
dtype: torch.dtype,
|
||||
group: ProcessGroup,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the workspace if not already initialized.
|
||||
|
||||
Currently, this function is called by either the AllReduceFusionPass
|
||||
or the FlashInferAllReduce backend for standalone allreduce.
|
||||
If the fusion pass is enabled via
|
||||
--compilation-config.pass_config.fuse_allreduce_rms=true,
|
||||
it will create the workspace first, and the standalone backend
|
||||
will reuse the workspace. Otherwise, the standalone backend will
|
||||
create the workspace.
|
||||
"""
|
||||
global _fi_ar_workspace
|
||||
if _fi_ar_workspace is not None:
|
||||
return
|
||||
|
||||
backend = envs.VLLM_FLASHINFER_ALLREDUCE_BACKEND
|
||||
comm_backend = TorchDistBackend(group=group)
|
||||
_fi_ar_workspace = flashinfer_comm.create_allreduce_fusion_workspace(
|
||||
backend=backend,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
max_token_num=max_token_num,
|
||||
hidden_dim=hidden_dim,
|
||||
dtype=dtype,
|
||||
comm_backend=comm_backend,
|
||||
)
|
||||
assert _fi_ar_workspace is not None
|
||||
logger.debug(
|
||||
"Initialized FlashInfer All Reduce workspace: backend=%s, "
|
||||
"world_size=%d, rank=%d, max_token_num=%d, hidden_dim=%d, dtype=%s",
|
||||
backend,
|
||||
world_size,
|
||||
rank,
|
||||
max_token_num,
|
||||
hidden_dim,
|
||||
dtype,
|
||||
)
|
||||
|
||||
|
||||
def initialize_fi_ar_quant_workspace(
|
||||
world_size: int,
|
||||
rank: int,
|
||||
max_token_num: int,
|
||||
hidden_dim: int,
|
||||
dtype: torch.dtype,
|
||||
group: ProcessGroup,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the workspace used by quantization fusion patterns.
|
||||
|
||||
Currently this always creates a workspace for trtllm backend as only it
|
||||
supports quantization fusion (FP8/FP4). If the primary workspace
|
||||
is already trtllm, the quant workspace aliases to it.
|
||||
"""
|
||||
global _fi_ar_quant_workspace
|
||||
if _fi_ar_quant_workspace is not None:
|
||||
return
|
||||
|
||||
# If primary workspace is already trtllm, reuse it
|
||||
if _fi_ar_workspace is not None and _fi_ar_workspace.backend == "trtllm":
|
||||
_fi_ar_quant_workspace = _fi_ar_workspace
|
||||
return
|
||||
|
||||
comm_backend = TorchDistBackend(group=group)
|
||||
_fi_ar_quant_workspace = flashinfer_comm.create_allreduce_fusion_workspace(
|
||||
backend="trtllm",
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
max_token_num=max_token_num,
|
||||
hidden_dim=hidden_dim,
|
||||
dtype=dtype,
|
||||
comm_backend=comm_backend,
|
||||
)
|
||||
assert _fi_ar_quant_workspace is not None
|
||||
logger.debug(
|
||||
"Initialized FlashInfer All Reduce workspace: backend=trtllm, "
|
||||
"world_size=%d, rank=%d, max_token_num=%d, hidden_dim=%d, dtype=%s",
|
||||
world_size,
|
||||
rank,
|
||||
max_token_num,
|
||||
hidden_dim,
|
||||
dtype,
|
||||
)
|
||||
|
||||
|
||||
def destroy_fi_ar_workspace():
|
||||
global _fi_ar_workspace
|
||||
global _fi_ar_quant_workspace
|
||||
if (
|
||||
_fi_ar_quant_workspace is not None
|
||||
and _fi_ar_quant_workspace is not _fi_ar_workspace
|
||||
):
|
||||
_fi_ar_quant_workspace.destroy()
|
||||
_fi_ar_quant_workspace = None
|
||||
if _fi_ar_workspace is not None:
|
||||
_fi_ar_workspace.destroy()
|
||||
_fi_ar_workspace = None
|
||||
|
||||
|
||||
class FlashInferAllReduce:
|
||||
def __init__(
|
||||
self,
|
||||
group: ProcessGroup,
|
||||
device: int | str | torch.device,
|
||||
):
|
||||
self.disabled = True
|
||||
|
||||
if not fi_ar_available:
|
||||
logger.info(
|
||||
"FlashInfer All Reduce is disabled because flashinfer is not available"
|
||||
)
|
||||
return
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
logger.info(
|
||||
"FlashInfer All Reduce is disabled because it requires CUDA platform"
|
||||
)
|
||||
return
|
||||
|
||||
self.group = group
|
||||
self.world_size = dist.get_world_size(self.group)
|
||||
self.rank = dist.get_rank(self.group)
|
||||
self.device = device
|
||||
if self.world_size == 1:
|
||||
return
|
||||
|
||||
# Use the same threshold as the allreduce-rms fusion pass
|
||||
# TODO: tune the threshold
|
||||
MiB = 1024 * 1024
|
||||
max_workspace_size = PassConfig.default_fi_allreduce_fusion_max_size_mb().get(
|
||||
self.world_size, None
|
||||
)
|
||||
if not max_workspace_size:
|
||||
logger.warning(
|
||||
"FlashInfer All Reduce is disabled because it "
|
||||
"is not supported for world_size=%d.",
|
||||
self.world_size,
|
||||
)
|
||||
return
|
||||
self.max_workspace_size = max_workspace_size * MiB
|
||||
self.max_num_tokens = 0
|
||||
self.disabled = False
|
||||
|
||||
def _ensure_workspace(self, hidden_dim: int, dtype: torch.dtype) -> bool:
|
||||
"""Ensure the all reduce workspace is initialized."""
|
||||
if get_fi_ar_workspace() is not None:
|
||||
return True
|
||||
if self.max_num_tokens == 0:
|
||||
element_size = torch.tensor([], dtype=dtype, device="cpu").element_size()
|
||||
self.max_num_tokens = self.max_workspace_size // (hidden_dim * element_size)
|
||||
try:
|
||||
initialize_fi_ar_workspace(
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
max_token_num=self.max_num_tokens,
|
||||
hidden_dim=hidden_dim,
|
||||
dtype=dtype,
|
||||
group=self.group,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to initialize FlashInfer All Reduce workspace: %s. "
|
||||
"FlashInfer All Reduce will be disabled.",
|
||||
e,
|
||||
)
|
||||
self.disabled = True
|
||||
return False
|
||||
|
||||
def should_use_fi_ar(self, input_tensor: torch.Tensor) -> bool:
|
||||
if self.disabled:
|
||||
return False
|
||||
|
||||
if not input_tensor.is_cuda:
|
||||
return False
|
||||
|
||||
if not input_tensor.is_contiguous():
|
||||
return False
|
||||
|
||||
if len(input_tensor.shape) != 2:
|
||||
return False
|
||||
|
||||
num_tokens, hidden_dim = input_tensor.shape
|
||||
if not self.max_num_tokens:
|
||||
element_size = torch.tensor([], dtype=input_tensor.dtype).element_size()
|
||||
self.max_num_tokens = self.max_workspace_size // (hidden_dim * element_size)
|
||||
|
||||
if num_tokens > self.max_num_tokens:
|
||||
return False
|
||||
|
||||
return self._ensure_workspace(hidden_dim, input_tensor.dtype)
|
||||
|
||||
def all_reduce(self, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
workspace = get_fi_ar_workspace()
|
||||
return flashinfer_comm.allreduce_fusion(
|
||||
input=input_tensor,
|
||||
workspace=workspace,
|
||||
pattern=flashinfer_comm.AllReduceFusionPattern.kAllReduce,
|
||||
)
|
||||
|
||||
def destroy(self):
|
||||
if not self.disabled:
|
||||
destroy_fi_ar_workspace()
|
||||
14
vllm/envs.py
14
vllm/envs.py
@@ -168,6 +168,7 @@ if TYPE_CHECKING:
|
||||
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency", "masked_gemm"] = (
|
||||
"latency"
|
||||
)
|
||||
VLLM_FLASHINFER_ALLREDUCE_BACKEND: Literal["auto", "trtllm", "mnnvl"] = "auto"
|
||||
VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE: int = 394 * 1024 * 1024
|
||||
VLLM_XGRAMMAR_CACHE_MB: int = 0
|
||||
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
|
||||
@@ -206,6 +207,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ROCM_FP8_MFMA_PAGE_ATTN: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False
|
||||
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True
|
||||
VLLM_ALLREDUCE_USE_FLASHINFER: bool = False
|
||||
VLLM_TUNED_CONFIG_FOLDER: str | None = None
|
||||
VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS: set[str] = set()
|
||||
VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT: bool = False
|
||||
@@ -1290,6 +1292,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"latency",
|
||||
["throughput", "latency", "masked_gemm"],
|
||||
),
|
||||
# Flashinfer fused allreduce backend.
|
||||
# "auto" will default to "mnnvl", which performs mostly same/better than "trtllm".
|
||||
# But "mnnvl" backend does not support fuse with quantization.
|
||||
"VLLM_FLASHINFER_ALLREDUCE_BACKEND": env_with_choices(
|
||||
"VLLM_FLASHINFER_ALLREDUCE_BACKEND",
|
||||
"auto",
|
||||
["auto", "trtllm", "mnnvl"],
|
||||
),
|
||||
# Control the workspace buffer size for the FlashInfer backend.
|
||||
"VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE": lambda: int(
|
||||
os.getenv("VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE", str(394 * 1024 * 1024))
|
||||
@@ -1448,6 +1458,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_ALLREDUCE_USE_SYMM_MEM": lambda: bool(
|
||||
int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1"))
|
||||
),
|
||||
# Whether to use FlashInfer allreduce
|
||||
"VLLM_ALLREDUCE_USE_FLASHINFER": lambda: bool(
|
||||
int(os.getenv("VLLM_ALLREDUCE_USE_FLASHINFER", "0"))
|
||||
),
|
||||
# Experimental: use this to enable MCP tool calling for non harmony models
|
||||
"VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT": lambda: bool(
|
||||
int(os.getenv("VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT", "0"))
|
||||
|
||||
Reference in New Issue
Block a user