[MoE Refactor] Split invoke_fused_moe_kernel (#31050)

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
Yongye Zhu
2026-01-02 13:47:15 -08:00
committed by GitHub
parent 6ef770df7c
commit 5a468ff7c7

View File

@@ -541,11 +541,12 @@ def fused_moe_kernel(
tl.store(c_ptrs, accumulator, mask=c_mask) tl.store(c_ptrs, accumulator, mask=c_mask)
def invoke_fused_moe_kernel( # NOTE(zyongye): we can remove all the wna16 kernel
# once we drop off sm75 support
def invoke_fused_moe_wna16_cuda_kernel(
A: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, B: torch.Tensor,
C: torch.Tensor, C: torch.Tensor,
A_scale: torch.Tensor | None,
B_scale: torch.Tensor | None, B_scale: torch.Tensor | None,
B_zp: torch.Tensor | None, B_zp: torch.Tensor | None,
topk_weights: torch.Tensor | None, topk_weights: torch.Tensor | None,
@@ -555,69 +556,21 @@ def invoke_fused_moe_kernel(
mul_routed_weight: bool, mul_routed_weight: bool,
top_k: int, top_k: int,
config: dict[str, Any], config: dict[str, Any],
compute_type: tl.dtype, block_shape: list[int],
use_fp8_w8a8: bool, ):
use_int8_w8a8: bool, assert B_scale is not None and B_scale.ndim == 3
use_int8_w8a16: bool, assert B_zp is None or B_zp.ndim == 3
use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: list[int] | None = None,
B_bias: torch.Tensor | None = None,
) -> None:
assert topk_weights is not None or not mul_routed_weight
assert topk_weights is None or topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
if use_fp8_w8a8 or use_int8_w8a8:
assert B_scale is not None
assert block_shape is None or triton.cdiv(
B.size(-2), block_shape[0]
) == B_scale.size(-2)
assert block_shape is None or triton.cdiv(
B.size(-1), block_shape[1]
) == B_scale.size(-1)
elif use_int8_w8a16 or use_int4_w4a16:
assert B_scale is not None
assert block_shape is None or block_shape[0] == 0 assert block_shape is None or block_shape[0] == 0
else:
assert A_scale is None
assert B_scale is None
M = A.size(0) M = A.size(0)
num_tokens = M * top_k num_tokens = M * top_k
bit = 4
EM = sorted_token_ids.size(0)
if A.size(0) < config["BLOCK_SIZE_M"]:
# optimize for small batch_size.
# We assume that top_ids of each token is unique,
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
# and we can skip some invalid blocks.
EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"])
grid = lambda META: (
triton.cdiv(EM, META["BLOCK_SIZE_M"])
* triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]),
)
HAS_BIAS = B_bias is not None
if (
(use_int8_w8a16 or use_int4_w4a16)
and block_shape is not None
and block_shape[1] > 0
):
assert B_scale is not None and B_scale.ndim == 3
assert B_zp is None or B_zp.ndim == 3
use_moe_wna16_cuda = should_moe_wna16_use_cuda(
num_valid_tokens=num_tokens,
group_size=block_shape[1],
num_experts=B.size(0),
bit=4 if use_int4_w4a16 else 8,
)
config = config.copy() config = config.copy()
config.update( config.update(
get_moe_wna16_block_config( get_moe_wna16_block_config(
config=config, config=config,
use_moe_wna16_cuda=use_moe_wna16_cuda, use_moe_wna16_cuda=True,
num_valid_tokens=num_tokens, num_valid_tokens=num_tokens,
size_k=A.size(1), size_k=A.size(1),
size_n=B.size(1), size_n=B.size(1),
@@ -628,8 +581,6 @@ def invoke_fused_moe_kernel(
) )
) )
if use_moe_wna16_cuda:
bit = 4 if use_int4_w4a16 else 8
ops.moe_wna16_gemm( ops.moe_wna16_gemm(
A, A,
C, C,
@@ -646,7 +597,61 @@ def invoke_fused_moe_kernel(
config["BLOCK_SIZE_K"], config["BLOCK_SIZE_K"],
bit, bit,
) )
return
# NOTE(zyongye): we can remove all the wna16 kernel
# once we drop off sm75 support
def invoke_fused_moe_wna16_triton_kernel(
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
B_scale: torch.Tensor | None,
B_zp: torch.Tensor | None,
topk_weights: torch.Tensor | None,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool,
top_k: int,
config: dict[str, Any],
compute_type: tl.dtype,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
block_shape: list[int],
):
assert B_scale is not None and B_scale.ndim == 3
assert B_zp is None or B_zp.ndim == 3
assert block_shape is None or block_shape[0] == 0
M = A.size(0)
num_tokens = M * top_k
EM = sorted_token_ids.size(0)
if A.size(0) < config["BLOCK_SIZE_M"]:
# optimize for small batch_size.
# We assume that top_ids of each token is unique,
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
# and we can skip some invalid blocks.
EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"])
grid = lambda META: (
triton.cdiv(EM, META["BLOCK_SIZE_M"])
* triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]),
)
config = config.copy()
config.update(
get_moe_wna16_block_config(
config=config,
use_moe_wna16_cuda=False,
num_valid_tokens=num_tokens,
size_k=A.size(1),
size_n=B.size(1),
num_experts=B.size(1),
group_size=block_shape[1],
real_top_k=top_k,
block_size_m=config["BLOCK_SIZE_M"],
)
)
fused_moe_kernel_gptq_awq[grid]( fused_moe_kernel_gptq_awq[grid](
A, A,
B, B,
@@ -684,7 +689,65 @@ def invoke_fused_moe_kernel(
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
**config, **config,
) )
def invoke_fused_moe_triton_kernel(
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
A_scale: torch.Tensor | None,
B_scale: torch.Tensor | None,
topk_weights: torch.Tensor | None,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool,
top_k: int,
config: dict[str, Any],
compute_type: tl.dtype,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: list[int] | None = None,
B_bias: torch.Tensor | None = None,
):
assert topk_weights is not None or not mul_routed_weight
assert topk_weights is None or topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
if use_fp8_w8a8 or use_int8_w8a8:
assert B_scale is not None
assert block_shape is None or triton.cdiv(
B.size(-2), block_shape[0]
) == B_scale.size(-2)
assert block_shape is None or triton.cdiv(
B.size(-1), block_shape[1]
) == B_scale.size(-1)
elif use_int8_w8a16 or use_int4_w4a16:
assert B_scale is not None
assert block_shape is None or block_shape[0] == 0
else: else:
assert A_scale is None
assert B_scale is None
M = A.size(0)
num_tokens = M * top_k
EM = sorted_token_ids.size(0)
if A.size(0) < config["BLOCK_SIZE_M"]:
# optimize for small batch_size.
# We assume that top_ids of each token is unique,
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
# and we can skip some invalid blocks.
EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"])
grid = lambda META: (
triton.cdiv(EM, META["BLOCK_SIZE_M"])
* triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]),
)
HAS_BIAS = B_bias is not None
config = config.copy() config = config.copy()
config["SPLIT_K"] = 1 config["SPLIT_K"] = 1
BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K") BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
@@ -734,6 +797,109 @@ def invoke_fused_moe_kernel(
) )
def dispatch_fused_moe_kernel(
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
A_scale: torch.Tensor | None,
B_scale: torch.Tensor | None,
B_zp: torch.Tensor | None,
topk_weights: torch.Tensor | None,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool,
top_k: int,
config: dict[str, Any],
compute_type: tl.dtype,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: list[int] | None = None,
B_bias: torch.Tensor | None = None,
) -> None:
assert topk_weights is not None or not mul_routed_weight
assert topk_weights is None or topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
M = A.size(0)
num_tokens = M * top_k
if (use_int8_w8a16 or use_int4_w4a16) and (
block_shape is not None and block_shape[1] > 0
):
assert B_bias is None
use_moe_wna16_cuda = should_moe_wna16_use_cuda(
num_valid_tokens=num_tokens,
group_size=block_shape[1],
num_experts=B.size(0),
bit=4 if use_int4_w4a16 else 8,
)
if use_moe_wna16_cuda:
invoke_fused_moe_wna16_cuda_kernel(
A,
B,
C,
B_scale,
B_zp,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
mul_routed_weight,
top_k,
config,
block_shape,
)
return
invoke_fused_moe_wna16_triton_kernel(
A,
B,
C,
B_scale,
B_zp,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
mul_routed_weight,
top_k,
config,
compute_type,
use_int8_w8a16,
use_int4_w4a16,
block_shape,
)
else:
invoke_fused_moe_triton_kernel(
A,
B,
C,
A_scale,
B_scale,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
mul_routed_weight,
top_k,
config,
compute_type,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
use_int4_w4a16,
per_channel_quant,
block_shape,
B_bias,
)
@triton.jit @triton.jit
def compute_identity_kernel( def compute_identity_kernel(
top_k: int, top_k: int,
@@ -1997,7 +2163,7 @@ def fused_experts_impl(
ignore_invalid_experts=True, ignore_invalid_experts=True,
) )
invoke_fused_moe_kernel( dispatch_fused_moe_kernel(
qcurr_hidden_states, qcurr_hidden_states,
w1, w1,
intermediate_cache1, intermediate_cache1,
@@ -2056,7 +2222,7 @@ def fused_experts_impl(
if expert_map is not None: if expert_map is not None:
intermediate_cache3.zero_() intermediate_cache3.zero_()
invoke_fused_moe_kernel( dispatch_fused_moe_kernel(
qintermediate_cache2, qintermediate_cache2,
w2, w2,
intermediate_cache3, intermediate_cache3,
@@ -2207,13 +2373,12 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map
) )
invoke_fused_moe_kernel( invoke_fused_moe_triton_kernel(
hidden_states, hidden_states,
w1, w1,
intermediate_cache1, intermediate_cache1,
a1q_scale, a1q_scale,
self.w1_scale, self.w1_scale,
self.w1_zp,
None, # topk_weights None, # topk_weights
sorted_token_ids, sorted_token_ids,
expert_ids, expert_ids,
@@ -2245,13 +2410,12 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.block_shape, self.block_shape,
) )
invoke_fused_moe_kernel( invoke_fused_moe_triton_kernel(
qintermediate_cache2, qintermediate_cache2,
w2, w2,
intermediate_cache3, intermediate_cache3,
a2q_scale, a2q_scale,
self.w2_scale, self.w2_scale,
self.w2_zp,
topk_weights, topk_weights,
sorted_token_ids, sorted_token_ids,
expert_ids, expert_ids,