diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index 1bff517fd..fe047e0df 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -683,7 +683,7 @@ def get_accelerator_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tens if current_platform.is_xpu(): assert cpu_tensor.is_pinned(), "CPU tensor must be pinned" return torch.ops._C.get_xpu_view_from_cpu_tensor(cpu_tensor) - elif current_platform.is_cuda(): + elif current_platform.is_cuda() or current_platform.is_rocm(): return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor) else: raise ValueError(