[Perf] Eliminate redundant SparseMatrix creation in gpt_oss_triton_kernels (#37683)
Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
@@ -21,12 +21,16 @@ from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_m
|
||||
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
||||
from triton_kernels.tensor_details import layout
|
||||
from triton_kernels.testing import assert_close
|
||||
from triton_kernels.topk import topk as topk_fn
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import mxfp4_w4a16_moe_quant_config
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||
legacy_routing,
|
||||
make_routing_data,
|
||||
triton_kernel_moe_forward,
|
||||
)
|
||||
from vllm.utils.math_utils import round_up
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
from .utils import shuffle_weight
|
||||
|
||||
@@ -355,3 +359,43 @@ def test_unit_shuffle():
|
||||
)
|
||||
|
||||
assert_close(ref=out_ref, tri=out)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [2, 8, 64])
|
||||
@pytest.mark.parametrize("num_experts", [32, 128])
|
||||
@pytest.mark.parametrize("topk", [1, 4])
|
||||
@pytest.mark.parametrize("renormalize", [True, False])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
def test_legacy_routing(
|
||||
num_tokens: int, num_experts: int, topk: int, renormalize: bool, dtype: torch.dtype
|
||||
):
|
||||
set_random_seed(0)
|
||||
gating_output = torch.randn(num_tokens, num_experts, device="cuda", dtype=dtype)
|
||||
|
||||
sm_first = not renormalize
|
||||
logits = gating_output
|
||||
if sm_first:
|
||||
logits = torch.softmax(logits, dim=-1)
|
||||
sparse_logits = topk_fn(logits, topk, apply_softmax=not sm_first)
|
||||
topk_ids = sparse_logits.indx.to(torch.long)
|
||||
topk_weights = sparse_logits.vals
|
||||
routing_data_ref, gather_indx_ref, scatter_indx_ref = make_routing_data(
|
||||
topk_ids, topk_weights, num_experts
|
||||
)
|
||||
|
||||
routing_data, gather_indx, scatter_indx = legacy_routing(
|
||||
gating_output, topk, sm_first=sm_first
|
||||
)
|
||||
|
||||
assert_close(
|
||||
ref=gather_indx_ref.src_indx, tri=gather_indx.src_indx, maxtol=0, rmstol=0
|
||||
)
|
||||
assert_close(
|
||||
ref=gather_indx_ref.dst_indx, tri=gather_indx.dst_indx, maxtol=0, rmstol=0
|
||||
)
|
||||
assert_close(
|
||||
ref=scatter_indx_ref.src_indx, tri=scatter_indx.src_indx, maxtol=0, rmstol=0
|
||||
)
|
||||
assert_close(
|
||||
ref=scatter_indx_ref.dst_indx, tri=scatter_indx.dst_indx, maxtol=0, rmstol=0
|
||||
)
|
||||
|
||||
@@ -142,6 +142,33 @@ def legacy_routing_from_bitmatrix(
|
||||
return routing_data, gather_idx, scatter_idx
|
||||
|
||||
|
||||
def legacy_routing_from_sparsematrix(
|
||||
sparse_logits: "SparseMatrix",
|
||||
n_expts_tot: int,
|
||||
n_expts_act: int,
|
||||
) -> tuple["RoutingData", "GatherIndx", "ScatterIndx"]:
|
||||
"""
|
||||
Creates routing data from a SparseMatrix representation.
|
||||
"""
|
||||
dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx
|
||||
combine_indx = sparse_logits.mask_metadata.col_sorted_indx
|
||||
ragged_batch_metadata = make_ragged_tensor_metadata(
|
||||
sparse_logits.mask_metadata.col_sum,
|
||||
dispatch_indx.shape[0],
|
||||
)
|
||||
gate_scal = sparse_logits.vals.flatten()[combine_indx]
|
||||
routing_data = RoutingData(
|
||||
gate_scal,
|
||||
ragged_batch_metadata.block_sizes,
|
||||
n_expts_tot,
|
||||
n_expts_act,
|
||||
ragged_batch_metadata,
|
||||
)
|
||||
gather_idx = GatherIndx(combine_indx, dispatch_indx)
|
||||
scatter_idx = ScatterIndx(dispatch_indx, combine_indx)
|
||||
return routing_data, gather_idx, scatter_idx
|
||||
|
||||
|
||||
def legacy_routing(
|
||||
logits: torch.Tensor,
|
||||
n_expts_act: int,
|
||||
@@ -158,10 +185,8 @@ def legacy_routing(
|
||||
if sm_first:
|
||||
logits = torch.softmax(logits, dim=-1)
|
||||
sparse_logits = topk(logits, n_expts_act, apply_softmax=not sm_first)
|
||||
return legacy_routing_from_bitmatrix(
|
||||
sparse_logits.mask,
|
||||
sparse_logits.vals,
|
||||
sparse_logits.indx,
|
||||
return legacy_routing_from_sparsematrix(
|
||||
sparse_logits,
|
||||
logits.shape[-1],
|
||||
n_expts_act,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user