Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -11,11 +11,17 @@ import torch
import vllm.envs as envs
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable
_resize_cache, count_expert_num_tokens)
_resize_cache,
count_expert_num_tokens,
)
from vllm.utils import cdiv
from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled,
dbo_maybe_run_recv_hook,
dbo_register_recv_hook, dbo_yield)
from vllm.v1.worker.ubatching import (
dbo_current_ubatch_id,
dbo_enabled,
dbo_maybe_run_recv_hook,
dbo_register_recv_hook,
dbo_yield,
)
#
# This file defines a set of base classes used to make MoE kernels more modular.
@@ -59,31 +65,34 @@ class FusedMoEActivationFormat(Enum):
"""
The standard activation format (num_tokens, hidden dim).
"""
Standard = "standard",
Standard = ("standard",)
"""
The batched experts format (num experts, max tokens per expert, hidden dim)
"""
BatchedExperts = "batched_experts",
BatchedExperts = ("batched_experts",)
@dataclass
class ExpertTokensMetadata:
"""
Metadata regarding expert-token routing.
"""
Metadata regarding expert-token routing.
"""
expert_num_tokens: torch.Tensor
expert_num_tokens_cpu: Optional[torch.Tensor]
@staticmethod
def make_from_list(expert_num_tokens_list: list[int],
device: str) -> "ExpertTokensMetadata":
expert_num_tokens_cpu = torch.tensor(expert_num_tokens_list,
device="cpu",
dtype=torch.int32)
def make_from_list(
expert_num_tokens_list: list[int], device: str
) -> "ExpertTokensMetadata":
expert_num_tokens_cpu = torch.tensor(
expert_num_tokens_list, device="cpu", dtype=torch.int32
)
return ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens_cpu.to(device,
non_blocking=True),
expert_num_tokens_cpu=expert_num_tokens_cpu)
expert_num_tokens=expert_num_tokens_cpu.to(device, non_blocking=True),
expert_num_tokens_cpu=expert_num_tokens_cpu,
)
class TopKWeightAndReduce(ABC):
@@ -92,10 +101,14 @@ class TopKWeightAndReduce(ABC):
"""
@abstractmethod
def apply(self, output: Optional[torch.Tensor],
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool) -> torch.Tensor:
def apply(
self,
output: Optional[torch.Tensor],
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
) -> torch.Tensor:
"""
Apply topk_weights to the fused_experts_outputs and/or reduce.
If an output tensor is not passed, it will be created in the
@@ -200,16 +213,16 @@ class FusedMoEPrepareAndFinalize(ABC):
- apply_router_weight_on_input: When True, apply the weights to the
activations, before quantization + dispatching.
Returns a callback or a hook callback pair that when invoked waits for
results from other workers and has the same return signature as
Returns a callback or a hook callback pair that when invoked waits for
results from other workers and has the same return signature as
`prepare`, if a hook is returned this is more lightweight check that
the recv is complete without doing extra work (used by DBO, will be
the recv is complete without doing extra work (used by DBO, will be
refactored in the very near future)
e.g.
ret = obj.prepare_async(...)
if isinstance(ret, tuple):
hook, receiver = ret
hook()
@@ -270,10 +283,10 @@ class FusedMoEPrepareAndFinalize(ABC):
- weight_and_reduce_impl: An optional TopKWeightAndReduce
implementation.
Returns a callback or a hook callback pair that when invoked waits for
results from other workers and has the same return signature as
Returns a callback or a hook callback pair that when invoked waits for
results from other workers and has the same return signature as
`finalize`, if a hook is returned this is more lightweight check that
the recv is complete without doing extra work (used by DBO, will be
the recv is complete without doing extra work (used by DBO, will be
refactored in the very near future)
ret = obj.finalize_async(output, ...)
@@ -344,7 +357,8 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
@property
@abstractmethod
def activation_formats(
self) -> tuple[FusedMoEActivationFormat, FusedMoEActivationFormat]:
self,
) -> tuple[FusedMoEActivationFormat, FusedMoEActivationFormat]:
"""
A property which is a tuple of the input and output activation formats
for the 'apply' method.
@@ -382,8 +396,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
if a1.dim() == 2:
# Make sure we are using the correct a1 (pre-permute).
assert topk_ids.size(0) == a1.size(0), \
f"{topk_ids.size(0)} != {a1.size(0)}"
assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}"
M = a1.size(0)
else:
assert a1.dim() == 3
@@ -511,8 +524,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
"""
raise NotImplementedError
def activation(self, activation: str, output: torch.Tensor,
input: torch.Tensor) -> None:
def activation(
self, activation: str, output: torch.Tensor, input: torch.Tensor
) -> None:
assert output.size(-1) * 2 == input.size(-1)
if activation == "silu":
torch.ops._C.silu_and_mul(output, input)
@@ -522,8 +536,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
def enable_chunking(self):
return envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and \
self.supports_chunking()
return (
envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and self.supports_chunking()
)
def finalize_weight_and_reduce_impl(self) -> TopKWeightAndReduce:
raise NotImplementedError
@@ -585,8 +600,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
raise NotImplementedError
def _chunk_scales(scales: Optional[torch.Tensor], start: int,
end: int) -> Optional[torch.Tensor]:
def _chunk_scales(
scales: Optional[torch.Tensor], start: int, end: int
) -> Optional[torch.Tensor]:
if scales is not None:
if scales.numel() == 1:
return scales
@@ -596,17 +612,19 @@ def _chunk_scales(scales: Optional[torch.Tensor], start: int,
class SharedResizableBuffer:
def __init__(self):
self.buffer = None
def get(self, shape: tuple[int, ...], device: torch.device,
dtype: torch.dtype):
def get(self, shape: tuple[int, ...], device: torch.device, dtype: torch.dtype):
if shape == () or shape is None:
return None
shape_numel = prod(shape)
if (self.buffer is None or self.buffer.numel() < shape_numel
or self.buffer.device != device or self.buffer.dtype != dtype):
if (
self.buffer is None
or self.buffer.numel() < shape_numel
or self.buffer.device != device
or self.buffer.dtype != dtype
):
self.buffer = torch.empty(shape_numel, device=device, dtype=dtype)
return self.buffer[:shape_numel].view(*shape)
@@ -626,7 +644,6 @@ class FusedMoEModularKernel(torch.nn.Module):
"""
class SharedBuffers:
def __init__(self) -> None:
self.fused_out = SharedResizableBuffer()
self.workspace13 = SharedResizableBuffer()
@@ -652,12 +669,14 @@ class FusedMoEModularKernel(torch.nn.Module):
self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts
self.shared_experts = shared_experts
assert prepare_finalize.activation_format == \
fused_experts.activation_formats[0], (
f"{prepare_finalize.__class__.__name__}."
f"{prepare_finalize.activation_format} == "
f"{fused_experts.__class__.__name__}."
f"{fused_experts.activation_formats[0]}")
assert (
prepare_finalize.activation_format == fused_experts.activation_formats[0]
), (
f"{prepare_finalize.__class__.__name__}."
f"{prepare_finalize.activation_format} == "
f"{fused_experts.__class__.__name__}."
f"{fused_experts.activation_formats[0]}"
)
def _do_fused_experts(
self,
@@ -677,14 +696,21 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool,
) -> torch.Tensor:
_, M, N, K, top_k = self.fused_experts.moe_problem_size(a1q, w1, w2, topk_ids)
_, M, N, K, top_k = self.fused_experts.moe_problem_size(
a1q, w1, w2, topk_ids)
(workspace13_shape, workspace2_shape, fused_out_shape,
workspace_dtype) = self.fused_experts.workspace_shapes(
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
expert_tokens_meta)
(workspace13_shape, workspace2_shape, fused_out_shape, workspace_dtype) = (
self.fused_experts.workspace_shapes(
a1,
a1q,
M,
N,
K,
top_k,
global_num_experts,
local_num_experts,
expert_tokens_meta,
)
)
# select per-ubatch buffers to avoid cross-ubatch reuse under DBO
ubatch_idx = dbo_current_ubatch_id()
@@ -692,15 +718,16 @@ class FusedMoEModularKernel(torch.nn.Module):
# We can reuse the memory between cache1 and cache3 because by the
# time we need cache3, we're done with cache1.
workspace13 = buffers.workspace13.get(workspace13_shape,
device=a1.device,
dtype=workspace_dtype)
workspace2 = buffers.workspace2.get(workspace2_shape,
device=a1.device,
dtype=workspace_dtype)
workspace13 = buffers.workspace13.get(
workspace13_shape, device=a1.device, dtype=workspace_dtype
)
workspace2 = buffers.workspace2.get(
workspace2_shape, device=a1.device, dtype=workspace_dtype
)
assert fused_out is None or fused_out.shape == fused_out_shape, (
f"fused_out {fused_out.shape} but expected {fused_out_shape}")
f"fused_out {fused_out.shape} but expected {fused_out_shape}"
)
if fused_out is None:
# reuse workspace13 for the output
fused_out = _resize_cache(workspace13, fused_out_shape)
@@ -741,9 +768,7 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool,
) -> torch.Tensor:
_, M, N, K, top_k = self.fused_experts.moe_problem_size(
a1q, w1, w2, topk_ids)
_, M, N, K, top_k = self.fused_experts.moe_problem_size(a1q, w1, w2, topk_ids)
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
num_chunks = cdiv(M, CHUNK_SIZE)
@@ -775,18 +800,31 @@ class FusedMoEModularKernel(torch.nn.Module):
# Construct the entire output that can then be processed in chunks.
(_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes(
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
expert_tokens_meta)
a1,
a1q,
M,
N,
K,
top_k,
global_num_experts,
local_num_experts,
expert_tokens_meta,
)
ubatch_idx = dbo_current_ubatch_id()
buffers = self.shared_buffers[ubatch_idx]
fused_out = buffers.fused_out.get(fused_out_shape,
device=a1q.device,
dtype=a1.dtype)
fused_out = buffers.fused_out.get(
fused_out_shape, device=a1q.device, dtype=a1.dtype
)
def slice_input_tensors(
chunk_idx: int
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor], torch.Tensor, torch.Tensor]:
chunk_idx: int,
) -> tuple[
torch.Tensor,
Optional[torch.Tensor],
Optional[torch.Tensor],
torch.Tensor,
torch.Tensor,
]:
s = chunk_idx * CHUNK_SIZE
e = min(s + CHUNK_SIZE, M)
return (
@@ -799,7 +837,8 @@ class FusedMoEModularKernel(torch.nn.Module):
def slice_output_tensor(chunk_idx: int) -> torch.Tensor:
assert fused_out.size(0) % M == 0, (
f"fused_out shape {fused_out.shape} vs M {M}")
f"fused_out shape {fused_out.shape} vs M {M}"
)
factor = fused_out.size(0) // M
out_chunk_size = CHUNK_SIZE * factor
s = chunk_idx * out_chunk_size
@@ -807,38 +846,45 @@ class FusedMoEModularKernel(torch.nn.Module):
return fused_out[s:e]
def slice_expert_tokens_metadata(
full_expert_tokens_meta: ExpertTokensMetadata,
chunk_topk_ids: torch.Tensor, local_num_experts: int,
expert_map: Optional[torch.Tensor]) -> ExpertTokensMetadata:
full_expert_tokens_meta: ExpertTokensMetadata,
chunk_topk_ids: torch.Tensor,
local_num_experts: int,
expert_map: Optional[torch.Tensor],
) -> ExpertTokensMetadata:
# The existing expert_num_tokens is for the entire a1q
# input. Chunking forces recomputation of the number
# of tokens assigned to each expert.
c_expert_num_tokens = count_expert_num_tokens(
chunk_topk_ids, local_num_experts, expert_map)
chunk_topk_ids, local_num_experts, expert_map
)
c_expert_num_tokens_cpu = None
need_expert_num_tokens_cpu = (
full_expert_tokens_meta.expert_num_tokens_cpu is not None)
full_expert_tokens_meta.expert_num_tokens_cpu is not None
)
if need_expert_num_tokens_cpu:
# This is blocking as some implementations need the count
# on the CPU to determine appropriate input/out fused-moe
# buffers
c_expert_num_tokens_cpu = c_expert_num_tokens.to(
"cpu", non_blocking=False)
"cpu", non_blocking=False
)
return ExpertTokensMetadata(
expert_num_tokens=c_expert_num_tokens,
expert_num_tokens_cpu=c_expert_num_tokens_cpu)
expert_num_tokens_cpu=c_expert_num_tokens_cpu,
)
for chunk_idx in range(num_chunks):
c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = (
slice_input_tensors(chunk_idx))
slice_input_tensors(chunk_idx)
)
c_expert_tokens_meta = None
if expert_tokens_meta is not None:
c_expert_tokens_meta = slice_expert_tokens_metadata(
expert_tokens_meta, c_topk_ids, local_num_experts,
expert_map)
expert_tokens_meta, c_topk_ids, local_num_experts, expert_map
)
self._do_fused_experts(
fused_out=slice_output_tensor(chunk_idx),
@@ -917,16 +963,21 @@ class FusedMoEModularKernel(torch.nn.Module):
# TODO(lucas): enable in follow-up
assert not dbo_enabled()
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
_expert_topk_weights) = self.prepare_finalize.prepare(
a1,
topk_weights,
topk_ids,
global_num_experts,
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
)
(
a1q,
a1q_scale,
expert_tokens_meta,
_expert_topk_ids,
_expert_topk_weights,
) = self.prepare_finalize.prepare(
a1,
topk_weights,
topk_ids,
global_num_experts,
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
)
else:
# Overlap shared expert compute with all2all dispatch.
dbo_maybe_run_recv_hook()
@@ -943,8 +994,9 @@ class FusedMoEModularKernel(torch.nn.Module):
# TODO(lucas): refactor this in the alternative schedules followup
# currently unpack if we have hook + receiver pair or just
# receiver (see finalize_async docstring)
hook, receiver = prepare_ret \
if isinstance(prepare_ret, tuple) else (None, prepare_ret)
hook, receiver = (
prepare_ret if isinstance(prepare_ret, tuple) else (None, prepare_ret)
)
if hook is not None:
if dbo_enabled():
@@ -956,13 +1008,19 @@ class FusedMoEModularKernel(torch.nn.Module):
else:
hook()
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
_expert_topk_weights) = receiver()
(
a1q,
a1q_scale,
expert_tokens_meta,
_expert_topk_ids,
_expert_topk_weights,
) = receiver()
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
topk_weights = (topk_weights if _expert_topk_weights is None else
_expert_topk_weights)
topk_weights = (
topk_weights if _expert_topk_weights is None else _expert_topk_weights
)
fused_out = None
@@ -1022,8 +1080,11 @@ class FusedMoEModularKernel(torch.nn.Module):
# TODO(lucas): refactor this in the alternative schedules followup
# currently unpack if we have hook + receiver pair or just
# receiver (see finalize_async docstring)
hook, receiver = finalize_ret \
if isinstance(finalize_ret, tuple) else (None, finalize_ret)
hook, receiver = (
finalize_ret
if isinstance(finalize_ret, tuple)
else (None, finalize_ret)
)
if hook is not None:
if dbo_enabled():