[Kernels][MoE] Fix legacy_routing to use bitmatrix-based routing path (#38504)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -4,12 +4,9 @@
|
||||
Tests that triton_kernel_moe_forward correctly applies expert_map
|
||||
remapping when expert parallelism (EP) is enabled.
|
||||
|
||||
Previously, legacy_routing was always used and it produced routing data
|
||||
with global expert IDs that didn't correspond to local weight indices,
|
||||
causing illegal memory access with EP. The fix splits routing: when
|
||||
expert_map is provided, topk selection is performed first, expert_map is
|
||||
applied to remap global→local IDs, and make_routing_data builds routing
|
||||
structures from the local IDs.
|
||||
Both EP and non-EP paths use topk + make_routing_data. When expert_map
|
||||
is provided, global expert IDs are remapped to local IDs before building
|
||||
routing structures.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
@@ -24,21 +21,15 @@ class TestTritonMoeForwardExpertMap:
|
||||
|
||||
@pytest.mark.parametrize("expert_map_present", [False, True])
|
||||
def test_routing_path_selection(self, expert_map_present):
|
||||
"""Verify that the EP-aware routing path is taken when expert_map
|
||||
is present, and the legacy_routing path is taken otherwise."""
|
||||
"""Verify that both EP and non-EP paths use topk + make_routing_data,
|
||||
and that expert_map remapping is applied when present."""
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
# This is a structural test: we mock the routing functions to
|
||||
# verify the correct path is exercised.
|
||||
mock_expert_map = (
|
||||
torch.tensor([0, -1, 1, -1], device=device) if expert_map_present else None
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"vllm.model_executor.layers.fused_moe."
|
||||
"gpt_oss_triton_kernels_moe.legacy_routing"
|
||||
) as mock_legacy,
|
||||
patch("triton_kernels.topk.topk") as mock_topk,
|
||||
patch(
|
||||
"vllm.model_executor.layers.fused_moe."
|
||||
@@ -53,27 +44,19 @@ class TestTritonMoeForwardExpertMap:
|
||||
triton_kernel_moe_forward,
|
||||
)
|
||||
|
||||
# Set up return values
|
||||
mock_routing_data = MagicMock()
|
||||
mock_gather = MagicMock()
|
||||
mock_scatter = MagicMock()
|
||||
|
||||
if expert_map_present:
|
||||
sparse_result = MagicMock()
|
||||
sparse_result.indx = torch.tensor([[0, 2]], dtype=torch.int32)
|
||||
sparse_result.vals = torch.tensor([[0.6, 0.4]])
|
||||
mock_topk.return_value = sparse_result
|
||||
mock_make_routing.return_value = (
|
||||
mock_routing_data,
|
||||
mock_gather,
|
||||
mock_scatter,
|
||||
)
|
||||
else:
|
||||
mock_legacy.return_value = (
|
||||
mock_routing_data,
|
||||
mock_gather,
|
||||
mock_scatter,
|
||||
)
|
||||
sparse_result = MagicMock()
|
||||
sparse_result.indx = torch.tensor([[0, 2]], dtype=torch.int32)
|
||||
sparse_result.vals = torch.tensor([[0.6, 0.4]])
|
||||
mock_topk.return_value = sparse_result
|
||||
mock_make_routing.return_value = (
|
||||
mock_routing_data,
|
||||
mock_gather,
|
||||
mock_scatter,
|
||||
)
|
||||
|
||||
mock_fused_experts.return_value = torch.zeros((1, 8), device=device)
|
||||
|
||||
@@ -92,20 +75,14 @@ class TestTritonMoeForwardExpertMap:
|
||||
expert_map=mock_expert_map,
|
||||
)
|
||||
|
||||
# Both paths use topk + make_routing_data
|
||||
mock_topk.assert_called_once()
|
||||
mock_make_routing.assert_called_once()
|
||||
|
||||
if expert_map_present:
|
||||
# EP path: should use topk + make_routing_data, NOT
|
||||
# legacy_routing
|
||||
mock_topk.assert_called_once()
|
||||
mock_make_routing.assert_called_once()
|
||||
mock_legacy.assert_not_called()
|
||||
# expert_map should be None in the fused_experts call
|
||||
# (already applied)
|
||||
call_kwargs = mock_fused_experts.call_args
|
||||
assert call_kwargs[1].get("expert_map") is None or (
|
||||
len(call_kwargs[0]) > 0
|
||||
)
|
||||
else:
|
||||
# Non-EP path: should use legacy_routing
|
||||
mock_legacy.assert_called_once()
|
||||
mock_topk.assert_not_called()
|
||||
mock_make_routing.assert_not_called()
|
||||
|
||||
Reference in New Issue
Block a user