Fix CUDA kernel compilation: use c10::cuda::getCurrentCUDAStream

- amax_gsa.cu: fix at::cuda::getCurrentCUDAStream → c10::
- amax_gsa.cu: fix torch::TensorOptions().device() → x.options()
- sampler.cu: same fixes for compilation on B200
- Both kernels now compile cleanly with torch.utils.cpp_extension.load
This commit is contained in:
2026-06-01 20:49:55 +00:00
parent 2dc5b4ec19
commit 60715f89bc
2 changed files with 10 additions and 7 deletions

View File

@@ -5,13 +5,14 @@
* No CPU-GPU sync. The output tensor stays on GPU and can be passed
* directly to CuTeDSL GEMM's global_scale_a parameter via to_cute().
*
* This eliminates ~610 CPU-GPU syncs per decode step from Nvfp4Linear,
* ~183 from Nvfp4MoE, and ~122 from Nvfp4SharedExpert.
* This eliminates ~915 CPU-GPU syncs per decode step from Nvfp4Linear,
* Nvfp4MoE, and Nvfp4SharedExpert.
*/
#include <cuda.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>
#include <cstdint>
#include <cfloat>
@@ -51,10 +52,10 @@ torch::Tensor compute_amax_gsa_cuda(torch::Tensor x, double divisor) {
TORCH_CHECK(x.scalar_type() == torch::kBFloat16, "input must be BF16");
int n = x.numel();
auto out = torch::zeros({}, torch::TensorOptions()
.dtype(torch::kFloat32).device(x.device));
auto options = x.options().dtype(torch::kFloat32);
auto out = torch::zeros({}, options);
compute_amax_gsa_kernel<<<1, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
compute_amax_gsa_kernel<<<1, 256, 0, c10::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<const __nv_bfloat16*>(x.data_ptr<at::BFloat16>()),
n, (float)divisor,
out.data_ptr<float>()

View File

@@ -27,6 +27,7 @@
#include <cuda.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>
#include <cstdint>
#include <cfloat>
@@ -168,10 +169,11 @@ torch::Tensor sample_cuda(
int mp = 0; const int64_t* pi = nullptr; const float* pv = nullptr;
if (pen_ids && pen_ids->numel()) { mp = pen_ids->size(1); pi = pen_ids->data_ptr<int64_t>(); pv = pen_vals->data_ptr<float>(); }
auto out = torch::empty({B}, torch::TensorOptions().dtype(torch::kInt64).device(logits.device()));
auto options = logits.options().dtype(torch::kInt64);
auto out = torch::empty({B}, options);
int smem = BDIM * LK * (sizeof(float) + sizeof(int));
fused_sampler_kernel<<<B, BDIM, smem, at::cuda::getCurrentCUDAStream()>>>(
fused_sampler_kernel<<<B, BDIM, smem, c10::cuda::getCurrentCUDAStream()>>>(
logits.data_ptr<float>(), pi, pv,
B, V, logits.stride(0), mp,
(float)temperature, (int)top_k, (float)top_p, (int)min_keep,