[Kernel] Add gpt-oss Router GEMM kernel (#37205)
Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
@@ -3,9 +3,11 @@
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.model_executor.custom_op import PluggableLayer
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
|
||||
@PluggableLayer.register("gate_linear")
|
||||
@@ -13,8 +15,9 @@ class GateLinear(ReplicatedLinear):
|
||||
"""MoE gate linear layer with three-tier GEMM dispatch:
|
||||
|
||||
1. DSV3 specialized kernel (SM90+, batch<=16, supported dims)
|
||||
2. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype)
|
||||
3. F.linear via ReplicatedLinear (ultimate fallback)
|
||||
2. gpt-oss specialized kernel (SM90+, batch<=128, supported dims)
|
||||
3. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype)
|
||||
4. F.linear via ReplicatedLinear (ultimate fallback)
|
||||
|
||||
The ``out_dtype`` attribute is mutable and can be set after init
|
||||
(e.g. when the required dtype depends on the expert quantization
|
||||
@@ -25,6 +28,10 @@ class GateLinear(ReplicatedLinear):
|
||||
DSV3_SUPPORTED_NUM_EXPERTS = [256, 384]
|
||||
DSV3_SUPPORTED_HIDDEN_SIZES = [7168]
|
||||
|
||||
# Dimensions supported by the gpt-oss specialized kernel
|
||||
GPT_OSS_SUPPORTED_NUM_EXPERTS = [32, 128]
|
||||
GPT_OSS_SUPPORTED_HIDDEN_SIZES = [2880]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
@@ -65,6 +72,15 @@ class GateLinear(ReplicatedLinear):
|
||||
and input_size in self.DSV3_SUPPORTED_HIDDEN_SIZES
|
||||
)
|
||||
|
||||
# gpt-oss specialized kernel eligibility (SM90+, exact dims)
|
||||
self.allow_gpt_oss_router_gemm = (
|
||||
self.weight.dtype == torch.bfloat16
|
||||
and current_platform.is_cuda()
|
||||
and is_hopper_or_blackwell
|
||||
and output_size in self.GPT_OSS_SUPPORTED_NUM_EXPERTS
|
||||
and input_size in self.GPT_OSS_SUPPORTED_HIDDEN_SIZES
|
||||
)
|
||||
|
||||
# cuBLAS bf16→fp32 eligibility
|
||||
self.allow_cublas_router_gemm = (
|
||||
self.allow_specialized_router_gemm
|
||||
@@ -92,8 +108,6 @@ class GateLinear(ReplicatedLinear):
|
||||
def forward(
|
||||
self, x: torch.Tensor
|
||||
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
|
||||
import vllm._custom_ops as ops
|
||||
|
||||
# Tier 1: DSV3 specialized kernel
|
||||
if self.allow_dsv3_router_gemm and x.shape[0] <= 16:
|
||||
output = ops.dsv3_router_gemm(
|
||||
@@ -103,15 +117,47 @@ class GateLinear(ReplicatedLinear):
|
||||
)
|
||||
return output, None
|
||||
|
||||
# Tier 2: cuBLAS bf16→fp32
|
||||
# Tier 2: gpt-oss specialized kernel
|
||||
if self.allow_gpt_oss_router_gemm:
|
||||
output = torch.ops.vllm.gpt_oss_router_gemm(x, self.weight, self.bias)
|
||||
return output, None
|
||||
|
||||
# Tier 3: cuBLAS bf16→fp32
|
||||
if self.allow_cublas_router_gemm and x.dtype == torch.bfloat16:
|
||||
output = ops.router_gemm_bf16_fp32(x, self.weight)
|
||||
return output, None
|
||||
|
||||
# Tier 3: F.linear (ReplicatedLinear)
|
||||
# Tier 4: F.linear (ReplicatedLinear)
|
||||
if self.out_dtype is not None and x.dtype != self.weight.dtype:
|
||||
x = x.to(self.weight.dtype)
|
||||
output, output_bias = super().forward(x)
|
||||
if self.out_dtype is not None and output.dtype != self.out_dtype:
|
||||
output = output.to(self.out_dtype)
|
||||
return output, output_bias
|
||||
|
||||
|
||||
def gpt_oss_router_gemm_impl(
|
||||
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Dynamically run min-latency gemm if num_tokens <= 128.
|
||||
This must be wrapped in a custom op because our torch.compile integration
|
||||
does not support runtime dispatching on num_tokens.
|
||||
"""
|
||||
if x.shape[0] <= 128:
|
||||
return ops.gpt_oss_router_gemm(x, weight, bias)
|
||||
else:
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
|
||||
def gpt_oss_router_gemm_fake(
|
||||
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
return x.new_empty((x.shape[0], weight.shape[0]))
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="gpt_oss_router_gemm",
|
||||
op_func=gpt_oss_router_gemm_impl,
|
||||
fake_impl=gpt_oss_router_gemm_fake,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user