[ROCm][Kernel] MoE weights padding (#14454)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: charlifu <charlifu@amd.com> Co-authored-by: charlifu <charlifu@amd.com>
This commit is contained in:
committed by
GitHub
parent
8279201ce6
commit
f533b5837f
@@ -800,7 +800,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
B.shape[1],
|
||||
A.shape[1],
|
||||
B.shape[2],
|
||||
EM,
|
||||
topk_ids.numel(),
|
||||
A.stride(0),
|
||||
@@ -1322,8 +1322,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert hidden_states.dtype in [
|
||||
torch.float32, torch.float16, torch.bfloat16
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user