diff --git a/CMakeLists.txt b/CMakeLists.txt index a6f7f6946..55127a514 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -783,7 +783,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") SRCS "${DSV3_FUSED_A_GEMM_SRC}" CUDA_ARCHS "${DSV3_FUSED_A_GEMM_ARCHS}") list(APPEND VLLM_EXT_SRC ${DSV3_FUSED_A_GEMM_SRC}) - list(APPEND VLLM_GPU_FLAGS "-DENABLE_DSV3_FUSED_A_GEMM=1") message(STATUS "Building dsv3_fused_a_gemm for archs: ${DSV3_FUSED_A_GEMM_ARCHS}") else() message(STATUS "Not building dsv3_fused_a_gemm as no compatible archs found " diff --git a/csrc/dsv3_fused_a_gemm.cu b/csrc/dsv3_fused_a_gemm.cu index 5b8374303..65dff9c84 100644 --- a/csrc/dsv3_fused_a_gemm.cu +++ b/csrc/dsv3_fused_a_gemm.cu @@ -745,3 +745,7 @@ void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, stream); } } + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("dsv3_fused_a_gemm", &dsv3_fused_a_gemm); +} diff --git a/csrc/moe/dsv3_router_gemm_entry.cu b/csrc/moe/dsv3_router_gemm_entry.cu index 1ba97bd76..38fb681c2 100644 --- a/csrc/moe/dsv3_router_gemm_entry.cu +++ b/csrc/moe/dsv3_router_gemm_entry.cu @@ -20,10 +20,12 @@ #include #include +#include #include #include +#include "core/registration.h" #include "dsv3_router_gemm_utils.h" static constexpr int DEFAULT_NUM_EXPERTS = 256; @@ -161,3 +163,7 @@ void dsv3_router_gemm(at::Tensor& output, // [num_tokens, num_experts] } } } + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("dsv3_router_gemm", &dsv3_router_gemm); +} diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 22b00f20a..438599451 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -127,7 +127,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { // DeepSeek V3 optimized router GEMM for SM90+ m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"); - m.impl("dsv3_router_gemm", torch::kCUDA, &dsv3_router_gemm); + // conditionally compiled so impl registration is in source file #endif } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index c16b9c223..39b6bc98a 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -242,7 +242,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // DeepSeek V3 fused A GEMM (SM 9.0+, bf16 only, 1-16 tokens). ops.def( "dsv3_fused_a_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"); - ops.impl("dsv3_fused_a_gemm", torch::kCUDA, &dsv3_fused_a_gemm); + // conditionally compiled so impl registration is in source file // Quantized GEMM for AWQ. ops.def(