[Kernels][MoE] Fix legacy_routing to use bitmatrix-based routing path (#38504)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-04-06 21:57:09 -05:00
committed by GitHub
parent 62095e82c1
commit 2df2c85be4
7 changed files with 84 additions and 216 deletions

View File

@@ -47,7 +47,6 @@ if has_triton_kernels():
BIT,
Bitmatrix,
)
from triton_kernels.topk import topk
try:
from triton_kernels.tensor import (
@@ -89,6 +88,7 @@ def pack_bitmatrix(
offsets = offsets_m[:, None] * n_expts_act + offsets_k[None, :]
mask = (offsets_m < n_rows)[:, None] & (offsets_k < n_expts_act)[None, :]
indices = tl.load(topk_ids + offsets, mask=mask, other=-1)
valid = indices >= 0
div = indices // 32
rem = indices % 32
one = tl.cast(1, tl.uint32)
@@ -99,8 +99,13 @@ def pack_bitmatrix(
offs = tl.arange(0, BLOCK_SIZE_K // 32) + i * (BLOCK_SIZE_K // 32)
# All topks that need to go into this column has the correct bit set.
# Other bits are 0. x is a 2D tensor.
# Guard with `valid` to prevent negative indices from producing
# spurious bits (on HIP, -1 // 32 == 0 and 1 << (-1 % 32) sets
# bit 31).
x = tl.where(
div[:, :, None] == offs[None, None, :], (one << rem)[:, :, None], 0
valid[:, :, None] & (div[:, :, None] == offs[None, None, :]),
(one << rem)[:, :, None],
0,
)
# Reduce x to get a single int32_t bitpack.
y = tl.reduce_or(x, axis=1)
@@ -108,93 +113,6 @@ def pack_bitmatrix(
tl.store(bitmatrix_ptrs, y, mask=offsets_m[:, None] < n_rows)
def legacy_routing_from_bitmatrix(
bitmatrix: "Bitmatrix",
expt_scal: torch.Tensor,
expt_indx: torch.Tensor,
n_expts_tot: int,
n_expts_act: int,
) -> tuple["RoutingData", "GatherIndx", "ScatterIndx"]:
"""
Replacement for the removed triton_kernels.routing.routing_from_bitmatrix.
Creates routing data from a bitmatrix representation.
"""
if use_legacy_triton_kernels:
from triton_kernels.routing import routing_from_bitmatrix
return routing_from_bitmatrix(
bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act
)
sparse_logits = SparseMatrix(indx=expt_indx, vals=expt_scal, mask=bitmatrix)
dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx
combine_indx = sparse_logits.mask_metadata.col_sorted_indx
ragged_batch_metadata = make_ragged_tensor_metadata(
sparse_logits.mask_metadata.col_sum,
dispatch_indx.shape[0],
)
gate_scal = sparse_logits.vals.flatten()[combine_indx]
routing_data = RoutingData(
gate_scal,
ragged_batch_metadata.block_sizes,
n_expts_tot,
n_expts_act,
ragged_batch_metadata,
)
gather_idx = GatherIndx(combine_indx, dispatch_indx)
scatter_idx = ScatterIndx(dispatch_indx, combine_indx)
return routing_data, gather_idx, scatter_idx
def legacy_routing_from_sparsematrix(
sparse_logits: "SparseMatrix",
n_expts_tot: int,
n_expts_act: int,
) -> tuple["RoutingData", "GatherIndx", "ScatterIndx"]:
"""
Creates routing data from a SparseMatrix representation.
"""
dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx
combine_indx = sparse_logits.mask_metadata.col_sorted_indx
ragged_batch_metadata = make_ragged_tensor_metadata(
sparse_logits.mask_metadata.col_sum,
dispatch_indx.shape[0],
)
gate_scal = sparse_logits.vals.flatten()[combine_indx]
routing_data = RoutingData(
gate_scal,
ragged_batch_metadata.block_sizes,
n_expts_tot,
n_expts_act,
ragged_batch_metadata,
)
gather_idx = GatherIndx(combine_indx, dispatch_indx)
scatter_idx = ScatterIndx(dispatch_indx, combine_indx)
return routing_data, gather_idx, scatter_idx
def legacy_routing(
logits: torch.Tensor,
n_expts_act: int,
sm_first: bool = False,
) -> tuple["RoutingData", "GatherIndx", "ScatterIndx"]:
"""
Replacement for the removed triton_kernels.routing.routing function.
Computes routing data from gating logits.
"""
if use_legacy_triton_kernels:
from triton_kernels.routing import routing
return routing(logits, n_expts_act, sm_first=sm_first)
if sm_first:
logits = torch.softmax(logits, dim=-1)
sparse_logits = topk(logits, n_expts_act, apply_softmax=not sm_first)
return legacy_routing_from_sparsematrix(
sparse_logits,
logits.shape[-1],
n_expts_act,
)
def triton_kernel_moe_forward(
hidden_states: torch.Tensor,
w1, # Tensor or triton_kernels.Tensor
@@ -241,26 +159,22 @@ def triton_kernel_moe_forward(
unpadded_K_w2=unpadded_K_w2,
)
if expert_map is not None:
# With expert parallelism, legacy_routing produces routing data
# using global expert IDs which don't correspond to local weight
# indices. Split the routing into topk selection + expert_map
# remapping + local routing data construction (matching the
# approach used by OAITritonExperts.apply).
from triton_kernels.topk import topk as topk_fn
from triton_kernels.topk import topk as topk_fn
sm_first = not renormalize
logits = gating_output
if sm_first:
logits = torch.softmax(logits, dim=-1)
topk_result = topk_fn(logits, topk, apply_softmax=not sm_first)
# topk may return a tuple (vals, indx, bitmatrix) or a
# SparseMatrix depending on the triton_kernels version.
if isinstance(topk_result, tuple):
topk_weights, topk_ids_raw, _ = topk_result
else:
topk_weights = topk_result.vals
topk_ids_raw = topk_result.indx
sm_first = not renormalize
logits = gating_output
if sm_first:
logits = torch.softmax(logits, dim=-1)
topk_result = topk_fn(logits, topk, apply_softmax=not sm_first)
# topk may return a tuple (vals, indx, bitmatrix) or a
# SparseMatrix depending on the triton_kernels version.
if isinstance(topk_result, tuple):
topk_weights, topk_ids_raw, _ = topk_result
else:
topk_weights = topk_result.vals
topk_ids_raw = topk_result.indx
if expert_map is not None:
# topk_ids_raw contains global expert IDs - remap to local.
topk_ids = expert_map[topk_ids_raw.to(torch.long)]
local_num_experts = w1.shape[0]
@@ -271,8 +185,9 @@ def triton_kernel_moe_forward(
effective_expert_map = None
effective_global_num_experts = local_num_experts
else:
routing_data, gather_idx, scatter_idx = legacy_routing(
gating_output, topk, sm_first=not renormalize
topk_ids = topk_ids_raw.to(torch.long)
routing_data, gather_idx, scatter_idx = make_routing_data(
topk_ids, topk_weights, gating_output.shape[-1]
)
effective_expert_map = expert_map
effective_global_num_experts = global_num_experts
@@ -539,10 +454,31 @@ def make_routing_data(
# matmul_ogs expects invalid topk_weights to be -1s
topk_weights = torch.where(topk_ids == -1, -1.0, topk_weights)
routing_data, gather_indx, scatter_indx = legacy_routing_from_bitmatrix(
bitmatrix, topk_weights, topk_ids, num_local_experts, num_topk
)
if use_legacy_triton_kernels:
from triton_kernels.routing import routing_from_bitmatrix
return routing_from_bitmatrix(
bitmatrix, topk_weights, topk_ids, num_local_experts, num_topk
)
sparse_logits = SparseMatrix(indx=topk_ids, vals=topk_weights, mask=bitmatrix)
dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx
combine_indx = sparse_logits.mask_metadata.col_sorted_indx
ragged_batch_metadata = make_ragged_tensor_metadata(
sparse_logits.mask_metadata.col_sum,
dispatch_indx.shape[0],
)
gate_scal = sparse_logits.vals.flatten()[combine_indx]
routing_data = RoutingData(
gate_scal,
ragged_batch_metadata.block_sizes,
num_local_experts,
num_topk,
ragged_batch_metadata,
)
gather_indx = GatherIndx(combine_indx, dispatch_indx)
scatter_indx = ScatterIndx(dispatch_indx, combine_indx)
return routing_data, gather_indx, scatter_indx