[ROCm] Fix MoE kernel test failures on gfx950 (#37833)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Matthew Wong <Matthew.Wong2@amd.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Matthew Wong <Matthew.Wong2@amd.com>
This commit is contained in:
Andreas Karatzas
2026-03-25 13:46:40 -05:00
committed by GitHub
parent e38817fadb
commit 7d6917bef5
12 changed files with 478 additions and 86 deletions

View File

@@ -384,12 +384,21 @@ def test_legacy_routing(
logits = gating_output
if sm_first:
logits = torch.softmax(logits, dim=-1)
sparse_logits = topk_fn(logits, topk, apply_softmax=not sm_first)
topk_ids = sparse_logits.indx.to(torch.long)
topk_weights = sparse_logits.vals
routing_data_ref, gather_indx_ref, scatter_indx_ref = make_routing_data(
topk_ids, topk_weights, num_experts
)
topk_result = topk_fn(logits, topk, apply_softmax=not sm_first)
# topk_fn returns SparseMatrix on NVIDIA, plain tuple on ROCm.
if isinstance(topk_result, tuple):
topk_weights, topk_ids_raw, bitmatrix = topk_result
from triton_kernels.routing import routing_from_bitmatrix
routing_data_ref, gather_indx_ref, scatter_indx_ref = routing_from_bitmatrix(
bitmatrix, topk_weights, topk_ids_raw, num_experts, topk
)
else:
topk_ids = topk_result.indx.to(torch.long)
topk_weights = topk_result.vals
routing_data_ref, gather_indx_ref, scatter_indx_ref = make_routing_data(
topk_ids, topk_weights, num_experts
)
routing_data, gather_indx, scatter_indx = legacy_routing(
gating_output, topk, sm_first=sm_first