[ROCm][Kernel] MoE weights padding (#14454)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Signed-off-by: charlifu <charlifu@amd.com>
Co-authored-by: charlifu <charlifu@amd.com>
This commit is contained in:
Gregory Shtrasberg
2025-03-24 19:45:30 -04:00
committed by GitHub
parent 8279201ce6
commit f533b5837f
5 changed files with 65 additions and 16 deletions

View File

@@ -800,7 +800,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
expert_ids,
num_tokens_post_padded,
B.shape[1],
A.shape[1],
B.shape[2],
EM,
topk_ids.numel(),
A.stride(0),
@@ -1322,8 +1322,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]