[Kernel] Refactor FlashInfer allreduce for mnnvl backend (#34109)

Signed-off-by: hjjq <50634613+hjjq@users.noreply.github.com>
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: Wei Zhao <51183510+wzhao18@users.noreply.github.com>
This commit is contained in:
Hanjie Qiu
2026-02-25 19:17:20 -08:00
committed by GitHub
parent 2aa4140402
commit 71dfce6aa6
7 changed files with 593 additions and 180 deletions

View File

@@ -0,0 +1,252 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
import vllm.envs as envs
from vllm.config.compilation import PassConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__)
fi_ar_available = False
try:
import flashinfer.comm as flashinfer_comm # type: ignore[no-redef]
from flashinfer.comm.mnnvl import (
TorchDistBackend, # type: ignore[import-not-found, no-redef]
)
fi_ar_available = hasattr(flashinfer_comm, "allreduce_fusion")
except ImportError:
pass
# Global workspace for standalone allreduce and non-quant ar+rms fusion
_fi_ar_workspace = None
# Extra workspace for quant fusion patterns (only supported by trtllm backend)
# Only created if primary workspace is not already trtllm
_fi_ar_quant_workspace = None
def get_fi_ar_workspace():
return _fi_ar_workspace
def get_fi_ar_quant_workspace():
return _fi_ar_quant_workspace
def initialize_fi_ar_workspace(
world_size: int,
rank: int,
max_token_num: int,
hidden_dim: int,
dtype: torch.dtype,
group: ProcessGroup,
) -> None:
"""
Initialize the workspace if not already initialized.
Currently, this function is called by either the AllReduceFusionPass
or the FlashInferAllReduce backend for standalone allreduce.
If the fusion pass is enabled via
--compilation-config.pass_config.fuse_allreduce_rms=true,
it will create the workspace first, and the standalone backend
will reuse the workspace. Otherwise, the standalone backend will
create the workspace.
"""
global _fi_ar_workspace
if _fi_ar_workspace is not None:
return
backend = envs.VLLM_FLASHINFER_ALLREDUCE_BACKEND
comm_backend = TorchDistBackend(group=group)
_fi_ar_workspace = flashinfer_comm.create_allreduce_fusion_workspace(
backend=backend,
world_size=world_size,
rank=rank,
max_token_num=max_token_num,
hidden_dim=hidden_dim,
dtype=dtype,
comm_backend=comm_backend,
)
assert _fi_ar_workspace is not None
logger.debug(
"Initialized FlashInfer All Reduce workspace: backend=%s, "
"world_size=%d, rank=%d, max_token_num=%d, hidden_dim=%d, dtype=%s",
backend,
world_size,
rank,
max_token_num,
hidden_dim,
dtype,
)
def initialize_fi_ar_quant_workspace(
world_size: int,
rank: int,
max_token_num: int,
hidden_dim: int,
dtype: torch.dtype,
group: ProcessGroup,
) -> None:
"""
Initialize the workspace used by quantization fusion patterns.
Currently this always creates a workspace for trtllm backend as only it
supports quantization fusion (FP8/FP4). If the primary workspace
is already trtllm, the quant workspace aliases to it.
"""
global _fi_ar_quant_workspace
if _fi_ar_quant_workspace is not None:
return
# If primary workspace is already trtllm, reuse it
if _fi_ar_workspace is not None and _fi_ar_workspace.backend == "trtllm":
_fi_ar_quant_workspace = _fi_ar_workspace
return
comm_backend = TorchDistBackend(group=group)
_fi_ar_quant_workspace = flashinfer_comm.create_allreduce_fusion_workspace(
backend="trtllm",
world_size=world_size,
rank=rank,
max_token_num=max_token_num,
hidden_dim=hidden_dim,
dtype=dtype,
comm_backend=comm_backend,
)
assert _fi_ar_quant_workspace is not None
logger.debug(
"Initialized FlashInfer All Reduce workspace: backend=trtllm, "
"world_size=%d, rank=%d, max_token_num=%d, hidden_dim=%d, dtype=%s",
world_size,
rank,
max_token_num,
hidden_dim,
dtype,
)
def destroy_fi_ar_workspace():
global _fi_ar_workspace
global _fi_ar_quant_workspace
if (
_fi_ar_quant_workspace is not None
and _fi_ar_quant_workspace is not _fi_ar_workspace
):
_fi_ar_quant_workspace.destroy()
_fi_ar_quant_workspace = None
if _fi_ar_workspace is not None:
_fi_ar_workspace.destroy()
_fi_ar_workspace = None
class FlashInferAllReduce:
def __init__(
self,
group: ProcessGroup,
device: int | str | torch.device,
):
self.disabled = True
if not fi_ar_available:
logger.info(
"FlashInfer All Reduce is disabled because flashinfer is not available"
)
return
if not current_platform.is_cuda():
logger.info(
"FlashInfer All Reduce is disabled because it requires CUDA platform"
)
return
self.group = group
self.world_size = dist.get_world_size(self.group)
self.rank = dist.get_rank(self.group)
self.device = device
if self.world_size == 1:
return
# Use the same threshold as the allreduce-rms fusion pass
# TODO: tune the threshold
MiB = 1024 * 1024
max_workspace_size = PassConfig.default_fi_allreduce_fusion_max_size_mb().get(
self.world_size, None
)
if not max_workspace_size:
logger.warning(
"FlashInfer All Reduce is disabled because it "
"is not supported for world_size=%d.",
self.world_size,
)
return
self.max_workspace_size = max_workspace_size * MiB
self.max_num_tokens = 0
self.disabled = False
def _ensure_workspace(self, hidden_dim: int, dtype: torch.dtype) -> bool:
"""Ensure the all reduce workspace is initialized."""
if get_fi_ar_workspace() is not None:
return True
if self.max_num_tokens == 0:
element_size = torch.tensor([], dtype=dtype, device="cpu").element_size()
self.max_num_tokens = self.max_workspace_size // (hidden_dim * element_size)
try:
initialize_fi_ar_workspace(
world_size=self.world_size,
rank=self.rank,
max_token_num=self.max_num_tokens,
hidden_dim=hidden_dim,
dtype=dtype,
group=self.group,
)
return True
except Exception as e:
logger.warning(
"Failed to initialize FlashInfer All Reduce workspace: %s. "
"FlashInfer All Reduce will be disabled.",
e,
)
self.disabled = True
return False
def should_use_fi_ar(self, input_tensor: torch.Tensor) -> bool:
if self.disabled:
return False
if not input_tensor.is_cuda:
return False
if not input_tensor.is_contiguous():
return False
if len(input_tensor.shape) != 2:
return False
num_tokens, hidden_dim = input_tensor.shape
if not self.max_num_tokens:
element_size = torch.tensor([], dtype=input_tensor.dtype).element_size()
self.max_num_tokens = self.max_workspace_size // (hidden_dim * element_size)
if num_tokens > self.max_num_tokens:
return False
return self._ensure_workspace(hidden_dim, input_tensor.dtype)
def all_reduce(self, input_tensor: torch.Tensor) -> torch.Tensor:
workspace = get_fi_ar_workspace()
return flashinfer_comm.allreduce_fusion(
input=input_tensor,
workspace=workspace,
pattern=flashinfer_comm.AllReduceFusionPattern.kAllReduce,
)
def destroy(self):
if not self.disabled:
destroy_fi_ar_workspace()