[Kernels][Bugfix] Use torch op for all kernels in FusedMoE forward. Add additional testing for cudagraphs. (#19717)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -488,10 +488,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
|
||||
if use_fp8_w8a8 or use_int8_w8a8:
|
||||
assert B_scale is not None
|
||||
assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0])
|
||||
== B_scale.shape[-2])
|
||||
assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1])
|
||||
== B_scale.shape[-1])
|
||||
assert (block_shape is None
|
||||
or triton.cdiv(B.size(-2), block_shape[0]) == B_scale.size(-2))
|
||||
assert (block_shape is None
|
||||
or triton.cdiv(B.size(-1), block_shape[1]) == B_scale.size(-1))
|
||||
|
||||
elif use_int8_w8a16 or use_int4_w4a16:
|
||||
assert B_scale is not None
|
||||
@@ -500,19 +500,19 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
assert A_scale is None
|
||||
assert B_scale is None
|
||||
|
||||
M = A.shape[0]
|
||||
M = A.size(0)
|
||||
num_tokens = M * top_k
|
||||
|
||||
EM = sorted_token_ids.shape[0]
|
||||
if A.shape[0] < config["BLOCK_SIZE_M"]:
|
||||
EM = sorted_token_ids.size(0)
|
||||
if A.size(0) < config["BLOCK_SIZE_M"]:
|
||||
# optimize for small batch_size.
|
||||
# We assume that top_ids of each token is unique, so
|
||||
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
|
||||
# and we can skip some invalid blocks.
|
||||
EM = min(sorted_token_ids.shape[0],
|
||||
A.shape[0] * top_k * config['BLOCK_SIZE_M'])
|
||||
EM = min(sorted_token_ids.size(0),
|
||||
A.size(0) * top_k * config['BLOCK_SIZE_M'])
|
||||
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv(
|
||||
B.shape[1], META['BLOCK_SIZE_N']), )
|
||||
B.size(1), META['BLOCK_SIZE_N']), )
|
||||
|
||||
if (use_int8_w8a16 or use_int4_w4a16) and \
|
||||
block_shape is not None and block_shape[1] > 0:
|
||||
@@ -522,16 +522,16 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
use_moe_wna16_cuda = should_moe_wna16_use_cuda(
|
||||
num_valid_tokens=num_tokens,
|
||||
group_size=block_shape[1],
|
||||
num_experts=B.shape[0],
|
||||
num_experts=B.size(0),
|
||||
bit=4 if use_int4_w4a16 else 8)
|
||||
config = config.copy()
|
||||
config.update(
|
||||
get_moe_wna16_block_config(config=config,
|
||||
use_moe_wna16_cuda=use_moe_wna16_cuda,
|
||||
num_valid_tokens=num_tokens,
|
||||
size_k=A.shape[1],
|
||||
size_n=B.shape[1],
|
||||
num_experts=B.shape[1],
|
||||
size_k=A.size(1),
|
||||
size_n=B.size(1),
|
||||
num_experts=B.size(1),
|
||||
group_size=block_shape[1],
|
||||
real_top_k=top_k,
|
||||
block_size_m=config["BLOCK_SIZE_M"]))
|
||||
@@ -556,8 +556,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
B.shape[1],
|
||||
A.shape[1],
|
||||
B.size(1),
|
||||
A.size(1),
|
||||
EM,
|
||||
num_tokens,
|
||||
A.stride(0),
|
||||
@@ -573,7 +573,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
B_zp.stride(0) if B_zp is not None else 0,
|
||||
B_zp.stride(2) if B_zp is not None else 0,
|
||||
B_zp.stride(1) if B_zp is not None else 0,
|
||||
block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0,
|
||||
block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0,
|
||||
group_size=block_shape[1],
|
||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||
top_k=top_k,
|
||||
@@ -599,8 +599,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
B.shape[1],
|
||||
B.shape[2],
|
||||
B.size(1),
|
||||
B.size(2),
|
||||
EM,
|
||||
num_tokens,
|
||||
A.stride(0),
|
||||
@@ -818,7 +818,7 @@ def try_get_optimal_moe_config(
|
||||
M: int,
|
||||
is_marlin: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
):
|
||||
) -> dict[str, int]:
|
||||
from vllm.model_executor.layers.fused_moe import get_config
|
||||
override_config = get_config()
|
||||
if override_config:
|
||||
@@ -873,10 +873,10 @@ def fused_topk(
|
||||
renormalize: bool,
|
||||
indices_type: Optional[torch.dtype] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
||||
assert hidden_states.size(0) == gating_output.size(0), (
|
||||
"Number of tokens mismatch")
|
||||
|
||||
M, _ = hidden_states.shape
|
||||
M, _ = hidden_states.size()
|
||||
|
||||
topk_weights = torch.empty(M,
|
||||
topk,
|
||||
@@ -915,7 +915,7 @@ def grouped_topk(
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
||||
assert hidden_states.size(0) == gating_output.size(0), (
|
||||
"Number of tokens mismatch")
|
||||
|
||||
if scoring_func == "softmax":
|
||||
@@ -925,7 +925,7 @@ def grouped_topk(
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
num_token = scores.shape[0]
|
||||
num_token = scores.size(0)
|
||||
if e_score_correction_bias is not None:
|
||||
# Store original scores before applying correction bias. We use biased
|
||||
# scores for expert selection but original scores for routing weights
|
||||
@@ -942,7 +942,7 @@ def grouped_topk(
|
||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||||
score_mask = group_mask.unsqueeze(-1).expand(
|
||||
num_token, num_expert_group,
|
||||
scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
|
||||
scores.size(-1) // num_expert_group).reshape(num_token, -1) # [n, e]
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(),
|
||||
float("-inf")) # [n, e]
|
||||
|
||||
@@ -1162,7 +1162,7 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
allow_deep_gemm: bool = False) -> torch.Tensor:
|
||||
# For now, disable DeepGemm for small N (<= 512) until better
|
||||
# permute/unpermute ops are available.
|
||||
N = w1.shape[1]
|
||||
N = w1.size(1)
|
||||
if (allow_deep_gemm and use_fp8_w8a8 and N > 512
|
||||
and _valid_deep_gemm(hidden_states, w1, w2)):
|
||||
assert apply_router_weight_on_input is False
|
||||
@@ -1233,13 +1233,13 @@ def fused_experts_impl(
|
||||
) -> torch.Tensor:
|
||||
# Check constraints.
|
||||
if use_int4_w4a16:
|
||||
assert hidden_states.shape[1] // 2 == w1.shape[
|
||||
2], "Hidden size mismatch"
|
||||
assert hidden_states.size(1) // 2 == w1.size(2), (
|
||||
"Hidden size mismatch")
|
||||
else:
|
||||
assert hidden_states.shape[1] == w1.shape[2], (
|
||||
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}")
|
||||
assert hidden_states.size(1) == w1.size(2), (
|
||||
f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}")
|
||||
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert topk_weights.size() == topk_ids.size(), "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
@@ -1247,12 +1247,12 @@ def fused_experts_impl(
|
||||
torch.float32, torch.float16, torch.bfloat16
|
||||
]
|
||||
|
||||
num_tokens = hidden_states.shape[0]
|
||||
E, N, _ = w1.shape
|
||||
K = w2.shape[1]
|
||||
num_tokens = hidden_states.size(0)
|
||||
E, N, _ = w1.size()
|
||||
K = w2.size(1)
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
top_k_num = topk_ids.shape[1]
|
||||
top_k_num = topk_ids.size(1)
|
||||
# We execute the fused_moe kernel in chunks to circumvent this issue:
|
||||
# https://github.com/vllm-project/vllm/issues/5938
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
@@ -1269,8 +1269,8 @@ def fused_experts_impl(
|
||||
|
||||
get_config_func = functools.partial(
|
||||
try_get_optimal_moe_config,
|
||||
w1.shape,
|
||||
w2.shape,
|
||||
w1.size(),
|
||||
w2.size(),
|
||||
top_k_num,
|
||||
config_dtype,
|
||||
block_shape=block_shape,
|
||||
@@ -1310,7 +1310,7 @@ def fused_experts_impl(
|
||||
min((chunk + 1) * CHUNK_SIZE,
|
||||
num_tokens))
|
||||
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
|
||||
tokens_in_chunk, _ = curr_hidden_states.shape
|
||||
tokens_in_chunk, _ = curr_hidden_states.size()
|
||||
|
||||
if tokens_in_chunk == 0:
|
||||
break
|
||||
@@ -1322,7 +1322,7 @@ def fused_experts_impl(
|
||||
# do not need to be adjusted.
|
||||
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
|
||||
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk *
|
||||
topk_ids.shape[1]]
|
||||
topk_ids.size(1)]
|
||||
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
|
||||
config = get_config_func(tokens_in_chunk)
|
||||
|
||||
@@ -1398,7 +1398,7 @@ def fused_experts_impl(
|
||||
per_channel_quant=per_channel_quant,
|
||||
block_shape=block_shape)
|
||||
|
||||
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
|
||||
out_hidden_states[begin_chunk_idx:end_chunk_idx])
|
||||
|
||||
return out_hidden_states
|
||||
@@ -1611,8 +1611,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
config = try_get_optimal_moe_config(
|
||||
w1.shape,
|
||||
w2.shape,
|
||||
w1.size(),
|
||||
w2.size(),
|
||||
top_k_num,
|
||||
config_dtype,
|
||||
num_tokens,
|
||||
|
||||
Reference in New Issue
Block a user