[Kernel] FlashInfer: switch allreduce fusion to unified API (#33985)
Signed-off-by: Mohammad Miadh Angkad <176301910+mmangkad@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
cb62e86f83
commit
d4f123cc48
@@ -5,7 +5,7 @@
|
||||
Benchmark for FlashInfer fused collective operations vs standard operations.
|
||||
|
||||
This benchmark compares:
|
||||
1. FlashInfer's trtllm_allreduce_fusion (fused allreduce + rmsnorm + optional quant)
|
||||
1. FlashInfer's allreduce_fusion (fused allreduce + rmsnorm + optional quant)
|
||||
2. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations
|
||||
|
||||
Usage with torchrun:
|
||||
@@ -24,7 +24,6 @@ import torch.distributed as dist # type: ignore
|
||||
|
||||
from vllm.config.vllm import CompilationConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed import (
|
||||
get_tp_group,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.distributed.parallel_state import (
|
||||
@@ -52,11 +51,12 @@ logger = init_logger(__name__)
|
||||
try:
|
||||
import flashinfer.comm as flashinfer_comm # type: ignore
|
||||
|
||||
if not hasattr(flashinfer_comm, "trtllm_allreduce_fusion"):
|
||||
if not (
|
||||
hasattr(flashinfer_comm, "allreduce_fusion")
|
||||
and hasattr(flashinfer_comm, "create_allreduce_fusion_workspace")
|
||||
):
|
||||
flashinfer_comm = None
|
||||
logger.warning(
|
||||
"FlashInfer comm module found but missing trtllm_allreduce_fusion"
|
||||
)
|
||||
logger.warning("FlashInfer comm module found but missing allreduce_fusion API")
|
||||
except ImportError:
|
||||
flashinfer_comm = None
|
||||
logger.warning("FlashInfer not found, only benchmarking standard operations")
|
||||
@@ -75,7 +75,7 @@ _FI_MAX_SIZES = {
|
||||
}
|
||||
|
||||
# Global workspace tensor for FlashInfer
|
||||
_FI_WORKSPACE_TENSOR = None
|
||||
_FI_WORKSPACE = None
|
||||
|
||||
|
||||
def setup_flashinfer_workspace(
|
||||
@@ -83,10 +83,10 @@ def setup_flashinfer_workspace(
|
||||
rank: int,
|
||||
hidden_dim: int,
|
||||
max_token_num: int,
|
||||
use_fp32_lamport: bool = False,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Setup FlashInfer workspace for fused allreduce operations."""
|
||||
global _FI_WORKSPACE_TENSOR
|
||||
global _FI_WORKSPACE
|
||||
|
||||
if flashinfer_comm is None:
|
||||
return None, None
|
||||
@@ -96,33 +96,29 @@ def setup_flashinfer_workspace(
|
||||
return None, None
|
||||
|
||||
try:
|
||||
# Create IPC workspace
|
||||
ipc_handles, workspace_tensor = (
|
||||
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
|
||||
tp_rank=rank,
|
||||
tp_size=world_size,
|
||||
max_token_num=max_token_num,
|
||||
hidden_dim=hidden_dim,
|
||||
group=get_tp_group().device_group,
|
||||
use_fp32_lamport=use_fp32_lamport,
|
||||
)
|
||||
workspace = flashinfer_comm.create_allreduce_fusion_workspace(
|
||||
backend="trtllm",
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
max_token_num=max_token_num,
|
||||
hidden_dim=hidden_dim,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
_FI_WORKSPACE_TENSOR = workspace_tensor
|
||||
return ipc_handles, workspace_tensor
|
||||
_FI_WORKSPACE = workspace
|
||||
return workspace
|
||||
except Exception as e:
|
||||
logger.error("Failed to setup FlashInfer workspace: %s", e)
|
||||
return None, None
|
||||
return None
|
||||
|
||||
|
||||
def cleanup_flashinfer_workspace(ipc_handles):
|
||||
def cleanup_flashinfer_workspace(workspace):
|
||||
"""Cleanup FlashInfer workspace."""
|
||||
if flashinfer_comm is None or ipc_handles is None:
|
||||
if flashinfer_comm is None or workspace is None:
|
||||
return
|
||||
|
||||
try:
|
||||
group = get_tp_group().device_group
|
||||
flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(ipc_handles, group)
|
||||
workspace.destroy()
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup FlashInfer workspace: %s", e)
|
||||
|
||||
@@ -132,25 +128,15 @@ class FlashInferFusedAllReduceParams:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
use_fp32_lamport: bool = False,
|
||||
max_token_num: int = 1024,
|
||||
):
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.use_fp32_lamport = use_fp32_lamport
|
||||
self.trigger_completion_at_end = True
|
||||
self.launch_with_pdl = True
|
||||
self.fp32_acc = True
|
||||
self.max_token_num = max_token_num
|
||||
|
||||
def get_trtllm_fused_allreduce_kwargs(self):
|
||||
return {
|
||||
"world_rank": self.rank,
|
||||
"world_size": self.world_size,
|
||||
"launch_with_pdl": self.launch_with_pdl,
|
||||
"trigger_completion_at_end": self.trigger_completion_at_end,
|
||||
"fp32_acc": self.fp32_acc,
|
||||
}
|
||||
|
||||
@@ -165,7 +151,7 @@ def flashinfer_fused_allreduce_rmsnorm(
|
||||
norm_out: torch.Tensor | None = None,
|
||||
):
|
||||
"""FlashInfer fused allreduce + rmsnorm operation."""
|
||||
if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None:
|
||||
if flashinfer_comm is None or _FI_WORKSPACE is None:
|
||||
raise RuntimeError("FlashInfer not available or workspace not initialized")
|
||||
|
||||
if norm_out is None:
|
||||
@@ -174,18 +160,15 @@ def flashinfer_fused_allreduce_rmsnorm(
|
||||
else:
|
||||
residual_out = input_tensor
|
||||
|
||||
flashinfer_comm.trtllm_allreduce_fusion(
|
||||
allreduce_in=input_tensor,
|
||||
token_num=input_tensor.shape[0],
|
||||
flashinfer_comm.allreduce_fusion(
|
||||
input=input_tensor,
|
||||
workspace=_FI_WORKSPACE,
|
||||
pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
|
||||
residual_in=residual,
|
||||
residual_out=residual_out,
|
||||
norm_out=norm_out,
|
||||
rms_gamma=rms_gamma,
|
||||
rms_eps=rms_eps,
|
||||
hidden_dim=input_tensor.shape[-1],
|
||||
workspace_ptrs=_FI_WORKSPACE_TENSOR,
|
||||
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
|
||||
allreduce_out=None,
|
||||
quant_out=None,
|
||||
scale_out=None,
|
||||
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
|
||||
@@ -207,7 +190,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp8_quant(
|
||||
quant_out: torch.Tensor | None = None,
|
||||
):
|
||||
"""FlashInfer fused allreduce + rmsnorm + FP8 quantization."""
|
||||
if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None:
|
||||
if flashinfer_comm is None or _FI_WORKSPACE is None:
|
||||
raise RuntimeError("FlashInfer not available or workspace not initialized")
|
||||
|
||||
if norm_out is None:
|
||||
@@ -216,18 +199,15 @@ def flashinfer_fused_allreduce_rmsnorm_fp8_quant(
|
||||
else:
|
||||
residual_out = input_tensor
|
||||
|
||||
flashinfer_comm.trtllm_allreduce_fusion(
|
||||
allreduce_in=input_tensor,
|
||||
token_num=input_tensor.shape[0],
|
||||
flashinfer_comm.allreduce_fusion(
|
||||
input=input_tensor,
|
||||
workspace=_FI_WORKSPACE,
|
||||
pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant,
|
||||
residual_in=residual,
|
||||
residual_out=residual_out,
|
||||
norm_out=norm_out,
|
||||
rms_gamma=rms_gamma,
|
||||
rms_eps=rms_eps,
|
||||
hidden_dim=input_tensor.shape[-1],
|
||||
workspace_ptrs=_FI_WORKSPACE_TENSOR,
|
||||
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant,
|
||||
allreduce_out=None,
|
||||
quant_out=quant_out,
|
||||
scale_out=None,
|
||||
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
|
||||
@@ -250,7 +230,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant(
|
||||
norm_out: torch.Tensor | None = None,
|
||||
):
|
||||
"""FlashInfer fused allreduce + rmsnorm + FP4 quantization."""
|
||||
if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None:
|
||||
if flashinfer_comm is None or _FI_WORKSPACE is None:
|
||||
raise RuntimeError("FlashInfer not available or workspace not initialized")
|
||||
|
||||
if norm_out is None:
|
||||
@@ -259,18 +239,15 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant(
|
||||
else:
|
||||
residual_out = input_tensor
|
||||
|
||||
flashinfer_comm.trtllm_allreduce_fusion(
|
||||
allreduce_in=input_tensor,
|
||||
token_num=input_tensor.shape[0],
|
||||
flashinfer_comm.allreduce_fusion(
|
||||
input=input_tensor,
|
||||
workspace=_FI_WORKSPACE,
|
||||
pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant,
|
||||
residual_in=residual,
|
||||
residual_out=residual_out,
|
||||
norm_out=norm_out,
|
||||
rms_gamma=rms_gamma,
|
||||
rms_eps=rms_eps,
|
||||
hidden_dim=input_tensor.shape[-1],
|
||||
workspace_ptrs=_FI_WORKSPACE_TENSOR,
|
||||
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant,
|
||||
allreduce_out=None,
|
||||
quant_out=quant_out,
|
||||
scale_out=output_scale,
|
||||
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
|
||||
@@ -1040,23 +1017,31 @@ def main():
|
||||
configs = list(itertools.product(args.num_tokens, dtypes, residual_options))
|
||||
|
||||
# Setup FlashInfer workspace if available
|
||||
ipc_handles = None
|
||||
workspace = None
|
||||
allreduce_params = None
|
||||
|
||||
if flashinfer_comm is not None:
|
||||
# Use the largest hidden dimension for workspace setup
|
||||
max_element_size = max(torch.finfo(dt).bits // 8 for dt in dtypes)
|
||||
workspace_dtype = (
|
||||
torch.float32
|
||||
if max_element_size == 4
|
||||
else (torch.bfloat16 if torch.bfloat16 in dtypes else torch.float16)
|
||||
)
|
||||
max_num_token = _FI_MAX_SIZES.get(world_size) // (
|
||||
args.hidden_dim * world_size * 2
|
||||
args.hidden_dim * max_element_size
|
||||
)
|
||||
|
||||
ipc_handles, workspace_tensor = setup_flashinfer_workspace(
|
||||
world_size, rank, args.hidden_dim, max_num_token
|
||||
workspace = setup_flashinfer_workspace(
|
||||
world_size,
|
||||
rank,
|
||||
args.hidden_dim,
|
||||
max_num_token,
|
||||
dtype=workspace_dtype,
|
||||
)
|
||||
|
||||
if workspace_tensor is not None:
|
||||
if workspace is not None:
|
||||
allreduce_params = FlashInferFusedAllReduceParams(
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
max_token_num=max_num_token,
|
||||
)
|
||||
|
||||
@@ -1119,8 +1104,8 @@ def main():
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
if ipc_handles is not None:
|
||||
cleanup_flashinfer_workspace(ipc_handles)
|
||||
if workspace is not None:
|
||||
cleanup_flashinfer_workspace(workspace)
|
||||
|
||||
dist.barrier()
|
||||
|
||||
|
||||
@@ -202,9 +202,10 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||
@pytest.mark.skipif(
|
||||
not find_spec("flashinfer")
|
||||
or not has_module_attribute("flashinfer.comm", "trtllm_allreduce_fusion"),
|
||||
or not has_module_attribute("flashinfer.comm", "allreduce_fusion")
|
||||
or not has_module_attribute("flashinfer.comm", "create_allreduce_fusion_workspace"),
|
||||
reason="flashinfer is not found or flashinfer "
|
||||
"is not compiled with trtllm_allreduce_fusion",
|
||||
"is not compiled with allreduce_fusion",
|
||||
)
|
||||
def test_all_reduce_fusion_pass_replace(
|
||||
test_model: torch.nn.Module,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
from importlib.util import find_spec
|
||||
from types import ModuleType
|
||||
|
||||
@@ -36,7 +37,9 @@ if find_spec("flashinfer"):
|
||||
try:
|
||||
import flashinfer.comm as _flashinfer_comm
|
||||
|
||||
if hasattr(_flashinfer_comm, "trtllm_allreduce_fusion"):
|
||||
if hasattr(_flashinfer_comm, "allreduce_fusion") and hasattr(
|
||||
_flashinfer_comm, "create_allreduce_fusion_workspace"
|
||||
):
|
||||
flashinfer_comm = _flashinfer_comm
|
||||
except ImportError:
|
||||
pass
|
||||
@@ -79,7 +82,7 @@ _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB: dict[int, dict[int, float]] = {
|
||||
|
||||
|
||||
if flashinfer_comm is not None:
|
||||
_FI_WORKSPACE_TENSOR = None
|
||||
_FI_WORKSPACE = None
|
||||
MiB = 1024 * 1024
|
||||
|
||||
def call_trtllm_fused_allreduce_norm(
|
||||
@@ -87,10 +90,8 @@ if flashinfer_comm is not None:
|
||||
residual: torch.Tensor,
|
||||
rms_gamma: torch.Tensor,
|
||||
rms_eps: float,
|
||||
world_rank: int,
|
||||
world_size: int,
|
||||
launch_with_pdl: bool,
|
||||
trigger_completion_at_end: bool,
|
||||
fp32_acc: bool,
|
||||
max_token_num: int,
|
||||
pattern_code: int,
|
||||
@@ -121,7 +122,7 @@ if flashinfer_comm is not None:
|
||||
max_one_shot_size is None or current_tensor_size <= max_one_shot_size * MiB
|
||||
)
|
||||
|
||||
assert _FI_WORKSPACE_TENSOR is not None, (
|
||||
assert _FI_WORKSPACE is not None, (
|
||||
"Flashinfer must be enabled when using flashinfer"
|
||||
)
|
||||
if norm_out is None:
|
||||
@@ -134,24 +135,18 @@ if flashinfer_comm is not None:
|
||||
residual_out = allreduce_in
|
||||
# For the sizes that are smaller than the max size,
|
||||
# we only use flashinfer one shot allreduce
|
||||
flashinfer_comm.trtllm_allreduce_fusion(
|
||||
allreduce_in=allreduce_in,
|
||||
token_num=allreduce_in.shape[0],
|
||||
flashinfer_comm.allreduce_fusion(
|
||||
input=allreduce_in,
|
||||
workspace=_FI_WORKSPACE,
|
||||
pattern=pattern_code,
|
||||
residual_in=residual,
|
||||
residual_out=residual_out,
|
||||
norm_out=norm_out,
|
||||
rms_gamma=rms_gamma,
|
||||
rms_eps=rms_eps,
|
||||
world_rank=world_rank,
|
||||
world_size=world_size,
|
||||
hidden_dim=allreduce_in.shape[-1],
|
||||
workspace_ptrs=_FI_WORKSPACE_TENSOR,
|
||||
launch_with_pdl=launch_with_pdl,
|
||||
use_oneshot=use_oneshot,
|
||||
trigger_completion_at_end=trigger_completion_at_end,
|
||||
fp32_acc=fp32_acc,
|
||||
pattern_code=pattern_code,
|
||||
allreduce_out=None,
|
||||
quant_out=quant_out,
|
||||
scale_out=scale_out,
|
||||
# in vllm we only support swizzled layout
|
||||
@@ -164,10 +159,8 @@ if flashinfer_comm is not None:
|
||||
residual: torch.Tensor,
|
||||
rms_gamma: torch.Tensor,
|
||||
rms_eps: float,
|
||||
world_rank: int,
|
||||
world_size: int,
|
||||
launch_with_pdl: bool,
|
||||
trigger_completion_at_end: bool,
|
||||
fp32_acc: bool,
|
||||
max_token_num: int,
|
||||
pattern_code: int,
|
||||
@@ -200,25 +193,18 @@ class FlashInferFusedAllReduceParams:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
use_fp32_lamport: bool = False,
|
||||
max_token_num: int = 1024,
|
||||
) -> None:
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.use_fp32_lamport = use_fp32_lamport
|
||||
self.trigger_completion_at_end = True
|
||||
self.launch_with_pdl = True
|
||||
self.fp32_acc = True
|
||||
self.max_token_num = max_token_num
|
||||
|
||||
def get_trtllm_fused_allreduce_kwargs(self) -> dict[str, bool | int]:
|
||||
return {
|
||||
"world_rank": self.rank,
|
||||
"world_size": self.world_size,
|
||||
"launch_with_pdl": self.launch_with_pdl,
|
||||
"trigger_completion_at_end": self.trigger_completion_at_end,
|
||||
"fp32_acc": self.fp32_acc,
|
||||
"max_token_num": self.max_token_num,
|
||||
}
|
||||
@@ -712,7 +698,6 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
|
||||
self.hidden_dim = config.model_config.get_hidden_size()
|
||||
self.group = get_tp_group().device_group
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
use_fp32_lamport = self.model_dtype == torch.float32
|
||||
if flashinfer_comm is None:
|
||||
logger.warning(
|
||||
"Flashinfer is not installed or comm module not found, "
|
||||
@@ -730,7 +715,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
|
||||
self.tp_size,
|
||||
)
|
||||
return
|
||||
element_size = 4 if use_fp32_lamport else 2
|
||||
element_size = torch.tensor([], dtype=self.model_dtype).element_size()
|
||||
self.max_token_num = max_size // (self.hidden_dim * element_size)
|
||||
# take the min to save workspace size and we'll never use more
|
||||
# than max_num_batched_tokens anyways
|
||||
@@ -744,23 +729,19 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
|
||||
scope="global",
|
||||
)
|
||||
|
||||
self.ipc_handles, workspace_tensor = (
|
||||
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
|
||||
tp_rank=rank,
|
||||
tp_size=self.tp_size,
|
||||
max_token_num=self.max_token_num,
|
||||
hidden_dim=self.hidden_dim,
|
||||
group=self.group,
|
||||
use_fp32_lamport=use_fp32_lamport,
|
||||
)
|
||||
self.workspace = flashinfer_comm.create_allreduce_fusion_workspace(
|
||||
backend="trtllm",
|
||||
world_size=self.tp_size,
|
||||
rank=rank,
|
||||
max_token_num=self.max_token_num,
|
||||
hidden_dim=self.hidden_dim,
|
||||
dtype=self.model_dtype,
|
||||
)
|
||||
|
||||
global _FI_WORKSPACE_TENSOR
|
||||
_FI_WORKSPACE_TENSOR = workspace_tensor
|
||||
global _FI_WORKSPACE
|
||||
_FI_WORKSPACE = self.workspace
|
||||
self.allreduce_params = FlashInferFusedAllReduceParams(
|
||||
rank=rank,
|
||||
world_size=self.tp_size,
|
||||
use_fp32_lamport=use_fp32_lamport,
|
||||
max_token_num=self.max_token_num,
|
||||
)
|
||||
|
||||
@@ -832,7 +813,6 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
|
||||
def __del__(self) -> None:
|
||||
if getattr(self, "disabled", True):
|
||||
return
|
||||
if flashinfer_comm is not None:
|
||||
flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(
|
||||
self.ipc_handles, self.group
|
||||
)
|
||||
if getattr(self, "workspace", None) is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
self.workspace.destroy()
|
||||
|
||||
Reference in New Issue
Block a user