Files
vllm/tests/kernels/moe/test_routing.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

493 lines
16 KiB
Python
Raw Normal View History

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import pytest
import torch
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.fused_moe.router.router_factory import (
create_fused_moe_router,
)
from vllm.model_executor.models.llama4 import Llama4MoE
# Test parameters
MK_S = [(32, 256), (64, 512)]
TOP_KS = [2, 4, 6]
NUM_EXPERTS = [8, 16, 64]
def setup_eplb_state(enable_eplb: bool, global_num_experts: int) -> EplbLayerState:
if not enable_eplb:
return EplbLayerState()
# Initialize EPLB state with proper tensors for testing
# For testing purposes, we use a simple 1:1 mapping (no redundant experts)
# expert_load_view: tracks load on each expert (shape: num_experts)
expert_load_view = torch.zeros(global_num_experts, dtype=torch.int32, device="cuda")
# logical_to_physical_map: maps logical experts to physical experts
# Shape: (num_logical_experts, max_slots)
# For testing, use simple 1:1 mapping with single slot per expert
logical_to_physical_map = torch.arange(
global_num_experts, dtype=torch.int64, device="cuda"
).unsqueeze(-1)
# logical_replica_count: number of replicas per logical expert
# Shape: (num_logical_experts,)
# For testing, each logical expert has exactly 1 replica
logical_replica_count = torch.ones(
global_num_experts, dtype=torch.int64, device="cuda"
)
return EplbLayerState(
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def make_test_data(
m: int, k: int, num_experts: int
) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states = torch.randn((m, k), device="cuda") / 10
logits = torch.randn((m, num_experts), device="cuda")
return hidden_states, logits
def make_e_score_correction_bias(
e_score_correction_bias_val: float,
num_experts: int,
) -> torch.Tensor:
# return torch.randn(num_experts, device="cuda") * e_score_correction_bias_val
return torch.full(
(num_experts,), e_score_correction_bias_val, device="cuda", dtype=torch.float32
)
def assert_routing_results_close(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
baseline_weights: torch.Tensor,
baseline_ids: torch.Tensor,
rtol: float = 1e-3,
atol: float = 1e-3,
):
"""
Compare routing results, sorting by expert ID first to handle non-deterministic
ordering from sorted=False in topk.
"""
# Sort both results by expert IDs for consistent comparison
sorted_indices_actual = torch.argsort(topk_ids, dim=-1)
sorted_indices_baseline = torch.argsort(baseline_ids.to(topk_ids.dtype), dim=-1)
# Gather the sorted values
topk_ids_sorted = torch.gather(topk_ids, 1, sorted_indices_actual)
topk_weights_sorted = torch.gather(topk_weights, 1, sorted_indices_actual)
baseline_ids_sorted = torch.gather(
baseline_ids.to(topk_ids.dtype), 1, sorted_indices_baseline
)
baseline_weights_sorted = torch.gather(baseline_weights, 1, sorted_indices_baseline)
# Compare
torch.testing.assert_close(topk_ids_sorted, baseline_ids_sorted)
torch.testing.assert_close(
topk_weights_sorted, baseline_weights_sorted, rtol=rtol, atol=atol
)
def baseline_fused_topk(
router_logits: torch.Tensor, top_k: int, renormalize: bool
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Baseline for standard fused top-k routing.
Algorithm:
1. Apply softmax to router logits
2. Select top-k experts
3. Optionally renormalize the weights
"""
scores = torch.softmax(router_logits, dim=-1, dtype=torch.float32)
# Use sorted=False to match vllm implementation (vllm_is_batch_invariant
# defaults to False)
topk_weights, topk_ids = torch.topk(scores, top_k, dim=-1, sorted=False)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
def baseline_fused_topk_bias(
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
e_score_correction_bias: torch.Tensor,
routed_scaling_factor: float,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Baseline for fused top-k with bias correction.
Algorithm:
1. Apply softmax to router logits
2. Add bias to scores for expert selection
3. Select top-k experts using biased scores
4. Get weights from original (unbiased) scores
5. Apply routed scaling factor
6. Optionally renormalize the weights
"""
# Apply softmax to get scores
scores = torch.softmax(router_logits, dim=-1, dtype=torch.float32)
# Add bias for expert selection
scores_for_choice = scores + e_score_correction_bias.unsqueeze(0)
# Select top-k using biased scores (sorted=False to match implementation)
topk_ids = torch.topk(scores_for_choice, k=top_k, dim=-1, sorted=False)[1]
# Get weights from original scores (not biased)
topk_weights = scores.gather(1, topk_ids)
# Renormalize if needed (BEFORE applying scaling factor)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
# Apply scaling factor (AFTER renormalization, if applicable)
if routed_scaling_factor != 1.0:
topk_weights *= routed_scaling_factor
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
def baseline_grouped_topk(
router_logits: torch.Tensor,
top_k: int,
num_expert_group: int,
topk_group: int,
scoring_func: str,
renormalize: bool,
e_score_correction_bias: torch.Tensor | None,
routed_scaling_factor: float,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Baseline for grouped top-k routing (e.g., DeepSeek).
Algorithm:
1. Apply scoring function (softmax or sigmoid)
2. Optionally add bias
3. Select top-k groups based on max scores within each group
4. Mask scores to only include selected groups
5. Select top-k experts from masked scores
6. Apply scaling factor
7. Optionally renormalize
"""
num_token = router_logits.shape[0]
# Apply scoring function
if scoring_func == "softmax":
scores = torch.softmax(router_logits, dim=-1, dtype=torch.float32)
elif scoring_func == "sigmoid":
scores = torch.sigmoid(router_logits.float())
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
# Handle bias correction
if e_score_correction_bias is not None:
original_scores = scores
scores = scores + e_score_correction_bias.unsqueeze(0)
# For bias case, use sum of top-2 scores in each group
group_scores = (
scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
)
else:
# Use max score in each group
group_scores = scores.view(num_token, num_expert_group, -1).max(dim=-1).values
# Select top-k groups
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[1]
# Create mask for selected groups
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
# Expand mask to all experts
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
.reshape(num_token, -1)
)
# Mask scores (set non-selected to -inf)
tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf"))
# Select top-k experts
if e_score_correction_bias is not None:
topk_ids = torch.topk(tmp_scores, k=top_k, dim=-1, sorted=False)[1]
topk_weights = original_scores.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(tmp_scores, k=top_k, dim=-1, sorted=False)
# Renormalize if needed
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
# Apply scaling factor
if routed_scaling_factor != 1.0:
topk_weights *= routed_scaling_factor
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
def baseline_custom_llama4(
router_logits: torch.Tensor, top_k: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Baseline for Llama4 custom routing.
Algorithm:
1. Select top-k expert indices (without softmax)
2. Apply sigmoid to the selected scores
"""
router_scores, router_indices = torch.topk(router_logits, top_k, dim=-1)
router_scores = torch.sigmoid(router_scores.float())
return router_scores.to(torch.float32), router_indices.to(torch.int32)
@pytest.mark.parametrize("m,k", MK_S)
@pytest.mark.parametrize("top_k", TOP_KS)
@pytest.mark.parametrize("global_num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("renormalize", [False, True])
@pytest.mark.parametrize("enable_eplb", [False, True])
def test_fused_topk(
m: int,
k: int,
top_k: int,
global_num_experts: int,
renormalize: bool,
enable_eplb: bool,
):
if top_k > global_num_experts:
pytest.skip(f"top_k ({top_k}) > global_num_experts ({global_num_experts})")
eplb_state = setup_eplb_state(enable_eplb, global_num_experts)
router = create_fused_moe_router(
top_k=top_k,
global_num_experts=global_num_experts,
renormalize=renormalize,
enable_eplb=enable_eplb,
eplb_state=eplb_state,
)
hidden_states, router_logits = make_test_data(m, k, global_num_experts)
# Get router output
topk_weights, topk_ids = router.select_experts(hidden_states, router_logits)
# Compute baseline
baseline_weights, baseline_ids = baseline_fused_topk(
router_logits, top_k, renormalize
)
# Compare results
assert_routing_results_close(topk_weights, topk_ids, baseline_weights, baseline_ids)
@pytest.mark.parametrize("m,k", MK_S)
@pytest.mark.parametrize("top_k", TOP_KS)
@pytest.mark.parametrize("global_num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("renormalize", [False, True])
@pytest.mark.parametrize("enable_eplb", [False, True])
@pytest.mark.parametrize("e_score_correction_bias_val", [0.9])
@pytest.mark.parametrize("routed_scaling_factor", [1.0, 1.1])
def test_fused_topk_bias(
m: int,
k: int,
top_k: int,
global_num_experts: int,
renormalize: bool,
enable_eplb: bool,
e_score_correction_bias_val: float,
routed_scaling_factor: float,
):
if top_k > global_num_experts:
pytest.skip(f"top_k ({top_k}) > global_num_experts ({global_num_experts})")
eplb_state = setup_eplb_state(enable_eplb, global_num_experts)
e_score_correction_bias = make_e_score_correction_bias(
e_score_correction_bias_val,
global_num_experts,
)
router = create_fused_moe_router(
e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
top_k=top_k,
global_num_experts=global_num_experts,
renormalize=renormalize,
enable_eplb=enable_eplb,
eplb_state=eplb_state,
)
hidden_states, router_logits = make_test_data(m, k, global_num_experts)
# Get router output
topk_weights, topk_ids = router.select_experts(hidden_states, router_logits)
# Compute baseline
baseline_weights, baseline_ids = baseline_fused_topk_bias(
router_logits,
top_k,
renormalize,
e_score_correction_bias,
routed_scaling_factor,
)
# Compare results
assert_routing_results_close(topk_weights, topk_ids, baseline_weights, baseline_ids)
@pytest.mark.parametrize("m,k", MK_S)
@pytest.mark.parametrize("top_k", TOP_KS)
@pytest.mark.parametrize(
"global_num_experts,num_expert_group,topk_group",
[
(64, 8, 4), # 8 groups of 8 experts, select 4 groups
(32, 4, 2), # 4 groups of 8 experts, select 2 groups
],
)
@pytest.mark.parametrize("renormalize", [False, True])
@pytest.mark.parametrize("enable_eplb", [False, True])
@pytest.mark.parametrize("e_score_correction_bias_val", [0.9])
@pytest.mark.parametrize("routed_scaling_factor", [1.0, 1.1])
@pytest.mark.parametrize("scoring_func", ["sigmoid", "softmax"])
def test_grouped_topk(
m: int,
k: int,
top_k: int,
global_num_experts: int,
renormalize: bool,
enable_eplb: bool,
num_expert_group: int,
topk_group: int,
scoring_func: str,
e_score_correction_bias_val: float,
routed_scaling_factor: float,
):
if top_k > global_num_experts:
pytest.skip(f"top_k ({top_k}) > global_num_experts ({global_num_experts})")
eplb_state = setup_eplb_state(enable_eplb, global_num_experts)
e_score_correction_bias = make_e_score_correction_bias(
e_score_correction_bias_val,
global_num_experts,
)
router = create_fused_moe_router(
use_grouped_topk=True,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
top_k=top_k,
global_num_experts=global_num_experts,
renormalize=renormalize,
enable_eplb=enable_eplb,
eplb_state=eplb_state,
)
hidden_states, router_logits = make_test_data(m, k, global_num_experts)
# Get router output
topk_weights, topk_ids = router.select_experts(hidden_states, router_logits)
# Compute baseline
baseline_weights, baseline_ids = baseline_grouped_topk(
router_logits,
top_k,
num_expert_group,
topk_group,
scoring_func,
renormalize,
e_score_correction_bias,
routed_scaling_factor,
)
# Compare results
assert_routing_results_close(topk_weights, topk_ids, baseline_weights, baseline_ids)
@pytest.mark.parametrize("m,k", MK_S)
@pytest.mark.parametrize("top_k", TOP_KS)
@pytest.mark.parametrize("global_num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("renormalize", [False, True])
@pytest.mark.parametrize("enable_eplb", [False, True])
@pytest.mark.parametrize("custom_routing_function", [Llama4MoE.custom_routing_function])
def test_custom(
m: int,
k: int,
top_k: int,
global_num_experts: int,
renormalize: bool,
enable_eplb: bool,
custom_routing_function: Callable,
):
if top_k > global_num_experts:
pytest.skip(f"top_k ({top_k}) > global_num_experts ({global_num_experts})")
eplb_state = setup_eplb_state(enable_eplb, global_num_experts)
router = create_fused_moe_router(
top_k=top_k,
global_num_experts=global_num_experts,
custom_routing_function=custom_routing_function,
renormalize=renormalize,
enable_eplb=enable_eplb,
eplb_state=eplb_state,
)
hidden_states, router_logits = make_test_data(m, k, global_num_experts)
# Get router output
topk_weights, topk_ids = router.select_experts(hidden_states, router_logits)
# Compute baseline (Llama4 uses sigmoid)
baseline_weights, baseline_ids = baseline_custom_llama4(router_logits, top_k)
# Compare results
assert_routing_results_close(topk_weights, topk_ids, baseline_weights, baseline_ids)
# TODO: is other test sufficient?
# # See tests/test_routing_simulatator.py
# @pytest.mark.parametrize("m,k", MK_S)
# @pytest.mark.parametrize("top_k", TOP_KS)
# @pytest.mark.parametrize("global_num_experts", NUM_EXPERTS)
# @pytest.mark.parametrize("renormalize", [False, True])
# @pytest.mark.parametrize("enable_eplb", [False, True])
# @pytest.mark.parameterize("strategy", ["uniform_random", "normal_routing"])
# def test_simulated(
# m: int,
# k: int,
# top_k: int,
# global_num_experts: int,
# renormalize: bool,
# enable_eplb: bool,
# strategy: str,
# monkeypatch,
# ):
# eplb_state = setup_eplb_state(enable_eplb)
# monkeypatch.setenv("VLLM_MOE_ROUTING_SIMULATION_STRATEGY", strategy)
# router = create_fused_moe_router(
# top_k=top_k,
# global_num_experts=global_num_experts,
# enable_eplb=enable_eplb,
# eplb_state=eplb_state,
# )
# hidden_states, router_logits = make_test_data(m, k, global_num_experts)
# topk_weights, topk_ids = router.select_experts(hidden_states, router_logits)