Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
@@ -1030,7 +1030,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_MOE_EXT_SRC
|
||||
"csrc/moe/moe_wna16.cu"
|
||||
"csrc/moe/grouped_topk_kernels.cu"
|
||||
"csrc/moe/gpt_oss_router_gemm.cu"
|
||||
"csrc/moe/router_gemm.cu")
|
||||
endif()
|
||||
|
||||
|
||||
@@ -1,134 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.config import get_config
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
# Dimensions supported by the DSV3 specialized kernel
|
||||
DSV3_SUPPORTED_NUM_EXPERTS = [256, 384]
|
||||
DSV3_SUPPORTED_HIDDEN_SIZES = [7168]
|
||||
|
||||
# Dimensions supported by the gpt-oss specialized kernel
|
||||
GPT_OSS_SUPPORTED_NUM_EXPERTS = [32, 128]
|
||||
GPT_OSS_SUPPORTED_HIDDEN_SIZES = [2880]
|
||||
|
||||
|
||||
def get_batch_size_range(max_batch_size):
|
||||
return [2**x for x in range(14) if 2**x <= max_batch_size]
|
||||
|
||||
|
||||
def get_model_params(config):
|
||||
if config.architectures[0] in (
|
||||
"DeepseekV2ForCausalLM",
|
||||
"DeepseekV3ForCausalLM",
|
||||
"DeepseekV32ForCausalLM",
|
||||
):
|
||||
num_experts = config.n_routed_experts
|
||||
hidden_size = config.hidden_size
|
||||
elif config.architectures[0] in ("GptOssForCausalLM",):
|
||||
num_experts = config.num_local_experts
|
||||
hidden_size = config.hidden_size
|
||||
else:
|
||||
raise ValueError(f"Unsupported architecture: {config.architectures}")
|
||||
return num_experts, hidden_size
|
||||
|
||||
|
||||
def get_benchmark(model, max_batch_size, trust_remote_code):
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=get_batch_size_range(max_batch_size),
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=[
|
||||
"torch",
|
||||
"vllm",
|
||||
],
|
||||
line_names=["PyTorch", "vLLM"],
|
||||
styles=([("blue", "-"), ("red", "-")]),
|
||||
ylabel="TFLOPs",
|
||||
plot_name=f"{model} router gemm throughput",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider):
|
||||
config = get_config(model=model, trust_remote_code=trust_remote_code)
|
||||
num_experts, hidden_size = get_model_params(config)
|
||||
|
||||
mat_a = torch.randn(
|
||||
(batch_size, hidden_size), dtype=torch.bfloat16, device="cuda"
|
||||
).contiguous()
|
||||
mat_b = torch.randn(
|
||||
(num_experts, hidden_size), dtype=torch.bfloat16, device="cuda"
|
||||
).contiguous()
|
||||
bias = torch.randn(
|
||||
num_experts, dtype=torch.bfloat16, device="cuda"
|
||||
).contiguous()
|
||||
|
||||
is_hopper_or_blackwell = current_platform.is_device_capability(
|
||||
90
|
||||
) or current_platform.is_device_capability_family(100)
|
||||
allow_dsv3_router_gemm = (
|
||||
is_hopper_or_blackwell
|
||||
and num_experts in DSV3_SUPPORTED_NUM_EXPERTS
|
||||
and hidden_size in DSV3_SUPPORTED_HIDDEN_SIZES
|
||||
)
|
||||
allow_gpt_oss_router_gemm = (
|
||||
is_hopper_or_blackwell
|
||||
and num_experts in GPT_OSS_SUPPORTED_NUM_EXPERTS
|
||||
and hidden_size in GPT_OSS_SUPPORTED_HIDDEN_SIZES
|
||||
)
|
||||
|
||||
has_bias = False
|
||||
if allow_gpt_oss_router_gemm:
|
||||
has_bias = True
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "torch":
|
||||
|
||||
def runner():
|
||||
if has_bias:
|
||||
F.linear(mat_a, mat_b, bias)
|
||||
else:
|
||||
F.linear(mat_a, mat_b)
|
||||
elif provider == "vllm":
|
||||
|
||||
def runner():
|
||||
if allow_dsv3_router_gemm:
|
||||
ops.dsv3_router_gemm(mat_a, mat_b, torch.bfloat16)
|
||||
elif allow_gpt_oss_router_gemm:
|
||||
ops.gpt_oss_router_gemm(mat_a, mat_b, bias)
|
||||
else:
|
||||
raise ValueError("Unsupported router gemm")
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
runner, quantiles=quantiles
|
||||
)
|
||||
|
||||
def tflops(t_ms):
|
||||
flops = 2 * batch_size * hidden_size * num_experts
|
||||
return flops / (t_ms * 1e-3) / 1e12
|
||||
|
||||
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
||||
|
||||
return benchmark
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
parser.add_argument("--model", type=str, default="openai/gpt-oss-20b")
|
||||
parser.add_argument("--max-batch-size", default=16, type=int)
|
||||
parser.add_argument("--trust-remote-code", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get the benchmark function
|
||||
benchmark = get_benchmark(args.model, args.max_batch_size, args.trust_remote_code)
|
||||
# Run performance benchmark
|
||||
benchmark.run(print_data=True)
|
||||
@@ -1,144 +0,0 @@
|
||||
/*
|
||||
* Adapted from
|
||||
* https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc7/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_cuda.cu
|
||||
* Copyright (c) 2025, The vLLM team.
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
|
||||
* All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/all.h>
|
||||
#include "gpt_oss_router_gemm.cuh"
|
||||
|
||||
void launch_gpt_oss_router_gemm(__nv_bfloat16* gA, __nv_bfloat16* gB,
|
||||
__nv_bfloat16* gC, __nv_bfloat16* bias,
|
||||
int batch_size, int output_features,
|
||||
int input_features, cudaStream_t stream) {
|
||||
static int const WARP_TILE_M = 16;
|
||||
static int const TILE_M = WARP_TILE_M;
|
||||
static int const TILE_N = 8;
|
||||
static int const TILE_K = 64;
|
||||
static int const STAGES = 16;
|
||||
static int const STAGE_UNROLL = 4;
|
||||
static bool const PROFILE = false;
|
||||
|
||||
CUtensorMap weight_map{};
|
||||
CUtensorMap activation_map{};
|
||||
|
||||
constexpr uint32_t rank = 2;
|
||||
uint64_t size[rank] = {(uint64_t)input_features, (uint64_t)output_features};
|
||||
uint64_t stride[rank - 1] = {input_features * sizeof(__nv_bfloat16)};
|
||||
uint32_t box_size[rank] = {TILE_K, TILE_M};
|
||||
uint32_t elem_stride[rank] = {1, 1};
|
||||
|
||||
CUresult res = cuTensorMapEncodeTiled(
|
||||
&weight_map, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, rank,
|
||||
gB, size, stride, box_size, elem_stride,
|
||||
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
|
||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B,
|
||||
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE,
|
||||
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
|
||||
TORCH_CHECK(res == CUDA_SUCCESS,
|
||||
"cuTensorMapEncodeTiled failed for weight_map, error code=",
|
||||
static_cast<int>(res));
|
||||
|
||||
size[1] = batch_size;
|
||||
box_size[1] = TILE_N;
|
||||
|
||||
res = cuTensorMapEncodeTiled(
|
||||
&activation_map, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
|
||||
rank, gA, size, stride, box_size, elem_stride,
|
||||
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
|
||||
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B,
|
||||
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE,
|
||||
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
|
||||
TORCH_CHECK(res == CUDA_SUCCESS,
|
||||
"cuTensorMapEncodeTiled failed for activation_map, error code=",
|
||||
static_cast<int>(res));
|
||||
|
||||
int smem_size = STAGES * STAGE_UNROLL *
|
||||
(TILE_M * TILE_K * sizeof(__nv_bfloat16) +
|
||||
TILE_N * TILE_K * sizeof(__nv_bfloat16));
|
||||
|
||||
gpuErrChk(cudaFuncSetAttribute(
|
||||
gpt_oss_router_gemm_kernel<WARP_TILE_M, TILE_M, TILE_N, TILE_K, STAGES,
|
||||
STAGE_UNROLL, PROFILE>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
|
||||
int tiles_m = (output_features + TILE_M - 1) / TILE_M;
|
||||
int tiles_n = (batch_size + TILE_N - 1) / TILE_N;
|
||||
|
||||
dim3 grid(tiles_m, tiles_n);
|
||||
dim3 block(384);
|
||||
|
||||
cudaLaunchConfig_t config;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
config.gridDim = grid;
|
||||
config.blockDim = block;
|
||||
config.dynamicSmemBytes = smem_size;
|
||||
config.stream = stream;
|
||||
config.attrs = attrs;
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = 1;
|
||||
config.numAttrs = 1;
|
||||
|
||||
cudaLaunchKernelEx(
|
||||
&config,
|
||||
&gpt_oss_router_gemm_kernel<WARP_TILE_M, TILE_M, TILE_N, TILE_K, STAGES,
|
||||
STAGE_UNROLL, PROFILE>,
|
||||
gC, gA, gB, bias, output_features, batch_size, input_features, weight_map,
|
||||
activation_map, nullptr);
|
||||
}
|
||||
|
||||
void gpt_oss_router_gemm_cuda_forward(torch::Tensor& output,
|
||||
torch::Tensor input, torch::Tensor weight,
|
||||
torch::Tensor bias) {
|
||||
auto const batch_size = input.size(0);
|
||||
auto const input_dim = input.size(1);
|
||||
auto const output_dim = weight.size(0);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
if (input.scalar_type() == at::ScalarType::BFloat16) {
|
||||
launch_gpt_oss_router_gemm((__nv_bfloat16*)input.data_ptr(),
|
||||
(__nv_bfloat16*)weight.data_ptr(),
|
||||
(__nv_bfloat16*)output.mutable_data_ptr(),
|
||||
(__nv_bfloat16*)bias.data_ptr(), batch_size,
|
||||
output_dim, input_dim, stream);
|
||||
} else {
|
||||
throw std::invalid_argument("Unsupported dtype, only supports bfloat16");
|
||||
}
|
||||
}
|
||||
|
||||
void gpt_oss_router_gemm(torch::Tensor& output, torch::Tensor input,
|
||||
torch::Tensor weight, torch::Tensor bias) {
|
||||
TORCH_CHECK(input.dim() == 2, "input must be 2D");
|
||||
TORCH_CHECK(weight.dim() == 2, "weight must be 2D");
|
||||
TORCH_CHECK(bias.dim() == 1, "bias must be 1D");
|
||||
TORCH_CHECK(input.sizes()[1] == weight.sizes()[1],
|
||||
"input.size(1) must match weight.size(1)");
|
||||
TORCH_CHECK(weight.sizes()[0] == bias.sizes()[0],
|
||||
"weight.size(0) must match bias.size(0)");
|
||||
TORCH_CHECK(input.scalar_type() == at::ScalarType::BFloat16,
|
||||
"input tensor must be bfloat16");
|
||||
TORCH_CHECK(weight.scalar_type() == at::ScalarType::BFloat16,
|
||||
"weight tensor must be bfloat16");
|
||||
TORCH_CHECK(bias.scalar_type() == at::ScalarType::BFloat16,
|
||||
"bias tensor must be bfloat16");
|
||||
gpt_oss_router_gemm_cuda_forward(output, input, weight, bias);
|
||||
}
|
||||
@@ -1,447 +0,0 @@
|
||||
/*
|
||||
* Adapted from
|
||||
* https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc7/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_kernel.cuh
|
||||
* Copyright (c) 2025, The vLLM team.
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
|
||||
* All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "cuda_bf16.h"
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <vector>
|
||||
|
||||
#include "cuda_pipeline.h"
|
||||
#include <cuda.h>
|
||||
#include <cuda/barrier>
|
||||
#include <cuda/std/utility>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
using barrier = cuda::barrier<cuda::thread_scope_block>;
|
||||
namespace cde = cuda::device::experimental;
|
||||
namespace ptx = cuda::ptx;
|
||||
|
||||
#define gpuErrChk(ans) \
|
||||
{ \
|
||||
gpuAssert((ans), __FILE__, __LINE__); \
|
||||
}
|
||||
|
||||
inline void gpuAssert(cudaError_t code, char const* file, int line,
|
||||
bool abort = true) {
|
||||
if (code != cudaSuccess) {
|
||||
fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file,
|
||||
line);
|
||||
if (abort) {
|
||||
throw std::runtime_error(cudaGetErrorString(code));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||||
__device__ uint64_t gclock64() {
|
||||
unsigned long long int rv;
|
||||
asm volatile("mov.u64 %0, %%globaltimer;" : "=l"(rv));
|
||||
return rv;
|
||||
}
|
||||
|
||||
__device__ void ldmatrix(__nv_bfloat16 rv[2], uint32_t smem_ptr) {
|
||||
int dst;
|
||||
asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n"
|
||||
: "=r"(dst)
|
||||
: "r"(smem_ptr));
|
||||
int* rvi = reinterpret_cast<int*>(&rv[0]);
|
||||
rvi[0] = dst;
|
||||
}
|
||||
|
||||
__device__ void ldmatrix2(__nv_bfloat16 rv[4], uint32_t smem_ptr) {
|
||||
int x, y;
|
||||
asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
|
||||
: "=r"(x), "=r"(y)
|
||||
: "r"(smem_ptr));
|
||||
|
||||
int* rvi = reinterpret_cast<int*>(&rv[0]);
|
||||
rvi[0] = x;
|
||||
rvi[1] = y;
|
||||
}
|
||||
|
||||
__device__ void ldmatrix4(__nv_bfloat16 rv[8], uint32_t smem_ptr) {
|
||||
int x, y, z, w;
|
||||
asm volatile(
|
||||
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(x), "=r"(y), "=r"(z), "=r"(w)
|
||||
: "r"(smem_ptr));
|
||||
int* rvi = reinterpret_cast<int*>(&rv[0]);
|
||||
rvi[0] = x;
|
||||
rvi[1] = y;
|
||||
rvi[2] = z;
|
||||
rvi[3] = w;
|
||||
}
|
||||
|
||||
__device__ void HMMA_1688(float d[4], __nv_bfloat16 a[4], __nv_bfloat16 b[2],
|
||||
float c[4]) {
|
||||
uint32_t const* A = reinterpret_cast<uint32_t const*>(&a[0]);
|
||||
uint32_t const* B = reinterpret_cast<uint32_t const*>(&b[0]);
|
||||
float const* C = reinterpret_cast<float const*>(&c[0]);
|
||||
float* D = reinterpret_cast<float*>(&d[0]);
|
||||
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
|
||||
: "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(C[0]), "f"(C[1]), "f"(C[2]),
|
||||
"f"(C[3]));
|
||||
}
|
||||
|
||||
__device__ void HMMA_16816(float d[4], __nv_bfloat16 a[8], __nv_bfloat16 b[4],
|
||||
float c[4]) {
|
||||
uint32_t const* A = reinterpret_cast<uint32_t const*>(&a[0]);
|
||||
uint32_t const* B = reinterpret_cast<uint32_t const*>(&b[0]);
|
||||
float const* C = reinterpret_cast<float const*>(&c[0]);
|
||||
float* D = reinterpret_cast<float*>(&d[0]);
|
||||
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
|
||||
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
|
||||
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
|
||||
}
|
||||
|
||||
__device__ void bar_wait(uint32_t bar_ptr, int phase) {
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg .pred P1;\n"
|
||||
"LAB_WAIT:\n"
|
||||
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n"
|
||||
"@P1 bra.uni DONE;\n"
|
||||
"bra.uni LAB_WAIT;\n"
|
||||
"DONE:\n"
|
||||
"}\n" ::"r"(bar_ptr),
|
||||
"r"(phase));
|
||||
}
|
||||
|
||||
__device__ bool bar_try_wait(uint32_t bar_ptr, int phase) {
|
||||
uint32_t success;
|
||||
#ifdef INTERNAL
|
||||
asm volatile(".pragma \"set knob DontInsertYield\";\n" : : : "memory");
|
||||
#endif
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
".reg .pred P1; \n\t"
|
||||
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%1], %2; \n\t"
|
||||
"selp.b32 %0, 1, 0, P1; \n\t"
|
||||
"}"
|
||||
: "=r"(success)
|
||||
: "r"(bar_ptr), "r"(phase));
|
||||
return success;
|
||||
}
|
||||
|
||||
__device__ uint32_t elect_one_sync() {
|
||||
uint32_t pred = 0;
|
||||
uint32_t laneid = 0;
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg .b32 %%rx;\n"
|
||||
".reg .pred %%px;\n"
|
||||
" elect.sync %%rx|%%px, %2;\n"
|
||||
"@%%px mov.s32 %1, 1;\n"
|
||||
" mov.s32 %0, %%rx;\n"
|
||||
"}\n"
|
||||
: "+r"(laneid), "+r"(pred)
|
||||
: "r"(0xFFFFFFFF));
|
||||
return pred;
|
||||
}
|
||||
#endif
|
||||
|
||||
struct Profile {
|
||||
uint64_t start;
|
||||
uint64_t weight_load_start;
|
||||
uint64_t act_load_start;
|
||||
uint64_t compute_start;
|
||||
uint64_t complete;
|
||||
};
|
||||
|
||||
template <int WARP_TILE_M, int TILE_M, int TILE_N, int TILE_K, int STAGES,
|
||||
int STAGE_UNROLL, bool PROFILE>
|
||||
__global__ __launch_bounds__(384, 1) void gpt_oss_router_gemm_kernel(
|
||||
__nv_bfloat16* output, __nv_bfloat16* weights, __nv_bfloat16* activations,
|
||||
__nv_bfloat16* bias, int M, int N, int K,
|
||||
const __grid_constant__ CUtensorMap weight_map,
|
||||
const __grid_constant__ CUtensorMap activation_map,
|
||||
Profile* profile = nullptr) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||||
|
||||
if (PROFILE && threadIdx.x == 0 && blockIdx.y == 0)
|
||||
profile[blockIdx.x].start = gclock64();
|
||||
|
||||
extern __shared__ __align__(128) char smem[];
|
||||
|
||||
__nv_bfloat16* sh_weights = (__nv_bfloat16*)&smem[0];
|
||||
__nv_bfloat16* sh_activations =
|
||||
(__nv_bfloat16*)&smem[STAGES * STAGE_UNROLL * TILE_M * TILE_K *
|
||||
sizeof(__nv_bfloat16)];
|
||||
|
||||
#pragma nv_diag_suppress static_var_with_dynamic_init
|
||||
__shared__ barrier bar_wt_ready[STAGES];
|
||||
__shared__ barrier bar_act_ready[STAGES];
|
||||
__shared__ barrier bar_data_consumed[STAGES];
|
||||
|
||||
__shared__ float4 reduction_buffer[128];
|
||||
|
||||
__shared__ nv_bfloat16 sh_bias[TILE_M];
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
for (int i = 0; i < STAGES; i++) {
|
||||
init(&bar_wt_ready[i], 1);
|
||||
init(&bar_act_ready[i], 1);
|
||||
init(&bar_data_consumed[i], 32);
|
||||
}
|
||||
ptx::fence_proxy_async(ptx::space_shared);
|
||||
asm volatile("prefetch.tensormap [%0];"
|
||||
:
|
||||
: "l"(reinterpret_cast<uint64_t>(&weight_map))
|
||||
: "memory");
|
||||
asm volatile("prefetch.tensormap [%0];"
|
||||
:
|
||||
: "l"(reinterpret_cast<uint64_t>(&activation_map))
|
||||
: "memory");
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int warp_id = threadIdx.x / 32;
|
||||
int lane_id = threadIdx.x % 32;
|
||||
|
||||
int phase = 0;
|
||||
|
||||
int mib = blockIdx.x * TILE_M;
|
||||
int ni = blockIdx.y * TILE_N;
|
||||
|
||||
float accum[4];
|
||||
for (int i = 0; i < 4; i++) accum[i] = 0.f;
|
||||
|
||||
int const K_LOOPS_DMA =
|
||||
(K + 4 * TILE_K * STAGE_UNROLL - 1) / (4 * (TILE_K * STAGE_UNROLL));
|
||||
int const K_LOOPS_COMPUTE = K_LOOPS_DMA;
|
||||
|
||||
// Data loading thread
|
||||
if (warp_id >= 4 && elect_one_sync()) {
|
||||
int stage = warp_id % 4;
|
||||
|
||||
bool weight_warp = warp_id < 8;
|
||||
if (!weight_warp) {
|
||||
cudaGridDependencySynchronize();
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
}
|
||||
|
||||
for (int ki = 0; ki < K_LOOPS_DMA; ki++) {
|
||||
int k = (ki * 4 + (warp_id % 4)) * TILE_K * STAGE_UNROLL;
|
||||
|
||||
uint64_t desc_ptr_wt = reinterpret_cast<uint64_t>(&weight_map);
|
||||
uint64_t desc_ptr_act = reinterpret_cast<uint64_t>(&activation_map);
|
||||
|
||||
uint32_t bar_ptr_wt = __cvta_generic_to_shared(&bar_wt_ready[stage]);
|
||||
uint32_t bar_ptr_act = __cvta_generic_to_shared(&bar_act_ready[stage]);
|
||||
int bytes_wt = TILE_M * TILE_K * sizeof(__nv_bfloat16);
|
||||
int bytes_act = TILE_N * TILE_K * sizeof(__nv_bfloat16);
|
||||
|
||||
bar_wait(__cvta_generic_to_shared(&bar_data_consumed[stage]), phase ^ 1);
|
||||
|
||||
if (weight_warp)
|
||||
asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"
|
||||
:
|
||||
: "r"(bar_ptr_wt), "r"(STAGE_UNROLL * bytes_wt));
|
||||
if (!weight_warp)
|
||||
asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"
|
||||
:
|
||||
: "r"(bar_ptr_act), "r"(STAGE_UNROLL * bytes_act));
|
||||
|
||||
if (PROFILE && blockIdx.y == 0 && ki == 0 && weight_warp)
|
||||
profile[blockIdx.x].weight_load_start = gclock64();
|
||||
if (PROFILE && blockIdx.y == 0 && ki == 0 && !weight_warp)
|
||||
profile[blockIdx.x].act_load_start = gclock64();
|
||||
|
||||
for (int i = 0; i < STAGE_UNROLL; i++) {
|
||||
uint32_t smem_ptr_wt = __cvta_generic_to_shared(
|
||||
&sh_weights[(stage * STAGE_UNROLL + i) * TILE_M * TILE_K]);
|
||||
uint32_t crd0 = k + i * TILE_K;
|
||||
uint32_t crd1 = mib;
|
||||
if (weight_warp)
|
||||
asm volatile(
|
||||
"cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_"
|
||||
"tx::bytes [%0], [%1, {%3,%4}], "
|
||||
"[%2];"
|
||||
:
|
||||
: "r"(smem_ptr_wt), "l"(desc_ptr_wt), "r"(bar_ptr_wt), "r"(crd0),
|
||||
"r"(crd1)
|
||||
: "memory");
|
||||
|
||||
uint32_t smem_ptr_act = __cvta_generic_to_shared(
|
||||
&sh_activations[(stage * STAGE_UNROLL + i) * TILE_N * TILE_K]);
|
||||
crd0 = k + i * TILE_K;
|
||||
crd1 = ni;
|
||||
if (!weight_warp)
|
||||
asm volatile(
|
||||
"cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_"
|
||||
"tx::bytes [%0], [%1, {%3,%4}], "
|
||||
"[%2];"
|
||||
:
|
||||
: "r"(smem_ptr_act), "l"(desc_ptr_act), "r"(bar_ptr_act),
|
||||
"r"(crd0), "r"(crd1)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
stage += 4;
|
||||
if (stage >= STAGES) {
|
||||
stage = warp_id % 4;
|
||||
phase ^= 1;
|
||||
}
|
||||
}
|
||||
// Wait for pending loads to be consumed before exiting, to avoid race
|
||||
for (int i = 0; i < (STAGES / 4) - 1; i++) {
|
||||
bar_wait(__cvta_generic_to_shared(&bar_data_consumed[stage]), phase ^ 1);
|
||||
stage += 4;
|
||||
if (stage >= STAGES) {
|
||||
stage = warp_id % 4;
|
||||
phase ^= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Compute threads
|
||||
else if (warp_id < 4) {
|
||||
// Sneak the bias load into the compute warps since they're just waiting for
|
||||
// stuff anyway
|
||||
if (threadIdx.x < TILE_M) sh_bias[threadIdx.x] = bias[mib + threadIdx.x];
|
||||
|
||||
int stage = warp_id;
|
||||
|
||||
int phase = 0;
|
||||
int lane_id_div8 = lane_id / 8;
|
||||
int lane_id_mod8 = lane_id % 8;
|
||||
|
||||
int lane_row_offset_wt = (lane_id_div8 % 2) ? 8 : 0;
|
||||
int lane_col_offset_wt = (lane_id_div8 / 2) ? 1 : 0;
|
||||
|
||||
int row_wt = lane_id_mod8 + lane_row_offset_wt;
|
||||
int row_act = lane_id_mod8;
|
||||
|
||||
int row_offset_wt = (reinterpret_cast<uintptr_t>(sh_weights) / 128) % 8;
|
||||
int row_offset_act = row_offset_wt;
|
||||
|
||||
uint32_t bar_ptr_wt = __cvta_generic_to_shared(&bar_wt_ready[stage]);
|
||||
uint32_t bar_ptr_act = __cvta_generic_to_shared(&bar_act_ready[stage]);
|
||||
|
||||
bool weight_ready = bar_try_wait(bar_ptr_wt, phase);
|
||||
bool act_ready = bar_try_wait(bar_ptr_act, phase);
|
||||
|
||||
#pragma unroll 2
|
||||
for (int ki = 0; ki < K_LOOPS_COMPUTE; ki++) {
|
||||
int next_stage = stage + 4;
|
||||
int next_phase = phase;
|
||||
if (next_stage >= STAGES) {
|
||||
next_stage = warp_id;
|
||||
next_phase ^= 1;
|
||||
}
|
||||
|
||||
while (!weight_ready || !act_ready) {
|
||||
weight_ready = bar_try_wait(bar_ptr_wt, phase);
|
||||
act_ready = bar_try_wait(bar_ptr_act, phase);
|
||||
}
|
||||
|
||||
if (PROFILE && blockIdx.y == 0 && threadIdx.x == 0 && ki == 0)
|
||||
profile[blockIdx.x].compute_start = gclock64();
|
||||
|
||||
if (ki + 1 < K_LOOPS_COMPUTE) {
|
||||
weight_ready = bar_try_wait(
|
||||
__cvta_generic_to_shared(&bar_wt_ready[next_stage]), next_phase);
|
||||
act_ready = bar_try_wait(
|
||||
__cvta_generic_to_shared(&bar_act_ready[next_stage]), next_phase);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int su = 0; su < STAGE_UNROLL; su++) {
|
||||
__nv_bfloat16* ptr_weights =
|
||||
&sh_weights[(stage * STAGE_UNROLL + su) * TILE_M * TILE_K];
|
||||
__nv_bfloat16* ptr_act =
|
||||
&sh_activations[(stage * STAGE_UNROLL + su) * TILE_N * TILE_K];
|
||||
|
||||
#pragma unroll
|
||||
for (int kii = 0; kii < TILE_K / 16; kii++) {
|
||||
__nv_bfloat16 a[8];
|
||||
__nv_bfloat16 b[4];
|
||||
|
||||
int col = 2 * kii + lane_col_offset_wt;
|
||||
int col_sw = ((row_wt + row_offset_wt) % 8) ^ col;
|
||||
|
||||
ldmatrix4(a, __cvta_generic_to_shared(
|
||||
&ptr_weights[row_wt * TILE_K + col_sw * 8]));
|
||||
|
||||
col = 2 * kii + lane_id_div8;
|
||||
col_sw = ((row_act + row_offset_act) % 8) ^ col;
|
||||
|
||||
ldmatrix2(b, __cvta_generic_to_shared(
|
||||
&ptr_act[row_act * TILE_K + 8 * col_sw]));
|
||||
|
||||
HMMA_16816(accum, a, b, accum);
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t bar_c = __cvta_generic_to_shared(&bar_data_consumed[stage]);
|
||||
asm volatile("mbarrier.arrive.shared::cta.b64 _, [%0];" : : "r"(bar_c));
|
||||
|
||||
stage = next_stage;
|
||||
phase = next_phase;
|
||||
}
|
||||
|
||||
float4 accum4;
|
||||
accum4.x = accum[0];
|
||||
accum4.y = accum[1];
|
||||
accum4.z = accum[2];
|
||||
accum4.w = accum[3];
|
||||
reduction_buffer[threadIdx.x] = accum4;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (warp_id == 0) {
|
||||
int mi = mib + warp_id * WARP_TILE_M;
|
||||
int tm = mi + lane_id / 4;
|
||||
int tn = ni + 2 * (lane_id % 4);
|
||||
|
||||
float4 accum1 = reduction_buffer[32 + threadIdx.x];
|
||||
float4 accum2 = reduction_buffer[64 + threadIdx.x];
|
||||
float4 accum3 = reduction_buffer[96 + threadIdx.x];
|
||||
|
||||
accum[0] = accum[0] + accum1.x + accum2.x + accum3.x;
|
||||
accum[1] = accum[1] + accum1.y + accum2.y + accum3.y;
|
||||
accum[2] = accum[2] + accum1.z + accum2.z + accum3.z;
|
||||
accum[3] = accum[3] + accum1.w + accum2.w + accum3.w;
|
||||
|
||||
float bias_lo = __bfloat162float(sh_bias[tm - mib]);
|
||||
float bias_hi = __bfloat162float(sh_bias[tm + 8 - mib]);
|
||||
|
||||
if (tn < N && tm < M)
|
||||
output[tn * M + tm] = __float2bfloat16(accum[0] + bias_lo);
|
||||
if (tn + 1 < N && tm < M)
|
||||
output[(tn + 1) * M + tm] = __float2bfloat16(accum[1] + bias_lo);
|
||||
if (tn < N && tm + 8 < M)
|
||||
output[tn * M + tm + 8] = __float2bfloat16(accum[2] + bias_hi);
|
||||
if (tn + 1 < N && tm + 8 < M)
|
||||
output[(tn + 1) * M + tm + 8] = __float2bfloat16(accum[3] + bias_hi);
|
||||
|
||||
if (PROFILE && blockIdx.y == 0 && threadIdx.x == 0)
|
||||
profile[blockIdx.x].complete = gclock64();
|
||||
}
|
||||
}
|
||||
#endif // end if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||||
}
|
||||
@@ -70,8 +70,4 @@ torch::Tensor router_gemm_bf16_fp32(torch::Tensor const& input,
|
||||
// Supports num_tokens in [1, 16], num_experts in {256, 384}, hidden_dim = 7168
|
||||
void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a,
|
||||
const torch::Tensor& mat_b);
|
||||
|
||||
// gpt-oss optimized router GEMM kernel for SM90+
|
||||
void gpt_oss_router_gemm(torch::Tensor& output, torch::Tensor input,
|
||||
torch::Tensor weight, torch::Tensor bias);
|
||||
#endif
|
||||
|
||||
@@ -132,12 +132,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
// DeepSeek V3 optimized router GEMM for SM90+
|
||||
m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
// gpt-oss optimized router GEMM kernel for SM90+
|
||||
m.def(
|
||||
"gpt_oss_router_gemm(Tensor! output, Tensor input, Tensor weights, "
|
||||
"Tensor bias) -> ()");
|
||||
m.impl("gpt_oss_router_gemm", torch::kCUDA, &gpt_oss_router_gemm);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for optimized router GEMM kernel
|
||||
|
||||
Run `pytest tests/kernels/moe/test_router_gemm.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not (
|
||||
current_platform.is_cuda()
|
||||
and (
|
||||
current_platform.is_device_capability(90)
|
||||
or current_platform.is_device_capability_family(100)
|
||||
)
|
||||
),
|
||||
reason="This test only runs on Hopper or Blackwell GPUs.",
|
||||
)
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("input_dim", [360, 720, 1440, 2880])
|
||||
@pytest.mark.parametrize("output_dim", [32, 64, 128])
|
||||
def test_gpt_oss_router_gemm(batch_size, input_dim, output_dim):
|
||||
set_random_seed(0)
|
||||
x = torch.randn(batch_size, input_dim, device="cuda", dtype=torch.bfloat16)
|
||||
weight = torch.randn(output_dim, input_dim, device="cuda", dtype=torch.bfloat16)
|
||||
bias = torch.randn(output_dim, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
output = ops.gpt_oss_router_gemm(x, weight, bias)
|
||||
output_ref = torch.nn.functional.linear(x, weight, bias)
|
||||
torch.testing.assert_close(output, output_ref, atol=1e-2, rtol=1e-2)
|
||||
@@ -2327,19 +2327,6 @@ def dsv3_router_gemm(
|
||||
return output
|
||||
|
||||
|
||||
def gpt_oss_router_gemm(
|
||||
hidden_states: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
output = torch.empty(
|
||||
hidden_states.shape[0],
|
||||
weight.shape[0],
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
torch.ops._moe_C.gpt_oss_router_gemm(output, hidden_states, weight, bias)
|
||||
return output
|
||||
|
||||
|
||||
def topk_softmax(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
|
||||
@@ -13,7 +13,6 @@ from vllm.lora.layers.column_parallel_linear import (
|
||||
QKVParallelLinearWithShardedLoRA,
|
||||
)
|
||||
from vllm.lora.layers.fused_moe import FusedMoE3DWithLoRA, FusedMoEWithLoRA
|
||||
from vllm.lora.layers.gate_linear import GateLinearWithLoRA
|
||||
from vllm.lora.layers.logits_processor import LogitsProcessorWithLoRA
|
||||
from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA
|
||||
from vllm.lora.layers.row_parallel_linear import (
|
||||
@@ -39,7 +38,6 @@ __all__ = [
|
||||
"RowParallelLinearWithLoRA",
|
||||
"RowParallelLinearWithShardedLoRA",
|
||||
"ReplicatedLinearWithLoRA",
|
||||
"GateLinearWithLoRA",
|
||||
"LoRAMapping",
|
||||
"LoRAMappingType",
|
||||
"FusedMoEWithLoRA",
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.model_executor.custom_op import maybe_get_oot_by_class
|
||||
from vllm.model_executor.layers.fused_moe.router.gate_linear import GateLinear
|
||||
|
||||
from .replicated_linear import ReplicatedLinearWithLoRA
|
||||
|
||||
|
||||
class GateLinearWithLoRA(ReplicatedLinearWithLoRA):
|
||||
def __init__(self, base_layer: GateLinear) -> None:
|
||||
super().__init__(
|
||||
base_layer,
|
||||
)
|
||||
|
||||
# GateLinearWithLoRA should always be replaced, regardless of the fully
|
||||
# sharded LoRAs setting, because it is, by definition, copied per GPU.
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
return type(source_layer) is maybe_get_oot_by_class(GateLinear)
|
||||
@@ -21,7 +21,6 @@ from vllm.lora.layers import (
|
||||
ColumnParallelLinearWithShardedLoRA,
|
||||
FusedMoE3DWithLoRA,
|
||||
FusedMoEWithLoRA,
|
||||
GateLinearWithLoRA,
|
||||
LogitsProcessorWithLoRA,
|
||||
MergedColumnParallelLinearVariableSliceWithLoRA,
|
||||
MergedColumnParallelLinearWithLoRA,
|
||||
@@ -82,7 +81,6 @@ _all_lora_classes: set[type[BaseLayerWithLoRA]] = {
|
||||
MergedQKVParallelLinearWithLoRA,
|
||||
RowParallelLinearWithLoRA,
|
||||
ReplicatedLinearWithLoRA,
|
||||
GateLinearWithLoRA,
|
||||
LogitsProcessorWithLoRA,
|
||||
ColumnParallelLinearWithShardedLoRA,
|
||||
QKVParallelLinearWithShardedLoRA,
|
||||
|
||||
@@ -3,11 +3,9 @@
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.model_executor.custom_op import PluggableLayer
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
|
||||
@PluggableLayer.register("gate_linear")
|
||||
@@ -15,9 +13,8 @@ class GateLinear(ReplicatedLinear):
|
||||
"""MoE gate linear layer with three-tier GEMM dispatch:
|
||||
|
||||
1. DSV3 specialized kernel (SM90+, batch<=16, supported dims)
|
||||
2. gpt-oss specialized kernel (SM90+, batch<=128, supported dims)
|
||||
3. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype)
|
||||
4. F.linear via ReplicatedLinear (ultimate fallback)
|
||||
2. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype)
|
||||
3. F.linear via ReplicatedLinear (ultimate fallback)
|
||||
|
||||
The ``out_dtype`` attribute is mutable and can be set after init
|
||||
(e.g. when the required dtype depends on the expert quantization
|
||||
@@ -28,10 +25,6 @@ class GateLinear(ReplicatedLinear):
|
||||
DSV3_SUPPORTED_NUM_EXPERTS = [256, 384]
|
||||
DSV3_SUPPORTED_HIDDEN_SIZES = [7168]
|
||||
|
||||
# Dimensions supported by the gpt-oss specialized kernel
|
||||
GPT_OSS_SUPPORTED_NUM_EXPERTS = [32, 128]
|
||||
GPT_OSS_SUPPORTED_HIDDEN_SIZES = [2880]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
@@ -72,15 +65,6 @@ class GateLinear(ReplicatedLinear):
|
||||
and input_size in self.DSV3_SUPPORTED_HIDDEN_SIZES
|
||||
)
|
||||
|
||||
# gpt-oss specialized kernel eligibility (SM90+, exact dims)
|
||||
self.allow_gpt_oss_router_gemm = (
|
||||
self.weight.dtype == torch.bfloat16
|
||||
and current_platform.is_cuda()
|
||||
and is_hopper_or_blackwell
|
||||
and output_size in self.GPT_OSS_SUPPORTED_NUM_EXPERTS
|
||||
and input_size in self.GPT_OSS_SUPPORTED_HIDDEN_SIZES
|
||||
)
|
||||
|
||||
# cuBLAS bf16→fp32 eligibility
|
||||
self.allow_cublas_router_gemm = (
|
||||
self.allow_specialized_router_gemm
|
||||
@@ -108,6 +92,8 @@ class GateLinear(ReplicatedLinear):
|
||||
def forward(
|
||||
self, x: torch.Tensor
|
||||
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
|
||||
import vllm._custom_ops as ops
|
||||
|
||||
# Tier 1: DSV3 specialized kernel
|
||||
if self.allow_dsv3_router_gemm and x.shape[0] <= 16:
|
||||
output = ops.dsv3_router_gemm(
|
||||
@@ -117,47 +103,15 @@ class GateLinear(ReplicatedLinear):
|
||||
)
|
||||
return output, None
|
||||
|
||||
# Tier 2: gpt-oss specialized kernel
|
||||
if self.allow_gpt_oss_router_gemm:
|
||||
output = torch.ops.vllm.gpt_oss_router_gemm(x, self.weight, self.bias)
|
||||
return output, None
|
||||
|
||||
# Tier 3: cuBLAS bf16→fp32
|
||||
# Tier 2: cuBLAS bf16→fp32
|
||||
if self.allow_cublas_router_gemm and x.dtype == torch.bfloat16:
|
||||
output = ops.router_gemm_bf16_fp32(x, self.weight)
|
||||
return output, None
|
||||
|
||||
# Tier 4: F.linear (ReplicatedLinear)
|
||||
# Tier 3: F.linear (ReplicatedLinear)
|
||||
if self.out_dtype is not None and x.dtype != self.weight.dtype:
|
||||
x = x.to(self.weight.dtype)
|
||||
output, output_bias = super().forward(x)
|
||||
if self.out_dtype is not None and output.dtype != self.out_dtype:
|
||||
output = output.to(self.out_dtype)
|
||||
return output, output_bias
|
||||
|
||||
|
||||
def gpt_oss_router_gemm_impl(
|
||||
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Dynamically run min-latency gemm if num_tokens <= 128.
|
||||
This must be wrapped in a custom op because our torch.compile integration
|
||||
does not support runtime dispatching on num_tokens.
|
||||
"""
|
||||
if x.shape[0] <= 128:
|
||||
return ops.gpt_oss_router_gemm(x, weight, bias)
|
||||
else:
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
|
||||
def gpt_oss_router_gemm_fake(
|
||||
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
return x.new_empty((x.shape[0], weight.shape[0]))
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="gpt_oss_router_gemm",
|
||||
op_func=gpt_oss_router_gemm_impl,
|
||||
fake_impl=gpt_oss_router_gemm_fake,
|
||||
)
|
||||
|
||||
@@ -20,11 +20,12 @@ from vllm.distributed import (
|
||||
tensor_model_parallel_all_gather,
|
||||
)
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
@@ -174,11 +175,13 @@ class MLPBlock(torch.nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
self.experts_per_token = config.num_experts_per_tok
|
||||
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
self.router = GateLinear(
|
||||
self.router = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.num_local_experts,
|
||||
bias=True,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.router",
|
||||
return_bias=False,
|
||||
)
|
||||
assert config.intermediate_size % self.world_size == 0
|
||||
self.experts = FusedMoE(
|
||||
@@ -206,7 +209,7 @@ class MLPBlock(torch.nn.Module):
|
||||
self, x[:, : self.hidden_size], self.router.weight, self.router.bias
|
||||
)
|
||||
else:
|
||||
g, _ = self.router(x)
|
||||
g = self.router(x)
|
||||
x = self.experts(hidden_states=x, router_logits=g)[:, : self.hidden_size]
|
||||
|
||||
if self.is_sequence_parallel:
|
||||
|
||||
Reference in New Issue
Block a user