[Kernel] Add gpt-oss Router GEMM kernel (#37205)

Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
Xin Yang
2026-03-18 08:15:56 -07:00
committed by GitHub
parent 17808394bc
commit b1169d7be8
13 changed files with 875 additions and 13 deletions

View File

@@ -999,6 +999,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_MOE_EXT_SRC
"csrc/moe/moe_wna16.cu"
"csrc/moe/grouped_topk_kernels.cu"
"csrc/moe/gpt_oss_router_gemm.cu"
"csrc/moe/router_gemm.cu")
endif()

View File

@@ -0,0 +1,134 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn.functional as F
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config
from vllm.triton_utils import triton
from vllm.utils.argparse_utils import FlexibleArgumentParser
# Dimensions supported by the DSV3 specialized kernel
DSV3_SUPPORTED_NUM_EXPERTS = [256, 384]
DSV3_SUPPORTED_HIDDEN_SIZES = [7168]
# Dimensions supported by the gpt-oss specialized kernel
GPT_OSS_SUPPORTED_NUM_EXPERTS = [32, 128]
GPT_OSS_SUPPORTED_HIDDEN_SIZES = [2880]
def get_batch_size_range(max_batch_size):
return [2**x for x in range(14) if 2**x <= max_batch_size]
def get_model_params(config):
if config.architectures[0] in (
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"DeepseekV32ForCausalLM",
):
num_experts = config.n_routed_experts
hidden_size = config.hidden_size
elif config.architectures[0] in ("GptOssForCausalLM",):
num_experts = config.num_local_experts
hidden_size = config.hidden_size
else:
raise ValueError(f"Unsupported architecture: {config.architectures}")
return num_experts, hidden_size
def get_benchmark(model, max_batch_size, trust_remote_code):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=get_batch_size_range(max_batch_size),
x_log=False,
line_arg="provider",
line_vals=[
"torch",
"vllm",
],
line_names=["PyTorch", "vLLM"],
styles=([("blue", "-"), ("red", "-")]),
ylabel="TFLOPs",
plot_name=f"{model} router gemm throughput",
args={},
)
)
def benchmark(batch_size, provider):
config = get_config(model=model, trust_remote_code=trust_remote_code)
num_experts, hidden_size = get_model_params(config)
mat_a = torch.randn(
(batch_size, hidden_size), dtype=torch.bfloat16, device="cuda"
).contiguous()
mat_b = torch.randn(
(num_experts, hidden_size), dtype=torch.bfloat16, device="cuda"
).contiguous()
bias = torch.randn(
num_experts, dtype=torch.bfloat16, device="cuda"
).contiguous()
is_hopper_or_blackwell = current_platform.is_device_capability(
90
) or current_platform.is_device_capability_family(100)
allow_dsv3_router_gemm = (
is_hopper_or_blackwell
and num_experts in DSV3_SUPPORTED_NUM_EXPERTS
and hidden_size in DSV3_SUPPORTED_HIDDEN_SIZES
)
allow_gpt_oss_router_gemm = (
is_hopper_or_blackwell
and num_experts in GPT_OSS_SUPPORTED_NUM_EXPERTS
and hidden_size in GPT_OSS_SUPPORTED_HIDDEN_SIZES
)
has_bias = False
if allow_gpt_oss_router_gemm:
has_bias = True
quantiles = [0.5, 0.2, 0.8]
if provider == "torch":
def runner():
if has_bias:
F.linear(mat_a, mat_b, bias)
else:
F.linear(mat_a, mat_b)
elif provider == "vllm":
def runner():
if allow_dsv3_router_gemm:
ops.dsv3_router_gemm(mat_a, mat_b, torch.bfloat16)
elif allow_gpt_oss_router_gemm:
ops.gpt_oss_router_gemm(mat_a, mat_b, bias)
else:
raise ValueError("Unsupported router gemm")
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
runner, quantiles=quantiles
)
def tflops(t_ms):
flops = 2 * batch_size * hidden_size * num_experts
return flops / (t_ms * 1e-3) / 1e12
return tflops(ms), tflops(max_ms), tflops(min_ms)
return benchmark
if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser.add_argument("--model", type=str, default="openai/gpt-oss-20b")
parser.add_argument("--max-batch-size", default=16, type=int)
parser.add_argument("--trust-remote-code", action="store_true")
args = parser.parse_args()
# Get the benchmark function
benchmark = get_benchmark(args.model, args.max_batch_size, args.trust_remote_code)
# Run performance benchmark
benchmark.run(print_data=True)

View File

@@ -0,0 +1,144 @@
/*
* Adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc7/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_cuda.cu
* Copyright (c) 2025, The vLLM team.
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
* All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAStream.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/all.h>
#include "gpt_oss_router_gemm.cuh"
void launch_gpt_oss_router_gemm(__nv_bfloat16* gA, __nv_bfloat16* gB,
__nv_bfloat16* gC, __nv_bfloat16* bias,
int batch_size, int output_features,
int input_features, cudaStream_t stream) {
static int const WARP_TILE_M = 16;
static int const TILE_M = WARP_TILE_M;
static int const TILE_N = 8;
static int const TILE_K = 64;
static int const STAGES = 16;
static int const STAGE_UNROLL = 4;
static bool const PROFILE = false;
CUtensorMap weight_map{};
CUtensorMap activation_map{};
constexpr uint32_t rank = 2;
uint64_t size[rank] = {(uint64_t)input_features, (uint64_t)output_features};
uint64_t stride[rank - 1] = {input_features * sizeof(__nv_bfloat16)};
uint32_t box_size[rank] = {TILE_K, TILE_M};
uint32_t elem_stride[rank] = {1, 1};
CUresult res = cuTensorMapEncodeTiled(
&weight_map, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, rank,
gB, size, stride, box_size, elem_stride,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE,
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
TORCH_CHECK(res == CUDA_SUCCESS,
"cuTensorMapEncodeTiled failed for weight_map, error code=",
static_cast<int>(res));
size[1] = batch_size;
box_size[1] = TILE_N;
res = cuTensorMapEncodeTiled(
&activation_map, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
rank, gA, size, stride, box_size, elem_stride,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE,
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
TORCH_CHECK(res == CUDA_SUCCESS,
"cuTensorMapEncodeTiled failed for activation_map, error code=",
static_cast<int>(res));
int smem_size = STAGES * STAGE_UNROLL *
(TILE_M * TILE_K * sizeof(__nv_bfloat16) +
TILE_N * TILE_K * sizeof(__nv_bfloat16));
gpuErrChk(cudaFuncSetAttribute(
gpt_oss_router_gemm_kernel<WARP_TILE_M, TILE_M, TILE_N, TILE_K, STAGES,
STAGE_UNROLL, PROFILE>,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
int tiles_m = (output_features + TILE_M - 1) / TILE_M;
int tiles_n = (batch_size + TILE_N - 1) / TILE_N;
dim3 grid(tiles_m, tiles_n);
dim3 block(384);
cudaLaunchConfig_t config;
cudaLaunchAttribute attrs[1];
config.gridDim = grid;
config.blockDim = block;
config.dynamicSmemBytes = smem_size;
config.stream = stream;
config.attrs = attrs;
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = 1;
config.numAttrs = 1;
cudaLaunchKernelEx(
&config,
&gpt_oss_router_gemm_kernel<WARP_TILE_M, TILE_M, TILE_N, TILE_K, STAGES,
STAGE_UNROLL, PROFILE>,
gC, gA, gB, bias, output_features, batch_size, input_features, weight_map,
activation_map, nullptr);
}
void gpt_oss_router_gemm_cuda_forward(torch::Tensor& output,
torch::Tensor input, torch::Tensor weight,
torch::Tensor bias) {
auto const batch_size = input.size(0);
auto const input_dim = input.size(1);
auto const output_dim = weight.size(0);
auto stream = at::cuda::getCurrentCUDAStream();
if (input.scalar_type() == at::ScalarType::BFloat16) {
launch_gpt_oss_router_gemm((__nv_bfloat16*)input.data_ptr(),
(__nv_bfloat16*)weight.data_ptr(),
(__nv_bfloat16*)output.mutable_data_ptr(),
(__nv_bfloat16*)bias.data_ptr(), batch_size,
output_dim, input_dim, stream);
} else {
throw std::invalid_argument("Unsupported dtype, only supports bfloat16");
}
}
void gpt_oss_router_gemm(torch::Tensor& output, torch::Tensor input,
torch::Tensor weight, torch::Tensor bias) {
TORCH_CHECK(input.dim() == 2, "input must be 2D");
TORCH_CHECK(weight.dim() == 2, "weight must be 2D");
TORCH_CHECK(bias.dim() == 1, "bias must be 1D");
TORCH_CHECK(input.sizes()[1] == weight.sizes()[1],
"input.size(1) must match weight.size(1)");
TORCH_CHECK(weight.sizes()[0] == bias.sizes()[0],
"weight.size(0) must match bias.size(0)");
TORCH_CHECK(input.scalar_type() == at::ScalarType::BFloat16,
"input tensor must be bfloat16");
TORCH_CHECK(weight.scalar_type() == at::ScalarType::BFloat16,
"weight tensor must be bfloat16");
TORCH_CHECK(bias.scalar_type() == at::ScalarType::BFloat16,
"bias tensor must be bfloat16");
gpt_oss_router_gemm_cuda_forward(output, input, weight, bias);
}

View File

@@ -0,0 +1,447 @@
/*
* Adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc7/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_kernel.cuh
* Copyright (c) 2025, The vLLM team.
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
* All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cuda_bf16.h"
#include <stdint.h>
#include <stdio.h>
#include <vector>
#include "cuda_pipeline.h"
#include <cuda.h>
#include <cuda/barrier>
#include <cuda/std/utility>
#include <cuda_runtime.h>
using barrier = cuda::barrier<cuda::thread_scope_block>;
namespace cde = cuda::device::experimental;
namespace ptx = cuda::ptx;
#define gpuErrChk(ans) \
{ \
gpuAssert((ans), __FILE__, __LINE__); \
}
inline void gpuAssert(cudaError_t code, char const* file, int line,
bool abort = true) {
if (code != cudaSuccess) {
fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file,
line);
if (abort) {
throw std::runtime_error(cudaGetErrorString(code));
}
}
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
__device__ uint64_t gclock64() {
unsigned long long int rv;
asm volatile("mov.u64 %0, %%globaltimer;" : "=l"(rv));
return rv;
}
__device__ void ldmatrix(__nv_bfloat16 rv[2], uint32_t smem_ptr) {
int dst;
asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n"
: "=r"(dst)
: "r"(smem_ptr));
int* rvi = reinterpret_cast<int*>(&rv[0]);
rvi[0] = dst;
}
__device__ void ldmatrix2(__nv_bfloat16 rv[4], uint32_t smem_ptr) {
int x, y;
asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
: "=r"(x), "=r"(y)
: "r"(smem_ptr));
int* rvi = reinterpret_cast<int*>(&rv[0]);
rvi[0] = x;
rvi[1] = y;
}
__device__ void ldmatrix4(__nv_bfloat16 rv[8], uint32_t smem_ptr) {
int x, y, z, w;
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
: "r"(smem_ptr));
int* rvi = reinterpret_cast<int*>(&rv[0]);
rvi[0] = x;
rvi[1] = y;
rvi[2] = z;
rvi[3] = w;
}
__device__ void HMMA_1688(float d[4], __nv_bfloat16 a[4], __nv_bfloat16 b[2],
float c[4]) {
uint32_t const* A = reinterpret_cast<uint32_t const*>(&a[0]);
uint32_t const* B = reinterpret_cast<uint32_t const*>(&b[0]);
float const* C = reinterpret_cast<float const*>(&c[0]);
float* D = reinterpret_cast<float*>(&d[0]);
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
: "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(C[0]), "f"(C[1]), "f"(C[2]),
"f"(C[3]));
}
__device__ void HMMA_16816(float d[4], __nv_bfloat16 a[8], __nv_bfloat16 b[4],
float c[4]) {
uint32_t const* A = reinterpret_cast<uint32_t const*>(&a[0]);
uint32_t const* B = reinterpret_cast<uint32_t const*>(&b[0]);
float const* C = reinterpret_cast<float const*>(&c[0]);
float* D = reinterpret_cast<float*>(&d[0]);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
}
__device__ void bar_wait(uint32_t bar_ptr, int phase) {
asm volatile(
"{\n"
".reg .pred P1;\n"
"LAB_WAIT:\n"
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n"
"@P1 bra.uni DONE;\n"
"bra.uni LAB_WAIT;\n"
"DONE:\n"
"}\n" ::"r"(bar_ptr),
"r"(phase));
}
__device__ bool bar_try_wait(uint32_t bar_ptr, int phase) {
uint32_t success;
#ifdef INTERNAL
asm volatile(".pragma \"set knob DontInsertYield\";\n" : : : "memory");
#endif
asm volatile(
"{\n\t"
".reg .pred P1; \n\t"
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%1], %2; \n\t"
"selp.b32 %0, 1, 0, P1; \n\t"
"}"
: "=r"(success)
: "r"(bar_ptr), "r"(phase));
return success;
}
__device__ uint32_t elect_one_sync() {
uint32_t pred = 0;
uint32_t laneid = 0;
asm volatile(
"{\n"
".reg .b32 %%rx;\n"
".reg .pred %%px;\n"
" elect.sync %%rx|%%px, %2;\n"
"@%%px mov.s32 %1, 1;\n"
" mov.s32 %0, %%rx;\n"
"}\n"
: "+r"(laneid), "+r"(pred)
: "r"(0xFFFFFFFF));
return pred;
}
#endif
struct Profile {
uint64_t start;
uint64_t weight_load_start;
uint64_t act_load_start;
uint64_t compute_start;
uint64_t complete;
};
template <int WARP_TILE_M, int TILE_M, int TILE_N, int TILE_K, int STAGES,
int STAGE_UNROLL, bool PROFILE>
__global__ __launch_bounds__(384, 1) void gpt_oss_router_gemm_kernel(
__nv_bfloat16* output, __nv_bfloat16* weights, __nv_bfloat16* activations,
__nv_bfloat16* bias, int M, int N, int K,
const __grid_constant__ CUtensorMap weight_map,
const __grid_constant__ CUtensorMap activation_map,
Profile* profile = nullptr) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
if (PROFILE && threadIdx.x == 0 && blockIdx.y == 0)
profile[blockIdx.x].start = gclock64();
extern __shared__ __align__(128) char smem[];
__nv_bfloat16* sh_weights = (__nv_bfloat16*)&smem[0];
__nv_bfloat16* sh_activations =
(__nv_bfloat16*)&smem[STAGES * STAGE_UNROLL * TILE_M * TILE_K *
sizeof(__nv_bfloat16)];
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ barrier bar_wt_ready[STAGES];
__shared__ barrier bar_act_ready[STAGES];
__shared__ barrier bar_data_consumed[STAGES];
__shared__ float4 reduction_buffer[128];
__shared__ nv_bfloat16 sh_bias[TILE_M];
if (threadIdx.x == 0) {
for (int i = 0; i < STAGES; i++) {
init(&bar_wt_ready[i], 1);
init(&bar_act_ready[i], 1);
init(&bar_data_consumed[i], 32);
}
ptx::fence_proxy_async(ptx::space_shared);
asm volatile("prefetch.tensormap [%0];"
:
: "l"(reinterpret_cast<uint64_t>(&weight_map))
: "memory");
asm volatile("prefetch.tensormap [%0];"
:
: "l"(reinterpret_cast<uint64_t>(&activation_map))
: "memory");
}
__syncthreads();
int warp_id = threadIdx.x / 32;
int lane_id = threadIdx.x % 32;
int phase = 0;
int mib = blockIdx.x * TILE_M;
int ni = blockIdx.y * TILE_N;
float accum[4];
for (int i = 0; i < 4; i++) accum[i] = 0.f;
int const K_LOOPS_DMA =
(K + 4 * TILE_K * STAGE_UNROLL - 1) / (4 * (TILE_K * STAGE_UNROLL));
int const K_LOOPS_COMPUTE = K_LOOPS_DMA;
// Data loading thread
if (warp_id >= 4 && elect_one_sync()) {
int stage = warp_id % 4;
bool weight_warp = warp_id < 8;
if (!weight_warp) {
cudaGridDependencySynchronize();
cudaTriggerProgrammaticLaunchCompletion();
}
for (int ki = 0; ki < K_LOOPS_DMA; ki++) {
int k = (ki * 4 + (warp_id % 4)) * TILE_K * STAGE_UNROLL;
uint64_t desc_ptr_wt = reinterpret_cast<uint64_t>(&weight_map);
uint64_t desc_ptr_act = reinterpret_cast<uint64_t>(&activation_map);
uint32_t bar_ptr_wt = __cvta_generic_to_shared(&bar_wt_ready[stage]);
uint32_t bar_ptr_act = __cvta_generic_to_shared(&bar_act_ready[stage]);
int bytes_wt = TILE_M * TILE_K * sizeof(__nv_bfloat16);
int bytes_act = TILE_N * TILE_K * sizeof(__nv_bfloat16);
bar_wait(__cvta_generic_to_shared(&bar_data_consumed[stage]), phase ^ 1);
if (weight_warp)
asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"
:
: "r"(bar_ptr_wt), "r"(STAGE_UNROLL * bytes_wt));
if (!weight_warp)
asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"
:
: "r"(bar_ptr_act), "r"(STAGE_UNROLL * bytes_act));
if (PROFILE && blockIdx.y == 0 && ki == 0 && weight_warp)
profile[blockIdx.x].weight_load_start = gclock64();
if (PROFILE && blockIdx.y == 0 && ki == 0 && !weight_warp)
profile[blockIdx.x].act_load_start = gclock64();
for (int i = 0; i < STAGE_UNROLL; i++) {
uint32_t smem_ptr_wt = __cvta_generic_to_shared(
&sh_weights[(stage * STAGE_UNROLL + i) * TILE_M * TILE_K]);
uint32_t crd0 = k + i * TILE_K;
uint32_t crd1 = mib;
if (weight_warp)
asm volatile(
"cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_"
"tx::bytes [%0], [%1, {%3,%4}], "
"[%2];"
:
: "r"(smem_ptr_wt), "l"(desc_ptr_wt), "r"(bar_ptr_wt), "r"(crd0),
"r"(crd1)
: "memory");
uint32_t smem_ptr_act = __cvta_generic_to_shared(
&sh_activations[(stage * STAGE_UNROLL + i) * TILE_N * TILE_K]);
crd0 = k + i * TILE_K;
crd1 = ni;
if (!weight_warp)
asm volatile(
"cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_"
"tx::bytes [%0], [%1, {%3,%4}], "
"[%2];"
:
: "r"(smem_ptr_act), "l"(desc_ptr_act), "r"(bar_ptr_act),
"r"(crd0), "r"(crd1)
: "memory");
}
stage += 4;
if (stage >= STAGES) {
stage = warp_id % 4;
phase ^= 1;
}
}
// Wait for pending loads to be consumed before exiting, to avoid race
for (int i = 0; i < (STAGES / 4) - 1; i++) {
bar_wait(__cvta_generic_to_shared(&bar_data_consumed[stage]), phase ^ 1);
stage += 4;
if (stage >= STAGES) {
stage = warp_id % 4;
phase ^= 1;
}
}
}
// Compute threads
else if (warp_id < 4) {
// Sneak the bias load into the compute warps since they're just waiting for
// stuff anyway
if (threadIdx.x < TILE_M) sh_bias[threadIdx.x] = bias[mib + threadIdx.x];
int stage = warp_id;
int phase = 0;
int lane_id_div8 = lane_id / 8;
int lane_id_mod8 = lane_id % 8;
int lane_row_offset_wt = (lane_id_div8 % 2) ? 8 : 0;
int lane_col_offset_wt = (lane_id_div8 / 2) ? 1 : 0;
int row_wt = lane_id_mod8 + lane_row_offset_wt;
int row_act = lane_id_mod8;
int row_offset_wt = (reinterpret_cast<uintptr_t>(sh_weights) / 128) % 8;
int row_offset_act = row_offset_wt;
uint32_t bar_ptr_wt = __cvta_generic_to_shared(&bar_wt_ready[stage]);
uint32_t bar_ptr_act = __cvta_generic_to_shared(&bar_act_ready[stage]);
bool weight_ready = bar_try_wait(bar_ptr_wt, phase);
bool act_ready = bar_try_wait(bar_ptr_act, phase);
#pragma unroll 2
for (int ki = 0; ki < K_LOOPS_COMPUTE; ki++) {
int next_stage = stage + 4;
int next_phase = phase;
if (next_stage >= STAGES) {
next_stage = warp_id;
next_phase ^= 1;
}
while (!weight_ready || !act_ready) {
weight_ready = bar_try_wait(bar_ptr_wt, phase);
act_ready = bar_try_wait(bar_ptr_act, phase);
}
if (PROFILE && blockIdx.y == 0 && threadIdx.x == 0 && ki == 0)
profile[blockIdx.x].compute_start = gclock64();
if (ki + 1 < K_LOOPS_COMPUTE) {
weight_ready = bar_try_wait(
__cvta_generic_to_shared(&bar_wt_ready[next_stage]), next_phase);
act_ready = bar_try_wait(
__cvta_generic_to_shared(&bar_act_ready[next_stage]), next_phase);
}
#pragma unroll
for (int su = 0; su < STAGE_UNROLL; su++) {
__nv_bfloat16* ptr_weights =
&sh_weights[(stage * STAGE_UNROLL + su) * TILE_M * TILE_K];
__nv_bfloat16* ptr_act =
&sh_activations[(stage * STAGE_UNROLL + su) * TILE_N * TILE_K];
#pragma unroll
for (int kii = 0; kii < TILE_K / 16; kii++) {
__nv_bfloat16 a[8];
__nv_bfloat16 b[4];
int col = 2 * kii + lane_col_offset_wt;
int col_sw = ((row_wt + row_offset_wt) % 8) ^ col;
ldmatrix4(a, __cvta_generic_to_shared(
&ptr_weights[row_wt * TILE_K + col_sw * 8]));
col = 2 * kii + lane_id_div8;
col_sw = ((row_act + row_offset_act) % 8) ^ col;
ldmatrix2(b, __cvta_generic_to_shared(
&ptr_act[row_act * TILE_K + 8 * col_sw]));
HMMA_16816(accum, a, b, accum);
}
}
uint32_t bar_c = __cvta_generic_to_shared(&bar_data_consumed[stage]);
asm volatile("mbarrier.arrive.shared::cta.b64 _, [%0];" : : "r"(bar_c));
stage = next_stage;
phase = next_phase;
}
float4 accum4;
accum4.x = accum[0];
accum4.y = accum[1];
accum4.z = accum[2];
accum4.w = accum[3];
reduction_buffer[threadIdx.x] = accum4;
__syncthreads();
if (warp_id == 0) {
int mi = mib + warp_id * WARP_TILE_M;
int tm = mi + lane_id / 4;
int tn = ni + 2 * (lane_id % 4);
float4 accum1 = reduction_buffer[32 + threadIdx.x];
float4 accum2 = reduction_buffer[64 + threadIdx.x];
float4 accum3 = reduction_buffer[96 + threadIdx.x];
accum[0] = accum[0] + accum1.x + accum2.x + accum3.x;
accum[1] = accum[1] + accum1.y + accum2.y + accum3.y;
accum[2] = accum[2] + accum1.z + accum2.z + accum3.z;
accum[3] = accum[3] + accum1.w + accum2.w + accum3.w;
float bias_lo = __bfloat162float(sh_bias[tm - mib]);
float bias_hi = __bfloat162float(sh_bias[tm + 8 - mib]);
if (tn < N && tm < M)
output[tn * M + tm] = __float2bfloat16(accum[0] + bias_lo);
if (tn + 1 < N && tm < M)
output[(tn + 1) * M + tm] = __float2bfloat16(accum[1] + bias_lo);
if (tn < N && tm + 8 < M)
output[tn * M + tm + 8] = __float2bfloat16(accum[2] + bias_hi);
if (tn + 1 < N && tm + 8 < M)
output[(tn + 1) * M + tm + 8] = __float2bfloat16(accum[3] + bias_hi);
if (PROFILE && blockIdx.y == 0 && threadIdx.x == 0)
profile[blockIdx.x].complete = gclock64();
}
}
#endif // end if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
}

View File

@@ -70,4 +70,8 @@ torch::Tensor router_gemm_bf16_fp32(torch::Tensor const& input,
// Supports num_tokens in [1, 16], num_experts in {256, 384}, hidden_dim = 7168
void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a,
const torch::Tensor& mat_b);
// gpt-oss optimized router GEMM kernel for SM90+
void gpt_oss_router_gemm(torch::Tensor& output, torch::Tensor input,
torch::Tensor weight, torch::Tensor bias);
#endif

View File

@@ -132,6 +132,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// DeepSeek V3 optimized router GEMM for SM90+
m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
// conditionally compiled so impl registration is in source file
// gpt-oss optimized router GEMM kernel for SM90+
m.def(
"gpt_oss_router_gemm(Tensor! output, Tensor input, Tensor weights, "
"Tensor bias) -> ()");
m.impl("gpt_oss_router_gemm", torch::kCUDA, &gpt_oss_router_gemm);
#endif
}

View File

@@ -0,0 +1,37 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for optimized router GEMM kernel
Run `pytest tests/kernels/moe/test_router_gemm.py`.
"""
import pytest
import torch
import vllm._custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
@pytest.mark.skipif(
not (
current_platform.is_cuda()
and (
current_platform.is_device_capability(90)
or current_platform.is_device_capability_family(100)
)
),
reason="This test only runs on Hopper or Blackwell GPUs.",
)
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8])
@pytest.mark.parametrize("input_dim", [360, 720, 1440, 2880])
@pytest.mark.parametrize("output_dim", [32, 64, 128])
def test_gpt_oss_router_gemm(batch_size, input_dim, output_dim):
set_random_seed(0)
x = torch.randn(batch_size, input_dim, device="cuda", dtype=torch.bfloat16)
weight = torch.randn(output_dim, input_dim, device="cuda", dtype=torch.bfloat16)
bias = torch.randn(output_dim, device="cuda", dtype=torch.bfloat16)
output = ops.gpt_oss_router_gemm(x, weight, bias)
output_ref = torch.nn.functional.linear(x, weight, bias)
torch.testing.assert_close(output, output_ref, atol=1e-2, rtol=1e-2)

View File

@@ -2362,6 +2362,19 @@ def dsv3_router_gemm(
return output
def gpt_oss_router_gemm(
hidden_states: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
) -> torch.Tensor:
output = torch.empty(
hidden_states.shape[0],
weight.shape[0],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
torch.ops._moe_C.gpt_oss_router_gemm(output, hidden_states, weight, bias)
return output
def topk_softmax(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,

View File

@@ -13,6 +13,7 @@ from vllm.lora.layers.column_parallel_linear import (
QKVParallelLinearWithShardedLoRA,
)
from vllm.lora.layers.fused_moe import FusedMoE3DWithLoRA, FusedMoEWithLoRA
from vllm.lora.layers.gate_linear import GateLinearWithLoRA
from vllm.lora.layers.logits_processor import LogitsProcessorWithLoRA
from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA
from vllm.lora.layers.row_parallel_linear import (
@@ -38,6 +39,7 @@ __all__ = [
"RowParallelLinearWithLoRA",
"RowParallelLinearWithShardedLoRA",
"ReplicatedLinearWithLoRA",
"GateLinearWithLoRA",
"LoRAMapping",
"LoRAMappingType",
"FusedMoEWithLoRA",

View File

@@ -0,0 +1,30 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig
from vllm.model_executor.custom_op import maybe_get_oot_by_class
from vllm.model_executor.layers.fused_moe.router.gate_linear import GateLinear
from .replicated_linear import ReplicatedLinearWithLoRA
class GateLinearWithLoRA(ReplicatedLinearWithLoRA):
def __init__(self, base_layer: GateLinear) -> None:
super().__init__(
base_layer,
)
# GateLinearWithLoRA should always be replaced, regardless of the fully
# sharded LoRAs setting, because it is, by definition, copied per GPU.
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is maybe_get_oot_by_class(GateLinear)

View File

@@ -21,6 +21,7 @@ from vllm.lora.layers import (
ColumnParallelLinearWithShardedLoRA,
FusedMoE3DWithLoRA,
FusedMoEWithLoRA,
GateLinearWithLoRA,
LogitsProcessorWithLoRA,
MergedColumnParallelLinearVariableSliceWithLoRA,
MergedColumnParallelLinearWithLoRA,
@@ -81,6 +82,7 @@ _all_lora_classes: set[type[BaseLayerWithLoRA]] = {
MergedQKVParallelLinearWithLoRA,
RowParallelLinearWithLoRA,
ReplicatedLinearWithLoRA,
GateLinearWithLoRA,
LogitsProcessorWithLoRA,
ColumnParallelLinearWithShardedLoRA,
QKVParallelLinearWithShardedLoRA,

View File

@@ -3,9 +3,11 @@
import torch
from torch.nn.parameter import Parameter
import vllm._custom_ops as ops
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
@PluggableLayer.register("gate_linear")
@@ -13,8 +15,9 @@ class GateLinear(ReplicatedLinear):
"""MoE gate linear layer with three-tier GEMM dispatch:
1. DSV3 specialized kernel (SM90+, batch<=16, supported dims)
2. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype)
3. F.linear via ReplicatedLinear (ultimate fallback)
2. gpt-oss specialized kernel (SM90+, batch<=128, supported dims)
3. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype)
4. F.linear via ReplicatedLinear (ultimate fallback)
The ``out_dtype`` attribute is mutable and can be set after init
(e.g. when the required dtype depends on the expert quantization
@@ -25,6 +28,10 @@ class GateLinear(ReplicatedLinear):
DSV3_SUPPORTED_NUM_EXPERTS = [256, 384]
DSV3_SUPPORTED_HIDDEN_SIZES = [7168]
# Dimensions supported by the gpt-oss specialized kernel
GPT_OSS_SUPPORTED_NUM_EXPERTS = [32, 128]
GPT_OSS_SUPPORTED_HIDDEN_SIZES = [2880]
def __init__(
self,
input_size: int,
@@ -65,6 +72,15 @@ class GateLinear(ReplicatedLinear):
and input_size in self.DSV3_SUPPORTED_HIDDEN_SIZES
)
# gpt-oss specialized kernel eligibility (SM90+, exact dims)
self.allow_gpt_oss_router_gemm = (
self.weight.dtype == torch.bfloat16
and current_platform.is_cuda()
and is_hopper_or_blackwell
and output_size in self.GPT_OSS_SUPPORTED_NUM_EXPERTS
and input_size in self.GPT_OSS_SUPPORTED_HIDDEN_SIZES
)
# cuBLAS bf16→fp32 eligibility
self.allow_cublas_router_gemm = (
self.allow_specialized_router_gemm
@@ -92,8 +108,6 @@ class GateLinear(ReplicatedLinear):
def forward(
self, x: torch.Tensor
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
import vllm._custom_ops as ops
# Tier 1: DSV3 specialized kernel
if self.allow_dsv3_router_gemm and x.shape[0] <= 16:
output = ops.dsv3_router_gemm(
@@ -103,15 +117,47 @@ class GateLinear(ReplicatedLinear):
)
return output, None
# Tier 2: cuBLAS bf16→fp32
# Tier 2: gpt-oss specialized kernel
if self.allow_gpt_oss_router_gemm:
output = torch.ops.vllm.gpt_oss_router_gemm(x, self.weight, self.bias)
return output, None
# Tier 3: cuBLAS bf16→fp32
if self.allow_cublas_router_gemm and x.dtype == torch.bfloat16:
output = ops.router_gemm_bf16_fp32(x, self.weight)
return output, None
# Tier 3: F.linear (ReplicatedLinear)
# Tier 4: F.linear (ReplicatedLinear)
if self.out_dtype is not None and x.dtype != self.weight.dtype:
x = x.to(self.weight.dtype)
output, output_bias = super().forward(x)
if self.out_dtype is not None and output.dtype != self.out_dtype:
output = output.to(self.out_dtype)
return output, output_bias
def gpt_oss_router_gemm_impl(
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
) -> torch.Tensor:
"""
Dynamically run min-latency gemm if num_tokens <= 128.
This must be wrapped in a custom op because our torch.compile integration
does not support runtime dispatching on num_tokens.
"""
if x.shape[0] <= 128:
return ops.gpt_oss_router_gemm(x, weight, bias)
else:
return torch.nn.functional.linear(x, weight, bias)
def gpt_oss_router_gemm_fake(
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
) -> torch.Tensor:
return x.new_empty((x.shape[0], weight.shape[0]))
direct_register_custom_op(
op_name="gpt_oss_router_gemm",
op_func=gpt_oss_router_gemm_impl,
fake_impl=gpt_oss_router_gemm_fake,
)

View File

@@ -20,12 +20,11 @@ from vllm.distributed import (
tensor_model_parallel_all_gather,
)
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -175,13 +174,11 @@ class MLPBlock(torch.nn.Module):
self.hidden_size = config.hidden_size
self.experts_per_token = config.num_experts_per_tok
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
self.router = ReplicatedLinear(
self.router = GateLinear(
config.hidden_size,
config.num_local_experts,
bias=True,
quant_config=None,
prefix=f"{prefix}.router",
return_bias=False,
)
assert config.intermediate_size % self.world_size == 0
self.experts = FusedMoE(
@@ -209,7 +206,7 @@ class MLPBlock(torch.nn.Module):
self, x[:, : self.hidden_size], self.router.weight, self.router.bias
)
else:
g = self.router(x)
g, _ = self.router(x)
x = self.experts(hidden_states=x, router_logits=g)[:, : self.hidden_size]
if self.is_sequence_parallel:
@@ -273,7 +270,6 @@ class GptOssModel(nn.Module, EagleModelMixin):
self.config = vllm_config.model_config.hf_config
self.quant_config = vllm_config.quant_config
self.parallel_config = vllm_config.parallel_config
self.config.hidden_size = self.config.hidden_size
self.embedding = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,