2025-03-31 20:22:34 +08:00
|
|
|
#include <torch/all.h>
|
|
|
|
|
#include <torch/cuda.h>
|
|
|
|
|
#include <cuda_runtime.h>
|
|
|
|
|
|
2026-02-13 11:11:26 -05:00
|
|
|
// This function assumes that `cpu_tensor` is a CPU tensor,
|
|
|
|
|
// and that UVA (Unified Virtual Addressing) is enabled.
|
2025-03-31 20:22:34 +08:00
|
|
|
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor) {
|
|
|
|
|
TORCH_CHECK(cpu_tensor.device().is_cpu(), "Input tensor must be on CPU");
|
|
|
|
|
|
2026-02-13 11:11:26 -05:00
|
|
|
// handle empty tensor
|
|
|
|
|
if (cpu_tensor.numel() == 0) {
|
|
|
|
|
return torch::empty(cpu_tensor.sizes(),
|
|
|
|
|
cpu_tensor.options().device(torch::kCUDA));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (cpu_tensor.is_pinned()) {
|
|
|
|
|
// If CPU tensor is pinned, directly get the device pointer.
|
|
|
|
|
void* host_ptr = const_cast<void*>(cpu_tensor.data_ptr());
|
|
|
|
|
void* device_ptr = nullptr;
|
|
|
|
|
cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
|
|
|
|
|
TORCH_CHECK(err == cudaSuccess,
|
|
|
|
|
"cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));
|
|
|
|
|
|
|
|
|
|
return torch::from_blob(
|
|
|
|
|
device_ptr, cpu_tensor.sizes(), cpu_tensor.strides(),
|
|
|
|
|
[base = cpu_tensor](void*) {}, // keep cpu tensor alive
|
|
|
|
|
cpu_tensor.options().device(torch::kCUDA));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// If CPU tensor is not pinned, allocate a new pinned memory buffer.
|
|
|
|
|
torch::Tensor contiguous_cpu = cpu_tensor.contiguous();
|
|
|
|
|
size_t nbytes = contiguous_cpu.nbytes();
|
|
|
|
|
|
|
|
|
|
void* host_ptr = nullptr;
|
|
|
|
|
cudaError_t err = cudaHostAlloc(&host_ptr, nbytes, cudaHostAllocMapped);
|
|
|
|
|
if (err != cudaSuccess) {
|
|
|
|
|
AT_ERROR("cudaHostAlloc failed: ", cudaGetErrorString(err));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
err = cudaMemcpy(host_ptr, contiguous_cpu.data_ptr(), nbytes,
|
|
|
|
|
cudaMemcpyDefault);
|
|
|
|
|
if (err != cudaSuccess) {
|
|
|
|
|
cudaFreeHost(host_ptr);
|
|
|
|
|
AT_ERROR("cudaMemcpy failed: ", cudaGetErrorString(err));
|
|
|
|
|
}
|
2025-03-31 20:22:34 +08:00
|
|
|
|
|
|
|
|
void* device_ptr = nullptr;
|
2026-02-13 11:11:26 -05:00
|
|
|
err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
|
|
|
|
|
if (err != cudaSuccess) {
|
|
|
|
|
cudaFreeHost(host_ptr);
|
|
|
|
|
AT_ERROR("cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto deleter = [host_ptr](void*) { cudaFreeHost(host_ptr); };
|
|
|
|
|
|
|
|
|
|
return torch::from_blob(device_ptr, contiguous_cpu.sizes(),
|
|
|
|
|
contiguous_cpu.strides(), deleter,
|
|
|
|
|
contiguous_cpu.options().device(torch::kCUDA));
|
|
|
|
|
}
|