[CI] Fix mypy for vllm/distributed (#26593)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Wentao Ye
2025-10-13 16:02:24 -04:00
committed by GitHub
parent d2a7938582
commit 314285d4f2
14 changed files with 122 additions and 65 deletions

View File

@@ -15,9 +15,11 @@ from vllm.utils.flashinfer import has_flashinfer_all2all
from .base_device_communicator import All2AllManagerBase, Cache
if has_flashinfer_all2all():
from flashinfer.comm import Mapping
from flashinfer.comm.mnnvl import MnnvlConfig
from flashinfer.comm.trtllm_alltoall import MnnvlMoe
from flashinfer.comm import Mapping # type: ignore[import-not-found]
from flashinfer.comm.mnnvl import MnnvlConfig # type: ignore[import-not-found]
from flashinfer.comm.trtllm_alltoall import (
MnnvlMoe, # type: ignore[import-not-found]
)
logger = init_logger(__name__)
@@ -65,6 +67,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
) -> tuple[torch.Tensor, torch.Tensor]:
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
hidden_states = self.naive_multicast(
@@ -81,6 +84,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
ep_rank = self.rank if is_sequence_parallel else self.dp_rank
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
@@ -113,7 +117,10 @@ class AgRsAll2AllManager(All2AllManagerBase):
"""
Gather hidden_states and router_logits from all dp ranks.
"""
sizes = get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
assert sizes is not None
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
@@ -130,7 +137,10 @@ class AgRsAll2AllManager(All2AllManagerBase):
"""
Reduce-scatter hidden_states across all dp ranks.
"""
sizes = get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
assert sizes is not None
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
hidden_states = dist_group.reduce_scatterv(hidden_states, dim=0, sizes=sizes)
@@ -155,7 +165,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
if self.internode:
# inter-node communication needs nvshmem,
# intra-node communication uses p2p mapping directly
from pplx_kernels.nvshmem import (
from pplx_kernels.nvshmem import ( # type: ignore[import-not-found]
nvshmem_alloc_empty_unique_id,
nvshmem_get_unique_id,
nvshmem_init,
@@ -182,7 +192,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
self.handle_cache = Cache()
def get_handle(self, kwargs):
import pplx_kernels as pplx
import pplx_kernels as pplx # type: ignore[import-not-found]
return self.handle_cache.get_or_create(
kwargs,
@@ -208,7 +218,9 @@ class PPLXAll2AllManager(All2AllManagerBase):
handle.destroy()
if self.internode:
from pplx_kernels.nvshmem import nvshmem_finalize
from pplx_kernels.nvshmem import (
nvshmem_finalize, # type: ignore[import-not-found]
)
logger.debug("PPLX NVSHMEM finalize")
nvshmem_finalize()
@@ -288,7 +300,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
"args are computed in the Manager itself."
)
import deep_ep
import deep_ep # type: ignore[import-not-found]
buffer_kwargs = self._make_all2all_kwargs()
logger.debug("DeepEP all2all args %s", buffer_kwargs)
@@ -298,7 +310,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
return handle
def set_num_sms(self, num_sms: int):
import deep_ep
import deep_ep # type: ignore[import-not-found]
# Right now the buffers are sized for only what the kernels were
# created with. So we can only reduce the number of SMS used
@@ -332,7 +344,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
num_global_experts: Number of experts in the model.
num_local_experts: Number of experts in an EP rank.
"""
import deep_ep
import deep_ep # type: ignore[import-not-found]
# Defaults for internode and intranode are taken from DeepEP tests.
num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
@@ -358,7 +370,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
The kwargs for DeepEPLLAll2AllManager is dictated by
_make_all2all_kwargs.
"""
import deep_ep
import deep_ep # type: ignore[import-not-found]
buffer_kwargs = self._make_all2all_kwargs(**kwargs)
logger.debug("DeepEP all2all args %s", buffer_kwargs)