Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -10,7 +10,6 @@ from torch.distributed import ProcessGroup
class Cache:
def __init__(self):
self._cache: WeakValueDictionary = WeakValueDictionary()
self._lock = threading.RLock() # Reentrant lock for thread safety
@@ -35,9 +34,11 @@ class All2AllManagerBase:
self.cpu_group = cpu_group
# compute some common properties
from vllm.distributed.parallel_state import (get_dp_group,
get_tp_group,
in_the_same_node_as)
from vllm.distributed.parallel_state import (
get_dp_group,
get_tp_group,
in_the_same_node_as,
)
# all2all lives in ep group, which is merged from dp and tp group
self.dp_group = get_dp_group()
@@ -63,10 +64,12 @@ class All2AllManagerBase:
# and reuse it for the same config.
raise NotImplementedError
def dispatch(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False):
def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
):
raise NotImplementedError
def set_num_sms(self, num_sms: int):
@@ -75,9 +78,7 @@ class All2AllManagerBase:
def max_sms_used(self) -> Optional[int]:
return None # None means it could use the whole GPU
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False):
def combine(self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False):
raise NotImplementedError
def destroy(self):
@@ -92,11 +93,13 @@ class DeviceCommunicatorBase:
communication backend), the `device_group` will also be given.
"""
def __init__(self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
def __init__(
self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = "",
):
self.device = device or torch.device("cpu")
self.cpu_group = cpu_group
self.device_group = device_group
@@ -106,11 +109,11 @@ class DeviceCommunicatorBase:
self.ranks = dist.get_process_group_ranks(cpu_group)
self.global_rank = dist.get_rank()
self.global_world_size = dist.get_world_size()
self.rank_in_group = dist.get_group_rank(self.cpu_group,
self.global_rank)
self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank)
use_ep = False
from vllm.config import get_current_vllm_config
config = get_current_vllm_config()
if config is not None:
# as long as we use data parallel (coupled data parallel
@@ -134,41 +137,39 @@ class DeviceCommunicatorBase:
# NOTE: we have to use concat-style all-gather here,
# stack-style all-gather has compatibility issues with
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size = (input_size[0] * self.world_size, ) + input_size[1:]
output_size = (input_size[0] * self.world_size,) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(output_size,
dtype=input_.dtype,
device=input_.device)
output_tensor = torch.empty(
output_size, dtype=input_.dtype, device=input_.device
)
# All-gather.
dist.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
dist.all_gather_into_tensor(output_tensor, input_, group=self.device_group)
# Reshape
output_tensor = output_tensor.reshape((self.world_size, ) + input_size)
output_tensor = output_tensor.reshape((self.world_size,) + input_size)
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(self.world_size *
input_size[dim], ) +
input_size[dim + 1:])
output_tensor = output_tensor.reshape(
input_size[:dim]
+ (self.world_size * input_size[dim],)
+ input_size[dim + 1 :]
)
return output_tensor
def all_gatherv(
self,
input_: Union[torch.Tensor, list[torch.Tensor]],
dim: int = 0,
sizes: Optional[list[int]] = None
sizes: Optional[list[int]] = None,
) -> Union[torch.Tensor, list[torch.Tensor]]:
raise NotImplementedError
def reduce_scatter(self,
input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
)
if dim < 0:
# Convert negative dim to positive.
@@ -180,30 +181,28 @@ class DeviceCommunicatorBase:
assert input_tensor.shape[0] % world_size == 0
chunk_size = input_tensor.shape[0] // world_size
output_shape = (chunk_size, ) + input_tensor.shape[1:]
output_shape = (chunk_size,) + input_tensor.shape[1:]
output_tensor = torch.empty(output_shape,
dtype=input_tensor.dtype,
device=input_tensor.device)
output_tensor = torch.empty(
output_shape, dtype=input_tensor.dtype, device=input_tensor.device
)
# Perform reduce-scatter operation
torch.distributed.reduce_scatter_tensor(output_tensor,
input_tensor,
group=self.device_group)
torch.distributed.reduce_scatter_tensor(
output_tensor, input_tensor, group=self.device_group
)
# Reshape before returning
return output_tensor.movedim(0, dim).contiguous()
def reduce_scatterv(self,
input_: torch.Tensor,
dim: int = -1,
sizes: Optional[list[int]] = None) -> torch.Tensor:
def reduce_scatterv(
self, input_: torch.Tensor, dim: int = -1, sizes: Optional[list[int]] = None
) -> torch.Tensor:
raise NotImplementedError
def gather(self,
input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> Optional[torch.Tensor]:
def gather(
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
) -> Optional[torch.Tensor]:
"""
NOTE: We assume that the input tensor is on the same device across
all the ranks.
@@ -211,7 +210,8 @@ class DeviceCommunicatorBase:
"""
world_size = self.world_size
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
)
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
@@ -222,10 +222,9 @@ class DeviceCommunicatorBase:
else:
gather_list = None
# Gather.
torch.distributed.gather(input_,
gather_list,
dst=self.ranks[dst],
group=self.device_group)
torch.distributed.gather(
input_, gather_list, dst=self.ranks[dst], group=self.device_group
)
if self.rank_in_group == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
@@ -239,10 +238,9 @@ class DeviceCommunicatorBase:
dst = (self.rank_in_group + 1) % self.world_size
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
def recv(self,
size: torch.Size,
dtype: torch.dtype,
src: Optional[int] = None) -> torch.Tensor:
def recv(
self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None
) -> torch.Tensor:
"""Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
if src is None:
@@ -255,8 +253,7 @@ class DeviceCommunicatorBase:
def destroy(self):
pass
def prepare_communication_buffer_for_model(self,
model: torch.nn.Module) -> None:
def prepare_communication_buffer_for_model(self, model: torch.nn.Module) -> None:
"""
Prepare the communication buffer for the model.
"""
@@ -264,11 +261,14 @@ class DeviceCommunicatorBase:
return
moe_modules = [
module for module in model.modules()
module
for module in model.modules()
# TODO(bnell): Should use isinstance but can't. Maybe search for
# presence of quant_method.init_prepare_finalize?
if (module.__class__.__name__ == "FusedMoE"
or module.__class__.__name__ == "SharedFusedMoE")
if (
module.__class__.__name__ == "FusedMoE"
or module.__class__.__name__ == "SharedFusedMoE"
)
]
for module in moe_modules:
module.quant_method.init_prepare_finalize(module)
@@ -277,7 +277,7 @@ class DeviceCommunicatorBase:
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
is_sequence_parallel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Dispatch the hidden states and router logits to the appropriate device.
@@ -285,9 +285,9 @@ class DeviceCommunicatorBase:
"""
return hidden_states, router_logits
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.