[Bugfix] Fix DSV3 kernels breaking _C and _moe_C on unsupported arches (#35123)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2026-02-23 20:11:27 -05:00
committed by GitHub
parent 22a97e6613
commit 3ef9fd0f98
5 changed files with 12 additions and 3 deletions

View File

@@ -783,7 +783,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
SRCS "${DSV3_FUSED_A_GEMM_SRC}"
CUDA_ARCHS "${DSV3_FUSED_A_GEMM_ARCHS}")
list(APPEND VLLM_EXT_SRC ${DSV3_FUSED_A_GEMM_SRC})
list(APPEND VLLM_GPU_FLAGS "-DENABLE_DSV3_FUSED_A_GEMM=1")
message(STATUS "Building dsv3_fused_a_gemm for archs: ${DSV3_FUSED_A_GEMM_ARCHS}")
else()
message(STATUS "Not building dsv3_fused_a_gemm as no compatible archs found "

View File

@@ -745,3 +745,7 @@ void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a,
stream);
}
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("dsv3_fused_a_gemm", &dsv3_fused_a_gemm);
}

View File

@@ -20,10 +20,12 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include "core/registration.h"
#include "dsv3_router_gemm_utils.h"
static constexpr int DEFAULT_NUM_EXPERTS = 256;
@@ -161,3 +163,7 @@ void dsv3_router_gemm(at::Tensor& output, // [num_tokens, num_experts]
}
}
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("dsv3_router_gemm", &dsv3_router_gemm);
}

View File

@@ -127,7 +127,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// DeepSeek V3 optimized router GEMM for SM90+
m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
m.impl("dsv3_router_gemm", torch::kCUDA, &dsv3_router_gemm);
// conditionally compiled so impl registration is in source file
#endif
}

View File

@@ -242,7 +242,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// DeepSeek V3 fused A GEMM (SM 9.0+, bf16 only, 1-16 tokens).
ops.def(
"dsv3_fused_a_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
ops.impl("dsv3_fused_a_gemm", torch::kCUDA, &dsv3_fused_a_gemm);
// conditionally compiled so impl registration is in source file
// Quantized GEMM for AWQ.
ops.def(