Add unpermute-aware fused MoE path and small-batch fallback (#29354)
Signed-off-by: Runkai Tao <rt572@physics.rutgers.edu> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -351,6 +351,7 @@ def fused_moe_kernel(
|
||||
# Block size for block-wise quantization
|
||||
group_n: tl.constexpr,
|
||||
group_k: tl.constexpr,
|
||||
naive_block_assignment: tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
@@ -386,6 +387,9 @@ def fused_moe_kernel(
|
||||
- expert_ids: A tensor containing the indices of the expert for each
|
||||
block. It determines which expert matrix from B should be used for
|
||||
each block in A.
|
||||
- naive_block_assignment: A boolean flag indicating whether to use naive
|
||||
token wise block assignment. If True, each block corresponds to a
|
||||
single token.
|
||||
This kernel performs the multiplication of a token by its corresponding
|
||||
expert matrix as determined by `expert_ids`. The sorting of
|
||||
`sorted_token_ids` by expert index and padding ensures divisibility by
|
||||
@@ -411,11 +415,20 @@ def fused_moe_kernel(
|
||||
# and accumulate
|
||||
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
||||
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
|
||||
offs = tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
||||
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
||||
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
||||
return
|
||||
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
||||
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
||||
if not naive_block_assignment:
|
||||
offs_token_id = pid_m * BLOCK_SIZE_M + offs
|
||||
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
||||
else:
|
||||
offs_token = tl.where(
|
||||
offs == 0,
|
||||
pid_m, # first element = pid_m
|
||||
num_valid_tokens, # remaining elements = constant
|
||||
)
|
||||
|
||||
token_mask = offs_token < num_valid_tokens
|
||||
|
||||
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
||||
@@ -557,7 +570,7 @@ def invoke_fused_moe_wna16_cuda_kernel(
|
||||
B_scale: torch.Tensor | None,
|
||||
B_zp: torch.Tensor | None,
|
||||
topk_weights: torch.Tensor | None,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
sorted_token_ids: torch.Tensor | None,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
mul_routed_weight: bool,
|
||||
@@ -705,7 +718,7 @@ def invoke_fused_moe_triton_kernel(
|
||||
A_scale: torch.Tensor | None,
|
||||
B_scale: torch.Tensor | None,
|
||||
topk_weights: torch.Tensor | None,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
sorted_token_ids: torch.Tensor | None,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
mul_routed_weight: bool,
|
||||
@@ -722,7 +735,7 @@ def invoke_fused_moe_triton_kernel(
|
||||
):
|
||||
assert topk_weights is not None or not mul_routed_weight
|
||||
assert topk_weights is None or topk_weights.stride(1) == 1
|
||||
assert sorted_token_ids.stride(0) == 1
|
||||
assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1
|
||||
|
||||
if use_fp8_w8a8 or use_int8_w8a8:
|
||||
assert B_scale is not None
|
||||
@@ -741,14 +754,18 @@ def invoke_fused_moe_triton_kernel(
|
||||
|
||||
M = A.size(0)
|
||||
num_tokens = M * top_k
|
||||
|
||||
EM = sorted_token_ids.size(0)
|
||||
if A.size(0) < config["BLOCK_SIZE_M"]:
|
||||
# optimize for small batch_size.
|
||||
# We assume that top_ids of each token is unique,
|
||||
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
|
||||
# and we can skip some invalid blocks.
|
||||
EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"])
|
||||
if sorted_token_ids is not None:
|
||||
EM = sorted_token_ids.size(0)
|
||||
if A.size(0) < config["BLOCK_SIZE_M"]:
|
||||
# optimize for small batch_size.
|
||||
# We assume that top_ids of each token is unique,
|
||||
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
|
||||
# and we can skip some invalid blocks.
|
||||
EM = min(
|
||||
sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"]
|
||||
)
|
||||
else:
|
||||
EM = num_tokens * config["BLOCK_SIZE_M"]
|
||||
grid = lambda META: (
|
||||
triton.cdiv(EM, META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]),
|
||||
@@ -798,6 +815,7 @@ def invoke_fused_moe_triton_kernel(
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
per_channel_quant=per_channel_quant,
|
||||
naive_block_assignment=(sorted_token_ids is None),
|
||||
HAS_BIAS=HAS_BIAS,
|
||||
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
**config,
|
||||
@@ -812,7 +830,7 @@ def dispatch_fused_moe_kernel(
|
||||
B_scale: torch.Tensor | None,
|
||||
B_zp: torch.Tensor | None,
|
||||
topk_weights: torch.Tensor | None,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
sorted_token_ids: torch.Tensor | None,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
mul_routed_weight: bool,
|
||||
@@ -829,7 +847,7 @@ def dispatch_fused_moe_kernel(
|
||||
) -> None:
|
||||
assert topk_weights is not None or not mul_routed_weight
|
||||
assert topk_weights is None or topk_weights.stride(1) == 1
|
||||
assert sorted_token_ids.stride(0) == 1
|
||||
assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1
|
||||
|
||||
M = A.size(0)
|
||||
num_tokens = M * top_k
|
||||
@@ -2165,14 +2183,37 @@ def fused_experts_impl(
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||
curr_topk_ids,
|
||||
config["BLOCK_SIZE_M"],
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
ignore_invalid_experts=True,
|
||||
# SPARSITY_FACTOR is a heuristic margin ensuring tokens_in_chunk * top_k
|
||||
# activates only a small fraction of total experts
|
||||
SPARSITY_FACTOR = 4
|
||||
# block quantized code path is not implemented yet.
|
||||
naive_block_assignment = (
|
||||
expert_map is None
|
||||
and tokens_in_chunk * top_k_num * SPARSITY_FACTOR <= global_num_experts
|
||||
and not (
|
||||
(use_int8_w8a16 or use_int4_w4a16)
|
||||
and block_shape is not None
|
||||
and block_shape[1] > 0
|
||||
)
|
||||
)
|
||||
|
||||
if not naive_block_assignment:
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||
curr_topk_ids,
|
||||
config["BLOCK_SIZE_M"],
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
ignore_invalid_experts=True,
|
||||
)
|
||||
else:
|
||||
max_num_tokens_padded = topk_ids.numel() * config["BLOCK_SIZE_M"]
|
||||
expert_ids = curr_topk_ids.view(-1)
|
||||
num_tokens_post_padded = torch.empty(
|
||||
(1), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
num_tokens_post_padded.fill_(max_num_tokens_padded)
|
||||
sorted_token_ids = None
|
||||
|
||||
dispatch_fused_moe_kernel(
|
||||
qcurr_hidden_states,
|
||||
w1,
|
||||
|
||||
Reference in New Issue
Block a user