Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -33,17 +33,19 @@ class NaiveAll2AllManager(All2AllManagerBase):
def __init__(self, cpu_group):
super().__init__(cpu_group)
def naive_multicast(self, x: torch.Tensor,
cu_tokens_across_sp_cpu: torch.Tensor,
is_sequence_parallel: bool) -> torch.Tensor:
assert (len(x.shape) == 2)
buffer = torch.empty((cu_tokens_across_sp_cpu[-1], x.size(1)),
device=x.device,
dtype=x.dtype)
def naive_multicast(
self,
x: torch.Tensor,
cu_tokens_across_sp_cpu: torch.Tensor,
is_sequence_parallel: bool,
) -> torch.Tensor:
assert len(x.shape) == 2
buffer = torch.empty(
(cu_tokens_across_sp_cpu[-1], x.size(1)), device=x.device, dtype=x.dtype
)
rank = self.rank if is_sequence_parallel else self.dp_rank
world_size = (self.world_size
if is_sequence_parallel else self.dp_world_size)
world_size = self.world_size if is_sequence_parallel else self.dp_world_size
start = 0 if rank == 0 else cu_tokens_across_sp_cpu[rank - 1]
end = cu_tokens_across_sp_cpu[rank]
@@ -59,24 +61,23 @@ class NaiveAll2AllManager(All2AllManagerBase):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
is_sequence_parallel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
dp_metadata = get_forward_context().dp_metadata
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
hidden_states = self.naive_multicast(hidden_states,
cu_tokens_across_sp_cpu,
is_sequence_parallel)
router_logits = self.naive_multicast(router_logits,
cu_tokens_across_sp_cpu,
is_sequence_parallel)
hidden_states = self.naive_multicast(
hidden_states, cu_tokens_across_sp_cpu, is_sequence_parallel
)
router_logits = self.naive_multicast(
router_logits, cu_tokens_across_sp_cpu, is_sequence_parallel
)
return hidden_states, router_logits
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
ep_rank = self.rank if is_sequence_parallel else self.dp_rank
dp_metadata = get_forward_context().dp_metadata
@@ -107,13 +108,12 @@ class AgRsAll2AllManager(All2AllManagerBase):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
is_sequence_parallel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Gather hidden_states and router_logits from all dp ranks.
"""
sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank()
sizes = get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
@@ -124,19 +124,16 @@ class AgRsAll2AllManager(All2AllManagerBase):
)
return hidden_states, router_logits
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
"""
Reduce-scatter hidden_states across all dp ranks.
"""
sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank()
sizes = get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
hidden_states = dist_group.reduce_scatterv(hidden_states,
dim=0,
sizes=sizes)
hidden_states = dist_group.reduce_scatterv(hidden_states, dim=0, sizes=sizes)
return hidden_states
def destroy(self):
@@ -149,24 +146,35 @@ class PPLXAll2AllManager(All2AllManagerBase):
"""
def __init__(self, cpu_group):
assert has_pplx(
), "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa
assert has_pplx(), (
"pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels."
) # noqa
super().__init__(cpu_group)
if self.internode:
# inter-node communication needs nvshmem,
# intra-node communication uses p2p mapping directly
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
nvshmem_get_unique_id,
nvshmem_init)
from pplx_kernels.nvshmem import (
nvshmem_alloc_empty_unique_id,
nvshmem_get_unique_id,
nvshmem_init,
)
logger.debug(
"Initialize NVSHMEM for pplx_kernels: "
"rank=%d, world size=%d", self.rank, self.world_size)
uid = nvshmem_get_unique_id(
) if self.rank == 0 else nvshmem_alloc_empty_unique_id()
dist.broadcast(uid,
src=dist.get_process_group_ranks(self.cpu_group)[0],
group=self.cpu_group)
"Initialize NVSHMEM for pplx_kernels: rank=%d, world size=%d",
self.rank,
self.world_size,
)
uid = (
nvshmem_get_unique_id()
if self.rank == 0
else nvshmem_alloc_empty_unique_id()
)
dist.broadcast(
uid,
src=dist.get_process_group_ranks(self.cpu_group)[0],
group=self.cpu_group,
)
logger.debug("PPLX NVSHMEM UID = %s", uid)
nvshmem_init(uid, self.rank, self.world_size)
@@ -174,21 +182,23 @@ class PPLXAll2AllManager(All2AllManagerBase):
def get_handle(self, kwargs):
import pplx_kernels as pplx
return self.handle_cache.get_or_create(
kwargs, pplx.AllToAll.internode
if self.internode else pplx.AllToAll.intranode)
kwargs,
pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode,
)
def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
is_sequence_parallel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
raise NotImplementedError
def destroy(self):
@@ -198,6 +208,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
if self.internode:
from pplx_kernels.nvshmem import nvshmem_finalize
logger.debug("PPLX NVSHMEM finalize")
nvshmem_finalize()
@@ -208,8 +219,9 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
"""
def __init__(self, cpu_group):
assert has_deep_ep(
), "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa
assert has_deep_ep(), (
"DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels."
) # noqa
super().__init__(cpu_group)
self.handle_cache = Cache()
@@ -224,13 +236,13 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
is_sequence_parallel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
raise NotImplementedError
def destroy(self):
@@ -260,23 +272,27 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
assert num_rdma_bytes is not None
assert num_qps_per_rank is not None
return dict(group=self.cpu_group,
num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=False,
num_qps_per_rank=num_qps_per_rank)
return dict(
group=self.cpu_group,
num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=False,
num_qps_per_rank=num_qps_per_rank,
)
def get_handle(self, kwargs):
assert len(kwargs) == 0, (
"DeepEPHTAll2AllManager expects no arguments. All the required "
"args are computed in the Manager itself.")
"args are computed in the Manager itself."
)
import deep_ep
buffer_kwargs = self._make_all2all_kwargs()
logger.debug("DeepEP all2all args %s", buffer_kwargs)
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
buffer_kwargs, deep_ep.Buffer)
buffer_kwargs, deep_ep.Buffer
)
return handle
def set_num_sms(self, num_sms: int):
@@ -323,14 +339,17 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
hidden=token_hidden_size,
num_ranks=num_ep_ranks,
num_experts=num_global_experts)
num_experts=num_global_experts,
)
assert num_rdma_bytes is not None
return dict(group=self.cpu_group,
num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=num_qps_per_rank)
return dict(
group=self.cpu_group,
num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=num_qps_per_rank,
)
def get_handle(self, kwargs):
"""
@@ -338,10 +357,12 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
_make_all2all_kwargs.
"""
import deep_ep
buffer_kwargs = self._make_all2all_kwargs(**kwargs)
logger.debug("DeepEP all2all args %s", buffer_kwargs)
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
buffer_kwargs, deep_ep.Buffer)
buffer_kwargs, deep_ep.Buffer
)
return handle
# DeepEP LL uses RDMA so no SMs are used for communication
@@ -355,12 +376,15 @@ class FlashInferAllToAllManager(All2AllManagerBase):
"""
def __init__(self, cpu_group):
assert has_flashinfer_all2all(
), "flashinfer all2all module not found. Please install/check flashinfer" # noqa
assert has_flashinfer_all2all(), (
"flashinfer all2all module not found. Please install/check flashinfer"
) # noqa
super().__init__(cpu_group)
logger.debug(
"Initialize for flashinfer All2All "
"rank=%d, world size=%d", self.rank, self.world_size)
"Initialize for flashinfer All2All rank=%d, world size=%d",
self.rank,
self.world_size,
)
self.initialized = False
self.alltoall_info = None
@@ -375,8 +399,7 @@ class FlashInferAllToAllManager(All2AllManagerBase):
return
self.cleanup()
logger.debug("making map: "
"rank=%d, world size=%d", rank, world_size)
logger.debug("making map: rank=%d, world size=%d", rank, world_size)
self.mapping = Mapping(
world_size,
rank,
@@ -385,25 +408,28 @@ class FlashInferAllToAllManager(All2AllManagerBase):
)
from vllm.distributed.device_communicators.mnnvl_compat import (
CustomCommunicator)
CustomCommunicator,
)
dp_config = MnnvlConfig(
comm_backend=CustomCommunicator(get_dp_group().cpu_group),
fabric_page_size=1 << 29, # 512MB
allocation_granularity=0 # Auto-detect
allocation_granularity=0, # Auto-detect
)
self.workspace_tensor = MnnvlMoe.get_moe_workspaces(
self.mapping, dp_config)
self.workspace_tensor = MnnvlMoe.get_moe_workspaces(self.mapping, dp_config)
self.prepare_workspace_tensor = MnnvlMoe.get_moe_prepare_workspace(
self.mapping, dp_config)
self.mapping, dp_config
)
self.world_size = world_size
self.rank = rank
self.gpus_per_node = gpus_per_node
self.initialized = True
logger.info("FlashInfer All2All initialized for rank %s, size %s",
rank, world_size)
logger.info(
"FlashInfer All2All initialized for rank %s, size %s", rank, world_size
)
def ensure_alltoall_workspace_initialized(self):
"""Ensure workspace is initialized"""
@@ -426,8 +452,11 @@ class FlashInferAllToAllManager(All2AllManagerBase):
def cleanup(self):
"""Clean up workspace"""
if self.initialized and self.workspace_tensor is not None \
and self.prepare_workspace_tensor is not None:
if (
self.initialized
and self.workspace_tensor is not None
and self.prepare_workspace_tensor is not None
):
try:
del self.workspace_tensor
del self.prepare_workspace_tensor