[Kernel] Add gpt-oss Router GEMM kernel (#37205)

Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
Xin Yang
2026-03-18 08:15:56 -07:00
committed by GitHub
parent 17808394bc
commit b1169d7be8
13 changed files with 875 additions and 13 deletions

View File

@@ -3,9 +3,11 @@
import torch
from torch.nn.parameter import Parameter
import vllm._custom_ops as ops
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
@PluggableLayer.register("gate_linear")
@@ -13,8 +15,9 @@ class GateLinear(ReplicatedLinear):
"""MoE gate linear layer with three-tier GEMM dispatch:
1. DSV3 specialized kernel (SM90+, batch<=16, supported dims)
2. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype)
3. F.linear via ReplicatedLinear (ultimate fallback)
2. gpt-oss specialized kernel (SM90+, batch<=128, supported dims)
3. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype)
4. F.linear via ReplicatedLinear (ultimate fallback)
The ``out_dtype`` attribute is mutable and can be set after init
(e.g. when the required dtype depends on the expert quantization
@@ -25,6 +28,10 @@ class GateLinear(ReplicatedLinear):
DSV3_SUPPORTED_NUM_EXPERTS = [256, 384]
DSV3_SUPPORTED_HIDDEN_SIZES = [7168]
# Dimensions supported by the gpt-oss specialized kernel
GPT_OSS_SUPPORTED_NUM_EXPERTS = [32, 128]
GPT_OSS_SUPPORTED_HIDDEN_SIZES = [2880]
def __init__(
self,
input_size: int,
@@ -65,6 +72,15 @@ class GateLinear(ReplicatedLinear):
and input_size in self.DSV3_SUPPORTED_HIDDEN_SIZES
)
# gpt-oss specialized kernel eligibility (SM90+, exact dims)
self.allow_gpt_oss_router_gemm = (
self.weight.dtype == torch.bfloat16
and current_platform.is_cuda()
and is_hopper_or_blackwell
and output_size in self.GPT_OSS_SUPPORTED_NUM_EXPERTS
and input_size in self.GPT_OSS_SUPPORTED_HIDDEN_SIZES
)
# cuBLAS bf16→fp32 eligibility
self.allow_cublas_router_gemm = (
self.allow_specialized_router_gemm
@@ -92,8 +108,6 @@ class GateLinear(ReplicatedLinear):
def forward(
self, x: torch.Tensor
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
import vllm._custom_ops as ops
# Tier 1: DSV3 specialized kernel
if self.allow_dsv3_router_gemm and x.shape[0] <= 16:
output = ops.dsv3_router_gemm(
@@ -103,15 +117,47 @@ class GateLinear(ReplicatedLinear):
)
return output, None
# Tier 2: cuBLAS bf16→fp32
# Tier 2: gpt-oss specialized kernel
if self.allow_gpt_oss_router_gemm:
output = torch.ops.vllm.gpt_oss_router_gemm(x, self.weight, self.bias)
return output, None
# Tier 3: cuBLAS bf16→fp32
if self.allow_cublas_router_gemm and x.dtype == torch.bfloat16:
output = ops.router_gemm_bf16_fp32(x, self.weight)
return output, None
# Tier 3: F.linear (ReplicatedLinear)
# Tier 4: F.linear (ReplicatedLinear)
if self.out_dtype is not None and x.dtype != self.weight.dtype:
x = x.to(self.weight.dtype)
output, output_bias = super().forward(x)
if self.out_dtype is not None and output.dtype != self.out_dtype:
output = output.to(self.out_dtype)
return output, output_bias
def gpt_oss_router_gemm_impl(
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
) -> torch.Tensor:
"""
Dynamically run min-latency gemm if num_tokens <= 128.
This must be wrapped in a custom op because our torch.compile integration
does not support runtime dispatching on num_tokens.
"""
if x.shape[0] <= 128:
return ops.gpt_oss_router_gemm(x, weight, bias)
else:
return torch.nn.functional.linear(x, weight, bias)
def gpt_oss_router_gemm_fake(
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
) -> torch.Tensor:
return x.new_empty((x.shape[0], weight.shape[0]))
direct_register_custom_op(
op_name="gpt_oss_router_gemm",
op_func=gpt_oss_router_gemm_impl,
fake_impl=gpt_oss_router_gemm_fake,
)