Fix CUDA stream API: getCurrentCUDAStream().stream()

This commit is contained in:
2026-06-04 03:43:04 +00:00
parent 9b3917e248
commit e26c28a1ce

View File

@@ -104,11 +104,12 @@ void launch_blackwell_swizzle(
// Pybind11 bindings for torch.utils.cpp_extension.load
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("blackwell_swizzle_32_4_4", [](at::Tensor input, at::Tensor output, int32_t rows, int32_t cols) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
launch_blackwell_swizzle(
input.data_ptr<uint8_t>(),
output.data_ptr<uint8_t>(),
rows, cols,
at::cuda::getCurrentCUDAStream()
stream
);
}, "Blackwell 32_4_4 scale swizzle");
}