Fix CUDA kernel compilation: use c10::cuda::getCurrentCUDAStream
- amax_gsa.cu: fix at::cuda::getCurrentCUDAStream → c10:: - amax_gsa.cu: fix torch::TensorOptions().device() → x.options() - sampler.cu: same fixes for compilation on B200 - Both kernels now compile cleanly with torch.utils.cpp_extension.load
This commit is contained in:
@@ -5,13 +5,14 @@
|
||||
* No CPU-GPU sync. The output tensor stays on GPU and can be passed
|
||||
* directly to CuTeDSL GEMM's global_scale_a parameter via to_cute().
|
||||
*
|
||||
* This eliminates ~610 CPU-GPU syncs per decode step from Nvfp4Linear,
|
||||
* ~183 from Nvfp4MoE, and ~122 from Nvfp4SharedExpert.
|
||||
* This eliminates ~915 CPU-GPU syncs per decode step from Nvfp4Linear,
|
||||
* Nvfp4MoE, and Nvfp4SharedExpert.
|
||||
*/
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/extension.h>
|
||||
#include <cstdint>
|
||||
#include <cfloat>
|
||||
@@ -51,10 +52,10 @@ torch::Tensor compute_amax_gsa_cuda(torch::Tensor x, double divisor) {
|
||||
TORCH_CHECK(x.scalar_type() == torch::kBFloat16, "input must be BF16");
|
||||
|
||||
int n = x.numel();
|
||||
auto out = torch::zeros({}, torch::TensorOptions()
|
||||
.dtype(torch::kFloat32).device(x.device));
|
||||
auto options = x.options().dtype(torch::kFloat32);
|
||||
auto out = torch::zeros({}, options);
|
||||
|
||||
compute_amax_gsa_kernel<<<1, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
compute_amax_gsa_kernel<<<1, 256, 0, c10::cuda::getCurrentCUDAStream()>>>(
|
||||
reinterpret_cast<const __nv_bfloat16*>(x.data_ptr<at::BFloat16>()),
|
||||
n, (float)divisor,
|
||||
out.data_ptr<float>()
|
||||
|
||||
@@ -27,6 +27,7 @@
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/extension.h>
|
||||
#include <cstdint>
|
||||
#include <cfloat>
|
||||
@@ -168,10 +169,11 @@ torch::Tensor sample_cuda(
|
||||
int mp = 0; const int64_t* pi = nullptr; const float* pv = nullptr;
|
||||
if (pen_ids && pen_ids->numel()) { mp = pen_ids->size(1); pi = pen_ids->data_ptr<int64_t>(); pv = pen_vals->data_ptr<float>(); }
|
||||
|
||||
auto out = torch::empty({B}, torch::TensorOptions().dtype(torch::kInt64).device(logits.device()));
|
||||
auto options = logits.options().dtype(torch::kInt64);
|
||||
auto out = torch::empty({B}, options);
|
||||
int smem = BDIM * LK * (sizeof(float) + sizeof(int));
|
||||
|
||||
fused_sampler_kernel<<<B, BDIM, smem, at::cuda::getCurrentCUDAStream()>>>(
|
||||
fused_sampler_kernel<<<B, BDIM, smem, c10::cuda::getCurrentCUDAStream()>>>(
|
||||
logits.data_ptr<float>(), pi, pv,
|
||||
B, V, logits.stride(0), mp,
|
||||
(float)temperature, (int)top_k, (float)top_p, (int)min_keep,
|
||||
|
||||
Reference in New Issue
Block a user