[CI] Fix mypy for vllm/distributed (#26593)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -15,9 +15,11 @@ from vllm.utils.flashinfer import has_flashinfer_all2all
|
||||
from .base_device_communicator import All2AllManagerBase, Cache
|
||||
|
||||
if has_flashinfer_all2all():
|
||||
from flashinfer.comm import Mapping
|
||||
from flashinfer.comm.mnnvl import MnnvlConfig
|
||||
from flashinfer.comm.trtllm_alltoall import MnnvlMoe
|
||||
from flashinfer.comm import Mapping # type: ignore[import-not-found]
|
||||
from flashinfer.comm.mnnvl import MnnvlConfig # type: ignore[import-not-found]
|
||||
from flashinfer.comm.trtllm_alltoall import (
|
||||
MnnvlMoe, # type: ignore[import-not-found]
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -65,6 +67,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
assert dp_metadata is not None
|
||||
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
|
||||
|
||||
hidden_states = self.naive_multicast(
|
||||
@@ -81,6 +84,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
|
||||
ep_rank = self.rank if is_sequence_parallel else self.dp_rank
|
||||
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
assert dp_metadata is not None
|
||||
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
|
||||
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
|
||||
|
||||
@@ -113,7 +117,10 @@ class AgRsAll2AllManager(All2AllManagerBase):
|
||||
"""
|
||||
Gather hidden_states and router_logits from all dp ranks.
|
||||
"""
|
||||
sizes = get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
assert dp_metadata is not None
|
||||
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
assert sizes is not None
|
||||
|
||||
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
|
||||
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
|
||||
@@ -130,7 +137,10 @@ class AgRsAll2AllManager(All2AllManagerBase):
|
||||
"""
|
||||
Reduce-scatter hidden_states across all dp ranks.
|
||||
"""
|
||||
sizes = get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
assert dp_metadata is not None
|
||||
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
assert sizes is not None
|
||||
|
||||
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
|
||||
hidden_states = dist_group.reduce_scatterv(hidden_states, dim=0, sizes=sizes)
|
||||
@@ -155,7 +165,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
|
||||
if self.internode:
|
||||
# inter-node communication needs nvshmem,
|
||||
# intra-node communication uses p2p mapping directly
|
||||
from pplx_kernels.nvshmem import (
|
||||
from pplx_kernels.nvshmem import ( # type: ignore[import-not-found]
|
||||
nvshmem_alloc_empty_unique_id,
|
||||
nvshmem_get_unique_id,
|
||||
nvshmem_init,
|
||||
@@ -182,7 +192,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
|
||||
self.handle_cache = Cache()
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
import pplx_kernels as pplx
|
||||
import pplx_kernels as pplx # type: ignore[import-not-found]
|
||||
|
||||
return self.handle_cache.get_or_create(
|
||||
kwargs,
|
||||
@@ -208,7 +218,9 @@ class PPLXAll2AllManager(All2AllManagerBase):
|
||||
handle.destroy()
|
||||
|
||||
if self.internode:
|
||||
from pplx_kernels.nvshmem import nvshmem_finalize
|
||||
from pplx_kernels.nvshmem import (
|
||||
nvshmem_finalize, # type: ignore[import-not-found]
|
||||
)
|
||||
|
||||
logger.debug("PPLX NVSHMEM finalize")
|
||||
nvshmem_finalize()
|
||||
@@ -288,7 +300,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
"args are computed in the Manager itself."
|
||||
)
|
||||
|
||||
import deep_ep
|
||||
import deep_ep # type: ignore[import-not-found]
|
||||
|
||||
buffer_kwargs = self._make_all2all_kwargs()
|
||||
logger.debug("DeepEP all2all args %s", buffer_kwargs)
|
||||
@@ -298,7 +310,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
return handle
|
||||
|
||||
def set_num_sms(self, num_sms: int):
|
||||
import deep_ep
|
||||
import deep_ep # type: ignore[import-not-found]
|
||||
|
||||
# Right now the buffers are sized for only what the kernels were
|
||||
# created with. So we can only reduce the number of SMS used
|
||||
@@ -332,7 +344,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
num_global_experts: Number of experts in the model.
|
||||
num_local_experts: Number of experts in an EP rank.
|
||||
"""
|
||||
import deep_ep
|
||||
import deep_ep # type: ignore[import-not-found]
|
||||
|
||||
# Defaults for internode and intranode are taken from DeepEP tests.
|
||||
num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
|
||||
@@ -358,7 +370,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
The kwargs for DeepEPLLAll2AllManager is dictated by
|
||||
_make_all2all_kwargs.
|
||||
"""
|
||||
import deep_ep
|
||||
import deep_ep # type: ignore[import-not-found]
|
||||
|
||||
buffer_kwargs = self._make_all2all_kwargs(**kwargs)
|
||||
logger.debug("DeepEP all2all args %s", buffer_kwargs)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -118,15 +119,18 @@ class CustomAllreduce:
|
||||
# now `device` is a `torch.device` object
|
||||
assert isinstance(device, torch.device)
|
||||
self.device = device
|
||||
device_capability = current_platform.get_device_capability().as_version_str()
|
||||
device_capability = current_platform.get_device_capability()
|
||||
if (
|
||||
current_platform.is_cuda()
|
||||
and symm_mem_enabled
|
||||
and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES
|
||||
and device_capability is not None
|
||||
):
|
||||
max_size = min(
|
||||
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size], max_size
|
||||
)
|
||||
device_capability_str = device_capability.as_version_str()
|
||||
if device_capability_str in CUSTOM_ALL_REDUCE_MAX_SIZES:
|
||||
max_size = min(
|
||||
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability_str][world_size],
|
||||
max_size,
|
||||
)
|
||||
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
|
||||
if cuda_visible_devices:
|
||||
device_ids = list(map(int, cuda_visible_devices.split(",")))
|
||||
@@ -213,6 +217,7 @@ class CustomAllreduce:
|
||||
# We cannot directly use `dist.all_gather_object` here
|
||||
# because it is incompatible with `gloo` backend under inference mode.
|
||||
# see https://github.com/pytorch/pytorch/issues/126032 for details.
|
||||
all_data: list[list[list[int] | None]]
|
||||
all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))]
|
||||
all_data[self.rank] = [handle, offset]
|
||||
ranks = sorted(dist.get_process_group_ranks(group=self.group))
|
||||
@@ -221,8 +226,8 @@ class CustomAllreduce:
|
||||
all_data[i], src=rank, group=self.group, device="cpu"
|
||||
)
|
||||
# Unpack list of tuples to tuple of lists.
|
||||
handles = [d[0] for d in all_data] # type: ignore
|
||||
offsets = [d[1] for d in all_data] # type: ignore
|
||||
handles = cast(list[list[int]], [d[0] for d in all_data])
|
||||
offsets = cast(list[list[int]], [d[1] for d in all_data])
|
||||
ops.register_graph_buffers(self._ptr, handles, offsets)
|
||||
|
||||
def should_custom_ar(self, inp: torch.Tensor):
|
||||
|
||||
@@ -52,9 +52,14 @@ class SymmMemCommunicator:
|
||||
self.device = device
|
||||
self.group = group
|
||||
self.world_size = dist.get_world_size(self.group)
|
||||
self.device_capability = (
|
||||
current_platform.get_device_capability().as_version_str()
|
||||
)
|
||||
capability = current_platform.get_device_capability()
|
||||
if capability is None:
|
||||
logger.warning(
|
||||
"SymmMemCommunicator: device capability is unknown, "
|
||||
"communicator is not available."
|
||||
)
|
||||
return
|
||||
self.device_capability = capability.as_version_str()
|
||||
if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES:
|
||||
logger.warning(
|
||||
"SymmMemCommunicator: Device capability %s not supported, "
|
||||
|
||||
Reference in New Issue
Block a user