[ROCm][Bugfix] Ensure that the moe_wna16_gemm kernel is not built on ROCm platforms. (#14629)

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore
2025-03-12 05:00:28 -07:00
committed by GitHub
parent ff47aab056
commit 45f3f3f59e
4 changed files with 8 additions and 3 deletions

View File

@@ -31,6 +31,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
" Tensor! num_tokens_post_pad) -> ()");
m.impl("sgl_moe_align_block_size", torch::kCUDA, &sgl_moe_align_block_size);
#ifndef USE_ROCM
m.def(
"moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, "
"Tensor b_scales, Tensor? b_qzeros, "
@@ -41,7 +42,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.impl("moe_wna16_gemm", torch::kCUDA, &moe_wna16_gemm);
#ifndef USE_ROCM
m.def(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "