Files
vllm/tests/kernels/moe/test_router_gemm.py
2026-03-18 08:15:56 -07:00

38 lines
1.3 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for optimized router GEMM kernel
Run `pytest tests/kernels/moe/test_router_gemm.py`.
"""
import pytest
import torch
import vllm._custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
@pytest.mark.skipif(
not (
current_platform.is_cuda()
and (
current_platform.is_device_capability(90)
or current_platform.is_device_capability_family(100)
)
),
reason="This test only runs on Hopper or Blackwell GPUs.",
)
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8])
@pytest.mark.parametrize("input_dim", [360, 720, 1440, 2880])
@pytest.mark.parametrize("output_dim", [32, 64, 128])
def test_gpt_oss_router_gemm(batch_size, input_dim, output_dim):
set_random_seed(0)
x = torch.randn(batch_size, input_dim, device="cuda", dtype=torch.bfloat16)
weight = torch.randn(output_dim, input_dim, device="cuda", dtype=torch.bfloat16)
bias = torch.randn(output_dim, device="cuda", dtype=torch.bfloat16)
output = ops.gpt_oss_router_gemm(x, weight, bias)
output_ref = torch.nn.functional.linear(x, weight, bias)
torch.testing.assert_close(output, output_ref, atol=1e-2, rtol=1e-2)