[Kernels][Bugfix] Use torch op for all kernels in FusedMoE forward. Add additional testing for cudagraphs. (#19717)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-06-25 02:22:58 -04:00
committed by GitHub
parent f59fc60fb3
commit 015fab8c2f
14 changed files with 379 additions and 238 deletions

View File

@@ -488,10 +488,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
if use_fp8_w8a8 or use_int8_w8a8:
assert B_scale is not None
assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0])
== B_scale.shape[-2])
assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1])
== B_scale.shape[-1])
assert (block_shape is None
or triton.cdiv(B.size(-2), block_shape[0]) == B_scale.size(-2))
assert (block_shape is None
or triton.cdiv(B.size(-1), block_shape[1]) == B_scale.size(-1))
elif use_int8_w8a16 or use_int4_w4a16:
assert B_scale is not None
@@ -500,19 +500,19 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
assert A_scale is None
assert B_scale is None
M = A.shape[0]
M = A.size(0)
num_tokens = M * top_k
EM = sorted_token_ids.shape[0]
if A.shape[0] < config["BLOCK_SIZE_M"]:
EM = sorted_token_ids.size(0)
if A.size(0) < config["BLOCK_SIZE_M"]:
# optimize for small batch_size.
# We assume that top_ids of each token is unique, so
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
# and we can skip some invalid blocks.
EM = min(sorted_token_ids.shape[0],
A.shape[0] * top_k * config['BLOCK_SIZE_M'])
EM = min(sorted_token_ids.size(0),
A.size(0) * top_k * config['BLOCK_SIZE_M'])
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv(
B.shape[1], META['BLOCK_SIZE_N']), )
B.size(1), META['BLOCK_SIZE_N']), )
if (use_int8_w8a16 or use_int4_w4a16) and \
block_shape is not None and block_shape[1] > 0:
@@ -522,16 +522,16 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_moe_wna16_cuda = should_moe_wna16_use_cuda(
num_valid_tokens=num_tokens,
group_size=block_shape[1],
num_experts=B.shape[0],
num_experts=B.size(0),
bit=4 if use_int4_w4a16 else 8)
config = config.copy()
config.update(
get_moe_wna16_block_config(config=config,
use_moe_wna16_cuda=use_moe_wna16_cuda,
num_valid_tokens=num_tokens,
size_k=A.shape[1],
size_n=B.shape[1],
num_experts=B.shape[1],
size_k=A.size(1),
size_n=B.size(1),
num_experts=B.size(1),
group_size=block_shape[1],
real_top_k=top_k,
block_size_m=config["BLOCK_SIZE_M"]))
@@ -556,8 +556,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
B.shape[1],
A.shape[1],
B.size(1),
A.size(1),
EM,
num_tokens,
A.stride(0),
@@ -573,7 +573,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
B_zp.stride(0) if B_zp is not None else 0,
B_zp.stride(2) if B_zp is not None else 0,
B_zp.stride(1) if B_zp is not None else 0,
block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0,
block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0,
group_size=block_shape[1],
MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k,
@@ -599,8 +599,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
B.shape[1],
B.shape[2],
B.size(1),
B.size(2),
EM,
num_tokens,
A.stride(0),
@@ -818,7 +818,7 @@ def try_get_optimal_moe_config(
M: int,
is_marlin: bool = False,
block_shape: Optional[list[int]] = None,
):
) -> dict[str, int]:
from vllm.model_executor.layers.fused_moe import get_config
override_config = get_config()
if override_config:
@@ -873,10 +873,10 @@ def fused_topk(
renormalize: bool,
indices_type: Optional[torch.dtype] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert hidden_states.shape[0] == gating_output.shape[0], (
assert hidden_states.size(0) == gating_output.size(0), (
"Number of tokens mismatch")
M, _ = hidden_states.shape
M, _ = hidden_states.size()
topk_weights = torch.empty(M,
topk,
@@ -915,7 +915,7 @@ def grouped_topk(
e_score_correction_bias: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.shape[0] == gating_output.shape[0], (
assert hidden_states.size(0) == gating_output.size(0), (
"Number of tokens mismatch")
if scoring_func == "softmax":
@@ -925,7 +925,7 @@ def grouped_topk(
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
num_token = scores.shape[0]
num_token = scores.size(0)
if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
@@ -942,7 +942,7 @@ def grouped_topk(
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = group_mask.unsqueeze(-1).expand(
num_token, num_expert_group,
scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
scores.size(-1) // num_expert_group).reshape(num_token, -1) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(),
float("-inf")) # [n, e]
@@ -1162,7 +1162,7 @@ def fused_experts(hidden_states: torch.Tensor,
allow_deep_gemm: bool = False) -> torch.Tensor:
# For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available.
N = w1.shape[1]
N = w1.size(1)
if (allow_deep_gemm and use_fp8_w8a8 and N > 512
and _valid_deep_gemm(hidden_states, w1, w2)):
assert apply_router_weight_on_input is False
@@ -1233,13 +1233,13 @@ def fused_experts_impl(
) -> torch.Tensor:
# Check constraints.
if use_int4_w4a16:
assert hidden_states.shape[1] // 2 == w1.shape[
2], "Hidden size mismatch"
assert hidden_states.size(1) // 2 == w1.size(2), (
"Hidden size mismatch")
else:
assert hidden_states.shape[1] == w1.shape[2], (
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}")
assert hidden_states.size(1) == w1.size(2), (
f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}")
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert topk_weights.size() == topk_ids.size(), "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
@@ -1247,12 +1247,12 @@ def fused_experts_impl(
torch.float32, torch.float16, torch.bfloat16
]
num_tokens = hidden_states.shape[0]
E, N, _ = w1.shape
K = w2.shape[1]
num_tokens = hidden_states.size(0)
E, N, _ = w1.size()
K = w2.size(1)
if global_num_experts == -1:
global_num_experts = E
top_k_num = topk_ids.shape[1]
top_k_num = topk_ids.size(1)
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
@@ -1269,8 +1269,8 @@ def fused_experts_impl(
get_config_func = functools.partial(
try_get_optimal_moe_config,
w1.shape,
w2.shape,
w1.size(),
w2.size(),
top_k_num,
config_dtype,
block_shape=block_shape,
@@ -1310,7 +1310,7 @@ def fused_experts_impl(
min((chunk + 1) * CHUNK_SIZE,
num_tokens))
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
tokens_in_chunk, _ = curr_hidden_states.shape
tokens_in_chunk, _ = curr_hidden_states.size()
if tokens_in_chunk == 0:
break
@@ -1322,7 +1322,7 @@ def fused_experts_impl(
# do not need to be adjusted.
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk *
topk_ids.shape[1]]
topk_ids.size(1)]
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
config = get_config_func(tokens_in_chunk)
@@ -1398,7 +1398,7 @@ def fused_experts_impl(
per_channel_quant=per_channel_quant,
block_shape=block_shape)
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
out_hidden_states[begin_chunk_idx:end_chunk_idx])
return out_hidden_states
@@ -1611,8 +1611,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
dtype=hidden_states.dtype)
config = try_get_optimal_moe_config(
w1.shape,
w2.shape,
w1.size(),
w2.size(),
top_k_num,
config_dtype,
num_tokens,