[Perf] Eliminate redundant SparseMatrix creation in gpt_oss_triton_kernels (#37683)

Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
Xin Yang
2026-03-20 10:28:41 -07:00
committed by GitHub
parent fb4e8bf442
commit d0532bf38d
2 changed files with 73 additions and 4 deletions

View File

@@ -142,6 +142,33 @@ def legacy_routing_from_bitmatrix(
return routing_data, gather_idx, scatter_idx
def legacy_routing_from_sparsematrix(
sparse_logits: "SparseMatrix",
n_expts_tot: int,
n_expts_act: int,
) -> tuple["RoutingData", "GatherIndx", "ScatterIndx"]:
"""
Creates routing data from a SparseMatrix representation.
"""
dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx
combine_indx = sparse_logits.mask_metadata.col_sorted_indx
ragged_batch_metadata = make_ragged_tensor_metadata(
sparse_logits.mask_metadata.col_sum,
dispatch_indx.shape[0],
)
gate_scal = sparse_logits.vals.flatten()[combine_indx]
routing_data = RoutingData(
gate_scal,
ragged_batch_metadata.block_sizes,
n_expts_tot,
n_expts_act,
ragged_batch_metadata,
)
gather_idx = GatherIndx(combine_indx, dispatch_indx)
scatter_idx = ScatterIndx(dispatch_indx, combine_indx)
return routing_data, gather_idx, scatter_idx
def legacy_routing(
logits: torch.Tensor,
n_expts_act: int,
@@ -158,10 +185,8 @@ def legacy_routing(
if sm_first:
logits = torch.softmax(logits, dim=-1)
sparse_logits = topk(logits, n_expts_act, apply_softmax=not sm_first)
return legacy_routing_from_bitmatrix(
sparse_logits.mask,
sparse_logits.vals,
sparse_logits.indx,
return legacy_routing_from_sparsematrix(
sparse_logits,
logits.shape[-1],
n_expts_act,
)