Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -683,7 +683,7 @@ def get_accelerator_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tens
|
||||
if current_platform.is_xpu():
|
||||
assert cpu_tensor.is_pinned(), "CPU tensor must be pinned"
|
||||
return torch.ops._C.get_xpu_view_from_cpu_tensor(cpu_tensor)
|
||||
elif current_platform.is_cuda():
|
||||
elif current_platform.is_cuda() or current_platform.is_rocm():
|
||||
return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
||||
Reference in New Issue
Block a user