From 9bd72311068919b8f3430278d47859cf312039fa Mon Sep 17 00:00:00 2001 From: Xin Yang <105740670+xyang16@users.noreply.github.com> Date: Wed, 1 Apr 2026 22:02:32 -0700 Subject: [PATCH] Revert "[Kernel] Add gpt-oss Router GEMM kernel (#37205)" (#38778) Signed-off-by: Xin Yang --- CMakeLists.txt | 1 - benchmarks/kernels/benchmark_router_gemm.py | 134 ------ csrc/moe/gpt_oss_router_gemm.cu | 144 ------ csrc/moe/gpt_oss_router_gemm.cuh | 447 ------------------ csrc/moe/moe_ops.h | 4 - csrc/moe/torch_bindings.cpp | 6 - tests/kernels/moe/test_router_gemm.py | 37 -- vllm/_custom_ops.py | 13 - vllm/lora/layers/__init__.py | 2 - vllm/lora/layers/gate_linear.py | 30 -- vllm/lora/utils.py | 2 - .../layers/fused_moe/router/gate_linear.py | 58 +-- vllm/model_executor/models/gpt_oss.py | 9 +- 13 files changed, 12 insertions(+), 875 deletions(-) delete mode 100644 benchmarks/kernels/benchmark_router_gemm.py delete mode 100644 csrc/moe/gpt_oss_router_gemm.cu delete mode 100644 csrc/moe/gpt_oss_router_gemm.cuh delete mode 100644 tests/kernels/moe/test_router_gemm.py delete mode 100644 vllm/lora/layers/gate_linear.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 1695d5ab4..dd6ebce34 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1030,7 +1030,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu" "csrc/moe/grouped_topk_kernels.cu" - "csrc/moe/gpt_oss_router_gemm.cu" "csrc/moe/router_gemm.cu") endif() diff --git a/benchmarks/kernels/benchmark_router_gemm.py b/benchmarks/kernels/benchmark_router_gemm.py deleted file mode 100644 index cc63f8904..000000000 --- a/benchmarks/kernels/benchmark_router_gemm.py +++ /dev/null @@ -1,134 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch -import torch.nn.functional as F - -from vllm import _custom_ops as ops -from vllm.platforms import current_platform -from vllm.transformers_utils.config import get_config -from vllm.triton_utils import triton -from vllm.utils.argparse_utils import FlexibleArgumentParser - -# Dimensions supported by the DSV3 specialized kernel -DSV3_SUPPORTED_NUM_EXPERTS = [256, 384] -DSV3_SUPPORTED_HIDDEN_SIZES = [7168] - -# Dimensions supported by the gpt-oss specialized kernel -GPT_OSS_SUPPORTED_NUM_EXPERTS = [32, 128] -GPT_OSS_SUPPORTED_HIDDEN_SIZES = [2880] - - -def get_batch_size_range(max_batch_size): - return [2**x for x in range(14) if 2**x <= max_batch_size] - - -def get_model_params(config): - if config.architectures[0] in ( - "DeepseekV2ForCausalLM", - "DeepseekV3ForCausalLM", - "DeepseekV32ForCausalLM", - ): - num_experts = config.n_routed_experts - hidden_size = config.hidden_size - elif config.architectures[0] in ("GptOssForCausalLM",): - num_experts = config.num_local_experts - hidden_size = config.hidden_size - else: - raise ValueError(f"Unsupported architecture: {config.architectures}") - return num_experts, hidden_size - - -def get_benchmark(model, max_batch_size, trust_remote_code): - @triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["batch_size"], - x_vals=get_batch_size_range(max_batch_size), - x_log=False, - line_arg="provider", - line_vals=[ - "torch", - "vllm", - ], - line_names=["PyTorch", "vLLM"], - styles=([("blue", "-"), ("red", "-")]), - ylabel="TFLOPs", - plot_name=f"{model} router gemm throughput", - args={}, - ) - ) - def benchmark(batch_size, provider): - config = get_config(model=model, trust_remote_code=trust_remote_code) - num_experts, hidden_size = get_model_params(config) - - mat_a = torch.randn( - (batch_size, hidden_size), dtype=torch.bfloat16, device="cuda" - ).contiguous() - mat_b = torch.randn( - (num_experts, hidden_size), dtype=torch.bfloat16, device="cuda" - ).contiguous() - bias = torch.randn( - num_experts, dtype=torch.bfloat16, device="cuda" - ).contiguous() - - is_hopper_or_blackwell = current_platform.is_device_capability( - 90 - ) or current_platform.is_device_capability_family(100) - allow_dsv3_router_gemm = ( - is_hopper_or_blackwell - and num_experts in DSV3_SUPPORTED_NUM_EXPERTS - and hidden_size in DSV3_SUPPORTED_HIDDEN_SIZES - ) - allow_gpt_oss_router_gemm = ( - is_hopper_or_blackwell - and num_experts in GPT_OSS_SUPPORTED_NUM_EXPERTS - and hidden_size in GPT_OSS_SUPPORTED_HIDDEN_SIZES - ) - - has_bias = False - if allow_gpt_oss_router_gemm: - has_bias = True - - quantiles = [0.5, 0.2, 0.8] - - if provider == "torch": - - def runner(): - if has_bias: - F.linear(mat_a, mat_b, bias) - else: - F.linear(mat_a, mat_b) - elif provider == "vllm": - - def runner(): - if allow_dsv3_router_gemm: - ops.dsv3_router_gemm(mat_a, mat_b, torch.bfloat16) - elif allow_gpt_oss_router_gemm: - ops.gpt_oss_router_gemm(mat_a, mat_b, bias) - else: - raise ValueError("Unsupported router gemm") - - ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( - runner, quantiles=quantiles - ) - - def tflops(t_ms): - flops = 2 * batch_size * hidden_size * num_experts - return flops / (t_ms * 1e-3) / 1e12 - - return tflops(ms), tflops(max_ms), tflops(min_ms) - - return benchmark - - -if __name__ == "__main__": - parser = FlexibleArgumentParser() - parser.add_argument("--model", type=str, default="openai/gpt-oss-20b") - parser.add_argument("--max-batch-size", default=16, type=int) - parser.add_argument("--trust-remote-code", action="store_true") - args = parser.parse_args() - - # Get the benchmark function - benchmark = get_benchmark(args.model, args.max_batch_size, args.trust_remote_code) - # Run performance benchmark - benchmark.run(print_data=True) diff --git a/csrc/moe/gpt_oss_router_gemm.cu b/csrc/moe/gpt_oss_router_gemm.cu deleted file mode 100644 index 0294cd36a..000000000 --- a/csrc/moe/gpt_oss_router_gemm.cu +++ /dev/null @@ -1,144 +0,0 @@ -/* - * Adapted from - * https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc7/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_cuda.cu - * Copyright (c) 2025, The vLLM team. - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. - * All rights reserved. SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include "gpt_oss_router_gemm.cuh" - -void launch_gpt_oss_router_gemm(__nv_bfloat16* gA, __nv_bfloat16* gB, - __nv_bfloat16* gC, __nv_bfloat16* bias, - int batch_size, int output_features, - int input_features, cudaStream_t stream) { - static int const WARP_TILE_M = 16; - static int const TILE_M = WARP_TILE_M; - static int const TILE_N = 8; - static int const TILE_K = 64; - static int const STAGES = 16; - static int const STAGE_UNROLL = 4; - static bool const PROFILE = false; - - CUtensorMap weight_map{}; - CUtensorMap activation_map{}; - - constexpr uint32_t rank = 2; - uint64_t size[rank] = {(uint64_t)input_features, (uint64_t)output_features}; - uint64_t stride[rank - 1] = {input_features * sizeof(__nv_bfloat16)}; - uint32_t box_size[rank] = {TILE_K, TILE_M}; - uint32_t elem_stride[rank] = {1, 1}; - - CUresult res = cuTensorMapEncodeTiled( - &weight_map, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, rank, - gB, size, stride, box_size, elem_stride, - CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, - CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, - CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, - CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); - TORCH_CHECK(res == CUDA_SUCCESS, - "cuTensorMapEncodeTiled failed for weight_map, error code=", - static_cast(res)); - - size[1] = batch_size; - box_size[1] = TILE_N; - - res = cuTensorMapEncodeTiled( - &activation_map, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, - rank, gA, size, stride, box_size, elem_stride, - CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, - CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, - CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, - CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); - TORCH_CHECK(res == CUDA_SUCCESS, - "cuTensorMapEncodeTiled failed for activation_map, error code=", - static_cast(res)); - - int smem_size = STAGES * STAGE_UNROLL * - (TILE_M * TILE_K * sizeof(__nv_bfloat16) + - TILE_N * TILE_K * sizeof(__nv_bfloat16)); - - gpuErrChk(cudaFuncSetAttribute( - gpt_oss_router_gemm_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - - int tiles_m = (output_features + TILE_M - 1) / TILE_M; - int tiles_n = (batch_size + TILE_N - 1) / TILE_N; - - dim3 grid(tiles_m, tiles_n); - dim3 block(384); - - cudaLaunchConfig_t config; - cudaLaunchAttribute attrs[1]; - config.gridDim = grid; - config.blockDim = block; - config.dynamicSmemBytes = smem_size; - config.stream = stream; - config.attrs = attrs; - attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = 1; - config.numAttrs = 1; - - cudaLaunchKernelEx( - &config, - &gpt_oss_router_gemm_kernel, - gC, gA, gB, bias, output_features, batch_size, input_features, weight_map, - activation_map, nullptr); -} - -void gpt_oss_router_gemm_cuda_forward(torch::Tensor& output, - torch::Tensor input, torch::Tensor weight, - torch::Tensor bias) { - auto const batch_size = input.size(0); - auto const input_dim = input.size(1); - auto const output_dim = weight.size(0); - - auto stream = at::cuda::getCurrentCUDAStream(); - - if (input.scalar_type() == at::ScalarType::BFloat16) { - launch_gpt_oss_router_gemm((__nv_bfloat16*)input.data_ptr(), - (__nv_bfloat16*)weight.data_ptr(), - (__nv_bfloat16*)output.mutable_data_ptr(), - (__nv_bfloat16*)bias.data_ptr(), batch_size, - output_dim, input_dim, stream); - } else { - throw std::invalid_argument("Unsupported dtype, only supports bfloat16"); - } -} - -void gpt_oss_router_gemm(torch::Tensor& output, torch::Tensor input, - torch::Tensor weight, torch::Tensor bias) { - TORCH_CHECK(input.dim() == 2, "input must be 2D"); - TORCH_CHECK(weight.dim() == 2, "weight must be 2D"); - TORCH_CHECK(bias.dim() == 1, "bias must be 1D"); - TORCH_CHECK(input.sizes()[1] == weight.sizes()[1], - "input.size(1) must match weight.size(1)"); - TORCH_CHECK(weight.sizes()[0] == bias.sizes()[0], - "weight.size(0) must match bias.size(0)"); - TORCH_CHECK(input.scalar_type() == at::ScalarType::BFloat16, - "input tensor must be bfloat16"); - TORCH_CHECK(weight.scalar_type() == at::ScalarType::BFloat16, - "weight tensor must be bfloat16"); - TORCH_CHECK(bias.scalar_type() == at::ScalarType::BFloat16, - "bias tensor must be bfloat16"); - gpt_oss_router_gemm_cuda_forward(output, input, weight, bias); -} diff --git a/csrc/moe/gpt_oss_router_gemm.cuh b/csrc/moe/gpt_oss_router_gemm.cuh deleted file mode 100644 index 5cc653f19..000000000 --- a/csrc/moe/gpt_oss_router_gemm.cuh +++ /dev/null @@ -1,447 +0,0 @@ -/* - * Adapted from - * https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc7/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_kernel.cuh - * Copyright (c) 2025, The vLLM team. - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. - * All rights reserved. SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "cuda_bf16.h" -#include -#include -#include - -#include "cuda_pipeline.h" -#include -#include -#include -#include - -using barrier = cuda::barrier; -namespace cde = cuda::device::experimental; -namespace ptx = cuda::ptx; - -#define gpuErrChk(ans) \ - { \ - gpuAssert((ans), __FILE__, __LINE__); \ - } - -inline void gpuAssert(cudaError_t code, char const* file, int line, - bool abort = true) { - if (code != cudaSuccess) { - fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, - line); - if (abort) { - throw std::runtime_error(cudaGetErrorString(code)); - } - } -} - -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) -__device__ uint64_t gclock64() { - unsigned long long int rv; - asm volatile("mov.u64 %0, %%globaltimer;" : "=l"(rv)); - return rv; -} - -__device__ void ldmatrix(__nv_bfloat16 rv[2], uint32_t smem_ptr) { - int dst; - asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" - : "=r"(dst) - : "r"(smem_ptr)); - int* rvi = reinterpret_cast(&rv[0]); - rvi[0] = dst; -} - -__device__ void ldmatrix2(__nv_bfloat16 rv[4], uint32_t smem_ptr) { - int x, y; - asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" - : "=r"(x), "=r"(y) - : "r"(smem_ptr)); - - int* rvi = reinterpret_cast(&rv[0]); - rvi[0] = x; - rvi[1] = y; -} - -__device__ void ldmatrix4(__nv_bfloat16 rv[8], uint32_t smem_ptr) { - int x, y, z, w; - asm volatile( - "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];" - : "=r"(x), "=r"(y), "=r"(z), "=r"(w) - : "r"(smem_ptr)); - int* rvi = reinterpret_cast(&rv[0]); - rvi[0] = x; - rvi[1] = y; - rvi[2] = z; - rvi[3] = w; -} - -__device__ void HMMA_1688(float d[4], __nv_bfloat16 a[4], __nv_bfloat16 b[2], - float c[4]) { - uint32_t const* A = reinterpret_cast(&a[0]); - uint32_t const* B = reinterpret_cast(&b[0]); - float const* C = reinterpret_cast(&c[0]); - float* D = reinterpret_cast(&d[0]); - - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(C[0]), "f"(C[1]), "f"(C[2]), - "f"(C[3])); -} - -__device__ void HMMA_16816(float d[4], __nv_bfloat16 a[8], __nv_bfloat16 b[4], - float c[4]) { - uint32_t const* A = reinterpret_cast(&a[0]); - uint32_t const* B = reinterpret_cast(&b[0]); - float const* C = reinterpret_cast(&c[0]); - float* D = reinterpret_cast(&d[0]); - - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); -} - -__device__ void bar_wait(uint32_t bar_ptr, int phase) { - asm volatile( - "{\n" - ".reg .pred P1;\n" - "LAB_WAIT:\n" - "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n" - "@P1 bra.uni DONE;\n" - "bra.uni LAB_WAIT;\n" - "DONE:\n" - "}\n" ::"r"(bar_ptr), - "r"(phase)); -} - -__device__ bool bar_try_wait(uint32_t bar_ptr, int phase) { - uint32_t success; - #ifdef INTERNAL - asm volatile(".pragma \"set knob DontInsertYield\";\n" : : : "memory"); - #endif - asm volatile( - "{\n\t" - ".reg .pred P1; \n\t" - "mbarrier.try_wait.parity.shared::cta.b64 P1, [%1], %2; \n\t" - "selp.b32 %0, 1, 0, P1; \n\t" - "}" - : "=r"(success) - : "r"(bar_ptr), "r"(phase)); - return success; -} - -__device__ uint32_t elect_one_sync() { - uint32_t pred = 0; - uint32_t laneid = 0; - asm volatile( - "{\n" - ".reg .b32 %%rx;\n" - ".reg .pred %%px;\n" - " elect.sync %%rx|%%px, %2;\n" - "@%%px mov.s32 %1, 1;\n" - " mov.s32 %0, %%rx;\n" - "}\n" - : "+r"(laneid), "+r"(pred) - : "r"(0xFFFFFFFF)); - return pred; -} -#endif - -struct Profile { - uint64_t start; - uint64_t weight_load_start; - uint64_t act_load_start; - uint64_t compute_start; - uint64_t complete; -}; - -template -__global__ __launch_bounds__(384, 1) void gpt_oss_router_gemm_kernel( - __nv_bfloat16* output, __nv_bfloat16* weights, __nv_bfloat16* activations, - __nv_bfloat16* bias, int M, int N, int K, - const __grid_constant__ CUtensorMap weight_map, - const __grid_constant__ CUtensorMap activation_map, - Profile* profile = nullptr) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) - - if (PROFILE && threadIdx.x == 0 && blockIdx.y == 0) - profile[blockIdx.x].start = gclock64(); - - extern __shared__ __align__(128) char smem[]; - - __nv_bfloat16* sh_weights = (__nv_bfloat16*)&smem[0]; - __nv_bfloat16* sh_activations = - (__nv_bfloat16*)&smem[STAGES * STAGE_UNROLL * TILE_M * TILE_K * - sizeof(__nv_bfloat16)]; - - #pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ barrier bar_wt_ready[STAGES]; - __shared__ barrier bar_act_ready[STAGES]; - __shared__ barrier bar_data_consumed[STAGES]; - - __shared__ float4 reduction_buffer[128]; - - __shared__ nv_bfloat16 sh_bias[TILE_M]; - - if (threadIdx.x == 0) { - for (int i = 0; i < STAGES; i++) { - init(&bar_wt_ready[i], 1); - init(&bar_act_ready[i], 1); - init(&bar_data_consumed[i], 32); - } - ptx::fence_proxy_async(ptx::space_shared); - asm volatile("prefetch.tensormap [%0];" - : - : "l"(reinterpret_cast(&weight_map)) - : "memory"); - asm volatile("prefetch.tensormap [%0];" - : - : "l"(reinterpret_cast(&activation_map)) - : "memory"); - } - __syncthreads(); - - int warp_id = threadIdx.x / 32; - int lane_id = threadIdx.x % 32; - - int phase = 0; - - int mib = blockIdx.x * TILE_M; - int ni = blockIdx.y * TILE_N; - - float accum[4]; - for (int i = 0; i < 4; i++) accum[i] = 0.f; - - int const K_LOOPS_DMA = - (K + 4 * TILE_K * STAGE_UNROLL - 1) / (4 * (TILE_K * STAGE_UNROLL)); - int const K_LOOPS_COMPUTE = K_LOOPS_DMA; - - // Data loading thread - if (warp_id >= 4 && elect_one_sync()) { - int stage = warp_id % 4; - - bool weight_warp = warp_id < 8; - if (!weight_warp) { - cudaGridDependencySynchronize(); - cudaTriggerProgrammaticLaunchCompletion(); - } - - for (int ki = 0; ki < K_LOOPS_DMA; ki++) { - int k = (ki * 4 + (warp_id % 4)) * TILE_K * STAGE_UNROLL; - - uint64_t desc_ptr_wt = reinterpret_cast(&weight_map); - uint64_t desc_ptr_act = reinterpret_cast(&activation_map); - - uint32_t bar_ptr_wt = __cvta_generic_to_shared(&bar_wt_ready[stage]); - uint32_t bar_ptr_act = __cvta_generic_to_shared(&bar_act_ready[stage]); - int bytes_wt = TILE_M * TILE_K * sizeof(__nv_bfloat16); - int bytes_act = TILE_N * TILE_K * sizeof(__nv_bfloat16); - - bar_wait(__cvta_generic_to_shared(&bar_data_consumed[stage]), phase ^ 1); - - if (weight_warp) - asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" - : - : "r"(bar_ptr_wt), "r"(STAGE_UNROLL * bytes_wt)); - if (!weight_warp) - asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" - : - : "r"(bar_ptr_act), "r"(STAGE_UNROLL * bytes_act)); - - if (PROFILE && blockIdx.y == 0 && ki == 0 && weight_warp) - profile[blockIdx.x].weight_load_start = gclock64(); - if (PROFILE && blockIdx.y == 0 && ki == 0 && !weight_warp) - profile[blockIdx.x].act_load_start = gclock64(); - - for (int i = 0; i < STAGE_UNROLL; i++) { - uint32_t smem_ptr_wt = __cvta_generic_to_shared( - &sh_weights[(stage * STAGE_UNROLL + i) * TILE_M * TILE_K]); - uint32_t crd0 = k + i * TILE_K; - uint32_t crd1 = mib; - if (weight_warp) - asm volatile( - "cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_" - "tx::bytes [%0], [%1, {%3,%4}], " - "[%2];" - : - : "r"(smem_ptr_wt), "l"(desc_ptr_wt), "r"(bar_ptr_wt), "r"(crd0), - "r"(crd1) - : "memory"); - - uint32_t smem_ptr_act = __cvta_generic_to_shared( - &sh_activations[(stage * STAGE_UNROLL + i) * TILE_N * TILE_K]); - crd0 = k + i * TILE_K; - crd1 = ni; - if (!weight_warp) - asm volatile( - "cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_" - "tx::bytes [%0], [%1, {%3,%4}], " - "[%2];" - : - : "r"(smem_ptr_act), "l"(desc_ptr_act), "r"(bar_ptr_act), - "r"(crd0), "r"(crd1) - : "memory"); - } - - stage += 4; - if (stage >= STAGES) { - stage = warp_id % 4; - phase ^= 1; - } - } - // Wait for pending loads to be consumed before exiting, to avoid race - for (int i = 0; i < (STAGES / 4) - 1; i++) { - bar_wait(__cvta_generic_to_shared(&bar_data_consumed[stage]), phase ^ 1); - stage += 4; - if (stage >= STAGES) { - stage = warp_id % 4; - phase ^= 1; - } - } - } - // Compute threads - else if (warp_id < 4) { - // Sneak the bias load into the compute warps since they're just waiting for - // stuff anyway - if (threadIdx.x < TILE_M) sh_bias[threadIdx.x] = bias[mib + threadIdx.x]; - - int stage = warp_id; - - int phase = 0; - int lane_id_div8 = lane_id / 8; - int lane_id_mod8 = lane_id % 8; - - int lane_row_offset_wt = (lane_id_div8 % 2) ? 8 : 0; - int lane_col_offset_wt = (lane_id_div8 / 2) ? 1 : 0; - - int row_wt = lane_id_mod8 + lane_row_offset_wt; - int row_act = lane_id_mod8; - - int row_offset_wt = (reinterpret_cast(sh_weights) / 128) % 8; - int row_offset_act = row_offset_wt; - - uint32_t bar_ptr_wt = __cvta_generic_to_shared(&bar_wt_ready[stage]); - uint32_t bar_ptr_act = __cvta_generic_to_shared(&bar_act_ready[stage]); - - bool weight_ready = bar_try_wait(bar_ptr_wt, phase); - bool act_ready = bar_try_wait(bar_ptr_act, phase); - - #pragma unroll 2 - for (int ki = 0; ki < K_LOOPS_COMPUTE; ki++) { - int next_stage = stage + 4; - int next_phase = phase; - if (next_stage >= STAGES) { - next_stage = warp_id; - next_phase ^= 1; - } - - while (!weight_ready || !act_ready) { - weight_ready = bar_try_wait(bar_ptr_wt, phase); - act_ready = bar_try_wait(bar_ptr_act, phase); - } - - if (PROFILE && blockIdx.y == 0 && threadIdx.x == 0 && ki == 0) - profile[blockIdx.x].compute_start = gclock64(); - - if (ki + 1 < K_LOOPS_COMPUTE) { - weight_ready = bar_try_wait( - __cvta_generic_to_shared(&bar_wt_ready[next_stage]), next_phase); - act_ready = bar_try_wait( - __cvta_generic_to_shared(&bar_act_ready[next_stage]), next_phase); - } - - #pragma unroll - for (int su = 0; su < STAGE_UNROLL; su++) { - __nv_bfloat16* ptr_weights = - &sh_weights[(stage * STAGE_UNROLL + su) * TILE_M * TILE_K]; - __nv_bfloat16* ptr_act = - &sh_activations[(stage * STAGE_UNROLL + su) * TILE_N * TILE_K]; - - #pragma unroll - for (int kii = 0; kii < TILE_K / 16; kii++) { - __nv_bfloat16 a[8]; - __nv_bfloat16 b[4]; - - int col = 2 * kii + lane_col_offset_wt; - int col_sw = ((row_wt + row_offset_wt) % 8) ^ col; - - ldmatrix4(a, __cvta_generic_to_shared( - &ptr_weights[row_wt * TILE_K + col_sw * 8])); - - col = 2 * kii + lane_id_div8; - col_sw = ((row_act + row_offset_act) % 8) ^ col; - - ldmatrix2(b, __cvta_generic_to_shared( - &ptr_act[row_act * TILE_K + 8 * col_sw])); - - HMMA_16816(accum, a, b, accum); - } - } - - uint32_t bar_c = __cvta_generic_to_shared(&bar_data_consumed[stage]); - asm volatile("mbarrier.arrive.shared::cta.b64 _, [%0];" : : "r"(bar_c)); - - stage = next_stage; - phase = next_phase; - } - - float4 accum4; - accum4.x = accum[0]; - accum4.y = accum[1]; - accum4.z = accum[2]; - accum4.w = accum[3]; - reduction_buffer[threadIdx.x] = accum4; - - __syncthreads(); - - if (warp_id == 0) { - int mi = mib + warp_id * WARP_TILE_M; - int tm = mi + lane_id / 4; - int tn = ni + 2 * (lane_id % 4); - - float4 accum1 = reduction_buffer[32 + threadIdx.x]; - float4 accum2 = reduction_buffer[64 + threadIdx.x]; - float4 accum3 = reduction_buffer[96 + threadIdx.x]; - - accum[0] = accum[0] + accum1.x + accum2.x + accum3.x; - accum[1] = accum[1] + accum1.y + accum2.y + accum3.y; - accum[2] = accum[2] + accum1.z + accum2.z + accum3.z; - accum[3] = accum[3] + accum1.w + accum2.w + accum3.w; - - float bias_lo = __bfloat162float(sh_bias[tm - mib]); - float bias_hi = __bfloat162float(sh_bias[tm + 8 - mib]); - - if (tn < N && tm < M) - output[tn * M + tm] = __float2bfloat16(accum[0] + bias_lo); - if (tn + 1 < N && tm < M) - output[(tn + 1) * M + tm] = __float2bfloat16(accum[1] + bias_lo); - if (tn < N && tm + 8 < M) - output[tn * M + tm + 8] = __float2bfloat16(accum[2] + bias_hi); - if (tn + 1 < N && tm + 8 < M) - output[(tn + 1) * M + tm + 8] = __float2bfloat16(accum[3] + bias_hi); - - if (PROFILE && blockIdx.y == 0 && threadIdx.x == 0) - profile[blockIdx.x].complete = gclock64(); - } - } -#endif // end if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) -} diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index de931dc76..d8d962887 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -70,8 +70,4 @@ torch::Tensor router_gemm_bf16_fp32(torch::Tensor const& input, // Supports num_tokens in [1, 16], num_experts in {256, 384}, hidden_dim = 7168 void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, const torch::Tensor& mat_b); - -// gpt-oss optimized router GEMM kernel for SM90+ -void gpt_oss_router_gemm(torch::Tensor& output, torch::Tensor input, - torch::Tensor weight, torch::Tensor bias); #endif diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 4cd74366e..7b627a6f8 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -132,12 +132,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { // DeepSeek V3 optimized router GEMM for SM90+ m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"); // conditionally compiled so impl registration is in source file - - // gpt-oss optimized router GEMM kernel for SM90+ - m.def( - "gpt_oss_router_gemm(Tensor! output, Tensor input, Tensor weights, " - "Tensor bias) -> ()"); - m.impl("gpt_oss_router_gemm", torch::kCUDA, &gpt_oss_router_gemm); #endif } diff --git a/tests/kernels/moe/test_router_gemm.py b/tests/kernels/moe/test_router_gemm.py deleted file mode 100644 index 906e47708..000000000 --- a/tests/kernels/moe/test_router_gemm.py +++ /dev/null @@ -1,37 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Tests for optimized router GEMM kernel - -Run `pytest tests/kernels/moe/test_router_gemm.py`. -""" - -import pytest -import torch - -import vllm._custom_ops as ops -from vllm.platforms import current_platform -from vllm.utils.torch_utils import set_random_seed - - -@pytest.mark.skipif( - not ( - current_platform.is_cuda() - and ( - current_platform.is_device_capability(90) - or current_platform.is_device_capability_family(100) - ) - ), - reason="This test only runs on Hopper or Blackwell GPUs.", -) -@pytest.mark.parametrize("batch_size", [1, 2, 4, 8]) -@pytest.mark.parametrize("input_dim", [360, 720, 1440, 2880]) -@pytest.mark.parametrize("output_dim", [32, 64, 128]) -def test_gpt_oss_router_gemm(batch_size, input_dim, output_dim): - set_random_seed(0) - x = torch.randn(batch_size, input_dim, device="cuda", dtype=torch.bfloat16) - weight = torch.randn(output_dim, input_dim, device="cuda", dtype=torch.bfloat16) - bias = torch.randn(output_dim, device="cuda", dtype=torch.bfloat16) - - output = ops.gpt_oss_router_gemm(x, weight, bias) - output_ref = torch.nn.functional.linear(x, weight, bias) - torch.testing.assert_close(output, output_ref, atol=1e-2, rtol=1e-2) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index b5fbb6071..9cc023138 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2327,19 +2327,6 @@ def dsv3_router_gemm( return output -def gpt_oss_router_gemm( - hidden_states: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor -) -> torch.Tensor: - output = torch.empty( - hidden_states.shape[0], - weight.shape[0], - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - torch.ops._moe_C.gpt_oss_router_gemm(output, hidden_states, weight, bias) - return output - - def topk_softmax( topk_weights: torch.Tensor, topk_ids: torch.Tensor, diff --git a/vllm/lora/layers/__init__.py b/vllm/lora/layers/__init__.py index 235f40b73..1f3fdea2c 100644 --- a/vllm/lora/layers/__init__.py +++ b/vllm/lora/layers/__init__.py @@ -13,7 +13,6 @@ from vllm.lora.layers.column_parallel_linear import ( QKVParallelLinearWithShardedLoRA, ) from vllm.lora.layers.fused_moe import FusedMoE3DWithLoRA, FusedMoEWithLoRA -from vllm.lora.layers.gate_linear import GateLinearWithLoRA from vllm.lora.layers.logits_processor import LogitsProcessorWithLoRA from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA from vllm.lora.layers.row_parallel_linear import ( @@ -39,7 +38,6 @@ __all__ = [ "RowParallelLinearWithLoRA", "RowParallelLinearWithShardedLoRA", "ReplicatedLinearWithLoRA", - "GateLinearWithLoRA", "LoRAMapping", "LoRAMappingType", "FusedMoEWithLoRA", diff --git a/vllm/lora/layers/gate_linear.py b/vllm/lora/layers/gate_linear.py deleted file mode 100644 index 9bcaaa5b8..000000000 --- a/vllm/lora/layers/gate_linear.py +++ /dev/null @@ -1,30 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch.nn as nn -from transformers import PretrainedConfig - -from vllm.config.lora import LoRAConfig -from vllm.model_executor.custom_op import maybe_get_oot_by_class -from vllm.model_executor.layers.fused_moe.router.gate_linear import GateLinear - -from .replicated_linear import ReplicatedLinearWithLoRA - - -class GateLinearWithLoRA(ReplicatedLinearWithLoRA): - def __init__(self, base_layer: GateLinear) -> None: - super().__init__( - base_layer, - ) - - # GateLinearWithLoRA should always be replaced, regardless of the fully - # sharded LoRAs setting, because it is, by definition, copied per GPU. - @classmethod - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: PretrainedConfig | None = None, - ) -> bool: - return type(source_layer) is maybe_get_oot_by_class(GateLinear) diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 75ed9674a..2349ace70 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -21,7 +21,6 @@ from vllm.lora.layers import ( ColumnParallelLinearWithShardedLoRA, FusedMoE3DWithLoRA, FusedMoEWithLoRA, - GateLinearWithLoRA, LogitsProcessorWithLoRA, MergedColumnParallelLinearVariableSliceWithLoRA, MergedColumnParallelLinearWithLoRA, @@ -82,7 +81,6 @@ _all_lora_classes: set[type[BaseLayerWithLoRA]] = { MergedQKVParallelLinearWithLoRA, RowParallelLinearWithLoRA, ReplicatedLinearWithLoRA, - GateLinearWithLoRA, LogitsProcessorWithLoRA, ColumnParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLoRA, diff --git a/vllm/model_executor/layers/fused_moe/router/gate_linear.py b/vllm/model_executor/layers/fused_moe/router/gate_linear.py index e8ed8a524..77d8e7560 100644 --- a/vllm/model_executor/layers/fused_moe/router/gate_linear.py +++ b/vllm/model_executor/layers/fused_moe/router/gate_linear.py @@ -3,11 +3,9 @@ import torch from torch.nn.parameter import Parameter -import vllm._custom_ops as ops from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.platforms import current_platform -from vllm.utils.torch_utils import direct_register_custom_op @PluggableLayer.register("gate_linear") @@ -15,9 +13,8 @@ class GateLinear(ReplicatedLinear): """MoE gate linear layer with three-tier GEMM dispatch: 1. DSV3 specialized kernel (SM90+, batch<=16, supported dims) - 2. gpt-oss specialized kernel (SM90+, batch<=128, supported dims) - 3. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype) - 4. F.linear via ReplicatedLinear (ultimate fallback) + 2. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype) + 3. F.linear via ReplicatedLinear (ultimate fallback) The ``out_dtype`` attribute is mutable and can be set after init (e.g. when the required dtype depends on the expert quantization @@ -28,10 +25,6 @@ class GateLinear(ReplicatedLinear): DSV3_SUPPORTED_NUM_EXPERTS = [256, 384] DSV3_SUPPORTED_HIDDEN_SIZES = [7168] - # Dimensions supported by the gpt-oss specialized kernel - GPT_OSS_SUPPORTED_NUM_EXPERTS = [32, 128] - GPT_OSS_SUPPORTED_HIDDEN_SIZES = [2880] - def __init__( self, input_size: int, @@ -72,15 +65,6 @@ class GateLinear(ReplicatedLinear): and input_size in self.DSV3_SUPPORTED_HIDDEN_SIZES ) - # gpt-oss specialized kernel eligibility (SM90+, exact dims) - self.allow_gpt_oss_router_gemm = ( - self.weight.dtype == torch.bfloat16 - and current_platform.is_cuda() - and is_hopper_or_blackwell - and output_size in self.GPT_OSS_SUPPORTED_NUM_EXPERTS - and input_size in self.GPT_OSS_SUPPORTED_HIDDEN_SIZES - ) - # cuBLAS bf16→fp32 eligibility self.allow_cublas_router_gemm = ( self.allow_specialized_router_gemm @@ -108,6 +92,8 @@ class GateLinear(ReplicatedLinear): def forward( self, x: torch.Tensor ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: + import vllm._custom_ops as ops + # Tier 1: DSV3 specialized kernel if self.allow_dsv3_router_gemm and x.shape[0] <= 16: output = ops.dsv3_router_gemm( @@ -117,47 +103,15 @@ class GateLinear(ReplicatedLinear): ) return output, None - # Tier 2: gpt-oss specialized kernel - if self.allow_gpt_oss_router_gemm: - output = torch.ops.vllm.gpt_oss_router_gemm(x, self.weight, self.bias) - return output, None - - # Tier 3: cuBLAS bf16→fp32 + # Tier 2: cuBLAS bf16→fp32 if self.allow_cublas_router_gemm and x.dtype == torch.bfloat16: output = ops.router_gemm_bf16_fp32(x, self.weight) return output, None - # Tier 4: F.linear (ReplicatedLinear) + # Tier 3: F.linear (ReplicatedLinear) if self.out_dtype is not None and x.dtype != self.weight.dtype: x = x.to(self.weight.dtype) output, output_bias = super().forward(x) if self.out_dtype is not None and output.dtype != self.out_dtype: output = output.to(self.out_dtype) return output, output_bias - - -def gpt_oss_router_gemm_impl( - x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor -) -> torch.Tensor: - """ - Dynamically run min-latency gemm if num_tokens <= 128. - This must be wrapped in a custom op because our torch.compile integration - does not support runtime dispatching on num_tokens. - """ - if x.shape[0] <= 128: - return ops.gpt_oss_router_gemm(x, weight, bias) - else: - return torch.nn.functional.linear(x, weight, bias) - - -def gpt_oss_router_gemm_fake( - x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor -) -> torch.Tensor: - return x.new_empty((x.shape[0], weight.shape[0])) - - -direct_register_custom_op( - op_name="gpt_oss_router_gemm", - op_func=gpt_oss_router_gemm_impl, - fake_impl=gpt_oss_router_gemm_fake, -) diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 482056250..a9ec82974 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -20,11 +20,12 @@ from vllm.distributed import ( tensor_model_parallel_all_gather, ) from vllm.model_executor.layers.attention import Attention -from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( QKVParallelLinear, + ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -174,11 +175,13 @@ class MLPBlock(torch.nn.Module): self.hidden_size = config.hidden_size self.experts_per_token = config.num_experts_per_tok self.world_size = dist.get_world_size() if dist.is_initialized() else 1 - self.router = GateLinear( + self.router = ReplicatedLinear( config.hidden_size, config.num_local_experts, bias=True, + quant_config=None, prefix=f"{prefix}.router", + return_bias=False, ) assert config.intermediate_size % self.world_size == 0 self.experts = FusedMoE( @@ -206,7 +209,7 @@ class MLPBlock(torch.nn.Module): self, x[:, : self.hidden_size], self.router.weight, self.router.bias ) else: - g, _ = self.router(x) + g = self.router(x) x = self.experts(hidden_states=x, router_logits=g)[:, : self.hidden_size] if self.is_sequence_parallel: