Files
vllm/vllm/distributed/device_communicators/flashinfer_all_reduce.py
2026-03-17 15:19:52 -04:00

270 lines
8.1 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import atexit
import os
import random
import threading
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
import vllm.envs as envs
from vllm.config.compilation import PassConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__)
fi_ar_available = False
try:
import flashinfer.comm as flashinfer_comm # type: ignore[no-redef]
from flashinfer.comm.mnnvl import (
TorchDistBackend, # type: ignore[import-not-found, no-redef]
)
fi_ar_available = hasattr(flashinfer_comm, "allreduce_fusion")
except ImportError:
pass
# Global workspace for standalone allreduce and non-quant ar+rms fusion
_fi_ar_workspace = None
# Extra workspace for quant fusion patterns (only supported by trtllm backend)
# Only created if primary workspace is not already trtllm
_fi_ar_quant_workspace = None
def get_fi_ar_workspace():
return _fi_ar_workspace
def get_fi_ar_quant_workspace():
return _fi_ar_quant_workspace
def initialize_fi_ar_workspace(
world_size: int,
rank: int,
max_token_num: int,
hidden_dim: int,
dtype: torch.dtype,
group: ProcessGroup,
) -> None:
"""
Initialize the workspace if not already initialized.
Currently, this function is called by either the AllReduceFusionPass
or the FlashInferAllReduce backend for standalone allreduce.
If the fusion pass is enabled via
--compilation-config.pass_config.fuse_allreduce_rms=true,
it will create the workspace first, and the standalone backend
will reuse the workspace. Otherwise, the standalone backend will
create the workspace.
"""
global _fi_ar_workspace
if _fi_ar_workspace is not None:
return
backend = envs.VLLM_FLASHINFER_ALLREDUCE_BACKEND
comm_backend = TorchDistBackend(group=group)
rng_state = random.getstate()
try:
random.seed(int.from_bytes(os.urandom(16), byteorder="big"))
_fi_ar_workspace = flashinfer_comm.create_allreduce_fusion_workspace(
backend=backend,
world_size=world_size,
rank=rank,
max_token_num=max_token_num,
hidden_dim=hidden_dim,
dtype=dtype,
comm_backend=comm_backend,
)
finally:
random.setstate(rng_state)
assert _fi_ar_workspace is not None
logger.debug(
"Initialized FlashInfer All Reduce workspace: backend=%s, "
"world_size=%d, rank=%d, max_token_num=%d, hidden_dim=%d, dtype=%s",
backend,
world_size,
rank,
max_token_num,
hidden_dim,
dtype,
)
def initialize_fi_ar_quant_workspace(
world_size: int,
rank: int,
max_token_num: int,
hidden_dim: int,
dtype: torch.dtype,
group: ProcessGroup,
) -> None:
"""
Initialize the workspace used by quantization fusion patterns.
Currently this always creates a workspace for trtllm backend as only it
supports quantization fusion (FP8/FP4). If the primary workspace
is already trtllm, the quant workspace aliases to it.
"""
global _fi_ar_quant_workspace
if _fi_ar_quant_workspace is not None:
return
# If primary workspace is already trtllm, reuse it
if _fi_ar_workspace is not None and _fi_ar_workspace.backend == "trtllm":
_fi_ar_quant_workspace = _fi_ar_workspace
return
comm_backend = TorchDistBackend(group=group)
_fi_ar_quant_workspace = flashinfer_comm.create_allreduce_fusion_workspace(
backend="trtllm",
world_size=world_size,
rank=rank,
max_token_num=max_token_num,
hidden_dim=hidden_dim,
dtype=dtype,
comm_backend=comm_backend,
)
assert _fi_ar_quant_workspace is not None
logger.debug(
"Initialized FlashInfer All Reduce workspace: backend=trtllm, "
"world_size=%d, rank=%d, max_token_num=%d, hidden_dim=%d, dtype=%s",
world_size,
rank,
max_token_num,
hidden_dim,
dtype,
)
_fi_ar_workspace_lock = threading.Lock()
def destroy_fi_ar_workspace():
global _fi_ar_workspace
global _fi_ar_quant_workspace
with _fi_ar_workspace_lock:
if (
_fi_ar_quant_workspace is not None
and _fi_ar_quant_workspace is not _fi_ar_workspace
):
_fi_ar_quant_workspace.destroy()
_fi_ar_quant_workspace = None
if _fi_ar_workspace is not None:
_fi_ar_workspace.destroy()
_fi_ar_workspace = None
atexit.register(destroy_fi_ar_workspace)
class FlashInferAllReduce:
def __init__(
self,
group: ProcessGroup,
device: int | str | torch.device,
):
self.disabled = True
if not fi_ar_available:
logger.info(
"FlashInfer All Reduce is disabled because flashinfer is not available"
)
return
if not current_platform.is_cuda():
logger.info(
"FlashInfer All Reduce is disabled because it requires CUDA platform"
)
return
self.group = group
self.world_size = dist.get_world_size(self.group)
self.rank = dist.get_rank(self.group)
self.device = device
if self.world_size == 1:
return
# Use the same threshold as the allreduce-rms fusion pass
# TODO: tune the threshold
MiB = 1024 * 1024
max_workspace_size = PassConfig.default_fi_allreduce_fusion_max_size_mb().get(
self.world_size, None
)
if not max_workspace_size:
logger.warning(
"FlashInfer All Reduce is disabled because it "
"is not supported for world_size=%d.",
self.world_size,
)
return
self.max_workspace_size = max_workspace_size * MiB
self.max_num_tokens = 0
self.disabled = False
def _ensure_workspace(self, hidden_dim: int, dtype: torch.dtype) -> bool:
"""Ensure the all reduce workspace is initialized."""
if get_fi_ar_workspace() is not None:
return True
if self.max_num_tokens == 0:
element_size = torch.tensor([], dtype=dtype, device="cpu").element_size()
self.max_num_tokens = self.max_workspace_size // (hidden_dim * element_size)
try:
initialize_fi_ar_workspace(
world_size=self.world_size,
rank=self.rank,
max_token_num=self.max_num_tokens,
hidden_dim=hidden_dim,
dtype=dtype,
group=self.group,
)
return True
except Exception as e:
logger.warning(
"Failed to initialize FlashInfer All Reduce workspace: %s. "
"FlashInfer All Reduce will be disabled.",
e,
)
self.disabled = True
return False
def should_use_fi_ar(self, input_tensor: torch.Tensor) -> bool:
if self.disabled:
return False
if not input_tensor.is_cuda:
return False
if not input_tensor.is_contiguous():
return False
if len(input_tensor.shape) != 2:
return False
num_tokens, hidden_dim = input_tensor.shape
if not self.max_num_tokens:
element_size = torch.tensor([], dtype=input_tensor.dtype).element_size()
self.max_num_tokens = self.max_workspace_size // (hidden_dim * element_size)
if num_tokens > self.max_num_tokens:
return False
return self._ensure_workspace(hidden_dim, input_tensor.dtype)
def all_reduce(self, input_tensor: torch.Tensor) -> torch.Tensor:
workspace = get_fi_ar_workspace()
return flashinfer_comm.allreduce_fusion(
input=input_tensor,
workspace=workspace,
pattern=flashinfer_comm.AllReduceFusionPattern.kAllReduce,
)
def destroy(self):
if not self.disabled:
destroy_fi_ar_workspace()