From 60715f89bce8bebc7a61fb978345da7a445d0fc1 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 1 Jun 2026 20:49:55 +0000 Subject: [PATCH] Fix CUDA kernel compilation: use c10::cuda::getCurrentCUDAStream MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - amax_gsa.cu: fix at::cuda::getCurrentCUDAStream → c10:: - amax_gsa.cu: fix torch::TensorOptions().device() → x.options() - sampler.cu: same fixes for compilation on B200 - Both kernels now compile cleanly with torch.utils.cpp_extension.load --- dsv4/kernels/cuda/amax_gsa.cu | 11 ++++++----- dsv4/kernels/cuda/sampler.cu | 6 ++++-- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/dsv4/kernels/cuda/amax_gsa.cu b/dsv4/kernels/cuda/amax_gsa.cu index d3d71f76..b64f2893 100644 --- a/dsv4/kernels/cuda/amax_gsa.cu +++ b/dsv4/kernels/cuda/amax_gsa.cu @@ -5,13 +5,14 @@ * No CPU-GPU sync. The output tensor stays on GPU and can be passed * directly to CuTeDSL GEMM's global_scale_a parameter via to_cute(). * - * This eliminates ~610 CPU-GPU syncs per decode step from Nvfp4Linear, - * ~183 from Nvfp4MoE, and ~122 from Nvfp4SharedExpert. + * This eliminates ~915 CPU-GPU syncs per decode step from Nvfp4Linear, + * Nvfp4MoE, and Nvfp4SharedExpert. */ #include #include #include +#include #include #include #include @@ -51,10 +52,10 @@ torch::Tensor compute_amax_gsa_cuda(torch::Tensor x, double divisor) { TORCH_CHECK(x.scalar_type() == torch::kBFloat16, "input must be BF16"); int n = x.numel(); - auto out = torch::zeros({}, torch::TensorOptions() - .dtype(torch::kFloat32).device(x.device)); + auto options = x.options().dtype(torch::kFloat32); + auto out = torch::zeros({}, options); - compute_amax_gsa_kernel<<<1, 256, 0, at::cuda::getCurrentCUDAStream()>>>( + compute_amax_gsa_kernel<<<1, 256, 0, c10::cuda::getCurrentCUDAStream()>>>( reinterpret_cast(x.data_ptr()), n, (float)divisor, out.data_ptr() diff --git a/dsv4/kernels/cuda/sampler.cu b/dsv4/kernels/cuda/sampler.cu index 06cfeb88..77fef23d 100644 --- a/dsv4/kernels/cuda/sampler.cu +++ b/dsv4/kernels/cuda/sampler.cu @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -168,10 +169,11 @@ torch::Tensor sample_cuda( int mp = 0; const int64_t* pi = nullptr; const float* pv = nullptr; if (pen_ids && pen_ids->numel()) { mp = pen_ids->size(1); pi = pen_ids->data_ptr(); pv = pen_vals->data_ptr(); } - auto out = torch::empty({B}, torch::TensorOptions().dtype(torch::kInt64).device(logits.device())); + auto options = logits.options().dtype(torch::kInt64); + auto out = torch::empty({B}, options); int smem = BDIM * LK * (sizeof(float) + sizeof(int)); - fused_sampler_kernel<<>>( + fused_sampler_kernel<<>>( logits.data_ptr(), pi, pv, B, V, logits.stride(0), mp, (float)temperature, (int)top_k, (float)top_p, (int)min_keep,