Fix CUDA stream API: getCurrentCUDAStream().stream()
This commit is contained in:
@@ -104,11 +104,12 @@ void launch_blackwell_swizzle(
|
||||
// Pybind11 bindings for torch.utils.cpp_extension.load
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("blackwell_swizzle_32_4_4", [](at::Tensor input, at::Tensor output, int32_t rows, int32_t cols) {
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
launch_blackwell_swizzle(
|
||||
input.data_ptr<uint8_t>(),
|
||||
output.data_ptr<uint8_t>(),
|
||||
rows, cols,
|
||||
at::cuda::getCurrentCUDAStream()
|
||||
stream
|
||||
);
|
||||
}, "Blackwell 32_4_4 scale swizzle");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user