[Kernels] Modular kernel refactor (#24812)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -337,6 +337,14 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
def num_dispatchers(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def output_is_reduced(self) -> bool:
|
||||
"""
|
||||
Indicates whether or not the output of finalize is reduced across all
|
||||
ranks.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# TODO: add supported activations method (return string)
|
||||
class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
@@ -493,11 +501,15 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
|
||||
"""
|
||||
Workspace type: The dtype to use for the workspace tensors.
|
||||
"""
|
||||
return act_dtype
|
||||
|
||||
@abstractmethod
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
aq: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
@@ -505,22 +517,33 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
"""
|
||||
Compute the shapes for the temporary and final outputs of the two gemms
|
||||
and activation in the fused expert function. Since the gemms are
|
||||
independent, the workspace for the first gemm can be shared with the
|
||||
workspace for the last gemm.
|
||||
|
||||
Inputs:
|
||||
- M: number of tokens.
|
||||
- N: Row (or column) dimension of expert weights.
|
||||
- K: hidden dimension
|
||||
- topk: The number of top-k experts to select.
|
||||
- global_num_experts: global number of experts.
|
||||
- local_num_experts: local number of experts due to DP/EP.
|
||||
- expert_tokens_meta: number of tokens per expert metadata for batched
|
||||
format.
|
||||
|
||||
Returns a tuple of:
|
||||
- workspace13 shape tuple: must be large enough to hold the
|
||||
result of either expert gemm.
|
||||
- workspace2 shape tuple: must be large enough to hold the
|
||||
result of the activation function.
|
||||
- output shape tuple: must be exact size of the final gemm output.
|
||||
- Workspace type: The dtype to use for the workspace tensors.
|
||||
- Note: in order for activation chunking to work, the first dimension
|
||||
of each tuple must be the number of tokens.
|
||||
- Note: workspace shapes can be 0 if the workspace is not needed.
|
||||
But in order for activation chunking to work, the first dimension
|
||||
of each tuple must be the number of tokens when the shape is
|
||||
not 0.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -561,7 +584,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
This function computes the intermediate result of a Mixture of Experts
|
||||
(MoE) layer using two sets of weights, w1 and w2.
|
||||
@@ -600,7 +623,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _chunk_scales(
|
||||
def _slice_scales(
|
||||
scales: Optional[torch.Tensor], start: int, end: int
|
||||
) -> Optional[torch.Tensor]:
|
||||
if scales is not None:
|
||||
@@ -615,9 +638,10 @@ class SharedResizableBuffer:
|
||||
def __init__(self):
|
||||
self.buffer = None
|
||||
|
||||
def get(self, shape: tuple[int, ...], device: torch.device, dtype: torch.dtype):
|
||||
if shape == () or shape is None:
|
||||
return None
|
||||
def get(
|
||||
self, shape: tuple[int, ...], device: torch.device, dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
assert shape != ()
|
||||
shape_numel = prod(shape)
|
||||
if (
|
||||
self.buffer is None
|
||||
@@ -678,131 +702,63 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
f"{fused_experts.activation_formats[0]}"
|
||||
)
|
||||
|
||||
def _do_fused_experts(
|
||||
def output_is_reduced(self) -> bool:
|
||||
"""
|
||||
Indicates whether or not the output of fused MoE kernel
|
||||
is reduced across all ranks.
|
||||
"""
|
||||
return self.prepare_finalize.output_is_reduced()
|
||||
|
||||
def _chunk_info(self, M: int) -> tuple[int, int]:
|
||||
"""
|
||||
Compute number of chunks and chunk size for given M.
|
||||
If chunking is not supported, set the CHUNK_SIZE to M so we
|
||||
get num_chunks == 1. Take max(M, 1) to avoid divide by zero.
|
||||
If there are no tokens to process, the number of chunks will be zero.
|
||||
"""
|
||||
CHUNK_SIZE = (
|
||||
max(M, 1)
|
||||
if not self.fused_experts.supports_chunking()
|
||||
else min(M, envs.VLLM_FUSED_MOE_CHUNK_SIZE)
|
||||
)
|
||||
num_chunks = cdiv(M, CHUNK_SIZE)
|
||||
# If there are no tokens, then there should be no loop iterations.
|
||||
assert M > 0 or num_chunks == 0
|
||||
return num_chunks, CHUNK_SIZE
|
||||
|
||||
def _allocate_buffers(
|
||||
self,
|
||||
fused_out: Optional[torch.Tensor],
|
||||
a1: torch.Tensor,
|
||||
a1q: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
out_dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
M_chunk: int,
|
||||
M_full: int,
|
||||
N: int,
|
||||
K: int,
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> torch.Tensor:
|
||||
_, M, N, K, top_k = self.fused_experts.moe_problem_size(a1q, w1, w2, topk_ids)
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Allocate temporary and output buffers for the fused experts op.
|
||||
Inputs:
|
||||
- out_dtype: output type of workspace and output tensors.
|
||||
- device: the device of the workspace and output tensors.
|
||||
See `workspace_shapes` for a description of the remainder of arguments.
|
||||
Returns a tuple of (workspace13, workspace2, output) tensors.
|
||||
"""
|
||||
assert M_full > 0 and M_chunk > 0
|
||||
|
||||
(workspace13_shape, workspace2_shape, fused_out_shape, workspace_dtype) = (
|
||||
self.fused_experts.workspace_shapes(
|
||||
a1,
|
||||
a1q,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
top_k,
|
||||
global_num_experts,
|
||||
local_num_experts,
|
||||
expert_tokens_meta,
|
||||
)
|
||||
)
|
||||
num_chunks, _ = self._chunk_info(M_full)
|
||||
|
||||
# select per-ubatch buffers to avoid cross-ubatch reuse under DBO
|
||||
ubatch_idx = dbo_current_ubatch_id()
|
||||
buffers = self.shared_buffers[ubatch_idx]
|
||||
workspace_dtype = self.fused_experts.workspace_dtype(out_dtype)
|
||||
|
||||
# We can reuse the memory between cache1 and cache3 because by the
|
||||
# time we need cache3, we're done with cache1.
|
||||
workspace13 = buffers.workspace13.get(
|
||||
workspace13_shape, device=a1.device, dtype=workspace_dtype
|
||||
)
|
||||
workspace2 = buffers.workspace2.get(
|
||||
workspace2_shape, device=a1.device, dtype=workspace_dtype
|
||||
)
|
||||
|
||||
assert fused_out is None or fused_out.shape == fused_out_shape, (
|
||||
f"fused_out {fused_out.shape} but expected {fused_out_shape}"
|
||||
)
|
||||
if fused_out is None:
|
||||
# reuse workspace13 for the output
|
||||
fused_out = _resize_cache(workspace13, fused_out_shape)
|
||||
|
||||
self.fused_experts.apply(
|
||||
fused_out,
|
||||
a1q,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale,
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
return fused_out
|
||||
|
||||
def _maybe_chunk_fused_experts(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1q: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> torch.Tensor:
|
||||
_, M, N, K, top_k = self.fused_experts.moe_problem_size(a1q, w1, w2, topk_ids)
|
||||
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
num_chunks = cdiv(M, CHUNK_SIZE)
|
||||
|
||||
# TODO(bnell): get rid of one level here, update slice functions
|
||||
# to nops on num_chunks==1
|
||||
|
||||
if not self.fused_experts.supports_chunking() or num_chunks == 1:
|
||||
return self._do_fused_experts(
|
||||
fused_out=None,
|
||||
a1=a1,
|
||||
a1q=a1q,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
local_num_experts=local_num_experts,
|
||||
expert_map=expert_map,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=self.fused_experts.a2_scale,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
# Chunking required case
|
||||
assert num_chunks > 1
|
||||
|
||||
# Construct the entire output that can then be processed in chunks.
|
||||
(_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes(
|
||||
a1,
|
||||
a1q,
|
||||
M,
|
||||
# Get intermediate workspace shapes based off the chunked M size.
|
||||
workspace13_shape, workspace2_shape, _ = self.fused_experts.workspace_shapes(
|
||||
M_chunk,
|
||||
N,
|
||||
K,
|
||||
top_k,
|
||||
@@ -810,102 +766,338 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
local_num_experts,
|
||||
expert_tokens_meta,
|
||||
)
|
||||
ubatch_idx = dbo_current_ubatch_id()
|
||||
buffers = self.shared_buffers[ubatch_idx]
|
||||
fused_out = buffers.fused_out.get(
|
||||
fused_out_shape, device=a1q.device, dtype=a1.dtype
|
||||
|
||||
# Get final output shape based on the full M size.
|
||||
_, _, fused_out_shape = self.fused_experts.workspace_shapes(
|
||||
M_full,
|
||||
N,
|
||||
K,
|
||||
top_k,
|
||||
global_num_experts,
|
||||
local_num_experts,
|
||||
expert_tokens_meta,
|
||||
)
|
||||
|
||||
def slice_input_tensors(
|
||||
chunk_idx: int,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
Optional[torch.Tensor],
|
||||
Optional[torch.Tensor],
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
]:
|
||||
s = chunk_idx * CHUNK_SIZE
|
||||
e = min(s + CHUNK_SIZE, M)
|
||||
return (
|
||||
a1q[s:e],
|
||||
_chunk_scales(a1q_scale, s, e),
|
||||
_chunk_scales(self.fused_experts.a2_scale, s, e),
|
||||
topk_ids[s:e],
|
||||
topk_weights[s:e],
|
||||
# We can reuse the memory between cache1 and cache3 because by the
|
||||
# time we need cache3, we're done with cache1.
|
||||
workspace13 = buffers.workspace13.get(
|
||||
workspace13_shape, device=device, dtype=workspace_dtype
|
||||
)
|
||||
workspace2 = buffers.workspace2.get(
|
||||
workspace2_shape, device=device, dtype=workspace_dtype
|
||||
)
|
||||
|
||||
# Construct the entire output that can then be processed in chunks.
|
||||
# Reuse workspace13 for the output in the non-chunked case as long
|
||||
# as it is large enough. This will not always be the case for standard
|
||||
# format experts and with experts that have empty workspaces.
|
||||
if num_chunks == 1 and prod(workspace13_shape) >= prod(fused_out_shape):
|
||||
fused_out = _resize_cache(workspace13, fused_out_shape)
|
||||
else:
|
||||
fused_out = buffers.fused_out.get(
|
||||
fused_out_shape, device=device, dtype=out_dtype
|
||||
)
|
||||
|
||||
def slice_output_tensor(chunk_idx: int) -> torch.Tensor:
|
||||
assert fused_out.size(0) % M == 0, (
|
||||
f"fused_out shape {fused_out.shape} vs M {M}"
|
||||
)
|
||||
factor = fused_out.size(0) // M
|
||||
out_chunk_size = CHUNK_SIZE * factor
|
||||
s = chunk_idx * out_chunk_size
|
||||
e = min(s + out_chunk_size, fused_out.size(0))
|
||||
return fused_out[s:e]
|
||||
return workspace13, workspace2, fused_out
|
||||
|
||||
def slice_expert_tokens_metadata(
|
||||
full_expert_tokens_meta: ExpertTokensMetadata,
|
||||
chunk_topk_ids: torch.Tensor,
|
||||
local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
) -> ExpertTokensMetadata:
|
||||
# The existing expert_num_tokens is for the entire a1q
|
||||
# input. Chunking forces recomputation of the number
|
||||
# of tokens assigned to each expert.
|
||||
c_expert_num_tokens = count_expert_num_tokens(
|
||||
chunk_topk_ids, local_num_experts, expert_map
|
||||
@staticmethod
|
||||
def _slice_output_tensor(
|
||||
fused_out: torch.Tensor,
|
||||
chunk_idx: int,
|
||||
num_chunks: int,
|
||||
CHUNK_SIZE: int,
|
||||
M: int,
|
||||
) -> torch.Tensor:
|
||||
if num_chunks == 1:
|
||||
return fused_out
|
||||
|
||||
assert fused_out.size(0) % M == 0, f"fused_out shape {fused_out.shape} vs M {M}"
|
||||
factor = fused_out.size(0) // M
|
||||
out_chunk_size = CHUNK_SIZE * factor
|
||||
s = chunk_idx * out_chunk_size
|
||||
e = min(s + out_chunk_size, fused_out.size(0))
|
||||
return fused_out[s:e]
|
||||
|
||||
@staticmethod
|
||||
def _slice_expert_tokens_metadata(
|
||||
num_chunks: int,
|
||||
full_expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
chunk_topk_ids: torch.Tensor,
|
||||
local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
) -> Optional[ExpertTokensMetadata]:
|
||||
if num_chunks == 1 or full_expert_tokens_meta is None:
|
||||
return full_expert_tokens_meta
|
||||
|
||||
# The existing expert_num_tokens is for the entire a1q
|
||||
# input. Chunking forces recomputation of the number
|
||||
# of tokens assigned to each expert.
|
||||
c_expert_num_tokens = count_expert_num_tokens(
|
||||
chunk_topk_ids, local_num_experts, expert_map
|
||||
)
|
||||
|
||||
c_expert_num_tokens_cpu = None
|
||||
need_expert_num_tokens_cpu = (
|
||||
full_expert_tokens_meta.expert_num_tokens_cpu is not None
|
||||
)
|
||||
if need_expert_num_tokens_cpu:
|
||||
# This is blocking as some implementations need the count
|
||||
# on the CPU to determine appropriate input/out fused-moe
|
||||
# buffers
|
||||
c_expert_num_tokens_cpu = c_expert_num_tokens.to("cpu", non_blocking=False)
|
||||
|
||||
return ExpertTokensMetadata(
|
||||
expert_num_tokens=c_expert_num_tokens,
|
||||
expert_num_tokens_cpu=c_expert_num_tokens_cpu,
|
||||
)
|
||||
|
||||
def _prepare(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
Optional[torch.Tensor],
|
||||
Optional[ExpertTokensMetadata],
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
]:
|
||||
"""
|
||||
The _prepare method is a wrapper around self.prepare_finalize.prepare
|
||||
that handles DBO and async.
|
||||
"""
|
||||
if not self.prepare_finalize.supports_async():
|
||||
# We shouldn't be running an a2a kernel that doesn't
|
||||
# support async prepare/finalize
|
||||
# TODO(lucas): enable in follow-up
|
||||
assert not dbo_enabled()
|
||||
|
||||
(
|
||||
a1q,
|
||||
a1q_scale,
|
||||
expert_tokens_meta,
|
||||
_expert_topk_ids,
|
||||
_expert_topk_weights,
|
||||
) = self.prepare_finalize.prepare(
|
||||
hidden_states,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.quant_config,
|
||||
)
|
||||
else:
|
||||
# Overlap shared expert compute with all2all dispatch.
|
||||
dbo_maybe_run_recv_hook()
|
||||
prepare_ret = self.prepare_finalize.prepare_async(
|
||||
hidden_states,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.quant_config,
|
||||
)
|
||||
|
||||
c_expert_num_tokens_cpu = None
|
||||
need_expert_num_tokens_cpu = (
|
||||
full_expert_tokens_meta.expert_num_tokens_cpu is not None
|
||||
# TODO(lucas): refactor this in the alternative schedules followup
|
||||
# currently unpack if we have hook + receiver pair or just
|
||||
# receiver (see finalize_async docstring)
|
||||
hook, receiver = (
|
||||
prepare_ret if isinstance(prepare_ret, tuple) else (None, prepare_ret)
|
||||
)
|
||||
if need_expert_num_tokens_cpu:
|
||||
# This is blocking as some implementations need the count
|
||||
# on the CPU to determine appropriate input/out fused-moe
|
||||
# buffers
|
||||
c_expert_num_tokens_cpu = c_expert_num_tokens.to(
|
||||
"cpu", non_blocking=False
|
||||
)
|
||||
|
||||
return ExpertTokensMetadata(
|
||||
expert_num_tokens=c_expert_num_tokens,
|
||||
expert_num_tokens_cpu=c_expert_num_tokens_cpu,
|
||||
if hook is not None:
|
||||
if dbo_enabled():
|
||||
# If DBO is being used, register the hook with the ubatch
|
||||
# context and call it in dbo_maybe_run_recv_hook instead of
|
||||
# passing it to the receiver.
|
||||
dbo_register_recv_hook(hook)
|
||||
dbo_yield()
|
||||
else:
|
||||
hook()
|
||||
|
||||
(
|
||||
a1q,
|
||||
a1q_scale,
|
||||
expert_tokens_meta,
|
||||
_expert_topk_ids,
|
||||
_expert_topk_weights,
|
||||
) = receiver()
|
||||
|
||||
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
|
||||
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
|
||||
topk_weights = (
|
||||
topk_weights if _expert_topk_weights is None else _expert_topk_weights
|
||||
)
|
||||
|
||||
return a1q, a1q_scale, expert_tokens_meta, topk_ids, topk_weights
|
||||
|
||||
def _fused_experts(
|
||||
self,
|
||||
in_dtype: torch.dtype,
|
||||
a1q: torch.Tensor,
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
) -> torch.Tensor:
|
||||
_, M_full, N, K, top_k = self.fused_experts.moe_problem_size(
|
||||
a1q, w1, w2, topk_ids
|
||||
)
|
||||
|
||||
num_chunks, CHUNK_SIZE = self._chunk_info(M_full)
|
||||
|
||||
def input_chunk_range(chunk_idx: int) -> tuple[int, int]:
|
||||
if num_chunks == 1:
|
||||
# Use a1q.size(0) here since batched format does not
|
||||
# keep M in the first dimension.
|
||||
return 0, a1q.size(0)
|
||||
else:
|
||||
s = chunk_idx * CHUNK_SIZE
|
||||
e = min(s + CHUNK_SIZE, M_full)
|
||||
return s, e
|
||||
|
||||
# This happens when none of the tokens from the all2all reach this
|
||||
# EP rank. Also, note that this is only relevant for CUDAGraph
|
||||
# incompatible all2all kernels like the DeepEP high-throughput
|
||||
# kernels. CUDAGraph compatible all2all kernels like the pplx
|
||||
# kernels and the DeepEP low-latency kernels are always batched
|
||||
# and can never run into the tensor.numel() == 0 case.
|
||||
if M_full == 0:
|
||||
assert num_chunks == 0
|
||||
workspace13 = None
|
||||
workspace2 = None
|
||||
fused_out = torch.empty_like(a1q)
|
||||
else:
|
||||
assert num_chunks > 0
|
||||
workspace13, workspace2, fused_out = self._allocate_buffers(
|
||||
in_dtype,
|
||||
a1q.device,
|
||||
CHUNK_SIZE,
|
||||
M_full,
|
||||
N,
|
||||
K,
|
||||
top_k,
|
||||
global_num_experts,
|
||||
local_num_experts,
|
||||
expert_tokens_meta,
|
||||
)
|
||||
|
||||
for chunk_idx in range(num_chunks):
|
||||
c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = (
|
||||
slice_input_tensors(chunk_idx)
|
||||
s, e = input_chunk_range(chunk_idx)
|
||||
|
||||
c_expert_tokens_meta = self._slice_expert_tokens_metadata(
|
||||
num_chunks,
|
||||
expert_tokens_meta,
|
||||
topk_ids[s:e],
|
||||
local_num_experts,
|
||||
expert_map,
|
||||
)
|
||||
|
||||
c_expert_tokens_meta = None
|
||||
if expert_tokens_meta is not None:
|
||||
c_expert_tokens_meta = slice_expert_tokens_metadata(
|
||||
expert_tokens_meta, c_topk_ids, local_num_experts, expert_map
|
||||
)
|
||||
c_fused_out = self._slice_output_tensor(
|
||||
fused_out, chunk_idx, num_chunks, CHUNK_SIZE, M_full
|
||||
)
|
||||
|
||||
self._do_fused_experts(
|
||||
fused_out=slice_output_tensor(chunk_idx),
|
||||
a1=a1,
|
||||
a1q=c_a1q,
|
||||
self.fused_experts.apply(
|
||||
output=c_fused_out,
|
||||
hidden_states=a1q[s:e],
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=c_topk_weights,
|
||||
topk_ids=c_topk_ids,
|
||||
topk_weights=topk_weights[s:e],
|
||||
topk_ids=topk_ids[s:e],
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
local_num_experts=local_num_experts,
|
||||
expert_map=expert_map,
|
||||
a1q_scale=c_a1q_scale,
|
||||
a2_scale=c_a2_scale,
|
||||
a1q_scale=_slice_scales(a1q_scale, s, e),
|
||||
a2_scale=_slice_scales(self.fused_experts.a2_scale, e, e),
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_tokens_meta=c_expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
return fused_out
|
||||
|
||||
def _finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_out: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
The _finalize method is a wrapper around self.prepare_finalize.finalize
|
||||
that handles DBO, async and shared expert overlap.
|
||||
"""
|
||||
shared_output: Optional[torch.Tensor] = None
|
||||
|
||||
if not self.prepare_finalize.supports_async():
|
||||
assert not dbo_enabled()
|
||||
|
||||
self.prepare_finalize.finalize(
|
||||
output,
|
||||
fused_out,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.finalize_weight_and_reduce_impl(),
|
||||
)
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
else:
|
||||
finalize_ret = self.prepare_finalize.finalize_async(
|
||||
output,
|
||||
fused_out,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.finalize_weight_and_reduce_impl(),
|
||||
)
|
||||
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
|
||||
# TODO(lucas): refactor this in the alternative schedules followup
|
||||
# currently unpack if we have hook + receiver pair or just
|
||||
# receiver (see finalize_async docstring)
|
||||
hook, receiver = (
|
||||
finalize_ret
|
||||
if isinstance(finalize_ret, tuple)
|
||||
else (None, finalize_ret)
|
||||
)
|
||||
|
||||
if hook is not None:
|
||||
if dbo_enabled():
|
||||
# If DBO is being used, register the hook with the ubatch
|
||||
# context and call it in dbo_maybe_run_recv_hook instead of
|
||||
# passing it to the receiver.
|
||||
dbo_register_recv_hook(hook)
|
||||
dbo_yield()
|
||||
else:
|
||||
hook()
|
||||
|
||||
receiver()
|
||||
|
||||
if self.shared_experts is None:
|
||||
return output
|
||||
else:
|
||||
assert shared_output is not None
|
||||
return shared_output, output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -947,156 +1139,45 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
|
||||
a1 = hidden_states
|
||||
output = a1 if inplace and self.shared_experts is None else torch.zeros_like(a1)
|
||||
if inplace and self.shared_experts is None:
|
||||
output = hidden_states
|
||||
else:
|
||||
output = torch.zeros_like(hidden_states)
|
||||
|
||||
local_num_experts = w1.size(0)
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = local_num_experts
|
||||
|
||||
if not self.prepare_finalize.supports_async():
|
||||
# We shouldn't be running an a2a kernel that doesn't
|
||||
# support async prepare/finalize
|
||||
# TODO(lucas): enable in follow-up
|
||||
assert not dbo_enabled()
|
||||
|
||||
(
|
||||
a1q,
|
||||
a1q_scale,
|
||||
expert_tokens_meta,
|
||||
_expert_topk_ids,
|
||||
_expert_topk_weights,
|
||||
) = self.prepare_finalize.prepare(
|
||||
a1,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.quant_config,
|
||||
)
|
||||
else:
|
||||
# Overlap shared expert compute with all2all dispatch.
|
||||
dbo_maybe_run_recv_hook()
|
||||
prepare_ret = self.prepare_finalize.prepare_async(
|
||||
a1,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.quant_config,
|
||||
)
|
||||
|
||||
# TODO(lucas): refactor this in the alternative schedules followup
|
||||
# currently unpack if we have hook + receiver pair or just
|
||||
# receiver (see finalize_async docstring)
|
||||
hook, receiver = (
|
||||
prepare_ret if isinstance(prepare_ret, tuple) else (None, prepare_ret)
|
||||
)
|
||||
|
||||
if hook is not None:
|
||||
if dbo_enabled():
|
||||
# If DBO is being used, register the hook with the ubatch
|
||||
# context and call it in dbo_maybe_run_recv_hook instead of
|
||||
# passing it to the receiver.
|
||||
dbo_register_recv_hook(hook)
|
||||
dbo_yield()
|
||||
else:
|
||||
hook()
|
||||
|
||||
(
|
||||
a1q,
|
||||
a1q_scale,
|
||||
expert_tokens_meta,
|
||||
_expert_topk_ids,
|
||||
_expert_topk_weights,
|
||||
) = receiver()
|
||||
|
||||
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
|
||||
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
|
||||
topk_weights = (
|
||||
topk_weights if _expert_topk_weights is None else _expert_topk_weights
|
||||
a1q, a1q_scale, expert_tokens_meta, topk_ids, topk_weights = self._prepare(
|
||||
hidden_states,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
fused_out = None
|
||||
fused_out = self._fused_experts(
|
||||
in_dtype=hidden_states.dtype,
|
||||
a1q=a1q,
|
||||
a1q_scale=a1q_scale,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
local_num_experts=local_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
)
|
||||
|
||||
if a1q.numel() == 0:
|
||||
# This happens when none of the tokens from the all2all reach this
|
||||
# EP rank. Also, note that this is only relevant for CUDAGraph
|
||||
# incompatible all2all kernels like the DeepEP high-throughput
|
||||
# kernels. CUDAGraph compatible all2all kernels like the pplx
|
||||
# kernels and the DeepEP low-latency kernels are always batched
|
||||
# and can never run into the tensor.numel() == 0 case.
|
||||
fused_out = torch.empty_like(a1q).to(dtype=a1.dtype)
|
||||
else:
|
||||
fused_out = self._maybe_chunk_fused_experts(
|
||||
a1=a1,
|
||||
a1q=a1q,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
local_num_experts=local_num_experts,
|
||||
expert_map=expert_map,
|
||||
a1q_scale=a1q_scale,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
shared_output: Optional[torch.Tensor] = None
|
||||
|
||||
if not self.prepare_finalize.supports_async():
|
||||
assert not dbo_enabled()
|
||||
|
||||
self.prepare_finalize.finalize(
|
||||
output,
|
||||
fused_out,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.finalize_weight_and_reduce_impl(),
|
||||
)
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(a1)
|
||||
else:
|
||||
finalize_ret = self.prepare_finalize.finalize_async(
|
||||
output,
|
||||
fused_out,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.finalize_weight_and_reduce_impl(),
|
||||
)
|
||||
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(a1)
|
||||
|
||||
# TODO(lucas): refactor this in the alternative schedules followup
|
||||
# currently unpack if we have hook + receiver pair or just
|
||||
# receiver (see finalize_async docstring)
|
||||
hook, receiver = (
|
||||
finalize_ret
|
||||
if isinstance(finalize_ret, tuple)
|
||||
else (None, finalize_ret)
|
||||
)
|
||||
|
||||
if hook is not None:
|
||||
if dbo_enabled():
|
||||
# If DBO is being used, register the hook with the ubatch
|
||||
# context and call it in dbo_maybe_run_recv_hook instead of
|
||||
# passing it to the receiver.
|
||||
dbo_register_recv_hook(hook)
|
||||
dbo_yield()
|
||||
else:
|
||||
hook()
|
||||
|
||||
receiver()
|
||||
|
||||
if self.shared_experts is None:
|
||||
return output
|
||||
else:
|
||||
assert shared_output is not None
|
||||
return shared_output, output
|
||||
return self._finalize(
|
||||
output,
|
||||
fused_out,
|
||||
hidden_states,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user