Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -11,11 +11,17 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable
|
||||
_resize_cache, count_expert_num_tokens)
|
||||
_resize_cache,
|
||||
count_expert_num_tokens,
|
||||
)
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled,
|
||||
dbo_maybe_run_recv_hook,
|
||||
dbo_register_recv_hook, dbo_yield)
|
||||
from vllm.v1.worker.ubatching import (
|
||||
dbo_current_ubatch_id,
|
||||
dbo_enabled,
|
||||
dbo_maybe_run_recv_hook,
|
||||
dbo_register_recv_hook,
|
||||
dbo_yield,
|
||||
)
|
||||
|
||||
#
|
||||
# This file defines a set of base classes used to make MoE kernels more modular.
|
||||
@@ -59,31 +65,34 @@ class FusedMoEActivationFormat(Enum):
|
||||
"""
|
||||
The standard activation format (num_tokens, hidden dim).
|
||||
"""
|
||||
Standard = "standard",
|
||||
|
||||
Standard = ("standard",)
|
||||
"""
|
||||
The batched experts format (num experts, max tokens per expert, hidden dim)
|
||||
"""
|
||||
BatchedExperts = "batched_experts",
|
||||
BatchedExperts = ("batched_experts",)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExpertTokensMetadata:
|
||||
"""
|
||||
Metadata regarding expert-token routing.
|
||||
"""
|
||||
Metadata regarding expert-token routing.
|
||||
"""
|
||||
|
||||
expert_num_tokens: torch.Tensor
|
||||
expert_num_tokens_cpu: Optional[torch.Tensor]
|
||||
|
||||
@staticmethod
|
||||
def make_from_list(expert_num_tokens_list: list[int],
|
||||
device: str) -> "ExpertTokensMetadata":
|
||||
expert_num_tokens_cpu = torch.tensor(expert_num_tokens_list,
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
def make_from_list(
|
||||
expert_num_tokens_list: list[int], device: str
|
||||
) -> "ExpertTokensMetadata":
|
||||
expert_num_tokens_cpu = torch.tensor(
|
||||
expert_num_tokens_list, device="cpu", dtype=torch.int32
|
||||
)
|
||||
return ExpertTokensMetadata(
|
||||
expert_num_tokens=expert_num_tokens_cpu.to(device,
|
||||
non_blocking=True),
|
||||
expert_num_tokens_cpu=expert_num_tokens_cpu)
|
||||
expert_num_tokens=expert_num_tokens_cpu.to(device, non_blocking=True),
|
||||
expert_num_tokens_cpu=expert_num_tokens_cpu,
|
||||
)
|
||||
|
||||
|
||||
class TopKWeightAndReduce(ABC):
|
||||
@@ -92,10 +101,14 @@ class TopKWeightAndReduce(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, output: Optional[torch.Tensor],
|
||||
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool) -> torch.Tensor:
|
||||
def apply(
|
||||
self,
|
||||
output: Optional[torch.Tensor],
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply topk_weights to the fused_experts_outputs and/or reduce.
|
||||
If an output tensor is not passed, it will be created in the
|
||||
@@ -200,16 +213,16 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
- apply_router_weight_on_input: When True, apply the weights to the
|
||||
activations, before quantization + dispatching.
|
||||
|
||||
Returns a callback or a hook callback pair that when invoked waits for
|
||||
results from other workers and has the same return signature as
|
||||
Returns a callback or a hook callback pair that when invoked waits for
|
||||
results from other workers and has the same return signature as
|
||||
`prepare`, if a hook is returned this is more lightweight check that
|
||||
the recv is complete without doing extra work (used by DBO, will be
|
||||
the recv is complete without doing extra work (used by DBO, will be
|
||||
refactored in the very near future)
|
||||
|
||||
|
||||
e.g.
|
||||
|
||||
ret = obj.prepare_async(...)
|
||||
|
||||
|
||||
if isinstance(ret, tuple):
|
||||
hook, receiver = ret
|
||||
hook()
|
||||
@@ -270,10 +283,10 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
- weight_and_reduce_impl: An optional TopKWeightAndReduce
|
||||
implementation.
|
||||
|
||||
Returns a callback or a hook callback pair that when invoked waits for
|
||||
results from other workers and has the same return signature as
|
||||
Returns a callback or a hook callback pair that when invoked waits for
|
||||
results from other workers and has the same return signature as
|
||||
`finalize`, if a hook is returned this is more lightweight check that
|
||||
the recv is complete without doing extra work (used by DBO, will be
|
||||
the recv is complete without doing extra work (used by DBO, will be
|
||||
refactored in the very near future)
|
||||
|
||||
ret = obj.finalize_async(output, ...)
|
||||
@@ -344,7 +357,8 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def activation_formats(
|
||||
self) -> tuple[FusedMoEActivationFormat, FusedMoEActivationFormat]:
|
||||
self,
|
||||
) -> tuple[FusedMoEActivationFormat, FusedMoEActivationFormat]:
|
||||
"""
|
||||
A property which is a tuple of the input and output activation formats
|
||||
for the 'apply' method.
|
||||
@@ -382,8 +396,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
|
||||
if a1.dim() == 2:
|
||||
# Make sure we are using the correct a1 (pre-permute).
|
||||
assert topk_ids.size(0) == a1.size(0), \
|
||||
f"{topk_ids.size(0)} != {a1.size(0)}"
|
||||
assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}"
|
||||
M = a1.size(0)
|
||||
else:
|
||||
assert a1.dim() == 3
|
||||
@@ -511,8 +524,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def activation(self, activation: str, output: torch.Tensor,
|
||||
input: torch.Tensor) -> None:
|
||||
def activation(
|
||||
self, activation: str, output: torch.Tensor, input: torch.Tensor
|
||||
) -> None:
|
||||
assert output.size(-1) * 2 == input.size(-1)
|
||||
if activation == "silu":
|
||||
torch.ops._C.silu_and_mul(output, input)
|
||||
@@ -522,8 +536,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
||||
|
||||
def enable_chunking(self):
|
||||
return envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and \
|
||||
self.supports_chunking()
|
||||
return (
|
||||
envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and self.supports_chunking()
|
||||
)
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> TopKWeightAndReduce:
|
||||
raise NotImplementedError
|
||||
@@ -585,8 +600,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _chunk_scales(scales: Optional[torch.Tensor], start: int,
|
||||
end: int) -> Optional[torch.Tensor]:
|
||||
def _chunk_scales(
|
||||
scales: Optional[torch.Tensor], start: int, end: int
|
||||
) -> Optional[torch.Tensor]:
|
||||
if scales is not None:
|
||||
if scales.numel() == 1:
|
||||
return scales
|
||||
@@ -596,17 +612,19 @@ def _chunk_scales(scales: Optional[torch.Tensor], start: int,
|
||||
|
||||
|
||||
class SharedResizableBuffer:
|
||||
|
||||
def __init__(self):
|
||||
self.buffer = None
|
||||
|
||||
def get(self, shape: tuple[int, ...], device: torch.device,
|
||||
dtype: torch.dtype):
|
||||
def get(self, shape: tuple[int, ...], device: torch.device, dtype: torch.dtype):
|
||||
if shape == () or shape is None:
|
||||
return None
|
||||
shape_numel = prod(shape)
|
||||
if (self.buffer is None or self.buffer.numel() < shape_numel
|
||||
or self.buffer.device != device or self.buffer.dtype != dtype):
|
||||
if (
|
||||
self.buffer is None
|
||||
or self.buffer.numel() < shape_numel
|
||||
or self.buffer.device != device
|
||||
or self.buffer.dtype != dtype
|
||||
):
|
||||
self.buffer = torch.empty(shape_numel, device=device, dtype=dtype)
|
||||
return self.buffer[:shape_numel].view(*shape)
|
||||
|
||||
@@ -626,7 +644,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
"""
|
||||
|
||||
class SharedBuffers:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.fused_out = SharedResizableBuffer()
|
||||
self.workspace13 = SharedResizableBuffer()
|
||||
@@ -652,12 +669,14 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
self.prepare_finalize = prepare_finalize
|
||||
self.fused_experts = fused_experts
|
||||
self.shared_experts = shared_experts
|
||||
assert prepare_finalize.activation_format == \
|
||||
fused_experts.activation_formats[0], (
|
||||
f"{prepare_finalize.__class__.__name__}."
|
||||
f"{prepare_finalize.activation_format} == "
|
||||
f"{fused_experts.__class__.__name__}."
|
||||
f"{fused_experts.activation_formats[0]}")
|
||||
assert (
|
||||
prepare_finalize.activation_format == fused_experts.activation_formats[0]
|
||||
), (
|
||||
f"{prepare_finalize.__class__.__name__}."
|
||||
f"{prepare_finalize.activation_format} == "
|
||||
f"{fused_experts.__class__.__name__}."
|
||||
f"{fused_experts.activation_formats[0]}"
|
||||
)
|
||||
|
||||
def _do_fused_experts(
|
||||
self,
|
||||
@@ -677,14 +696,21 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> torch.Tensor:
|
||||
_, M, N, K, top_k = self.fused_experts.moe_problem_size(a1q, w1, w2, topk_ids)
|
||||
|
||||
_, M, N, K, top_k = self.fused_experts.moe_problem_size(
|
||||
a1q, w1, w2, topk_ids)
|
||||
|
||||
(workspace13_shape, workspace2_shape, fused_out_shape,
|
||||
workspace_dtype) = self.fused_experts.workspace_shapes(
|
||||
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
|
||||
expert_tokens_meta)
|
||||
(workspace13_shape, workspace2_shape, fused_out_shape, workspace_dtype) = (
|
||||
self.fused_experts.workspace_shapes(
|
||||
a1,
|
||||
a1q,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
top_k,
|
||||
global_num_experts,
|
||||
local_num_experts,
|
||||
expert_tokens_meta,
|
||||
)
|
||||
)
|
||||
|
||||
# select per-ubatch buffers to avoid cross-ubatch reuse under DBO
|
||||
ubatch_idx = dbo_current_ubatch_id()
|
||||
@@ -692,15 +718,16 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
|
||||
# We can reuse the memory between cache1 and cache3 because by the
|
||||
# time we need cache3, we're done with cache1.
|
||||
workspace13 = buffers.workspace13.get(workspace13_shape,
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
workspace2 = buffers.workspace2.get(workspace2_shape,
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
workspace13 = buffers.workspace13.get(
|
||||
workspace13_shape, device=a1.device, dtype=workspace_dtype
|
||||
)
|
||||
workspace2 = buffers.workspace2.get(
|
||||
workspace2_shape, device=a1.device, dtype=workspace_dtype
|
||||
)
|
||||
|
||||
assert fused_out is None or fused_out.shape == fused_out_shape, (
|
||||
f"fused_out {fused_out.shape} but expected {fused_out_shape}")
|
||||
f"fused_out {fused_out.shape} but expected {fused_out_shape}"
|
||||
)
|
||||
if fused_out is None:
|
||||
# reuse workspace13 for the output
|
||||
fused_out = _resize_cache(workspace13, fused_out_shape)
|
||||
@@ -741,9 +768,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> torch.Tensor:
|
||||
|
||||
_, M, N, K, top_k = self.fused_experts.moe_problem_size(
|
||||
a1q, w1, w2, topk_ids)
|
||||
_, M, N, K, top_k = self.fused_experts.moe_problem_size(a1q, w1, w2, topk_ids)
|
||||
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
num_chunks = cdiv(M, CHUNK_SIZE)
|
||||
@@ -775,18 +800,31 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
|
||||
# Construct the entire output that can then be processed in chunks.
|
||||
(_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes(
|
||||
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
|
||||
expert_tokens_meta)
|
||||
a1,
|
||||
a1q,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
top_k,
|
||||
global_num_experts,
|
||||
local_num_experts,
|
||||
expert_tokens_meta,
|
||||
)
|
||||
ubatch_idx = dbo_current_ubatch_id()
|
||||
buffers = self.shared_buffers[ubatch_idx]
|
||||
fused_out = buffers.fused_out.get(fused_out_shape,
|
||||
device=a1q.device,
|
||||
dtype=a1.dtype)
|
||||
fused_out = buffers.fused_out.get(
|
||||
fused_out_shape, device=a1q.device, dtype=a1.dtype
|
||||
)
|
||||
|
||||
def slice_input_tensors(
|
||||
chunk_idx: int
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], torch.Tensor, torch.Tensor]:
|
||||
chunk_idx: int,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
Optional[torch.Tensor],
|
||||
Optional[torch.Tensor],
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
]:
|
||||
s = chunk_idx * CHUNK_SIZE
|
||||
e = min(s + CHUNK_SIZE, M)
|
||||
return (
|
||||
@@ -799,7 +837,8 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
|
||||
def slice_output_tensor(chunk_idx: int) -> torch.Tensor:
|
||||
assert fused_out.size(0) % M == 0, (
|
||||
f"fused_out shape {fused_out.shape} vs M {M}")
|
||||
f"fused_out shape {fused_out.shape} vs M {M}"
|
||||
)
|
||||
factor = fused_out.size(0) // M
|
||||
out_chunk_size = CHUNK_SIZE * factor
|
||||
s = chunk_idx * out_chunk_size
|
||||
@@ -807,38 +846,45 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
return fused_out[s:e]
|
||||
|
||||
def slice_expert_tokens_metadata(
|
||||
full_expert_tokens_meta: ExpertTokensMetadata,
|
||||
chunk_topk_ids: torch.Tensor, local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor]) -> ExpertTokensMetadata:
|
||||
full_expert_tokens_meta: ExpertTokensMetadata,
|
||||
chunk_topk_ids: torch.Tensor,
|
||||
local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
) -> ExpertTokensMetadata:
|
||||
# The existing expert_num_tokens is for the entire a1q
|
||||
# input. Chunking forces recomputation of the number
|
||||
# of tokens assigned to each expert.
|
||||
c_expert_num_tokens = count_expert_num_tokens(
|
||||
chunk_topk_ids, local_num_experts, expert_map)
|
||||
chunk_topk_ids, local_num_experts, expert_map
|
||||
)
|
||||
|
||||
c_expert_num_tokens_cpu = None
|
||||
need_expert_num_tokens_cpu = (
|
||||
full_expert_tokens_meta.expert_num_tokens_cpu is not None)
|
||||
full_expert_tokens_meta.expert_num_tokens_cpu is not None
|
||||
)
|
||||
if need_expert_num_tokens_cpu:
|
||||
# This is blocking as some implementations need the count
|
||||
# on the CPU to determine appropriate input/out fused-moe
|
||||
# buffers
|
||||
c_expert_num_tokens_cpu = c_expert_num_tokens.to(
|
||||
"cpu", non_blocking=False)
|
||||
"cpu", non_blocking=False
|
||||
)
|
||||
|
||||
return ExpertTokensMetadata(
|
||||
expert_num_tokens=c_expert_num_tokens,
|
||||
expert_num_tokens_cpu=c_expert_num_tokens_cpu)
|
||||
expert_num_tokens_cpu=c_expert_num_tokens_cpu,
|
||||
)
|
||||
|
||||
for chunk_idx in range(num_chunks):
|
||||
c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = (
|
||||
slice_input_tensors(chunk_idx))
|
||||
slice_input_tensors(chunk_idx)
|
||||
)
|
||||
|
||||
c_expert_tokens_meta = None
|
||||
if expert_tokens_meta is not None:
|
||||
c_expert_tokens_meta = slice_expert_tokens_metadata(
|
||||
expert_tokens_meta, c_topk_ids, local_num_experts,
|
||||
expert_map)
|
||||
expert_tokens_meta, c_topk_ids, local_num_experts, expert_map
|
||||
)
|
||||
|
||||
self._do_fused_experts(
|
||||
fused_out=slice_output_tensor(chunk_idx),
|
||||
@@ -917,16 +963,21 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
# TODO(lucas): enable in follow-up
|
||||
assert not dbo_enabled()
|
||||
|
||||
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
|
||||
_expert_topk_weights) = self.prepare_finalize.prepare(
|
||||
a1,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.quant_config,
|
||||
)
|
||||
(
|
||||
a1q,
|
||||
a1q_scale,
|
||||
expert_tokens_meta,
|
||||
_expert_topk_ids,
|
||||
_expert_topk_weights,
|
||||
) = self.prepare_finalize.prepare(
|
||||
a1,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.quant_config,
|
||||
)
|
||||
else:
|
||||
# Overlap shared expert compute with all2all dispatch.
|
||||
dbo_maybe_run_recv_hook()
|
||||
@@ -943,8 +994,9 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
# TODO(lucas): refactor this in the alternative schedules followup
|
||||
# currently unpack if we have hook + receiver pair or just
|
||||
# receiver (see finalize_async docstring)
|
||||
hook, receiver = prepare_ret \
|
||||
if isinstance(prepare_ret, tuple) else (None, prepare_ret)
|
||||
hook, receiver = (
|
||||
prepare_ret if isinstance(prepare_ret, tuple) else (None, prepare_ret)
|
||||
)
|
||||
|
||||
if hook is not None:
|
||||
if dbo_enabled():
|
||||
@@ -956,13 +1008,19 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
else:
|
||||
hook()
|
||||
|
||||
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
|
||||
_expert_topk_weights) = receiver()
|
||||
(
|
||||
a1q,
|
||||
a1q_scale,
|
||||
expert_tokens_meta,
|
||||
_expert_topk_ids,
|
||||
_expert_topk_weights,
|
||||
) = receiver()
|
||||
|
||||
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
|
||||
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
|
||||
topk_weights = (topk_weights if _expert_topk_weights is None else
|
||||
_expert_topk_weights)
|
||||
topk_weights = (
|
||||
topk_weights if _expert_topk_weights is None else _expert_topk_weights
|
||||
)
|
||||
|
||||
fused_out = None
|
||||
|
||||
@@ -1022,8 +1080,11 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
# TODO(lucas): refactor this in the alternative schedules followup
|
||||
# currently unpack if we have hook + receiver pair or just
|
||||
# receiver (see finalize_async docstring)
|
||||
hook, receiver = finalize_ret \
|
||||
if isinstance(finalize_ret, tuple) else (None, finalize_ret)
|
||||
hook, receiver = (
|
||||
finalize_ret
|
||||
if isinstance(finalize_ret, tuple)
|
||||
else (None, finalize_ret)
|
||||
)
|
||||
|
||||
if hook is not None:
|
||||
if dbo_enabled():
|
||||
|
||||
Reference in New Issue
Block a user