270 lines
8.1 KiB
Python
270 lines
8.1 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
|
|
import atexit
|
|
import os
|
|
import random
|
|
import threading
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.distributed import ProcessGroup
|
|
|
|
import vllm.envs as envs
|
|
from vllm.config.compilation import PassConfig
|
|
from vllm.logger import init_logger
|
|
from vllm.platforms import current_platform
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
fi_ar_available = False
|
|
try:
|
|
import flashinfer.comm as flashinfer_comm # type: ignore[no-redef]
|
|
from flashinfer.comm.mnnvl import (
|
|
TorchDistBackend, # type: ignore[import-not-found, no-redef]
|
|
)
|
|
|
|
fi_ar_available = hasattr(flashinfer_comm, "allreduce_fusion")
|
|
except ImportError:
|
|
pass
|
|
|
|
# Global workspace for standalone allreduce and non-quant ar+rms fusion
|
|
_fi_ar_workspace = None
|
|
# Extra workspace for quant fusion patterns (only supported by trtllm backend)
|
|
# Only created if primary workspace is not already trtllm
|
|
_fi_ar_quant_workspace = None
|
|
|
|
|
|
def get_fi_ar_workspace():
|
|
return _fi_ar_workspace
|
|
|
|
|
|
def get_fi_ar_quant_workspace():
|
|
return _fi_ar_quant_workspace
|
|
|
|
|
|
def initialize_fi_ar_workspace(
|
|
world_size: int,
|
|
rank: int,
|
|
max_token_num: int,
|
|
hidden_dim: int,
|
|
dtype: torch.dtype,
|
|
group: ProcessGroup,
|
|
) -> None:
|
|
"""
|
|
Initialize the workspace if not already initialized.
|
|
|
|
Currently, this function is called by either the AllReduceFusionPass
|
|
or the FlashInferAllReduce backend for standalone allreduce.
|
|
If the fusion pass is enabled via
|
|
--compilation-config.pass_config.fuse_allreduce_rms=true,
|
|
it will create the workspace first, and the standalone backend
|
|
will reuse the workspace. Otherwise, the standalone backend will
|
|
create the workspace.
|
|
"""
|
|
global _fi_ar_workspace
|
|
if _fi_ar_workspace is not None:
|
|
return
|
|
|
|
backend = envs.VLLM_FLASHINFER_ALLREDUCE_BACKEND
|
|
comm_backend = TorchDistBackend(group=group)
|
|
rng_state = random.getstate()
|
|
try:
|
|
random.seed(int.from_bytes(os.urandom(16), byteorder="big"))
|
|
_fi_ar_workspace = flashinfer_comm.create_allreduce_fusion_workspace(
|
|
backend=backend,
|
|
world_size=world_size,
|
|
rank=rank,
|
|
max_token_num=max_token_num,
|
|
hidden_dim=hidden_dim,
|
|
dtype=dtype,
|
|
comm_backend=comm_backend,
|
|
)
|
|
finally:
|
|
random.setstate(rng_state)
|
|
assert _fi_ar_workspace is not None
|
|
logger.debug(
|
|
"Initialized FlashInfer All Reduce workspace: backend=%s, "
|
|
"world_size=%d, rank=%d, max_token_num=%d, hidden_dim=%d, dtype=%s",
|
|
backend,
|
|
world_size,
|
|
rank,
|
|
max_token_num,
|
|
hidden_dim,
|
|
dtype,
|
|
)
|
|
|
|
|
|
def initialize_fi_ar_quant_workspace(
|
|
world_size: int,
|
|
rank: int,
|
|
max_token_num: int,
|
|
hidden_dim: int,
|
|
dtype: torch.dtype,
|
|
group: ProcessGroup,
|
|
) -> None:
|
|
"""
|
|
Initialize the workspace used by quantization fusion patterns.
|
|
|
|
Currently this always creates a workspace for trtllm backend as only it
|
|
supports quantization fusion (FP8/FP4). If the primary workspace
|
|
is already trtllm, the quant workspace aliases to it.
|
|
"""
|
|
global _fi_ar_quant_workspace
|
|
if _fi_ar_quant_workspace is not None:
|
|
return
|
|
|
|
# If primary workspace is already trtllm, reuse it
|
|
if _fi_ar_workspace is not None and _fi_ar_workspace.backend == "trtllm":
|
|
_fi_ar_quant_workspace = _fi_ar_workspace
|
|
return
|
|
|
|
comm_backend = TorchDistBackend(group=group)
|
|
_fi_ar_quant_workspace = flashinfer_comm.create_allreduce_fusion_workspace(
|
|
backend="trtllm",
|
|
world_size=world_size,
|
|
rank=rank,
|
|
max_token_num=max_token_num,
|
|
hidden_dim=hidden_dim,
|
|
dtype=dtype,
|
|
comm_backend=comm_backend,
|
|
)
|
|
assert _fi_ar_quant_workspace is not None
|
|
logger.debug(
|
|
"Initialized FlashInfer All Reduce workspace: backend=trtllm, "
|
|
"world_size=%d, rank=%d, max_token_num=%d, hidden_dim=%d, dtype=%s",
|
|
world_size,
|
|
rank,
|
|
max_token_num,
|
|
hidden_dim,
|
|
dtype,
|
|
)
|
|
|
|
|
|
_fi_ar_workspace_lock = threading.Lock()
|
|
|
|
|
|
def destroy_fi_ar_workspace():
|
|
global _fi_ar_workspace
|
|
global _fi_ar_quant_workspace
|
|
with _fi_ar_workspace_lock:
|
|
if (
|
|
_fi_ar_quant_workspace is not None
|
|
and _fi_ar_quant_workspace is not _fi_ar_workspace
|
|
):
|
|
_fi_ar_quant_workspace.destroy()
|
|
_fi_ar_quant_workspace = None
|
|
if _fi_ar_workspace is not None:
|
|
_fi_ar_workspace.destroy()
|
|
_fi_ar_workspace = None
|
|
|
|
|
|
atexit.register(destroy_fi_ar_workspace)
|
|
|
|
|
|
class FlashInferAllReduce:
|
|
def __init__(
|
|
self,
|
|
group: ProcessGroup,
|
|
device: int | str | torch.device,
|
|
):
|
|
self.disabled = True
|
|
|
|
if not fi_ar_available:
|
|
logger.info(
|
|
"FlashInfer All Reduce is disabled because flashinfer is not available"
|
|
)
|
|
return
|
|
|
|
if not current_platform.is_cuda():
|
|
logger.info(
|
|
"FlashInfer All Reduce is disabled because it requires CUDA platform"
|
|
)
|
|
return
|
|
|
|
self.group = group
|
|
self.world_size = dist.get_world_size(self.group)
|
|
self.rank = dist.get_rank(self.group)
|
|
self.device = device
|
|
if self.world_size == 1:
|
|
return
|
|
|
|
# Use the same threshold as the allreduce-rms fusion pass
|
|
# TODO: tune the threshold
|
|
MiB = 1024 * 1024
|
|
max_workspace_size = PassConfig.default_fi_allreduce_fusion_max_size_mb().get(
|
|
self.world_size, None
|
|
)
|
|
if not max_workspace_size:
|
|
logger.warning(
|
|
"FlashInfer All Reduce is disabled because it "
|
|
"is not supported for world_size=%d.",
|
|
self.world_size,
|
|
)
|
|
return
|
|
self.max_workspace_size = max_workspace_size * MiB
|
|
self.max_num_tokens = 0
|
|
self.disabled = False
|
|
|
|
def _ensure_workspace(self, hidden_dim: int, dtype: torch.dtype) -> bool:
|
|
"""Ensure the all reduce workspace is initialized."""
|
|
if get_fi_ar_workspace() is not None:
|
|
return True
|
|
if self.max_num_tokens == 0:
|
|
element_size = torch.tensor([], dtype=dtype, device="cpu").element_size()
|
|
self.max_num_tokens = self.max_workspace_size // (hidden_dim * element_size)
|
|
try:
|
|
initialize_fi_ar_workspace(
|
|
world_size=self.world_size,
|
|
rank=self.rank,
|
|
max_token_num=self.max_num_tokens,
|
|
hidden_dim=hidden_dim,
|
|
dtype=dtype,
|
|
group=self.group,
|
|
)
|
|
return True
|
|
except Exception as e:
|
|
logger.warning(
|
|
"Failed to initialize FlashInfer All Reduce workspace: %s. "
|
|
"FlashInfer All Reduce will be disabled.",
|
|
e,
|
|
)
|
|
self.disabled = True
|
|
return False
|
|
|
|
def should_use_fi_ar(self, input_tensor: torch.Tensor) -> bool:
|
|
if self.disabled:
|
|
return False
|
|
|
|
if not input_tensor.is_cuda:
|
|
return False
|
|
|
|
if not input_tensor.is_contiguous():
|
|
return False
|
|
|
|
if len(input_tensor.shape) != 2:
|
|
return False
|
|
|
|
num_tokens, hidden_dim = input_tensor.shape
|
|
if not self.max_num_tokens:
|
|
element_size = torch.tensor([], dtype=input_tensor.dtype).element_size()
|
|
self.max_num_tokens = self.max_workspace_size // (hidden_dim * element_size)
|
|
|
|
if num_tokens > self.max_num_tokens:
|
|
return False
|
|
|
|
return self._ensure_workspace(hidden_dim, input_tensor.dtype)
|
|
|
|
def all_reduce(self, input_tensor: torch.Tensor) -> torch.Tensor:
|
|
workspace = get_fi_ar_workspace()
|
|
return flashinfer_comm.allreduce_fusion(
|
|
input=input_tensor,
|
|
workspace=workspace,
|
|
pattern=flashinfer_comm.AllReduceFusionPattern.kAllReduce,
|
|
)
|
|
|
|
def destroy(self):
|
|
if not self.disabled:
|
|
destroy_fi_ar_workspace()
|