[ROCm] Fix MoE kernel test failures on gfx950 (#37833)
Signed-off-by: Andreas Karatzas <akaratza@amd.com> Signed-off-by: Matthew Wong <Matthew.Wong2@amd.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Matthew Wong <Matthew.Wong2@amd.com>
This commit is contained in:
@@ -384,12 +384,21 @@ def test_legacy_routing(
|
||||
logits = gating_output
|
||||
if sm_first:
|
||||
logits = torch.softmax(logits, dim=-1)
|
||||
sparse_logits = topk_fn(logits, topk, apply_softmax=not sm_first)
|
||||
topk_ids = sparse_logits.indx.to(torch.long)
|
||||
topk_weights = sparse_logits.vals
|
||||
routing_data_ref, gather_indx_ref, scatter_indx_ref = make_routing_data(
|
||||
topk_ids, topk_weights, num_experts
|
||||
)
|
||||
topk_result = topk_fn(logits, topk, apply_softmax=not sm_first)
|
||||
# topk_fn returns SparseMatrix on NVIDIA, plain tuple on ROCm.
|
||||
if isinstance(topk_result, tuple):
|
||||
topk_weights, topk_ids_raw, bitmatrix = topk_result
|
||||
from triton_kernels.routing import routing_from_bitmatrix
|
||||
|
||||
routing_data_ref, gather_indx_ref, scatter_indx_ref = routing_from_bitmatrix(
|
||||
bitmatrix, topk_weights, topk_ids_raw, num_experts, topk
|
||||
)
|
||||
else:
|
||||
topk_ids = topk_result.indx.to(torch.long)
|
||||
topk_weights = topk_result.vals
|
||||
routing_data_ref, gather_indx_ref, scatter_indx_ref = make_routing_data(
|
||||
topk_ids, topk_weights, num_experts
|
||||
)
|
||||
|
||||
routing_data, gather_indx, scatter_indx = legacy_routing(
|
||||
gating_output, topk, sm_first=sm_first
|
||||
|
||||
Reference in New Issue
Block a user